-
Notifications
You must be signed in to change notification settings - Fork 606
/
Copy pathinitialize.py
116 lines (90 loc) · 4.1 KB
/
initialize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import argparse
import torch
import time
from quantization import quantize
from SwissArmyTransformer import get_args, get_tokenizer
from SwissArmyTransformer.arguments import initialize_distributed
from SwissArmyTransformer.training import load_checkpoint
from SwissArmyTransformer.model import GLM130B
from SwissArmyTransformer.mpu import get_model_parallel_world_size, get_model_parallel_rank, get_model_parallel_group
def add_bminf_args(parser):
"""Arguments for BMInf"""
group = parser.add_argument_group("BMInf")
group.add_argument("--bminf", action="store_true", help="Use BMInf to support low resource evaluation")
group.add_argument("--bminf-memory-limit", type=int, default=20, help="Max memory for model per GPU (in GB)")
return parser
def add_quantization_args(parser):
group = parser.add_argument_group("Quantization")
group.add_argument("--quantization-bit-width", type=int, default=None)
group.add_argument("--from-quantized-checkpoint", action="store_true", help="Loading from a quantized checkpoint")
def add_initialization_args(parser):
group = parser.add_argument_group("Initialization")
group.add_argument(
"--sequential-initialization",
action="store_true",
help="Initialize sequentially in tensor parallel group (reduce CPU RAM for initialization)",
)
def initialize(extra_args_provider):
parser = argparse.ArgumentParser(add_help=False)
add_bminf_args(parser)
add_quantization_args(parser)
add_initialization_args(parser)
GLM130B.add_model_specific_args(parser)
extra_args_provider(parser)
known, args_list = parser.parse_known_args()
args = get_args(args_list)
args = argparse.Namespace(**vars(args), **vars(known))
args.do_train = False
initialize_distributed(args)
return args
def initialize_model_and_tokenizer(args):
tokenizer = get_tokenizer(args)
torch.distributed.barrier()
start = time.time()
for i in range(get_model_parallel_world_size()):
if get_model_parallel_rank() == i:
# Initialize model
model = GLM130B(args).half()
if args.from_quantized_checkpoint:
assert args.quantization_bit_width is not None
# Quantize model before moving to GPU
model = quantize(model, args.quantization_bit_width)
# Load checkpoint
load_checkpoint(model, args)
if args.quantization_bit_width is not None and not args.from_quantized_checkpoint:
# Quantize model before moving to GPU
model = quantize(model, args.quantization_bit_width)
if args.bminf:
import bminf
if torch.distributed.get_rank() == 0:
print(f"> BMInf activated, memory limit: {args.bminf_memory_limit} GB")
with torch.cuda.device(args.device):
model = bminf.wrapper(model, quantization=False, memory_limit=args.bminf_memory_limit << 30)
else:
model = model.to(args.device)
if args.sequential_initialization:
torch.distributed.barrier(group=get_model_parallel_group())
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(f"> Model initialized in {time.time() - start:.1f}s")
torch.cuda.empty_cache()
model.eval()
# generate rotary embedding cache
original_parallel_output = model.transformer.parallel_output
model.transformer.parallel_output = True
with torch.no_grad():
_, *_ = model(
torch.ones(1, args.max_sequence_length, device=torch.cuda.current_device(), dtype=torch.int64),
torch.arange(args.max_sequence_length, device=torch.cuda.current_device(), dtype=torch.int64).view(1, -1),
torch.randn(
1,
1,
args.max_sequence_length,
args.max_sequence_length,
device=torch.cuda.current_device(),
)
< 0.5,
)
model.transformer.parallel_output = original_parallel_output
torch.distributed.barrier()
return model, tokenizer