FastChat Training Script Code Analysis - Train.py 【FastChat Series Part 1】
FastChat Training Script Code Analysis - Train.py 【FastChat Series Part 1】
In this article, we delve into the train.py script of FastChat (https://github.com/lm-sys/FastChat) (https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py), a key component for training and optimizing large language models (LLMs). FastChat is an advanced open-source platform focused on developing, deploying, and evaluating chatbots based on LLMs. The platform not only supports top-tier models like Vicuna and MT-Bench but also includes a distributed multi-model service system equipped with a Web UI and RESTful API compatible with OpenAI, enabling efficient training and evaluation of models.
We provide a detailed analysis of the train.py script’s source code. This script is a training script for natural language processing models based on the transformers library, covering critical steps such as data preprocessing, model training, and saving. Our goal is to offer a detailed explanation of each class and function in train.py, including their functionality and role in the overall training process.
1. Importing Modules
1. Built-in Modules
These are standard library modules that come with Python and don’t require additional installation.
1 | from dataclasses import dataclass, field |
Imports Python’s dataclasses
module for creating classes with default values.
1 | import json |
Imports the json
module for handling JSON format data.
1 | import math |
Imports the math
module for mathematical operations.
1 | import pathlib |
Imports the pathlib
module for handling file paths.
1 | from typing import Dict, Optional, Sequence |
Imports the typing
module for type annotations.
1 | import numpy as np |
2. Dependency Libraries
These are external libraries typically installed via a package manager like pip.
Imports the numpy
library, commonly used for scientific computing.
1 | import torch |
Imports PyTorch
, a popular deep learning framework.
1 | from torch.utils.data import Dataset |
Imports Dataset
from torch
for creating custom datasets.
1 | import transformers |
Imports the transformers
library, a popular natural language processing library.
1 | from transformers import Trainer |
Imports Trainer
from transformers
for training models.
1 | from transformers.trainer_pt_utils import LabelSmoother |
Imports LabelSmoother
from transformers
for label smoothing.
3. Project-Specific Functions
These are functions or classes custom-implemented in the Fast Chat project.
1 | from fastchat.conversation import SeparatorStyle |
Imports SeparatorStyle
from the fastchat
package for defining conversation separator styles. The SeparatorStyle
class is an enumeration class created using Python’s enum
module, defining a series of separator styles. Enumerations are a programming concept used to define a named set of constants, making code clearer and more maintainable.
In the SeparatorStyle
class, each member represents a specific style of separator. These styles are often used in text processing, especially in scenarios where different sections or elements need to be distinguished. For instance, in handling dialog or textual data, different methods might be needed to differentiate between user input and machine responses.
Regarding the use of the auto()
function:
auto()
is a special function provided by Python’senum
module. It automatically assigns a unique value to each member in an enumeration class.- Without using
auto()
, you would need to manually assign a unique value to each enumeration member.auto()
simplifies this process by letting Python handle the assignment of these values automatically. - The values assigned by
auto()
are usually integers, starting from 1 and increasing sequentially.
In the case of the SeparatorStyle
class, auto()
is used to automatically assign a unique integer value to each type of separator style. For example, ADD_COLON_SINGLE
, ADD_COLON_TWO
, etc., will be given different integer values.
The names of each enumeration member (such as ADD_COLON_SINGLE
, NO_COLON_SINGLE
, etc.) typically describe the characteristics of that separator style. For instance, ADD_COLON_SINGLE
might represent adding a colon as a separator after a certain element, whereas NO_COLON_SINGLE
means no colon is added.
This approach makes referencing and handling these separator styles in the code more convenient and clear. For example, different separator styles can be chosen based on different scenarios or requirements without having to remember their specific values.
1 | from fastchat.model.model_adapter import get_conversation_template |
Imports get_conversation_template
from the fastchat
package for obtaining conversation templates. In this code segment, the call logic primarily involves obtaining the default conversation template for a specific model. The call chain is as follows:
Starting Call -
get_conversation_template(model_path: str)
- This function is the starting point of the call chain. It accepts a parameter
model_path
, specifying the path of the model. - The purpose of this function is to obtain the default conversation template for the given model path.
- This function is the starting point of the call chain. It accepts a parameter
Call
get_model_adapter(model_path: str)
- The
get_conversation_template
function first callsget_model_adapter
, passing in the model path. - The purpose of
get_model_adapter
is to find and return a suitableBaseModelAdapter
object for the provided model path. - This function first tries to match the basename of
model_path
. If no match is found, it tries the full path. - If a suitable adapter is found, it is returned; otherwise, a
ValueError
is thrown.
- The
Execute
BaseModelAdapter.get_default_conv_template(model_path: str)
- Once the appropriate model adapter is obtained,
get_conversation_template
retrieves the default conversation template by calling theget_default_conv_template
method of that adapter. - Note that this method is defined in the
BaseModelAdapter
class but might be overridden in subclasses.
- Once the appropriate model adapter is obtained,
Call
get_conv_template(name: str)
- Inside the
get_default_conv_template
method, it calls theget_conv_template
function, usually passing a predefined template name like"one_shot"
. - The purpose of
get_conv_template
is to retrieve a specified name’s template from the global registry of conversation templatesconv_templates
.
- Inside the
Obtain and Return a
Conversation
Object- The
get_conv_template
function returns an instance of theConversation
class, usually copied from theconv_templates
dictionary. - Finally, this
Conversation
instance is returned to the original call site ofget_conversation_template
.
- The
Summarizing the call chain:
1 | get_conversation_template(model_path) |
In this process, the code navigates through a series of function calls to find a suitable model adapter based on the provided model path and retrieve a specific conversation template from it. This design pattern allows flexibility in providing different conversation templates for different models, enhancing the reusability and extensibility of the code.
2. Configuration Classes
These classes are defined using Python’s dataclass
decorator and are mainly used for storing configurations and parameters. These classes usually do not contain complex methods or logic but are used to define and store data structures. These classes include:
ModelArguments
: Stores parameters related to the model, like model path, trust in remote code, etc.DataArguments
: Stores parameters related to data, like data path, evaluation data path, and whether to use lazy preprocessing.TrainingArguments
: Stores parameters related to training, like cache directory, optimizer type, model maximum length, etc. This class extendstransformers.TrainingArguments
and adds some custom parameters.
These classes are mainly used to simplify and organize parameter management in the code, making parameter modification and access more convenient.
1. ModelArguments Class
Code
1 |
|
Explanation
ModelArguments
is a data class (dataclass
) used for storing model-related configuration parameters.
Attributes:
model_name_or_path
: Specifies the name or path of the pretrained model.trust_remote_code
: Whether to allow custom models that have their modeling files defined on the Hub.padding_side
: Specifies the padding side in the tokenizer, typically right or left padding.
Introduction to `@dataclass` decorator, click to expand
`@dataclass` is a decorator used to automate the generation of special methods like `__init__()`, `__repr__()`, `__eq__()` etc., thus simplifying the writing of data classes. This decorator is part of Python 3.7 and is in the `dataclasses` module.When you use @dataclass
before a class definition, Python automatically adds some special methods based on the fields defined in the class. This is very useful for creating classes that store a small amount of data but do not need complex methods.
Specifically, using @dataclass
:
Automatically generates a constructor (
__init__
method): Python creates an__init__
method automatically based on the fields defined in the class, so you don’t need to manually write this method to initialize your class instances.Automatically generates a
__repr__
method: This makes printing the class instances provide a more readable string representation, usually including the class name and its fields and their values.Automatically generates an
__eq__
method: This allows you to use the==
operator to compare two instances of the class, comparing the values of the instance fields.Support for type annotations: When defining fields, you can use type annotations, which not only help with clarity of code but can also be checked for type correctness using some tools.
In the case of the ModelArguments
class, the @dataclass
decorator will generate the above-mentioned methods. This means you can easily create an instance of ModelArguments
, and when printing or comparing these instances, you will get the expected behavior.
For example, when you create an instance of ModelArguments
:
1 | args = ModelArguments() |
This will call the automatically generated __init__
method, using the default values “facebook/opt-125m” for model_name_or_path
, False
for trust_remote_code
, and “right” for padding_side
.
When you print this instance:
1 | print(args) |
This will call the automatically generated __repr__
method, showing a detailed view of the class instance, like ModelArguments(model_name_or_path="facebook/opt-125m", trust_remote_code=False, padding_side="right")
.
Thus, the @dataclass
decorator simplifies the process of creating classes, making the code more concise and maintainable.
Overall, the @dataclass
decorator is a convenient tool provided by Python for quickly creating classes mainly used for storing data.
2. DataArguments Class
Code
1 |
|
Explanation
DataArguments Class
DataArguments
is also a data class used for storing data-related configuration parameters.- Attributes:
data_path
: Path to the training data.eval_data_path
: Path to the evaluation data.lazy_preprocess
: Whether to use lazy loading for data preprocessing, i.e., load and process data as needed.
3. TrainingArguments Class
Code
1 |
|
Explanation
TrainingArguments
class extends transformers.TrainingArguments
.
TrainingArguments Class
TrainingArguments
is a data class that, by extendingtransformers.TrainingArguments
, gains the capability to handle training parameters.- Attributes defined in
TrainingArguments
:cache_dir
: Specifies the directory path for caching the model and tokenizer.optim
: Defines the type of optimizer to use, like'adamw_torch'
.model_max_length
: Specifies the maximum sequence length the model can handle.
transformers.TrainingArguments Class
transformers.TrainingArguments
is a class in the transformers library that is used for configuring various parameters in the model training process.- This class contains a plethora of attributes for controlling the training process, such as:
output_dir
: Specifies the directory to save the model and training results.num_train_epochs
: Number of training epochs.per_device_train_batch_size
: Batch size per device for training.save_steps
: Steps interval for saving the model.evaluation_strategy
: Strategy for evaluating the model, like at the end of each epoch.learning_rate
: Learning rate.warmup_steps
: Steps used for warmup in the learning rate schedule.
transformers.TrainingArguments
also
contains many other parameters for fine-tuning the training process, including logging, model saving strategies, learning rate scheduling, and more.
By extending transformers.TrainingArguments
, the TrainingArguments
class not only inherits all these training parameter configurations but can also add some custom training parameters, like in this case cache_dir
, optim
, and model_max_length
. This approach enhances code reusability and flexibility, allowing you to adjust and extend training configurations as per the specific requirements of your project.
3. Functional Utility Functions
1. rank0_print(*args)
Code
1 | local_rank = None |
Explanation
Defines a global variable local_rank for distributed training.
Defines a function rank0_print to print information only if local_rank is 0, used for controlling output in distributed training. This way, repetitive printing of the same information across multiple nodes is avoided, making the output clearer and more concise.
- Used to print information only on the main node (rank 0) in a distributed training environment.
- Parameters: A variable number of arguments for printing.
2. trainer_save_model_safe(trainer: transformers.Trainer)
Code
1 | def trainer_save_model_safe(trainer: transformers.Trainer): |
The function trainer_save_model_safe(trainer: transformers.Trainer)
aims to safely save models trained with the PyTorch distributed framework. Let’s delve into the details of this function and its key components.
Explanation
Parameters:
trainer
: An instance oftransformers.Trainer
. This class is one of the core components of the Hugging Face Transformers library, used for training and evaluating models.
Functionality:
- The main purpose of this function is to safely save models in a distributed training environment. It particularly considers the model saving strategy when using Fully Sharded Data Parallel (FSDP).
FSDP
- FullyShardedDataParallel (FSDP)
- This is a component of PyTorch’s distributed training framework. FSDP helps reduce memory usage on each GPU by sharding model parameters across multiple GPUs, allowing the training of larger models.
- In this context, FSDP is primarily used for handling and saving model states in distributed training.
- StateDictType
- This is an enumeration type that defines how to save the model’s state dictionary. In FSDP environments, saving and loading model states might require special handling.
- FullStateDictConfig
- This class configures parameters for saving the full state dictionary. It’s part of FSDP’s functionality and is used to control how the model state is saved.
- FullyShardedDataParallel (FSDP)
Function Implementation
- Setting Save Policy
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
creates a save policy. Here, two key parameters are specified:offload_to_cpu
: Offload model parameters to CPU before saving the state dictionary, which helps reduce GPU memory usage.rank0_only
: Save the model only on rank 0 (usually the main node). In distributed training, this avoids saving the same model copy on every node, saving storage space.
- Saving the Model
- Using the
with FSDP.state_dict_type(trainer.model, StateDictType.FULL_STATE_DICT, save_policy)
context manager, the type and policy for saving the model’s state dictionary are set. - Within this context,
trainer.save_model()
is called to save the model. Due to thesave_policy
, the model is saved securely following the specified configuration.
- Using the
- Setting Save Policy
The function trainer_save_model_safe
encapsulates a safe model saving logic, particularly for scenarios involving PyTorch’s FSDP in distributed training. It ensures that only a complete model state is saved on one node and offloads model parameters to CPU before saving, optimizing memory usage and storage efficiency. This is crucial for training large models and managing large-scale distributed training environments.
3.preprocess(sources,tokenizer: transformers.PreTrainedTokenizer) -> Dict
Code
1 |
|
The function preprocess(sources, tokenizer: transformers.PreTrainedTokenizer) -> Dict
is intended for preprocessing dialogue data to be suitable for training machine learning models. This function can be broken down into several main parts for a more detailed explanation:
1. Obtaining Conversation Templates and Role Definitions
1 | conv = get_conversation_template("vicuna") |
- Functionality: Initializes conversation templates and defines the roles of dialogue participants.
- Implementation:
conv = get_conversation_template("vicuna")
obtains the conversation template for a specified model (e.g., “vicuna”).- The
roles
dictionary maps “human” and “gpt” to the roles defined in the conversation template.
- Example:
- If the conversation template is for “vicuna”, then
roles
might map “human” to “user” and “gpt” to “assistant”. For example,{'human': 'USER', 'gpt': 'ASSISTANT'}
.
- If the conversation template is for “vicuna”, then
2. Applying Prompt Templates
1 | # Apply prompt templates |
- Functionality: Applies prompt templates to source data to construct dialogues.
- Implementation:
- Iterates through
sources
(original dialogue data), transforming each dialogue source into a conversation in template format. - If the first part of a dialogue is not initiated by the “human” role, it skips that part.
- Assigns a role to each sentence and adds it to the conversation template.
- Ultimately, each processed dialogue is added to the
conversations
list.
- Iterates through
- Example:
- Suppose we have a source which is the first item in dummy input:
python source = [{'from': 'human', 'value': 'Who are you?'}, {'from': 'gpt', 'value': 'I am Vicuna, a language model trained by researchers from Large Model Systems Organization (LMSYS).'}, {'from': 'human', 'value': 'Have a nice day!'}, {'from': 'gpt', 'value': 'You too!'}]
conversations
under the Vicuna template, usingSeparatorStyle.ADD_COLON_TWO
as the separator style, might look like [“A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user’s questions. USER: Who are you? ASSISTANT: I am Vicuna, a language model trained by researchers from Large Model Systems Organization (LMSYS).USER: Have a nice day! ASSISTANT: You too!“]Implementation of get_prompt
The `get_prompt` method implementation varies depending on the `SeparatorStyle`. Below is a table detailing the `get_prompt` method for various styles, along with English examples:Separator Style ( SeparatorStyle
)Description Example ADD_COLON_SINGLE
Adds a colon and separator after each message. USER: Hello there!\nASSISTANT: Hi, how can I help?\n ADD_COLON_TWO
Uses two alternating separators, usually between different roles. USER: What’s the weather?\nASSISTANT: It’s sunny today.\n\n ADD_COLON_SPACE_SINGLE
Adds a colon, space, and separator after each message. USER: Can you book a flight?\nASSISTANT: Sure, where to?\n NO_COLON_SINGLE
Messages directly follow roles without a colon, followed by a separator. USERWhat are you doing?\nASSISTANTI’m here to assist you.\n NO_COLON_TWO
No colons, with two alternating separators. USERHow’s the project going?\nASSISTANTIt’s on track.\n\n ADD_NEW_LINE_SINGLE
Each message is preceded by a newline, followed by a separator. USER\nHow can I reset my password?\nASSISTANT\nYou can reset it via email.\n RWKV
Special format, usually for specific models. USER: What is AI?\n\nASSISTANT: AI stands for Artificial Intelligence.\n\n LLAMA2
Special label format for specific models. [INST] USER How does blockchain work?\nASSISTANT It is a distributed ledger.\n\n CHATGLM
Specific format for CHATGLM
model.[Round 1]\nUSER: Tell me a joke.\nASSISTANT: Why did the chicken cross the road?\n CHATML
Similar to CHATGLM
, but with newlines before and after each message.USER\nDo you like music?\n\nASSISTANT\nYes, I enjoy many genres.\n\n CHATGLM3
Format for CHATGLM3
model.USER\nCan you play chess?\nASSISTANTYes, I can play.\n CHATINTERN
Format for CHATINTERN
model, using special markers.USER:Where is the nearest ATM?\nASSISTANT:It’s next to the post office.\nDOLLY
Specific format for DOLLY
model.USER:\nWhat is quantum computing?\nASSISTANT:\nIt involves computation using quantum-mechanical phenomena.\n\n PHOENIX
For PHOENIX
model, messages are wrapped in special markers.USER: How to bake a cake?\nASSISTANT:You need flour, sugar, and eggs.\nROBIN
Similar to ADD_NEW_LINE_SINGLE
, but with a newline after roles.USER:\nIs AI dangerous?\nASSISTANT:\nIt depends on how it’s used.\n
- Suppose we have a source which is the first item in dummy input: