-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
80 lines (69 loc) · 3.26 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Ray Imports
import ray
from ray.tune import run_experiments
from ray.tune.registry import register_env
from ray import tune
from v2i import V2I
from v2i.src.core.common import readYaml, raiseValueError
import argparse
parser = argparse.ArgumentParser(description="Training Script for v2i simulator")
parser.add_argument("-sc", "--sim-config", type=str, required=True, help="v2i simulation configuration file")
parser.add_argument("-tc", "--training-algo-config", required=True, type=str, help="training algorithm configuration file")
parser.add_argument("-w", "--num-workers", type=int, default=4, help="number of parallel worker to use for training, (default=4)")
parser.add_argument("-n", "--name", type=str, required=True, help="experiment name")
parser.add_argument("-f", "--checkpoint-freq", type=int, default=100, help="checkpoint frequency")
def doIMPALAEssentials(algoConfig, args):
# Set Experiment Name
algoConfig[args.name] = algoConfig.pop("EXP_NAME")
# Set Number of Workers
algoConfig[args.name]["config"]["num_workers"] = int(args.num_workers)
# Set batch size
algoConfig[args.name]["config"]["train_batch_size"] = int(args.num_workers) * algoConfig[args.name]["config"]["sample_batch_size"] * int(algoConfig[args.name]["config"]["num_envs_per_worker"])
# Set Environment Name
algoConfig[args.name]["env"] = "v2i-v0"
# Set Algorithm Here
algoConfig[args.name]["run"] = "IMPALA"
# Set Checkpoint Frequency
algoConfig[args.name]["checkpoint_freq"] = args.checkpoint_freq
return algoConfig
def doPPOEssentials(algoConfig, simConfig, args):
# Set Experiment Name
algoConfig[args.name] = algoConfig.pop("EXP_NAME")
# Set Number of Workers
algoConfig[args.name]["config"]["num_workers"] = int(args.num_workers)
# Set batch size
if algoConfig[args.name]["config"]["train_batch_size"] == None:
algoConfig[args.name]["config"]["train_batch_size"] = int(args.num_workers) * algoConfig[args.name]["config"]["sgd_minibatch_size"]
# Set Environment Name
algoConfig[args.name]["env"] = "v2i-v0"
# Set Algorithm Here
algoConfig[args.name]["run"] = "PPO"
# Set Checkpoint Frequency
algoConfig[args.name]["checkpoint_freq"] = args.checkpoint_freq
# Enable/Disable memory use
if simConfig['config']['enable-lstm']:
algoConfig[args.name]['config']['model']['use_lstm'] = True
else:
algoConfig[args.name]['config']['model']['use_lstm'] = False
return algoConfig
if __name__ == "__main__":
args = parser.parse_args()
# Read Config Files
algoConfig = readYaml(args.training_algo_config)
simConfig = readYaml(args.sim_config)
# Set essentials
trainAlgo = args.training_algo_config.split("/")[-1].split('-')[0].upper()
if trainAlgo == 'PPO':
algoConfig = doPPOEssentials(algoConfig, simConfig, args)
elif trainAlgo == 'IMPALA':
algoConfig == doIMPALAEssentials(algoConfig, args)
else:
raiseValueError("invalid training algo %s"%(trainAlgo))
# Register Environment
register_env("v2i-v0", lambda config: V2I.V2I(args.sim_config, "train"))
# Start the training
ray.init()
run_experiments(algoConfig)