forked from dgriff777/rl_a3c_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathenvironment.py
64 lines (52 loc) · 2.05 KB
/
environment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from __future__ import division
import gym
import numpy as np
from gym.spaces.box import Box
from universe import vectorized
from universe.wrappers import Unvectorize, Vectorize
from skimage.color import rgb2gray
from cv2 import resize
#from skimage.transform import resize
#from scipy.misc import imresize as resize
def atari_env(env_id, env_conf):
env = gym.make(env_id)
if len(env.observation_space.shape) > 1:
env = Vectorize(env)
env = AtariRescale(env, env_conf)
env = NormalizedEnv(env)
env = Unvectorize(env)
return env
def _process_frame(frame, conf):
frame = frame[conf["crop1"]:conf["crop2"] + 160, :160]
frame = resize(rgb2gray(frame), (80, conf["dimension2"]))
frame = resize(frame, (80, 80))
frame = np.reshape(frame, [1, 80, 80])
return frame
class AtariRescale(vectorized.ObservationWrapper):
def __init__(self, env, env_conf):
super(AtariRescale, self).__init__(env)
self.observation_space = Box(0.0, 1.0, [1, 80, 80])
self.conf = env_conf
def _observation(self, observation_n):
return [
_process_frame(observation, self.conf)
for observation in observation_n
]
class NormalizedEnv(vectorized.ObservationWrapper):
def __init__(self, env=None):
super(NormalizedEnv, self).__init__(env)
self.state_mean = 0
self.state_std = 0
self.alpha = 0.9999
self.num_steps = 0
def _observation(self, observation_n):
for observation in observation_n:
self.num_steps += 1
self.state_mean = self.state_mean * self.alpha + \
observation.mean() * (1 - self.alpha)
self.state_std = self.state_std * self.alpha + \
observation.std() * (1 - self.alpha)
unbiased_mean = self.state_mean / (1 - pow(self.alpha, self.num_steps))
unbiased_std = self.state_std / (1 - pow(self.alpha, self.num_steps))
return [(observation - unbiased_mean) / (unbiased_std + 1e-8)
for observation in observation_n]