From 0c0bd4d41cf3776c4d1b312f093e5f29db00c142 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 14 Feb 2019 19:35:21 -0800 Subject: [PATCH] [rllib] Use model.value_function() in MARWIL (#4036) * fix marwil * add ph * fix --- python/ray/rllib/agents/marwil/marwil.py | 7 +++-- .../agents/marwil/marwil_policy_graph.py | 31 ++++++++----------- 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/python/ray/rllib/agents/marwil/marwil.py b/python/ray/rllib/agents/marwil/marwil.py index d0be3fbaa..e371efa4b 100644 --- a/python/ray/rllib/agents/marwil/marwil.py +++ b/python/ray/rllib/agents/marwil/marwil.py @@ -10,6 +10,11 @@ from ray.rllib.utils.annotations import override # yapf: disable # __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ + # You should override this to point to an offline dataset (see agent.py). + "input": "sampler", + # Use importance sampling estimators for reward + "input_evaluation": ["is", "wis"], + # Scaling of advantages in exponential terms # When beta is 0, MARWIL is reduced to imitation learning "beta": 1.0, @@ -19,8 +24,6 @@ DEFAULT_CONFIG = with_common_config({ "postprocess_inputs": True, # Whether to rollout "complete_episodes" or "truncate_episodes" "batch_mode": "complete_episodes", - # Use importance sampling estimators for reward - "input_evaluation": ["is", "wis"], # Learning rate for adam optimizer "lr": 1e-4, # Number of timesteps collected for each SGD round diff --git a/python/ray/rllib/agents/marwil/marwil_policy_graph.py b/python/ray/rllib/agents/marwil/marwil_policy_graph.py index 03f87ca2a..c1a155340 100644 --- a/python/ray/rllib/agents/marwil/marwil_policy_graph.py +++ b/python/ray/rllib/agents/marwil/marwil_policy_graph.py @@ -63,10 +63,17 @@ class MARWILPolicyGraph(TFPolicyGraph): # Action inputs self.obs_t = tf.placeholder( tf.float32, shape=(None, ) + observation_space.shape) + prev_actions_ph = ModelCatalog.get_action_placeholder(action_space) + prev_rewards_ph = tf.placeholder( + tf.float32, [None], name="prev_reward") with tf.variable_scope(P_SCOPE) as scope: - self.model = self._build_policy_network( - self.obs_t, observation_space, logit_dim) + self.model = ModelCatalog.get_model({ + "obs": self.obs_t, + "prev_actions": prev_actions_ph, + "prev_rewards": prev_rewards_ph, + "is_training": self._get_is_training_placeholder(), + }, observation_space, logit_dim, self.config["model"]) logits = self.model.outputs self.p_func_vars = _scope_vars(scope.name) @@ -80,8 +87,7 @@ class MARWILPolicyGraph(TFPolicyGraph): # v network evaluation with tf.variable_scope(V_SCOPE) as scope: - state_values = self._build_value_network(self.obs_t, - observation_space) + state_values = self.model.value_function() self.v_func_vars = _scope_vars(scope.name) self.v_loss = self._build_value_loss(state_values, self.cum_rew_t) self.p_loss = self._build_policy_loss(state_values, self.cum_rew_t, @@ -111,7 +117,9 @@ class MARWILPolicyGraph(TFPolicyGraph): loss=self.model.loss() + objective, loss_inputs=self.loss_inputs, state_inputs=self.model.state_in, - state_outputs=self.model.state_out) + state_outputs=self.model.state_out, + prev_action_input=prev_actions_ph, + prev_reward_input=prev_rewards_ph) self.sess.run(tf.global_variables_initializer()) self.stats_fetches = { @@ -121,19 +129,6 @@ class MARWILPolicyGraph(TFPolicyGraph): "vf_loss": self.v_loss.loss } - def _build_policy_network(self, obs, obs_space, logit_dim): - return ModelCatalog.get_model({ - "obs": obs, - "is_training": self._get_is_training_placeholder(), - }, obs_space, logit_dim, self.config["model"]) - - def _build_value_network(self, obs, obs_space): - value_model = ModelCatalog.get_model({ - "obs": obs, - "is_training": self._get_is_training_placeholder(), - }, obs_space, 1, self.config["model"]) - return value_model.outputs - def _build_value_loss(self, state_values, cum_rwds): return ValueLoss(state_values, cum_rwds)