diff --git a/python/ray/rllib/evolution_strategies/evolution_strategies.py b/python/ray/rllib/evolution_strategies/evolution_strategies.py index 65bb4e379..138eb00c3 100644 --- a/python/ray/rllib/evolution_strategies/evolution_strategies.py +++ b/python/ray/rllib/evolution_strategies/evolution_strategies.py @@ -73,7 +73,8 @@ class Worker(object): self.noise = SharedNoiseTable(noise) self.env = gym.make(env_name) - self.preprocessor = ModelCatalog.get_preprocessor(env_name) + self.preprocessor = ModelCatalog.get_preprocessor( + env_name, self.env.observation_space.shape) self.preprocessor_shape = self.preprocessor.transform_shape( self.env.observation_space.shape) @@ -167,7 +168,8 @@ class EvolutionStrategies(Algorithm): } env = gym.make(env_name) - preprocessor = ModelCatalog.get_preprocessor(env_name) + preprocessor = ModelCatalog.get_preprocessor( + env_name, env.observation_space.shape) preprocessor_shape = preprocessor.transform_shape( env.observation_space.shape) diff --git a/python/ray/rllib/models/catalog.py b/python/ray/rllib/models/catalog.py index 29289c812..effaa3326 100644 --- a/python/ray/rllib/models/catalog.py +++ b/python/ray/rllib/models/catalog.py @@ -75,27 +75,26 @@ class ModelCatalog(object): return ConvolutionalNetwork(inputs, num_outputs, options) @staticmethod - def get_preprocessor(env_name): + def get_preprocessor(env_name, obs_shape): """Returns a suitable processor for the given environment. Args: env_name (str): The name of the environment. + obs_shape (tuple): The shape of the env observation space. Returns: preprocessor (Preprocessor): Preprocessor for the env observations. """ - if env_name == "Pong-v0": + ATARI_OBS_SHAPE = (210, 160, 3) + ATARI_RAM_OBS_SHAPE = (128,) + + if obs_shape == ATARI_OBS_SHAPE: + print("Assuming Atari pixel env, using AtariPixelPreprocessor.") return AtariPixelPreprocessor() - elif env_name == "Pong-ram-v3": + elif obs_shape == ATARI_RAM_OBS_SHAPE: + print("Assuming Atari ram env, using AtariRamPreprocessor.") return AtariRamPreprocessor() - elif env_name == "CartPole-v0" or env_name == "CartPole-v1": - return NoPreprocessor() - elif env_name == "Hopper-v1": - return NoPreprocessor() - elif env_name == "Walker2d-v1": - return NoPreprocessor() - elif env_name == "Humanoid-v1" or env_name == "Pendulum-v0": - return NoPreprocessor() - else: - return AtariPixelPreprocessor() + + print("Non-atari env, not using any observation preprocessor.") + return NoPreprocessor() diff --git a/python/ray/rllib/policy_gradient/agent.py b/python/ray/rllib/policy_gradient/agent.py index 1e0b39fe5..ddddc3270 100644 --- a/python/ray/rllib/policy_gradient/agent.py +++ b/python/ray/rllib/policy_gradient/agent.py @@ -35,8 +35,7 @@ class Agent(object): network weights. When run as a remote agent, only this graph is used. """ - def __init__( - self, name, batchsize, preprocessor, config, logdir, is_remote): + def __init__(self, name, batchsize, config, logdir, is_remote): if is_remote: os.environ["CUDA_VISIBLE_DEVICES"] = "" devices = ["/cpu:0"] @@ -45,12 +44,12 @@ class Agent(object): self.devices = devices self.config = config self.logdir = logdir - self.env = BatchedEnv(name, batchsize, preprocessor=preprocessor) + self.env = BatchedEnv(name, batchsize) if is_remote: config_proto = tf.ConfigProto() else: config_proto = tf.ConfigProto(**config["tf_session_args"]) - self.preprocessor = preprocessor + self.preprocessor = self.env.preprocessor self.sess = tf.Session(config=config_proto) if config["use_tf_debugger"] and not is_remote: self.sess = tf_debug.LocalCLIDebugWrapperSession(self.sess) @@ -61,7 +60,7 @@ class Agent(object): self.kl_coeff = tf.placeholder( name="newkl", shape=(), dtype=tf.float32) - self.preprocessor_shape = preprocessor.transform_shape( + self.preprocessor_shape = self.preprocessor.transform_shape( self.env.observation_space.shape) self.observations = tf.placeholder( tf.float32, shape=(None,) + self.preprocessor_shape) diff --git a/python/ray/rllib/policy_gradient/env.py b/python/ray/rllib/policy_gradient/env.py index 418929bab..63509b06f 100644 --- a/python/ray/rllib/policy_gradient/env.py +++ b/python/ray/rllib/policy_gradient/env.py @@ -5,16 +5,18 @@ from __future__ import print_function import gym import numpy as np +from ray.rllib.models import ModelCatalog + class BatchedEnv(object): """This holds multiple gym envs and performs steps on all of them.""" - def __init__(self, name, batchsize, preprocessor=None): + def __init__(self, name, batchsize): self.envs = [gym.make(name) for _ in range(batchsize)] self.observation_space = self.envs[0].observation_space self.action_space = self.envs[0].action_space self.batchsize = batchsize - self.preprocessor = (preprocessor if preprocessor - else lambda obs: obs[None]) + self.preprocessor = ModelCatalog.get_preprocessor( + name, self.envs[0].observation_space.shape) def reset(self): observations = [ diff --git a/python/ray/rllib/policy_gradient/policy_gradient.py b/python/ray/rllib/policy_gradient/policy_gradient.py index 074ffdf79..6cb9d22f5 100644 --- a/python/ray/rllib/policy_gradient/policy_gradient.py +++ b/python/ray/rllib/policy_gradient/policy_gradient.py @@ -10,7 +10,6 @@ import tensorflow as tf import ray from ray.rllib.common import Algorithm, TrainingResult -from ray.rllib.models import ModelCatalog from ray.rllib.policy_gradient.agent import Agent, RemoteAgent from ray.rllib.policy_gradient.rollout import collect_samples from ray.rllib.policy_gradient.utils import shuffle @@ -74,17 +73,13 @@ class PolicyGradient(Algorithm): Algorithm.__init__(self, env_name, config, upload_dir=upload_dir) - self.preprocessor = ModelCatalog.get_preprocessor(self.env_name) self.global_step = 0 self.j = 0 self.kl_coeff = config["kl_coeff"] - self.model = Agent( - self.env_name, 1, self.preprocessor, self.config, self.logdir, - False) + self.model = Agent(self.env_name, 1, self.config, self.logdir, False) self.agents = [ RemoteAgent.remote( - self.env_name, 1, self.preprocessor, self.config, - self.logdir, True) + self.env_name, 1, self.config, self.logdir, True) for _ in range(config["num_agents"])] self.start_time = time.time()