Skip to content

Commit

Permalink
support bmt.save/load save model partition instead of whole model
Browse files Browse the repository at this point in the history
  • Loading branch information
MayDomine committed May 8, 2024
1 parent 151f679 commit fd7ac11
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 29 deletions.
2 changes: 1 addition & 1 deletion bmtrain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .wrapper import BMTrainModelWrapper
from .pipe_layer import PipelineTransformerBlockList
from . import debug
from .store import save, load
from .store import save, load, clean

from . import loss
from . import distributed
Expand Down
14 changes: 11 additions & 3 deletions bmtrain/block_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,12 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
def state_dict(self, destination=None, prefix='', keep_vars=False):
# gather here
with torch.no_grad():
with ZeroContext(self):
if config['save_param_gather']:
with ZeroContext(self):
return self._module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
else:
return self._module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)


def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
Expand All @@ -330,8 +334,10 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
tp_mode = param._tp_mode
if input_param.__class__.__name__ == "DistributedTensorWrapper":
input_param = input_param.broadcast()

verify_shape = torch.Size(it["shape"] if not tp_mode else param._tp_original_shape)
if config['load_param_gather']:
verify_shape = torch.Size(it["shape"] if not tp_mode else param._tp_original_shape)
else:
verify_shape = param.shape
if input_param.shape != verify_shape:
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'
Expand All @@ -353,6 +359,8 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
# copy to buffer
verify_size = verify_shape.numel()
assert input_param.numel() == verify_size
if not config['load_param_gather']:
continue

contiguous_param = input_param.to(it["parameter"].dtype).cuda().contiguous()

Expand Down
2 changes: 2 additions & 0 deletions bmtrain/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def init_distributed(
config["tp_rank"] = config['topology'].get_group_rank("tp")
config["tp_zero_rank"] = config['topology'].get_group_rank("tp_zero")
config["save_param_to_cpu"] = True
config["save_param_gather"] = True
config["load_param_gather"] = True
cpus_this_worker = None

all_available_cpus = sorted(list(os.sched_getaffinity(0)))
Expand Down
16 changes: 11 additions & 5 deletions bmtrain/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,13 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
for name, param in self._parameters.items():
if param is not None:
if isinstance(param, DistributedParameter):#and not param._in_block:
if param._in_block:
destination[prefix + name] = param.tp_gather().detach() # sync operation
if config["save_param_gather"]:
if param._in_block:
destination[prefix + name] = param.tp_gather().detach() # sync operation
else:
destination[prefix + name] = param.gather_all().detach() # sync operation
else:
destination[prefix + name] = param.gather_all().detach() # sync operation
destination[prefix + name] = param.clone().detach() # sync operation
if config['save_param_to_cpu']:
destination[prefix + name] = destination[prefix + name].cpu()
else:
Expand Down Expand Up @@ -110,14 +113,17 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
'the shape in current model is {}.'
.format(key, input_param.shape, param.shape))
continue
verify_shape = torch.Size(param._original_shape if not tp_mode else param._tp_original_shape)
if config['load_param_gather']:
verify_shape = torch.Size(param._original_shape if not tp_mode else param._tp_original_shape)
else:
verify_shape = param.shape
if not is_param_lazy and isinstance(param, DistributedParameter) and input_param.shape != verify_shape:
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'
.format(key, input_param.shape, verify_shape))
try:
with torch.no_grad():
if isinstance(param, DistributedParameter):
if isinstance(param, DistributedParameter) and config['load_param_gather']:
tp_split_dim = param._tp_split_dim
if tp_mode and tp_split_dim >= 0:
input_param = tp_split_tensor(input_param, tp_split_dim)
Expand Down
67 changes: 59 additions & 8 deletions bmtrain/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Mapping
import threading
import bmtrain as bmt
import os

def _save_to_state_dict(model : torch.nn.Module, rank, destination, prefix):
if isinstance(model, Block):
Expand All @@ -24,6 +25,21 @@ def _save_to_state_dict(model : torch.nn.Module, rank, destination, prefix):
destination._metadata = OrderedDict()
model._save_to_state_dict(destination, prefix, False)

def _save_to_each_rank(model : torch.nn.Module, destination=None, prefix=''):
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict(version=model._version)
_save_to_state_dict(model, 0, destination, prefix)
for name, module in model._modules.items():
if module is not None:
_save_to_each_rank(module, destination, prefix + name + '.')
for hook in model._state_dict_hooks.values():
hook_result = hook(model, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result
return destination

def _save_to_local_rank0(model : torch.nn.Module, destination=None, prefix=''):
if destination is None:
destination = OrderedDict()
Expand Down Expand Up @@ -88,7 +104,7 @@ def async_save_to_file(state_dict, file_path):
config['finish_save'] = True
print("finish save state_dict to ", file_path)

def save(model : torch.nn.Module, file_name : str, non_blocking : bool=False):
def save(model : torch.nn.Module, file_name : str, non_blocking : bool=False, save_gather : bool=True):
"""Saves the model to the file.
Similar to torch.save, but it used for distributed modules.
Expand All @@ -100,11 +116,18 @@ def save(model : torch.nn.Module, file_name : str, non_blocking : bool=False):
Examples:
>>> bmtrain.save(model, "model.pt")
>>> bmtrain
"""
torch.cuda.synchronize()
state_dict = _save_to_rank0(model)
if config["rank"] == 0:
if save_gather:
save_method = _save_to_rank0
else:
save_method = _save_to_each_rank
file_name = f"{file_name}_rank_{bmt.rank()}"
tmp = bmt.config['save_param_gather']
bmt.config['save_param_gather'] = save_gather
state_dict = save_method(model)
if config["rank"] == 0 or not save_gather:
if non_blocking is False:
torch.save(state_dict, file_name)
else:
Expand All @@ -118,6 +141,9 @@ def save(model : torch.nn.Module, file_name : str, non_blocking : bool=False):
config['save_thread'] = threading.Thread(target=async_save_to_file, args=(state_dict, file_name))
config['save_thread'].start()
bmt.synchronize()
bmt.config['save_param_gather'] = tmp



DTYPE_LIST = [
torch.float64,
Expand Down Expand Up @@ -299,7 +325,7 @@ def __iter__(self):
# pytorch 1.12.0 updated the load_state_dict method, which needs the state_dict to be a `Mapping`.
return iter(self.keys())

def load(model : torch.nn.Module, file_name : str, strict : bool = True):
def load(model : torch.nn.Module, file_name : str, strict : bool = True, load_gather : bool = True):
"""Loads the model from the file.
Similar to torch.load, but it uses less memory when loading large models.
Expand All @@ -312,14 +338,39 @@ def load(model : torch.nn.Module, file_name : str, strict : bool = True):
Example:
>>> bmtrain.load(model, "model.pt", strict=True)
"""
if config['rank'] == 0:
state_dict = DistributedStateDictWrapper(torch.load(file_name))
tmp = config['load_param_gather']
config['load_param_gather'] = load_gather
if load_gather:
if config['rank'] == 0:
state_dict = DistributedStateDictWrapper(torch.load(file_name))
else:
state_dict = DistributedStateDictWrapper({})
else:
state_dict = DistributedStateDictWrapper({})
if "rank" not in file_name:
file_name = f"{file_name}_rank_{bmt.rank()}"
state_dict = torch.load(file_name)

ret = model.load_state_dict(
state_dict,
strict = strict
)
config['load_param_gather'] = tmp
torch.cuda.synchronize()
return ret

def clean(file_name : str):
"""Cleans the file.
Args:
file_name (str): The file name of the checkpoint.
Example:
>>> bmtrain.clean("model.pt")
"""
if bmt.rank() == 0:
parent = os.path.dirname(os.path.abspath(file_name))
for f in os.listdir(parent):
if f.startswith(file_name):
os.remove(os.path.join(parent, f))


32 changes: 20 additions & 12 deletions tests/test_load_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.nn.functional as F
import bmtrain as bmt
import os
from collections import OrderedDict

class Linear_Normal(torch.nn.Module):
def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None:
Expand Down Expand Up @@ -36,41 +37,48 @@ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtyp
def forward(self, input):
return F.linear(input, self.weight, self.bias)

def test_save_load(m):
bmt.save(m, "test.pt", non_blocking=False)
bmt.load(m, "test.pt")
bmt.save(m, "test.pt", non_blocking=True)
bmt.load(m, "test.pt")
bmt.save(m, "test.pt", non_blocking=False, save_gather=True)
bmt.load(m, "test.pt", load_gather=True)
bmt.clean("test.pt")


def test_main():
ckpt_path = "test_ckpt.pt"
# Transformer BlockList
m = Linear_Normal(256, 256).cuda()
m2 = bmt.TransformerBlockList([bmt.Block(Linear_BMT(256, 256))])
if bmt.rank() == 0:
torch.save(m.state_dict(), ckpt_path)
dic2 = m.state_dict()
dic2["0.weight"] = dic2.pop("weight")
dic2["0.bias"] = dic2.pop("bias")
m2.load_state_dict(dic2)
m2_state = m.state_dict().copy()
m2_state["0.weight"] = m2_state.pop("weight")
m2_state["0.bias"] = m2_state.pop("bias")
test_save_load(m2)
m2.load_state_dict(m2_state)
for key in m.state_dict():
bmt_key = f"0.{key}"
assert bmt_key in m2.state_dict(), "wrong key in bmtrain model"
assert (m2.state_dict()[bmt_key].cuda() == m.state_dict()[key]).all() , "wrong param in bmtrain model"
if bmt.rank() == 0:
os.remove(ckpt_path)
print("Transformer Blocklist load_state_dict and state_dict test passed")
print("Transformer Blocklist load_state_dict ,state_dict, bmt.load/save test passed")

# Block
m3 = bmt.Block(Linear_BMT(256, 256))
m3.load_state_dict(m.state_dict())
for key in m.state_dict():
assert key in m3.state_dict(), "wrong key in bmtrain model"
assert (m.state_dict()[key] == m3.state_dict()[key].cuda()).all(), "wrong param in bmtrain model"
print("Block load_state_dict and state_dict test passed")
test_save_load(m2)
print("Block load_state_dict ,state_dict, bmt.load/save test passed")

# normal Distributed module
m4 = Linear_BMT(256, 256)
m4.load_state_dict(m.state_dict())
for key in m.state_dict():
assert key in m4.state_dict(), "wrong key in bmtrain model"
assert (m.state_dict()[key] == m4.state_dict()[key].cuda()).all(), "wrong param in bmtrain model"
print("bmt.distributedmodule load_state_dict and state_dict test passed")
test_save_load(m2)
print("bmt.distributedmodule load_state_dict, state_dict, bmt.load/save test passed")

if __name__ == "__main__":
bmt.init_distributed()
Expand Down

0 comments on commit fd7ac11

Please sign in to comment.