From d58b986858463057c92e4977f02bfa0dbcc81959 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Fri, 19 Jul 2019 12:12:04 -0700 Subject: [PATCH] [rllib] MultiCategorical shouldn't return array for kl or entropy (#5215) * wip * fix --- ci/travis/format.sh | 4 --- .../ray/rllib/agents/impala/vtrace_policy.py | 3 +- python/ray/rllib/agents/ppo/appo_policy.py | 8 ++--- python/ray/rllib/models/action_dist.py | 32 +++++++++++++++++-- 4 files changed, 36 insertions(+), 11 deletions(-) diff --git a/ci/travis/format.sh b/ci/travis/format.sh index 0e61e1d7d..aec3b4bc3 100755 --- a/ci/travis/format.sh +++ b/ci/travis/format.sh @@ -29,10 +29,6 @@ YAPF_VERSION=$(yapf --version | awk '{print $2}') tool_version_check() { if [[ $2 != $3 ]]; then echo "WARNING: Ray uses $1 $3, You currently are using $2. This might generate different results." - read -p "Do you want to continue? [y/n] " answer - if ! [ $answer = 'y' ] && ! [ $answer = 'Y' ]; then - exit 1 - fi fi } diff --git a/python/ray/rllib/agents/impala/vtrace_policy.py b/python/ray/rllib/agents/impala/vtrace_policy.py index a13764285..ec49ef0b1 100644 --- a/python/ray/rllib/agents/impala/vtrace_policy.py +++ b/python/ray/rllib/agents/impala/vtrace_policy.py @@ -184,7 +184,8 @@ def build_vtrace_loss(policy, batch_tensors): actions=make_time_major(loss_actions, drop_last=True), actions_logp=make_time_major( action_dist.logp(actions), drop_last=True), - actions_entropy=make_time_major(action_dist.entropy(), drop_last=True), + actions_entropy=make_time_major( + action_dist.multi_entropy(), drop_last=True), dones=make_time_major(dones, drop_last=True), behaviour_logits=make_time_major( unpacked_behaviour_logits, drop_last=True), diff --git a/python/ray/rllib/agents/ppo/appo_policy.py b/python/ray/rllib/agents/ppo/appo_policy.py index ad8452162..8fc991552 100644 --- a/python/ray/rllib/agents/ppo/appo_policy.py +++ b/python/ray/rllib/agents/ppo/appo_policy.py @@ -205,9 +205,9 @@ def build_appo_surrogate_loss(policy, batch_tensors): prev_action_dist.logp(actions), drop_last=True), actions_logp=make_time_major( action_dist.logp(actions), drop_last=True), - action_kl=prev_action_dist.kl(action_dist), + action_kl=prev_action_dist.multi_kl(action_dist), actions_entropy=make_time_major( - action_dist.entropy(), drop_last=True), + action_dist.multi_entropy(), drop_last=True), dones=make_time_major(dones, drop_last=True), behaviour_logits=make_time_major( unpacked_behaviour_logits, drop_last=True), @@ -229,8 +229,8 @@ def build_appo_surrogate_loss(policy, batch_tensors): policy.loss = PPOSurrogateLoss( prev_actions_logp=make_time_major(prev_action_dist.logp(actions)), actions_logp=make_time_major(action_dist.logp(actions)), - action_kl=prev_action_dist.kl(action_dist), - actions_entropy=make_time_major(action_dist.entropy()), + action_kl=prev_action_dist.multi_kl(action_dist), + actions_entropy=make_time_major(action_dist.multi_entropy()), values=make_time_major(values), valid_mask=make_time_major(mask), advantages=make_time_major( diff --git a/python/ray/rllib/models/action_dist.py b/python/ray/rllib/models/action_dist.py index 303f3bed2..b5a69ad75 100644 --- a/python/ray/rllib/models/action_dist.py +++ b/python/ray/rllib/models/action_dist.py @@ -69,6 +69,22 @@ class ActionDistribution(object): """Returns the log probability of the sampled action.""" return tf.exp(self.logp(self.sample_op)) + def multi_kl(self, other): + """The KL-divergence between two action distributions. + + This differs from kl() in that it can return an array for + MultiDiscrete. TODO(ekl) consider removing this. + """ + return self.kl(other) + + def multi_entropy(self): + """The entropy of the action distribution. + + This differs from entropy() in that it can return an array for + MultiDiscrete. TODO(ekl) consider removing this. + """ + return self.entropy() + class Categorical(ActionDistribution): """Categorical distribution for discrete action spaces.""" @@ -133,6 +149,7 @@ class MultiCategorical(ActionDistribution): ] self.sample_op = self._build_sample_op() + @override(ActionDistribution) def logp(self, actions): # If tensor is provided, unstack it into list if isinstance(actions, tf.Tensor): @@ -141,12 +158,23 @@ class MultiCategorical(ActionDistribution): [cat.logp(act) for cat, act in zip(self.cats, actions)]) return tf.reduce_sum(logps, axis=0) - def entropy(self): + @override(ActionDistribution) + def multi_entropy(self): return tf.stack([cat.entropy() for cat in self.cats], axis=1) - def kl(self, other): + @override(ActionDistribution) + def entropy(self): + return tf.reduce_sum(self.multi_entropy(), axis=1) + + @override(ActionDistribution) + def multi_kl(self, other): return [cat.kl(oth_cat) for cat, oth_cat in zip(self.cats, other.cats)] + @override(ActionDistribution) + def kl(self, other): + return tf.reduce_sum(self.multi_kl(other), axis=1) + + @override(ActionDistribution) def _build_sample_op(self): return tf.stack([cat.sample() for cat in self.cats], axis=1)