-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
133 lines (113 loc) · 5.12 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import loading_utilities as lu
import run_inference
import sys
import os
from simple_slurm import Slurm
from datetime import datetime
from pathlib import Path
from mpi4py import MPI
from mpi4py.util import pkl5
def main(param_name, folder_name=None, extra_train_steps=None, prune_frac=None):
comm = pkl5.Intracomm(MPI.COMM_WORLD)
cpu_id = comm.Get_rank()
if folder_name is None:
run_type = 'new'
else:
if extra_train_steps is None:
run_type = 'post'
else:
if prune_frac is None:
run_type = 'cont'
else:
run_type = 'prune'
param_name = Path(param_name)
run_params = lu.get_run_params(param_name=param_name)
if cpu_id == 0:
current_date = datetime.today().strftime('%Y%m%d_%H%M%S')
full_path = Path(__file__).parent.resolve()
if run_type == 'new':
save_folder = full_path / 'trained_models' / param_name.stem / current_date
os.makedirs(save_folder)
else:
save_folder = full_path / 'trained_models' / param_name.stem / folder_name
else:
save_folder = None
if 'slurm' in run_params.keys():
run_type_append = ''
if cpu_id == 0:
# these commands need to end with a " to complement the leading " in the run command
if run_type == 'new':
fit_model_command = 'run_inference.' + run_params['fit_file'] + '(\'' + str(param_name) + '\',\'' + str(save_folder) + '\')\"'
elif run_type == 'post':
fit_model_command = 'run_inference.infer_posterior(\'' + str(param_name) + '\',\'' + str(save_folder) + \
'\', infer_missing=True)\"'
elif run_type == 'cont':
fit_model_command = 'run_inference.continue_fit(\'' + str(param_name) + '\',\'' + str(save_folder) + \
'\',' + str(extra_train_steps) + ')\"'
elif run_type == 'prune':
fit_model_command = 'run_inference.prune_model(\'' + str(param_name) + '\',\'' + str(save_folder) + \
'\',' + str(extra_train_steps) + ',' + str(prune_frac) + ')\"'
run_type_append = '_es' + f'{int(extra_train_steps):03d}' + '_pf' + f'{int(prune_frac * 100):03d}'
else:
raise Exception('run type not recognized')
slurm_output_path = save_folder / ('slurm_%A_' + run_type + '.out')
job_name = param_name.stem + '_' + run_type + run_type_append
slurm_fit = Slurm(**run_params['slurm'], output=slurm_output_path, job_name=job_name)
cpus_per_task = run_params['slurm']['cpus_per_task']
run_command = ['module purge',
'module load anaconda3/2022.10',
'module load openmpi/gcc/4.1.2',
'conda activate fast-mpi4py',
'export MKL_NUM_THREADS=' + str(cpus_per_task),
'export OPENBLAS_NUM_THREADS=' + str(cpus_per_task),
'export OMP_NUM_THREADS=' + str(cpus_per_task),
'srun python -uc \"import run_inference; ' + fit_model_command,
]
slurm_fit.sbatch('\n'.join(run_command))
else:
if run_type == 'new':
method = getattr(run_inference, run_params['fit_file'])
method(param_name, save_folder)
elif run_type == 'post':
method = getattr(run_inference, 'infer_posterior')
method(param_name, save_folder, infer_missing=True)
elif run_type == 'cont':
method = getattr(run_inference, 'prune_model')
method(param_name, save_folder, extra_train_steps=extra_train_steps)
elif run_type == 'prune':
method = getattr(run_inference, 'prune_model')
method(param_name, save_folder, extra_train_steps=extra_train_steps, prune_frac=prune_frac)
else:
raise Exception('run type not recognized')
return save_folder
if __name__ == '__main__':
num_args = len(sys.argv)
if num_args == 1:
param_name = 'submission_params/syn_test.yml'
# param_name = 'submission_params/exp_test.yml'
folder_name = None
extra_train_steps = None
prune_frac = None
elif num_args == 2:
param_name = sys.argv[1]
folder_name = None
extra_train_steps = None
prune_frac = None
elif num_args == 3:
param_name = sys.argv[1]
folder_name = sys.argv[2]
extra_train_steps = None
prune_frac = None
elif num_args == 4:
param_name = sys.argv[1]
folder_name = sys.argv[2]
extra_train_steps = int(sys.argv[3])
prune_frac = None
elif num_args == 5:
param_name = sys.argv[1]
folder_name = sys.argv[2]
extra_train_steps = int(sys.argv[3])
prune_frac = float(sys.argv[4])
else:
raise Exception('Unsupported number of arguments: (' + str(num_args))
main(param_name, folder_name, extra_train_steps=extra_train_steps, prune_frac=prune_frac)