-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathgenerate.py
91 lines (75 loc) · 4.37 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# MIT License
# Copyright (c) [2023] [Anima-Lab]
from argparse import ArgumentParser
import os
import json
from omegaconf import OmegaConf
import torch
from models.maskdit import Precond_models
from sample import generate_with_net
from utils import parse_float_none, parse_int_list, init_processes
def generate(args):
rank = args.global_rank
size = args.global_size
config = OmegaConf.load(args.config)
label_dict = json.load(open(args.label_dict, 'r'))
class_label = label_dict[str(args.class_idx)][1]
print(f'start sampling class {class_label}...')
device = torch.device('cuda')
# setup directory
sample_dir = os.path.join(args.results_dir, class_label)
os.makedirs(sample_dir, exist_ok=True)
args.outdir = sample_dir
# setup model
model = Precond_models[config.model.precond](
img_resolution=config.model.in_size,
img_channels=config.model.in_channels,
num_classes=config.model.num_classes,
model_type=config.model.model_type,
use_decoder=config.model.use_decoder,
mae_loss_coef=config.model.mae_loss_coef,
pad_cls_token=config.model.pad_cls_token,
use_encoder_feat=config.model.self_cond,
).to(device)
model.eval()
print(f"{config.model.model_type} ((use_decoder: {config.model.use_decoder})) Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f'extras: {model.model.extras}, cls_token: {model.model.cls_token}')
model = torch.compile(model)
ckpt = torch.load(args.ckpt_path, map_location=device)
model.load_state_dict(ckpt['ema'])
generate_with_net(args, model, device, rank, size)
print(f'sampling class {class_label} done!')
if __name__ == '__main__':
parser = ArgumentParser('Sample from a trained model')
# basic config
parser.add_argument('--config', type=str, required=True, help='path to config file')
parser.add_argument('--label_dict', type=str, default='assets/imagenet_label.json', help='path to label dict')
parser.add_argument("--results_dir", type=str, default="samples", help='path to save samples')
parser.add_argument('--ckpt_path', type=str, default=None, help='path to ckpt')
# sampling
parser.add_argument('--seeds', type=parse_int_list, default='100-131', help='Random seeds (e.g. 1,2,5-10)')
parser.add_argument('--subdirs', action='store_true', help='Create subdirectory for every 1000 seeds')
parser.add_argument('--class_idx', type=int, default=None, help='Class label [default: random]')
parser.add_argument("--cfg_scale", type=parse_float_none, default=None, help='None = no guidance, by default = 4.0')
parser.add_argument('--num_steps', type=int, default=40, help='Number of sampling steps')
parser.add_argument('--S_churn', type=int, default=0, help='Stochasticity strength')
parser.add_argument('--solver', type=str, default=None, choices=['euler', 'heun'], help='Ablate ODE solver')
parser.add_argument('--discretization', type=str, default=None, choices=['vp', 've', 'iddpm', 'edm'], help='Ablate ODE solver')
parser.add_argument('--schedule', type=str, default=None, choices=['vp', 've', 'linear'], help='Ablate noise schedule sigma(t)')
parser.add_argument('--scaling', type=str, default=None, choices=['vp', 'none'], help='Ablate signal scaling s(t)')
parser.add_argument('--pretrained_path', type=str, default='assets/autoencoder_kl.pth', help='Autoencoder ckpt')
parser.add_argument('--max_batch_size', type=int, default=32, help='Maximum batch size per GPU during sampling')
parser.add_argument('--num_expected', type=int, default=32, help='Number of images to use')
parser.add_argument("--global_seed", type=int, default=0)
parser.add_argument('--fid_batch_size', type=int, default=32, help='Maximum batch size')
# ddp
parser.add_argument('--num_proc_node', type=int, default=1, help='The number of nodes in multi node env.')
parser.add_argument('--num_process_per_node', type=int, default=1, help='number of gpus')
parser.add_argument('--node_rank', type=int, default=0, help='The index of node.')
parser.add_argument('--local_rank', type=int, default=0, help='rank of process in the node')
parser.add_argument('--master_address', type=str, default='localhost', help='address for master')
args = parser.parse_args()
args.global_rank = 0
args.local_rank = 0
args.global_size = 1
init_processes(generate, args)