From e3c9f7e83a6007ded7ae7e99fcbe9fcaa371bad3 Mon Sep 17 00:00:00 2001 From: "Matthew A. Wright" Date: Tue, 6 Aug 2019 18:13:16 +0000 Subject: [PATCH] Custom action distributions (#5164) * custom action dist wip * Test case for custom action dist * ActionDistribution.get_parameter_shape_for_action_space pattern * Edit exception message to also suggest using a custom action distribution * Clean up ModelCatalog.get_action_dist * Pass model config to ActionDistribution constructors * Update custom action distribution test case * Name fix * Autoformatter * parameter shape static methods for torch distributions * Fix docstring * Generalize fake array for graph initialization * Fix action dist constructors * Correct parameter shape static methods for multicategorical and gaussian * Make suggested changes to custom action dist's * Correct instances of not passing model config to action dist * Autoformatter * fix tuple distribution constructor * bugfix --- python/ray/tune/registry.py | 4 +- rllib/agents/a3c/a3c_torch_policy.py | 2 +- rllib/agents/ars/policies.py | 2 +- rllib/agents/dqn/dqn_policy.py | 8 +-- rllib/agents/es/policies.py | 2 +- rllib/agents/impala/vtrace.py | 19 ++++--- rllib/agents/impala/vtrace_policy.py | 6 ++- rllib/agents/impala/vtrace_test.py | 11 ++-- rllib/agents/marwil/marwil_policy.py | 11 ++-- rllib/agents/pg/torch_pg_policy.py | 3 +- rllib/agents/ppo/appo_policy.py | 10 +++- rllib/agents/ppo/ppo_policy.py | 12 +++-- rllib/agents/ppo/test/test.py | 2 +- rllib/examples/custom_loss.py | 2 +- rllib/examples/custom_torch_policy.py | 2 +- rllib/models/action_dist.py | 24 ++++++++- rllib/models/catalog.py | 70 +++++++++++++++++-------- rllib/models/tf/tf_action_dist.py | 52 ++++++++++++++---- rllib/models/torch/torch_action_dist.py | 18 ++++++- rllib/policy/dynamic_tf_policy.py | 5 +- rllib/policy/torch_policy.py | 3 +- rllib/tests/test_catalog.py | 57 +++++++++++++++++++- 22 files changed, 252 insertions(+), 73 deletions(-) diff --git a/python/ray/tune/registry.py b/python/ray/tune/registry.py index 6202013c7..4c3c09b90 100644 --- a/python/ray/tune/registry.py +++ b/python/ray/tune/registry.py @@ -14,8 +14,10 @@ TRAINABLE_CLASS = "trainable_class" ENV_CREATOR = "env_creator" RLLIB_MODEL = "rllib_model" RLLIB_PREPROCESSOR = "rllib_preprocessor" +RLLIB_ACTION_DIST = "rllib_action_dist" KNOWN_CATEGORIES = [ - TRAINABLE_CLASS, ENV_CREATOR, RLLIB_MODEL, RLLIB_PREPROCESSOR + TRAINABLE_CLASS, ENV_CREATOR, RLLIB_MODEL, RLLIB_PREPROCESSOR, + RLLIB_ACTION_DIST ] logger = logging.getLogger(__name__) diff --git a/rllib/agents/a3c/a3c_torch_policy.py b/rllib/agents/a3c/a3c_torch_policy.py index 8045c397f..f14f1f16d 100644 --- a/rllib/agents/a3c/a3c_torch_policy.py +++ b/rllib/agents/a3c/a3c_torch_policy.py @@ -18,7 +18,7 @@ def actor_critic_loss(policy, batch_tensors): SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS] }) # TODO(ekl) seq lens shouldn't be None values = policy.model.value_function() - dist = policy.dist_class(logits) + dist = policy.dist_class(logits, policy.config["model"]) log_probs = dist.logp(batch_tensors[SampleBatch.ACTIONS]) policy.entropy = dist.entropy().mean() policy.pi_err = -batch_tensors[Postprocessing.ADVANTAGES].dot( diff --git a/rllib/agents/ars/policies.py b/rllib/agents/ars/policies.py index 7fdb54b99..6029241c9 100644 --- a/rllib/agents/ars/policies.py +++ b/rllib/agents/ars/policies.py @@ -81,7 +81,7 @@ class GenericPolicy(object): model = ModelCatalog.get_model({ "obs": self.inputs }, obs_space, action_space, dist_dim, model_config) - dist = dist_class(model.outputs) + dist = dist_class(model.outputs, model_config=model_config) self.sampler = dist.sample() self.variables = ray.experimental.tf_utils.TensorFlowVariables( diff --git a/rllib/agents/dqn/dqn_policy.py b/rllib/agents/dqn/dqn_policy.py index 700c9085a..46c891f7d 100644 --- a/rllib/agents/dqn/dqn_policy.py +++ b/rllib/agents/dqn/dqn_policy.py @@ -107,9 +107,10 @@ class QLoss(object): class QValuePolicy(object): def __init__(self, q_values, observations, num_actions, stochastic, eps, - softmax, softmax_temp): + softmax, softmax_temp, model_config): if softmax: - action_dist = Categorical(q_values / softmax_temp) + action_dist = Categorical( + q_values / softmax_temp, model_config=model_config) self.action = action_dist.sample() self.action_prob = action_dist.sampled_action_prob() return @@ -255,7 +256,8 @@ def build_q_networks(policy, q_model, input_dict, obs_space, action_space, # Action outputs qvp = QValuePolicy(q_values, input_dict[SampleBatch.CUR_OBS], action_space.n, policy.stochastic, policy.eps, - config["soft_q"], config["softmax_temp"]) + config["soft_q"], config["softmax_temp"], + config["model"]) policy.output_actions, policy.action_prob = qvp.action, qvp.action_prob return policy.output_actions, policy.action_prob diff --git a/rllib/agents/es/policies.py b/rllib/agents/es/policies.py index dfc7e2dee..8b15cfca4 100644 --- a/rllib/agents/es/policies.py +++ b/rllib/agents/es/policies.py @@ -59,7 +59,7 @@ class GenericPolicy(object): model = ModelCatalog.get_model({ "obs": self.inputs }, obs_space, action_space, dist_dim, model_options) - dist = dist_class(model.outputs) + dist = dist_class(model.outputs, model_config=model_options) self.sampler = dist.sample() self.variables = ray.experimental.tf_utils.TensorFlowVariables( diff --git a/rllib/agents/impala/vtrace.py b/rllib/agents/impala/vtrace.py index 0064faa16..7062d5d6f 100644 --- a/rllib/agents/impala/vtrace.py +++ b/rllib/agents/impala/vtrace.py @@ -49,13 +49,14 @@ VTraceReturns = collections.namedtuple("VTraceReturns", "vs pg_advantages") def log_probs_from_logits_and_actions(policy_logits, actions, + config, dist_class=Categorical): return multi_log_probs_from_logits_and_actions([policy_logits], [actions], - dist_class)[0] + dist_class, config)[0] -def multi_log_probs_from_logits_and_actions(policy_logits, actions, - dist_class): +def multi_log_probs_from_logits_and_actions(policy_logits, actions, dist_class, + config): """Computes action log-probs from policy logits and actions. In the notation used throughout documentation and comments, T refers to the @@ -76,6 +77,8 @@ def multi_log_probs_from_logits_and_actions(policy_logits, actions, ..., [T, B, ...] with actions. + dist_class: Python class of the action distribution + config: Trainer config dict Returns: A list with length of ACTION_SPACE of float32 @@ -97,7 +100,8 @@ def multi_log_probs_from_logits_and_actions(policy_logits, actions, tf.concat([[-1], a_shape[2:]], axis=0)) log_probs.append( tf.reshape( - dist_class(policy_logits_flat).logp(actions_flat), + dist_class(policy_logits_flat, + model_config=config["model"]).logp(actions_flat), a_shape[:2])) return log_probs @@ -110,6 +114,7 @@ def from_logits(behaviour_policy_logits, rewards, values, bootstrap_value, + config, dist_class=Categorical, clip_rho_threshold=1.0, clip_pg_rho_threshold=1.0, @@ -122,6 +127,7 @@ def from_logits(behaviour_policy_logits, rewards, values, bootstrap_value, + config, dist_class, clip_rho_threshold=clip_rho_threshold, clip_pg_rho_threshold=clip_pg_rho_threshold, @@ -145,6 +151,7 @@ def multi_from_logits(behaviour_policy_logits, rewards, values, bootstrap_value, + config, dist_class, clip_rho_threshold=1.0, clip_pg_rho_threshold=1.0, @@ -235,9 +242,9 @@ def multi_from_logits(behaviour_policy_logits, discounts, rewards, values, bootstrap_value ]): target_action_log_probs = multi_log_probs_from_logits_and_actions( - target_policy_logits, actions, dist_class) + target_policy_logits, actions, dist_class, config) behaviour_action_log_probs = multi_log_probs_from_logits_and_actions( - behaviour_policy_logits, actions, dist_class) + behaviour_policy_logits, actions, dist_class, config) log_rhos = get_log_rhos(target_action_log_probs, behaviour_action_log_probs) diff --git a/rllib/agents/impala/vtrace_policy.py b/rllib/agents/impala/vtrace_policy.py index b92fb0cd9..a288fb5d1 100644 --- a/rllib/agents/impala/vtrace_policy.py +++ b/rllib/agents/impala/vtrace_policy.py @@ -41,6 +41,7 @@ class VTraceLoss(object): bootstrap_value, dist_class, valid_mask, + config, vf_loss_coeff=0.5, entropy_coeff=0.01, clip_rho_threshold=1.0, @@ -72,6 +73,7 @@ class VTraceLoss(object): bootstrap_value: A float32 tensor of shape [B]. dist_class: action distribution class for logits. valid_mask: A bool tensor of valid RNN input elements (#2992). + config: Trainer config dict. """ # Compute vtrace on the CPU for better perf. @@ -87,7 +89,8 @@ class VTraceLoss(object): dist_class=dist_class, clip_rho_threshold=tf.cast(clip_rho_threshold, tf.float32), clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold, - tf.float32)) + tf.float32), + config=config) self.value_targets = self.vtrace_returns.vs # The policy gradients loss @@ -196,6 +199,7 @@ def build_vtrace_loss(policy, batch_tensors): bootstrap_value=make_time_major(values)[-1], dist_class=Categorical if is_multidiscrete else policy.dist_class, valid_mask=make_time_major(mask, drop_last=True), + config=policy.config, vf_loss_coeff=policy.config["vf_loss_coeff"], entropy_coeff=policy.entropy_coeff, clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"], diff --git a/rllib/agents/impala/vtrace_test.py b/rllib/agents/impala/vtrace_test.py index e1f39991b..9d88fefa9 100644 --- a/rllib/agents/impala/vtrace_test.py +++ b/rllib/agents/impala/vtrace_test.py @@ -98,7 +98,7 @@ class LogProbsFromLogitsAndActionsTest(tf.test.TestCase, 0, num_actions - 1, size=(seq_len, batch_size), dtype=np.int32) action_log_probs_tensor = vtrace.log_probs_from_logits_and_actions( - policy_logits, actions) + policy_logits, actions, {"model": None}) # dummy config dict # Ground Truth # Using broadcasting to create a mask that indexes action logits @@ -159,6 +159,8 @@ class VtraceTest(tf.test.TestCase, parameterized.TestCase): clip_rho_threshold = None # No clipping. clip_pg_rho_threshold = None # No clipping. + dummy_config = {"model": None} + # Intentionally leaving shapes unspecified to test if V-trace can # deal with that. placeholders = { @@ -178,12 +180,15 @@ class VtraceTest(tf.test.TestCase, parameterized.TestCase): from_logits_output = vtrace.from_logits( clip_rho_threshold=clip_rho_threshold, clip_pg_rho_threshold=clip_pg_rho_threshold, + config=dummy_config, **placeholders) target_log_probs = vtrace.log_probs_from_logits_and_actions( - placeholders["target_policy_logits"], placeholders["actions"]) + placeholders["target_policy_logits"], placeholders["actions"], + dummy_config) behaviour_log_probs = vtrace.log_probs_from_logits_and_actions( - placeholders["behaviour_policy_logits"], placeholders["actions"]) + placeholders["behaviour_policy_logits"], placeholders["actions"], + dummy_config) log_rhos = target_log_probs - behaviour_log_probs ground_truth = (log_rhos, behaviour_log_probs, target_log_probs) diff --git a/rllib/agents/marwil/marwil_policy.py b/rllib/agents/marwil/marwil_policy.py index 47ff12ebd..3ee1abfcb 100644 --- a/rllib/agents/marwil/marwil_policy.py +++ b/rllib/agents/marwil/marwil_policy.py @@ -29,7 +29,7 @@ class ValueLoss(object): class ReweightedImitationLoss(object): def __init__(self, state_values, cumulative_rewards, logits, actions, - action_space, beta): + action_space, beta, model_config): ma_adv_norm = tf.get_variable( name="moving_average_of_advantage_norm", dtype=tf.float32, @@ -48,8 +48,8 @@ class ReweightedImitationLoss(object): beta * tf.divide(adv, 1e-8 + tf.sqrt(ma_adv_norm))) # log\pi_\theta(a|s) - dist_cls, _ = ModelCatalog.get_action_dist(action_space, {}) - action_dist = dist_cls(logits) + dist_cls, _ = ModelCatalog.get_action_dist(action_space, model_config) + action_dist = dist_cls(logits, model_config=model_config) logprobs = action_dist.logp(actions) self.loss = -1.0 * tf.reduce_mean( @@ -106,7 +106,7 @@ class MARWILPolicy(MARWILPostprocessing, TFPolicy): self.p_func_vars = scope_vars(scope.name) # Action outputs - action_dist = dist_cls(logits) + action_dist = dist_cls(logits, model_config=self.config["model"]) self.output_actions = action_dist.sample() # Training inputs @@ -164,7 +164,8 @@ class MARWILPolicy(MARWILPostprocessing, TFPolicy): def _build_policy_loss(self, state_values, cum_rwds, logits, actions, action_space): return ReweightedImitationLoss(state_values, cum_rwds, logits, actions, - action_space, self.config["beta"]) + action_space, self.config["beta"], + self.config["model"]) @override(TFPolicy) def extra_compute_grad_fetches(self): diff --git a/rllib/agents/pg/torch_pg_policy.py b/rllib/agents/pg/torch_pg_policy.py index 442c57f48..2dc4a280f 100644 --- a/rllib/agents/pg/torch_pg_policy.py +++ b/rllib/agents/pg/torch_pg_policy.py @@ -13,7 +13,8 @@ def pg_torch_loss(policy, batch_tensors): logits, _ = policy.model({ SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS] }) - action_dist = policy.dist_class(logits) + action_dist = policy.dist_class( + logits, model_config=policy.config["model"]) log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS]) # save the error in the policy object policy.pi_err = -batch_tensors[Postprocessing.ADVANTAGES].dot( diff --git a/rllib/agents/ppo/appo_policy.py b/rllib/agents/ppo/appo_policy.py index 95f61a4f6..604eeab96 100644 --- a/rllib/agents/ppo/appo_policy.py +++ b/rllib/agents/ppo/appo_policy.py @@ -112,6 +112,7 @@ class VTraceSurrogateLoss(object): rewards, values, bootstrap_value, + config, dist_class, valid_mask, vf_loss_coeff=0.5, @@ -143,6 +144,7 @@ class VTraceSurrogateLoss(object): rewards: A float32 tensor of shape [T, B]. values: A float32 tensor of shape [T, B]. bootstrap_value: A float32 tensor of shape [B]. + config: Trainer config dict. dist_class: action distribution class for logits. valid_mask: A bool tensor of valid RNN input elements (#2992). vf_loss_coeff (float): Coefficient of the value function loss. @@ -165,6 +167,7 @@ class VTraceSurrogateLoss(object): rewards=rewards, values=values, bootstrap_value=bootstrap_value, + config=config, dist_class=dist_class, clip_rho_threshold=tf.cast(clip_rho_threshold, tf.float32), clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold, @@ -251,8 +254,10 @@ def build_appo_surrogate_loss(policy, batch_tensors): old_policy_behaviour_logits, output_hidden_shape, axis=1) unpacked_outputs = tf.split(policy.model_out, output_hidden_shape, axis=1) action_dist = policy.action_dist - old_policy_action_dist = policy.dist_class(old_policy_behaviour_logits) - prev_action_dist = policy.dist_class(behaviour_logits) + old_policy_action_dist = policy.dist_class( + old_policy_behaviour_logits, model_config=policy.config["model"]) + prev_action_dist = policy.dist_class( + behaviour_logits, model_config=policy.config["model"]) values = policy.value_function policy.model_vars = policy.model.variables() @@ -298,6 +303,7 @@ def build_appo_surrogate_loss(policy, batch_tensors): rewards=make_time_major(rewards, drop_last=True), values=make_time_major(values, drop_last=True), bootstrap_value=make_time_major(values)[-1], + config=policy.config, dist_class=Categorical if is_multidiscrete else policy.dist_class, valid_mask=make_time_major(mask, drop_last=True), vf_loss_coeff=policy.config["vf_loss_coeff"], diff --git a/rllib/agents/ppo/ppo_policy.py b/rllib/agents/ppo/ppo_policy.py index e87b106fa..d41aeb900 100644 --- a/rllib/agents/ppo/ppo_policy.py +++ b/rllib/agents/ppo/ppo_policy.py @@ -39,7 +39,8 @@ class PPOLoss(object): clip_param=0.1, vf_clip_param=0.1, vf_loss_coeff=1.0, - use_gae=True): + use_gae=True, + model_config=None): """Constructs the loss for Proximal Policy Objective. Arguments: @@ -65,13 +66,15 @@ class PPOLoss(object): vf_clip_param (float): Clip parameter for the value function vf_loss_coeff (float): Coefficient of the value function loss use_gae (bool): If true, use the Generalized Advantage Estimator. + model_config (dict): (Optional) model config for use in specifying + action distributions. """ def reduce_mean_valid(t): return tf.reduce_mean(tf.boolean_mask(t, valid_mask)) - dist_cls, _ = ModelCatalog.get_action_dist(action_space, {}) - prev_dist = dist_cls(logits) + dist_cls, _ = ModelCatalog.get_action_dist(action_space, model_config) + prev_dist = dist_cls(logits, model_config=model_config) # Make loss functions. logp_ratio = tf.exp( curr_action_dist.logp(actions) - prev_dist.logp(actions)) @@ -129,7 +132,8 @@ def ppo_surrogate_loss(policy, batch_tensors): clip_param=policy.config["clip_param"], vf_clip_param=policy.config["vf_clip_param"], vf_loss_coeff=policy.config["vf_loss_coeff"], - use_gae=policy.config["use_gae"]) + use_gae=policy.config["use_gae"], + model_config=policy.config["model"]) return policy.loss_obj.loss diff --git a/rllib/agents/ppo/test/test.py b/rllib/agents/ppo/test/test.py index 978fe7c69..7acbab41d 100644 --- a/rllib/agents/ppo/test/test.py +++ b/rllib/agents/ppo/test/test.py @@ -20,7 +20,7 @@ class DistributionsTest(unittest.TestCase): logits = tf.placeholder(tf.float32, shape=(None, 10)) z = 8 * (np.random.rand(10) - 0.5) data = np.tile(z, (num_samples, 1)) - c = Categorical(logits) + c = Categorical(logits, {}) # dummy config dict sample_op = c.sample() sess = tf.Session() sess.run(tf.global_variables_initializer()) diff --git a/rllib/examples/custom_loss.py b/rllib/examples/custom_loss.py index 16cc79272..ee2f3b896 100644 --- a/rllib/examples/custom_loss.py +++ b/rllib/examples/custom_loss.py @@ -67,7 +67,7 @@ class CustomLossModel(Model): print("FYI: You can also use these tensors: {}, ".format(loss_inputs)) # compute the IL loss - action_dist = Categorical(logits) + action_dist = Categorical(logits, self.options) self.policy_loss = policy_loss self.imitation_loss = tf.reduce_mean( -action_dist.logp(input_ops["actions"])) diff --git a/rllib/examples/custom_torch_policy.py b/rllib/examples/custom_torch_policy.py index e9b30876d..8f6ef5444 100644 --- a/rllib/examples/custom_torch_policy.py +++ b/rllib/examples/custom_torch_policy.py @@ -18,7 +18,7 @@ def policy_gradient_loss(policy, batch_tensors): logits, _ = policy.model({ SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS] }) - action_dist = policy.dist_class(logits) + action_dist = policy.dist_class(logits, policy.config["model"]) log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS]) return -batch_tensors[SampleBatch.REWARDS].dot(log_probs) diff --git a/rllib/models/action_dist.py b/rllib/models/action_dist.py index 78ff5bdf9..9bfeb32fb 100644 --- a/rllib/models/action_dist.py +++ b/rllib/models/action_dist.py @@ -11,11 +11,14 @@ class ActionDistribution(object): Args: inputs (Tensor): The input vector to compute samples from. + model_config (dict): Optional model config dict + (as defined in catalog.py) """ @DeveloperAPI - def __init__(self, inputs): + def __init__(self, inputs, model_config): self.inputs = inputs + self.model_config = model_config @DeveloperAPI def sample(self): @@ -52,3 +55,22 @@ class ActionDistribution(object): MultiDiscrete. TODO(ekl) consider removing this. """ return self.entropy() + + @DeveloperAPI + @staticmethod + def required_model_output_shape(action_space, model_config): + """Returns the required shape of an input parameter tensor for a + particular action space and an optional dict of distribution-specific + options. + + Args: + action_space (gym.Space): The action space this distribution will + be used for, whose shape attributes will be used to determine + the required shape of the input parameter tensor. + model_config (dict): Model's config dict (as defined in catalog.py) + + Returns: + model_output_shape (int or np.ndarray of ints): size of the + required input vector (minus leading batch dimension). + """ + raise NotImplementedError diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index 5c6b1cf67..243cfccc5 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -8,7 +8,7 @@ import numpy as np from functools import partial from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \ - _global_registry + RLLIB_ACTION_DIST, _global_registry from ray.rllib.models.extra_spaces import Simplex from ray.rllib.models.torch.torch_action_dist import (TorchCategorical, @@ -80,6 +80,8 @@ MODEL_DEFAULTS = { "custom_preprocessor": None, # Name of a custom model to use "custom_model": None, + # Name of a custom action distribution to use + "custom_action_dist": None, # Extra options to pass to the custom classes "custom_options": {}, } @@ -119,22 +121,30 @@ class ModelCatalog(object): """ config = config or MODEL_DEFAULTS - if isinstance(action_space, gym.spaces.Box): + if config.get("custom_action_dist"): + action_dist_name = config["custom_action_dist"] + logger.debug( + "Using custom action distribution {}".format(action_dist_name)) + dist = _global_registry.get(RLLIB_ACTION_DIST, action_dist_name) + + elif isinstance(action_space, gym.spaces.Box): if len(action_space.shape) > 1: raise UnsupportedSpaceException( "Action space has multiple dimensions " "{}. ".format(action_space.shape) + "Consider reshaping this into a single dimension, " + "using a custom action distribution, " "using a Tuple action space, or the multi-agent API.") if dist_type is None: dist = TorchDiagGaussian if torch else DiagGaussian - return dist, action_space.shape[0] * 2 elif dist_type == "deterministic": - return Deterministic, action_space.shape[0] + dist = Deterministic elif isinstance(action_space, gym.spaces.Discrete): dist = TorchCategorical if torch else Categorical - return dist, action_space.n elif isinstance(action_space, gym.spaces.Tuple): + if torch: + raise NotImplementedError("Tuple action spaces not supported " + "for Pytorch.") child_dist = [] input_lens = [] for action in action_space.spaces: @@ -142,8 +152,6 @@ class ModelCatalog(object): action, config) child_dist.append(dist) input_lens.append(action_size) - if torch: - raise NotImplementedError return partial( MultiActionDistribution, child_distributions=child_dist, @@ -151,14 +159,18 @@ class ModelCatalog(object): input_lens=input_lens), sum(input_lens) elif isinstance(action_space, Simplex): if torch: - raise NotImplementedError - return Dirichlet, action_space.shape[0] - elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete): + raise NotImplementedError("Simplex action spaces not " + "supported for Pytorch.") + dist = Dirichlet + elif isinstance(action_space, gym.spaces.MultiDiscrete): if torch: - raise NotImplementedError + raise NotImplementedError("MultiDiscrete action spaces not " + "supported for Pytorch.") return partial(MultiCategorical, input_lens=action_space.nvec), \ int(sum(action_space.nvec)) + return dist, dist.required_model_output_shape(action_space, config) + raise NotImplementedError("Unsupported args: {} {}".format( action_space, dist_type)) @@ -173,11 +185,16 @@ class ModelCatalog(object): action_placeholder (Tensor): A placeholder for the actions """ - if isinstance(action_space, gym.spaces.Box): - return tf.placeholder( - tf.float32, shape=(None, action_space.shape[0]), name="action") - elif isinstance(action_space, gym.spaces.Discrete): + if isinstance(action_space, gym.spaces.Discrete): return tf.placeholder(tf.int64, shape=(None, ), name="action") + elif isinstance(action_space, (gym.spaces.Box, Simplex)): + return tf.placeholder( + tf.float32, shape=(None, ) + action_space.shape, name="action") + elif isinstance(action_space, gym.spaces.MultiDiscrete): + return tf.placeholder( + tf.as_dtype(action_space.dtype), + shape=(None, ) + action_space.shape, + name="action") elif isinstance(action_space, gym.spaces.Tuple): size = 0 all_discrete = True @@ -191,14 +208,6 @@ class ModelCatalog(object): tf.int64 if all_discrete else tf.float32, shape=(None, size), name="action") - elif isinstance(action_space, Simplex): - return tf.placeholder( - tf.float32, shape=(None, action_space.shape[0]), name="action") - elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete): - return tf.placeholder( - tf.as_dtype(action_space.dtype), - shape=(None, len(action_space.nvec)), - name="action") else: raise NotImplementedError("action space {}" " not supported".format(action_space)) @@ -362,6 +371,21 @@ class ModelCatalog(object): """ _global_registry.register(RLLIB_MODEL, model_name, model_class) + @staticmethod + @PublicAPI + def register_custom_action_dist(action_dist_name, action_dist_class): + """Register a custom action distribution class by name. + + The model can be later used by specifying + {"custom_action_dist": action_dist_name} in the model config. + + Args: + model_name (str): Name to register the action distribution under. + model_class (type): Python class of the action distribution. + """ + _global_registry.register(RLLIB_ACTION_DIST, action_dist_name, + action_dist_class) + @staticmethod def _wrap_if_needed(model_cls, model_interface): assert issubclass(model_cls, TFModelV2) diff --git a/rllib/models/tf/tf_action_dist.py b/rllib/models/tf/tf_action_dist.py index 530d4bddd..48b5b40eb 100644 --- a/rllib/models/tf/tf_action_dist.py +++ b/rllib/models/tf/tf_action_dist.py @@ -17,8 +17,9 @@ class TFActionDistribution(ActionDistribution): """TF-specific extensions for building action distributions.""" @DeveloperAPI - def __init__(self, inputs): - super(TFActionDistribution, self).__init__(inputs) + def __init__(self, inputs, model_config): + super(TFActionDistribution, self).__init__( + inputs, model_config=model_config) self.sample_op = self._build_sample_op() @DeveloperAPI @@ -76,16 +77,22 @@ class Categorical(TFActionDistribution): def _build_sample_op(self): return tf.squeeze(tf.multinomial(self.inputs, 1), axis=1) + @staticmethod + @override(ActionDistribution) + def required_model_output_shape(action_space, model_config): + return action_space.n + class MultiCategorical(TFActionDistribution): - """Categorical distribution for discrete action spaces.""" + """MultiCategorical distribution for MultiDiscrete action spaces.""" - def __init__(self, inputs, input_lens): + def __init__(self, inputs, input_lens, model_config): self.cats = [ - Categorical(input_) + Categorical(input_, model_config=model_config) for input_ in tf.split(inputs, input_lens, axis=1) ] self.sample_op = self._build_sample_op() + self.model_config = model_config @override(ActionDistribution) def logp(self, actions): @@ -116,6 +123,11 @@ class MultiCategorical(TFActionDistribution): def _build_sample_op(self): return tf.stack([cat.sample() for cat in self.cats], axis=1) + @staticmethod + @override(ActionDistribution) + def required_model_output_shape(action_space, model_config): + return np.sum(action_space.nvec) + class DiagGaussian(TFActionDistribution): """Action distribution where each vector element is a gaussian. @@ -124,12 +136,12 @@ class DiagGaussian(TFActionDistribution): second half the gaussian standard deviations. """ - def __init__(self, inputs): + def __init__(self, inputs, model_config): mean, log_std = tf.split(inputs, 2, axis=1) self.mean = mean self.log_std = log_std self.std = tf.exp(log_std) - TFActionDistribution.__init__(self, inputs) + super(DiagGaussian, self).__init__(inputs, model_config) @override(ActionDistribution) def logp(self, x): @@ -157,6 +169,11 @@ class DiagGaussian(TFActionDistribution): def _build_sample_op(self): return self.mean + self.std * tf.random_normal(tf.shape(self.mean)) + @staticmethod + @override(ActionDistribution) + def required_model_output_shape(action_space, model_config): + return np.prod(action_space.shape) * 2 + class Deterministic(TFActionDistribution): """Action distribution that returns the input values directly. @@ -172,6 +189,11 @@ class Deterministic(TFActionDistribution): def _build_sample_op(self): return self.inputs + @staticmethod + @override(ActionDistribution) + def required_model_output_shape(action_space, model_config): + return np.prod(action_space.shape) + class MultiActionDistribution(TFActionDistribution): """Action distribution that operates for list of actions. @@ -180,12 +202,14 @@ class MultiActionDistribution(TFActionDistribution): inputs (Tensor list): A list of tensors from which to compute samples. """ - def __init__(self, inputs, action_space, child_distributions, input_lens): + def __init__(self, inputs, action_space, child_distributions, input_lens, + model_config): self.input_lens = input_lens split_inputs = tf.split(inputs, self.input_lens, axis=1) child_list = [] for i, distribution in enumerate(child_distributions): - child_list.append(distribution(split_inputs[i])) + child_list.append( + distribution(split_inputs[i], model_config=model_config)) self.child_distributions = child_list @override(ActionDistribution) @@ -241,7 +265,7 @@ class Dirichlet(TFActionDistribution): e.g. actions that represent resource allocation.""" - def __init__(self, inputs): + def __init__(self, inputs, model_config): """Input is a tensor of logits. The exponential of logits is used to parametrize the Dirichlet distribution as all parameters need to be positive. An arbitrary small epsilon is added to the concentration @@ -256,7 +280,8 @@ class Dirichlet(TFActionDistribution): validate_args=True, allow_nan_stats=False, ) - TFActionDistribution.__init__(self, concentration) + super(Dirichlet, self).__init__( + concentration, model_config=model_config) @override(ActionDistribution) def logp(self, x): @@ -278,3 +303,8 @@ class Dirichlet(TFActionDistribution): @override(TFActionDistribution) def _build_sample_op(self): return self.dist.sample() + + @staticmethod + @override(ActionDistribution) + def required_model_output_shape(action_space, model_config): + return np.prod(action_space.shape) diff --git a/rllib/models/torch/torch_action_dist.py b/rllib/models/torch/torch_action_dist.py index b8becc9a3..b1a373f15 100644 --- a/rllib/models/torch/torch_action_dist.py +++ b/rllib/models/torch/torch_action_dist.py @@ -7,6 +7,8 @@ try: except ImportError: pass # soft dep +import numpy as np + from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.utils.annotations import override @@ -35,18 +37,30 @@ class TorchCategorical(TorchDistributionWrapper): """Wrapper class for PyTorch Categorical distribution.""" @override(ActionDistribution) - def __init__(self, inputs): + def __init__(self, inputs, model_config): self.dist = torch.distributions.categorical.Categorical(logits=inputs) + self.model_config = model_config + + @staticmethod + @override(ActionDistribution) + def required_model_output_shape(action_space, model_config): + return action_space.n class TorchDiagGaussian(TorchDistributionWrapper): """Wrapper class for PyTorch Normal distribution.""" @override(ActionDistribution) - def __init__(self, inputs): + def __init__(self, inputs, model_config): mean, log_std = torch.chunk(inputs, 2, dim=1) self.dist = torch.distributions.normal.Normal(mean, torch.exp(log_std)) + self.model_config = model_config @override(TorchDistributionWrapper) def logp(self, actions): return TorchDistributionWrapper.logp(self, actions).sum(-1) + + @staticmethod + @override(ActionDistribution) + def required_model_output_shape(action_space, model_config): + return np.prod(action_space.shape) * 2 diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index 871e8acb4..75c3a7d91 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -166,7 +166,8 @@ class DynamicTFPolicy(TFPolicy): self, self.model, self.input_dict, obs_space, action_space, config) else: - self.action_dist = self.dist_class(self.model_out) + self.action_dist = self.dist_class( + self.model_out, model_config=self.config["model"]) action_sampler = self.action_dist.sample() action_prob = self.action_dist.sampled_action_prob() @@ -261,7 +262,7 @@ class DynamicTFPolicy(TFPolicy): def _initialize_loss(self): def fake_array(tensor): shape = tensor.shape.as_list() - shape[0] = 1 + shape = [s if s is not None else 1 for s in shape] return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype) dummy_batch = { diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index 16eeb6a0a..1a151625e 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -78,7 +78,8 @@ class TorchPolicy(Policy): input_dict["prev_rewards"] = prev_reward_batch model_out = self._model(input_dict, state_batches, [1]) logits, state = model_out - action_dist = self._action_dist_cls(logits) + action_dist = self._action_dist_cls( + logits, model_config=self.config["model"]) actions = action_dist.sample() return (actions.cpu().numpy(), [h.cpu().numpy() for h in state], diff --git a/rllib/tests/test_catalog.py b/rllib/tests/test_catalog.py index 30f6d95f0..c72f5f37d 100644 --- a/rllib/tests/test_catalog.py +++ b/rllib/tests/test_catalog.py @@ -5,8 +5,9 @@ from gym.spaces import Box, Discrete, Tuple import ray -from ray.rllib.models import ModelCatalog +from ray.rllib.models import ModelCatalog, MODEL_DEFAULTS from ray.rllib.models.model import Model +from ray.rllib.models.tf.tf_action_dist import TFActionDistribution from ray.rllib.models.preprocessors import (NoPreprocessor, OneHotPreprocessor, Preprocessor) from ray.rllib.models.tf.fcnet_v1 import FullyConnectedNetwork @@ -31,6 +32,25 @@ class CustomModel(Model): return tf.constant([[0] * 5]), None +class CustomActionDistribution(TFActionDistribution): + @staticmethod + def required_model_output_shape(action_space, model_config=None): + custom_options = model_config["custom_options"] or {} + if custom_options is not None and custom_options.get("output_dim"): + return custom_options.get("output_dim") + return action_space.shape + + def _build_sample_op(self): + custom_options = self.model_config["custom_options"] + if "output_dim" in custom_options: + output_shape = tf.concat( + [tf.shape(self.inputs)[:1], custom_options["output_dim"]], + axis=0) + else: + output_shape = tf.shape(self.inputs) + return tf.random_uniform(output_shape) + + class ModelCatalogTest(unittest.TestCase): def tearDown(self): ray.shutdown() @@ -94,6 +114,41 @@ class ModelCatalogTest(unittest.TestCase): {"custom_model": "foo"}) self.assertEqual(str(type(p1)), str(CustomModel)) + def testCustomActionDistribution(self): + ray.init() + # registration + ModelCatalog.register_custom_action_dist("test", + CustomActionDistribution) + action_space = Box(0, 1, shape=(5, 3), dtype=np.float32) + + # test retrieving it + model_config = MODEL_DEFAULTS.copy() + model_config["custom_action_dist"] = "test" + dist_cls, param_shape = ModelCatalog.get_action_dist( + action_space, model_config) + self.assertEqual(str(dist_cls), str(CustomActionDistribution)) + self.assertEqual(param_shape, action_space.shape) + + # test the class works as a distribution + dist_input = tf.placeholder(tf.float32, (None, ) + param_shape) + dist = dist_cls(dist_input, model_config=model_config) + self.assertEqual(dist.sample().shape[1:], dist_input.shape[1:]) + self.assertIsInstance(dist.sample(), tf.Tensor) + with self.assertRaises(NotImplementedError): + dist.entropy() + + # test passing the options to it + model_config["custom_options"].update({"output_dim": (3, )}) + dist_cls, param_shape = ModelCatalog.get_action_dist( + action_space, model_config) + self.assertEqual(param_shape, (3, )) + dist_input = tf.placeholder(tf.float32, (None, ) + param_shape) + dist = dist_cls(dist_input, model_config=model_config) + self.assertEqual(dist.sample().shape[1:], dist_input.shape[1:]) + self.assertIsInstance(dist.sample(), tf.Tensor) + with self.assertRaises(NotImplementedError): + dist.entropy() + if __name__ == "__main__": unittest.main(verbosity=2)