[rllib] MultiCategorical shouldn't return array for kl or entropy (#5215)

* wip

* fix
This commit is contained in:
Eric Liang
2019-07-19 12:12:04 -07:00
committed by GitHub
parent da7676c925
commit d58b986858
4 changed files with 36 additions and 11 deletions
@@ -184,7 +184,8 @@ def build_vtrace_loss(policy, batch_tensors):
actions=make_time_major(loss_actions, drop_last=True),
actions_logp=make_time_major(
action_dist.logp(actions), drop_last=True),
actions_entropy=make_time_major(action_dist.entropy(), drop_last=True),
actions_entropy=make_time_major(
action_dist.multi_entropy(), drop_last=True),
dones=make_time_major(dones, drop_last=True),
behaviour_logits=make_time_major(
unpacked_behaviour_logits, drop_last=True),
+4 -4
View File
@@ -205,9 +205,9 @@ def build_appo_surrogate_loss(policy, batch_tensors):
prev_action_dist.logp(actions), drop_last=True),
actions_logp=make_time_major(
action_dist.logp(actions), drop_last=True),
action_kl=prev_action_dist.kl(action_dist),
action_kl=prev_action_dist.multi_kl(action_dist),
actions_entropy=make_time_major(
action_dist.entropy(), drop_last=True),
action_dist.multi_entropy(), drop_last=True),
dones=make_time_major(dones, drop_last=True),
behaviour_logits=make_time_major(
unpacked_behaviour_logits, drop_last=True),
@@ -229,8 +229,8 @@ def build_appo_surrogate_loss(policy, batch_tensors):
policy.loss = PPOSurrogateLoss(
prev_actions_logp=make_time_major(prev_action_dist.logp(actions)),
actions_logp=make_time_major(action_dist.logp(actions)),
action_kl=prev_action_dist.kl(action_dist),
actions_entropy=make_time_major(action_dist.entropy()),
action_kl=prev_action_dist.multi_kl(action_dist),
actions_entropy=make_time_major(action_dist.multi_entropy()),
values=make_time_major(values),
valid_mask=make_time_major(mask),
advantages=make_time_major(
+30 -2
View File
@@ -69,6 +69,22 @@ class ActionDistribution(object):
"""Returns the log probability of the sampled action."""
return tf.exp(self.logp(self.sample_op))
def multi_kl(self, other):
"""The KL-divergence between two action distributions.
This differs from kl() in that it can return an array for
MultiDiscrete. TODO(ekl) consider removing this.
"""
return self.kl(other)
def multi_entropy(self):
"""The entropy of the action distribution.
This differs from entropy() in that it can return an array for
MultiDiscrete. TODO(ekl) consider removing this.
"""
return self.entropy()
class Categorical(ActionDistribution):
"""Categorical distribution for discrete action spaces."""
@@ -133,6 +149,7 @@ class MultiCategorical(ActionDistribution):
]
self.sample_op = self._build_sample_op()
@override(ActionDistribution)
def logp(self, actions):
# If tensor is provided, unstack it into list
if isinstance(actions, tf.Tensor):
@@ -141,12 +158,23 @@ class MultiCategorical(ActionDistribution):
[cat.logp(act) for cat, act in zip(self.cats, actions)])
return tf.reduce_sum(logps, axis=0)
def entropy(self):
@override(ActionDistribution)
def multi_entropy(self):
return tf.stack([cat.entropy() for cat in self.cats], axis=1)
def kl(self, other):
@override(ActionDistribution)
def entropy(self):
return tf.reduce_sum(self.multi_entropy(), axis=1)
@override(ActionDistribution)
def multi_kl(self, other):
return [cat.kl(oth_cat) for cat, oth_cat in zip(self.cats, other.cats)]
@override(ActionDistribution)
def kl(self, other):
return tf.reduce_sum(self.multi_kl(other), axis=1)
@override(ActionDistribution)
def _build_sample_op(self):
return tf.stack([cat.sample() for cat in self.cats], axis=1)