Custom action distributions (#5164)

* custom action dist wip

* Test case for custom action dist

* ActionDistribution.get_parameter_shape_for_action_space pattern

* Edit exception message to also suggest using a custom action distribution

* Clean up ModelCatalog.get_action_dist

* Pass model config to ActionDistribution constructors

* Update custom action distribution test case

* Name fix

* Autoformatter

* parameter shape static methods for torch distributions

* Fix docstring

* Generalize fake array for graph initialization

* Fix action dist constructors

* Correct parameter shape static methods for multicategorical and gaussian

* Make suggested changes to custom action dist's

* Correct instances of not passing model config to action dist

* Autoformatter

* fix tuple distribution constructor

* bugfix
This commit is contained in:
Matthew A. Wright
2019-08-06 18:13:16 +00:00
committed by Eric Liang
parent 94bff244e4
commit e3c9f7e83a
22 changed files with 252 additions and 73 deletions
+3 -1
View File
@@ -14,8 +14,10 @@ TRAINABLE_CLASS = "trainable_class"
ENV_CREATOR = "env_creator"
RLLIB_MODEL = "rllib_model"
RLLIB_PREPROCESSOR = "rllib_preprocessor"
RLLIB_ACTION_DIST = "rllib_action_dist"
KNOWN_CATEGORIES = [
TRAINABLE_CLASS, ENV_CREATOR, RLLIB_MODEL, RLLIB_PREPROCESSOR
TRAINABLE_CLASS, ENV_CREATOR, RLLIB_MODEL, RLLIB_PREPROCESSOR,
RLLIB_ACTION_DIST
]
logger = logging.getLogger(__name__)
+1 -1
View File
@@ -18,7 +18,7 @@ def actor_critic_loss(policy, batch_tensors):
SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS]
}) # TODO(ekl) seq lens shouldn't be None
values = policy.model.value_function()
dist = policy.dist_class(logits)
dist = policy.dist_class(logits, policy.config["model"])
log_probs = dist.logp(batch_tensors[SampleBatch.ACTIONS])
policy.entropy = dist.entropy().mean()
policy.pi_err = -batch_tensors[Postprocessing.ADVANTAGES].dot(
+1 -1
View File
@@ -81,7 +81,7 @@ class GenericPolicy(object):
model = ModelCatalog.get_model({
"obs": self.inputs
}, obs_space, action_space, dist_dim, model_config)
dist = dist_class(model.outputs)
dist = dist_class(model.outputs, model_config=model_config)
self.sampler = dist.sample()
self.variables = ray.experimental.tf_utils.TensorFlowVariables(
+5 -3
View File
@@ -107,9 +107,10 @@ class QLoss(object):
class QValuePolicy(object):
def __init__(self, q_values, observations, num_actions, stochastic, eps,
softmax, softmax_temp):
softmax, softmax_temp, model_config):
if softmax:
action_dist = Categorical(q_values / softmax_temp)
action_dist = Categorical(
q_values / softmax_temp, model_config=model_config)
self.action = action_dist.sample()
self.action_prob = action_dist.sampled_action_prob()
return
@@ -255,7 +256,8 @@ def build_q_networks(policy, q_model, input_dict, obs_space, action_space,
# Action outputs
qvp = QValuePolicy(q_values, input_dict[SampleBatch.CUR_OBS],
action_space.n, policy.stochastic, policy.eps,
config["soft_q"], config["softmax_temp"])
config["soft_q"], config["softmax_temp"],
config["model"])
policy.output_actions, policy.action_prob = qvp.action, qvp.action_prob
return policy.output_actions, policy.action_prob
+1 -1
View File
@@ -59,7 +59,7 @@ class GenericPolicy(object):
model = ModelCatalog.get_model({
"obs": self.inputs
}, obs_space, action_space, dist_dim, model_options)
dist = dist_class(model.outputs)
dist = dist_class(model.outputs, model_config=model_options)
self.sampler = dist.sample()
self.variables = ray.experimental.tf_utils.TensorFlowVariables(
+13 -6
View File
@@ -49,13 +49,14 @@ VTraceReturns = collections.namedtuple("VTraceReturns", "vs pg_advantages")
def log_probs_from_logits_and_actions(policy_logits,
actions,
config,
dist_class=Categorical):
return multi_log_probs_from_logits_and_actions([policy_logits], [actions],
dist_class)[0]
dist_class, config)[0]
def multi_log_probs_from_logits_and_actions(policy_logits, actions,
dist_class):
def multi_log_probs_from_logits_and_actions(policy_logits, actions, dist_class,
config):
"""Computes action log-probs from policy logits and actions.
In the notation used throughout documentation and comments, T refers to the
@@ -76,6 +77,8 @@ def multi_log_probs_from_logits_and_actions(policy_logits, actions,
...,
[T, B, ...]
with actions.
dist_class: Python class of the action distribution
config: Trainer config dict
Returns:
A list with length of ACTION_SPACE of float32
@@ -97,7 +100,8 @@ def multi_log_probs_from_logits_and_actions(policy_logits, actions,
tf.concat([[-1], a_shape[2:]], axis=0))
log_probs.append(
tf.reshape(
dist_class(policy_logits_flat).logp(actions_flat),
dist_class(policy_logits_flat,
model_config=config["model"]).logp(actions_flat),
a_shape[:2]))
return log_probs
@@ -110,6 +114,7 @@ def from_logits(behaviour_policy_logits,
rewards,
values,
bootstrap_value,
config,
dist_class=Categorical,
clip_rho_threshold=1.0,
clip_pg_rho_threshold=1.0,
@@ -122,6 +127,7 @@ def from_logits(behaviour_policy_logits,
rewards,
values,
bootstrap_value,
config,
dist_class,
clip_rho_threshold=clip_rho_threshold,
clip_pg_rho_threshold=clip_pg_rho_threshold,
@@ -145,6 +151,7 @@ def multi_from_logits(behaviour_policy_logits,
rewards,
values,
bootstrap_value,
config,
dist_class,
clip_rho_threshold=1.0,
clip_pg_rho_threshold=1.0,
@@ -235,9 +242,9 @@ def multi_from_logits(behaviour_policy_logits,
discounts, rewards, values, bootstrap_value
]):
target_action_log_probs = multi_log_probs_from_logits_and_actions(
target_policy_logits, actions, dist_class)
target_policy_logits, actions, dist_class, config)
behaviour_action_log_probs = multi_log_probs_from_logits_and_actions(
behaviour_policy_logits, actions, dist_class)
behaviour_policy_logits, actions, dist_class, config)
log_rhos = get_log_rhos(target_action_log_probs,
behaviour_action_log_probs)
+5 -1
View File
@@ -41,6 +41,7 @@ class VTraceLoss(object):
bootstrap_value,
dist_class,
valid_mask,
config,
vf_loss_coeff=0.5,
entropy_coeff=0.01,
clip_rho_threshold=1.0,
@@ -72,6 +73,7 @@ class VTraceLoss(object):
bootstrap_value: A float32 tensor of shape [B].
dist_class: action distribution class for logits.
valid_mask: A bool tensor of valid RNN input elements (#2992).
config: Trainer config dict.
"""
# Compute vtrace on the CPU for better perf.
@@ -87,7 +89,8 @@ class VTraceLoss(object):
dist_class=dist_class,
clip_rho_threshold=tf.cast(clip_rho_threshold, tf.float32),
clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold,
tf.float32))
tf.float32),
config=config)
self.value_targets = self.vtrace_returns.vs
# The policy gradients loss
@@ -196,6 +199,7 @@ def build_vtrace_loss(policy, batch_tensors):
bootstrap_value=make_time_major(values)[-1],
dist_class=Categorical if is_multidiscrete else policy.dist_class,
valid_mask=make_time_major(mask, drop_last=True),
config=policy.config,
vf_loss_coeff=policy.config["vf_loss_coeff"],
entropy_coeff=policy.entropy_coeff,
clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"],
+8 -3
View File
@@ -98,7 +98,7 @@ class LogProbsFromLogitsAndActionsTest(tf.test.TestCase,
0, num_actions - 1, size=(seq_len, batch_size), dtype=np.int32)
action_log_probs_tensor = vtrace.log_probs_from_logits_and_actions(
policy_logits, actions)
policy_logits, actions, {"model": None}) # dummy config dict
# Ground Truth
# Using broadcasting to create a mask that indexes action logits
@@ -159,6 +159,8 @@ class VtraceTest(tf.test.TestCase, parameterized.TestCase):
clip_rho_threshold = None # No clipping.
clip_pg_rho_threshold = None # No clipping.
dummy_config = {"model": None}
# Intentionally leaving shapes unspecified to test if V-trace can
# deal with that.
placeholders = {
@@ -178,12 +180,15 @@ class VtraceTest(tf.test.TestCase, parameterized.TestCase):
from_logits_output = vtrace.from_logits(
clip_rho_threshold=clip_rho_threshold,
clip_pg_rho_threshold=clip_pg_rho_threshold,
config=dummy_config,
**placeholders)
target_log_probs = vtrace.log_probs_from_logits_and_actions(
placeholders["target_policy_logits"], placeholders["actions"])
placeholders["target_policy_logits"], placeholders["actions"],
dummy_config)
behaviour_log_probs = vtrace.log_probs_from_logits_and_actions(
placeholders["behaviour_policy_logits"], placeholders["actions"])
placeholders["behaviour_policy_logits"], placeholders["actions"],
dummy_config)
log_rhos = target_log_probs - behaviour_log_probs
ground_truth = (log_rhos, behaviour_log_probs, target_log_probs)
+6 -5
View File
@@ -29,7 +29,7 @@ class ValueLoss(object):
class ReweightedImitationLoss(object):
def __init__(self, state_values, cumulative_rewards, logits, actions,
action_space, beta):
action_space, beta, model_config):
ma_adv_norm = tf.get_variable(
name="moving_average_of_advantage_norm",
dtype=tf.float32,
@@ -48,8 +48,8 @@ class ReweightedImitationLoss(object):
beta * tf.divide(adv, 1e-8 + tf.sqrt(ma_adv_norm)))
# log\pi_\theta(a|s)
dist_cls, _ = ModelCatalog.get_action_dist(action_space, {})
action_dist = dist_cls(logits)
dist_cls, _ = ModelCatalog.get_action_dist(action_space, model_config)
action_dist = dist_cls(logits, model_config=model_config)
logprobs = action_dist.logp(actions)
self.loss = -1.0 * tf.reduce_mean(
@@ -106,7 +106,7 @@ class MARWILPolicy(MARWILPostprocessing, TFPolicy):
self.p_func_vars = scope_vars(scope.name)
# Action outputs
action_dist = dist_cls(logits)
action_dist = dist_cls(logits, model_config=self.config["model"])
self.output_actions = action_dist.sample()
# Training inputs
@@ -164,7 +164,8 @@ class MARWILPolicy(MARWILPostprocessing, TFPolicy):
def _build_policy_loss(self, state_values, cum_rwds, logits, actions,
action_space):
return ReweightedImitationLoss(state_values, cum_rwds, logits, actions,
action_space, self.config["beta"])
action_space, self.config["beta"],
self.config["model"])
@override(TFPolicy)
def extra_compute_grad_fetches(self):
+2 -1
View File
@@ -13,7 +13,8 @@ def pg_torch_loss(policy, batch_tensors):
logits, _ = policy.model({
SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS]
})
action_dist = policy.dist_class(logits)
action_dist = policy.dist_class(
logits, model_config=policy.config["model"])
log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS])
# save the error in the policy object
policy.pi_err = -batch_tensors[Postprocessing.ADVANTAGES].dot(
+8 -2
View File
@@ -112,6 +112,7 @@ class VTraceSurrogateLoss(object):
rewards,
values,
bootstrap_value,
config,
dist_class,
valid_mask,
vf_loss_coeff=0.5,
@@ -143,6 +144,7 @@ class VTraceSurrogateLoss(object):
rewards: A float32 tensor of shape [T, B].
values: A float32 tensor of shape [T, B].
bootstrap_value: A float32 tensor of shape [B].
config: Trainer config dict.
dist_class: action distribution class for logits.
valid_mask: A bool tensor of valid RNN input elements (#2992).
vf_loss_coeff (float): Coefficient of the value function loss.
@@ -165,6 +167,7 @@ class VTraceSurrogateLoss(object):
rewards=rewards,
values=values,
bootstrap_value=bootstrap_value,
config=config,
dist_class=dist_class,
clip_rho_threshold=tf.cast(clip_rho_threshold, tf.float32),
clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold,
@@ -251,8 +254,10 @@ def build_appo_surrogate_loss(policy, batch_tensors):
old_policy_behaviour_logits, output_hidden_shape, axis=1)
unpacked_outputs = tf.split(policy.model_out, output_hidden_shape, axis=1)
action_dist = policy.action_dist
old_policy_action_dist = policy.dist_class(old_policy_behaviour_logits)
prev_action_dist = policy.dist_class(behaviour_logits)
old_policy_action_dist = policy.dist_class(
old_policy_behaviour_logits, model_config=policy.config["model"])
prev_action_dist = policy.dist_class(
behaviour_logits, model_config=policy.config["model"])
values = policy.value_function
policy.model_vars = policy.model.variables()
@@ -298,6 +303,7 @@ def build_appo_surrogate_loss(policy, batch_tensors):
rewards=make_time_major(rewards, drop_last=True),
values=make_time_major(values, drop_last=True),
bootstrap_value=make_time_major(values)[-1],
config=policy.config,
dist_class=Categorical if is_multidiscrete else policy.dist_class,
valid_mask=make_time_major(mask, drop_last=True),
vf_loss_coeff=policy.config["vf_loss_coeff"],
+8 -4
View File
@@ -39,7 +39,8 @@ class PPOLoss(object):
clip_param=0.1,
vf_clip_param=0.1,
vf_loss_coeff=1.0,
use_gae=True):
use_gae=True,
model_config=None):
"""Constructs the loss for Proximal Policy Objective.
Arguments:
@@ -65,13 +66,15 @@ class PPOLoss(object):
vf_clip_param (float): Clip parameter for the value function
vf_loss_coeff (float): Coefficient of the value function loss
use_gae (bool): If true, use the Generalized Advantage Estimator.
model_config (dict): (Optional) model config for use in specifying
action distributions.
"""
def reduce_mean_valid(t):
return tf.reduce_mean(tf.boolean_mask(t, valid_mask))
dist_cls, _ = ModelCatalog.get_action_dist(action_space, {})
prev_dist = dist_cls(logits)
dist_cls, _ = ModelCatalog.get_action_dist(action_space, model_config)
prev_dist = dist_cls(logits, model_config=model_config)
# Make loss functions.
logp_ratio = tf.exp(
curr_action_dist.logp(actions) - prev_dist.logp(actions))
@@ -129,7 +132,8 @@ def ppo_surrogate_loss(policy, batch_tensors):
clip_param=policy.config["clip_param"],
vf_clip_param=policy.config["vf_clip_param"],
vf_loss_coeff=policy.config["vf_loss_coeff"],
use_gae=policy.config["use_gae"])
use_gae=policy.config["use_gae"],
model_config=policy.config["model"])
return policy.loss_obj.loss
+1 -1
View File
@@ -20,7 +20,7 @@ class DistributionsTest(unittest.TestCase):
logits = tf.placeholder(tf.float32, shape=(None, 10))
z = 8 * (np.random.rand(10) - 0.5)
data = np.tile(z, (num_samples, 1))
c = Categorical(logits)
c = Categorical(logits, {}) # dummy config dict
sample_op = c.sample()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
+1 -1
View File
@@ -67,7 +67,7 @@ class CustomLossModel(Model):
print("FYI: You can also use these tensors: {}, ".format(loss_inputs))
# compute the IL loss
action_dist = Categorical(logits)
action_dist = Categorical(logits, self.options)
self.policy_loss = policy_loss
self.imitation_loss = tf.reduce_mean(
-action_dist.logp(input_ops["actions"]))
+1 -1
View File
@@ -18,7 +18,7 @@ def policy_gradient_loss(policy, batch_tensors):
logits, _ = policy.model({
SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS]
})
action_dist = policy.dist_class(logits)
action_dist = policy.dist_class(logits, policy.config["model"])
log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS])
return -batch_tensors[SampleBatch.REWARDS].dot(log_probs)
+23 -1
View File
@@ -11,11 +11,14 @@ class ActionDistribution(object):
Args:
inputs (Tensor): The input vector to compute samples from.
model_config (dict): Optional model config dict
(as defined in catalog.py)
"""
@DeveloperAPI
def __init__(self, inputs):
def __init__(self, inputs, model_config):
self.inputs = inputs
self.model_config = model_config
@DeveloperAPI
def sample(self):
@@ -52,3 +55,22 @@ class ActionDistribution(object):
MultiDiscrete. TODO(ekl) consider removing this.
"""
return self.entropy()
@DeveloperAPI
@staticmethod
def required_model_output_shape(action_space, model_config):
"""Returns the required shape of an input parameter tensor for a
particular action space and an optional dict of distribution-specific
options.
Args:
action_space (gym.Space): The action space this distribution will
be used for, whose shape attributes will be used to determine
the required shape of the input parameter tensor.
model_config (dict): Model's config dict (as defined in catalog.py)
Returns:
model_output_shape (int or np.ndarray of ints): size of the
required input vector (minus leading batch dimension).
"""
raise NotImplementedError
+47 -23
View File
@@ -8,7 +8,7 @@ import numpy as np
from functools import partial
from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \
_global_registry
RLLIB_ACTION_DIST, _global_registry
from ray.rllib.models.extra_spaces import Simplex
from ray.rllib.models.torch.torch_action_dist import (TorchCategorical,
@@ -80,6 +80,8 @@ MODEL_DEFAULTS = {
"custom_preprocessor": None,
# Name of a custom model to use
"custom_model": None,
# Name of a custom action distribution to use
"custom_action_dist": None,
# Extra options to pass to the custom classes
"custom_options": {},
}
@@ -119,22 +121,30 @@ class ModelCatalog(object):
"""
config = config or MODEL_DEFAULTS
if isinstance(action_space, gym.spaces.Box):
if config.get("custom_action_dist"):
action_dist_name = config["custom_action_dist"]
logger.debug(
"Using custom action distribution {}".format(action_dist_name))
dist = _global_registry.get(RLLIB_ACTION_DIST, action_dist_name)
elif isinstance(action_space, gym.spaces.Box):
if len(action_space.shape) > 1:
raise UnsupportedSpaceException(
"Action space has multiple dimensions "
"{}. ".format(action_space.shape) +
"Consider reshaping this into a single dimension, "
"using a custom action distribution, "
"using a Tuple action space, or the multi-agent API.")
if dist_type is None:
dist = TorchDiagGaussian if torch else DiagGaussian
return dist, action_space.shape[0] * 2
elif dist_type == "deterministic":
return Deterministic, action_space.shape[0]
dist = Deterministic
elif isinstance(action_space, gym.spaces.Discrete):
dist = TorchCategorical if torch else Categorical
return dist, action_space.n
elif isinstance(action_space, gym.spaces.Tuple):
if torch:
raise NotImplementedError("Tuple action spaces not supported "
"for Pytorch.")
child_dist = []
input_lens = []
for action in action_space.spaces:
@@ -142,8 +152,6 @@ class ModelCatalog(object):
action, config)
child_dist.append(dist)
input_lens.append(action_size)
if torch:
raise NotImplementedError
return partial(
MultiActionDistribution,
child_distributions=child_dist,
@@ -151,14 +159,18 @@ class ModelCatalog(object):
input_lens=input_lens), sum(input_lens)
elif isinstance(action_space, Simplex):
if torch:
raise NotImplementedError
return Dirichlet, action_space.shape[0]
elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete):
raise NotImplementedError("Simplex action spaces not "
"supported for Pytorch.")
dist = Dirichlet
elif isinstance(action_space, gym.spaces.MultiDiscrete):
if torch:
raise NotImplementedError
raise NotImplementedError("MultiDiscrete action spaces not "
"supported for Pytorch.")
return partial(MultiCategorical, input_lens=action_space.nvec), \
int(sum(action_space.nvec))
return dist, dist.required_model_output_shape(action_space, config)
raise NotImplementedError("Unsupported args: {} {}".format(
action_space, dist_type))
@@ -173,11 +185,16 @@ class ModelCatalog(object):
action_placeholder (Tensor): A placeholder for the actions
"""
if isinstance(action_space, gym.spaces.Box):
return tf.placeholder(
tf.float32, shape=(None, action_space.shape[0]), name="action")
elif isinstance(action_space, gym.spaces.Discrete):
if isinstance(action_space, gym.spaces.Discrete):
return tf.placeholder(tf.int64, shape=(None, ), name="action")
elif isinstance(action_space, (gym.spaces.Box, Simplex)):
return tf.placeholder(
tf.float32, shape=(None, ) + action_space.shape, name="action")
elif isinstance(action_space, gym.spaces.MultiDiscrete):
return tf.placeholder(
tf.as_dtype(action_space.dtype),
shape=(None, ) + action_space.shape,
name="action")
elif isinstance(action_space, gym.spaces.Tuple):
size = 0
all_discrete = True
@@ -191,14 +208,6 @@ class ModelCatalog(object):
tf.int64 if all_discrete else tf.float32,
shape=(None, size),
name="action")
elif isinstance(action_space, Simplex):
return tf.placeholder(
tf.float32, shape=(None, action_space.shape[0]), name="action")
elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete):
return tf.placeholder(
tf.as_dtype(action_space.dtype),
shape=(None, len(action_space.nvec)),
name="action")
else:
raise NotImplementedError("action space {}"
" not supported".format(action_space))
@@ -362,6 +371,21 @@ class ModelCatalog(object):
"""
_global_registry.register(RLLIB_MODEL, model_name, model_class)
@staticmethod
@PublicAPI
def register_custom_action_dist(action_dist_name, action_dist_class):
"""Register a custom action distribution class by name.
The model can be later used by specifying
{"custom_action_dist": action_dist_name} in the model config.
Args:
model_name (str): Name to register the action distribution under.
model_class (type): Python class of the action distribution.
"""
_global_registry.register(RLLIB_ACTION_DIST, action_dist_name,
action_dist_class)
@staticmethod
def _wrap_if_needed(model_cls, model_interface):
assert issubclass(model_cls, TFModelV2)
+41 -11
View File
@@ -17,8 +17,9 @@ class TFActionDistribution(ActionDistribution):
"""TF-specific extensions for building action distributions."""
@DeveloperAPI
def __init__(self, inputs):
super(TFActionDistribution, self).__init__(inputs)
def __init__(self, inputs, model_config):
super(TFActionDistribution, self).__init__(
inputs, model_config=model_config)
self.sample_op = self._build_sample_op()
@DeveloperAPI
@@ -76,16 +77,22 @@ class Categorical(TFActionDistribution):
def _build_sample_op(self):
return tf.squeeze(tf.multinomial(self.inputs, 1), axis=1)
@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
return action_space.n
class MultiCategorical(TFActionDistribution):
"""Categorical distribution for discrete action spaces."""
"""MultiCategorical distribution for MultiDiscrete action spaces."""
def __init__(self, inputs, input_lens):
def __init__(self, inputs, input_lens, model_config):
self.cats = [
Categorical(input_)
Categorical(input_, model_config=model_config)
for input_ in tf.split(inputs, input_lens, axis=1)
]
self.sample_op = self._build_sample_op()
self.model_config = model_config
@override(ActionDistribution)
def logp(self, actions):
@@ -116,6 +123,11 @@ class MultiCategorical(TFActionDistribution):
def _build_sample_op(self):
return tf.stack([cat.sample() for cat in self.cats], axis=1)
@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
return np.sum(action_space.nvec)
class DiagGaussian(TFActionDistribution):
"""Action distribution where each vector element is a gaussian.
@@ -124,12 +136,12 @@ class DiagGaussian(TFActionDistribution):
second half the gaussian standard deviations.
"""
def __init__(self, inputs):
def __init__(self, inputs, model_config):
mean, log_std = tf.split(inputs, 2, axis=1)
self.mean = mean
self.log_std = log_std
self.std = tf.exp(log_std)
TFActionDistribution.__init__(self, inputs)
super(DiagGaussian, self).__init__(inputs, model_config)
@override(ActionDistribution)
def logp(self, x):
@@ -157,6 +169,11 @@ class DiagGaussian(TFActionDistribution):
def _build_sample_op(self):
return self.mean + self.std * tf.random_normal(tf.shape(self.mean))
@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
return np.prod(action_space.shape) * 2
class Deterministic(TFActionDistribution):
"""Action distribution that returns the input values directly.
@@ -172,6 +189,11 @@ class Deterministic(TFActionDistribution):
def _build_sample_op(self):
return self.inputs
@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
return np.prod(action_space.shape)
class MultiActionDistribution(TFActionDistribution):
"""Action distribution that operates for list of actions.
@@ -180,12 +202,14 @@ class MultiActionDistribution(TFActionDistribution):
inputs (Tensor list): A list of tensors from which to compute samples.
"""
def __init__(self, inputs, action_space, child_distributions, input_lens):
def __init__(self, inputs, action_space, child_distributions, input_lens,
model_config):
self.input_lens = input_lens
split_inputs = tf.split(inputs, self.input_lens, axis=1)
child_list = []
for i, distribution in enumerate(child_distributions):
child_list.append(distribution(split_inputs[i]))
child_list.append(
distribution(split_inputs[i], model_config=model_config))
self.child_distributions = child_list
@override(ActionDistribution)
@@ -241,7 +265,7 @@ class Dirichlet(TFActionDistribution):
e.g. actions that represent resource allocation."""
def __init__(self, inputs):
def __init__(self, inputs, model_config):
"""Input is a tensor of logits. The exponential of logits is used to
parametrize the Dirichlet distribution as all parameters need to be
positive. An arbitrary small epsilon is added to the concentration
@@ -256,7 +280,8 @@ class Dirichlet(TFActionDistribution):
validate_args=True,
allow_nan_stats=False,
)
TFActionDistribution.__init__(self, concentration)
super(Dirichlet, self).__init__(
concentration, model_config=model_config)
@override(ActionDistribution)
def logp(self, x):
@@ -278,3 +303,8 @@ class Dirichlet(TFActionDistribution):
@override(TFActionDistribution)
def _build_sample_op(self):
return self.dist.sample()
@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
return np.prod(action_space.shape)
+16 -2
View File
@@ -7,6 +7,8 @@ try:
except ImportError:
pass # soft dep
import numpy as np
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.utils.annotations import override
@@ -35,18 +37,30 @@ class TorchCategorical(TorchDistributionWrapper):
"""Wrapper class for PyTorch Categorical distribution."""
@override(ActionDistribution)
def __init__(self, inputs):
def __init__(self, inputs, model_config):
self.dist = torch.distributions.categorical.Categorical(logits=inputs)
self.model_config = model_config
@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
return action_space.n
class TorchDiagGaussian(TorchDistributionWrapper):
"""Wrapper class for PyTorch Normal distribution."""
@override(ActionDistribution)
def __init__(self, inputs):
def __init__(self, inputs, model_config):
mean, log_std = torch.chunk(inputs, 2, dim=1)
self.dist = torch.distributions.normal.Normal(mean, torch.exp(log_std))
self.model_config = model_config
@override(TorchDistributionWrapper)
def logp(self, actions):
return TorchDistributionWrapper.logp(self, actions).sum(-1)
@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
return np.prod(action_space.shape) * 2
+3 -2
View File
@@ -166,7 +166,8 @@ class DynamicTFPolicy(TFPolicy):
self, self.model, self.input_dict, obs_space, action_space,
config)
else:
self.action_dist = self.dist_class(self.model_out)
self.action_dist = self.dist_class(
self.model_out, model_config=self.config["model"])
action_sampler = self.action_dist.sample()
action_prob = self.action_dist.sampled_action_prob()
@@ -261,7 +262,7 @@ class DynamicTFPolicy(TFPolicy):
def _initialize_loss(self):
def fake_array(tensor):
shape = tensor.shape.as_list()
shape[0] = 1
shape = [s if s is not None else 1 for s in shape]
return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype)
dummy_batch = {
+2 -1
View File
@@ -78,7 +78,8 @@ class TorchPolicy(Policy):
input_dict["prev_rewards"] = prev_reward_batch
model_out = self._model(input_dict, state_batches, [1])
logits, state = model_out
action_dist = self._action_dist_cls(logits)
action_dist = self._action_dist_cls(
logits, model_config=self.config["model"])
actions = action_dist.sample()
return (actions.cpu().numpy(),
[h.cpu().numpy() for h in state],
+56 -1
View File
@@ -5,8 +5,9 @@ from gym.spaces import Box, Discrete, Tuple
import ray
from ray.rllib.models import ModelCatalog
from ray.rllib.models import ModelCatalog, MODEL_DEFAULTS
from ray.rllib.models.model import Model
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
from ray.rllib.models.preprocessors import (NoPreprocessor, OneHotPreprocessor,
Preprocessor)
from ray.rllib.models.tf.fcnet_v1 import FullyConnectedNetwork
@@ -31,6 +32,25 @@ class CustomModel(Model):
return tf.constant([[0] * 5]), None
class CustomActionDistribution(TFActionDistribution):
@staticmethod
def required_model_output_shape(action_space, model_config=None):
custom_options = model_config["custom_options"] or {}
if custom_options is not None and custom_options.get("output_dim"):
return custom_options.get("output_dim")
return action_space.shape
def _build_sample_op(self):
custom_options = self.model_config["custom_options"]
if "output_dim" in custom_options:
output_shape = tf.concat(
[tf.shape(self.inputs)[:1], custom_options["output_dim"]],
axis=0)
else:
output_shape = tf.shape(self.inputs)
return tf.random_uniform(output_shape)
class ModelCatalogTest(unittest.TestCase):
def tearDown(self):
ray.shutdown()
@@ -94,6 +114,41 @@ class ModelCatalogTest(unittest.TestCase):
{"custom_model": "foo"})
self.assertEqual(str(type(p1)), str(CustomModel))
def testCustomActionDistribution(self):
ray.init()
# registration
ModelCatalog.register_custom_action_dist("test",
CustomActionDistribution)
action_space = Box(0, 1, shape=(5, 3), dtype=np.float32)
# test retrieving it
model_config = MODEL_DEFAULTS.copy()
model_config["custom_action_dist"] = "test"
dist_cls, param_shape = ModelCatalog.get_action_dist(
action_space, model_config)
self.assertEqual(str(dist_cls), str(CustomActionDistribution))
self.assertEqual(param_shape, action_space.shape)
# test the class works as a distribution
dist_input = tf.placeholder(tf.float32, (None, ) + param_shape)
dist = dist_cls(dist_input, model_config=model_config)
self.assertEqual(dist.sample().shape[1:], dist_input.shape[1:])
self.assertIsInstance(dist.sample(), tf.Tensor)
with self.assertRaises(NotImplementedError):
dist.entropy()
# test passing the options to it
model_config["custom_options"].update({"output_dim": (3, )})
dist_cls, param_shape = ModelCatalog.get_action_dist(
action_space, model_config)
self.assertEqual(param_shape, (3, ))
dist_input = tf.placeholder(tf.float32, (None, ) + param_shape)
dist = dist_cls(dist_input, model_config=model_config)
self.assertEqual(dist.sample().shape[1:], dist_input.shape[1:])
self.assertIsInstance(dist.sample(), tf.Tensor)
with self.assertRaises(NotImplementedError):
dist.entropy()
if __name__ == "__main__":
unittest.main(verbosity=2)