mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:38:18 +08:00
[rllib] Add downscale and frameskip options for Montezumas (#908)
* up * update * fix * update * update * update * api break * Update run_multi_node_tests.sh * fix
This commit is contained in:
committed by
Philipp Moritz
parent
7a36430399
commit
1ebfe9608f
@@ -26,7 +26,7 @@ from ray.rllib.dqn.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
|
||||
whether to use double dqn
|
||||
hiddens: array<int>
|
||||
hidden layer sizes of the state and action value networks
|
||||
model_config: dict
|
||||
model: dict
|
||||
config options to pass to the model constructor
|
||||
lr: float
|
||||
learning rate for adam optimizer
|
||||
@@ -79,7 +79,7 @@ DEFAULT_CONFIG = dict(
|
||||
dueling=True,
|
||||
double_q=True,
|
||||
hiddens=[256],
|
||||
model_config={},
|
||||
model={},
|
||||
lr=5e-4,
|
||||
schedule_max_timesteps=100000,
|
||||
timesteps_per_iteration=1000,
|
||||
|
||||
@@ -11,7 +11,7 @@ from ray.rllib.models import ModelCatalog
|
||||
def _build_q_network(inputs, num_actions, config):
|
||||
dueling = config["dueling"]
|
||||
hiddens = config["hiddens"]
|
||||
frontend = ModelCatalog.get_model(inputs, 1, config["model_config"])
|
||||
frontend = ModelCatalog.get_model(inputs, 1, config["model"])
|
||||
frontend_out = frontend.last_layer
|
||||
|
||||
with tf.variable_scope("action_value"):
|
||||
|
||||
@@ -12,6 +12,16 @@ from ray.rllib.models.fcnet import FullyConnectedNetwork
|
||||
from ray.rllib.models.visionnet import VisionNetwork
|
||||
|
||||
|
||||
MODEL_CONFIGS = [
|
||||
"conv_filters",
|
||||
"downscale_factor",
|
||||
"extra_frameskip",
|
||||
"fcnet_activation",
|
||||
"fcnet_hiddens",
|
||||
"free_log_std"
|
||||
]
|
||||
|
||||
|
||||
class ModelCatalog(object):
|
||||
"""Registry of default models and action distributions for envs.
|
||||
|
||||
@@ -67,7 +77,7 @@ class ModelCatalog(object):
|
||||
return FullyConnectedNetwork(inputs, num_outputs, options)
|
||||
|
||||
@staticmethod
|
||||
def get_preprocessor(env_name, obs_shape):
|
||||
def get_preprocessor(env_name, obs_shape, options=dict()):
|
||||
"""Returns a suitable processor for the given environment.
|
||||
|
||||
Args:
|
||||
@@ -81,12 +91,18 @@ class ModelCatalog(object):
|
||||
ATARI_OBS_SHAPE = (210, 160, 3)
|
||||
ATARI_RAM_OBS_SHAPE = (128,)
|
||||
|
||||
for k in options.keys():
|
||||
if k not in MODEL_CONFIGS:
|
||||
raise Exception(
|
||||
"Unknown config key `{}`, all keys: {}".format(
|
||||
k, MODEL_CONFIGS))
|
||||
|
||||
if obs_shape == ATARI_OBS_SHAPE:
|
||||
print("Assuming Atari pixel env, using AtariPixelPreprocessor.")
|
||||
return AtariPixelPreprocessor()
|
||||
return AtariPixelPreprocessor(options)
|
||||
elif obs_shape == ATARI_RAM_OBS_SHAPE:
|
||||
print("Assuming Atari ram env, using AtariRamPreprocessor.")
|
||||
return AtariRamPreprocessor()
|
||||
return AtariRamPreprocessor(options)
|
||||
|
||||
print("Non-atari env, not using any observation preprocessor.")
|
||||
return NoPreprocessor()
|
||||
return NoPreprocessor(options)
|
||||
|
||||
@@ -6,6 +6,13 @@ from __future__ import print_function
|
||||
class Preprocessor(object):
|
||||
"""Defines an abstract observation preprocessor function."""
|
||||
|
||||
def __init__(self, options):
|
||||
self.options = options
|
||||
self._init()
|
||||
|
||||
def _init(self):
|
||||
pass
|
||||
|
||||
def transform_shape(self, obs_shape):
|
||||
"""Returns the preprocessed observation shape."""
|
||||
raise NotImplementedError
|
||||
@@ -16,13 +23,19 @@ class Preprocessor(object):
|
||||
|
||||
|
||||
class AtariPixelPreprocessor(Preprocessor):
|
||||
def _init(self):
|
||||
self.downscale_factor = self.options.get("downscale_factor", 2)
|
||||
self.dim = int(160 / self.downscale_factor)
|
||||
|
||||
def transform_shape(self, obs_shape):
|
||||
return (80, 80, 3)
|
||||
return (self.dim, self.dim, 3)
|
||||
|
||||
# TODO(ekl) why does this need to return an extra size-1 dim (the [None])
|
||||
def transform(self, observation):
|
||||
"""Downsamples images from (210, 160, 3) to (80, 80, 3)."""
|
||||
return (observation[25:-25:2, ::2, :][None] - 128) / 128
|
||||
"""Downsamples images from (210, 160, 3) by the configured factor."""
|
||||
scaled = observation[
|
||||
25:-25:self.downscale_factor, ::self.downscale_factor, :][None]
|
||||
return (scaled - 128) / 128
|
||||
|
||||
|
||||
class AtariRamPreprocessor(Preprocessor):
|
||||
|
||||
@@ -12,11 +12,19 @@ class VisionNetwork(Model):
|
||||
"""Generic vision network."""
|
||||
|
||||
def _init(self, inputs, num_outputs, options):
|
||||
filters = options.get("conv_filters", [
|
||||
[16, [8, 8], 4],
|
||||
[32, [4, 4], 2],
|
||||
[512, [10, 10], 1],
|
||||
])
|
||||
with tf.name_scope("vision_net"):
|
||||
conv1 = slim.conv2d(inputs, 16, [8, 8], 4, scope="conv1")
|
||||
conv2 = slim.conv2d(conv1, 32, [4, 4], 2, scope="conv2")
|
||||
for i, (out_size, kernel, stride) in enumerate(filters[:-1], 1):
|
||||
inputs = slim.conv2d(
|
||||
inputs, out_size, kernel, stride,
|
||||
scope="conv{}".format(i))
|
||||
out_size, kernel, stride = filters[-1]
|
||||
fc1 = slim.conv2d(
|
||||
conv2, 512, [10, 10], padding="VALID", scope="fc1")
|
||||
inputs, out_size, kernel, stride, padding="VALID", scope="fc1")
|
||||
fc2 = slim.conv2d(fc1, num_outputs, [1, 1], activation_fn=None,
|
||||
normalizer_fn=None, scope="fc2")
|
||||
return tf.squeeze(fc2, [1, 2]), tf.squeeze(fc1, [1, 2])
|
||||
|
||||
@@ -10,13 +10,15 @@ from ray.rllib.models import ModelCatalog
|
||||
|
||||
class BatchedEnv(object):
|
||||
"""This holds multiple gym envs and performs steps on all of them."""
|
||||
def __init__(self, name, batchsize):
|
||||
def __init__(self, name, batchsize, options):
|
||||
self.envs = [gym.make(name) for _ in range(batchsize)]
|
||||
self.observation_space = self.envs[0].observation_space
|
||||
self.action_space = self.envs[0].action_space
|
||||
self.batchsize = batchsize
|
||||
self.preprocessor = ModelCatalog.get_preprocessor(
|
||||
name, self.envs[0].observation_space.shape)
|
||||
name, self.envs[0].observation_space.shape, options["model"])
|
||||
self.extra_frameskip = options.get("extra_frameskip", 1)
|
||||
assert self.extra_frameskip >= 1
|
||||
|
||||
def reset(self):
|
||||
observations = [
|
||||
@@ -33,7 +35,12 @@ class BatchedEnv(object):
|
||||
observations.append(np.zeros(self.shape))
|
||||
rewards.append(0.0)
|
||||
continue
|
||||
observation, reward, done, info = self.envs[i].step(action)
|
||||
reward = 0.0
|
||||
for j in range(self.extra_frameskip):
|
||||
observation, r, done, info = self.envs[i].step(action)
|
||||
reward += r
|
||||
if done:
|
||||
break
|
||||
if render:
|
||||
self.envs[0].render()
|
||||
observations.append(self.preprocessor.transform(observation))
|
||||
|
||||
@@ -53,7 +53,10 @@ DEFAULT_CONFIG = {
|
||||
"clip_param": 0.3,
|
||||
# Target value for KL divergence
|
||||
"kl_target": 0.01,
|
||||
# Config params to pass to the model
|
||||
"model": {"free_log_std": False},
|
||||
# If >1, adds frameskip
|
||||
"extra_frameskip": 1,
|
||||
# Number of timesteps collected in each outer loop
|
||||
"timesteps_per_batch": 40000,
|
||||
# Each tasks performs rollouts until at least this
|
||||
|
||||
@@ -46,7 +46,7 @@ class Runner(object):
|
||||
self.devices = devices
|
||||
self.config = config
|
||||
self.logdir = logdir
|
||||
self.env = BatchedEnv(name, batchsize)
|
||||
self.env = BatchedEnv(name, batchsize, config)
|
||||
if is_remote:
|
||||
config_proto = tf.ConfigProto()
|
||||
else:
|
||||
|
||||
@@ -41,25 +41,33 @@ if __name__ == "__main__":
|
||||
|
||||
ray.init(redis_address=args.redis_address)
|
||||
|
||||
def _check_and_update(config, json):
|
||||
for k in json.keys():
|
||||
if k not in config:
|
||||
raise Exception(
|
||||
"Unknown model config `{}`, all model configs: {}".format(
|
||||
k, config.keys()))
|
||||
config.update(json)
|
||||
|
||||
env_name = args.env
|
||||
if args.alg == "PPO":
|
||||
config = ppo.DEFAULT_CONFIG.copy()
|
||||
config.update(json_config)
|
||||
_check_and_update(config, json_config)
|
||||
alg = ppo.PPOAgent(
|
||||
env_name, config, upload_dir=args.upload_dir)
|
||||
elif args.alg == "ES":
|
||||
config = es.DEFAULT_CONFIG.copy()
|
||||
config.update(json_config)
|
||||
_check_and_update(config, json_config)
|
||||
alg = es.ESAgent(
|
||||
env_name, config, upload_dir=args.upload_dir)
|
||||
elif args.alg == "DQN":
|
||||
config = dqn.DEFAULT_CONFIG.copy()
|
||||
config.update(json_config)
|
||||
_check_and_update(config, json_config)
|
||||
alg = dqn.DQNAgent(
|
||||
env_name, config, upload_dir=args.upload_dir)
|
||||
elif args.alg == "A3C":
|
||||
config = a3c.DEFAULT_CONFIG.copy()
|
||||
config.update(json_config)
|
||||
_check_and_update(config, json_config)
|
||||
alg = a3c.A3CAgent(
|
||||
env_name, config, upload_dir=args.upload_dir)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user