[rllib] Pick preprocessor based on obs shape (#855)

* update

* auto choose
This commit is contained in:
Eric Liang
2017-08-23 01:46:55 +02:00
committed by Philipp Moritz
parent 58d06e3f4d
commit e2f2a7e57a
5 changed files with 27 additions and 30 deletions
@@ -73,7 +73,8 @@ class Worker(object):
self.noise = SharedNoiseTable(noise)
self.env = gym.make(env_name)
self.preprocessor = ModelCatalog.get_preprocessor(env_name)
self.preprocessor = ModelCatalog.get_preprocessor(
env_name, self.env.observation_space.shape)
self.preprocessor_shape = self.preprocessor.transform_shape(
self.env.observation_space.shape)
@@ -167,7 +168,8 @@ class EvolutionStrategies(Algorithm):
}
env = gym.make(env_name)
preprocessor = ModelCatalog.get_preprocessor(env_name)
preprocessor = ModelCatalog.get_preprocessor(
env_name, env.observation_space.shape)
preprocessor_shape = preprocessor.transform_shape(
env.observation_space.shape)
+12 -13
View File
@@ -75,27 +75,26 @@ class ModelCatalog(object):
return ConvolutionalNetwork(inputs, num_outputs, options)
@staticmethod
def get_preprocessor(env_name):
def get_preprocessor(env_name, obs_shape):
"""Returns a suitable processor for the given environment.
Args:
env_name (str): The name of the environment.
obs_shape (tuple): The shape of the env observation space.
Returns:
preprocessor (Preprocessor): Preprocessor for the env observations.
"""
if env_name == "Pong-v0":
ATARI_OBS_SHAPE = (210, 160, 3)
ATARI_RAM_OBS_SHAPE = (128,)
if obs_shape == ATARI_OBS_SHAPE:
print("Assuming Atari pixel env, using AtariPixelPreprocessor.")
return AtariPixelPreprocessor()
elif env_name == "Pong-ram-v3":
elif obs_shape == ATARI_RAM_OBS_SHAPE:
print("Assuming Atari ram env, using AtariRamPreprocessor.")
return AtariRamPreprocessor()
elif env_name == "CartPole-v0" or env_name == "CartPole-v1":
return NoPreprocessor()
elif env_name == "Hopper-v1":
return NoPreprocessor()
elif env_name == "Walker2d-v1":
return NoPreprocessor()
elif env_name == "Humanoid-v1" or env_name == "Pendulum-v0":
return NoPreprocessor()
else:
return AtariPixelPreprocessor()
print("Non-atari env, not using any observation preprocessor.")
return NoPreprocessor()
+4 -5
View File
@@ -35,8 +35,7 @@ class Agent(object):
network weights. When run as a remote agent, only this graph is used.
"""
def __init__(
self, name, batchsize, preprocessor, config, logdir, is_remote):
def __init__(self, name, batchsize, config, logdir, is_remote):
if is_remote:
os.environ["CUDA_VISIBLE_DEVICES"] = ""
devices = ["/cpu:0"]
@@ -45,12 +44,12 @@ class Agent(object):
self.devices = devices
self.config = config
self.logdir = logdir
self.env = BatchedEnv(name, batchsize, preprocessor=preprocessor)
self.env = BatchedEnv(name, batchsize)
if is_remote:
config_proto = tf.ConfigProto()
else:
config_proto = tf.ConfigProto(**config["tf_session_args"])
self.preprocessor = preprocessor
self.preprocessor = self.env.preprocessor
self.sess = tf.Session(config=config_proto)
if config["use_tf_debugger"] and not is_remote:
self.sess = tf_debug.LocalCLIDebugWrapperSession(self.sess)
@@ -61,7 +60,7 @@ class Agent(object):
self.kl_coeff = tf.placeholder(
name="newkl", shape=(), dtype=tf.float32)
self.preprocessor_shape = preprocessor.transform_shape(
self.preprocessor_shape = self.preprocessor.transform_shape(
self.env.observation_space.shape)
self.observations = tf.placeholder(
tf.float32, shape=(None,) + self.preprocessor_shape)
+5 -3
View File
@@ -5,16 +5,18 @@ from __future__ import print_function
import gym
import numpy as np
from ray.rllib.models import ModelCatalog
class BatchedEnv(object):
"""This holds multiple gym envs and performs steps on all of them."""
def __init__(self, name, batchsize, preprocessor=None):
def __init__(self, name, batchsize):
self.envs = [gym.make(name) for _ in range(batchsize)]
self.observation_space = self.envs[0].observation_space
self.action_space = self.envs[0].action_space
self.batchsize = batchsize
self.preprocessor = (preprocessor if preprocessor
else lambda obs: obs[None])
self.preprocessor = ModelCatalog.get_preprocessor(
name, self.envs[0].observation_space.shape)
def reset(self):
observations = [
@@ -10,7 +10,6 @@ import tensorflow as tf
import ray
from ray.rllib.common import Algorithm, TrainingResult
from ray.rllib.models import ModelCatalog
from ray.rllib.policy_gradient.agent import Agent, RemoteAgent
from ray.rllib.policy_gradient.rollout import collect_samples
from ray.rllib.policy_gradient.utils import shuffle
@@ -74,17 +73,13 @@ class PolicyGradient(Algorithm):
Algorithm.__init__(self, env_name, config, upload_dir=upload_dir)
self.preprocessor = ModelCatalog.get_preprocessor(self.env_name)
self.global_step = 0
self.j = 0
self.kl_coeff = config["kl_coeff"]
self.model = Agent(
self.env_name, 1, self.preprocessor, self.config, self.logdir,
False)
self.model = Agent(self.env_name, 1, self.config, self.logdir, False)
self.agents = [
RemoteAgent.remote(
self.env_name, 1, self.preprocessor, self.config,
self.logdir, True)
self.env_name, 1, self.config, self.logdir, True)
for _ in range(config["num_agents"])]
self.start_time = time.time()