KL Divergence Metrics (#3300)

* added KL divergence metrics

* fix
This commit is contained in:
andrewztan
2018-11-13 23:12:35 -08:00
committed by Eric Liang
parent 1660c9d627
commit 57c7b4238e
@@ -16,6 +16,7 @@ from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.explained_variance import explained_variance
from ray.rllib.models.action_dist import Categorical
class VTraceLoss(object):
@@ -184,6 +185,14 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
clip_rho_threshold=self.config["vtrace_clip_rho_threshold"],
clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"])
# KL divergence between worker and learner logits for debugging
model_dist = Categorical(self.model.outputs)
behaviour_dist = Categorical(behaviour_logits)
self.KLs = model_dist.kl(behaviour_dist)
self.mean_KL = tf.reduce_mean(self.KLs)
self.max_KL = tf.reduce_max(self.KLs)
self.median_KL = tf.contrib.distributions.percentile(self.KLs, 50.0)
# Initialize TFPolicyGraph
loss_in = [
("actions", actions),
@@ -225,6 +234,9 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
"vf_explained_var": explained_variance(
tf.reshape(self.loss.vtrace_returns.vs, [-1]),
tf.reshape(to_batches(values)[:-1], [-1])),
"mean_KL": self.mean_KL,
"max_KL": self.max_KL,
"median_KL": self.median_KL,
},
}