-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
compress_model.py
89 lines (71 loc) · 2.36 KB
/
compress_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from collections import OrderedDict
from text.symbols import symbols
import torch
from tools.log import logger
import utils
from models import SynthesizerTrn
import os
def copyStateDict(state_dict):
if list(state_dict.keys())[0].startswith("module"):
start_idx = 1
else:
start_idx = 0
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = ",".join(k.split(".")[start_idx:])
new_state_dict[name] = v
return new_state_dict
def removeOptimizer(config: str, input_model: str, ishalf: bool, output_model: str):
hps = utils.get_hparams_from_file(config)
net_g = SynthesizerTrn(
len(symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
)
optim_g = torch.optim.AdamW(
net_g.parameters(),
hps.train.learning_rate,
betas=hps.train.betas,
eps=hps.train.eps,
)
state_dict_g = torch.load(input_model, map_location="cpu")
new_dict_g = copyStateDict(state_dict_g)
keys = []
for k, v in new_dict_g["model"].items():
if "enc_q" in k:
continue # noqa: E701
keys.append(k)
new_dict_g = (
{k: new_dict_g["model"][k].half() for k in keys}
if ishalf
else {k: new_dict_g["model"][k] for k in keys}
)
torch.save(
{
"model": new_dict_g,
"iteration": 0,
"optimizer": optim_g.state_dict(),
"learning_rate": 0.0001,
},
output_model,
)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", type=str, default="configs/config.json")
parser.add_argument("-i", "--input", type=str)
parser.add_argument("-o", "--output", type=str, default=None)
parser.add_argument(
"-hf", "--half", action="store_true", default=False, help="Save as FP16"
)
args = parser.parse_args()
output = args.output
if output is None:
import os.path
filename, ext = os.path.splitext(args.input)
half = "_half" if args.half else ""
output = filename + "_release" + half + ext
removeOptimizer(args.config, args.input, args.half, output)
logger.info(f"压缩模型成功, 输出模型: {os.path.abspath(output)}")