mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 13:44:47 +08:00
Custom action distributions (#5164)
* custom action dist wip * Test case for custom action dist * ActionDistribution.get_parameter_shape_for_action_space pattern * Edit exception message to also suggest using a custom action distribution * Clean up ModelCatalog.get_action_dist * Pass model config to ActionDistribution constructors * Update custom action distribution test case * Name fix * Autoformatter * parameter shape static methods for torch distributions * Fix docstring * Generalize fake array for graph initialization * Fix action dist constructors * Correct parameter shape static methods for multicategorical and gaussian * Make suggested changes to custom action dist's * Correct instances of not passing model config to action dist * Autoformatter * fix tuple distribution constructor * bugfix
This commit is contained in:
committed by
Eric Liang
parent
94bff244e4
commit
e3c9f7e83a
@@ -14,8 +14,10 @@ TRAINABLE_CLASS = "trainable_class"
|
||||
ENV_CREATOR = "env_creator"
|
||||
RLLIB_MODEL = "rllib_model"
|
||||
RLLIB_PREPROCESSOR = "rllib_preprocessor"
|
||||
RLLIB_ACTION_DIST = "rllib_action_dist"
|
||||
KNOWN_CATEGORIES = [
|
||||
TRAINABLE_CLASS, ENV_CREATOR, RLLIB_MODEL, RLLIB_PREPROCESSOR
|
||||
TRAINABLE_CLASS, ENV_CREATOR, RLLIB_MODEL, RLLIB_PREPROCESSOR,
|
||||
RLLIB_ACTION_DIST
|
||||
]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -18,7 +18,7 @@ def actor_critic_loss(policy, batch_tensors):
|
||||
SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS]
|
||||
}) # TODO(ekl) seq lens shouldn't be None
|
||||
values = policy.model.value_function()
|
||||
dist = policy.dist_class(logits)
|
||||
dist = policy.dist_class(logits, policy.config["model"])
|
||||
log_probs = dist.logp(batch_tensors[SampleBatch.ACTIONS])
|
||||
policy.entropy = dist.entropy().mean()
|
||||
policy.pi_err = -batch_tensors[Postprocessing.ADVANTAGES].dot(
|
||||
|
||||
@@ -81,7 +81,7 @@ class GenericPolicy(object):
|
||||
model = ModelCatalog.get_model({
|
||||
"obs": self.inputs
|
||||
}, obs_space, action_space, dist_dim, model_config)
|
||||
dist = dist_class(model.outputs)
|
||||
dist = dist_class(model.outputs, model_config=model_config)
|
||||
self.sampler = dist.sample()
|
||||
|
||||
self.variables = ray.experimental.tf_utils.TensorFlowVariables(
|
||||
|
||||
@@ -107,9 +107,10 @@ class QLoss(object):
|
||||
|
||||
class QValuePolicy(object):
|
||||
def __init__(self, q_values, observations, num_actions, stochastic, eps,
|
||||
softmax, softmax_temp):
|
||||
softmax, softmax_temp, model_config):
|
||||
if softmax:
|
||||
action_dist = Categorical(q_values / softmax_temp)
|
||||
action_dist = Categorical(
|
||||
q_values / softmax_temp, model_config=model_config)
|
||||
self.action = action_dist.sample()
|
||||
self.action_prob = action_dist.sampled_action_prob()
|
||||
return
|
||||
@@ -255,7 +256,8 @@ def build_q_networks(policy, q_model, input_dict, obs_space, action_space,
|
||||
# Action outputs
|
||||
qvp = QValuePolicy(q_values, input_dict[SampleBatch.CUR_OBS],
|
||||
action_space.n, policy.stochastic, policy.eps,
|
||||
config["soft_q"], config["softmax_temp"])
|
||||
config["soft_q"], config["softmax_temp"],
|
||||
config["model"])
|
||||
policy.output_actions, policy.action_prob = qvp.action, qvp.action_prob
|
||||
|
||||
return policy.output_actions, policy.action_prob
|
||||
|
||||
@@ -59,7 +59,7 @@ class GenericPolicy(object):
|
||||
model = ModelCatalog.get_model({
|
||||
"obs": self.inputs
|
||||
}, obs_space, action_space, dist_dim, model_options)
|
||||
dist = dist_class(model.outputs)
|
||||
dist = dist_class(model.outputs, model_config=model_options)
|
||||
self.sampler = dist.sample()
|
||||
|
||||
self.variables = ray.experimental.tf_utils.TensorFlowVariables(
|
||||
|
||||
@@ -49,13 +49,14 @@ VTraceReturns = collections.namedtuple("VTraceReturns", "vs pg_advantages")
|
||||
|
||||
def log_probs_from_logits_and_actions(policy_logits,
|
||||
actions,
|
||||
config,
|
||||
dist_class=Categorical):
|
||||
return multi_log_probs_from_logits_and_actions([policy_logits], [actions],
|
||||
dist_class)[0]
|
||||
dist_class, config)[0]
|
||||
|
||||
|
||||
def multi_log_probs_from_logits_and_actions(policy_logits, actions,
|
||||
dist_class):
|
||||
def multi_log_probs_from_logits_and_actions(policy_logits, actions, dist_class,
|
||||
config):
|
||||
"""Computes action log-probs from policy logits and actions.
|
||||
|
||||
In the notation used throughout documentation and comments, T refers to the
|
||||
@@ -76,6 +77,8 @@ def multi_log_probs_from_logits_and_actions(policy_logits, actions,
|
||||
...,
|
||||
[T, B, ...]
|
||||
with actions.
|
||||
dist_class: Python class of the action distribution
|
||||
config: Trainer config dict
|
||||
|
||||
Returns:
|
||||
A list with length of ACTION_SPACE of float32
|
||||
@@ -97,7 +100,8 @@ def multi_log_probs_from_logits_and_actions(policy_logits, actions,
|
||||
tf.concat([[-1], a_shape[2:]], axis=0))
|
||||
log_probs.append(
|
||||
tf.reshape(
|
||||
dist_class(policy_logits_flat).logp(actions_flat),
|
||||
dist_class(policy_logits_flat,
|
||||
model_config=config["model"]).logp(actions_flat),
|
||||
a_shape[:2]))
|
||||
|
||||
return log_probs
|
||||
@@ -110,6 +114,7 @@ def from_logits(behaviour_policy_logits,
|
||||
rewards,
|
||||
values,
|
||||
bootstrap_value,
|
||||
config,
|
||||
dist_class=Categorical,
|
||||
clip_rho_threshold=1.0,
|
||||
clip_pg_rho_threshold=1.0,
|
||||
@@ -122,6 +127,7 @@ def from_logits(behaviour_policy_logits,
|
||||
rewards,
|
||||
values,
|
||||
bootstrap_value,
|
||||
config,
|
||||
dist_class,
|
||||
clip_rho_threshold=clip_rho_threshold,
|
||||
clip_pg_rho_threshold=clip_pg_rho_threshold,
|
||||
@@ -145,6 +151,7 @@ def multi_from_logits(behaviour_policy_logits,
|
||||
rewards,
|
||||
values,
|
||||
bootstrap_value,
|
||||
config,
|
||||
dist_class,
|
||||
clip_rho_threshold=1.0,
|
||||
clip_pg_rho_threshold=1.0,
|
||||
@@ -235,9 +242,9 @@ def multi_from_logits(behaviour_policy_logits,
|
||||
discounts, rewards, values, bootstrap_value
|
||||
]):
|
||||
target_action_log_probs = multi_log_probs_from_logits_and_actions(
|
||||
target_policy_logits, actions, dist_class)
|
||||
target_policy_logits, actions, dist_class, config)
|
||||
behaviour_action_log_probs = multi_log_probs_from_logits_and_actions(
|
||||
behaviour_policy_logits, actions, dist_class)
|
||||
behaviour_policy_logits, actions, dist_class, config)
|
||||
|
||||
log_rhos = get_log_rhos(target_action_log_probs,
|
||||
behaviour_action_log_probs)
|
||||
|
||||
@@ -41,6 +41,7 @@ class VTraceLoss(object):
|
||||
bootstrap_value,
|
||||
dist_class,
|
||||
valid_mask,
|
||||
config,
|
||||
vf_loss_coeff=0.5,
|
||||
entropy_coeff=0.01,
|
||||
clip_rho_threshold=1.0,
|
||||
@@ -72,6 +73,7 @@ class VTraceLoss(object):
|
||||
bootstrap_value: A float32 tensor of shape [B].
|
||||
dist_class: action distribution class for logits.
|
||||
valid_mask: A bool tensor of valid RNN input elements (#2992).
|
||||
config: Trainer config dict.
|
||||
"""
|
||||
|
||||
# Compute vtrace on the CPU for better perf.
|
||||
@@ -87,7 +89,8 @@ class VTraceLoss(object):
|
||||
dist_class=dist_class,
|
||||
clip_rho_threshold=tf.cast(clip_rho_threshold, tf.float32),
|
||||
clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold,
|
||||
tf.float32))
|
||||
tf.float32),
|
||||
config=config)
|
||||
self.value_targets = self.vtrace_returns.vs
|
||||
|
||||
# The policy gradients loss
|
||||
@@ -196,6 +199,7 @@ def build_vtrace_loss(policy, batch_tensors):
|
||||
bootstrap_value=make_time_major(values)[-1],
|
||||
dist_class=Categorical if is_multidiscrete else policy.dist_class,
|
||||
valid_mask=make_time_major(mask, drop_last=True),
|
||||
config=policy.config,
|
||||
vf_loss_coeff=policy.config["vf_loss_coeff"],
|
||||
entropy_coeff=policy.entropy_coeff,
|
||||
clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"],
|
||||
|
||||
@@ -98,7 +98,7 @@ class LogProbsFromLogitsAndActionsTest(tf.test.TestCase,
|
||||
0, num_actions - 1, size=(seq_len, batch_size), dtype=np.int32)
|
||||
|
||||
action_log_probs_tensor = vtrace.log_probs_from_logits_and_actions(
|
||||
policy_logits, actions)
|
||||
policy_logits, actions, {"model": None}) # dummy config dict
|
||||
|
||||
# Ground Truth
|
||||
# Using broadcasting to create a mask that indexes action logits
|
||||
@@ -159,6 +159,8 @@ class VtraceTest(tf.test.TestCase, parameterized.TestCase):
|
||||
clip_rho_threshold = None # No clipping.
|
||||
clip_pg_rho_threshold = None # No clipping.
|
||||
|
||||
dummy_config = {"model": None}
|
||||
|
||||
# Intentionally leaving shapes unspecified to test if V-trace can
|
||||
# deal with that.
|
||||
placeholders = {
|
||||
@@ -178,12 +180,15 @@ class VtraceTest(tf.test.TestCase, parameterized.TestCase):
|
||||
from_logits_output = vtrace.from_logits(
|
||||
clip_rho_threshold=clip_rho_threshold,
|
||||
clip_pg_rho_threshold=clip_pg_rho_threshold,
|
||||
config=dummy_config,
|
||||
**placeholders)
|
||||
|
||||
target_log_probs = vtrace.log_probs_from_logits_and_actions(
|
||||
placeholders["target_policy_logits"], placeholders["actions"])
|
||||
placeholders["target_policy_logits"], placeholders["actions"],
|
||||
dummy_config)
|
||||
behaviour_log_probs = vtrace.log_probs_from_logits_and_actions(
|
||||
placeholders["behaviour_policy_logits"], placeholders["actions"])
|
||||
placeholders["behaviour_policy_logits"], placeholders["actions"],
|
||||
dummy_config)
|
||||
log_rhos = target_log_probs - behaviour_log_probs
|
||||
ground_truth = (log_rhos, behaviour_log_probs, target_log_probs)
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ class ValueLoss(object):
|
||||
|
||||
class ReweightedImitationLoss(object):
|
||||
def __init__(self, state_values, cumulative_rewards, logits, actions,
|
||||
action_space, beta):
|
||||
action_space, beta, model_config):
|
||||
ma_adv_norm = tf.get_variable(
|
||||
name="moving_average_of_advantage_norm",
|
||||
dtype=tf.float32,
|
||||
@@ -48,8 +48,8 @@ class ReweightedImitationLoss(object):
|
||||
beta * tf.divide(adv, 1e-8 + tf.sqrt(ma_adv_norm)))
|
||||
|
||||
# log\pi_\theta(a|s)
|
||||
dist_cls, _ = ModelCatalog.get_action_dist(action_space, {})
|
||||
action_dist = dist_cls(logits)
|
||||
dist_cls, _ = ModelCatalog.get_action_dist(action_space, model_config)
|
||||
action_dist = dist_cls(logits, model_config=model_config)
|
||||
logprobs = action_dist.logp(actions)
|
||||
|
||||
self.loss = -1.0 * tf.reduce_mean(
|
||||
@@ -106,7 +106,7 @@ class MARWILPolicy(MARWILPostprocessing, TFPolicy):
|
||||
self.p_func_vars = scope_vars(scope.name)
|
||||
|
||||
# Action outputs
|
||||
action_dist = dist_cls(logits)
|
||||
action_dist = dist_cls(logits, model_config=self.config["model"])
|
||||
self.output_actions = action_dist.sample()
|
||||
|
||||
# Training inputs
|
||||
@@ -164,7 +164,8 @@ class MARWILPolicy(MARWILPostprocessing, TFPolicy):
|
||||
def _build_policy_loss(self, state_values, cum_rwds, logits, actions,
|
||||
action_space):
|
||||
return ReweightedImitationLoss(state_values, cum_rwds, logits, actions,
|
||||
action_space, self.config["beta"])
|
||||
action_space, self.config["beta"],
|
||||
self.config["model"])
|
||||
|
||||
@override(TFPolicy)
|
||||
def extra_compute_grad_fetches(self):
|
||||
|
||||
@@ -13,7 +13,8 @@ def pg_torch_loss(policy, batch_tensors):
|
||||
logits, _ = policy.model({
|
||||
SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS]
|
||||
})
|
||||
action_dist = policy.dist_class(logits)
|
||||
action_dist = policy.dist_class(
|
||||
logits, model_config=policy.config["model"])
|
||||
log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS])
|
||||
# save the error in the policy object
|
||||
policy.pi_err = -batch_tensors[Postprocessing.ADVANTAGES].dot(
|
||||
|
||||
@@ -112,6 +112,7 @@ class VTraceSurrogateLoss(object):
|
||||
rewards,
|
||||
values,
|
||||
bootstrap_value,
|
||||
config,
|
||||
dist_class,
|
||||
valid_mask,
|
||||
vf_loss_coeff=0.5,
|
||||
@@ -143,6 +144,7 @@ class VTraceSurrogateLoss(object):
|
||||
rewards: A float32 tensor of shape [T, B].
|
||||
values: A float32 tensor of shape [T, B].
|
||||
bootstrap_value: A float32 tensor of shape [B].
|
||||
config: Trainer config dict.
|
||||
dist_class: action distribution class for logits.
|
||||
valid_mask: A bool tensor of valid RNN input elements (#2992).
|
||||
vf_loss_coeff (float): Coefficient of the value function loss.
|
||||
@@ -165,6 +167,7 @@ class VTraceSurrogateLoss(object):
|
||||
rewards=rewards,
|
||||
values=values,
|
||||
bootstrap_value=bootstrap_value,
|
||||
config=config,
|
||||
dist_class=dist_class,
|
||||
clip_rho_threshold=tf.cast(clip_rho_threshold, tf.float32),
|
||||
clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold,
|
||||
@@ -251,8 +254,10 @@ def build_appo_surrogate_loss(policy, batch_tensors):
|
||||
old_policy_behaviour_logits, output_hidden_shape, axis=1)
|
||||
unpacked_outputs = tf.split(policy.model_out, output_hidden_shape, axis=1)
|
||||
action_dist = policy.action_dist
|
||||
old_policy_action_dist = policy.dist_class(old_policy_behaviour_logits)
|
||||
prev_action_dist = policy.dist_class(behaviour_logits)
|
||||
old_policy_action_dist = policy.dist_class(
|
||||
old_policy_behaviour_logits, model_config=policy.config["model"])
|
||||
prev_action_dist = policy.dist_class(
|
||||
behaviour_logits, model_config=policy.config["model"])
|
||||
values = policy.value_function
|
||||
|
||||
policy.model_vars = policy.model.variables()
|
||||
@@ -298,6 +303,7 @@ def build_appo_surrogate_loss(policy, batch_tensors):
|
||||
rewards=make_time_major(rewards, drop_last=True),
|
||||
values=make_time_major(values, drop_last=True),
|
||||
bootstrap_value=make_time_major(values)[-1],
|
||||
config=policy.config,
|
||||
dist_class=Categorical if is_multidiscrete else policy.dist_class,
|
||||
valid_mask=make_time_major(mask, drop_last=True),
|
||||
vf_loss_coeff=policy.config["vf_loss_coeff"],
|
||||
|
||||
@@ -39,7 +39,8 @@ class PPOLoss(object):
|
||||
clip_param=0.1,
|
||||
vf_clip_param=0.1,
|
||||
vf_loss_coeff=1.0,
|
||||
use_gae=True):
|
||||
use_gae=True,
|
||||
model_config=None):
|
||||
"""Constructs the loss for Proximal Policy Objective.
|
||||
|
||||
Arguments:
|
||||
@@ -65,13 +66,15 @@ class PPOLoss(object):
|
||||
vf_clip_param (float): Clip parameter for the value function
|
||||
vf_loss_coeff (float): Coefficient of the value function loss
|
||||
use_gae (bool): If true, use the Generalized Advantage Estimator.
|
||||
model_config (dict): (Optional) model config for use in specifying
|
||||
action distributions.
|
||||
"""
|
||||
|
||||
def reduce_mean_valid(t):
|
||||
return tf.reduce_mean(tf.boolean_mask(t, valid_mask))
|
||||
|
||||
dist_cls, _ = ModelCatalog.get_action_dist(action_space, {})
|
||||
prev_dist = dist_cls(logits)
|
||||
dist_cls, _ = ModelCatalog.get_action_dist(action_space, model_config)
|
||||
prev_dist = dist_cls(logits, model_config=model_config)
|
||||
# Make loss functions.
|
||||
logp_ratio = tf.exp(
|
||||
curr_action_dist.logp(actions) - prev_dist.logp(actions))
|
||||
@@ -129,7 +132,8 @@ def ppo_surrogate_loss(policy, batch_tensors):
|
||||
clip_param=policy.config["clip_param"],
|
||||
vf_clip_param=policy.config["vf_clip_param"],
|
||||
vf_loss_coeff=policy.config["vf_loss_coeff"],
|
||||
use_gae=policy.config["use_gae"])
|
||||
use_gae=policy.config["use_gae"],
|
||||
model_config=policy.config["model"])
|
||||
|
||||
return policy.loss_obj.loss
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ class DistributionsTest(unittest.TestCase):
|
||||
logits = tf.placeholder(tf.float32, shape=(None, 10))
|
||||
z = 8 * (np.random.rand(10) - 0.5)
|
||||
data = np.tile(z, (num_samples, 1))
|
||||
c = Categorical(logits)
|
||||
c = Categorical(logits, {}) # dummy config dict
|
||||
sample_op = c.sample()
|
||||
sess = tf.Session()
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
@@ -67,7 +67,7 @@ class CustomLossModel(Model):
|
||||
print("FYI: You can also use these tensors: {}, ".format(loss_inputs))
|
||||
|
||||
# compute the IL loss
|
||||
action_dist = Categorical(logits)
|
||||
action_dist = Categorical(logits, self.options)
|
||||
self.policy_loss = policy_loss
|
||||
self.imitation_loss = tf.reduce_mean(
|
||||
-action_dist.logp(input_ops["actions"]))
|
||||
|
||||
@@ -18,7 +18,7 @@ def policy_gradient_loss(policy, batch_tensors):
|
||||
logits, _ = policy.model({
|
||||
SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS]
|
||||
})
|
||||
action_dist = policy.dist_class(logits)
|
||||
action_dist = policy.dist_class(logits, policy.config["model"])
|
||||
log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS])
|
||||
return -batch_tensors[SampleBatch.REWARDS].dot(log_probs)
|
||||
|
||||
|
||||
@@ -11,11 +11,14 @@ class ActionDistribution(object):
|
||||
|
||||
Args:
|
||||
inputs (Tensor): The input vector to compute samples from.
|
||||
model_config (dict): Optional model config dict
|
||||
(as defined in catalog.py)
|
||||
"""
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self, inputs):
|
||||
def __init__(self, inputs, model_config):
|
||||
self.inputs = inputs
|
||||
self.model_config = model_config
|
||||
|
||||
@DeveloperAPI
|
||||
def sample(self):
|
||||
@@ -52,3 +55,22 @@ class ActionDistribution(object):
|
||||
MultiDiscrete. TODO(ekl) consider removing this.
|
||||
"""
|
||||
return self.entropy()
|
||||
|
||||
@DeveloperAPI
|
||||
@staticmethod
|
||||
def required_model_output_shape(action_space, model_config):
|
||||
"""Returns the required shape of an input parameter tensor for a
|
||||
particular action space and an optional dict of distribution-specific
|
||||
options.
|
||||
|
||||
Args:
|
||||
action_space (gym.Space): The action space this distribution will
|
||||
be used for, whose shape attributes will be used to determine
|
||||
the required shape of the input parameter tensor.
|
||||
model_config (dict): Model's config dict (as defined in catalog.py)
|
||||
|
||||
Returns:
|
||||
model_output_shape (int or np.ndarray of ints): size of the
|
||||
required input vector (minus leading batch dimension).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
+47
-23
@@ -8,7 +8,7 @@ import numpy as np
|
||||
from functools import partial
|
||||
|
||||
from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \
|
||||
_global_registry
|
||||
RLLIB_ACTION_DIST, _global_registry
|
||||
|
||||
from ray.rllib.models.extra_spaces import Simplex
|
||||
from ray.rllib.models.torch.torch_action_dist import (TorchCategorical,
|
||||
@@ -80,6 +80,8 @@ MODEL_DEFAULTS = {
|
||||
"custom_preprocessor": None,
|
||||
# Name of a custom model to use
|
||||
"custom_model": None,
|
||||
# Name of a custom action distribution to use
|
||||
"custom_action_dist": None,
|
||||
# Extra options to pass to the custom classes
|
||||
"custom_options": {},
|
||||
}
|
||||
@@ -119,22 +121,30 @@ class ModelCatalog(object):
|
||||
"""
|
||||
|
||||
config = config or MODEL_DEFAULTS
|
||||
if isinstance(action_space, gym.spaces.Box):
|
||||
if config.get("custom_action_dist"):
|
||||
action_dist_name = config["custom_action_dist"]
|
||||
logger.debug(
|
||||
"Using custom action distribution {}".format(action_dist_name))
|
||||
dist = _global_registry.get(RLLIB_ACTION_DIST, action_dist_name)
|
||||
|
||||
elif isinstance(action_space, gym.spaces.Box):
|
||||
if len(action_space.shape) > 1:
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space has multiple dimensions "
|
||||
"{}. ".format(action_space.shape) +
|
||||
"Consider reshaping this into a single dimension, "
|
||||
"using a custom action distribution, "
|
||||
"using a Tuple action space, or the multi-agent API.")
|
||||
if dist_type is None:
|
||||
dist = TorchDiagGaussian if torch else DiagGaussian
|
||||
return dist, action_space.shape[0] * 2
|
||||
elif dist_type == "deterministic":
|
||||
return Deterministic, action_space.shape[0]
|
||||
dist = Deterministic
|
||||
elif isinstance(action_space, gym.spaces.Discrete):
|
||||
dist = TorchCategorical if torch else Categorical
|
||||
return dist, action_space.n
|
||||
elif isinstance(action_space, gym.spaces.Tuple):
|
||||
if torch:
|
||||
raise NotImplementedError("Tuple action spaces not supported "
|
||||
"for Pytorch.")
|
||||
child_dist = []
|
||||
input_lens = []
|
||||
for action in action_space.spaces:
|
||||
@@ -142,8 +152,6 @@ class ModelCatalog(object):
|
||||
action, config)
|
||||
child_dist.append(dist)
|
||||
input_lens.append(action_size)
|
||||
if torch:
|
||||
raise NotImplementedError
|
||||
return partial(
|
||||
MultiActionDistribution,
|
||||
child_distributions=child_dist,
|
||||
@@ -151,14 +159,18 @@ class ModelCatalog(object):
|
||||
input_lens=input_lens), sum(input_lens)
|
||||
elif isinstance(action_space, Simplex):
|
||||
if torch:
|
||||
raise NotImplementedError
|
||||
return Dirichlet, action_space.shape[0]
|
||||
elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete):
|
||||
raise NotImplementedError("Simplex action spaces not "
|
||||
"supported for Pytorch.")
|
||||
dist = Dirichlet
|
||||
elif isinstance(action_space, gym.spaces.MultiDiscrete):
|
||||
if torch:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError("MultiDiscrete action spaces not "
|
||||
"supported for Pytorch.")
|
||||
return partial(MultiCategorical, input_lens=action_space.nvec), \
|
||||
int(sum(action_space.nvec))
|
||||
|
||||
return dist, dist.required_model_output_shape(action_space, config)
|
||||
|
||||
raise NotImplementedError("Unsupported args: {} {}".format(
|
||||
action_space, dist_type))
|
||||
|
||||
@@ -173,11 +185,16 @@ class ModelCatalog(object):
|
||||
action_placeholder (Tensor): A placeholder for the actions
|
||||
"""
|
||||
|
||||
if isinstance(action_space, gym.spaces.Box):
|
||||
return tf.placeholder(
|
||||
tf.float32, shape=(None, action_space.shape[0]), name="action")
|
||||
elif isinstance(action_space, gym.spaces.Discrete):
|
||||
if isinstance(action_space, gym.spaces.Discrete):
|
||||
return tf.placeholder(tf.int64, shape=(None, ), name="action")
|
||||
elif isinstance(action_space, (gym.spaces.Box, Simplex)):
|
||||
return tf.placeholder(
|
||||
tf.float32, shape=(None, ) + action_space.shape, name="action")
|
||||
elif isinstance(action_space, gym.spaces.MultiDiscrete):
|
||||
return tf.placeholder(
|
||||
tf.as_dtype(action_space.dtype),
|
||||
shape=(None, ) + action_space.shape,
|
||||
name="action")
|
||||
elif isinstance(action_space, gym.spaces.Tuple):
|
||||
size = 0
|
||||
all_discrete = True
|
||||
@@ -191,14 +208,6 @@ class ModelCatalog(object):
|
||||
tf.int64 if all_discrete else tf.float32,
|
||||
shape=(None, size),
|
||||
name="action")
|
||||
elif isinstance(action_space, Simplex):
|
||||
return tf.placeholder(
|
||||
tf.float32, shape=(None, action_space.shape[0]), name="action")
|
||||
elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete):
|
||||
return tf.placeholder(
|
||||
tf.as_dtype(action_space.dtype),
|
||||
shape=(None, len(action_space.nvec)),
|
||||
name="action")
|
||||
else:
|
||||
raise NotImplementedError("action space {}"
|
||||
" not supported".format(action_space))
|
||||
@@ -362,6 +371,21 @@ class ModelCatalog(object):
|
||||
"""
|
||||
_global_registry.register(RLLIB_MODEL, model_name, model_class)
|
||||
|
||||
@staticmethod
|
||||
@PublicAPI
|
||||
def register_custom_action_dist(action_dist_name, action_dist_class):
|
||||
"""Register a custom action distribution class by name.
|
||||
|
||||
The model can be later used by specifying
|
||||
{"custom_action_dist": action_dist_name} in the model config.
|
||||
|
||||
Args:
|
||||
model_name (str): Name to register the action distribution under.
|
||||
model_class (type): Python class of the action distribution.
|
||||
"""
|
||||
_global_registry.register(RLLIB_ACTION_DIST, action_dist_name,
|
||||
action_dist_class)
|
||||
|
||||
@staticmethod
|
||||
def _wrap_if_needed(model_cls, model_interface):
|
||||
assert issubclass(model_cls, TFModelV2)
|
||||
|
||||
@@ -17,8 +17,9 @@ class TFActionDistribution(ActionDistribution):
|
||||
"""TF-specific extensions for building action distributions."""
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self, inputs):
|
||||
super(TFActionDistribution, self).__init__(inputs)
|
||||
def __init__(self, inputs, model_config):
|
||||
super(TFActionDistribution, self).__init__(
|
||||
inputs, model_config=model_config)
|
||||
self.sample_op = self._build_sample_op()
|
||||
|
||||
@DeveloperAPI
|
||||
@@ -76,16 +77,22 @@ class Categorical(TFActionDistribution):
|
||||
def _build_sample_op(self):
|
||||
return tf.squeeze(tf.multinomial(self.inputs, 1), axis=1)
|
||||
|
||||
@staticmethod
|
||||
@override(ActionDistribution)
|
||||
def required_model_output_shape(action_space, model_config):
|
||||
return action_space.n
|
||||
|
||||
|
||||
class MultiCategorical(TFActionDistribution):
|
||||
"""Categorical distribution for discrete action spaces."""
|
||||
"""MultiCategorical distribution for MultiDiscrete action spaces."""
|
||||
|
||||
def __init__(self, inputs, input_lens):
|
||||
def __init__(self, inputs, input_lens, model_config):
|
||||
self.cats = [
|
||||
Categorical(input_)
|
||||
Categorical(input_, model_config=model_config)
|
||||
for input_ in tf.split(inputs, input_lens, axis=1)
|
||||
]
|
||||
self.sample_op = self._build_sample_op()
|
||||
self.model_config = model_config
|
||||
|
||||
@override(ActionDistribution)
|
||||
def logp(self, actions):
|
||||
@@ -116,6 +123,11 @@ class MultiCategorical(TFActionDistribution):
|
||||
def _build_sample_op(self):
|
||||
return tf.stack([cat.sample() for cat in self.cats], axis=1)
|
||||
|
||||
@staticmethod
|
||||
@override(ActionDistribution)
|
||||
def required_model_output_shape(action_space, model_config):
|
||||
return np.sum(action_space.nvec)
|
||||
|
||||
|
||||
class DiagGaussian(TFActionDistribution):
|
||||
"""Action distribution where each vector element is a gaussian.
|
||||
@@ -124,12 +136,12 @@ class DiagGaussian(TFActionDistribution):
|
||||
second half the gaussian standard deviations.
|
||||
"""
|
||||
|
||||
def __init__(self, inputs):
|
||||
def __init__(self, inputs, model_config):
|
||||
mean, log_std = tf.split(inputs, 2, axis=1)
|
||||
self.mean = mean
|
||||
self.log_std = log_std
|
||||
self.std = tf.exp(log_std)
|
||||
TFActionDistribution.__init__(self, inputs)
|
||||
super(DiagGaussian, self).__init__(inputs, model_config)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def logp(self, x):
|
||||
@@ -157,6 +169,11 @@ class DiagGaussian(TFActionDistribution):
|
||||
def _build_sample_op(self):
|
||||
return self.mean + self.std * tf.random_normal(tf.shape(self.mean))
|
||||
|
||||
@staticmethod
|
||||
@override(ActionDistribution)
|
||||
def required_model_output_shape(action_space, model_config):
|
||||
return np.prod(action_space.shape) * 2
|
||||
|
||||
|
||||
class Deterministic(TFActionDistribution):
|
||||
"""Action distribution that returns the input values directly.
|
||||
@@ -172,6 +189,11 @@ class Deterministic(TFActionDistribution):
|
||||
def _build_sample_op(self):
|
||||
return self.inputs
|
||||
|
||||
@staticmethod
|
||||
@override(ActionDistribution)
|
||||
def required_model_output_shape(action_space, model_config):
|
||||
return np.prod(action_space.shape)
|
||||
|
||||
|
||||
class MultiActionDistribution(TFActionDistribution):
|
||||
"""Action distribution that operates for list of actions.
|
||||
@@ -180,12 +202,14 @@ class MultiActionDistribution(TFActionDistribution):
|
||||
inputs (Tensor list): A list of tensors from which to compute samples.
|
||||
"""
|
||||
|
||||
def __init__(self, inputs, action_space, child_distributions, input_lens):
|
||||
def __init__(self, inputs, action_space, child_distributions, input_lens,
|
||||
model_config):
|
||||
self.input_lens = input_lens
|
||||
split_inputs = tf.split(inputs, self.input_lens, axis=1)
|
||||
child_list = []
|
||||
for i, distribution in enumerate(child_distributions):
|
||||
child_list.append(distribution(split_inputs[i]))
|
||||
child_list.append(
|
||||
distribution(split_inputs[i], model_config=model_config))
|
||||
self.child_distributions = child_list
|
||||
|
||||
@override(ActionDistribution)
|
||||
@@ -241,7 +265,7 @@ class Dirichlet(TFActionDistribution):
|
||||
|
||||
e.g. actions that represent resource allocation."""
|
||||
|
||||
def __init__(self, inputs):
|
||||
def __init__(self, inputs, model_config):
|
||||
"""Input is a tensor of logits. The exponential of logits is used to
|
||||
parametrize the Dirichlet distribution as all parameters need to be
|
||||
positive. An arbitrary small epsilon is added to the concentration
|
||||
@@ -256,7 +280,8 @@ class Dirichlet(TFActionDistribution):
|
||||
validate_args=True,
|
||||
allow_nan_stats=False,
|
||||
)
|
||||
TFActionDistribution.__init__(self, concentration)
|
||||
super(Dirichlet, self).__init__(
|
||||
concentration, model_config=model_config)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def logp(self, x):
|
||||
@@ -278,3 +303,8 @@ class Dirichlet(TFActionDistribution):
|
||||
@override(TFActionDistribution)
|
||||
def _build_sample_op(self):
|
||||
return self.dist.sample()
|
||||
|
||||
@staticmethod
|
||||
@override(ActionDistribution)
|
||||
def required_model_output_shape(action_space, model_config):
|
||||
return np.prod(action_space.shape)
|
||||
|
||||
@@ -7,6 +7,8 @@ try:
|
||||
except ImportError:
|
||||
pass # soft dep
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.models.action_dist import ActionDistribution
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
@@ -35,18 +37,30 @@ class TorchCategorical(TorchDistributionWrapper):
|
||||
"""Wrapper class for PyTorch Categorical distribution."""
|
||||
|
||||
@override(ActionDistribution)
|
||||
def __init__(self, inputs):
|
||||
def __init__(self, inputs, model_config):
|
||||
self.dist = torch.distributions.categorical.Categorical(logits=inputs)
|
||||
self.model_config = model_config
|
||||
|
||||
@staticmethod
|
||||
@override(ActionDistribution)
|
||||
def required_model_output_shape(action_space, model_config):
|
||||
return action_space.n
|
||||
|
||||
|
||||
class TorchDiagGaussian(TorchDistributionWrapper):
|
||||
"""Wrapper class for PyTorch Normal distribution."""
|
||||
|
||||
@override(ActionDistribution)
|
||||
def __init__(self, inputs):
|
||||
def __init__(self, inputs, model_config):
|
||||
mean, log_std = torch.chunk(inputs, 2, dim=1)
|
||||
self.dist = torch.distributions.normal.Normal(mean, torch.exp(log_std))
|
||||
self.model_config = model_config
|
||||
|
||||
@override(TorchDistributionWrapper)
|
||||
def logp(self, actions):
|
||||
return TorchDistributionWrapper.logp(self, actions).sum(-1)
|
||||
|
||||
@staticmethod
|
||||
@override(ActionDistribution)
|
||||
def required_model_output_shape(action_space, model_config):
|
||||
return np.prod(action_space.shape) * 2
|
||||
|
||||
@@ -166,7 +166,8 @@ class DynamicTFPolicy(TFPolicy):
|
||||
self, self.model, self.input_dict, obs_space, action_space,
|
||||
config)
|
||||
else:
|
||||
self.action_dist = self.dist_class(self.model_out)
|
||||
self.action_dist = self.dist_class(
|
||||
self.model_out, model_config=self.config["model"])
|
||||
action_sampler = self.action_dist.sample()
|
||||
action_prob = self.action_dist.sampled_action_prob()
|
||||
|
||||
@@ -261,7 +262,7 @@ class DynamicTFPolicy(TFPolicy):
|
||||
def _initialize_loss(self):
|
||||
def fake_array(tensor):
|
||||
shape = tensor.shape.as_list()
|
||||
shape[0] = 1
|
||||
shape = [s if s is not None else 1 for s in shape]
|
||||
return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype)
|
||||
|
||||
dummy_batch = {
|
||||
|
||||
@@ -78,7 +78,8 @@ class TorchPolicy(Policy):
|
||||
input_dict["prev_rewards"] = prev_reward_batch
|
||||
model_out = self._model(input_dict, state_batches, [1])
|
||||
logits, state = model_out
|
||||
action_dist = self._action_dist_cls(logits)
|
||||
action_dist = self._action_dist_cls(
|
||||
logits, model_config=self.config["model"])
|
||||
actions = action_dist.sample()
|
||||
return (actions.cpu().numpy(),
|
||||
[h.cpu().numpy() for h in state],
|
||||
|
||||
@@ -5,8 +5,9 @@ from gym.spaces import Box, Discrete, Tuple
|
||||
|
||||
import ray
|
||||
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models import ModelCatalog, MODEL_DEFAULTS
|
||||
from ray.rllib.models.model import Model
|
||||
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
|
||||
from ray.rllib.models.preprocessors import (NoPreprocessor, OneHotPreprocessor,
|
||||
Preprocessor)
|
||||
from ray.rllib.models.tf.fcnet_v1 import FullyConnectedNetwork
|
||||
@@ -31,6 +32,25 @@ class CustomModel(Model):
|
||||
return tf.constant([[0] * 5]), None
|
||||
|
||||
|
||||
class CustomActionDistribution(TFActionDistribution):
|
||||
@staticmethod
|
||||
def required_model_output_shape(action_space, model_config=None):
|
||||
custom_options = model_config["custom_options"] or {}
|
||||
if custom_options is not None and custom_options.get("output_dim"):
|
||||
return custom_options.get("output_dim")
|
||||
return action_space.shape
|
||||
|
||||
def _build_sample_op(self):
|
||||
custom_options = self.model_config["custom_options"]
|
||||
if "output_dim" in custom_options:
|
||||
output_shape = tf.concat(
|
||||
[tf.shape(self.inputs)[:1], custom_options["output_dim"]],
|
||||
axis=0)
|
||||
else:
|
||||
output_shape = tf.shape(self.inputs)
|
||||
return tf.random_uniform(output_shape)
|
||||
|
||||
|
||||
class ModelCatalogTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
@@ -94,6 +114,41 @@ class ModelCatalogTest(unittest.TestCase):
|
||||
{"custom_model": "foo"})
|
||||
self.assertEqual(str(type(p1)), str(CustomModel))
|
||||
|
||||
def testCustomActionDistribution(self):
|
||||
ray.init()
|
||||
# registration
|
||||
ModelCatalog.register_custom_action_dist("test",
|
||||
CustomActionDistribution)
|
||||
action_space = Box(0, 1, shape=(5, 3), dtype=np.float32)
|
||||
|
||||
# test retrieving it
|
||||
model_config = MODEL_DEFAULTS.copy()
|
||||
model_config["custom_action_dist"] = "test"
|
||||
dist_cls, param_shape = ModelCatalog.get_action_dist(
|
||||
action_space, model_config)
|
||||
self.assertEqual(str(dist_cls), str(CustomActionDistribution))
|
||||
self.assertEqual(param_shape, action_space.shape)
|
||||
|
||||
# test the class works as a distribution
|
||||
dist_input = tf.placeholder(tf.float32, (None, ) + param_shape)
|
||||
dist = dist_cls(dist_input, model_config=model_config)
|
||||
self.assertEqual(dist.sample().shape[1:], dist_input.shape[1:])
|
||||
self.assertIsInstance(dist.sample(), tf.Tensor)
|
||||
with self.assertRaises(NotImplementedError):
|
||||
dist.entropy()
|
||||
|
||||
# test passing the options to it
|
||||
model_config["custom_options"].update({"output_dim": (3, )})
|
||||
dist_cls, param_shape = ModelCatalog.get_action_dist(
|
||||
action_space, model_config)
|
||||
self.assertEqual(param_shape, (3, ))
|
||||
dist_input = tf.placeholder(tf.float32, (None, ) + param_shape)
|
||||
dist = dist_cls(dist_input, model_config=model_config)
|
||||
self.assertEqual(dist.sample().shape[1:], dist_input.shape[1:])
|
||||
self.assertIsInstance(dist.sample(), tf.Tensor)
|
||||
with self.assertRaises(NotImplementedError):
|
||||
dist.entropy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user