mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 05:24:48 +08:00
[rllib] MultiCategorical shouldn't return array for kl or entropy (#5215)
* wip * fix
This commit is contained in:
@@ -184,7 +184,8 @@ def build_vtrace_loss(policy, batch_tensors):
|
||||
actions=make_time_major(loss_actions, drop_last=True),
|
||||
actions_logp=make_time_major(
|
||||
action_dist.logp(actions), drop_last=True),
|
||||
actions_entropy=make_time_major(action_dist.entropy(), drop_last=True),
|
||||
actions_entropy=make_time_major(
|
||||
action_dist.multi_entropy(), drop_last=True),
|
||||
dones=make_time_major(dones, drop_last=True),
|
||||
behaviour_logits=make_time_major(
|
||||
unpacked_behaviour_logits, drop_last=True),
|
||||
|
||||
@@ -205,9 +205,9 @@ def build_appo_surrogate_loss(policy, batch_tensors):
|
||||
prev_action_dist.logp(actions), drop_last=True),
|
||||
actions_logp=make_time_major(
|
||||
action_dist.logp(actions), drop_last=True),
|
||||
action_kl=prev_action_dist.kl(action_dist),
|
||||
action_kl=prev_action_dist.multi_kl(action_dist),
|
||||
actions_entropy=make_time_major(
|
||||
action_dist.entropy(), drop_last=True),
|
||||
action_dist.multi_entropy(), drop_last=True),
|
||||
dones=make_time_major(dones, drop_last=True),
|
||||
behaviour_logits=make_time_major(
|
||||
unpacked_behaviour_logits, drop_last=True),
|
||||
@@ -229,8 +229,8 @@ def build_appo_surrogate_loss(policy, batch_tensors):
|
||||
policy.loss = PPOSurrogateLoss(
|
||||
prev_actions_logp=make_time_major(prev_action_dist.logp(actions)),
|
||||
actions_logp=make_time_major(action_dist.logp(actions)),
|
||||
action_kl=prev_action_dist.kl(action_dist),
|
||||
actions_entropy=make_time_major(action_dist.entropy()),
|
||||
action_kl=prev_action_dist.multi_kl(action_dist),
|
||||
actions_entropy=make_time_major(action_dist.multi_entropy()),
|
||||
values=make_time_major(values),
|
||||
valid_mask=make_time_major(mask),
|
||||
advantages=make_time_major(
|
||||
|
||||
@@ -69,6 +69,22 @@ class ActionDistribution(object):
|
||||
"""Returns the log probability of the sampled action."""
|
||||
return tf.exp(self.logp(self.sample_op))
|
||||
|
||||
def multi_kl(self, other):
|
||||
"""The KL-divergence between two action distributions.
|
||||
|
||||
This differs from kl() in that it can return an array for
|
||||
MultiDiscrete. TODO(ekl) consider removing this.
|
||||
"""
|
||||
return self.kl(other)
|
||||
|
||||
def multi_entropy(self):
|
||||
"""The entropy of the action distribution.
|
||||
|
||||
This differs from entropy() in that it can return an array for
|
||||
MultiDiscrete. TODO(ekl) consider removing this.
|
||||
"""
|
||||
return self.entropy()
|
||||
|
||||
|
||||
class Categorical(ActionDistribution):
|
||||
"""Categorical distribution for discrete action spaces."""
|
||||
@@ -133,6 +149,7 @@ class MultiCategorical(ActionDistribution):
|
||||
]
|
||||
self.sample_op = self._build_sample_op()
|
||||
|
||||
@override(ActionDistribution)
|
||||
def logp(self, actions):
|
||||
# If tensor is provided, unstack it into list
|
||||
if isinstance(actions, tf.Tensor):
|
||||
@@ -141,12 +158,23 @@ class MultiCategorical(ActionDistribution):
|
||||
[cat.logp(act) for cat, act in zip(self.cats, actions)])
|
||||
return tf.reduce_sum(logps, axis=0)
|
||||
|
||||
def entropy(self):
|
||||
@override(ActionDistribution)
|
||||
def multi_entropy(self):
|
||||
return tf.stack([cat.entropy() for cat in self.cats], axis=1)
|
||||
|
||||
def kl(self, other):
|
||||
@override(ActionDistribution)
|
||||
def entropy(self):
|
||||
return tf.reduce_sum(self.multi_entropy(), axis=1)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def multi_kl(self, other):
|
||||
return [cat.kl(oth_cat) for cat, oth_cat in zip(self.cats, other.cats)]
|
||||
|
||||
@override(ActionDistribution)
|
||||
def kl(self, other):
|
||||
return tf.reduce_sum(self.multi_kl(other), axis=1)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def _build_sample_op(self):
|
||||
return tf.stack([cat.sample() for cat in self.cats], axis=1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user