[rllib] Fix Preprocessor for ATARI (#1066)

* Removing squeeze, fix atari preprocessing

* nit comment

* comments

* jenkins

* Lint
This commit is contained in:
Richard Liaw
2017-10-03 18:45:02 -07:00
committed by Eric Liang
parent 0dcf36c91e
commit cb6dea94bc
5 changed files with 17 additions and 8 deletions
+1 -1
View File
@@ -30,7 +30,7 @@ class RLLibPreprocessing(gym.ObservationWrapper):
self.observation_space = Box(-1.0, 1.0, self._process_shape)
def _observation(self, observation):
return self.preprocessor.transform(observation).squeeze(0)
return self.preprocessor.transform(observation)
class Diagnostic(gym.Wrapper):
+2 -4
View File
@@ -99,15 +99,13 @@ class Policy:
t = 0
if save_obs:
obs = []
# TODO(ekl) the squeeze() is needed for Pong-v0, but we should fix
# this in the preprocessor instead
ob = preprocessor.transform(env.reset()).squeeze()
ob = preprocessor.transform(env.reset())
for _ in range(timestep_limit):
ac = self.act(ob[None], random_stream=random_stream)[0]
if save_obs:
obs.append(ob)
ob, rew, done, _ = env.step(ac)
ob = preprocessor.transform(ob).squeeze()
ob = preprocessor.transform(ob)
rews.append(rew)
t += 1
if render:
+4 -1
View File
@@ -42,12 +42,15 @@ class AtariPixelPreprocessor(Preprocessor):
scaled = observation[25:-25, :, :]
if self.dim < 80:
scaled = cv2.resize(scaled, (80, 80))
# OpenAI: Resize by half, then down to 42x42 (essentially mipmapping).
# If we resize directly we lose pixels that, when mapped to 42x42,
# aren't close enough to the pixel boundary.
scaled = cv2.resize(scaled, (self.dim, self.dim))
if self.grayscale:
scaled = scaled.mean(2)
scaled = scaled.astype(np.float32)
# Rescale needed for maintaining 1 channel
scaled = np.reshape(scaled, [self.dim, self.dim, 1])
scaled = scaled[None]
if self.zero_mean:
scaled = (scaled - 128) / 128
else:
+3 -2
View File
@@ -22,7 +22,8 @@ class BatchedEnv(object):
def reset(self):
observations = [
self.preprocessor.transform(env.reset()) for env in self.envs]
self.preprocessor.transform(env.reset())[None]
for env in self.envs]
self.shape = observations[0].shape
self.dones = [False for _ in range(self.batchsize)]
return np.vstack(observations)
@@ -43,7 +44,7 @@ class BatchedEnv(object):
break
if render:
self.envs[0].render()
observations.append(self.preprocessor.transform(observation))
observations.append(self.preprocessor.transform(observation)[None])
rewards.append(reward)
self.dones[i] = done
return (np.vstack(observations), np.array(rewards, dtype="float32"),