-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtuning_continuous.py
66 lines (47 loc) · 2 KB
/
tuning_continuous.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gym
import pybulletgym
import numpy as np
import torch
import random
import os
import time
import matplotlib.pyplot as plt
from continuous import train
from util import *
SEED = 1564
ENV_NAME = "InvertedDoublePendulumPyBulletEnv-v0"
DIMS = 9
SCALE = [ 0.00868285, 0.03400105, -0.00312787, 0.95092393, -0.01797627, -0.10439248, 0.86726532, 0.01176883, 0.12335652]
STD = [0.11101651, 0.58301397, 0.09502404, 0.07712284, 0.29911971, 1.78995357, 0.20914456, 0.45163139, 3.08248822]
MAX_EPISODES = 10000
MAX_TIMESTEPS = 200
if __name__ == "__main__":
para = {
"gamma" : [0.95, 0.99, 1],
"sigma" : [0.5, 1],
"alpha" : [0.001, 0.005, 0.01],
"batch_size" : [1, 4, 16, 32]
}
while True:
gamma = random.sample(para["gamma"], 1)[0]
sigma = random.sample(para["sigma"], 1)[0]
alpha = random.sample(para["alpha"], 1)[0]
batch_size = random.sample(para["batch_size"], 1)[0]
params = TrainingParameters(batch_size=batch_size, n_layers=1, lr=alpha, gamma=gamma, discrete=False, sigma=sigma)
if os.path.isfile(get_data_path(params.get_model_name())):
print("Training already done for params {} ".format(params.get_model_name()))
continue
env = gym.make(ENV_NAME)
env.seed(seed=SEED)
torch.manual_seed(SEED)
print("########################################################################")
print("Training {}".format(params.get_model_name()))
print("########################################################################")
start = time.time()
policy, cum_rewards, alive_time = train(env, params, max_episodes=int(MAX_EPISODES/batch_size), max_timesteps=MAX_TIMESTEPS,
dims=9, scale=SCALE, std=STD)
env.close()
results = np.array([cum_rewards, alive_time])
save_results(params.get_model_name(), results)
print("Time taken : {:.0f} seconds".format(time.time() - start))
print("\n\n\n")