diff --git a/doc/source/rllib-models.rst b/doc/source/rllib-models.rst index 54b4613b1..d0c1a1082 100644 --- a/doc/source/rllib-models.rst +++ b/doc/source/rllib-models.rst @@ -90,7 +90,7 @@ Custom TF models should subclass the common RLlib `model class >> class MyModel(rllib.model.Model): - ... def custom_loss(self, policy_loss): + ... def custom_loss(self, policy_loss, loss_inputs): ... reader = JsonReader(...) ... input_ops = reader.tf_input_ops() ... with tf.variable_scope(