mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 15:55:01 +08:00
[rllib] Include config dicts in the sphinx docs (#3064)
This commit is contained in:
@@ -10,6 +10,7 @@ from ray.rllib.optimizers import AsyncGradientsOptimizer
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.tune.trial import Resources
|
||||
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# Size of rollout batch
|
||||
"sample_batch_size": 10,
|
||||
@@ -34,31 +35,10 @@ DEFAULT_CONFIG = with_common_config({
|
||||
# Workers sample async. Note that this increases the effective
|
||||
# sample_batch_size by up to 5x due to async buffering of batches.
|
||||
"sample_async": True,
|
||||
# Model and preprocessor options
|
||||
"model": {
|
||||
# Use LSTM model. Requires TF.
|
||||
"use_lstm": False,
|
||||
# Max seq length for LSTM training.
|
||||
"max_seq_len": 20,
|
||||
# (Image statespace) - Converts image to Channels = 1
|
||||
"grayscale": True,
|
||||
# (Image statespace) - Each pixel
|
||||
"zero_mean": False,
|
||||
# (Image statespace) - Converts image to (dim, dim, C)
|
||||
"dim": 84,
|
||||
# (Image statespace) - Converts image shape to (C, dim, dim)
|
||||
"channel_major": False,
|
||||
},
|
||||
# Configure TF for single-process operation
|
||||
"tf_session_args": {
|
||||
"intra_op_parallelism_threads": 1,
|
||||
"inter_op_parallelism_threads": 1,
|
||||
"gpu_options": {
|
||||
"allow_growth": True,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
# __sphinx_doc_end__
|
||||
|
||||
|
||||
class A3CAgent(Agent):
|
||||
"""A3C implementations in TensorFlow and PyTorch."""
|
||||
|
||||
@@ -10,6 +10,7 @@ from datetime import datetime
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.models import MODEL_DEFAULTS
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
|
||||
@@ -18,10 +19,11 @@ from ray.tune.trainable import Trainable
|
||||
from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
|
||||
# __sphinx_doc_begin__
|
||||
COMMON_CONFIG = {
|
||||
# Discount factor of the MDP
|
||||
"gamma": 0.99,
|
||||
# Number of steps after which the rollout gets cut
|
||||
# Number of steps after which the episode is forced to terminate
|
||||
"horizon": None,
|
||||
# Number of environments to evaluate vectorwise per worker.
|
||||
"num_envs_per_worker": 1,
|
||||
@@ -36,7 +38,7 @@ COMMON_CONFIG = {
|
||||
"batch_mode": "truncate_episodes",
|
||||
# Whether to use a background thread for sampling (slightly off-policy)
|
||||
"sample_async": False,
|
||||
# Which observation filter to apply to the observation
|
||||
# Element-wise observation filter, either "NoFilter" or "MeanStdFilter"
|
||||
"observation_filter": "NoFilter",
|
||||
# Whether to synchronize the statistics of remote filters.
|
||||
"synchronize_filters": True,
|
||||
@@ -50,14 +52,12 @@ COMMON_CONFIG = {
|
||||
# Environment name can also be passed via config
|
||||
"env": None,
|
||||
# Arguments to pass to model
|
||||
"model": {
|
||||
"use_lstm": False,
|
||||
"max_seq_len": 20,
|
||||
},
|
||||
# Arguments to pass to the rllib optimizer
|
||||
"model": MODEL_DEFAULTS,
|
||||
# Arguments to pass to the policy optimizer. These vary by optimizer.
|
||||
"optimizer": {},
|
||||
# Configure TF for single-process operation by default
|
||||
"tf_session_args": {
|
||||
# note: parallelism_threads is set to auto for the local evaluator
|
||||
"intra_op_parallelism_threads": 1,
|
||||
"inter_op_parallelism_threads": 1,
|
||||
"gpu_options": {
|
||||
@@ -88,6 +88,8 @@ COMMON_CONFIG = {
|
||||
},
|
||||
}
|
||||
|
||||
# __sphinx_doc_end__
|
||||
|
||||
|
||||
def with_common_config(extra_config):
|
||||
"""Returns the given config dict merged with common agent confs."""
|
||||
|
||||
@@ -24,6 +24,7 @@ Result = namedtuple("Result", [
|
||||
"eval_returns", "eval_lengths"
|
||||
])
|
||||
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
"noise_stdev": 0.02, # std deviation of parameter noise
|
||||
"num_rollouts": 32, # number of perturbs to try
|
||||
@@ -34,9 +35,9 @@ DEFAULT_CONFIG = with_common_config({
|
||||
"noise_size": 250000000,
|
||||
"eval_prob": 0.03, # probability of evaluating the parameter rewards
|
||||
"report_length": 10, # how many of the last rewards we average over
|
||||
"env_config": {},
|
||||
"offset": 0,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
|
||||
|
||||
@ray.remote
|
||||
|
||||
@@ -7,7 +7,7 @@ from ray.rllib.utils import merge_dicts
|
||||
from ray.tune.trial import Resources
|
||||
|
||||
APEX_DDPG_DEFAULT_CONFIG = merge_dicts(
|
||||
DDPG_CONFIG,
|
||||
DDPG_CONFIG, # see also the options in ddpg.py, which are also supported
|
||||
{
|
||||
"optimizer_class": "AsyncReplayOptimizer",
|
||||
"optimizer": merge_dicts(
|
||||
|
||||
@@ -13,6 +13,7 @@ OPTIMIZER_SHARED_CONFIGS = [
|
||||
"train_batch_size", "learning_starts"
|
||||
]
|
||||
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# === Model ===
|
||||
# Hidden layer sizes of the policy network
|
||||
@@ -108,6 +109,8 @@ DEFAULT_CONFIG = with_common_config({
|
||||
"min_iter_time_s": 1,
|
||||
})
|
||||
|
||||
# __sphinx_doc_end__
|
||||
|
||||
|
||||
class DDPGAgent(DQNAgent):
|
||||
"""DDPG implementation in TensorFlow."""
|
||||
|
||||
@@ -6,8 +6,9 @@ from ray.rllib.agents.dqn.dqn import DQNAgent, DEFAULT_CONFIG as DQN_CONFIG
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.tune.trial import Resources
|
||||
|
||||
# __sphinx_doc_begin__
|
||||
APEX_DEFAULT_CONFIG = merge_dicts(
|
||||
DQN_CONFIG,
|
||||
DQN_CONFIG, # see also the options in dqn.py, which are also supported
|
||||
{
|
||||
"optimizer_class": "AsyncReplayOptimizer",
|
||||
"optimizer": merge_dicts(
|
||||
@@ -31,6 +32,8 @@ APEX_DEFAULT_CONFIG = merge_dicts(
|
||||
},
|
||||
)
|
||||
|
||||
# __sphinx_doc_end__
|
||||
|
||||
|
||||
class ApexAgent(DQNAgent):
|
||||
"""DQN variant that uses the Ape-X distributed policy optimizer.
|
||||
|
||||
@@ -13,7 +13,7 @@ def wrap_dqn(env, options):
|
||||
|
||||
# Override atari default to use the deepmind wrappers.
|
||||
# TODO(ekl) this logic should be pushed to the catalog.
|
||||
if is_atari and "custom_preprocessor" not in options:
|
||||
if is_atari and not options.get("custom_preprocessor"):
|
||||
return wrap_deepmind(env, dim=options.get("dim", 84))
|
||||
|
||||
return ModelCatalog.get_preprocessor_as_wrapper(env, options)
|
||||
|
||||
@@ -20,6 +20,7 @@ OPTIMIZER_SHARED_CONFIGS = [
|
||||
"learning_starts"
|
||||
]
|
||||
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# === Model ===
|
||||
# Number of atoms for representing the distribution of return. When
|
||||
@@ -116,6 +117,8 @@ DEFAULT_CONFIG = with_common_config({
|
||||
"min_iter_time_s": 1,
|
||||
})
|
||||
|
||||
# __sphinx_doc_end__
|
||||
|
||||
|
||||
class DQNAgent(Agent):
|
||||
"""DQN implementation in TensorFlow."""
|
||||
|
||||
@@ -24,6 +24,7 @@ Result = namedtuple("Result", [
|
||||
"eval_returns", "eval_lengths"
|
||||
])
|
||||
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
"l2_coeff": 0.005,
|
||||
"noise_stdev": 0.02,
|
||||
@@ -36,10 +37,8 @@ DEFAULT_CONFIG = with_common_config({
|
||||
"observation_filter": "MeanStdFilter",
|
||||
"noise_size": 250000000,
|
||||
"report_length": 10,
|
||||
"env": None,
|
||||
"env_config": {},
|
||||
"model": {},
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
|
||||
|
||||
@ray.remote
|
||||
@@ -77,7 +76,8 @@ class Worker(object):
|
||||
|
||||
self.env = env_creator(config["env_config"])
|
||||
from ray.rllib import models
|
||||
self.preprocessor = models.ModelCatalog.get_preprocessor(self.env)
|
||||
self.preprocessor = models.ModelCatalog.get_preprocessor(
|
||||
self.env, config["model"])
|
||||
|
||||
self.sess = utils.make_session(single_threaded=True)
|
||||
self.policy = policies.GenericPolicy(
|
||||
|
||||
@@ -23,6 +23,7 @@ OPTIMIZER_SHARED_CONFIGS = [
|
||||
"max_sample_requests_in_flight_per_worker",
|
||||
]
|
||||
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# V-trace params (see vtrace.py).
|
||||
"vtrace": True,
|
||||
@@ -63,15 +64,10 @@ DEFAULT_CONFIG = with_common_config({
|
||||
# balancing the three losses
|
||||
"vf_loss_coeff": 0.5,
|
||||
"entropy_coeff": -0.01,
|
||||
|
||||
# Model and preprocessor options.
|
||||
"model": {
|
||||
"use_lstm": False,
|
||||
"max_seq_len": 20,
|
||||
"dim": 84,
|
||||
},
|
||||
})
|
||||
|
||||
# __sphinx_doc_end__
|
||||
|
||||
|
||||
class ImpalaAgent(Agent):
|
||||
"""IMPALA implementation using DeepMind's V-trace."""
|
||||
|
||||
@@ -8,20 +8,16 @@ from ray.rllib.optimizers import SyncSamplesOptimizer
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.tune.trial import Resources
|
||||
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# No remote workers by default
|
||||
"num_workers": 0,
|
||||
# Learning rate
|
||||
"lr": 0.0004,
|
||||
# Override model config
|
||||
"model": {
|
||||
# Use LSTM model.
|
||||
"use_lstm": False,
|
||||
# Max seq length for LSTM training.
|
||||
"max_seq_len": 20,
|
||||
},
|
||||
})
|
||||
|
||||
# __sphinx_doc_end__
|
||||
|
||||
|
||||
class PGAgent(Agent):
|
||||
"""Simple policy gradient agent.
|
||||
|
||||
@@ -8,6 +8,7 @@ from ray.rllib.utils import merge_dicts
|
||||
from ray.rllib.optimizers import SyncSamplesOptimizer, LocalMultiGPUOptimizer
|
||||
from ray.tune.trial import Resources
|
||||
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# If true, use the Generalized Advantage Estimator (GAE)
|
||||
# with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
|
||||
@@ -53,15 +54,10 @@ DEFAULT_CONFIG = with_common_config({
|
||||
"observation_filter": "MeanStdFilter",
|
||||
# Use the sync samples optimizer instead of the multi-gpu one
|
||||
"simple_optimizer": False,
|
||||
# Override model config
|
||||
"model": {
|
||||
# Whether to use LSTM model
|
||||
"use_lstm": False,
|
||||
# Max seq length for LSTM training.
|
||||
"max_seq_len": 20,
|
||||
},
|
||||
})
|
||||
|
||||
# __sphinx_doc_end__
|
||||
|
||||
|
||||
class PPOAgent(Agent):
|
||||
"""Multi-GPU optimized implementation of PPO in TensorFlow."""
|
||||
|
||||
@@ -187,7 +187,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
def wrap(env):
|
||||
return env # we can't auto-wrap these env types
|
||||
elif is_atari(self.env) and \
|
||||
"custom_preprocessor" not in model_config and \
|
||||
not model_config.get("custom_preprocessor") and \
|
||||
preprocessor_pref == "deepmind":
|
||||
|
||||
if clip_rewards is None:
|
||||
@@ -196,9 +196,8 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
def wrap(env):
|
||||
env = wrap_deepmind(
|
||||
env,
|
||||
dim=model_config.get("dim", 84),
|
||||
framestack=not model_config.get("use_lstm")
|
||||
and not model_config.get("no_framestack"))
|
||||
dim=model_config.get("dim"),
|
||||
framestack=model_config.get("framestack"))
|
||||
if monitor_path:
|
||||
env = _monitor(env, monitor_path)
|
||||
return env
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.catalog import ModelCatalog, MODEL_DEFAULTS
|
||||
from ray.rllib.models.action_dist import (ActionDistribution, Categorical,
|
||||
DiagGaussian, Deterministic)
|
||||
from ray.rllib.models.model import Model
|
||||
@@ -7,6 +7,14 @@ from ray.rllib.models.fcnet import FullyConnectedNetwork
|
||||
from ray.rllib.models.lstm import LSTM
|
||||
|
||||
__all__ = [
|
||||
"ActionDistribution", "Categorical", "DiagGaussian", "Deterministic",
|
||||
"ModelCatalog", "Model", "Preprocessor", "FullyConnectedNetwork", "LSTM"
|
||||
"ActionDistribution",
|
||||
"Categorical",
|
||||
"DiagGaussian",
|
||||
"Deterministic",
|
||||
"ModelCatalog",
|
||||
"Model",
|
||||
"Preprocessor",
|
||||
"FullyConnectedNetwork",
|
||||
"LSTM",
|
||||
"MODEL_DEFAULTS",
|
||||
]
|
||||
|
||||
@@ -18,29 +18,52 @@ from ray.rllib.models.fcnet import FullyConnectedNetwork
|
||||
from ray.rllib.models.visionnet import VisionNetwork
|
||||
from ray.rllib.models.lstm import LSTM
|
||||
|
||||
MODEL_CONFIGS = [
|
||||
# __sphinx_doc_begin__
|
||||
MODEL_DEFAULTS = {
|
||||
# === Built-in options ===
|
||||
# Filter config. List of [out_channels, kernel, stride] for each filter
|
||||
"conv_filters",
|
||||
"conv_activation", # Nonlinearity for built-in convnet
|
||||
"fcnet_activation", # Nonlinearity for fully connected net (tanh, relu)
|
||||
"fcnet_hiddens", # Number of hidden layers for fully connected net
|
||||
"dim", # Dimension for ATARI
|
||||
"grayscale", # Converts ATARI frame to 1 Channel Grayscale image
|
||||
"zero_mean", # Changes frame to range from [-1, 1] if true
|
||||
"extra_frameskip", # (int) for number of frames to skip
|
||||
"free_log_std", # Documented in ray.rllib.models.Model
|
||||
"channel_major", # Pytorch conv requires images to be channel-major
|
||||
"squash_to_range", # Whether to squash the action output to space range
|
||||
"use_lstm", # Whether to wrap the model with a LSTM
|
||||
"max_seq_len", # Max seq len for training the LSTM, defaults to 20
|
||||
"lstm_cell_size", # Size of the LSTM cell
|
||||
"conv_filters": None,
|
||||
# Nonlinearity for built-in convnet
|
||||
"conv_activation": "relu",
|
||||
# Nonlinearity for fully connected net (tanh, relu)
|
||||
"fcnet_activation": "tanh",
|
||||
# Number of hidden layers for fully connected net
|
||||
"fcnet_hiddens": [256, 256],
|
||||
# For control envs, documented in ray.rllib.models.Model
|
||||
"free_log_std": False,
|
||||
# Whether to squash the action output to space range
|
||||
"squash_to_range": False,
|
||||
|
||||
# == LSTM ==
|
||||
# Whether to wrap the model with a LSTM
|
||||
"use_lstm": False,
|
||||
# Max seq len for training the LSTM, defaults to 20
|
||||
"max_seq_len": 20,
|
||||
# Size of the LSTM cell
|
||||
"lstm_cell_size": 256,
|
||||
|
||||
# == Atari ==
|
||||
# Whether to enable framestack for Atari envs
|
||||
"framestack": True,
|
||||
# Final resized frame dimension
|
||||
"dim": 84,
|
||||
# Pytorch conv requires images to be channel-major
|
||||
"channel_major": False,
|
||||
# (deprecated) Converts ATARI frame to 1 Channel Grayscale image
|
||||
"grayscale": False,
|
||||
# (deprecated) Changes frame to range from [-1, 1] if true
|
||||
"zero_mean": True,
|
||||
|
||||
# === Options for custom models ===
|
||||
"custom_preprocessor", # Name of a custom preprocessor to use
|
||||
"custom_model", # Name of a custom model to use
|
||||
"custom_options", # Extra options to pass to the custom classes
|
||||
]
|
||||
# Name of a custom preprocessor to use
|
||||
"custom_preprocessor": None,
|
||||
# Name of a custom model to use
|
||||
"custom_model": None,
|
||||
# Extra options to pass to the custom classes
|
||||
"custom_options": {},
|
||||
}
|
||||
|
||||
# __sphinx_doc_end__
|
||||
|
||||
|
||||
class ModelCatalog(object):
|
||||
@@ -71,10 +94,7 @@ class ModelCatalog(object):
|
||||
dist_dim (int): The size of the input vector to the distribution.
|
||||
"""
|
||||
|
||||
# TODO(ekl) are list spaces valid?
|
||||
if isinstance(action_space, list):
|
||||
action_space = gym.spaces.Tuple(action_space)
|
||||
config = config or {}
|
||||
config = config or MODEL_DEFAULTS
|
||||
if isinstance(action_space, gym.spaces.Box):
|
||||
if dist_type is None:
|
||||
dist = DiagGaussian
|
||||
@@ -82,7 +102,7 @@ class ModelCatalog(object):
|
||||
dist = squash_to_range(dist, action_space.low,
|
||||
action_space.high)
|
||||
return dist, action_space.shape[0] * 2
|
||||
elif dist_type == 'deterministic':
|
||||
elif dist_type == "deterministic":
|
||||
return Deterministic, action_space.shape[0]
|
||||
elif isinstance(action_space, gym.spaces.Discrete):
|
||||
return Categorical, action_space.n
|
||||
@@ -154,6 +174,7 @@ class ModelCatalog(object):
|
||||
model (Model): Neural network model.
|
||||
"""
|
||||
|
||||
options = options or MODEL_DEFAULTS
|
||||
model = ModelCatalog._get_model(inputs, num_outputs, options, state_in,
|
||||
seq_lens)
|
||||
|
||||
@@ -165,7 +186,7 @@ class ModelCatalog(object):
|
||||
|
||||
@staticmethod
|
||||
def _get_model(inputs, num_outputs, options, state_in, seq_lens):
|
||||
if "custom_model" in options:
|
||||
if options.get("custom_model"):
|
||||
model = options["custom_model"]
|
||||
print("Using custom model {}".format(model))
|
||||
return _global_registry.get(RLLIB_MODEL, model)(
|
||||
@@ -183,7 +204,7 @@ class ModelCatalog(object):
|
||||
return FullyConnectedNetwork(inputs, num_outputs, options)
|
||||
|
||||
@staticmethod
|
||||
def get_torch_model(input_shape, num_outputs, options={}):
|
||||
def get_torch_model(input_shape, num_outputs, options=None):
|
||||
"""Returns a PyTorch suitable model. This is currently only supported
|
||||
in A3C.
|
||||
|
||||
@@ -200,7 +221,8 @@ class ModelCatalog(object):
|
||||
from ray.rllib.models.pytorch.visionnet import (VisionNetwork as
|
||||
PyTorchVisionNet)
|
||||
|
||||
if "custom_model" in options:
|
||||
options = options or MODEL_DEFAULTS
|
||||
if options.get("custom_model"):
|
||||
model = options["custom_model"]
|
||||
print("Using custom torch model {}".format(model))
|
||||
return _global_registry.get(RLLIB_MODEL, model)(
|
||||
@@ -217,7 +239,7 @@ class ModelCatalog(object):
|
||||
return PyTorchFCNet(input_shape[0], num_outputs, options)
|
||||
|
||||
@staticmethod
|
||||
def get_preprocessor(env, options={}):
|
||||
def get_preprocessor(env, options=None):
|
||||
"""Returns a suitable processor for the given environment.
|
||||
|
||||
Args:
|
||||
@@ -227,12 +249,13 @@ class ModelCatalog(object):
|
||||
Returns:
|
||||
preprocessor (Preprocessor): Preprocessor for the env observations.
|
||||
"""
|
||||
options = options or MODEL_DEFAULTS
|
||||
for k in options.keys():
|
||||
if k not in MODEL_CONFIGS:
|
||||
if k not in MODEL_DEFAULTS:
|
||||
raise Exception("Unknown config key `{}`, all keys: {}".format(
|
||||
k, MODEL_CONFIGS))
|
||||
k, list(MODEL_DEFAULTS)))
|
||||
|
||||
if "custom_preprocessor" in options:
|
||||
if options.get("custom_preprocessor"):
|
||||
preprocessor = options["custom_preprocessor"]
|
||||
print("Using custom preprocessor {}".format(preprocessor))
|
||||
return _global_registry.get(RLLIB_PREPROCESSOR, preprocessor)(
|
||||
@@ -242,7 +265,7 @@ class ModelCatalog(object):
|
||||
return preprocessor(env.observation_space, options)
|
||||
|
||||
@staticmethod
|
||||
def get_preprocessor_as_wrapper(env, options={}):
|
||||
def get_preprocessor_as_wrapper(env, options=None):
|
||||
"""Returns a preprocessor as a gym observation wrapper.
|
||||
|
||||
Args:
|
||||
@@ -253,6 +276,7 @@ class ModelCatalog(object):
|
||||
wrapper (gym.ObservationWrapper): Preprocessor in wrapper form.
|
||||
"""
|
||||
|
||||
options = options or MODEL_DEFAULTS
|
||||
preprocessor = ModelCatalog.get_preprocessor(env, options)
|
||||
return _RLlibPreprocessorWrapper(env, preprocessor)
|
||||
|
||||
|
||||
@@ -13,8 +13,8 @@ class FullyConnectedNetwork(Model):
|
||||
"""Generic fully connected network."""
|
||||
|
||||
def _build_layers(self, inputs, num_outputs, options):
|
||||
hiddens = options.get("fcnet_hiddens", [256, 256])
|
||||
activation = get_activation_fn(options.get("fcnet_activation", "tanh"))
|
||||
hiddens = options.get("fcnet_hiddens")
|
||||
activation = get_activation_fn(options.get("fcnet_activation"))
|
||||
|
||||
with tf.name_scope("fc_net"):
|
||||
i = 1
|
||||
|
||||
@@ -135,7 +135,7 @@ class LSTM(Model):
|
||||
"""
|
||||
|
||||
def _build_layers(self, inputs, num_outputs, options):
|
||||
cell_size = options.get("lstm_cell_size", 256)
|
||||
cell_size = options.get("lstm_cell_size")
|
||||
last_layer = add_time_dimension(inputs, self.seq_lens)
|
||||
|
||||
# Setup the LSTM cell
|
||||
|
||||
@@ -55,7 +55,7 @@ class Model(object):
|
||||
self.seq_lens = tf.placeholder(
|
||||
dtype=tf.int32, shape=[None], name="seq_lens")
|
||||
|
||||
if options.get("free_log_std", False):
|
||||
if options.get("free_log_std"):
|
||||
assert num_outputs % 2 == 0
|
||||
num_outputs = num_outputs // 2
|
||||
self.outputs, self.last_layer = self._build_layers(
|
||||
|
||||
@@ -30,12 +30,18 @@ class Preprocessor(object):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AtariPixelPreprocessor(Preprocessor):
|
||||
class GenericPixelPreprocessor(Preprocessor):
|
||||
"""Generic image preprocessor.
|
||||
|
||||
Note: for Atari games, use config {"preprocessor_pref": "deepmind"}
|
||||
instead for deepmind-style Atari preprocessing.
|
||||
"""
|
||||
|
||||
def _init(self):
|
||||
self._grayscale = self._options.get("grayscale", False)
|
||||
self._zero_mean = self._options.get("zero_mean", True)
|
||||
self._dim = self._options.get("dim", 84)
|
||||
self._channel_major = self._options.get("channel_major", False)
|
||||
self._grayscale = self._options.get("grayscale")
|
||||
self._zero_mean = self._options.get("zero_mean")
|
||||
self._dim = self._options.get("dim")
|
||||
self._channel_major = self._options.get("channel_major")
|
||||
if self._grayscale:
|
||||
self.shape = (self._dim, self._dim, 1)
|
||||
else:
|
||||
@@ -130,7 +136,7 @@ def get_preprocessor(space):
|
||||
if isinstance(space, gym.spaces.Discrete):
|
||||
preprocessor = OneHotPreprocessor
|
||||
elif obs_shape == ATARI_OBS_SHAPE:
|
||||
preprocessor = AtariPixelPreprocessor
|
||||
preprocessor = GenericPixelPreprocessor
|
||||
elif obs_shape == ATARI_RAM_OBS_SHAPE:
|
||||
preprocessor = AtariRamPreprocessor
|
||||
elif isinstance(space, gym.spaces.Tuple):
|
||||
|
||||
@@ -18,11 +18,11 @@ class VisionNetwork(Model):
|
||||
inputs (tuple): (channels, rows/height, cols/width)
|
||||
num_outputs (int): logits size
|
||||
"""
|
||||
filters = options.get("conv_filters", [
|
||||
filters = options.get("conv_filters") or [
|
||||
[16, [8, 8], 4],
|
||||
[32, [4, 4], 2],
|
||||
[512, [11, 11], 1],
|
||||
])
|
||||
]
|
||||
layers = []
|
||||
in_channels, in_size = inputs[0], inputs[1:]
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ class VisionNetwork(Model):
|
||||
if not filters:
|
||||
filters = get_filter_config(options)
|
||||
|
||||
activation = get_activation_fn(options.get("conv_activation", "relu"))
|
||||
activation = get_activation_fn(options.get("conv_activation"))
|
||||
|
||||
with tf.name_scope("vision_net"):
|
||||
for i, (out_size, kernel, stride) in enumerate(filters[:-1], 1):
|
||||
@@ -57,7 +57,7 @@ def get_filter_config(options):
|
||||
[32, [4, 4], 2],
|
||||
[256, [11, 11], 1],
|
||||
]
|
||||
dim = options.get("dim", 84)
|
||||
dim = options.get("dim")
|
||||
if dim == 84:
|
||||
return filters_84x84
|
||||
elif dim == 42:
|
||||
|
||||
@@ -19,10 +19,6 @@ ACTION_SPACES_TO_TEST = {
|
||||
Box(0.0, 1.0, (5, ), dtype=np.float32),
|
||||
Box(0.0, 1.0, (5, ), dtype=np.float32)
|
||||
]),
|
||||
"implicit_tuple": [
|
||||
Box(0.0, 1.0, (5, ), dtype=np.float32),
|
||||
Box(0.0, 1.0, (5, ), dtype=np.float32)
|
||||
],
|
||||
"mixed_tuple": Tuple(
|
||||
[Discrete(2),
|
||||
Discrete(3),
|
||||
|
||||
Reference in New Issue
Block a user