From ba030482542cb670f5169b0d13254ec68c927deb Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sat, 2 Mar 2019 22:57:51 -0800 Subject: [PATCH] [rllib] TF model custom_loss() should actually allow access to full rollout data (#4220) --- doc/source/rllib-models.rst | 3 ++- .../ray/rllib/agents/ddpg/ddpg_policy_graph.py | 18 +++++++++++------- python/ray/rllib/evaluation/tf_policy_graph.py | 13 +++++++------ python/ray/rllib/examples/custom_loss.py | 6 ++++-- python/ray/rllib/models/model.py | 3 ++- python/ray/rllib/offline/input_reader.py | 2 +- 6 files changed, 27 insertions(+), 18 deletions(-) diff --git a/doc/source/rllib-models.rst b/doc/source/rllib-models.rst index 54b4613b1..d0c1a1082 100644 --- a/doc/source/rllib-models.rst +++ b/doc/source/rllib-models.rst @@ -90,7 +90,7 @@ Custom TF models should subclass the common RLlib `model class >> class MyModel(rllib.model.Model): - ... def custom_loss(self, policy_loss): + ... def custom_loss(self, policy_loss, loss_inputs): ... reader = JsonReader(...) ... input_ops = reader.tf_input_ops() ... with tf.variable_scope(