Detailed Explanation of LoRA, DPO, KTO, and SFT Technologies
Introduction to LLM Training Terminology:LoRA, DPO, KTO, and SFT Technologies
This document provides a detailed introduction to several important techniques used in fine-tuning and optimizing large language models (such as LLAMA3), including SFT (Supervised Fine-Tuning), LoRA (Low-Rank Adaptation), Alignment technologies, KTO (Kahneman-Tversky Optimization), and DPO (Direct Preference Optimization). The document also elaborates on the principles of each technique, specific implementation methods, as well as the selection of corresponding loss functions and optimizers.
1. SFT (Supervised Fine-Tuning)
1.1 Principle
SFT is a traditional fine-tuning method that adjusts the parameters of a pre-trained model through supervised learning to improve its performance on specific tasks. SFT is typically used to fine-tune models on specific labeled datasets, with the training process resembling standard supervised learning.
1.2 Implementation Method
- Select a Pre-trained Model: Such as GPT, BERT, and other language models.
- Prepare a Labeled Dataset: The dataset includes input-output pairs.
- Train the Model: Use a standard cross-entropy loss function to train the model, optimizing parameters through gradient descent.
1.3 Core Code
Using Hugging Face’s Trainer
interface for SFT:
1 | from transformers import Trainer, TrainingArguments, AutoModelForSeq2SeqLM, AutoTokenizer |
2. LoRA (Low-Rank Adaptation)
2.1 Principle
LoRA is a parameter-efficient fine-tuning technique that performs low-rank decomposition of the weight matrices in large models. It decomposes the original weight matrix $W$ into two low-rank matrices $B$ and $A$, and only fine-tunes these low-rank matrices. The design goal of LoRA is to reduce the number of fine-tuning parameters while retaining the pre-trained model weights, optimizing model performance by adjusting the low-rank matrices.
2.2 Implementation Method
- Weight Decomposition: For the model’s linear layers (such as the
q_proj
andv_proj
layers in the attention mechanism), decompose the weight matrix into two low-rank matrices $B$ and $A$. - Fine-Tune Specific Layers: Apply LoRA only to these specific linear layers, keeping other layers in the model unchanged.
2.3 Layers to Fine-Tune vs. Layers to Keep Unchanged
Layers to Fine-Tune
LoRA is typically applied to the linear projection layers in Transformer models, especially several key layers in the multi-head attention mechanism:
- q_proj (Query Projection Layer)
- k_proj (Key Projection Layer)
- v_proj (Value Projection Layer)
- o_proj (Output Projection Layer)
- ffn_up_proj and ffn_down_proj (Up and Down Projection Layers of the Feedforward Neural Network)
Layers to Keep Unchanged
- Embedding Layers: Responsible for encoding inputs and outputs, usually do not require fine-tuning.
- LayerNorm Layers: These layers are mainly used for normalization, do not contain many parameters, and are typically kept unchanged.
- Activation Function Layers: Non-linear activation functions like ReLU or GELU do not involve parameters and do not require fine-tuning.
2.4 Loss Function
The loss function for LoRA is usually task-specific. In language generation tasks, LoRA uses cross-entropy loss to measure the difference between the generated text and the target text:
$$
\mathcal{L}{\text{LoRA}} = - \sum{i} y_i \log(\hat{y}_i)
$$
where $y_i$ is the true label, and $\hat{y}_i$ is the model’s output probability.
2.5 Optimizer
LoRA fine-tuning typically uses the AdamW optimizer, as shown in the following code:
1 | optimizer = torch.optim.AdamW(lora_model.parameters(), lr=5e-5) |
2.6 Core Code
Implementing LoRA using the peft
library:
1 | from peft import LoraConfig, get_peft_model |
3. Alignment (Alignment Techniques)
Before introducing KL divergence, we first need to clarify how LLM alignment is achieved, along with the underlying principles and mathematical formulas.
1. What is Model Alignment?
The core objective of model alignment is to ensure that the language model’s outputs meet human expectations or preferences. Typically, the model is initially trained through large-scale supervised learning (SFT, Supervised Fine-Tuning) to generate a model with basic capabilities. Subsequently, through alignment techniques, the model is further adjusted to ensure that its generated content better aligns with human preferences or avoids producing harmful or erroneous information.
Core Mechanism of Alignment:
- Positive Samples: Outputs that meet human expectations (e.g., correct answers).
- Negative Samples: Outputs that do not meet human expectations (e.g., incorrect answers).
By using paired preference data or labels (correct/incorrect), the model’s outputs are further fine-tuned to generate more positive samples while reducing the probability of generating negative samples.
2. Mathematical Principles of Model Alignment
During the alignment process, the model generates outputs through a policy model, which is typically an SFT-trained language model used to generate outputs given an input. To optimize the model’s outputs to better align with human preferences, the following loss functions and optimization methods are commonly used:
2.1 Policy Model
Assume the current policy of the model is $\pi_\theta$, which represents the probability of the model generating output $y$ given input $x$:
$$
\pi_\theta(y|x)
$$
The objective of the policy model is to adjust the parameters $\theta$ to increase the probability of generating correct outputs (positive samples) and decrease the probability of generating incorrect outputs (negative samples).
2.2 Mechanism for Increasing Positive Sample Probability and Decreasing Negative Sample Probability
To achieve this goal, loss functions with preference comparisons or labels are typically used for optimization:
Optimization of Positive Samples: By increasing the loss weight of positive samples, the model is guided to generate positive samples with higher probability when faced with the same problem.
- The loss function for positive samples guides the model to produce more outputs that meet human expectations.
Penalty for Negative Samples: By applying higher loss weights to negative samples, the model learns to reduce the probability of generating these incorrect outputs.
- The loss function for negative samples aims to penalize the model more when it generates incorrect answers, thereby reducing the likelihood of such outputs.
In some methods, such as DPO and KTO, KL divergence between the current policy model and a reference model is calculated to prevent the model from deviating excessively from the original pre-trained model during optimization.
3. Role of Loss Functions and KL Divergence
In the model alignment process, the loss function typically consists of two parts:
- Preference Loss or Label Loss, used to optimize the model to generate outputs that meet human expectations.
- KL Divergence, used to constrain the model from deviating from the reference model.
3.1 Role of KL Divergence
KL divergence (Kullback-Leibler Divergence) measures the difference between two probability distributions. In model alignment, KL divergence is used to limit the distribution difference between the current model $\pi_\theta$ and the reference model $\pi_{\text{ref}}$, ensuring that the model’s outputs do not deviate excessively from the pre-trained model during optimization. The specific formula is:
$$
\text{KL}(\pi_\theta(y|x) | \pi_{\text{ref}}(y|x)) = \sum_y \pi_\theta(y|x) \log \frac{\pi_\theta(y|x)}{\pi_{\text{ref}}(y|x)}
$$
- If the KL divergence is large, it indicates that the current model’s generated distribution significantly differs from the reference model, which may mean the model is producing unreasonable outputs.
- By minimizing KL divergence, the model can be further optimized while ensuring the reasonableness of its outputs.
3.2 Loss Function Formulas
Based on preferences or labels, the model’s loss function can be expressed in the following forms:
Loss Function in DPO:
$$
L_{DPO} = -\mathbb{E}_{x, y_w, y_l \sim D} [\log \sigma(\beta (\log(\pi_\theta(y_w|x)) - \log(\pi_\theta(y_l|x))))]
$$
- $y_w$: Higher-preference answer.
- $y_l$: Lower-preference answer.
In DPO, KL divergence can be introduced as a regularization term:
$$
L_{DPO} = -\log \sigma(\log(\pi_\theta(y_w|x)) - \log(\pi_\theta(y_l|x)) + \beta \cdot \text{KL}(\pi_\theta | \pi_{\text{ref}}))
$$
By controlling KL divergence, the model’s outputs do not deviate too much from the reference model.
Loss Function in KTO:
The loss function in KTO is based on prospect theory and incorporates KL divergence as a core component:
$$
L_{KTO} = \lambda_U \cdot \sigma(\beta \cdot (\text{KL}(\pi_\theta(\text{Answer 2}) | \pi_{\text{ref}}) - r_{\theta}(\text{Answer 2})))
$$
- $r_{\theta}(x, y)$: The current policy’s confidence in negative samples (incorrect answers).
- KL divergence is used to measure the difference between the current model and the reference model, ensuring that while reducing the generation of negative samples, the model does not deviate from the original reference model.
By increasing the loss for negative samples (i.e., increasing the value of $\lambda_U$), the model reduces the confidence in negative samples, thereby decreasing the probability of generating similar incorrect answers in the future.
4. How to Optimize the Model
Through the loss functions introduced above, model optimization is typically performed using Gradient Descent. The gradients of the loss function reflect the differences between the model’s outputs and the expected outputs, and the optimization goal is to minimize the loss function.
Gradient Update Formula:
$$
\theta_{\text{new}} = \theta_{\text{old}} - \eta \nabla_{\theta} L
$$
where:
- $\eta$ is the learning rate, determining the step size of each parameter update.
- $\nabla_{\theta} L$ is the gradient of the loss function with respect to the model parameters, indicating the contribution of the current parameters to the loss.
Through continuous iteration, the model gradually increases the probability of generating positive samples and decreases the probability of generating negative samples, ultimately achieving model alignment.
- The core objective of Model Alignment is to optimize the model’s outputs to meet human expectations through preference or label data.
- The Policy Model ($\pi_\theta$) generates outputs, and KL divergence is used to control the degree of deviation from the reference model, preventing unreasonable biases during optimization.
- The Probability of Positive Samples is gradually increased through the optimization of the loss function, while the Probability of Negative Samples is reduced by increasing loss weights and lowering confidence.
- Gradient descent is used to update model parameters, ultimately achieving model alignment.
4. DPO (Direct Preference Optimization)
4.1 Principle
DPO directly optimizes the model’s output preference function to make the model’s outputs more aligned with human preferences. It compares different outputs generated by the model and uses a preference function to evaluate which of the two outputs is better, thereby guiding the optimization of the model parameters.
4.2 Loss Function
DPO uses a preference loss function to compare the quality of two outputs:
$$
\mathcal{L}_{\text{DPO}} = \log(1 + \exp(-\sigma \cdot (\hat{y}_a - \hat{y}_b) \cdot p))
$$
- $ \hat{y}_a $ and $ \hat{y}_b $ are the model’s predictions for two samples.
- $ p $ is the human preference (1 indicates preference for $a$, -1 indicates preference for $b$).
- $ \sigma $ is a smoothing parameter.
4.3 Optimizer
DPO typically uses the AdamW optimizer, which is suitable for optimizing large-scale parameter models. The code is as follows:
1 | optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) |
4.4 Core Code
The following are the training steps for DPO:
1 | import torch |
5. KTO (Kahneman-Tversky Optimization)
5.1 Principle
KTO is based on Kahneman and Tversky’s Prospect Theory, which uses an asymmetric utility function to measure the model’s gains and losses. It aims to optimize the model’s performance, especially in scenarios with asymmetric risks and rewards. The utility function is defined as follows:
$$
\mathcal{U}(x) =
\begin{cases}
x^{\alpha}, & x \geq 0 \
-\lambda (-x)^{\alpha}, & x < 0
\end{cases}
$$
- $x$ is the difference between the model’s prediction and the true value.
- $\alpha$ is the non-linear coefficient, typically 0.88.
- $\lambda$ is the loss penalty weight, typically 2.25.
5.2 Loss Function
The loss function for KTO is based on the utility function from Prospect Theory and is used to penalize the model’s prediction errors:
$$
\mathcal{L}{\text{KTO}} = -\mathbb{E}[\mathcal{U}(y{\text{pred}} - y_{\text{true}})]
$$
5.3 Optimizer
KTO commonly uses the AdamW optimizer to ensure stability during the training process:
1 | optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5) |
5.4 Core Code
The following is the code for calculating the KTO loss function:
1 | import torch |
Summary
Method | Loss Function | Optimizer |
---|---|---|
SFT | Cross-Entropy Loss | AdamW, RMSprop, SGD |
LoRA | Cross-Entropy Loss | AdamW, RMSprop, SGD |
DPO | Preference Loss Function: $\log(1 + \exp(-\sigma (\hat{y}_a - \hat{y}_b)p))$ | AdamW |
KTO | Prospect Theory Utility Function: $-\mathbb{E}[\mathcal{U}(y_{\text{pred}} - y_{\text{true}})]$ | AdamW |
Through the organization of this document, readers can clearly understand the principles, specific implementation steps, loss function designs, and optimizer selections for technologies such as SFT, LoRA, DPO, and KTO, especially in the context of fine-tuning large-scale pre-trained models like LLAMA3.