mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 03:34:48 +08:00
[rllib] Port DDPG to the build_tf_policy pattern (#5242)
This commit is contained in:
@@ -41,7 +41,7 @@ DEFAULT_CONFIG = with_common_config({
|
||||
# === Model ===
|
||||
# Apply a state preprocessor with spec given by the "model" config option
|
||||
# (like other RL algorithms). This is mostly useful if you have a weird
|
||||
# observation shape, like an image. Disabled by default.
|
||||
# observation shape, like an image. Auto-enabled if a custom model is set.
|
||||
"use_state_preprocessor": False,
|
||||
# Postprocess the policy network model output with these hidden layers. If
|
||||
# use_state_preprocessor is False, then these will be the *only* hidden
|
||||
@@ -173,7 +173,7 @@ def make_exploration_schedule(config, worker_index):
|
||||
if config["per_worker_exploration"]:
|
||||
assert config["num_workers"] > 1, "This requires multiple workers"
|
||||
if worker_index >= 0:
|
||||
# FIXME: what do magic constants mean? (0.4, 7)
|
||||
# Exploration constants from the Ape-X paper
|
||||
max_index = float(config["num_workers"] - 1)
|
||||
exponent = 1 + worker_index / max_index * 7
|
||||
return ConstantSchedule(0.4**exponent)
|
||||
|
||||
@@ -0,0 +1,246 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
class DDPGModel(TFModelV2):
|
||||
"""Extension of standard TFModel for DDPG.
|
||||
|
||||
Data flow:
|
||||
obs -> forward() -> model_out
|
||||
model_out -> get_policy_output() -> pi(s)
|
||||
model_out, actions -> get_q_values() -> Q(s, a)
|
||||
model_out, actions -> get_twin_q_values() -> Q_twin(s, a)
|
||||
|
||||
Note that this class by itself is not a valid model unless you
|
||||
implement forward() in a subclass."""
|
||||
|
||||
def __init__(self,
|
||||
obs_space,
|
||||
action_space,
|
||||
num_outputs,
|
||||
model_config,
|
||||
name,
|
||||
actor_hidden_activation="relu",
|
||||
actor_hiddens=(400, 300),
|
||||
critic_hidden_activation="relu",
|
||||
critic_hiddens=(400, 300),
|
||||
parameter_noise=False,
|
||||
twin_q=False,
|
||||
exploration_ou_sigma=0.2):
|
||||
"""Initialize variables of this model.
|
||||
|
||||
Extra model kwargs:
|
||||
actor_hidden_activation (str): activation for actor network
|
||||
actor_hiddens (list): hidden layers sizes for actor network
|
||||
critic_hidden_activation (str): activation for critic network
|
||||
critic_hiddens (list): hidden layers sizes for critic network
|
||||
parameter_noise (bool): use param noise exploration
|
||||
twin_q (bool): build twin Q networks
|
||||
exploration_ou_sigma (float): ou noise sigma for exploration
|
||||
|
||||
Note that the core layers for forward() are not defined here, this
|
||||
only defines the layers for the output heads. Those layers for
|
||||
forward() should be defined in subclasses of DDPGModel.
|
||||
"""
|
||||
|
||||
super(DDPGModel, self).__init__(obs_space, action_space, num_outputs,
|
||||
model_config, name)
|
||||
self.exploration_ou_sigma = exploration_ou_sigma
|
||||
|
||||
self.action_dim = np.product(action_space.shape)
|
||||
self.model_out = tf.keras.layers.Input(
|
||||
shape=(num_outputs, ), name="model_out")
|
||||
self.actions = tf.keras.layers.Input(
|
||||
shape=(self.action_dim, ), name="actions")
|
||||
|
||||
def build_action_net(action_out):
|
||||
activation = getattr(tf.nn, actor_hidden_activation)
|
||||
i = 0
|
||||
for hidden in actor_hiddens:
|
||||
if parameter_noise:
|
||||
import tensorflow.contrib.layers as layers
|
||||
action_out = layers.fully_connected(
|
||||
action_out,
|
||||
num_outputs=hidden,
|
||||
activation_fn=activation,
|
||||
normalizer_fn=layers.layer_norm)
|
||||
else:
|
||||
action_out = tf.layers.dense(
|
||||
action_out,
|
||||
units=hidden,
|
||||
activation=activation,
|
||||
name="action_hidden_{}".format(i))
|
||||
i += 1
|
||||
return tf.layers.dense(
|
||||
action_out,
|
||||
units=self.action_dim,
|
||||
activation=None,
|
||||
name="action_out")
|
||||
|
||||
action_scope = name + "/action_net"
|
||||
|
||||
# TODO(ekl) use keras layers instead of variable scopes
|
||||
def build_action_net_scope(model_out):
|
||||
with tf.variable_scope(action_scope, reuse=tf.AUTO_REUSE):
|
||||
return build_action_net(model_out)
|
||||
|
||||
pi_out = tf.keras.layers.Lambda(build_action_net_scope)(self.model_out)
|
||||
self.action_net = tf.keras.Model(self.model_out, pi_out)
|
||||
self.register_variables(self.action_net.variables)
|
||||
|
||||
# Noise vars for P network except for layer normalization vars
|
||||
if parameter_noise:
|
||||
with tf.variable_scope(action_scope, reuse=tf.AUTO_REUSE):
|
||||
self._build_parameter_noise([
|
||||
var for var in self.action_net.variables
|
||||
if "LayerNorm" not in var.name
|
||||
])
|
||||
|
||||
def build_q_net(name, model_out, actions):
|
||||
q_out = tf.keras.layers.Concatenate(axis=1)([model_out, actions])
|
||||
activation = getattr(tf.nn, critic_hidden_activation)
|
||||
for i, n in enumerate(critic_hiddens):
|
||||
q_out = tf.keras.layers.Dense(
|
||||
n,
|
||||
name="{}_hidden_{}".format(name, i),
|
||||
activation=activation)(q_out)
|
||||
q_out = tf.keras.layers.Dense(
|
||||
1, activation=None, name="{}_out".format(name))(q_out)
|
||||
return tf.keras.Model([model_out, actions], q_out)
|
||||
|
||||
self.q_net = build_q_net("q", self.model_out, self.actions)
|
||||
self.register_variables(self.q_net.variables)
|
||||
|
||||
if twin_q:
|
||||
self.twin_q_net = build_q_net("twin_q", self.model_out,
|
||||
self.actions)
|
||||
self.register_variables(self.twin_q_net.variables)
|
||||
else:
|
||||
self.twin_q_net = None
|
||||
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
"""This generates the model_out tensor input.
|
||||
|
||||
You must implement this as documented in modelv2.py."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_policy_output(self, model_out):
|
||||
"""Return the (unscaled) output of the policy network.
|
||||
|
||||
This returns the unscaled outputs of pi(s).
|
||||
|
||||
Arguments:
|
||||
model_out (Tensor): obs embeddings from the model layers, of shape
|
||||
[BATCH_SIZE, num_outputs].
|
||||
|
||||
Returns:
|
||||
tensor of shape [BATCH_SIZE, action_dim] with range [-inf, inf].
|
||||
"""
|
||||
return self.action_net(model_out)
|
||||
|
||||
def get_q_values(self, model_out, actions):
|
||||
"""Return the Q estimates for the most recent forward pass.
|
||||
|
||||
This implements Q(s, a).
|
||||
|
||||
Arguments:
|
||||
model_out (Tensor): obs embeddings from the model layers, of shape
|
||||
[BATCH_SIZE, num_outputs].
|
||||
actions (Tensor): action values that correspond with the most
|
||||
recent batch of observations passed through forward(), of shape
|
||||
[BATCH_SIZE, action_dim].
|
||||
|
||||
Returns:
|
||||
tensor of shape [BATCH_SIZE].
|
||||
"""
|
||||
return self.q_net([model_out, actions])
|
||||
|
||||
def get_twin_q_values(self, model_out, actions):
|
||||
"""Same as get_q_values but using the twin Q net.
|
||||
|
||||
This implements the twin Q(s, a).
|
||||
|
||||
Arguments:
|
||||
model_out (Tensor): obs embeddings from the model layers, of shape
|
||||
[BATCH_SIZE, num_outputs].
|
||||
actions (Tensor): action values that correspond with the most
|
||||
recent batch of observations passed through forward(), of shape
|
||||
[BATCH_SIZE, action_dim].
|
||||
|
||||
Returns:
|
||||
tensor of shape [BATCH_SIZE].
|
||||
"""
|
||||
return self.twin_q_net([model_out, actions])
|
||||
|
||||
def policy_variables(self):
|
||||
"""Return the list of variables for the policy net."""
|
||||
|
||||
return list(self.action_net.variables)
|
||||
|
||||
def q_variables(self):
|
||||
"""Return the list of variables for Q / twin Q nets."""
|
||||
|
||||
return self.q_net.variables + (self.twin_q_net.variables
|
||||
if self.twin_q_net else [])
|
||||
|
||||
def update_action_noise(self, session, distance_in_action_space,
|
||||
exploration_ou_sigma, cur_noise_scale):
|
||||
"""Update the model action noise settings.
|
||||
|
||||
This is called internally by the DDPG policy."""
|
||||
|
||||
self.pi_distance = distance_in_action_space
|
||||
if (distance_in_action_space < exploration_ou_sigma * cur_noise_scale):
|
||||
# multiplying the sampled OU noise by noise scale is
|
||||
# equivalent to multiplying the sigma of OU by noise scale
|
||||
self.parameter_noise_sigma_val *= 1.01
|
||||
else:
|
||||
self.parameter_noise_sigma_val /= 1.01
|
||||
self.parameter_noise_sigma.load(
|
||||
self.parameter_noise_sigma_val, session=session)
|
||||
|
||||
def _build_parameter_noise(self, pnet_params):
|
||||
assert pnet_params
|
||||
self.parameter_noise_sigma_val = self.exploration_ou_sigma
|
||||
self.parameter_noise_sigma = tf.get_variable(
|
||||
initializer=tf.constant_initializer(
|
||||
self.parameter_noise_sigma_val),
|
||||
name="parameter_noise_sigma",
|
||||
shape=(),
|
||||
trainable=False,
|
||||
dtype=tf.float32)
|
||||
self.parameter_noise = []
|
||||
# No need to add any noise on LayerNorm parameters
|
||||
for var in pnet_params:
|
||||
noise_var = tf.get_variable(
|
||||
name=var.name.split(":")[0] + "_noise",
|
||||
shape=var.shape,
|
||||
initializer=tf.constant_initializer(.0),
|
||||
trainable=False)
|
||||
self.parameter_noise.append(noise_var)
|
||||
remove_noise_ops = list()
|
||||
for var, var_noise in zip(pnet_params, self.parameter_noise):
|
||||
remove_noise_ops.append(tf.assign_add(var, -var_noise))
|
||||
self.remove_noise_op = tf.group(*tuple(remove_noise_ops))
|
||||
generate_noise_ops = list()
|
||||
for var_noise in self.parameter_noise:
|
||||
generate_noise_ops.append(
|
||||
tf.assign(
|
||||
var_noise,
|
||||
tf.random_normal(
|
||||
shape=var_noise.shape,
|
||||
stddev=self.parameter_noise_sigma)))
|
||||
with tf.control_dependencies(generate_noise_ops):
|
||||
add_noise_ops = list()
|
||||
for var, var_noise in zip(pnet_params, self.parameter_noise):
|
||||
add_noise_ops.append(tf.assign_add(var, var_noise))
|
||||
self.add_noise_op = tf.group(*tuple(add_noise_ops))
|
||||
self.pi_distance = None
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,20 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.models import Model
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
class NoopModel(Model):
|
||||
"""Trivial model that just returns the obs flattened.
|
||||
|
||||
This is the model used if use_state_preprocessor=False."""
|
||||
|
||||
@override(Model)
|
||||
def _build_layers_v2(self, input_dict, num_outputs, options):
|
||||
out = tf.reshape(input_dict["obs"], [-1, num_outputs])
|
||||
return out, out
|
||||
@@ -15,6 +15,11 @@ class DistributionalQModel(TFModelV2):
|
||||
|
||||
It also supports options for noisy nets and parameter space noise.
|
||||
|
||||
Data flow:
|
||||
obs -> forward() -> model_out
|
||||
model_out -> get_q_value_distributions() -> Q(s, a) atoms
|
||||
model_out -> get_state_value() -> V(s)
|
||||
|
||||
Note that this class by itself is not a valid model unless you
|
||||
implement forward() in a subclass."""
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ def check_config_and_setup_param_noise(config):
|
||||
policies = info["policy"]
|
||||
episode = info["episode"]
|
||||
episode.custom_metrics["policy_distance"] = policies[
|
||||
DEFAULT_POLICY_ID].pi_distance
|
||||
DEFAULT_POLICY_ID].model.pi_distance
|
||||
if end_callback:
|
||||
end_callback(info)
|
||||
|
||||
@@ -207,6 +207,7 @@ def make_exploration_schedule(config, worker_index):
|
||||
assert config["num_workers"] > 1, \
|
||||
"This requires multiple workers"
|
||||
if worker_index >= 0:
|
||||
# Exploration constants from the Ape-X paper
|
||||
exponent = (
|
||||
1 + worker_index / float(config["num_workers"] - 1) * 7)
|
||||
return ConstantSchedule(0.4**exponent)
|
||||
|
||||
@@ -300,12 +300,9 @@ def _build_parameter_noise(policy, pnet_params):
|
||||
def build_q_losses(policy, batch_tensors):
|
||||
config = policy.config
|
||||
# q network evaluation
|
||||
prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
|
||||
q_t, q_logits_t, q_dist_t = _compute_q_values(
|
||||
policy, policy.q_model, batch_tensors[SampleBatch.CUR_OBS],
|
||||
policy.observation_space, policy.action_space)
|
||||
policy.q_batchnorm_update_ops = list(
|
||||
set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops)
|
||||
|
||||
# target q network evalution
|
||||
q_tp1, q_logits_tp1, q_dist_tp1 = _compute_q_values(
|
||||
@@ -495,7 +492,6 @@ DQNTFPolicy = build_tf_policy(
|
||||
extra_action_feed_fn=exploration_setting_inputs,
|
||||
extra_action_fetches_fn=lambda policy: {"q_values": policy.q_values},
|
||||
extra_learn_fetches_fn=lambda policy: {"td_error": policy.q_loss.td_error},
|
||||
update_ops_fn=lambda policy: policy.q_batchnorm_update_ops,
|
||||
before_init=setup_early_mixins,
|
||||
after_init=setup_late_mixins,
|
||||
obs_include_prev_action_reward=False,
|
||||
|
||||
@@ -13,6 +13,10 @@ tf = try_import_tf()
|
||||
class SimpleQModel(TFModelV2):
|
||||
"""Extension of standard TFModel to provide Q values.
|
||||
|
||||
Data flow:
|
||||
obs -> forward() -> model_out
|
||||
model_out -> get_q_values() -> Q(s, a)
|
||||
|
||||
Note that this class by itself is not a valid model unless you
|
||||
implement forward() in a subclass."""
|
||||
|
||||
|
||||
@@ -301,6 +301,15 @@ class RolloutWorker(EvaluatorInterface):
|
||||
if seed is not None:
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
if not hasattr(self.env, "seed"):
|
||||
raise ValueError("Env doesn't support env.seed(): {}".format(
|
||||
self.env))
|
||||
self.env.seed(seed)
|
||||
try:
|
||||
import torch
|
||||
torch.manual_seed(seed)
|
||||
except ImportError:
|
||||
logger.info("Could not seed torch")
|
||||
if _has_tensorflow_graph(policy_dict):
|
||||
if (ray.is_initialized()
|
||||
and ray.worker._mode() != ray.worker.LOCAL_MODE
|
||||
|
||||
@@ -211,6 +211,7 @@ class ModelCatalog(object):
|
||||
framework="tf",
|
||||
name=None,
|
||||
model_interface=None,
|
||||
default_model=None,
|
||||
**model_kwargs):
|
||||
"""Returns a suitable model compatible with given spaces and output.
|
||||
|
||||
@@ -223,6 +224,8 @@ class ModelCatalog(object):
|
||||
framework (str): Either "tf" or "torch".
|
||||
name (str): Name (scope) for the model.
|
||||
model_interface (cls): Interface required for the model
|
||||
default_model (cls): Override the default class for the model. This
|
||||
only has an effect when not using a custom model
|
||||
model_kwargs (dict): args to pass to the ModelV2 constructor
|
||||
|
||||
Returns:
|
||||
@@ -263,7 +266,7 @@ class ModelCatalog(object):
|
||||
return instance
|
||||
|
||||
if framework == "tf":
|
||||
legacy_model_cls = ModelCatalog.get_model
|
||||
legacy_model_cls = default_model or ModelCatalog.get_model
|
||||
wrapper = ModelCatalog._wrap_if_needed(
|
||||
make_v1_wrapper(legacy_model_cls), model_interface)
|
||||
return wrapper(obs_space, action_space, num_outputs, model_config,
|
||||
|
||||
@@ -9,7 +9,11 @@ class ModelV2(object):
|
||||
"""Defines a Keras-style abstract network model for use with RLlib.
|
||||
|
||||
Custom models should extend either TFModelV2 or TorchModelV2 instead of
|
||||
this class directly. Experimental.
|
||||
this class directly.
|
||||
|
||||
Data flow:
|
||||
obs -> forward() -> model_out
|
||||
value_function() -> V(s)
|
||||
|
||||
Attributes:
|
||||
obs_space (Space): observation space of the target gym env. This
|
||||
|
||||
@@ -50,6 +50,9 @@ def make_v1_wrapper(legacy_model_cls):
|
||||
# Tracks branches created so far
|
||||
self.branches_created = set()
|
||||
|
||||
# Tracks update ops
|
||||
self._update_ops = None
|
||||
|
||||
with tf.variable_scope(self.name) as scope:
|
||||
self.variable_scope = scope
|
||||
|
||||
@@ -68,9 +71,14 @@ def make_v1_wrapper(legacy_model_cls):
|
||||
else:
|
||||
# create a new model instance
|
||||
with tf.variable_scope(self.name):
|
||||
prev_update_ops = set(
|
||||
tf.get_collection(tf.GraphKeys.UPDATE_OPS))
|
||||
new_instance = self.legacy_model_cls(
|
||||
input_dict, self.obs_space, self.action_space,
|
||||
self.num_outputs, self.model_config, state, seq_lens)
|
||||
self._update_ops = list(
|
||||
set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) -
|
||||
prev_update_ops)
|
||||
if len(new_instance.state_init) != len(self.get_initial_state()):
|
||||
raise ValueError(
|
||||
"When using a custom recurrent ModelV1 model, you should "
|
||||
@@ -83,6 +91,13 @@ def make_v1_wrapper(legacy_model_cls):
|
||||
self.variable_scope = new_instance.scope
|
||||
return new_instance.outputs, new_instance.state_out
|
||||
|
||||
@override(TFModelV2)
|
||||
def update_ops(self):
|
||||
if self._update_ops is None:
|
||||
raise ValueError(
|
||||
"Cannot get update ops before wrapped v1 model init")
|
||||
return list(self._update_ops)
|
||||
|
||||
@override(ModelV2)
|
||||
def variables(self):
|
||||
var_list = super(ModelV1Wrapper, self).variables()
|
||||
|
||||
@@ -21,3 +21,9 @@ class TFModelV2(ModelV2):
|
||||
model_config,
|
||||
name,
|
||||
framework="tf")
|
||||
|
||||
def update_ops(self):
|
||||
"""Return the list of update ops for this model.
|
||||
|
||||
For example, this should include any BatchNorm update ops."""
|
||||
return []
|
||||
|
||||
@@ -39,7 +39,6 @@ class DynamicTFPolicy(TFPolicy):
|
||||
config,
|
||||
loss_fn,
|
||||
stats_fn=None,
|
||||
update_ops_fn=None,
|
||||
grad_stats_fn=None,
|
||||
before_loss_init=None,
|
||||
make_model=None,
|
||||
@@ -60,8 +59,6 @@ class DynamicTFPolicy(TFPolicy):
|
||||
TF fetches given the policy and batch input tensors
|
||||
grad_stats_fn (func): optional function that returns a dict of
|
||||
TF fetches given the policy and loss gradient tensors
|
||||
update_ops_fn (func): optional function that returns a list
|
||||
overriding the update ops to run when applying gradients
|
||||
before_loss_init (func): optional function to run prior to loss
|
||||
init that takes the same arguments as __init__
|
||||
make_model (func): optional function that returns a ModelV2 object
|
||||
@@ -95,7 +92,6 @@ class DynamicTFPolicy(TFPolicy):
|
||||
self._loss_fn = loss_fn
|
||||
self._stats_fn = stats_fn
|
||||
self._grad_stats_fn = grad_stats_fn
|
||||
self._update_ops_fn = update_ops_fn
|
||||
self._obs_include_prev_action_reward = obs_include_prev_action_reward
|
||||
|
||||
# Setup standard placeholders
|
||||
@@ -127,8 +123,14 @@ class DynamicTFPolicy(TFPolicy):
|
||||
dtype=tf.int32, shape=[None], name="seq_lens")
|
||||
|
||||
# Setup model
|
||||
self.dist_class, logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"])
|
||||
if action_sampler_fn:
|
||||
if not make_model:
|
||||
raise ValueError(
|
||||
"make_model is required if action_sampler_fn is given")
|
||||
self.dist_class = None
|
||||
else:
|
||||
self.dist_class, logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"])
|
||||
if existing_model:
|
||||
self.model = existing_model
|
||||
elif make_model:
|
||||
@@ -158,7 +160,6 @@ class DynamicTFPolicy(TFPolicy):
|
||||
# Setup action sampler
|
||||
if action_sampler_fn:
|
||||
self.action_dist = None
|
||||
self.dist_class = None
|
||||
action_sampler, action_prob = action_sampler_fn(
|
||||
self, self.model, self.input_dict, obs_space, action_space,
|
||||
config)
|
||||
@@ -335,6 +336,6 @@ class DynamicTFPolicy(TFPolicy):
|
||||
loss = self._loss_fn(self, batch_tensors)
|
||||
if self._stats_fn:
|
||||
self._stats_fetches.update(self._stats_fn(self, batch_tensors))
|
||||
if self._update_ops_fn:
|
||||
self._update_ops = self._update_ops_fn(self)
|
||||
# override the update ops to be those of the model
|
||||
self._update_ops = self.model.update_ops()
|
||||
return loss
|
||||
|
||||
@@ -202,7 +202,7 @@ class TFPolicy(Policy):
|
||||
self._update_ops = tf.get_collection(
|
||||
tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name)
|
||||
if self._update_ops:
|
||||
logger.debug("Update ops to run on apply gradient: {}".format(
|
||||
logger.info("Update ops to run on apply gradient: {}".format(
|
||||
self._update_ops))
|
||||
with tf.control_dependencies(self._update_ops):
|
||||
self._apply_op = self.build_apply_op(self._optimizer,
|
||||
|
||||
@@ -15,9 +15,9 @@ def build_tf_policy(name,
|
||||
get_default_config=None,
|
||||
postprocess_fn=None,
|
||||
stats_fn=None,
|
||||
update_ops_fn=None,
|
||||
optimizer_fn=None,
|
||||
gradients_fn=None,
|
||||
apply_gradients_fn=None,
|
||||
grad_stats_fn=None,
|
||||
extra_action_fetches_fn=None,
|
||||
extra_action_feed_fn=None,
|
||||
@@ -35,8 +35,9 @@ def build_tf_policy(name,
|
||||
|
||||
Functions will be run in this order to initialize the policy:
|
||||
1. Placeholder setup: postprocess_fn
|
||||
2. Loss init: loss_fn, stats_fn, update_ops_fn
|
||||
3. Optimizer init: optimizer_fn, gradients_fn, grad_stats_fn
|
||||
2. Loss init: loss_fn, stats_fn
|
||||
3. Optimizer init: optimizer_fn, gradients_fn, apply_gradients_fn,
|
||||
grad_stats_fn
|
||||
|
||||
This means that you can e.g., depend on any policy attributes created in
|
||||
the running of `loss_fn` in later functions such as `stats_fn`.
|
||||
@@ -58,13 +59,13 @@ def build_tf_policy(name,
|
||||
that takes the same args as Policy.postprocess_trajectory()
|
||||
stats_fn (func): optional function that returns a dict of
|
||||
TF fetches given the policy and batch input tensors
|
||||
update_ops_fn (func): optional function that returns a list overriding
|
||||
the update ops to run when applying gradients
|
||||
optimizer_fn (func): optional function that returns a tf.Optimizer
|
||||
given the policy and config
|
||||
gradients_fn (func): optional function that returns a list of gradients
|
||||
given a tf optimizer and loss tensor. If not specified, this
|
||||
given (policy, optimizer, loss). If not specified, this
|
||||
defaults to optimizer.compute_gradients(loss)
|
||||
apply_gradients_fn (func): optional function that returns an apply
|
||||
gradients op given (policy, optimizer, grads_and_vars)
|
||||
grad_stats_fn (func): optional function that returns a dict of
|
||||
TF fetches given the policy and loss gradient tensors
|
||||
extra_action_fetches_fn (func): optional function that returns
|
||||
@@ -134,7 +135,6 @@ def build_tf_policy(name,
|
||||
loss_fn,
|
||||
stats_fn=stats_fn,
|
||||
grad_stats_fn=grad_stats_fn,
|
||||
update_ops_fn=update_ops_fn,
|
||||
before_loss_init=before_loss_init_wrapper,
|
||||
make_model=make_model,
|
||||
action_sampler_fn=action_sampler_fn,
|
||||
@@ -170,6 +170,13 @@ def build_tf_policy(name,
|
||||
else:
|
||||
return TFPolicy.gradients(self, optimizer, loss)
|
||||
|
||||
@override(TFPolicy)
|
||||
def build_apply_op(self, optimizer, grads_and_vars):
|
||||
if apply_gradients_fn:
|
||||
return apply_gradients_fn(self, optimizer, grads_and_vars)
|
||||
else:
|
||||
return TFPolicy.build_apply_op(self, optimizer, grads_and_vars)
|
||||
|
||||
@override(TFPolicy)
|
||||
def extra_compute_action_fetches(self):
|
||||
return dict(
|
||||
|
||||
Reference in New Issue
Block a user