BMTrain’s Documentation!
BMTrain is an efficient large model training toolkit that can be used to train large models with tens of billions of parameters. It can train models in a distributed manner while keeping the code as simple as stand-alone training.
Installation
Install BMTrain
1. From PyPI (Recommend)
$ pip install bmtrain
2. From Source
$ git clone https://github.com/OpenBMB/BMTrain.git
$ cd BMTrain
$ python3 setup.py install
Compilation Options
By setting environment variables, you can configure the compilation options of BMTrain (by default, the compilation environment will be automatically adapted).
AVX Instructions
Force the use of AVX instructions:
BMT_AVX256=ON
Force the use of AVX512 instructions:
BMT_AVX512=ON
CUDA Compute Capability
TORCH_CUDA_ARCH_LIST=6.0 6.1 7.0 7.5 8.0+PTX
Recommended Configuration
Network:Infiniband 100Gbps / RoCE 100Gbps
GPU:NVIDIA Tesla V100 / NVIDIA Tesla A100 / RTX 3090
CPU:CPU that supports AVX512 instructions, 32 cores or above
RAM:256GB or above
FAQ
If the following error message is reported during compilation, try using a newer version of the gcc compiler.
error: invalid static_cast from type `const torch::OrderdDict<...>`
Quick Start
Step 1: Initialize BMTrain
Before you can use BMTrain, you need to initialize it at the beginning of your code. Just like using the distributed module of PyTorch requires the use of init_process_group at the beginning of the code, using BMTrain requires the use of init_distributed at the beginning of the code.
import bmtrain as bmt
bmt.init_distributed(
seed=0,
# ...
)
NOTE: Do not use PyTorch’s distributed module and its associated communication functions when using BMTrain.
Step 2: Enable ZeRO-3 Optimization
To enable ZeRO-3 optimization, you need to make some simple replacements to the original model’s code.
torch.nn.Module
->bmtrain.DistributedModule
torch.nn.Parameter
->bmtrain.DistributedParameter
And wrap the transformer blocks with bmtrain.CheckpointBlock
.
Here is an example.
Original
import torch
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.empty(1024))
self.module_list = torch.nn.ModuleList([
SomeTransformerBlock(),
SomeTransformerBlock(),
SomeTransformerBlock()
])
def forward(self):
x = self.param
for module in self.module_list:
x = module(x, 1, 2, 3)
return x
Replaced
import torch
import bmtrain as bmt
class MyModule(bmt.DistributedModule): # changed here
def __init__(self):
super().__init__()
self.param = bmt.DistributedParameter(torch.empty(1024)) # changed here
self.module_list = torch.nn.ModuleList([
bmt.CheckpointBlock(SomeTransformerBlock()), # changed here
bmt.CheckpointBlock(SomeTransformerBlock()), # changed here
bmt.CheckpointBlock(SomeTransformerBlock()) # changed here
])
def forward(self):
x = self.param
for module in self.module_list:
x = module(x, 1, 2, 3)
return x
Step 3: Enable Communication Optimization
To further reduce the extra overhead of communication and overlap communication with computing time, TransformerBlockList
can be used for optimization.
You can enable them by making the following substitutions to the code:
torch.nn.ModuleList
->bmtrain.TransformerBlockList
for module in self.module_list: x = module(x, ...)
->x = self.module_list(x, ...)
Original
import torch
import bmtrain as bmt
class MyModule(bmt.DistributedModule):
def __init__(self):
super().__init__()
self.param = bmt.DistributedParameter(torch.empty(1024))
self.module_list = torch.nn.ModuleList([
bmt.CheckpointBlock(SomeTransformerBlock()),
bmt.CheckpointBlock(SomeTransformerBlock()),
bmt.CheckpointBlock(SomeTransformerBlock())
])
def forward(self):
x = self.param
for module in self.module_list:
x = module(x, 1, 2, 3)
return x
Replaced
import torch
import bmtrain as bmt
class MyModule(bmt.DistributedModule):
def __init__(self):
super().__init__()
self.param = bmt.DistributedParameter(torch.empty(1024))
self.module_list = bmt.TransformerBlockList([ # changed here
bmt.CheckpointBlock(SomeTransformerBlock()),
bmt.CheckpointBlock(SomeTransformerBlock()),
bmt.CheckpointBlock(SomeTransformerBlock())
])
def forward(self):
x = self.param
x = self.module_list(x, 1, 2, 3) # changed here
return x
Step 4: Launch Distributed Training
BMTrain uses the same launch command as the distributed module of PyTorch.
You can choose one of them depending on your version of PyTorch.
${MASTER_ADDR}
means the IP address of the master node.${MASTER_PORT}
means the port of the master node.${NNODES}
means the total number of nodes.${GPU_PER_NODE}
means the number of GPUs per node.${NODE_RANK}
means the rank of this node.
torch.distributed.launch
$ python3 -m torch.distributed.launch --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} --nproc_per_node ${GPU_PER_NODE} --nnodes ${NNODES} --node_rank ${NODE_RANK} train.py
torchrun
$ torchrun --nnodes=${NNODES} --nproc_per_node=${GPU_PER_NODE} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} train.py
For more information, please refer to the documentation.
Other Notes
BMTrain
makes underlying changes to PyTorch, so if your program outputs unexpected results, you can submit information about it in an issue.
For more examples, please refer to the examples folder.
Introduction to Core Technology
ZeRO-3 Optimization
Overlap Communication and Computation
CPU Offload
bmtrain
Initialization
Distributed Parameters and Modules
Methods for Parameters
Utilities
bmtrain.nccl
bmtrain.inspect
The bmtrain.inspect module is a module for debugging and analysis of distributed code.
We recommend that you use the tools in this module to obtain the parameters and computing results in distributed models.
bmtrain.lr_scheduler
The bmtrain.lr_scheduler module provides the common learning rate schedulers for big model training.