[rllib] Add downscale and frameskip options for Montezumas (#908)

* up

* update

* fix

* update

* update

* update

* api break

* Update run_multi_node_tests.sh

* fix
This commit is contained in:
Eric Liang
2017-09-02 17:20:56 -07:00
committed by Philipp Moritz
parent 7a36430399
commit 1ebfe9608f
12 changed files with 88 additions and 22 deletions
+2 -2
View File
@@ -26,7 +26,7 @@ from ray.rllib.dqn.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
whether to use double dqn
hiddens: array<int>
hidden layer sizes of the state and action value networks
model_config: dict
model: dict
config options to pass to the model constructor
lr: float
learning rate for adam optimizer
@@ -79,7 +79,7 @@ DEFAULT_CONFIG = dict(
dueling=True,
double_q=True,
hiddens=[256],
model_config={},
model={},
lr=5e-4,
schedule_max_timesteps=100000,
timesteps_per_iteration=1000,
+1 -1
View File
@@ -11,7 +11,7 @@ from ray.rllib.models import ModelCatalog
def _build_q_network(inputs, num_actions, config):
dueling = config["dueling"]
hiddens = config["hiddens"]
frontend = ModelCatalog.get_model(inputs, 1, config["model_config"])
frontend = ModelCatalog.get_model(inputs, 1, config["model"])
frontend_out = frontend.last_layer
with tf.variable_scope("action_value"):
+20 -4
View File
@@ -12,6 +12,16 @@ from ray.rllib.models.fcnet import FullyConnectedNetwork
from ray.rllib.models.visionnet import VisionNetwork
MODEL_CONFIGS = [
"conv_filters",
"downscale_factor",
"extra_frameskip",
"fcnet_activation",
"fcnet_hiddens",
"free_log_std"
]
class ModelCatalog(object):
"""Registry of default models and action distributions for envs.
@@ -67,7 +77,7 @@ class ModelCatalog(object):
return FullyConnectedNetwork(inputs, num_outputs, options)
@staticmethod
def get_preprocessor(env_name, obs_shape):
def get_preprocessor(env_name, obs_shape, options=dict()):
"""Returns a suitable processor for the given environment.
Args:
@@ -81,12 +91,18 @@ class ModelCatalog(object):
ATARI_OBS_SHAPE = (210, 160, 3)
ATARI_RAM_OBS_SHAPE = (128,)
for k in options.keys():
if k not in MODEL_CONFIGS:
raise Exception(
"Unknown config key `{}`, all keys: {}".format(
k, MODEL_CONFIGS))
if obs_shape == ATARI_OBS_SHAPE:
print("Assuming Atari pixel env, using AtariPixelPreprocessor.")
return AtariPixelPreprocessor()
return AtariPixelPreprocessor(options)
elif obs_shape == ATARI_RAM_OBS_SHAPE:
print("Assuming Atari ram env, using AtariRamPreprocessor.")
return AtariRamPreprocessor()
return AtariRamPreprocessor(options)
print("Non-atari env, not using any observation preprocessor.")
return NoPreprocessor()
return NoPreprocessor(options)
+16 -3
View File
@@ -6,6 +6,13 @@ from __future__ import print_function
class Preprocessor(object):
"""Defines an abstract observation preprocessor function."""
def __init__(self, options):
self.options = options
self._init()
def _init(self):
pass
def transform_shape(self, obs_shape):
"""Returns the preprocessed observation shape."""
raise NotImplementedError
@@ -16,13 +23,19 @@ class Preprocessor(object):
class AtariPixelPreprocessor(Preprocessor):
def _init(self):
self.downscale_factor = self.options.get("downscale_factor", 2)
self.dim = int(160 / self.downscale_factor)
def transform_shape(self, obs_shape):
return (80, 80, 3)
return (self.dim, self.dim, 3)
# TODO(ekl) why does this need to return an extra size-1 dim (the [None])
def transform(self, observation):
"""Downsamples images from (210, 160, 3) to (80, 80, 3)."""
return (observation[25:-25:2, ::2, :][None] - 128) / 128
"""Downsamples images from (210, 160, 3) by the configured factor."""
scaled = observation[
25:-25:self.downscale_factor, ::self.downscale_factor, :][None]
return (scaled - 128) / 128
class AtariRamPreprocessor(Preprocessor):
+11 -3
View File
@@ -12,11 +12,19 @@ class VisionNetwork(Model):
"""Generic vision network."""
def _init(self, inputs, num_outputs, options):
filters = options.get("conv_filters", [
[16, [8, 8], 4],
[32, [4, 4], 2],
[512, [10, 10], 1],
])
with tf.name_scope("vision_net"):
conv1 = slim.conv2d(inputs, 16, [8, 8], 4, scope="conv1")
conv2 = slim.conv2d(conv1, 32, [4, 4], 2, scope="conv2")
for i, (out_size, kernel, stride) in enumerate(filters[:-1], 1):
inputs = slim.conv2d(
inputs, out_size, kernel, stride,
scope="conv{}".format(i))
out_size, kernel, stride = filters[-1]
fc1 = slim.conv2d(
conv2, 512, [10, 10], padding="VALID", scope="fc1")
inputs, out_size, kernel, stride, padding="VALID", scope="fc1")
fc2 = slim.conv2d(fc1, num_outputs, [1, 1], activation_fn=None,
normalizer_fn=None, scope="fc2")
return tf.squeeze(fc2, [1, 2]), tf.squeeze(fc1, [1, 2])
+10 -3
View File
@@ -10,13 +10,15 @@ from ray.rllib.models import ModelCatalog
class BatchedEnv(object):
"""This holds multiple gym envs and performs steps on all of them."""
def __init__(self, name, batchsize):
def __init__(self, name, batchsize, options):
self.envs = [gym.make(name) for _ in range(batchsize)]
self.observation_space = self.envs[0].observation_space
self.action_space = self.envs[0].action_space
self.batchsize = batchsize
self.preprocessor = ModelCatalog.get_preprocessor(
name, self.envs[0].observation_space.shape)
name, self.envs[0].observation_space.shape, options["model"])
self.extra_frameskip = options.get("extra_frameskip", 1)
assert self.extra_frameskip >= 1
def reset(self):
observations = [
@@ -33,7 +35,12 @@ class BatchedEnv(object):
observations.append(np.zeros(self.shape))
rewards.append(0.0)
continue
observation, reward, done, info = self.envs[i].step(action)
reward = 0.0
for j in range(self.extra_frameskip):
observation, r, done, info = self.envs[i].step(action)
reward += r
if done:
break
if render:
self.envs[0].render()
observations.append(self.preprocessor.transform(observation))
+3
View File
@@ -53,7 +53,10 @@ DEFAULT_CONFIG = {
"clip_param": 0.3,
# Target value for KL divergence
"kl_target": 0.01,
# Config params to pass to the model
"model": {"free_log_std": False},
# If >1, adds frameskip
"extra_frameskip": 1,
# Number of timesteps collected in each outer loop
"timesteps_per_batch": 40000,
# Each tasks performs rollouts until at least this
+1 -1
View File
@@ -46,7 +46,7 @@ class Runner(object):
self.devices = devices
self.config = config
self.logdir = logdir
self.env = BatchedEnv(name, batchsize)
self.env = BatchedEnv(name, batchsize, config)
if is_remote:
config_proto = tf.ConfigProto()
else:
+12 -4
View File
@@ -41,25 +41,33 @@ if __name__ == "__main__":
ray.init(redis_address=args.redis_address)
def _check_and_update(config, json):
for k in json.keys():
if k not in config:
raise Exception(
"Unknown model config `{}`, all model configs: {}".format(
k, config.keys()))
config.update(json)
env_name = args.env
if args.alg == "PPO":
config = ppo.DEFAULT_CONFIG.copy()
config.update(json_config)
_check_and_update(config, json_config)
alg = ppo.PPOAgent(
env_name, config, upload_dir=args.upload_dir)
elif args.alg == "ES":
config = es.DEFAULT_CONFIG.copy()
config.update(json_config)
_check_and_update(config, json_config)
alg = es.ESAgent(
env_name, config, upload_dir=args.upload_dir)
elif args.alg == "DQN":
config = dqn.DEFAULT_CONFIG.copy()
config.update(json_config)
_check_and_update(config, json_config)
alg = dqn.DQNAgent(
env_name, config, upload_dir=args.upload_dir)
elif args.alg == "A3C":
config = a3c.DEFAULT_CONFIG.copy()
config.update(json_config)
_check_and_update(config, json_config)
alg = a3c.A3CAgent(
env_name, config, upload_dir=args.upload_dir)
else: