mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:23:10 +08:00
[rllib] Pick preprocessor based on obs shape (#855)
* update * auto choose
This commit is contained in:
committed by
Philipp Moritz
parent
58d06e3f4d
commit
e2f2a7e57a
@@ -73,7 +73,8 @@ class Worker(object):
|
||||
self.noise = SharedNoiseTable(noise)
|
||||
|
||||
self.env = gym.make(env_name)
|
||||
self.preprocessor = ModelCatalog.get_preprocessor(env_name)
|
||||
self.preprocessor = ModelCatalog.get_preprocessor(
|
||||
env_name, self.env.observation_space.shape)
|
||||
self.preprocessor_shape = self.preprocessor.transform_shape(
|
||||
self.env.observation_space.shape)
|
||||
|
||||
@@ -167,7 +168,8 @@ class EvolutionStrategies(Algorithm):
|
||||
}
|
||||
|
||||
env = gym.make(env_name)
|
||||
preprocessor = ModelCatalog.get_preprocessor(env_name)
|
||||
preprocessor = ModelCatalog.get_preprocessor(
|
||||
env_name, env.observation_space.shape)
|
||||
preprocessor_shape = preprocessor.transform_shape(
|
||||
env.observation_space.shape)
|
||||
|
||||
|
||||
@@ -75,27 +75,26 @@ class ModelCatalog(object):
|
||||
return ConvolutionalNetwork(inputs, num_outputs, options)
|
||||
|
||||
@staticmethod
|
||||
def get_preprocessor(env_name):
|
||||
def get_preprocessor(env_name, obs_shape):
|
||||
"""Returns a suitable processor for the given environment.
|
||||
|
||||
Args:
|
||||
env_name (str): The name of the environment.
|
||||
obs_shape (tuple): The shape of the env observation space.
|
||||
|
||||
Returns:
|
||||
preprocessor (Preprocessor): Preprocessor for the env observations.
|
||||
"""
|
||||
|
||||
if env_name == "Pong-v0":
|
||||
ATARI_OBS_SHAPE = (210, 160, 3)
|
||||
ATARI_RAM_OBS_SHAPE = (128,)
|
||||
|
||||
if obs_shape == ATARI_OBS_SHAPE:
|
||||
print("Assuming Atari pixel env, using AtariPixelPreprocessor.")
|
||||
return AtariPixelPreprocessor()
|
||||
elif env_name == "Pong-ram-v3":
|
||||
elif obs_shape == ATARI_RAM_OBS_SHAPE:
|
||||
print("Assuming Atari ram env, using AtariRamPreprocessor.")
|
||||
return AtariRamPreprocessor()
|
||||
elif env_name == "CartPole-v0" or env_name == "CartPole-v1":
|
||||
return NoPreprocessor()
|
||||
elif env_name == "Hopper-v1":
|
||||
return NoPreprocessor()
|
||||
elif env_name == "Walker2d-v1":
|
||||
return NoPreprocessor()
|
||||
elif env_name == "Humanoid-v1" or env_name == "Pendulum-v0":
|
||||
return NoPreprocessor()
|
||||
else:
|
||||
return AtariPixelPreprocessor()
|
||||
|
||||
print("Non-atari env, not using any observation preprocessor.")
|
||||
return NoPreprocessor()
|
||||
|
||||
@@ -35,8 +35,7 @@ class Agent(object):
|
||||
network weights. When run as a remote agent, only this graph is used.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, name, batchsize, preprocessor, config, logdir, is_remote):
|
||||
def __init__(self, name, batchsize, config, logdir, is_remote):
|
||||
if is_remote:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||
devices = ["/cpu:0"]
|
||||
@@ -45,12 +44,12 @@ class Agent(object):
|
||||
self.devices = devices
|
||||
self.config = config
|
||||
self.logdir = logdir
|
||||
self.env = BatchedEnv(name, batchsize, preprocessor=preprocessor)
|
||||
self.env = BatchedEnv(name, batchsize)
|
||||
if is_remote:
|
||||
config_proto = tf.ConfigProto()
|
||||
else:
|
||||
config_proto = tf.ConfigProto(**config["tf_session_args"])
|
||||
self.preprocessor = preprocessor
|
||||
self.preprocessor = self.env.preprocessor
|
||||
self.sess = tf.Session(config=config_proto)
|
||||
if config["use_tf_debugger"] and not is_remote:
|
||||
self.sess = tf_debug.LocalCLIDebugWrapperSession(self.sess)
|
||||
@@ -61,7 +60,7 @@ class Agent(object):
|
||||
self.kl_coeff = tf.placeholder(
|
||||
name="newkl", shape=(), dtype=tf.float32)
|
||||
|
||||
self.preprocessor_shape = preprocessor.transform_shape(
|
||||
self.preprocessor_shape = self.preprocessor.transform_shape(
|
||||
self.env.observation_space.shape)
|
||||
self.observations = tf.placeholder(
|
||||
tf.float32, shape=(None,) + self.preprocessor_shape)
|
||||
|
||||
@@ -5,16 +5,18 @@ from __future__ import print_function
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.models import ModelCatalog
|
||||
|
||||
|
||||
class BatchedEnv(object):
|
||||
"""This holds multiple gym envs and performs steps on all of them."""
|
||||
def __init__(self, name, batchsize, preprocessor=None):
|
||||
def __init__(self, name, batchsize):
|
||||
self.envs = [gym.make(name) for _ in range(batchsize)]
|
||||
self.observation_space = self.envs[0].observation_space
|
||||
self.action_space = self.envs[0].action_space
|
||||
self.batchsize = batchsize
|
||||
self.preprocessor = (preprocessor if preprocessor
|
||||
else lambda obs: obs[None])
|
||||
self.preprocessor = ModelCatalog.get_preprocessor(
|
||||
name, self.envs[0].observation_space.shape)
|
||||
|
||||
def reset(self):
|
||||
observations = [
|
||||
|
||||
@@ -10,7 +10,6 @@ import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.common import Algorithm, TrainingResult
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.policy_gradient.agent import Agent, RemoteAgent
|
||||
from ray.rllib.policy_gradient.rollout import collect_samples
|
||||
from ray.rllib.policy_gradient.utils import shuffle
|
||||
@@ -74,17 +73,13 @@ class PolicyGradient(Algorithm):
|
||||
|
||||
Algorithm.__init__(self, env_name, config, upload_dir=upload_dir)
|
||||
|
||||
self.preprocessor = ModelCatalog.get_preprocessor(self.env_name)
|
||||
self.global_step = 0
|
||||
self.j = 0
|
||||
self.kl_coeff = config["kl_coeff"]
|
||||
self.model = Agent(
|
||||
self.env_name, 1, self.preprocessor, self.config, self.logdir,
|
||||
False)
|
||||
self.model = Agent(self.env_name, 1, self.config, self.logdir, False)
|
||||
self.agents = [
|
||||
RemoteAgent.remote(
|
||||
self.env_name, 1, self.preprocessor, self.config,
|
||||
self.logdir, True)
|
||||
self.env_name, 1, self.config, self.logdir, True)
|
||||
for _ in range(config["num_agents"])]
|
||||
self.start_time = time.time()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user