From 57c7b4238ef02e61de256feaa77f8028a28d630d Mon Sep 17 00:00:00 2001 From: andrewztan Date: Tue, 13 Nov 2018 23:12:35 -0800 Subject: [PATCH] KL Divergence Metrics (#3300) * added KL divergence metrics * fix --- .../ray/rllib/agents/impala/vtrace_policy_graph.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/python/ray/rllib/agents/impala/vtrace_policy_graph.py b/python/ray/rllib/agents/impala/vtrace_policy_graph.py index a08836e93..abb6efed5 100644 --- a/python/ray/rllib/agents/impala/vtrace_policy_graph.py +++ b/python/ray/rllib/agents/impala/vtrace_policy_graph.py @@ -16,6 +16,7 @@ from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \ from ray.rllib.models.catalog import ModelCatalog from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.explained_variance import explained_variance +from ray.rllib.models.action_dist import Categorical class VTraceLoss(object): @@ -184,6 +185,14 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph): clip_rho_threshold=self.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"]) + # KL divergence between worker and learner logits for debugging + model_dist = Categorical(self.model.outputs) + behaviour_dist = Categorical(behaviour_logits) + self.KLs = model_dist.kl(behaviour_dist) + self.mean_KL = tf.reduce_mean(self.KLs) + self.max_KL = tf.reduce_max(self.KLs) + self.median_KL = tf.contrib.distributions.percentile(self.KLs, 50.0) + # Initialize TFPolicyGraph loss_in = [ ("actions", actions), @@ -225,6 +234,9 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph): "vf_explained_var": explained_variance( tf.reshape(self.loss.vtrace_returns.vs, [-1]), tf.reshape(to_batches(values)[:-1], [-1])), + "mean_KL": self.mean_KL, + "max_KL": self.max_KL, + "median_KL": self.median_KL, }, }