[rllib] Rough port of DQN to build_tf_policy() pattern (#4823)

This commit is contained in:
Eric Liang
2019-06-02 14:14:31 +08:00
committed by GitHub
parent c2ade075a3
commit 665d081fe9
8 changed files with 541 additions and 620 deletions
+82 -136
View File
@@ -4,20 +4,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gym
import ray
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.explained_variance import explained_variance
from ray.rllib.policy.policy import Policy
from ray.rllib.evaluation.postprocessing import compute_advantages, \
Postprocessing
from ray.rllib.policy.tf_policy import TFPolicy, \
LearningRateSchedule
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.utils.annotations import override
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.policy.tf_policy import LearningRateSchedule
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
@@ -44,144 +37,97 @@ class A3CLoss(object):
self.entropy * entropy_coeff)
class A3CPostprocessing(object):
"""Adds the VF preds and advantages fields to the trajectory."""
@override(TFPolicy)
def extra_compute_action_fetches(self):
return dict(
TFPolicy.extra_compute_action_fetches(self),
**{SampleBatch.VF_PREDS: self.vf})
@override(Policy)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
completed = sample_batch[SampleBatch.DONES][-1]
if completed:
last_r = 0.0
else:
next_state = []
for i in range(len(self.model.state_in)):
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
last_r = self._value(sample_batch[SampleBatch.NEXT_OBS][-1],
sample_batch[SampleBatch.ACTIONS][-1],
sample_batch[SampleBatch.REWARDS][-1],
*next_state)
return compute_advantages(sample_batch, last_r, self.config["gamma"],
self.config["lambda"])
def actor_critic_loss(policy, batch_tensors):
policy.loss = A3CLoss(
policy.action_dist, batch_tensors[SampleBatch.ACTIONS],
batch_tensors[Postprocessing.ADVANTAGES],
batch_tensors[Postprocessing.VALUE_TARGETS], policy.vf,
policy.config["vf_loss_coeff"], policy.config["entropy_coeff"])
return policy.loss.total_loss
class A3CTFPolicy(LearningRateSchedule, A3CPostprocessing, TFPolicy):
def __init__(self, observation_space, action_space, config):
config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
self.config = config
self.sess = tf.get_default_session()
def postprocess_advantages(policy,
sample_batch,
other_agent_batches=None,
episode=None):
completed = sample_batch[SampleBatch.DONES][-1]
if completed:
last_r = 0.0
else:
next_state = []
for i in range(len(policy.model.state_in)):
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
sample_batch[SampleBatch.ACTIONS][-1],
sample_batch[SampleBatch.REWARDS][-1],
*next_state)
return compute_advantages(sample_batch, last_r, policy.config["gamma"],
policy.config["lambda"])
# Setup the policy
self.observations = tf.placeholder(
tf.float32, [None] + list(observation_space.shape))
dist_class, logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])
self.prev_actions = ModelCatalog.get_action_placeholder(action_space)
self.prev_rewards = tf.placeholder(
tf.float32, [None], name="prev_reward")
self.model = ModelCatalog.get_model({
"obs": self.observations,
"prev_actions": self.prev_actions,
"prev_rewards": self.prev_rewards,
"is_training": self._get_is_training_placeholder(),
}, observation_space, action_space, logit_dim, self.config["model"])
action_dist = dist_class(self.model.outputs)
def add_value_function_fetch(policy):
return {SampleBatch.VF_PREDS: policy.vf}
class ValueNetworkMixin(object):
def __init__(self):
self.vf = self.model.value_function()
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
tf.get_variable_scope().name)
# Setup the policy loss
if isinstance(action_space, gym.spaces.Box):
ac_size = action_space.shape[0]
actions = tf.placeholder(tf.float32, [None, ac_size], name="ac")
elif isinstance(action_space, gym.spaces.Discrete):
actions = tf.placeholder(tf.int64, [None], name="ac")
else:
raise UnsupportedSpaceException(
"Action space {} is not supported for A3C.".format(
action_space))
advantages = tf.placeholder(tf.float32, [None], name="advantages")
self.v_target = tf.placeholder(tf.float32, [None], name="v_target")
self.loss = A3CLoss(action_dist, actions, advantages, self.v_target,
self.vf, self.config["vf_loss_coeff"],
self.config["entropy_coeff"])
# Initialize TFPolicy
loss_in = [
(SampleBatch.CUR_OBS, self.observations),
(SampleBatch.ACTIONS, actions),
(SampleBatch.PREV_ACTIONS, self.prev_actions),
(SampleBatch.PREV_REWARDS, self.prev_rewards),
(Postprocessing.ADVANTAGES, advantages),
(Postprocessing.VALUE_TARGETS, self.v_target),
]
LearningRateSchedule.__init__(self, self.config["lr"],
self.config["lr_schedule"])
TFPolicy.__init__(
self,
observation_space,
action_space,
self.sess,
obs_input=self.observations,
action_sampler=action_dist.sample(),
action_prob=action_dist.sampled_action_prob(),
loss=self.loss.total_loss,
model=self.model,
loss_inputs=loss_in,
state_inputs=self.model.state_in,
state_outputs=self.model.state_out,
prev_action_input=self.prev_actions,
prev_reward_input=self.prev_rewards,
seq_lens=self.model.seq_lens,
max_seq_len=self.config["model"]["max_seq_len"])
self.stats_fetches = {
LEARNER_STATS_KEY: {
"cur_lr": tf.cast(self.cur_lr, tf.float64),
"policy_loss": self.loss.pi_loss,
"policy_entropy": self.loss.entropy,
"grad_gnorm": tf.global_norm(self._grads),
"var_gnorm": tf.global_norm(self.var_list),
"vf_loss": self.loss.vf_loss,
"vf_explained_var": explained_variance(self.v_target, self.vf),
},
}
self.sess.run(tf.global_variables_initializer())
@override(Policy)
def get_initial_state(self):
return self.model.state_init
@override(TFPolicy)
def gradients(self, optimizer, loss):
grads = tf.gradients(loss, self.var_list)
self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"])
clipped_grads = list(zip(self.grads, self.var_list))
return clipped_grads
@override(TFPolicy)
def extra_compute_grad_fetches(self):
return self.stats_fetches
def _value(self, ob, prev_action, prev_reward, *args):
feed_dict = {
self.observations: [ob],
self.prev_actions: [prev_action],
self.prev_rewards: [prev_reward],
self.get_placeholder(SampleBatch.CUR_OBS): [ob],
self.get_placeholder(SampleBatch.PREV_ACTIONS): [prev_action],
self.get_placeholder(SampleBatch.PREV_REWARDS): [prev_reward],
self.model.seq_lens: [1]
}
assert len(args) == len(self.model.state_in), \
(args, self.model.state_in)
for k, v in zip(self.model.state_in, args):
feed_dict[k] = v
vf = self.sess.run(self.vf, feed_dict)
vf = self.get_session().run(self.vf, feed_dict)
return vf[0]
def stats(policy, batch_tensors):
return {
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
"policy_loss": policy.loss.pi_loss,
"policy_entropy": policy.loss.entropy,
"var_gnorm": tf.global_norm(policy.var_list),
"vf_loss": policy.loss.vf_loss,
}
def grad_stats(policy, grads):
return {
"grad_gnorm": tf.global_norm(grads),
"vf_explained_var": explained_variance(
policy.get_placeholder(Postprocessing.VALUE_TARGETS), policy.vf),
}
def clip_gradients(policy, optimizer, loss):
grads = tf.gradients(loss, policy.var_list)
grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"])
clipped_grads = list(zip(grads, policy.var_list))
return clipped_grads
def setup_mixins(policy, obs_space, action_space, config):
ValueNetworkMixin.__init__(policy)
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
policy.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
tf.get_variable_scope().name)
A3CTFPolicy = build_tf_policy(
name="A3CTFPolicy",
get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG,
loss_fn=actor_critic_loss,
stats_fn=stats,
grad_stats_fn=grad_stats,
gradients_fn=clip_gradients,
postprocess_fn=postprocess_advantages,
extra_action_fetches_fn=add_value_function_fetch,
before_loss_init=setup_mixins,
mixins=[ValueNetworkMixin, LearningRateSchedule])
+275 -274
View File
@@ -7,14 +7,14 @@ import numpy as np
from scipy.stats import entropy
import ray
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.models import ModelCatalog, Categorical
from ray.rllib.utils.annotations import override
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.tf_policy import TFPolicy, \
LearningRateSchedule
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
@@ -102,46 +102,6 @@ class QLoss(object):
}
class DQNPostprocessing(object):
"""Implements n-step learning and param noise adjustments."""
@override(TFPolicy)
def extra_compute_action_fetches(self):
return dict(
TFPolicy.extra_compute_action_fetches(self), **{
"q_values": self.q_values,
})
@override(Policy)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
if self.config["parameter_noise"]:
# adjust the sigma of parameter space noise
states = [list(x) for x in sample_batch.columns(["obs"])][0]
noisy_action_distribution = self.sess.run(
self.action_probs, feed_dict={self.cur_observations: states})
self.sess.run(self.remove_noise_op)
clean_action_distribution = self.sess.run(
self.action_probs, feed_dict={self.cur_observations: states})
distance_in_action_space = np.mean(
entropy(clean_action_distribution.T,
noisy_action_distribution.T))
self.pi_distance = distance_in_action_space
if (distance_in_action_space <
-np.log(1 - self.cur_epsilon +
self.cur_epsilon / self.num_actions)):
self.parameter_noise_sigma_val *= 1.01
else:
self.parameter_noise_sigma_val /= 1.01
self.parameter_noise_sigma.load(
self.parameter_noise_sigma_val, session=self.sess)
return _postprocess_dqn(self, sample_batch)
class QNetwork(object):
def __init__(self,
model,
@@ -345,170 +305,18 @@ class QValuePolicy(object):
self.action_prob = None
class DQNTFPolicy(LearningRateSchedule, DQNPostprocessing, TFPolicy):
def __init__(self, observation_space, action_space, config):
config = dict(ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, **config)
if not isinstance(action_space, Discrete):
raise UnsupportedSpaceException(
"Action space {} is not supported for DQN.".format(
action_space))
self.config = config
class ExplorationStateMixin(object):
def __init__(self, obs_space, action_space, config):
self.cur_epsilon = 1.0
self.num_actions = action_space.n
# Action inputs
self.stochastic = tf.placeholder(tf.bool, (), name="stochastic")
self.eps = tf.placeholder(tf.float32, (), name="eps")
self.cur_observations = tf.placeholder(
tf.float32, shape=(None, ) + observation_space.shape)
# Action Q network
with tf.variable_scope(Q_SCOPE) as scope:
q_values, q_logits, q_dist, _ = self._build_q_network(
self.cur_observations, observation_space, action_space)
self.q_values = q_values
self.q_func_vars = _scope_vars(scope.name)
# Noise vars for Q network except for layer normalization vars
def add_parameter_noise(self):
if self.config["parameter_noise"]:
self._build_parameter_noise([
var for var in self.q_func_vars if "LayerNorm" not in var.name
])
self.action_probs = tf.nn.softmax(self.q_values)
self.sess.run(self.add_noise_op)
# Action outputs
self.output_actions, self.action_prob = self._build_q_value_policy(
q_values)
# Replay inputs
self.obs_t = tf.placeholder(
tf.float32, shape=(None, ) + observation_space.shape)
self.act_t = tf.placeholder(tf.int32, [None], name="action")
self.rew_t = tf.placeholder(tf.float32, [None], name="reward")
self.obs_tp1 = tf.placeholder(
tf.float32, shape=(None, ) + observation_space.shape)
self.done_mask = tf.placeholder(tf.float32, [None], name="done")
self.importance_weights = tf.placeholder(
tf.float32, [None], name="weight")
# q network evaluation
with tf.variable_scope(Q_SCOPE, reuse=True):
prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
q_t, q_logits_t, q_dist_t, model = self._build_q_network(
self.obs_t, observation_space, action_space)
q_batchnorm_update_ops = list(
set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) -
prev_update_ops)
# target q network evalution
with tf.variable_scope(Q_TARGET_SCOPE) as scope:
q_tp1, q_logits_tp1, q_dist_tp1, _ = self._build_q_network(
self.obs_tp1, observation_space, action_space)
self.target_q_func_vars = _scope_vars(scope.name)
# q scores for actions which we know were selected in the given state.
one_hot_selection = tf.one_hot(self.act_t, self.num_actions)
q_t_selected = tf.reduce_sum(q_t * one_hot_selection, 1)
q_logits_t_selected = tf.reduce_sum(
q_logits_t * tf.expand_dims(one_hot_selection, -1), 1)
# compute estimate of best possible value starting from state at t + 1
if config["double_q"]:
with tf.variable_scope(Q_SCOPE, reuse=True):
q_tp1_using_online_net, q_logits_tp1_using_online_net, \
q_dist_tp1_using_online_net, _ = self._build_q_network(
self.obs_tp1, observation_space, action_space)
q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1)
q_tp1_best_one_hot_selection = tf.one_hot(
q_tp1_best_using_online_net, self.num_actions)
q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
q_dist_tp1_best = tf.reduce_sum(
q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1),
1)
else:
q_tp1_best_one_hot_selection = tf.one_hot(
tf.argmax(q_tp1, 1), self.num_actions)
q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
q_dist_tp1_best = tf.reduce_sum(
q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1),
1)
self.loss = self._build_q_loss(q_t_selected, q_logits_t_selected,
q_tp1_best, q_dist_tp1_best)
# update_target_fn will be called periodically to copy Q network to
# target Q network
update_target_expr = []
assert len(self.q_func_vars) == len(self.target_q_func_vars), \
(self.q_func_vars, self.target_q_func_vars)
for var, var_target in zip(self.q_func_vars, self.target_q_func_vars):
update_target_expr.append(var_target.assign(var))
self.update_target_expr = tf.group(*update_target_expr)
# initialize TFPolicy
self.sess = tf.get_default_session()
self.loss_inputs = [
(SampleBatch.CUR_OBS, self.obs_t),
(SampleBatch.ACTIONS, self.act_t),
(SampleBatch.REWARDS, self.rew_t),
(SampleBatch.NEXT_OBS, self.obs_tp1),
(SampleBatch.DONES, self.done_mask),
(PRIO_WEIGHTS, self.importance_weights),
]
LearningRateSchedule.__init__(self, self.config["lr"],
self.config["lr_schedule"])
TFPolicy.__init__(
self,
observation_space,
action_space,
self.sess,
obs_input=self.cur_observations,
action_sampler=self.output_actions,
action_prob=self.action_prob,
loss=self.loss.loss,
model=model,
loss_inputs=self.loss_inputs,
update_ops=q_batchnorm_update_ops)
self.sess.run(tf.global_variables_initializer())
self.stats_fetches = dict({
"cur_lr": tf.cast(self.cur_lr, tf.float64),
}, **self.loss.stats)
@override(TFPolicy)
def optimizer(self):
return tf.train.AdamOptimizer(
learning_rate=self.cur_lr, epsilon=self.config["adam_epsilon"])
@override(TFPolicy)
def gradients(self, optimizer, loss):
if self.config["grad_norm_clipping"] is not None:
grads_and_vars = _minimize_and_clip(
optimizer,
loss,
var_list=self.q_func_vars,
clip_val=self.config["grad_norm_clipping"])
else:
grads_and_vars = optimizer.compute_gradients(
loss, var_list=self.q_func_vars)
grads_and_vars = [(g, v) for (g, v) in grads_and_vars if g is not None]
return grads_and_vars
@override(TFPolicy)
def extra_compute_action_feed_dict(self):
return {
self.stochastic: True,
self.eps: self.cur_epsilon,
}
@override(TFPolicy)
def extra_compute_grad_fetches(self):
return {
"td_error": self.loss.td_error,
LEARNER_STATS_KEY: self.stats_fetches,
}
def set_epsilon(self, epsilon):
self.cur_epsilon = epsilon
@override(Policy)
def get_state(self):
@@ -519,93 +327,262 @@ class DQNTFPolicy(LearningRateSchedule, DQNPostprocessing, TFPolicy):
TFPolicy.set_state(self, state[0])
self.set_epsilon(state[1])
def _build_parameter_noise(self, pnet_params):
self.parameter_noise_sigma_val = 1.0
self.parameter_noise_sigma = tf.get_variable(
initializer=tf.constant_initializer(
self.parameter_noise_sigma_val),
name="parameter_noise_sigma",
shape=(),
trainable=False,
dtype=tf.float32)
self.parameter_noise = list()
# No need to add any noise on LayerNorm parameters
for var in pnet_params:
noise_var = tf.get_variable(
name=var.name.split(":")[0] + "_noise",
shape=var.shape,
initializer=tf.constant_initializer(.0),
trainable=False)
self.parameter_noise.append(noise_var)
remove_noise_ops = list()
for var, var_noise in zip(pnet_params, self.parameter_noise):
remove_noise_ops.append(tf.assign_add(var, -var_noise))
self.remove_noise_op = tf.group(*tuple(remove_noise_ops))
generate_noise_ops = list()
for var_noise in self.parameter_noise:
generate_noise_ops.append(
tf.assign(
var_noise,
tf.random_normal(
shape=var_noise.shape,
stddev=self.parameter_noise_sigma)))
with tf.control_dependencies(generate_noise_ops):
add_noise_ops = list()
for var, var_noise in zip(pnet_params, self.parameter_noise):
add_noise_ops.append(tf.assign_add(var, var_noise))
self.add_noise_op = tf.group(*tuple(add_noise_ops))
self.pi_distance = None
class TargetNetworkMixin(object):
def __init__(self, obs_space, action_space, config):
# update_target_fn will be called periodically to copy Q network to
# target Q network
update_target_expr = []
assert len(self.q_func_vars) == len(self.target_q_func_vars), \
(self.q_func_vars, self.target_q_func_vars)
for var, var_target in zip(self.q_func_vars, self.target_q_func_vars):
update_target_expr.append(var_target.assign(var))
self.update_target_expr = tf.group(*update_target_expr)
def update_target(self):
return self.get_session().run(self.update_target_expr)
class ComputeTDErrorMixin(object):
def compute_td_error(self, obs_t, act_t, rew_t, obs_tp1, done_mask,
importance_weights):
td_err = self.sess.run(
if not self.loss_initialized():
return np.zeros_like(rew_t)
td_err = self.get_session().run(
self.loss.td_error,
feed_dict={
self.obs_t: [np.array(ob) for ob in obs_t],
self.act_t: act_t,
self.rew_t: rew_t,
self.obs_tp1: [np.array(ob) for ob in obs_tp1],
self.done_mask: done_mask,
self.importance_weights: importance_weights
self.get_placeholder(SampleBatch.CUR_OBS): [
np.array(ob) for ob in obs_t
],
self.get_placeholder(SampleBatch.ACTIONS): act_t,
self.get_placeholder(SampleBatch.REWARDS): rew_t,
self.get_placeholder(SampleBatch.NEXT_OBS): [
np.array(ob) for ob in obs_tp1
],
self.get_placeholder(SampleBatch.DONES): done_mask,
self.get_placeholder(PRIO_WEIGHTS): importance_weights,
})
return td_err
def add_parameter_noise(self):
if self.config["parameter_noise"]:
self.sess.run(self.add_noise_op)
def update_target(self):
return self.sess.run(self.update_target_expr)
def postprocess_trajectory(policy,
sample_batch,
other_agent_batches=None,
episode=None):
if policy.config["parameter_noise"]:
# adjust the sigma of parameter space noise
states = [list(x) for x in sample_batch.columns(["obs"])][0]
def set_epsilon(self, epsilon):
self.cur_epsilon = epsilon
noisy_action_distribution = policy.get_session().run(
policy.action_probs, feed_dict={policy.cur_observations: states})
policy.get_session().run(policy.remove_noise_op)
clean_action_distribution = policy.get_session().run(
policy.action_probs, feed_dict={policy.cur_observations: states})
distance_in_action_space = np.mean(
entropy(clean_action_distribution.T, noisy_action_distribution.T))
policy.pi_distance = distance_in_action_space
if (distance_in_action_space <
-np.log(1 - policy.cur_epsilon +
policy.cur_epsilon / policy.num_actions)):
policy.parameter_noise_sigma_val *= 1.01
else:
policy.parameter_noise_sigma_val /= 1.01
policy.parameter_noise_sigma.load(
policy.parameter_noise_sigma_val, session=policy.get_session())
def _build_q_network(self, obs, obs_space, action_space):
qnet = QNetwork(
ModelCatalog.get_model({
"obs": obs,
"is_training": self._get_is_training_placeholder(),
}, obs_space, action_space, self.num_actions,
self.config["model"]), self.num_actions,
self.config["dueling"], self.config["hiddens"],
self.config["noisy"], self.config["num_atoms"],
self.config["v_min"], self.config["v_max"], self.config["sigma0"],
self.config["parameter_noise"])
return qnet.value, qnet.logits, qnet.dist, qnet.model
return _postprocess_dqn(policy, sample_batch)
def _build_q_value_policy(self, q_values):
policy = QValuePolicy(
q_values, self.cur_observations, self.num_actions, self.stochastic,
self.eps, self.config["soft_q"], self.config["softmax_temp"])
return policy.action, policy.action_prob
def _build_q_loss(self, q_t_selected, q_logits_t_selected, q_tp1_best,
q_dist_tp1_best):
return QLoss(q_t_selected, q_logits_t_selected, q_tp1_best,
q_dist_tp1_best, self.importance_weights, self.rew_t,
self.done_mask, self.config["gamma"],
self.config["n_step"], self.config["num_atoms"],
self.config["v_min"], self.config["v_max"])
def build_q_networks(policy, input_dict, observation_space, action_space,
config):
if not isinstance(action_space, Discrete):
raise UnsupportedSpaceException(
"Action space {} is not supported for DQN.".format(action_space))
# Action Q network
with tf.variable_scope(Q_SCOPE) as scope:
q_values, q_logits, q_dist, _ = _build_q_network(
policy, input_dict[SampleBatch.CUR_OBS], observation_space,
action_space)
policy.q_values = q_values
policy.q_func_vars = _scope_vars(scope.name)
# Noise vars for Q network except for layer normalization vars
if config["parameter_noise"]:
_build_parameter_noise(
policy,
[var for var in policy.q_func_vars if "LayerNorm" not in var.name])
policy.action_probs = tf.nn.softmax(policy.q_values)
# Action outputs
qvp = QValuePolicy(q_values, input_dict[SampleBatch.CUR_OBS],
action_space.n, policy.stochastic, policy.eps,
policy.config["soft_q"], policy.config["softmax_temp"])
policy.output_actions, policy.action_prob = qvp.action, qvp.action_prob
return policy.output_actions, policy.action_prob
def _build_parameter_noise(policy, pnet_params):
policy.parameter_noise_sigma_val = 1.0
policy.parameter_noise_sigma = tf.get_variable(
initializer=tf.constant_initializer(policy.parameter_noise_sigma_val),
name="parameter_noise_sigma",
shape=(),
trainable=False,
dtype=tf.float32)
policy.parameter_noise = list()
# No need to add any noise on LayerNorm parameters
for var in pnet_params:
noise_var = tf.get_variable(
name=var.name.split(":")[0] + "_noise",
shape=var.shape,
initializer=tf.constant_initializer(.0),
trainable=False)
policy.parameter_noise.append(noise_var)
remove_noise_ops = list()
for var, var_noise in zip(pnet_params, policy.parameter_noise):
remove_noise_ops.append(tf.assign_add(var, -var_noise))
policy.remove_noise_op = tf.group(*tuple(remove_noise_ops))
generate_noise_ops = list()
for var_noise in policy.parameter_noise:
generate_noise_ops.append(
tf.assign(
var_noise,
tf.random_normal(
shape=var_noise.shape,
stddev=policy.parameter_noise_sigma)))
with tf.control_dependencies(generate_noise_ops):
add_noise_ops = list()
for var, var_noise in zip(pnet_params, policy.parameter_noise):
add_noise_ops.append(tf.assign_add(var, var_noise))
policy.add_noise_op = tf.group(*tuple(add_noise_ops))
policy.pi_distance = None
def build_q_losses(policy, batch_tensors):
# q network evaluation
with tf.variable_scope(Q_SCOPE, reuse=True):
prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
q_t, q_logits_t, q_dist_t, model = _build_q_network(
policy, batch_tensors[SampleBatch.CUR_OBS],
policy.observation_space, policy.action_space)
policy.q_batchnorm_update_ops = list(
set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops)
# target q network evalution
with tf.variable_scope(Q_TARGET_SCOPE) as scope:
q_tp1, q_logits_tp1, q_dist_tp1, _ = _build_q_network(
policy, batch_tensors[SampleBatch.NEXT_OBS],
policy.observation_space, policy.action_space)
policy.target_q_func_vars = _scope_vars(scope.name)
# q scores for actions which we know were selected in the given state.
one_hot_selection = tf.one_hot(batch_tensors[SampleBatch.ACTIONS],
policy.action_space.n)
q_t_selected = tf.reduce_sum(q_t * one_hot_selection, 1)
q_logits_t_selected = tf.reduce_sum(
q_logits_t * tf.expand_dims(one_hot_selection, -1), 1)
# compute estimate of best possible value starting from state at t + 1
if policy.config["double_q"]:
with tf.variable_scope(Q_SCOPE, reuse=True):
q_tp1_using_online_net, q_logits_tp1_using_online_net, \
q_dist_tp1_using_online_net, _ = _build_q_network(
policy,
batch_tensors[SampleBatch.NEXT_OBS],
policy.observation_space, policy.action_space)
q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1)
q_tp1_best_one_hot_selection = tf.one_hot(q_tp1_best_using_online_net,
policy.action_space.n)
q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
q_dist_tp1_best = tf.reduce_sum(
q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1), 1)
else:
q_tp1_best_one_hot_selection = tf.one_hot(
tf.argmax(q_tp1, 1), policy.action_space.n)
q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
q_dist_tp1_best = tf.reduce_sum(
q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1), 1)
policy.loss = _build_q_loss(
q_t_selected, q_logits_t_selected, q_tp1_best, q_dist_tp1_best,
batch_tensors[SampleBatch.REWARDS], batch_tensors[SampleBatch.DONES],
batch_tensors[PRIO_WEIGHTS], policy.config)
return policy.loss.loss
def adam_optimizer(policy, config):
return tf.train.AdamOptimizer(
learning_rate=policy.cur_lr, epsilon=config["adam_epsilon"])
def clip_gradients(policy, optimizer, loss):
if policy.config["grad_norm_clipping"] is not None:
grads_and_vars = _minimize_and_clip(
optimizer,
loss,
var_list=policy.q_func_vars,
clip_val=policy.config["grad_norm_clipping"])
else:
grads_and_vars = optimizer.compute_gradients(
loss, var_list=policy.q_func_vars)
grads_and_vars = [(g, v) for (g, v) in grads_and_vars if g is not None]
return grads_and_vars
def exploration_setting_inputs(policy):
return {
policy.stochastic: True,
policy.eps: policy.cur_epsilon,
}
def build_q_stats(policy, batch_tensors):
return dict({
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
}, **policy.loss.stats)
def setup_early_mixins(policy, obs_space, action_space, config):
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
ExplorationStateMixin.__init__(policy, obs_space, action_space, config)
def setup_late_mixins(policy, obs_space, action_space, config):
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
def _build_q_network(policy, obs, obs_space, action_space):
config = policy.config
qnet = QNetwork(
ModelCatalog.get_model({
"obs": obs,
"is_training": policy._get_is_training_placeholder(),
}, obs_space, action_space, action_space.n, config["model"]),
action_space.n, config["dueling"], config["hiddens"], config["noisy"],
config["num_atoms"], config["v_min"], config["v_max"],
config["sigma0"], config["parameter_noise"])
return qnet.value, qnet.logits, qnet.dist, qnet.model
def _build_q_value_policy(policy, q_values):
policy = QValuePolicy(q_values, policy.cur_observations,
policy.num_actions, policy.stochastic, policy.eps,
policy.config["soft_q"],
policy.config["softmax_temp"])
return policy.action, policy.action_prob
def _build_q_loss(q_t_selected, q_logits_t_selected, q_tp1_best,
q_dist_tp1_best, rewards, dones, importance_weights, config):
return QLoss(q_t_selected, q_logits_t_selected, q_tp1_best,
q_dist_tp1_best, importance_weights, rewards,
tf.cast(dones, tf.float32), config["gamma"], config["n_step"],
config["num_atoms"], config["v_min"], config["v_max"])
def _adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones):
@@ -706,3 +683,27 @@ def _scope_vars(scope, trainable_only=False):
tf.GraphKeys.TRAINABLE_VARIABLES
if trainable_only else tf.GraphKeys.VARIABLES,
scope=scope if isinstance(scope, str) else scope.name)
DQNTFPolicy = build_tf_policy(
name="DQNTFPolicy",
get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG,
make_action_sampler=build_q_networks,
loss_fn=build_q_losses,
stats_fn=build_q_stats,
postprocess_fn=postprocess_trajectory,
optimizer_fn=adam_optimizer,
gradients_fn=clip_gradients,
extra_action_feed_fn=exploration_setting_inputs,
extra_action_fetches_fn=lambda policy: {"q_values": policy.q_values},
extra_learn_fetches_fn=lambda policy: {"td_error": policy.loss.td_error},
update_ops_fn=lambda policy: policy.q_batchnorm_update_ops,
before_init=setup_early_mixins,
after_init=setup_late_mixins,
obs_include_prev_action_reward=False,
mixins=[
ExplorationStateMixin,
TargetNetworkMixin,
ComputeTDErrorMixin,
LearningRateSchedule,
])
+5 -2
View File
@@ -365,12 +365,15 @@ class ValueNetworkMixin(object):
tf.get_variable_scope().name)
def value(self, ob, *args):
feed_dict = {self._obs_input: [ob], self.model.seq_lens: [1]}
feed_dict = {
self.get_placeholder(SampleBatch.CUR_OBS): [ob],
self.model.seq_lens: [1]
}
assert len(args) == len(self.model.state_in), \
(args, self.model.state_in)
for k, v in zip(self.model.state_in, args):
feed_dict[k] = v
vf = self._sess.run(self.value_function, feed_dict)
vf = self.get_session().run(self.value_function, feed_dict)
return vf[0]
+10 -12
View File
@@ -216,7 +216,7 @@ class KLCoeffMixin(object):
self.kl_coeff_val *= 1.5
elif sampled_kl < 0.5 * self.kl_target:
self.kl_coeff_val *= 0.5
self.kl_coeff.load(self.kl_coeff_val, session=self._sess)
self.kl_coeff.load(self.kl_coeff_val, session=self.get_session())
return self.kl_coeff_val
@@ -240,28 +240,26 @@ class ValueNetworkMixin(object):
"a custom LSTM model that overrides the "
"value_function() method.")
with tf.variable_scope("value_function"):
self.value_function = ModelCatalog.get_model({
"obs": self._obs_input,
"prev_actions": self._prev_action_input,
"prev_rewards": self._prev_reward_input,
"is_training": self._get_is_training_placeholder(),
}, obs_space, action_space, 1, vf_config).outputs
self.value_function = ModelCatalog.get_model(
self.get_obs_input_dict(), obs_space, action_space, 1,
vf_config).outputs
self.value_function = tf.reshape(self.value_function, [-1])
else:
self.value_function = tf.zeros(shape=tf.shape(self._obs_input)[:1])
self.value_function = tf.zeros(
shape=tf.shape(self.get_placeholder(SampleBatch.CUR_OBS))[:1])
def _value(self, ob, prev_action, prev_reward, *args):
feed_dict = {
self._obs_input: [ob],
self._prev_action_input: [prev_action],
self._prev_reward_input: [prev_reward],
self.get_placeholder(SampleBatch.CUR_OBS): [ob],
self.get_placeholder(SampleBatch.PREV_ACTIONS): [prev_action],
self.get_placeholder(SampleBatch.PREV_REWARDS): [prev_reward],
self.model.seq_lens: [1]
}
assert len(args) == len(self.model.state_in), \
(args, self.model.state_in)
for k, v in zip(self.model.state_in, args):
feed_dict[k] = v
vf = self._sess.run(self.value_function, feed_dict)
vf = self.get_session().run(self.value_function, feed_dict)
return vf[0]
@@ -1,146 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.utils.annotations import override, DeveloperAPI
@DeveloperAPI
def build_tf_policy(name,
loss_fn,
get_default_config=None,
stats_fn=None,
grad_stats_fn=None,
extra_action_fetches_fn=None,
postprocess_fn=None,
optimizer_fn=None,
gradients_fn=None,
before_init=None,
before_loss_init=None,
after_init=None,
make_action_sampler=None,
mixins=None,
get_batch_divisibility_req=None):
"""Helper function for creating a dynamic tf policy at runtime.
Arguments:
name (str): name of the policy (e.g., "PPOTFPolicy")
loss_fn (func): function that returns a loss tensor the policy,
and dict of experience tensor placeholders
get_default_config (func): optional function that returns the default
config to merge with any overrides
stats_fn (func): optional function that returns a dict of
TF fetches given the policy and batch input tensors
grad_stats_fn (func): optional function that returns a dict of
TF fetches given the policy and loss gradient tensors
extra_action_fetches_fn (func): optional function that returns
a dict of TF fetches given the policy object
postprocess_fn (func): optional experience postprocessing function
that takes the same args as Policy.postprocess_trajectory()
optimizer_fn (func): optional function that returns a tf.Optimizer
given the policy and config
gradients_fn (func): optional function that returns a list of gradients
given a tf optimizer and loss tensor. If not specified, this
defaults to optimizer.compute_gradients(loss)
before_init (func): optional function to run at the beginning of
policy init that takes the same arguments as the policy constructor
before_loss_init (func): optional function to run prior to loss
init that takes the same arguments as the policy constructor
after_init (func): optional function to run at the end of policy init
that takes the same arguments as the policy constructor
make_action_sampler (func): optional function that returns a
tuple of action and action prob tensors. The function takes
(policy, input_dict, obs_space, action_space, config) as its
arguments
mixins (list): list of any class mixins for the returned policy class.
These mixins will be applied in order and will have higher
precedence than the DynamicTFPolicy class
get_batch_divisibility_req (func): optional function that returns
the divisibility requirement for sample batches
Returns:
a DynamicTFPolicy instance that uses the specified args
"""
if not name.endswith("TFPolicy"):
raise ValueError("Name should match *TFPolicy", name)
base = DynamicTFPolicy
while mixins:
class new_base(mixins.pop(), base):
pass
base = new_base
class policy_cls(base):
def __init__(self,
obs_space,
action_space,
config,
existing_inputs=None):
if get_default_config:
config = dict(get_default_config(), **config)
if before_init:
before_init(self, obs_space, action_space, config)
def before_loss_init_wrapper(policy, obs_space, action_space,
config):
if before_loss_init:
before_loss_init(policy, obs_space, action_space, config)
if extra_action_fetches_fn is None:
self._extra_action_fetches = {}
else:
self._extra_action_fetches = extra_action_fetches_fn(self)
DynamicTFPolicy.__init__(
self,
obs_space,
action_space,
config,
loss_fn,
stats_fn=stats_fn,
grad_stats_fn=grad_stats_fn,
before_loss_init=before_loss_init_wrapper,
existing_inputs=existing_inputs)
if after_init:
after_init(self, obs_space, action_space, config)
@override(Policy)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
if not postprocess_fn:
return sample_batch
return postprocess_fn(self, sample_batch, other_agent_batches,
episode)
@override(TFPolicy)
def optimizer(self):
if optimizer_fn:
return optimizer_fn(self, self.config)
else:
return TFPolicy.optimizer(self)
@override(TFPolicy)
def gradients(self, optimizer, loss):
if gradients_fn:
return gradients_fn(self, optimizer, loss)
else:
return TFPolicy.gradients(self, optimizer, loss)
@override(TFPolicy)
def extra_compute_action_fetches(self):
return dict(
TFPolicy.extra_compute_action_fetches(self),
**self._extra_action_fetches)
policy_cls.__name__ = name
policy_cls.__qualname__ = name
return policy_cls
+69 -33
View File
@@ -37,11 +37,13 @@ class DynamicTFPolicy(TFPolicy):
config,
loss_fn,
stats_fn=None,
update_ops_fn=None,
grad_stats_fn=None,
before_loss_init=None,
make_action_sampler=None,
existing_inputs=None,
get_batch_divisibility_req=None):
get_batch_divisibility_req=None,
obs_include_prev_action_reward=True):
"""Initialize a dynamic TF policy.
Arguments:
@@ -54,6 +56,8 @@ class DynamicTFPolicy(TFPolicy):
TF fetches given the policy and batch input tensors
grad_stats_fn (func): optional function that returns a dict of
TF fetches given the policy and loss gradient tensors
update_ops_fn (func): optional function that returns a list
overriding the update ops to run when applying gradients
before_loss_init (func): optional function to run prior to loss
init that takes the same arguments as __init__
make_action_sampler (func): optional function that returns a
@@ -65,30 +69,39 @@ class DynamicTFPolicy(TFPolicy):
defining new ones
get_batch_divisibility_req (func): optional function that returns
the divisibility requirement for sample batches
obs_include_prev_action_reward (bool): whether to include the
previous action and reward in the model input
"""
self.config = config
self._loss_fn = loss_fn
self._stats_fn = stats_fn
self._grad_stats_fn = grad_stats_fn
self._update_ops_fn = update_ops_fn
self._obs_include_prev_action_reward = obs_include_prev_action_reward
# Setup standard placeholders
prev_actions = None
prev_rewards = None
if existing_inputs is not None:
obs = existing_inputs[SampleBatch.CUR_OBS]
prev_actions = existing_inputs[SampleBatch.PREV_ACTIONS]
prev_rewards = existing_inputs[SampleBatch.PREV_REWARDS]
if self._obs_include_prev_action_reward:
prev_actions = existing_inputs[SampleBatch.PREV_ACTIONS]
prev_rewards = existing_inputs[SampleBatch.PREV_REWARDS]
else:
obs = tf.placeholder(
tf.float32,
shape=[None] + list(obs_space.shape),
name="observation")
prev_actions = ModelCatalog.get_action_placeholder(action_space)
prev_rewards = tf.placeholder(
tf.float32, [None], name="prev_reward")
if self._obs_include_prev_action_reward:
prev_actions = ModelCatalog.get_action_placeholder(
action_space)
prev_rewards = tf.placeholder(
tf.float32, [None], name="prev_reward")
input_dict = {
"obs": obs,
"prev_actions": prev_actions,
"prev_rewards": prev_rewards,
self.input_dict = {
SampleBatch.CUR_OBS: obs,
SampleBatch.PREV_ACTIONS: prev_actions,
SampleBatch.PREV_REWARDS: prev_rewards,
"is_training": self._get_is_training_placeholder(),
}
@@ -100,7 +113,7 @@ class DynamicTFPolicy(TFPolicy):
self.dist_class = None
self.action_dist = None
action_sampler, action_prob = make_action_sampler(
self, input_dict, obs_space, action_space, config)
self, self.input_dict, obs_space, action_space, config)
else:
self.dist_class, logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])
@@ -117,7 +130,7 @@ class DynamicTFPolicy(TFPolicy):
existing_state_in = []
existing_seq_lens = None
self.model = ModelCatalog.get_model(
input_dict,
self.input_dict,
obs_space,
action_space,
logit_dim,
@@ -158,6 +171,13 @@ class DynamicTFPolicy(TFPolicy):
if not existing_inputs:
self._initialize_loss()
def get_obs_input_dict(self):
"""Returns the obs input dict used to build policy models.
This dict includes the obs, prev actions, prev rewards, etc. tensors.
"""
return self.input_dict
@override(TFPolicy)
def copy(self, existing_inputs):
"""Creates a copy of self using existing input placeholders."""
@@ -190,10 +210,8 @@ class DynamicTFPolicy(TFPolicy):
self.action_space,
self.config,
existing_inputs=input_dict)
loss = instance._loss_fn(instance, input_dict)
if instance._stats_fn:
instance._stats_fetches.update(
instance._stats_fn(instance, input_dict))
loss = instance._do_loss_init(input_dict)
TFPolicy._initialize_loss(
instance, loss, [(k, existing_inputs[i])
for i, (k, _) in enumerate(self._loss_inputs)])
@@ -216,14 +234,18 @@ class DynamicTFPolicy(TFPolicy):
return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype)
dummy_batch = {
SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input),
SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input),
SampleBatch.CUR_OBS: fake_array(self._obs_input),
SampleBatch.NEXT_OBS: fake_array(self._obs_input),
SampleBatch.ACTIONS: fake_array(self._prev_action_input),
SampleBatch.REWARDS: np.array([0], dtype=np.float32),
SampleBatch.DONES: np.array([False], dtype=np.bool),
SampleBatch.ACTIONS: fake_array(
ModelCatalog.get_action_placeholder(self.action_space)),
SampleBatch.REWARDS: np.array([0], dtype=np.float32),
}
if self._obs_include_prev_action_reward:
dummy_batch.update({
SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input),
SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input),
})
state_init = self.get_initial_state()
for i, h in enumerate(state_init):
dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0)
@@ -238,16 +260,24 @@ class DynamicTFPolicy(TFPolicy):
postprocessed_batch = self.postprocess_trajectory(
SampleBatch(dummy_batch))
batch_tensors = UsageTrackingDict({
SampleBatch.PREV_ACTIONS: self._prev_action_input,
SampleBatch.PREV_REWARDS: self._prev_reward_input,
SampleBatch.CUR_OBS: self._obs_input,
})
loss_inputs = [
(SampleBatch.PREV_ACTIONS, self._prev_action_input),
(SampleBatch.PREV_REWARDS, self._prev_reward_input),
(SampleBatch.CUR_OBS, self._obs_input),
]
if self._obs_include_prev_action_reward:
batch_tensors = UsageTrackingDict({
SampleBatch.PREV_ACTIONS: self._prev_action_input,
SampleBatch.PREV_REWARDS: self._prev_reward_input,
SampleBatch.CUR_OBS: self._obs_input,
})
loss_inputs = [
(SampleBatch.PREV_ACTIONS, self._prev_action_input),
(SampleBatch.PREV_REWARDS, self._prev_reward_input),
(SampleBatch.CUR_OBS, self._obs_input),
]
else:
batch_tensors = UsageTrackingDict({
SampleBatch.CUR_OBS: self._obs_input,
})
loss_inputs = [
(SampleBatch.CUR_OBS, self._obs_input),
]
for k, v in postprocessed_batch.items():
if k in batch_tensors:
@@ -264,12 +294,18 @@ class DynamicTFPolicy(TFPolicy):
"Initializing loss function with dummy input:\n\n{}\n".format(
summarize(batch_tensors)))
loss = self._loss_fn(self, batch_tensors)
if self._stats_fn:
self._stats_fetches.update(self._stats_fn(self, batch_tensors))
loss = self._do_loss_init(batch_tensors)
for k in sorted(batch_tensors.accessed_keys):
loss_inputs.append((k, batch_tensors[k]))
TFPolicy._initialize_loss(self, loss, loss_inputs)
if self._grad_stats_fn:
self._stats_fetches.update(self._grad_stats_fn(self, self._grads))
self._sess.run(tf.global_variables_initializer())
def _do_loss_init(self, batch_tensors):
loss = self._loss_fn(self, batch_tensors)
if self._stats_fn:
self._stats_fetches.update(self._stats_fn(self, batch_tensors))
if self._update_ops_fn:
self._update_ops = self._update_ops_fn(self)
return loss
+37 -4
View File
@@ -139,6 +139,39 @@ class TFPolicy(Policy):
raise ValueError(
"seq_lens tensor must be given if state inputs are defined")
def get_placeholder(self, name):
"""Returns the given action or loss input placeholder by name.
If the loss has not been initialized and a loss input placeholder is
requested, an error is raised.
"""
obs_inputs = {
SampleBatch.CUR_OBS: self._obs_input,
SampleBatch.PREV_ACTIONS: self._prev_action_input,
SampleBatch.PREV_REWARDS: self._prev_reward_input,
}
if name in obs_inputs:
return obs_inputs[name]
if not self.loss_initialized():
raise RuntimeError(
"You cannot call policy.get_placeholder() for non-obs inputs "
"before the loss has been initialized. To avoid this, use "
"policy.loss_initialized() to check whether this is the "
"case, or move the call to later (e.g., from stats_fn to "
"grad_stats_fn).")
return self._loss_input_dict[name]
def get_session(self):
"""Returns a reference to the TF session for this policy."""
return self._sess
def loss_initialized(self):
"""Returns whether the loss function has been initialized."""
return self._loss is not None
def _initialize_loss(self, loss, loss_inputs):
self._loss_inputs = loss_inputs
self._loss_input_dict = dict(self._loss_inputs)
@@ -172,7 +205,7 @@ class TFPolicy(Policy):
self._grads_and_vars)
if log_once("loss_used"):
logger.debug(
logger.info(
"These tensors were used in the loss_fn:\n\n{}\n".format(
summarize(self._loss_input_dict)))
@@ -195,21 +228,21 @@ class TFPolicy(Policy):
@override(Policy)
def compute_gradients(self, postprocessed_batch):
assert self._loss is not None, "Loss not initialized"
assert self.loss_initialized()
builder = TFRunBuilder(self._sess, "compute_gradients")
fetches = self._build_compute_gradients(builder, postprocessed_batch)
return builder.get(fetches)
@override(Policy)
def apply_gradients(self, gradients):
assert self._loss is not None, "Loss not initialized"
assert self.loss_initialized()
builder = TFRunBuilder(self._sess, "apply_gradients")
fetches = self._build_apply_gradients(builder, gradients)
builder.get(fetches)
@override(Policy)
def learn_on_batch(self, postprocessed_batch):
assert self._loss is not None, "Loss not initialized"
assert self.loss_initialized()
builder = TFRunBuilder(self._sess, "learn_on_batch")
fetches = self._build_learn_on_batch(builder, postprocessed_batch)
return builder.get(fetches)
+63 -13
View File
@@ -3,7 +3,7 @@ from __future__ import division
from __future__ import print_function
from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.utils.annotations import override, DeveloperAPI
@@ -12,39 +12,60 @@ from ray.rllib.utils.annotations import override, DeveloperAPI
def build_tf_policy(name,
loss_fn,
get_default_config=None,
stats_fn=None,
grad_stats_fn=None,
extra_action_fetches_fn=None,
postprocess_fn=None,
stats_fn=None,
update_ops_fn=None,
optimizer_fn=None,
gradients_fn=None,
grad_stats_fn=None,
extra_action_fetches_fn=None,
extra_action_feed_fn=None,
extra_learn_fetches_fn=None,
extra_learn_feed_fn=None,
before_init=None,
before_loss_init=None,
after_init=None,
make_action_sampler=None,
mixins=None,
get_batch_divisibility_req=None):
get_batch_divisibility_req=None,
obs_include_prev_action_reward=True):
"""Helper function for creating a dynamic tf policy at runtime.
Functions will be run in this order to initialize the policy:
1. Placeholder setup: postprocess_fn
2. Loss init: loss_fn, stats_fn, update_ops_fn
3. Optimizer init: optimizer_fn, gradients_fn, grad_stats_fn
This means that you can e.g., depend on any policy attributes created in
the running of `loss_fn` in later functions such as `stats_fn`.
Arguments:
name (str): name of the policy (e.g., "PPOTFPolicy")
loss_fn (func): function that returns a loss tensor the policy,
and dict of experience tensor placeholders
and dict of experience tensor placeholdes
get_default_config (func): optional function that returns the default
config to merge with any overrides
stats_fn (func): optional function that returns a dict of
TF fetches given the policy and batch input tensors
grad_stats_fn (func): optional function that returns a dict of
TF fetches given the policy and loss gradient tensors
extra_action_fetches_fn (func): optional function that returns
a dict of TF fetches given the policy object
postprocess_fn (func): optional experience postprocessing function
that takes the same args as Policy.postprocess_trajectory()
stats_fn (func): optional function that returns a dict of
TF fetches given the policy and batch input tensors
update_ops_fn (func): optional function that returns a list overriding
the update ops to run when applying gradients
optimizer_fn (func): optional function that returns a tf.Optimizer
given the policy and config
gradients_fn (func): optional function that returns a list of gradients
given a tf optimizer and loss tensor. If not specified, this
defaults to optimizer.compute_gradients(loss)
grad_stats_fn (func): optional function that returns a dict of
TF fetches given the policy and loss gradient tensors
extra_action_fetches_fn (func): optional function that returns
a dict of TF fetches given the policy object
extra_action_feed_fn (func): optional function that returns a feed dict
to also feed to TF when computing actions
extra_learn_fetches_fn (func): optional function that returns a dict of
extra values to fetch and return when learning on a batch
extra_learn_feed_fn (func): optional function that returns a feed dict
to also feed to TF when learning on a batch
before_init (func): optional function to run at the beginning of
policy init that takes the same arguments as the policy constructor
before_loss_init (func): optional function to run prior to loss
@@ -60,6 +81,8 @@ def build_tf_policy(name,
precedence than the DynamicTFPolicy class
get_batch_divisibility_req (func): optional function that returns
the divisibility requirement for sample batches
obs_include_prev_action_reward (bool): whether to include the
previous action and reward in the model input
Returns:
a DynamicTFPolicy instance that uses the specified args
@@ -105,8 +128,11 @@ def build_tf_policy(name,
loss_fn,
stats_fn=stats_fn,
grad_stats_fn=grad_stats_fn,
update_ops_fn=update_ops_fn,
before_loss_init=before_loss_init_wrapper,
existing_inputs=existing_inputs)
make_action_sampler=make_action_sampler,
existing_inputs=existing_inputs,
obs_include_prev_action_reward=obs_include_prev_action_reward)
if after_init:
after_init(self, obs_space, action_space, config)
@@ -141,6 +167,30 @@ def build_tf_policy(name,
TFPolicy.extra_compute_action_fetches(self),
**self._extra_action_fetches)
@override(TFPolicy)
def extra_compute_action_feed_dict(self):
if extra_action_feed_fn:
return extra_action_feed_fn(self)
else:
return TFPolicy.extra_compute_action_feed_dict(self)
@override(TFPolicy)
def extra_compute_grad_fetches(self):
if extra_learn_fetches_fn:
# auto-add empty learner stats dict if needed
return dict({
LEARNER_STATS_KEY: {}
}, **extra_learn_fetches_fn(self))
else:
return TFPolicy.extra_compute_grad_fetches(self)
@override(TFPolicy)
def extra_compute_grad_feed_dict(self):
if extra_learn_feed_fn:
return extra_learn_feed_fn(self)
else:
return TFPolicy.extra_compute_grad_feed_dict(self)
policy_cls.__name__ = name
policy_cls.__qualname__ = name
return policy_cls