BMTrain’s Documentation!

BMTrain is an efficient large model training toolkit that can be used to train large models with tens of billions of parameters. It can train models in a distributed manner while keeping the code as simple as stand-alone training.


Installation

Install BMTrain

1. From PyPI (Recommend)

$ pip install bmtrain

2. From Source

$ git clone https://github.com/OpenBMB/BMTrain.git
$ cd BMTrain
$ python3 setup.py install

Compilation Options

By setting environment variables, you can configure the compilation options of BMTrain (by default, the compilation environment will be automatically adapted).

AVX Instructions

  • Force the use of AVX instructions: BMT_AVX256=ON

  • Force the use of AVX512 instructions: BMT_AVX512=ON

CUDA Compute Capability

TORCH_CUDA_ARCH_LIST=6.0 6.1 7.0 7.5 8.0+PTX

FAQ

If the following error message is reported during compilation, try using a newer version of the gcc compiler.

error: invalid static_cast from type `const torch::OrderdDict<...>`

Quick Start

Step 1: Initialize BMTrain

Before you can use BMTrain, you need to initialize it at the beginning of your code. Just like using the distributed module of PyTorch requires the use of init_process_group at the beginning of the code, using BMTrain requires the use of init_distributed at the beginning of the code.

import bmtrain as bmt
bmt.init_distributed(
    seed=0,
    # ...
)

NOTE: Do not use PyTorch’s distributed module and its associated communication functions when using BMTrain.

Step 2: Enable ZeRO-3 Optimization

To enable ZeRO-3 optimization, you need to make some simple replacements to the original model’s code.

  • torch.nn.Module -> bmtrain.DistributedModule

  • torch.nn.Parameter -> bmtrain.DistributedParameter

And wrap the transformer blocks with bmtrain.CheckpointBlock.

Here is an example.

Original

import torch
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.empty(1024))
        self.module_list = torch.nn.ModuleList([
            SomeTransformerBlock(),
            SomeTransformerBlock(),
            SomeTransformerBlock()
        ])
    
    def forward(self):
        x = self.param
        for module in self.module_list:
            x = module(x, 1, 2, 3)
        return x

Replaced

import torch
import bmtrain as bmt
class MyModule(bmt.DistributedModule): # changed here
    def __init__(self):
        super().__init__()
        self.param = bmt.DistributedParameter(torch.empty(1024)) # changed here
        self.module_list = torch.nn.ModuleList([
            bmt.CheckpointBlock(SomeTransformerBlock()), # changed here
            bmt.CheckpointBlock(SomeTransformerBlock()), # changed here
            bmt.CheckpointBlock(SomeTransformerBlock())  # changed here
        ])
    
    def forward(self):
        x = self.param
        for module in self.module_list:
            x = module(x, 1, 2, 3)
        return x
    

Step 3: Enable Communication Optimization

To further reduce the extra overhead of communication and overlap communication with computing time, TransformerBlockList can be used for optimization.

You can enable them by making the following substitutions to the code:

  • torch.nn.ModuleList -> bmtrain.TransformerBlockList

  • for module in self.module_list: x = module(x, ...) -> x = self.module_list(x, ...)

Original

import torch
import bmtrain as bmt
class MyModule(bmt.DistributedModule):
    def __init__(self):
        super().__init__()
        self.param = bmt.DistributedParameter(torch.empty(1024))
        self.module_list = torch.nn.ModuleList([
            bmt.CheckpointBlock(SomeTransformerBlock()),
            bmt.CheckpointBlock(SomeTransformerBlock()),
            bmt.CheckpointBlock(SomeTransformerBlock())
        ])
    
    def forward(self):
        x = self.param
        for module in self.module_list:
            x = module(x, 1, 2, 3)
        return x
    

Replaced

import torch
import bmtrain as bmt
class MyModule(bmt.DistributedModule):
    def __init__(self):
        super().__init__()
        self.param = bmt.DistributedParameter(torch.empty(1024))
        self.module_list = bmt.TransformerBlockList([ # changed here
            bmt.CheckpointBlock(SomeTransformerBlock()),
            bmt.CheckpointBlock(SomeTransformerBlock()),
            bmt.CheckpointBlock(SomeTransformerBlock())
        ])
    
    def forward(self):
        x = self.param
        x = self.module_list(x, 1, 2, 3) # changed here
        return x
    

Step 4: Launch Distributed Training

BMTrain uses the same launch command as the distributed module of PyTorch.

You can choose one of them depending on your version of PyTorch.

  • ${MASTER_ADDR} means the IP address of the master node.

  • ${MASTER_PORT} means the port of the master node.

  • ${NNODES} means the total number of nodes.

  • ${GPU_PER_NODE} means the number of GPUs per node.

  • ${NODE_RANK} means the rank of this node.

torch.distributed.launch

$ python3 -m torch.distributed.launch --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} --nproc_per_node ${GPU_PER_NODE} --nnodes ${NNODES} --node_rank ${NODE_RANK} train.py

torchrun

$ torchrun --nnodes=${NNODES} --nproc_per_node=${GPU_PER_NODE} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} train.py

For more information, please refer to the documentation.

Other Notes

BMTrain makes underlying changes to PyTorch, so if your program outputs unexpected results, you can submit information about it in an issue.

For more examples, please refer to the examples folder.

Introduction to Core Technology

ZeRO-3 Optimization

_images/ZeRO3.png

Overlap Communication and Computation

_images/communication_fig.png

CPU Offload

_images/cpu.png

bmtrain

Initialization

Distributed Parameters and Modules

Methods for Parameters

Utilities

bmtrain.nccl

bmtrain.inspect

The bmtrain.inspect module is a module for debugging and analysis of distributed code.

We recommend that you use the tools in this module to obtain the parameters and computing results in distributed models.

bmtrain.lr_scheduler

The bmtrain.lr_scheduler module provides the common learning rate schedulers for big model training.

LR Schedulers

API