Files
ray/python/ray/rllib/models/catalog.py
T
Eric Liang b6a18cb39b [rllib] Also refactor DQN to use shared RLlib models (#730)
* wip

* works with cartpole

* lint

* fix pg

* comment

* action dist rename

* preprocessor

* fix test

* typo

* fix the action[0] nonsense

* revert

* satisfy the lint

* wip

* works with cartpole

* lint

* fix pg

* comment

* action dist rename

* preprocessor

* fix test

* typo

* fix the action[0] nonsense

* revert

* satisfy the lint

* Minor indentation changes.

* fix merge

* add humanoid

* initial dqn refactor

* remove tfutil

* fix calls

* fix tf errors 1

* closer

* runs now

* lint

* tensorboard graph

* fix linting

* more 4 space

* fix

* fix linT

* more lint

* oops

* es parity

* remove example.py

* fix training bug

* add cartpole demo

* try fixing cartpole

* allow model options, configure cartpole

* debug

* simplify

* no dueling

* avoid out of file handles

* Test dqn in jenkins.

* Minor formatting.

* fix issue

* fix another

* Fix problem in which we log to a directory that hasn't been created.
2017-07-26 12:29:00 -07:00

81 lines
2.5 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gym
from ray.rllib.models.action_dist import (
Categorical, Deterministic, DiagGaussian)
from ray.rllib.models.fcnet import FullyConnectedNetwork
from ray.rllib.models.visionnet import VisionNetwork
class ModelCatalog(object):
"""Registry of default models and action distributions for envs.
Example:
dist_class, dist_dim = ModelCatalog.get_action_dist(env.action_space)
model = ModelCatalog.get_model(inputs, dist_dim)
dist = dist_class(model.outputs)
action_op = dist.sample()
"""
@staticmethod
def get_action_dist(action_space, dist_type=None):
"""Returns action distribution class and size for the given action space.
Args:
action_space (Space): Action space of the target gym env.
Returns:
dist_class (ActionDistribution): Python class of the distribution.
dist_dim (int): The size of the input vector to the distribution.
"""
if isinstance(action_space, gym.spaces.Box):
if dist_type is None:
return DiagGaussian, action_space.shape[0] * 2
elif dist_type == 'deterministic':
return Deterministic, action_space.shape[0]
elif isinstance(action_space, gym.spaces.Discrete):
return Categorical, action_space.n
raise NotImplementedError(
"Unsupported args: {} {}".format(action_space, dist_type))
@staticmethod
def get_model(inputs, num_outputs, options=None):
"""Returns a suitable model conforming to given input and output specs.
Args:
inputs (Tensor): The input tensor to the model.
num_outputs (int): The size of the output vector of the model.
options (dict): Optional args to pass to the model constructor.
Returns:
model (Model): Neural network model.
"""
if options is None:
options = {}
obs_rank = len(inputs.get_shape()) - 1
if obs_rank > 1:
return VisionNetwork(inputs, num_outputs, options)
return FullyConnectedNetwork(inputs, num_outputs, options)
@staticmethod
def get_preprocessor(env_name):
"""Returns a suitable processor for the given environment.
Args:
env_name (str): The name of the environment.
Returns:
preprocessor (Preprocessor): Preprocessor for the env observations.
"""
raise NotImplementedError