From d90f3653946c0be3edaad68fbb2660d164099bf1 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 12 Nov 2018 18:55:24 -0800 Subject: [PATCH] [rllib] Add self-supervised loss to model (#3291) # What do these changes do? Allow self-supervised losses to be easily defined in custom models. Add this to the reference policy graphs. --- doc/source/rllib-models.rst | 14 ++++++++++++- .../rllib/agents/a3c/a3c_tf_policy_graph.py | 2 +- .../rllib/agents/ddpg/ddpg_policy_graph.py | 20 ++++++++++--------- .../ray/rllib/agents/dqn/dqn_policy_graph.py | 13 ++++++------ .../agents/impala/vtrace_policy_graph.py | 2 +- python/ray/rllib/agents/pg/pg_policy_graph.py | 2 +- .../ray/rllib/agents/ppo/ppo_policy_graph.py | 2 +- python/ray/rllib/models/model.py | 12 +++++++++++ 8 files changed, 47 insertions(+), 20 deletions(-) diff --git a/doc/source/rllib-models.rst b/doc/source/rllib-models.rst index 79df3ab5c..5fde37f53 100644 --- a/doc/source/rllib-models.rst +++ b/doc/source/rllib-models.rst @@ -30,7 +30,7 @@ The following is a list of the built-in model hyperparameters: Custom Models ------------- -Custom models should subclass the common RLlib `model class `__ and override the ``_build_layers_v2`` method. This method takes in a dict of tensor inputs (the observation ``obs``, ``prev_action``, and ``prev_reward``), and returns a feature layer and float vector of the specified output size. You can also override the ``value_function`` method to implement a custom value branch. The model can then be registered and used in place of a built-in model: +Custom models should subclass the common RLlib `model class `__ and override the ``_build_layers_v2`` method. This method takes in a dict of tensor inputs (the observation ``obs``, ``prev_action``, and ``prev_reward``), and returns a feature layer and float vector of the specified output size. You can also override the ``value_function`` method to implement a custom value branch. A self-supervised loss can be defined via the ``loss`` method. The model can then be registered and used in place of a built-in model: .. code-block:: python @@ -86,6 +86,18 @@ Custom models should subclass the common RLlib `model class