LoRA, DPO, KTO 与 SFT 技术详解
LoRA, DPO, KTO 与 SFT 技术详解
本篇文档将详细介绍几种在大型语言模型(如 LLAMA3)微调和优化中的重要技术,包括 SFT(Supervised Fine-Tuning)、LoRA(Low-Rank Adaptation)、Alignment 技术、KTO(Kahneman-Tversky Optimization) 和 DPO(Direct Preference Optimization)。文中还将详细阐述每种技术的原理、具体实现方法以及相应的损失函数与优化器选择。
1. SFT(Supervised Fine-Tuning)
1.1 原理
SFT 是一种传统的微调方法,通过监督学习对预训练模型进行微调,调整模型的参数使其在特定任务上表现更好。SFT 通常用于针对特定的标注数据进行模型微调,训练的过程类似于常规的监督学习。
1.2 实现方法
- 选择预训练模型:如 GPT、BERT 等语言模型。
- 准备标注数据集:数据集包含输入和输出对。
- 训练模型:使用标准的交叉熵损失函数对模型进行训练,通过梯度下降优化参数。
1.3 核心代码
使用 Hugging Face 的 Trainer
接口进行 SFT:
1 | from transformers import Trainer, TrainingArguments, AutoModelForSeq2SeqLM, AutoTokenizer |
2. LoRA(Low-Rank Adaptation)
2.1 原理
LoRA 是一种参数高效的微调技术,通过对大模型中的权重矩阵进行低秩分解,将原始权重矩阵 $W$ 分解为两个低秩矩阵 $B$ 和 $A$,并仅对这些低秩矩阵进行微调。LoRA 的设计目标是减少微调参数的数量,在保留预训练模型权重的同时,通过调整低秩矩阵来优化模型表现。
2.2 实现方法
- 权重分解:对于模型的线性层(如注意力机制中的
q_proj
和v_proj
层),将权重矩阵分解为两个低秩矩阵 $B$ 和 $A$。 - 微调特定层:仅对这些特定的线性层应用 LoRA,而模型中的其他层保持不变。
2.3 可微调的层与不变的层
可微调的层
LoRA 通常应用于 Transformer 模型中的线性投影层,尤其是多头注意力机制中的几个关键层:
- q_proj(Query 投影层)
- k_proj(Key 投影层)
- v_proj(Value 投影层)
- o_proj(Output 投影层)
- ffn_up_proj 和 ffn_down_proj(前馈神经网络的上下投影层)
不变的层
- Embedding 层:负责输入和输出的编码,通常不需要微调。
- LayerNorm 层:这些层主要用于归一化,不含大量参数,通常保持不变。
- 激活函数层:如 ReLU 或 GELU 等非线性激活函数不涉及参数,不需要进行微调。
2.4 损失函数
LoRA 的损失函数通常与具体任务相关。在语言生成任务中,LoRA 使用交叉熵损失来度量生成文本和目标文本之间的差异:
$$
\mathcal{L}{\text{LoRA}} = - \sum{i} y_i \log(\hat{y}_i)
$$
其中 $y_i$ 是真实标签,$\hat{y}_i$ 是模型的输出概率。
2.5 优化器
LoRA 微调通常使用 AdamW 优化器,具体代码如下:
1 | optimizer = torch.optim.AdamW(lora_model.parameters(), lr=5e-5) |
2.6 核心代码
使用 peft
库实现 LoRA:
1 | from peft import LoraConfig, get_peft_model |
3. Alignment(对齐技术)
在引入KL散度之前,我们首先需要明确LLM对齐(Alignment)是如何实现的,以及背后的原理和数学公式。
1. 什么是模型对齐(Alignment)?
模型对齐的核心目标是让语言模型的输出符合人类的期望或偏好。通常,模型最初通过大规模监督学习(SFT,Supervised Fine-Tuning)训练,生成具有基础能力的模型。接下来,通过对齐技术,进一步调整模型,使其生成的内容更符合人类偏好或避免产生有害、错误的信息。
对齐的核心机制:
- 正样本:符合人类预期的输出(如正确回答)。
- 负样本:不符合人类预期的输出(如错误回答)。
通过使用成对偏好数据或标签(正确/错误),对模型的输出进行进一步微调,使模型能够生成更多的正样本,同时减少负样本的生成概率。
2. 模型对齐的数学原理
在对齐过程中,模型会通过策略模型(Policy Model)来生成输出,策略模型通常是经过SFT训练的语言模型,用来在给定输入下生成输出。为了优化模型的输出,使其更加符合人类偏好,常常使用以下损失函数和优化方法:
2.1 策略模型
假设当前模型的策略为 $\pi_\theta$,它表示在给定输入 $x$ 时,模型生成输出 $y$ 的概率:
$$
\pi_\theta(y|x)
$$
策略模型的目标是通过调整参数 $\theta$,提高生成正确输出(正样本)的概率,降低生成错误输出(负样本)的概率。
2.2 提高正样本概率与降低负样本概率的机制
为了实现这个目标,通常使用带有偏好比较或标签的损失函数进行优化:
正样本的优化:通过增加正样本的损失权重,使得模型生成正样本的概率更高。
- 正样本的损失函数会引导模型在面对相同问题时,生成更多符合人类期望的答案。
负样本的惩罚:对负样本施加更高的损失权重,模型会学习到减少这些错误输出的概率。
- 负样本的损失函数旨在让模型在生成错误答案时感知到更大的惩罚,从而减少这些输出的生成。
在某些方法中,例如DPO和KTO,还会通过计算当前策略模型与参考模型之间的KL散度,来防止模型在优化过程中过度偏离原始预训练模型。
3. 损失函数与KL散度的作用
在模型对齐的过程中,损失函数通常包含两部分:
- 偏好损失或标签损失,用于优化模型生成符合人类期望的输出。
- KL散度,用于约束模型不要偏离参考模型。
3.1 KL散度的作用
KL散度(Kullback-Leibler Divergence)衡量的是两个概率分布之间的差异。在模型对齐中,KL散度用于限制当前模型 \(\pi_\theta\) 和参考模型 \(\pi_{\text{ref}}\) 的分布差异,确保在优化过程中模型的输出不会过度偏离预训练模型。具体公式为:
$$
\text{KL}(\pi_\theta(y|x) | \pi_{\text{ref}}(y|x)) = \sum_y \pi_\theta(y|x) \log \frac{\pi_\theta(y|x)}{\pi_{\text{ref}}(y|x)}
$$
- 如果KL散度较大,表示当前模型生成的分布与参考模型有较大的差异,这可能意味着模型生成了不合理的输出。
- 通过最小化KL散度,模型能够在保证输出合理性的基础上,进行进一步的优化。
3.2 损失函数公式
根据偏好或标签,模型的损失函数可以表达为以下形式:
DPO中的损失函数:
$$
L_{DPO} = -\mathbb{E}_{x, y_w, y_l \sim D} [\log \sigma(\beta (\log(\pi_\theta(y_w|x)) - \log(\pi_\theta(y_l|x))))]
$$
- $y_w$:偏好较高的答案。
- $y_l$:偏好较低的答案。
DPO中可以引入KL散度作为正则化项:
$$
L_{DPO} = -\log \sigma(\log(\pi_\theta(y_w|x)) - \log(\pi_\theta(y_l|x)) + \beta \cdot \text{KL}(\pi_\theta | \pi_{\text{ref}}))
$$
通过控制KL散度,模型的输出不会偏离参考模型太多。
KTO中的损失函数:
KTO的损失函数基于前景理论,并将KL散度作为核心部分,表达为:
$$
L_{KTO} = \lambda_U \cdot \sigma(\beta \cdot (\text{KL}(\pi_\theta(\text{Answer 2}) | \pi_{\text{ref}}) - r_{\theta}(\text{Answer 2})))
$$
- $r_{\theta}(x, y)$:当前策略对负样本(错误答案)的置信度。
- KL散度用于衡量当前模型与参考模型的差异,确保模型在减少负样本生成的同时,不偏离原始参考模型。
通过增加负样本的损失(即增加 $\lambda_U$ 的值),模型会降低负样本的置信度,使未来生成类似错误答案的概率变小。
4. 如何优化模型
通过上面介绍的损失函数,模型的优化通常是通过梯度下降(Gradient Descent)来完成的。损失函数的梯度反映了模型输出与期望输出之间的差异,优化目标是最小化损失函数。
梯度更新公式:
$$
\theta_{\text{new}} = \theta_{\text{old}} - \eta \nabla_{\theta} L
$$
其中:
- $\eta$ 是学习率,决定每次参数更新的步长。
- $\nabla_{\theta} L$ 是损失函数对模型参数的梯度,表示当前参数对损失的贡献。
通过不断迭代,模型会逐渐提高生成正样本的概率,减少负样本的生成概率,最终实现模型对齐。
- 模型对齐(Alignment)的核心目标是通过偏好或标签数据,优化模型的输出,使其符合人类期望。
- 策略模型($\pi_\theta$)生成输出,KL散度用于控制模型与参考模型的偏离程度,避免模型在优化过程中产生不合理的偏差。
- 正样本的概率通过损失函数的优化逐步提升,负样本的概率通过增加损失权重和降低置信度来减少。
- 梯度下降用于更新模型参数,最终实现模型对齐
4. DPO(Direct Preference Optimization)
4.1 原理
DPO 通过直接优化模型输出的偏好函数,使模型的输出更加符合人类偏好。它比较模型的不同输出,并通过偏好函数评估这两个输出哪个更好,从而指导模型参数的优化。
4.2 损失函数
DPO 使用偏好损失函数(Preference Loss),用于比较两个输出的优劣:
$$
\mathcal{L}_{\text{DPO}} = \log(1 + \exp(-\sigma \cdot (\hat{y}_a - \hat{y}_b) \cdot p))
$$
- $ \hat{y}_a $ 和 $ \hat{y}_b $ 是模型对两个样本的预测值。
- $ p $ 是人类偏好(1 表示偏好 $a$,-1 表示偏好 $b$)。
- $ \sigma $ 是平滑参数。
4.3 优化器
DPO 通常使用 AdamW 优化器,适用于大规模参数模型的优化,代码如下:
1 | optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) |
4.4 核心代码
以下是 DPO 的训练步骤:
1 | import torch |
5. KTO(Kahneman-Tversky Optimization)
5.1 原理
KTO 基于 Kahneman 和 Tversky 的前景理论(Prospect Theory),通过非对称效用函数衡量模型的增益和损失,旨在优化模型的表现,尤其在风险和收益不对称的场景下。效用函数定义如下:
$$
\mathcal{U}(x) =
\begin{cases}
x^{\alpha}, & x \geq 0 \
-\lambda (-x)^{\alpha}, & x < 0
\end{cases}
$$
- $x$ 是模型预测与真实值的差异。
- $\alpha$ 是非线性系数,通常为 0
.88。
- $\lambda$ 是损失的惩罚权重,通常为 2.25。
5.2 损失函数
KTO 的损失函数基于前景理论的效用函数,用于惩罚模型的预测误差:
$$
\mathcal{L}{\text{KTO}} = -\mathbb{E}[\mathcal{U}(y{\text{pred}} - y_{\text{true}})]
$$
5.3 优化器
KTO 常使用 AdamW 优化器,以确保训练过程的稳定性:
1 | optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5) |
5.4 核心代码
以下是 KTO 损失函数的计算代码:
1 | import torch |
总结
方法 | 损失函数 | 优化器 |
---|---|---|
SFT | 交叉熵损失 | AdamW,RMSprop,SGD |
LoRA | 交叉熵损失 | AdamW,RMSprop,SGD |
DPO | 偏好损失函数: $\log(1 + \exp(-\sigma (\hat{y}_a - \hat{y}_b)p))$ | AdamW |
KTO | 前景理论效用函数: $-\mathbb{E}[\mathcal{U}(y_{\text{pred}} - y_{\text{true}})]$ | AdamW |
通过本文档的整理,读者能够清晰理解 SFT、LoRA、DPO 和 KTO 等技术的原理、具体实现步骤、损失函数设计和优化器选择,特别是在 LLAMA3 这种大规模预训练模型的微调场景下的实际应用。