From cb6dea94bce34a3c0a3033a023815b5aff9db79d Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Tue, 3 Oct 2017 18:45:02 -0700 Subject: [PATCH] [rllib] Fix Preprocessor for ATARI (#1066) * Removing squeeze, fix atari preprocessing * nit comment * comments * jenkins * Lint --- python/ray/rllib/a3c/envs.py | 2 +- python/ray/rllib/es/policies.py | 6 ++---- python/ray/rllib/models/preprocessors.py | 5 ++++- python/ray/rllib/ppo/env.py | 5 +++-- test/jenkins_tests/run_multi_node_tests.sh | 7 +++++++ 5 files changed, 17 insertions(+), 8 deletions(-) diff --git a/python/ray/rllib/a3c/envs.py b/python/ray/rllib/a3c/envs.py index f25f8e0a1..8a56b92ce 100644 --- a/python/ray/rllib/a3c/envs.py +++ b/python/ray/rllib/a3c/envs.py @@ -30,7 +30,7 @@ class RLLibPreprocessing(gym.ObservationWrapper): self.observation_space = Box(-1.0, 1.0, self._process_shape) def _observation(self, observation): - return self.preprocessor.transform(observation).squeeze(0) + return self.preprocessor.transform(observation) class Diagnostic(gym.Wrapper): diff --git a/python/ray/rllib/es/policies.py b/python/ray/rllib/es/policies.py index 617d23a7d..c152328de 100644 --- a/python/ray/rllib/es/policies.py +++ b/python/ray/rllib/es/policies.py @@ -99,15 +99,13 @@ class Policy: t = 0 if save_obs: obs = [] - # TODO(ekl) the squeeze() is needed for Pong-v0, but we should fix - # this in the preprocessor instead - ob = preprocessor.transform(env.reset()).squeeze() + ob = preprocessor.transform(env.reset()) for _ in range(timestep_limit): ac = self.act(ob[None], random_stream=random_stream)[0] if save_obs: obs.append(ob) ob, rew, done, _ = env.step(ac) - ob = preprocessor.transform(ob).squeeze() + ob = preprocessor.transform(ob) rews.append(rew) t += 1 if render: diff --git a/python/ray/rllib/models/preprocessors.py b/python/ray/rllib/models/preprocessors.py index e0ee469f1..740358c0b 100644 --- a/python/ray/rllib/models/preprocessors.py +++ b/python/ray/rllib/models/preprocessors.py @@ -42,12 +42,15 @@ class AtariPixelPreprocessor(Preprocessor): scaled = observation[25:-25, :, :] if self.dim < 80: scaled = cv2.resize(scaled, (80, 80)) + # OpenAI: Resize by half, then down to 42x42 (essentially mipmapping). + # If we resize directly we lose pixels that, when mapped to 42x42, + # aren't close enough to the pixel boundary. scaled = cv2.resize(scaled, (self.dim, self.dim)) if self.grayscale: scaled = scaled.mean(2) scaled = scaled.astype(np.float32) + # Rescale needed for maintaining 1 channel scaled = np.reshape(scaled, [self.dim, self.dim, 1]) - scaled = scaled[None] if self.zero_mean: scaled = (scaled - 128) / 128 else: diff --git a/python/ray/rllib/ppo/env.py b/python/ray/rllib/ppo/env.py index d569f7c01..1dba5973f 100644 --- a/python/ray/rllib/ppo/env.py +++ b/python/ray/rllib/ppo/env.py @@ -22,7 +22,8 @@ class BatchedEnv(object): def reset(self): observations = [ - self.preprocessor.transform(env.reset()) for env in self.envs] + self.preprocessor.transform(env.reset())[None] + for env in self.envs] self.shape = observations[0].shape self.dones = [False for _ in range(self.batchsize)] return np.vstack(observations) @@ -43,7 +44,7 @@ class BatchedEnv(object): break if render: self.envs[0].render() - observations.append(self.preprocessor.transform(observation)) + observations.append(self.preprocessor.transform(observation)[None]) rewards.append(reward) self.dones[i] = done return (np.vstack(observations), np.array(rewards, dtype="float32"), diff --git a/test/jenkins_tests/run_multi_node_tests.sh b/test/jenkins_tests/run_multi_node_tests.sh index 0dea1ee84..852148091 100755 --- a/test/jenkins_tests/run_multi_node_tests.sh +++ b/test/jenkins_tests/run_multi_node_tests.sh @@ -84,6 +84,13 @@ docker run --shm-size=10G --memory=10G $DOCKER_SHA \ --num-iterations 2 \ --config '{"stepsize": 0.01}' +docker run --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/rllib/train.py \ + --env CartPole-v0 \ + --alg A3C \ + --num-iterations 2 \ + --config '{"use_lstm": false}' + docker run --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v0 \