mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 14:48:54 +08:00
[rllib] Use model.value_function() in MARWIL (#4036)
* fix marwil * add ph * fix
This commit is contained in:
@@ -10,6 +10,11 @@ from ray.rllib.utils.annotations import override
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# You should override this to point to an offline dataset (see agent.py).
|
||||
"input": "sampler",
|
||||
# Use importance sampling estimators for reward
|
||||
"input_evaluation": ["is", "wis"],
|
||||
|
||||
# Scaling of advantages in exponential terms
|
||||
# When beta is 0, MARWIL is reduced to imitation learning
|
||||
"beta": 1.0,
|
||||
@@ -19,8 +24,6 @@ DEFAULT_CONFIG = with_common_config({
|
||||
"postprocess_inputs": True,
|
||||
# Whether to rollout "complete_episodes" or "truncate_episodes"
|
||||
"batch_mode": "complete_episodes",
|
||||
# Use importance sampling estimators for reward
|
||||
"input_evaluation": ["is", "wis"],
|
||||
# Learning rate for adam optimizer
|
||||
"lr": 1e-4,
|
||||
# Number of timesteps collected for each SGD round
|
||||
|
||||
@@ -63,10 +63,17 @@ class MARWILPolicyGraph(TFPolicyGraph):
|
||||
# Action inputs
|
||||
self.obs_t = tf.placeholder(
|
||||
tf.float32, shape=(None, ) + observation_space.shape)
|
||||
prev_actions_ph = ModelCatalog.get_action_placeholder(action_space)
|
||||
prev_rewards_ph = tf.placeholder(
|
||||
tf.float32, [None], name="prev_reward")
|
||||
|
||||
with tf.variable_scope(P_SCOPE) as scope:
|
||||
self.model = self._build_policy_network(
|
||||
self.obs_t, observation_space, logit_dim)
|
||||
self.model = ModelCatalog.get_model({
|
||||
"obs": self.obs_t,
|
||||
"prev_actions": prev_actions_ph,
|
||||
"prev_rewards": prev_rewards_ph,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
}, observation_space, logit_dim, self.config["model"])
|
||||
logits = self.model.outputs
|
||||
self.p_func_vars = _scope_vars(scope.name)
|
||||
|
||||
@@ -80,8 +87,7 @@ class MARWILPolicyGraph(TFPolicyGraph):
|
||||
|
||||
# v network evaluation
|
||||
with tf.variable_scope(V_SCOPE) as scope:
|
||||
state_values = self._build_value_network(self.obs_t,
|
||||
observation_space)
|
||||
state_values = self.model.value_function()
|
||||
self.v_func_vars = _scope_vars(scope.name)
|
||||
self.v_loss = self._build_value_loss(state_values, self.cum_rew_t)
|
||||
self.p_loss = self._build_policy_loss(state_values, self.cum_rew_t,
|
||||
@@ -111,7 +117,9 @@ class MARWILPolicyGraph(TFPolicyGraph):
|
||||
loss=self.model.loss() + objective,
|
||||
loss_inputs=self.loss_inputs,
|
||||
state_inputs=self.model.state_in,
|
||||
state_outputs=self.model.state_out)
|
||||
state_outputs=self.model.state_out,
|
||||
prev_action_input=prev_actions_ph,
|
||||
prev_reward_input=prev_rewards_ph)
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
self.stats_fetches = {
|
||||
@@ -121,19 +129,6 @@ class MARWILPolicyGraph(TFPolicyGraph):
|
||||
"vf_loss": self.v_loss.loss
|
||||
}
|
||||
|
||||
def _build_policy_network(self, obs, obs_space, logit_dim):
|
||||
return ModelCatalog.get_model({
|
||||
"obs": obs,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
}, obs_space, logit_dim, self.config["model"])
|
||||
|
||||
def _build_value_network(self, obs, obs_space):
|
||||
value_model = ModelCatalog.get_model({
|
||||
"obs": obs,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
}, obs_space, 1, self.config["model"])
|
||||
return value_model.outputs
|
||||
|
||||
def _build_value_loss(self, state_values, cum_rwds):
|
||||
return ValueLoss(state_values, cum_rwds)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user