Files
ray/python/ray/rllib/models/catalog.py
T
Richard Liaw bc082e9a9e [rllib] Additional support for Shared Models in A3C (#866)
* Code for Supporting Shared Models

Running (with vnet modification) - needs to be tested for performance

Summaries

Small refactoring + generalized to more domains

Small fix for jenkins

Linting

linting

Addressing changes

Addressing changes

Update envs.py

Addressing changes

convnet

Merge - new model

final touches

final linting

Changing iterations back

removed extra change

changes for fast experimentation

changes to enable a2c

TEMP FOR DEBUGGING

ContinuousActions - Still doesn't work

InvertedPendulum trains with 8 workers - k=200

huber loss

Maxes for InvertedPendulum-v1 - 16w,200steps

temp: working with a2c

Back to shared model

more fixes

small

nit

LSTM to shared models

need to fix last_features

tuning pong

Best record for hitting 0 - with k=16,n=20

nit

a2cremoval

remove A2c reference and nits

nit

removed a2c vestiges

removing a2c

removing example.py

Linting

nit

* Linting + Removing vestigal code

* Final Touches

* nits

* rerun travis
2017-08-28 12:23:14 -07:00

93 lines
3.2 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gym
from ray.rllib.models.action_dist import (
Categorical, Deterministic, DiagGaussian)
from ray.rllib.models.preprocessors import (
NoPreprocessor, AtariRamPreprocessor, AtariPixelPreprocessor)
from ray.rllib.models.fcnet import FullyConnectedNetwork
from ray.rllib.models.visionnet import VisionNetwork
class ModelCatalog(object):
"""Registry of default models and action distributions for envs.
Example:
dist_class, dist_dim = ModelCatalog.get_action_dist(env.action_space)
model = ModelCatalog.get_model(inputs, dist_dim)
dist = dist_class(model.outputs)
action_op = dist.sample()
"""
@staticmethod
def get_action_dist(action_space, dist_type=None):
"""Returns action distribution class and size for the given action space.
Args:
action_space (Space): Action space of the target gym env.
dist_type (Optional[str]): Identifier of the action distribution.
Returns:
dist_class (ActionDistribution): Python class of the distribution.
dist_dim (int): The size of the input vector to the distribution.
"""
if isinstance(action_space, gym.spaces.Box):
if dist_type is None:
return DiagGaussian, action_space.shape[0] * 2
elif dist_type == 'deterministic':
return Deterministic, action_space.shape[0]
elif isinstance(action_space, gym.spaces.Discrete):
return Categorical, action_space.n
raise NotImplementedError(
"Unsupported args: {} {}".format(action_space, dist_type))
@staticmethod
def get_model(inputs, num_outputs, options=dict()):
"""Returns a suitable model conforming to given input and output specs.
Args:
inputs (Tensor): The input tensor to the model.
num_outputs (int): The size of the output vector of the model.
options (dict): Optional args to pass to the model constructor.
Returns:
model (Model): Neural network model.
"""
obs_rank = len(inputs.get_shape()) - 1
if obs_rank > 1:
return VisionNetwork(inputs, num_outputs, options)
return FullyConnectedNetwork(inputs, num_outputs, options)
@staticmethod
def get_preprocessor(env_name, obs_shape):
"""Returns a suitable processor for the given environment.
Args:
env_name (str): The name of the environment.
obs_shape (tuple): The shape of the env observation space.
Returns:
preprocessor (Preprocessor): Preprocessor for the env observations.
"""
ATARI_OBS_SHAPE = (210, 160, 3)
ATARI_RAM_OBS_SHAPE = (128,)
if obs_shape == ATARI_OBS_SHAPE:
print("Assuming Atari pixel env, using AtariPixelPreprocessor.")
return AtariPixelPreprocessor()
elif obs_shape == ATARI_RAM_OBS_SHAPE:
print("Assuming Atari ram env, using AtariRamPreprocessor.")
return AtariRamPreprocessor()
print("Non-atari env, not using any observation preprocessor.")
return NoPreprocessor()