FastChat 训练脚本代码逐行解析-Train.py 【FastChat 系列第 1 篇】
FastChat 训练脚本代码逐行解析-Train.py 【FastChat 系列第 1 篇】
在本文中,我们将深入探讨 FastChat 的 train.py
脚本,这是一个用于训练和优化大型语言模型的关键组件。FastChat 是一个先进的开源平台,专注于开发、部署和评估基于大型语言模型(LLM)的聊天机器人。该平台不仅提供对顶尖模型如 Vicuna 和 MT-Bench 的支持,还包括一个分布式的多模型服务系统,配备了 Web UI 和与 OpenAI 兼容的 RESTful API,使用户能够高效地训练和评估他们的模型。
本文的深入分析将聚焦于 train.py
脚本的源代码。这个脚本是基于 transformers 库的自然语言处理模型训练脚本,涵盖了数据预处理、模型训练和保存等关键步骤。我们旨在提供对 train.py
中每个类和函数的详细解释,包括它们的功能和在整个训练过程中的作用。
1. 导入模块
1. 内置模块
这些是 Python 自带的标准库模块,无需额外安装。
1 | from dataclasses import dataclass, field |
导入 Python 的dataclasses
模块,用于创建带有默认值的类。
1 | import json |
导入json
模块,用于处理 JSON 格式的数据。
1 | import math |
导入math
模块,用于数学运算。
1 | import pathlib |
导入pathlib
模块,用于处理文件路径。
1 | from typing import Dict, Optional, Sequence |
导入typing
模块,用于类型注解。
1 | import numpy as np |
2. 依赖库
这些是外部安装的依赖库,通常通过包管理器如 pip 安装。
导入numpy
库,一个常用的科学计算库。
1 | import torch |
导入PyTorch
,一个流行的深度学习框架。
1 | from torch.utils.data import Dataset |
从torch
中导入Dataset
,用于创建自定义数据集。
1 | import transformers |
导入transformers
库,一个流行的自然语言处理库。
1 | from transformers import Trainer |
从transformers
中导入Trainer
,用于训练模型。
1 | from transformers.trainer_pt_utils import LabelSmoother |
从transformers
中导入LabelSmoother
,用于标签平滑。
3. 项目特定函数
这些是在 Fast Chat 项目中自定义实现的函数或类。
1 | from fastchat.conversation import SeparatorStyle |
从fastchat
包导入SeparatorStyle
,用于定义对话分隔符风格。SeparatorStyle
类是一个使用 Python 的 enum
模块创建的枚举类,用于定义一系列的分隔符样式。枚举(Enumeration)是一种编程概念,用于定义一组命名的常数,使代码更加清晰和易于维护。
在 SeparatorStyle
类中,每个成员代表一种特定的分隔符样式。这些样式通常用于文本处理中,特别是在需要区分不同部分或元素的情况下。例如,在处理对话或文本数据时,可能需要不同的方式来区分用户输入和机器回复。
关于 auto()
函数的使用:
auto()
是 Pythonenum
模块提供的一个特殊函数。它在枚举类中自动分配一个唯一的值给每个成员。- 在不使用
auto()
的情况下,你需要手动为每个枚举成员指定一个唯一的值。使用auto()
可以简化这个过程,让 Python 自动处理这些值的分配。 auto()
分配的值通常是整数,从 1 开始依次递增。
具体到 SeparatorStyle
类,auto()
被用来为每种分隔符样式自动分配一个唯一的整数值。例如,ADD_COLON_SINGLE
、ADD_COLON_TWO
等将分别被赋予不同的整数值。
每个枚举成员的名称(如 ADD_COLON_SINGLE
、NO_COLON_SINGLE
等)通常描述了该分隔符样式的特点。例如,ADD_COLON_SINGLE
可能表示在某个元素后添加一个冒号作为分隔符,而 NO_COLON_SINGLE
则表示不添加冒号。
这种方式使得在代码中引用和处理这些分隔符样式变得更加方便和清晰。例如,可以根据不同的场景或需求选择使用不同的分隔符样式,而无需记住它们对应的具体值。
1 | from fastchat.model.model_adapter import get_conversation_template |
从fastchat
包导入get_conversation_template
,用于获取对话模板。
在这段代码中,调用逻辑主要涉及到获取特定模型的默认对话模板。调用链路如下:
起始调用 -
get_conversation_template(model_path: str)
- 这个函数是调用链的起点。它接收一个参数
model_path
,用于指定模型的路径。 - 这个函数的目的是获取给定模型路径的默认对话模板。
- 这个函数是调用链的起点。它接收一个参数
调用
get_model_adapter(model_path: str)
get_conversation_template
函数首先调用get_model_adapter
,传入模型路径。get_model_adapter
的目的是根据提供的模型路径,找到并返回一个适合该模型的BaseModelAdapter
对象。- 这个函数首先尝试匹配
model_path
的基本名称(basename),如果没有找到匹配项,它会尝试匹配完整的路径。 - 如果找到合适的适配器,则返回该适配器;如果没有找到,则抛出一个
ValueError
。
执行
BaseModelAdapter.get_default_conv_template(model_path: str)
- 在获取到适当的模型适配器后,
get_conversation_template
通过调用该适配器的get_default_conv_template
方法来获取默认的对话模板。 - 注意这个方法在
BaseModelAdapter
类中定义,但可能在子类中被重写。
- 在获取到适当的模型适配器后,
调用
get_conv_template(name: str)
- 在
get_default_conv_template
方法内部,它调用get_conv_template
函数,通常传入一个预定义的模板名称,比如"one_shot"
。 get_conv_template
的作用是从全局注册的对话模板字典conv_templates
中获取指定名称的模板。
- 在
获取并返回
Conversation
对象get_conv_template
函数返回Conversation
类的一个实例,这通常是从conv_templates
字典中复制得到的。- 最终,这个
Conversation
实例被返回到最初调用get_conversation_template
的地方。
总结调用链路:
1 | get_conversation_template(model_path) |
在这个过程中,代码通过一系列函数调用,根据提供的模型路径,找到相应的模型适配器,并从中获取特定的对话模板。这种设计模式允许灵活地为不同的模型提供不同的对话模板,从而提高了代码的可重用性和可扩展性。
2. 配置类
这些类是使用 Python 的 dataclass
装饰器定义的,主要用于存储配置和参数。这些类通常不包含复杂的方法或逻辑,而是用于定义和存储数据结构。这些类包括:
ModelArguments
: 存储与模型相关的参数,如模型路径、远程代码信任等。DataArguments
: 存储与数据相关的参数,如数据路径、评估数据路径以及是否使用懒加载预处理。TrainingArguments
: 存储与训练相关的参数,如缓存目录、优化器类型、模型最大长度等。这个类继承自transformers.TrainingArguments
,增加了一些自定义参数。
这些类主要用于简化和组织代码中的参数管理,使得参数的修改和访问更加方便。
1. ModelArguments 类
Code
1 |
|
Explanation
ModelArguments
是一个数据类(dataclass
),用于存储与模型相关的配置参数。
属性:
model_name_or_path
: 指定预训练模型的名称或路径。trust_remote_code
: 是否允许使用自定义模型,这些模型在 Hub 上有自己的模型文件。padding_side
: 指定在分词器(tokenizer
)中使用的填充方式,通常是左填充或右填充。
`@dataclass`装饰器的介绍,点击展开
`@dataclass` 是一个装饰器,用于自动化生成特殊方法,如 `__init__()`、`__repr__()`、`__eq__()` 等,从而简化数据类的编写。这个装饰器是 Python 3.7 中引入的一部分,属于 `dataclasses` 模块。当你在一个类定义前使用 @dataclass
装饰器时,Python 会自动为这个类添加一些由属性定义的特殊方法。这对于创建存储少量数据但不需要复杂方法的类非常有用。
具体来说,使用 @dataclass
时:
自动生成构造函数(
__init__
方法):Python 会根据类中定义的字段自动创建一个__init__
方法,这样你就不需要手动编写这个方法来初始化类的实例了。自动生成
__repr__
方法:这使得打印类的实例时能够得到更具可读性的字符串表示,通常包含类名和其中的字段及其值。自动生成
__eq__
方法:这使得可以使用==
操作符来比较两个类的实例,比较的是实例中字段的值。支持类型注解:在定义字段时,你可以使用类型注解,这不仅有助于代码清晰性,还可以通过一些工具进行类型检查。
在ModelArguments
类的例子中,@dataclass
装饰器会为这个类生成上述的方法。这意味着你可以很方便地创建ModelArguments
的实例,并在打印或比较这些实例时得到预期的行为。
例如,当你创建一个ModelArguments
实例时:
1 | args = ModelArguments() |
这将调用自动生成的__init__
方法,使用默认值”facebook/opt-125m”为model_name_or_path
、False
为trust_remote_code
和”right”为padding_side
。
当你打印这个实例:
1 | print(args) |
这将调用自动生成的__repr__
方法,显示类实例的详细信息,如ModelArguments(model_name_or_path="facebook/opt-125m", trust_remote_code=False, padding_side="right")
。
这样,@dataclass
装饰器简化了类的创建过程,使得代码更加简洁和易于维护。
总的来说,@dataclass
装饰器是 Python 提供的一个便捷工具,用于快速创建主要用于存储数据的类。
2. DataArguments 类
Code
1 |
|
Explanation
DataArguments 类
DataArguments
也是一个数据类,用于存储数据相关的配置参数。- 属性:
data_path
: 训练数据的路径。eval_data_path
: 评估数据的路径。lazy_preprocess
: 是否在数据预处理时使用延迟加载,即在需要时才加载和处理数据。
3. TrainingArguments 类
Code
1 |
|
Explanation
TrainingArguments
类继承自 transformers.TrainingArguments
。。
TrainingArguments 类
TrainingArguments
是一个数据类,它通过继承transformers.TrainingArguments
,获得了处理训练参数的能力。- 在
TrainingArguments
中定义的属性:cache_dir
: 用于指定模型和分词器缓存的目录路径。optim
: 定义了要使用的优化器类型,例如'adamw_torch'
。model_max_length
: 指定模型能处理的最大序列长度。
transformers.TrainingArguments 类
transformers.TrainingArguments
是transformers
库中的一个类,用于配置模型训练过程中的各种参数。- 这个类包含大量的属性,用于控制训练过程,例如:
output_dir
: 指定保存模型和训练结果的目录。num_train_epochs
: 训练的轮数(epochs)。per_device_train_batch_size
: 每个设备上的训练批次大小。save_steps
: 保存模型的步数间隔。evaluation_strategy
: 评估模型的策略,如在每个 epoch 结束时进行评估。learning_rate
: 学习率。warmup_steps
: 在学习率调度中用于预热的步数。
transformers.TrainingArguments
还包含了许多其他参数,用于微调训练过程,包括日志记录、模型保存策略、学习率调度等。
通过继承 transformers.TrainingArguments
,TrainingArguments
类不仅继承了所有这些训练参数的配置能力,而且还可以添加一些自定义的训练参数,如本例中的 cache_dir
、optim
和 model_max_length
。这种做法提高了代码的可复用性和灵活性,使得您可以根据项目的具体需求调整和扩展训练配置。
3.功能型函数 (Functional Utility Functions)
1. rank0_print(*args)
Code
1 | local_rank = None |
Explanation
定义一个全局变量 local_rank,用于分布式训练。
定义一个函数 rank0_print,只在 local_rank 为 0 时打印信息,用于分布式训练中的信息输出控制。这样可以避免在多个节点上重复打印相同的信息,使得输出更加清晰和简洁。
- 用于只在分布式训练环境中的主节点(rank 0)上打印信息。
- 参数:可变数量的参数,用于打印。
2. trainer_save_model_safe(trainer: transformers.Trainer)
Code
1 | def trainer_save_model_safe(trainer: transformers.Trainer): |
函数 trainer_save_model_safe(trainer: transformers.Trainer)
旨在安全地保存使用 PyTorch 分布式框架训练的模型。让我们详细了解此函数及其涉及的关键组件。
Explanation
- 参数:
trainer
:transformers.Trainer
的实例。这个类是 Hugging Face Transformers 库的核心组件之一,用于训练和评估模型。
- 功能:
- 此函数的主要目的是在分布式训练环境中安全地保存模型。它特别考虑了使用
FullyShardedDataParallel
(FSDP) 进行训练时的模型保存策略。
- FSDP
- FullyShardedDataParallel (FSDP)
- 这是 PyTorch 分布式训练框架的一个组件。FSDP 通过将模型参数分片到多个 GPU 上来减少每个 GPU 的内存占用,从而实现更大模型的训练。
- 在此场景中,FSDP 主要用于处理和保存分布式训练中的模型状态。
- StateDictType
- 这是一个枚举类型,定义了如何保存模型的状态字典(state dict)。在 FSDP 环境中,保存和加载模型状态可能需要特殊的处理。
- FullStateDictConfig
- 这个类用于配置保存完整状态字典时的参数。它是 FSDP 功能的一部分,用于控制如何保存模型状态。
- 函数实现
- 设置保存策略
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
创建了一个保存策略。这里指定两个关键参数:offload_to_cpu
: 在保存状态字典之前,将模型参数卸载到 CPU,这有助于减少 GPU 内存的使用。rank0_only
: 只在 rank 0(通常是主节点)上保存模型。在分布式训练中,这可以避免每个节点都保存相同的模型副本,节省存储空间。
- 保存模型
- 使用
with FSDP.state_dict_type(trainer.model, StateDictType.FULL_STATE_DICT, save_policy)
上下文管理器设置模型保存的状态字典类型和策略。 - 在这个上下文内,调用
trainer.save_model()
来保存模型。由于使用了save_policy
,模型将根据上述配置安全地保存。
- 使用
函数 trainer_save_model_safe
封装了一个安全的模型保存逻辑,特别是针对使用 PyTorch 的 FSDP 进行分布式训练的场景。它确保了只在一个节点上保存完整的模型状态,并且在保存之前将模型参数转移到 CPU,从而优化内存使用和存储效率。这对于训练大型模型和管理大规模分布式训练环境至关重要。
3.preprocess(sources,tokenizer: transformers.PreTrainedTokenizer) -> Dict
Code
1 |
|
函数 preprocess(sources, tokenizer: transformers.PreTrainedTokenizer) -> Dict
用于预处理对话数据,使其适用于机器学习模型的训练。这个函数可以分为几个主要部分进行详细介绍:
1. 获取对话模板和角色定义
1 | conv = get_conversation_template("vicuna") |
- 功能: 初始化对话模板和定义对话参与者的角色。
- 实现:
conv = get_conversation_template("vicuna")
获取指定模型(如 “vicuna”)的对话模板。roles
字典将 “human” 和 “gpt” 分别映射到对话模板中定义的角色。
- 示例:
- 如果对话模板是 “vicuna”,则
roles
可能是{"human": "user", "gpt": "assistant"}
。conv = get_conversation_template("vicuna")
得到的模板如下:1
Conversation(name='vicuna_v1.1', system_template='{system_message}', system_message="A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.", roles=('USER', 'ASSISTANT'), messages=[], offset=0, sep_style=<SeparatorStyle.ADD_COLON_TWO: 2>, sep=' ', sep2='</s>', stop_str=None, stop_token_ids=None)
roles
将 “human” 映射到 “USER”,将 “gpt” 映射到 “ASSISTANT”。{'human': 'USER', 'gpt': 'ASSISTANT'}
- 如果对话模板是 “vicuna”,则
2. prompt 模板
1 | # Apply prompt templates |
功能: 为源数据应用提示模板,构建对话。
实现:
- 遍历
sources
(原始对话数据),将每个对话源转换为模板格式的对话。 - 如果对话的第一部分不是 “human” 角色发起,则跳过该部分。
- 为每个句子指定角色,并将其添加到对话模板中。
- 最终,每个处理后的对话被添加到
conversations
列表中。
- 遍历
示例:
- 假如我们的 source 是 dummy input 中的第一条数据:
python source = [{'from': 'human', 'value': 'Who are you?'}, {'from': 'gpt', 'value': 'I am Vicuna, a language model trained by researchers from Large Model Systems Organization (LMSYS).'}, {'from': 'human', 'value': 'Have a nice day!'}, {'from': 'gpt', 'value': 'You too!'}]
conversations
在 Vicuna template 下,我们会使用SeparatorStyle.ADD_COLON_TWO
作为分隔符风格,构成的数据可能是 [“A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user’s questions. USER: Who are you? ASSISTANT: I am Vicuna, a language model trained by researchers from Large Model Systems Organization (LMSYS).USER: Have a nice day! ASSISTANT: You too!“]get_prompt的实现
`get_prompt` 方法的实现根据不同的 `SeparatorStyle` 有着不同的行为。下面是一个表格,详细介绍了各种风格的 `get_prompt` 方法,以及对应的英文示例:分隔符风格 ( SeparatorStyle
)描述 示例 ADD_COLON_SINGLE
在每个消息后加冒号和分隔符。 USER: Hello there!\nASSISTANT: Hi, how can I help?\n ADD_COLON_TWO
使用两种分隔符交替,通常在不同角色之间切换。 USER: What’s the weather?\nASSISTANT: It’s sunny today.\n\n ADD_COLON_SPACE_SINGLE
消息后加冒号、空格和分隔符。 USER: Can you book a flight?\nASSISTANT: Sure, where to?\n NO_COLON_SINGLE
消息直接跟在角色后,不加冒号,后接分隔符。 USERWhat are you doing?\nASSISTANTI’m here to assist you.\n NO_COLON_TWO
无冒号,使用两种分隔符交替。 USERHow’s the project going?\nASSISTANTIt’s on track.\n\n ADD_NEW_LINE_SINGLE
每条消息前换行,消息后加分隔符。 USER\nHow can I reset my password?\nASSISTANT\nYou can reset it via email.\n RWKV
特殊格式,通常用于特定模型。 USER: What is AI?\n\nASSISTANT: AI stands for Artificial Intelligence.\n\n LLAMA2
特殊标签格式,针对特定模型。 [INST] USER How does blockchain work?\nASSISTANT It is a distributed ledger.\n\n CHATGLM
特定于 CHATGLM
模型的格式。[Round 1]\nUSER: Tell me a joke.\nASSISTANT: Why did the chicken cross the road?\n CHATML
类似 CHATGLM
,但每条消息前后都有换行。USER\nDo you like music?\n\nASSISTANT\nYes, I enjoy many genres.\n\n CHATGLM3
适用于 CHATGLM3
模型的格式。USER\nCan you play chess?\nASSISTANTYes, I can play.\n CHATINTERN
适用于 CHATINTERN
模型的格式,使用特殊标记。USER:Where is the nearest ATM?\nASSISTANT:It’s next to the post office.\nDOLLY
特定于 DOLLY
模型的格式。USER:\nWhat is quantum computing?\nASSISTANT:\nIt involves computation using quantum-mechanical phenomena.\n\n PHOENIX
适用于 PHOENIX
模型,消息被特殊标记包裹。USER: How to bake a cake?\nASSISTANT:You need flour, sugar, and eggs.\nROBIN
类似 ADD_NEW_LINE_SINGLE
,但角色后有换行。USER:\nIs AI dangerous?\nASSISTANT:\nIt depends on how it’s used.\n FALCON_CHAT
类似 ADD_COLON_SINGLE
,但可适用于FALCON
模型。USER: What is the capital of France?\nASSISTANT: It’s Paris.\n METAMATH
对话中使用特殊前缀和后缀,适用于 METAMATH
模型。USER:\nWhat is 2+2?\n: It’s 4\n DEEPSEEK_CHAT
适用于 DEEPSEEK
模型的特定格式。USER: What’s your favorite color?\nASSISTANT: I like blue.\n\n YUAN2
适用于 YUAN2
模型,特殊的分隔符应用。How are you today? I’m fine, thank you! get_prompt
方法的不同实现,可以灵活地适应各种需求,使对话生成或处理更加准确和高效。
- 假如我们的 source 是 dummy input 中的第一条数据:
3. 对话的分词
1 | # Tokenize conversations |
功能: 文本对话首先被分词处理,转换成模型能够处理的数值序列。然后,这些序列被克隆以形成初始的训练目标。这样做的目的是为了在训练过程中提供一个基准,指导模型学习生成正确的输出。在后续步骤中,这些目标可能会根据特定的训练目标进行调整。
实现:
tokenizer
函数接收文本列表(这里是conversations
),并返回一个包含数值化表示的input_ids
。return_tensors="pt"
指定返回的数据类型为 PyTorch 张量。padding="max_length"
和max_length=tokenizer.model_max_length
确保所有输入长度统一,不足的部分使用填充。truncation=True
表示如果输入过长,将其截断到最大长度。
- 在训练期间,模型需要知道期望的输出以计算损失和进行反向传播。这些期望的输出被称为 “targets”。
targets = input_ids.clone()
表示创建input_ids
的一个副本作为初始的目标。- 之所以需要克隆
input_ids
,是因为在许多语言模型训练任务中(特别是像自回归模型这样的生成任务),模型的目标输出往往与输入非常相似,但在某些细节上存在差异。 - 在后续步骤中,这个
targets
可能会根据特定的训练需求进一步修改或掩码(例如,在对话任务中,可能只对模型生成的回复部分计算损失,而不是整个对话)。
- 之所以需要克隆
4. 目标掩码
1 | # Mask targets. Only compute loss on the assistant outputs. |
功能: 对目标输出进行掩码处理,以便模型只对特定输出计算损失。目标是对生成的 targets(即模型的输出标签)进行掩码处理。这是为了确保在训练过程中只对助手(assistant)的输出计算损失,而不是整个对话。
实现:
sep = conv.sep + conv.roles[1] + ": "
定义了用于识别助手回复的分隔符。在这个例子中,sep
可能是 “\n\nAssistant: “。- 循环遍历每个处理后的对话 (
conversation
) 及其对应的目标 (target
)。 total_len
是当前目标序列中非填充(padding)部分的长度。turns
是将对话根据conv.sep2
分隔成不同轮次的列表。
- 对每个轮次进行处理
- 每个轮次(turn)包含用户和助手的消息。
- 使用
tokenizer(turn)
将每个轮次的文本转换为模型能理解的 ID 序列。 - 通过
parts = turn.split(sep)
分离用户和助手的消息。 instruction_len
是用户消息部分的长度(在某些情况下需要调整,比如-2
是为了适应特定的分词器)。
- 掩码目标
target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID
将用户消息部分的目标 ID 替换为IGNORE_TOKEN_ID
,这意味着在计算损失时会忽略这部分。cur_len
用于跟踪当前处理到的位置。- 每处理完一个轮次,更新
cur_len
。
- 最终处理
target[cur_len:] = IGNORE_TOKEN_ID
确保在最后一个轮次之后的所有内容都被忽略。- 如果
cur_len
小于tokenizer.model_max_length
,但不等于total_len
,则表示有不一致性,此时会发出警告,并将整个目标序列设置为IGNORE_TOKEN_ID
。
5. 返回处理后的数据
- 功能: 返回预处理后的数据,包括输入 ID、目标标签和注意力掩码。
- 实现:
- 返回一个字典,包含
input_ids
(模型输入)、labels
(训练目标)和attention_mask
(指示哪些部分是有效输入的掩码)。
- 返回一个字典,包含
总结
这个 preprocess
函数通过将原始文本数据转换为模型可以理解的格式,为训练准备数据。它涵盖了从文本处理到分词,再到目标掩码的整个预处理流程。这个过程对于任何基于对话的自然语言处理任务至关重要,特别是在需要模型专注于对特定部分的响应时。
4. 数据集类
这些类继承自 PyTorch 的 Dataset
类,并且是为特定的数据处理任务定制的。这些类包含具体的方法来处理和准备数据,以便用于模型训练。这些类包括:
SupervisedDataset
: 用于有监督学习的数据集。它处理原始数据,将其转换为适合模型训练的格式。LazySupervisedDataset
: 类似于SupervisedDataset
,但使用懒加载方式处理数据。这意味着数据只在需要时才被加载和处理,这对于处理大型数据集特别有用。
这些类通常包含 __init__
, __len__
, 和 __getitem__
方法,分别用于初始化数据集、获取数据集大小和检索特定索引的数据。这样的设计模式使得数据集可以轻松地与 PyTorch 的 DataLoader 配合使用,从而实现高效的数据加载和批处理。
1. SupervisedDataset 类
1 | class SupervisedDataset(Dataset): |
SupervisedDataset
类是一个用于有监督学习的数据集类,特别是为了微调(fine-tuning)任务设计。这个类继承自 PyTorch 的 Dataset
类,并重写了其方法以适应特定的数据处理需求。下面是对这个类的详细介绍:SupervisedDataset
类提供了一种结构化和高效的方法来处理和加载用于有监督学习的对话数据。它遵循 PyTorch 数据集(Dataset
)的标准结构,使得与 PyTorch 的数据加载器(DataLoader
)等其他组件兼容,从而方便在训练循环中使用。通过预处理步骤,该类确保数据以适当的格式提供给模型,以便进行有效的训练。
- 类名:
SupervisedDataset
- 继承:
Dataset
(来自 PyTorch) - 目的: 用于有监督的模型微调任务。
1.1 __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer)
- 调用时机:创建
SupervisedDataset
类的实例时。这通常发生在准备训练数据集的阶段,当你创建数据加载器(DataLoader)之前。 - 功能:初始化数据集实例,处理原始对话数据,并将其转换为模型可以理解的格式。
- 参数:
raw_data
:包含对话数据的列表或类似结构。tokenizer
:一个预训练的分词器实例,用于将文本转换为模型可以处理的格式。
- 返回值:无返回值,但此方法会设置
input_ids
、labels
和attention_mask
作为类的内部状态。 - 实现细节:
- 使用列表推导式从
raw_data
中提取每个样本的对话内容。 - 调用
preprocess
函数处理这些对话,将其转换为适合模型输入的格式。 - 从返回的
data_dict
中提取input_ids
(模型输入 ID)、labels
(目标标签)和attention_mask
(注意力掩码)。
- 使用列表推导式从
1.2 __len__(self)
- 调用时机:当需要获取数据集大小时,例如在设置数据加载器时,或者在训练循环中迭代数据集时。
- 功能:返回数据集中的样本数量。
- 返回值:一个整数,表示数据集中的样本数量。
- 实现: 直接返回
input_ids
的长度,即样本的数量。
1.3 __getitem__(self, i)
- 调用时机:在数据加载器请求数据集的特定样本时,这通常发生在训练或评估循环的每个迭代中。
- 功能:获取指定索引
i
处的数据样本。 - 参数:
i
:所请求样本的索引。
- 返回值:一个字典,包含索引
i
处样本的input_ids
、labels
和attention_mask
。这些是 PyTorch 张量(torch.Tensor
),适用于模型的训练或评估。
在有监督学习的场景中,SupervisedDataset
类扮演着数据预处理和封装的角色,确保数据以正确的格式提供给模型。__init__
方法在数据集实例化时调用,负责数据的初始化和预处理。__len__
和 __getitem__
方法则在训练和评估过程中被频繁调用,分别用于获取数据集的大小和提取特定的数据样本。这些方法的设计和实现使得 SupervisedDataset
类可以无缝地与 PyTorch 的其他数据处理和训练工具集成。
2. LazySupervisedDataset 类
1 | class LazySupervisedDataset(Dataset): |
LazySupervisedDataset
类是另一种数据集实现,用于有监督的模型微调。与 SupervisedDataset
相比,它采用了一种“懒加载”(lazy loading)的策略。以下是对该类的详细解释,以及它与非懒加载版本的比较。
2.1 __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer)
- 作用:初始化
LazySupervisedDataset
实例。 - 实现细节:
- 将原始数据 (
raw_data
) 和分词器 (tokenizer
) 保存为类的属性。 - 初始化一个空字典
cached_data_dict
,用于缓存已处理的数据。
- 将原始数据 (
- 与
SupervisedDataset
的差异:- 在
LazySupervisedDataset
中,原始数据不是在初始化时立即处理,而是存储原始形式以便稍后处理。 cached_data_dict
用于缓存按需处理的数据,以避免重复处理。
- 在
2.2 __len__(self)
作用:返回数据集中的样本数量。
实现:直接返回原始数据 (
raw_data
) 的长度。与
SupervisedDataset
的差异:在
LazySupervisedDataset
类中,__len__
方法确实返回的是数据集中样本的数量,但是这里的“样本数量”是指原始数据 (raw_data
) 中的样本数量,而不是处理后的数据的数量。由于LazySupervisedDataset
采用懒加载策略,数据在初始化时并未被处理,因此__len__
方法基于原始数据计算长度是合理的。这意味着即便数据尚未被转换为模型可用的格式,
__len__
方法仍能准确反映数据集中待处理样本的数量。这与SupervisedDataset
的主要区别在于后者在初始化时就对所有数据进行预处理,因此其__len__
方法返回的是已处理数据的数量。而在LazySupervisedDataset
中,数据处理是按需进行的,因此__len__
返回的是原始数据中的样本数。
__getitem__(self, i)
- 作用:按需获取并处理指定索引
i
处的数据样本。 - 实现细节:
- 首先检查索引
i
是否在缓存cached_data_dict
中。 - 如果是,则直接返回缓存的数据;如果不是,则处理原始数据中索引
i
处的样本,并将处理后的结果添加到缓存中。 - 返回一个包含
input_ids
、labels
和attention_mask
的字典。
- 首先检查索引
- 与
SupervisedDataset
的差异:LazySupervisedDataset
在__getitem__
被调用时才处理数据,而SupervisedDataset
在初始化时就处理所有数据。LazySupervisedDataset
使用缓存来避免重复处理同一样本,而SupervisedDataset
不需要这种机制,因为所有数据在初始化时就已经被处理。
懒加载 vs 非懒加载
- 懒加载(Lazy Loading):
- 优点:减少内存占用,因为只有需要时才处理数据。对于大型数据集非常有用。
- 缺点:可能增加训练时的数据加载时间,尤其是当缓存未命中时。
- 非懒加载(Eager Loading):
- 优点:在训练开始前一次性处理所有数据,可以减少训练过程中的延迟。
- 缺点:需要更多的初始内存来存储处理后的所有数据,对于非常大的数据集可能不实用。
3. make_supervised_data_module
函数
1 | def make_supervised_data_module( |
函数 make_supervised_data_module
的目的是为有监督的模型微调创建数据集和数据整理器(collator)。这个函数根据提供的参数构建适合训练和评估的数据集。下面是对这个函数的超级详细解释:
函数签名
- 函数名:
make_supervised_data_module
- 参数:
tokenizer
:transformers.PreTrainedTokenizer
的实例,用于文本的分词处理。data_args
: 包含数据相关设置的对象,通常包括数据文件路径等信息。
- 返回值:一个字典,包含训练和评估数据集。
函数实现细节
1. 选择数据集类
- 根据
data_args.lazy_preprocess
的值选择使用LazySupervisedDataset
还是SupervisedDataset
类。- 如果
data_args.lazy_preprocess
为True
,则使用LazySupervisedDataset
实现懒加载。 - 否则,使用
SupervisedDataset
进行预加载。
- 如果
- 这一选择影响数据的加载方式,即数据是一次性全部加载并预处理,还是按需加载和处理。
2. 加载训练数据
- 使用
json.load(open(data_args.data_path, "r"))
加载训练数据。- 这里假设训练数据以 JSON 格式存储,并且
data_args.data_path
包含了数据文件的路径。
- 这里假设训练数据以 JSON 格式存储,并且
- 创建训练数据集
train_dataset
实例,传入加载的训练数据和分词器。
3. 加载评估数据(如果提供)
- 检查是否提供了评估数据路径
data_args.eval_data_path
。- 如果提供,同样使用
json.load
加载评估数据。 - 创建评估数据集
eval_dataset
实例。
- 如果提供,同样使用
- 如果没有提供评估数据路径,将
eval_dataset
设置为None
。
4. 返回结果
- 返回一个字典,包含两个键:
train_dataset
和eval_dataset
。train_dataset
对应训练数据集实例。eval_dataset
对应评估数据集实例(如果有的话)。
5. 训练流程
了解了,我将为 train()
函数中每行代码提供更详细的解释:
解析命令行参数
1 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) |
- 创建一个
HfArgumentParser
实例,这是一个帮助解析命令行参数的工具,特别用于处理 Hugging Face transformers 库中的参数。
1 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
- 解析命令行参数并将它们映射到三个数据类 (
ModelArguments
,DataArguments
,TrainingArguments
) 的实例中。
1 | local_rank = training_args.local_rank |
local_rank
用于标识分布式训练中的进程编号。training_args.local_rank
获取这个编号。
设置模型配置
1 | config = transformers.AutoConfig.from_pretrained( |
- 从预训练模型的配置创建
AutoConfig
实例。它自动加载与特定模型相关的配置。 model_args.model_name_or_path
: 指定模型的名称或路径。cache_dir=training_args.cache_dir
: 指定缓存目录。trust_remote_code=model_args.trust_remote_code
: 指定是否信任从远程下载的代码。
1 | orig_ctx_len = getattr(config, "max_position_embeddings", None) |
- 使用
getattr
函数从配置中获取max_position_embeddings
属性,该属性指示模型的最大位置嵌入数(即模型能处理的最大序列长度)。如果不存在该属性,则返回None
。getattr
是一个 Python 内置函数,用于获取对象的属性值。如果属性不存在,返回第三个参数指定的默认值(此处为None
)。 orig_ctx_len
存储模型配置中的max_position_embeddings
属性值,即模型可以处理的最大位置嵌入数(通常与最大序列长度相关)。
1 | if orig_ctx_len and training_args.model_max_length > orig_ctx_len: |
- 如果提供的模型最大长度 (
training_args.model_max_length
) 超过了原始模型的最大长度 (orig_ctx_len
),则计算一个缩放因子以进行位置编码的调整。这通常用于处理超出预训练模型原始设计的序列长度。 rope_scaling
用于调整相对位置编码。scaling_factor
和 RoPE 缩放- 如果模型的最大长度超过原始配置的最大长度,
scaling_factor
被用来计算缩放因子。 - 这涉及到 Rotary Positional Embedding(RoPE)的概念,即在位置嵌入中使用的技术,可以随序列长度线性缩放。
- 缩放因子用于调整位置嵌入,使其适应更长的序列。
- 如果模型的最大长度超过原始配置的最大长度,
1 | config.use_cache = False |
- 禁用模型在前向传播时缓存中间计算结果的功能,这有助于减少内存消耗。这个设置告诉模型在前向传播时不使用或保存缓存。
加载模型和分词器
1 | model = transformers.AutoModelForCausalLM.from_pretrained( |
- 加载预训练的因果语言模型(Causal Language Model)。这类模型通常用于生成任务。
trust_remote_code
这个参数用于确定是否信任从远程(如 Hugging Face Hub)加载的自定义模型代码。cache_dir=training_args.cache_dir
- 指定下载和缓存预训练模型和分词器的目录。
- 如果指定,模型和分词器将从这个目录加载,如果不存在,将从远程下载并缓存到此目录。
1 | tokenizer = transformers.AutoTokenizer.from_pretrained( |
- 加载与模型对应的分词器。
use_fast=False
: 表示不使用快速分词器,快速分词器通常是基于 Rust 的分词器,提供更高效的分词处理。padding_side=model_args.padding_side
: 指定填充(padding)应该发生在序列的哪一侧。
1 | if tokenizer.pad_token != tokenizer.unk_token: |
- 将分词器的填充令牌设置为未知令牌(
unk_token
),如果它们不一致的话。这是因为某些模型需要在填充位置使用特定的令牌。
加载数据
1 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) |
- 调用
make_supervised_data_module
函数,为训练和评估准备数据集。这个函数会根据data_args
中的设置,选择使用懒加载或预加载的方式处理数据。
初始化并启动训练器
1 | trainer = Trainer( |
- 初始化
Trainer
对象,传入模型、分词器、训练参数以及通过make_supervised_data_module
函数准备好的数据。
1 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): |
- 检查是否存在训练检查点,如果存在,则从检查点恢复训练;如果不存在,开始新的训练过程。
保存模型
1 | model.config.use_cache = True |
- 启用模型的缓存并保存训练器的状态。
1 | if trainer.is_deepspeed_enabled: |
- 检查是否启用了 DeepSpeed。如果启用了,则使用
trainer.save_model()
保存模型。如果没有启用 DeepSpeed,则使用trainer_save_model_safe
安全地保存模型,特别是在使用分布式训练时。
Trainer 类解释
Trainer
是 Hugging Face Transformers 库提供的一个类,用于封装模型的训练逻辑。以下是对 Trainer
类的功能的详细介绍:
模型训练与评估:
Trainer
类负责设置和执行模型的训练和评估过程。它自动处理数据的批处理、梯度计算、优化器步骤和设备管理等任务。参数:在初始化时,
Trainer
接受多种参数,包括模型(model
)、分词器(tokenizer
)、训练参数(如学习率、批大小等,通过training_args
传入)和数据集。灵活性和高级功能:
Trainer
支持多种训练设置,如多 GPU 训练、混合精度训练和 TPU 训练。它还支持自定义回调函数,用于在训练过程中执行特定操作。简化 API:
Trainer
类提供了一个简化的 API,使得用户可以用几行代码配置和运行模型训练。它抽象了许多底层细节,使得用户可以专注于模型的构建和训练策略。检查点和恢复:
Trainer
支持保存和加载检查点,这意味着训练过程可以在中断后从上次保存的状态恢复。
总体来说,Trainer
类是一个功能强大且灵活的工具,为训练复杂的 Transformer 模型提供了便利和高效性。
6. run bash
1 | torchrun --nproc_per_node=8 --master_port=20001 fastchat/train/train.py \ |
torchrun
torchrun
是 PyTorch 提供的一个命令行工具,用于启动分布式训练。它是 torch.distributed.launch
模块的一部分,旨在简化在多个进程上运行 PyTorch 程序的过程。以下是对 torchrun
中使用的参数的详细解释:
--nproc_per_node=8
--nproc_per_node
指定每个节点(在这种情况下通常是一台机器)上要启动的进程数。这里设置为 8,意味着在当前节点上将启动 8 个训练进程。- 作用:用于控制每个节点上的并行度。在多 GPU 系统中,这通常等于 GPU 的数量。
--master_port=20001
--master_port
指定主节点用于通信的端口。这里设置为 20001。- 作用:在分布式训练中,不同进程需要通过网络进行通信。这个参数指定了用于进程间通信的端口。
fastchat/train/train.py
- 这不是
torchrun
的参数,而是指定了要执行的 Python 脚本,即训练脚本的路径。
- 这不是
在分布式训练中,torchrun
负责在每个进程中正确地设置环境变量,如 LOCAL_RANK
(当前进程在其节点上的排名)、WORLD_SIZE
(总进程数)和 RANK
(全局进程排名)。这些环境变量对于使用 PyTorch 分布式包(如 torch.distributed
)进行有效通信至关重要。
使用示例
假设您有一台拥有 8 个 GPU 的机器,您想在所有 GPU 上并行运行训练。使用 torchrun
,您的命令可能如下所示:
1 | torchrun --nproc_per_node=8 --master_port=20001 fastchat/train/train.py --其他参数 |
这个命令会在每个 GPU 上启动一个训练进程,每个进程运行 train.py
脚本,并且所有进程能够通过分布式通信有效协作。
torchrun
是分布式训练的关键工具,它简化了在多个进程上启动 PyTorch 程序的流程,特别是在多 GPU 环境中。通过自动设置必要的环境变量,torchrun
使得实现和运行分布式训练变得更加容易和可靠。
2. 参数
--model_name_or_path
- 可以是预训练模型的官方名称(如 “bert-base-uncased”)、自定义训练的模型路径或 Hugging Face Model Hub 上的模型。
- 作用:指定用于训练的模型。
--data_path
- 路径可以是本地文件系统上的路径。
- 作用:指定训练使用的数据文件。
--fp16
- 可取值为 True 或 False。
- 作用:启用或禁用混合精度训练,以提高训练速度和降低显存使用。
--output_dir
- 任何有效的文件路径。
- 作用:指定输出目录,用于保存训练过程中产生的文件。
--num_train_epochs
- 任何正整数。
- 作用:指定训练的轮次。
--per_device_train_batch_size
和--per_device_eval_batch_size
- 任何正整数。
- 作用:分别指定每个设备上的训练和评估批次大小。
--gradient_accumulation_steps
- 任何正整数。
- 作用:指定梯度累积的步骤数,用于在有限的显存下增加有效的批次大小。
--evaluation_strategy
- 可取值包括 “no”、”steps”、”epoch”。
- 作用:指定评估的策略,如每个 epoch 或特定步数后进行评估,或不进行评估。
--save_strategy
- 可取值包括 “no”、”steps”、”epoch”。
- 作用:指定模型保存的策略。
--save_steps
和--save_total_limit
--save_steps
取任何正整数。--save_total_limit
取任何正整数或 None。- 作用:分别指定保存模型的步数间隔和最大保存的检查点数量。
--learning_rate
- 任何正浮点数。
- 作用:指定优化器的学习率。
--weight_decay
- 任何非负浮点数。
- 作用:指定权重衰减,用于正则化。
--warmup_ratio
- 任何非负浮点数,通常在 0 到 1 之间。
- 作用:指定预热的比例,即学习率在初始阶段逐渐增加的过程。
--lr_scheduler_type
- 可取值如 “linear”、”cosine”、”cosine_with_restarts”、”polynomial” 等。
- 作用:指定学习率调度器的类型。
--logging_steps
- 任何正整数。
- 作用:指定记录日志的步数间隔。
--fsdp
- 可取值如 “full_shard”、”auto_wrap” 等,或它们的组合。
- 作用:指定使用全分片数据并行(Fully Sharded Data Parallel)的配置。
--fsdp_transformer_layer_cls_to_wrap
- 指定要在 FSDP 中包装的特定层的类名。
- 作用:针对大型模型的分布式训练进行优化。
--model_max_length
- 任何正整数。
- 作用:指定模型处理的最大序列长度。
--gradient_checkpointing
- 可取值为 True 或 False。
- 作用:启用或禁用梯度检查点,以减少显存使用。
--lazy_preprocess
- 可取值为 True 或 False。
- 作用:启用或禁用懒加载预处理,即按需加载和处理数据。
这些参数共同构成了一个复杂的训练配置,允许用户根据特定需求灵活调整模型训练过程。
7. 总结
随着本文的结束,我们完成了对 FastChat 平台中 train.py 脚本的深入解析,这只是我们系列技术博客中的第一部分。在这一部分中,我们聚焦于 train.py 脚本的结构和功能,涵盖了从数据预处理到模型训练和保存等关键步骤。通过这次解析,读者不仅能够更好地理解 FastChat 平台的工作原理,还能获得如何有效利用这个工具进行大型语言模型训练的宝贵知识。
随着我们技术博客系列的不断展开,我们将继续深入探索 FastChat 的其他组件和功能。接下来的文章将进一步拓展我们的讨论范围,涉及到更多高级功能和实际应用场景。我们期望这些内容能够为对 AI 和机器学习感兴趣的读者提供更全面、深入的见解。
最后,我们鼓励读者持续关注我们的博客,以获取关于 FastChat 及其在大型语言模型训练领域应用的最新信息和分析。无论您是该领域的专家还是初学者,我们相信这个系列将为您提供价值和启发。敬请期待我们下一篇文章的发布,它将为您揭开 FastChat 更多令人兴奋的面纱。