mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 12:45:44 +08:00
@@ -16,6 +16,7 @@ from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.models.action_dist import Categorical
|
||||
|
||||
|
||||
class VTraceLoss(object):
|
||||
@@ -184,6 +185,14 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
clip_rho_threshold=self.config["vtrace_clip_rho_threshold"],
|
||||
clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"])
|
||||
|
||||
# KL divergence between worker and learner logits for debugging
|
||||
model_dist = Categorical(self.model.outputs)
|
||||
behaviour_dist = Categorical(behaviour_logits)
|
||||
self.KLs = model_dist.kl(behaviour_dist)
|
||||
self.mean_KL = tf.reduce_mean(self.KLs)
|
||||
self.max_KL = tf.reduce_max(self.KLs)
|
||||
self.median_KL = tf.contrib.distributions.percentile(self.KLs, 50.0)
|
||||
|
||||
# Initialize TFPolicyGraph
|
||||
loss_in = [
|
||||
("actions", actions),
|
||||
@@ -225,6 +234,9 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
"vf_explained_var": explained_variance(
|
||||
tf.reshape(self.loss.vtrace_returns.vs, [-1]),
|
||||
tf.reshape(to_batches(values)[:-1], [-1])),
|
||||
"mean_KL": self.mean_KL,
|
||||
"max_KL": self.max_KL,
|
||||
"median_KL": self.median_KL,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user