From fb88f7efe6a5158c77c31ec0df3ac603b6794b3f Mon Sep 17 00:00:00 2001 From: Federico Fontana Date: Thu, 4 Apr 2019 22:33:09 +0100 Subject: [PATCH] Fixed bug in Dirichlet (#4440) (#4560) --- python/ray/rllib/models/action_dist.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/python/ray/rllib/models/action_dist.py b/python/ray/rllib/models/action_dist.py index 138fd9f8a..026a6c493 100644 --- a/python/ray/rllib/models/action_dist.py +++ b/python/ray/rllib/models/action_dist.py @@ -261,17 +261,35 @@ TupleActions = namedtuple("TupleActions", ["batches"]) class Dirichlet(ActionDistribution): - """Dirichlet distribution for countinuous actions that are between + """Dirichlet distribution for continuous actions that are between [0,1] and sum to 1. e.g. actions that represent resource allocation.""" def __init__(self, inputs): - self.dist = tf.distributions.Dirichlet(concentration=inputs) - ActionDistribution.__init__(self, inputs) + """Input is a tensor of logits. The exponential of logits is used to + parametrize the Dirichlet distribution as all parameters need to be + positive. An arbitrary small epsilon is added to the concentration + parameters to be zero due to numerical error. + + See issue #4440 for more details. + """ + self.epsilon = 1e-7 + concentration = tf.exp(inputs) + self.epsilon + self.dist = tf.distributions.Dirichlet( + concentration=concentration, + validate_args=True, + allow_nan_stats=False, + ) + ActionDistribution.__init__(self, concentration) @override(ActionDistribution) def logp(self, x): + # Support of Dirichlet are positive real numbers. x is already be + # an array of positive number, but we clip to avoid zeros due to + # numerical errors. + x = tf.maximum(x, self.epsilon) + x = x / tf.reduce_sum(x, axis=-1, keepdims=True) return self.dist.log_prob(x) @override(ActionDistribution)