mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 15:22:56 +08:00
[rllib] Rough port of DQN to build_tf_policy() pattern (#4823)
This commit is contained in:
@@ -4,20 +4,13 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gym
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||
Postprocessing
|
||||
from ray.rllib.policy.tf_policy import TFPolicy, \
|
||||
LearningRateSchedule
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.policy.tf_policy import LearningRateSchedule
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
@@ -44,144 +37,97 @@ class A3CLoss(object):
|
||||
self.entropy * entropy_coeff)
|
||||
|
||||
|
||||
class A3CPostprocessing(object):
|
||||
"""Adds the VF preds and advantages fields to the trajectory."""
|
||||
|
||||
@override(TFPolicy)
|
||||
def extra_compute_action_fetches(self):
|
||||
return dict(
|
||||
TFPolicy.extra_compute_action_fetches(self),
|
||||
**{SampleBatch.VF_PREDS: self.vf})
|
||||
|
||||
@override(Policy)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
completed = sample_batch[SampleBatch.DONES][-1]
|
||||
if completed:
|
||||
last_r = 0.0
|
||||
else:
|
||||
next_state = []
|
||||
for i in range(len(self.model.state_in)):
|
||||
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
|
||||
last_r = self._value(sample_batch[SampleBatch.NEXT_OBS][-1],
|
||||
sample_batch[SampleBatch.ACTIONS][-1],
|
||||
sample_batch[SampleBatch.REWARDS][-1],
|
||||
*next_state)
|
||||
return compute_advantages(sample_batch, last_r, self.config["gamma"],
|
||||
self.config["lambda"])
|
||||
def actor_critic_loss(policy, batch_tensors):
|
||||
policy.loss = A3CLoss(
|
||||
policy.action_dist, batch_tensors[SampleBatch.ACTIONS],
|
||||
batch_tensors[Postprocessing.ADVANTAGES],
|
||||
batch_tensors[Postprocessing.VALUE_TARGETS], policy.vf,
|
||||
policy.config["vf_loss_coeff"], policy.config["entropy_coeff"])
|
||||
return policy.loss.total_loss
|
||||
|
||||
|
||||
class A3CTFPolicy(LearningRateSchedule, A3CPostprocessing, TFPolicy):
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
|
||||
self.config = config
|
||||
self.sess = tf.get_default_session()
|
||||
def postprocess_advantages(policy,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
completed = sample_batch[SampleBatch.DONES][-1]
|
||||
if completed:
|
||||
last_r = 0.0
|
||||
else:
|
||||
next_state = []
|
||||
for i in range(len(policy.model.state_in)):
|
||||
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
|
||||
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
|
||||
sample_batch[SampleBatch.ACTIONS][-1],
|
||||
sample_batch[SampleBatch.REWARDS][-1],
|
||||
*next_state)
|
||||
return compute_advantages(sample_batch, last_r, policy.config["gamma"],
|
||||
policy.config["lambda"])
|
||||
|
||||
# Setup the policy
|
||||
self.observations = tf.placeholder(
|
||||
tf.float32, [None] + list(observation_space.shape))
|
||||
dist_class, logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"])
|
||||
self.prev_actions = ModelCatalog.get_action_placeholder(action_space)
|
||||
self.prev_rewards = tf.placeholder(
|
||||
tf.float32, [None], name="prev_reward")
|
||||
self.model = ModelCatalog.get_model({
|
||||
"obs": self.observations,
|
||||
"prev_actions": self.prev_actions,
|
||||
"prev_rewards": self.prev_rewards,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
}, observation_space, action_space, logit_dim, self.config["model"])
|
||||
action_dist = dist_class(self.model.outputs)
|
||||
|
||||
def add_value_function_fetch(policy):
|
||||
return {SampleBatch.VF_PREDS: policy.vf}
|
||||
|
||||
|
||||
class ValueNetworkMixin(object):
|
||||
def __init__(self):
|
||||
self.vf = self.model.value_function()
|
||||
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
|
||||
tf.get_variable_scope().name)
|
||||
|
||||
# Setup the policy loss
|
||||
if isinstance(action_space, gym.spaces.Box):
|
||||
ac_size = action_space.shape[0]
|
||||
actions = tf.placeholder(tf.float32, [None, ac_size], name="ac")
|
||||
elif isinstance(action_space, gym.spaces.Discrete):
|
||||
actions = tf.placeholder(tf.int64, [None], name="ac")
|
||||
else:
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space {} is not supported for A3C.".format(
|
||||
action_space))
|
||||
advantages = tf.placeholder(tf.float32, [None], name="advantages")
|
||||
self.v_target = tf.placeholder(tf.float32, [None], name="v_target")
|
||||
self.loss = A3CLoss(action_dist, actions, advantages, self.v_target,
|
||||
self.vf, self.config["vf_loss_coeff"],
|
||||
self.config["entropy_coeff"])
|
||||
|
||||
# Initialize TFPolicy
|
||||
loss_in = [
|
||||
(SampleBatch.CUR_OBS, self.observations),
|
||||
(SampleBatch.ACTIONS, actions),
|
||||
(SampleBatch.PREV_ACTIONS, self.prev_actions),
|
||||
(SampleBatch.PREV_REWARDS, self.prev_rewards),
|
||||
(Postprocessing.ADVANTAGES, advantages),
|
||||
(Postprocessing.VALUE_TARGETS, self.v_target),
|
||||
]
|
||||
LearningRateSchedule.__init__(self, self.config["lr"],
|
||||
self.config["lr_schedule"])
|
||||
TFPolicy.__init__(
|
||||
self,
|
||||
observation_space,
|
||||
action_space,
|
||||
self.sess,
|
||||
obs_input=self.observations,
|
||||
action_sampler=action_dist.sample(),
|
||||
action_prob=action_dist.sampled_action_prob(),
|
||||
loss=self.loss.total_loss,
|
||||
model=self.model,
|
||||
loss_inputs=loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
state_outputs=self.model.state_out,
|
||||
prev_action_input=self.prev_actions,
|
||||
prev_reward_input=self.prev_rewards,
|
||||
seq_lens=self.model.seq_lens,
|
||||
max_seq_len=self.config["model"]["max_seq_len"])
|
||||
|
||||
self.stats_fetches = {
|
||||
LEARNER_STATS_KEY: {
|
||||
"cur_lr": tf.cast(self.cur_lr, tf.float64),
|
||||
"policy_loss": self.loss.pi_loss,
|
||||
"policy_entropy": self.loss.entropy,
|
||||
"grad_gnorm": tf.global_norm(self._grads),
|
||||
"var_gnorm": tf.global_norm(self.var_list),
|
||||
"vf_loss": self.loss.vf_loss,
|
||||
"vf_explained_var": explained_variance(self.v_target, self.vf),
|
||||
},
|
||||
}
|
||||
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
@override(Policy)
|
||||
def get_initial_state(self):
|
||||
return self.model.state_init
|
||||
|
||||
@override(TFPolicy)
|
||||
def gradients(self, optimizer, loss):
|
||||
grads = tf.gradients(loss, self.var_list)
|
||||
self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"])
|
||||
clipped_grads = list(zip(self.grads, self.var_list))
|
||||
return clipped_grads
|
||||
|
||||
@override(TFPolicy)
|
||||
def extra_compute_grad_fetches(self):
|
||||
return self.stats_fetches
|
||||
|
||||
def _value(self, ob, prev_action, prev_reward, *args):
|
||||
feed_dict = {
|
||||
self.observations: [ob],
|
||||
self.prev_actions: [prev_action],
|
||||
self.prev_rewards: [prev_reward],
|
||||
self.get_placeholder(SampleBatch.CUR_OBS): [ob],
|
||||
self.get_placeholder(SampleBatch.PREV_ACTIONS): [prev_action],
|
||||
self.get_placeholder(SampleBatch.PREV_REWARDS): [prev_reward],
|
||||
self.model.seq_lens: [1]
|
||||
}
|
||||
assert len(args) == len(self.model.state_in), \
|
||||
(args, self.model.state_in)
|
||||
for k, v in zip(self.model.state_in, args):
|
||||
feed_dict[k] = v
|
||||
vf = self.sess.run(self.vf, feed_dict)
|
||||
vf = self.get_session().run(self.vf, feed_dict)
|
||||
return vf[0]
|
||||
|
||||
|
||||
def stats(policy, batch_tensors):
|
||||
return {
|
||||
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
|
||||
"policy_loss": policy.loss.pi_loss,
|
||||
"policy_entropy": policy.loss.entropy,
|
||||
"var_gnorm": tf.global_norm(policy.var_list),
|
||||
"vf_loss": policy.loss.vf_loss,
|
||||
}
|
||||
|
||||
|
||||
def grad_stats(policy, grads):
|
||||
return {
|
||||
"grad_gnorm": tf.global_norm(grads),
|
||||
"vf_explained_var": explained_variance(
|
||||
policy.get_placeholder(Postprocessing.VALUE_TARGETS), policy.vf),
|
||||
}
|
||||
|
||||
|
||||
def clip_gradients(policy, optimizer, loss):
|
||||
grads = tf.gradients(loss, policy.var_list)
|
||||
grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"])
|
||||
clipped_grads = list(zip(grads, policy.var_list))
|
||||
return clipped_grads
|
||||
|
||||
|
||||
def setup_mixins(policy, obs_space, action_space, config):
|
||||
ValueNetworkMixin.__init__(policy)
|
||||
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
||||
policy.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
|
||||
tf.get_variable_scope().name)
|
||||
|
||||
|
||||
A3CTFPolicy = build_tf_policy(
|
||||
name="A3CTFPolicy",
|
||||
get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG,
|
||||
loss_fn=actor_critic_loss,
|
||||
stats_fn=stats,
|
||||
grad_stats_fn=grad_stats,
|
||||
gradients_fn=clip_gradients,
|
||||
postprocess_fn=postprocess_advantages,
|
||||
extra_action_fetches_fn=add_value_function_fetch,
|
||||
before_loss_init=setup_mixins,
|
||||
mixins=[ValueNetworkMixin, LearningRateSchedule])
|
||||
|
||||
@@ -7,14 +7,14 @@ import numpy as np
|
||||
from scipy.stats import entropy
|
||||
|
||||
import ray
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.models import ModelCatalog, Categorical
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.tf_policy import TFPolicy, \
|
||||
LearningRateSchedule
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
@@ -102,46 +102,6 @@ class QLoss(object):
|
||||
}
|
||||
|
||||
|
||||
class DQNPostprocessing(object):
|
||||
"""Implements n-step learning and param noise adjustments."""
|
||||
|
||||
@override(TFPolicy)
|
||||
def extra_compute_action_fetches(self):
|
||||
return dict(
|
||||
TFPolicy.extra_compute_action_fetches(self), **{
|
||||
"q_values": self.q_values,
|
||||
})
|
||||
|
||||
@override(Policy)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
if self.config["parameter_noise"]:
|
||||
# adjust the sigma of parameter space noise
|
||||
states = [list(x) for x in sample_batch.columns(["obs"])][0]
|
||||
|
||||
noisy_action_distribution = self.sess.run(
|
||||
self.action_probs, feed_dict={self.cur_observations: states})
|
||||
self.sess.run(self.remove_noise_op)
|
||||
clean_action_distribution = self.sess.run(
|
||||
self.action_probs, feed_dict={self.cur_observations: states})
|
||||
distance_in_action_space = np.mean(
|
||||
entropy(clean_action_distribution.T,
|
||||
noisy_action_distribution.T))
|
||||
self.pi_distance = distance_in_action_space
|
||||
if (distance_in_action_space <
|
||||
-np.log(1 - self.cur_epsilon +
|
||||
self.cur_epsilon / self.num_actions)):
|
||||
self.parameter_noise_sigma_val *= 1.01
|
||||
else:
|
||||
self.parameter_noise_sigma_val /= 1.01
|
||||
self.parameter_noise_sigma.load(
|
||||
self.parameter_noise_sigma_val, session=self.sess)
|
||||
|
||||
return _postprocess_dqn(self, sample_batch)
|
||||
|
||||
|
||||
class QNetwork(object):
|
||||
def __init__(self,
|
||||
model,
|
||||
@@ -345,170 +305,18 @@ class QValuePolicy(object):
|
||||
self.action_prob = None
|
||||
|
||||
|
||||
class DQNTFPolicy(LearningRateSchedule, DQNPostprocessing, TFPolicy):
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
config = dict(ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, **config)
|
||||
if not isinstance(action_space, Discrete):
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space {} is not supported for DQN.".format(
|
||||
action_space))
|
||||
|
||||
self.config = config
|
||||
class ExplorationStateMixin(object):
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
self.cur_epsilon = 1.0
|
||||
self.num_actions = action_space.n
|
||||
|
||||
# Action inputs
|
||||
self.stochastic = tf.placeholder(tf.bool, (), name="stochastic")
|
||||
self.eps = tf.placeholder(tf.float32, (), name="eps")
|
||||
self.cur_observations = tf.placeholder(
|
||||
tf.float32, shape=(None, ) + observation_space.shape)
|
||||
|
||||
# Action Q network
|
||||
with tf.variable_scope(Q_SCOPE) as scope:
|
||||
q_values, q_logits, q_dist, _ = self._build_q_network(
|
||||
self.cur_observations, observation_space, action_space)
|
||||
self.q_values = q_values
|
||||
self.q_func_vars = _scope_vars(scope.name)
|
||||
|
||||
# Noise vars for Q network except for layer normalization vars
|
||||
def add_parameter_noise(self):
|
||||
if self.config["parameter_noise"]:
|
||||
self._build_parameter_noise([
|
||||
var for var in self.q_func_vars if "LayerNorm" not in var.name
|
||||
])
|
||||
self.action_probs = tf.nn.softmax(self.q_values)
|
||||
self.sess.run(self.add_noise_op)
|
||||
|
||||
# Action outputs
|
||||
self.output_actions, self.action_prob = self._build_q_value_policy(
|
||||
q_values)
|
||||
|
||||
# Replay inputs
|
||||
self.obs_t = tf.placeholder(
|
||||
tf.float32, shape=(None, ) + observation_space.shape)
|
||||
self.act_t = tf.placeholder(tf.int32, [None], name="action")
|
||||
self.rew_t = tf.placeholder(tf.float32, [None], name="reward")
|
||||
self.obs_tp1 = tf.placeholder(
|
||||
tf.float32, shape=(None, ) + observation_space.shape)
|
||||
self.done_mask = tf.placeholder(tf.float32, [None], name="done")
|
||||
self.importance_weights = tf.placeholder(
|
||||
tf.float32, [None], name="weight")
|
||||
|
||||
# q network evaluation
|
||||
with tf.variable_scope(Q_SCOPE, reuse=True):
|
||||
prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
|
||||
q_t, q_logits_t, q_dist_t, model = self._build_q_network(
|
||||
self.obs_t, observation_space, action_space)
|
||||
q_batchnorm_update_ops = list(
|
||||
set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) -
|
||||
prev_update_ops)
|
||||
|
||||
# target q network evalution
|
||||
with tf.variable_scope(Q_TARGET_SCOPE) as scope:
|
||||
q_tp1, q_logits_tp1, q_dist_tp1, _ = self._build_q_network(
|
||||
self.obs_tp1, observation_space, action_space)
|
||||
self.target_q_func_vars = _scope_vars(scope.name)
|
||||
|
||||
# q scores for actions which we know were selected in the given state.
|
||||
one_hot_selection = tf.one_hot(self.act_t, self.num_actions)
|
||||
q_t_selected = tf.reduce_sum(q_t * one_hot_selection, 1)
|
||||
q_logits_t_selected = tf.reduce_sum(
|
||||
q_logits_t * tf.expand_dims(one_hot_selection, -1), 1)
|
||||
|
||||
# compute estimate of best possible value starting from state at t + 1
|
||||
if config["double_q"]:
|
||||
with tf.variable_scope(Q_SCOPE, reuse=True):
|
||||
q_tp1_using_online_net, q_logits_tp1_using_online_net, \
|
||||
q_dist_tp1_using_online_net, _ = self._build_q_network(
|
||||
self.obs_tp1, observation_space, action_space)
|
||||
q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1)
|
||||
q_tp1_best_one_hot_selection = tf.one_hot(
|
||||
q_tp1_best_using_online_net, self.num_actions)
|
||||
q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
|
||||
q_dist_tp1_best = tf.reduce_sum(
|
||||
q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1),
|
||||
1)
|
||||
else:
|
||||
q_tp1_best_one_hot_selection = tf.one_hot(
|
||||
tf.argmax(q_tp1, 1), self.num_actions)
|
||||
q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
|
||||
q_dist_tp1_best = tf.reduce_sum(
|
||||
q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1),
|
||||
1)
|
||||
|
||||
self.loss = self._build_q_loss(q_t_selected, q_logits_t_selected,
|
||||
q_tp1_best, q_dist_tp1_best)
|
||||
|
||||
# update_target_fn will be called periodically to copy Q network to
|
||||
# target Q network
|
||||
update_target_expr = []
|
||||
assert len(self.q_func_vars) == len(self.target_q_func_vars), \
|
||||
(self.q_func_vars, self.target_q_func_vars)
|
||||
for var, var_target in zip(self.q_func_vars, self.target_q_func_vars):
|
||||
update_target_expr.append(var_target.assign(var))
|
||||
self.update_target_expr = tf.group(*update_target_expr)
|
||||
|
||||
# initialize TFPolicy
|
||||
self.sess = tf.get_default_session()
|
||||
self.loss_inputs = [
|
||||
(SampleBatch.CUR_OBS, self.obs_t),
|
||||
(SampleBatch.ACTIONS, self.act_t),
|
||||
(SampleBatch.REWARDS, self.rew_t),
|
||||
(SampleBatch.NEXT_OBS, self.obs_tp1),
|
||||
(SampleBatch.DONES, self.done_mask),
|
||||
(PRIO_WEIGHTS, self.importance_weights),
|
||||
]
|
||||
|
||||
LearningRateSchedule.__init__(self, self.config["lr"],
|
||||
self.config["lr_schedule"])
|
||||
TFPolicy.__init__(
|
||||
self,
|
||||
observation_space,
|
||||
action_space,
|
||||
self.sess,
|
||||
obs_input=self.cur_observations,
|
||||
action_sampler=self.output_actions,
|
||||
action_prob=self.action_prob,
|
||||
loss=self.loss.loss,
|
||||
model=model,
|
||||
loss_inputs=self.loss_inputs,
|
||||
update_ops=q_batchnorm_update_ops)
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
self.stats_fetches = dict({
|
||||
"cur_lr": tf.cast(self.cur_lr, tf.float64),
|
||||
}, **self.loss.stats)
|
||||
|
||||
@override(TFPolicy)
|
||||
def optimizer(self):
|
||||
return tf.train.AdamOptimizer(
|
||||
learning_rate=self.cur_lr, epsilon=self.config["adam_epsilon"])
|
||||
|
||||
@override(TFPolicy)
|
||||
def gradients(self, optimizer, loss):
|
||||
if self.config["grad_norm_clipping"] is not None:
|
||||
grads_and_vars = _minimize_and_clip(
|
||||
optimizer,
|
||||
loss,
|
||||
var_list=self.q_func_vars,
|
||||
clip_val=self.config["grad_norm_clipping"])
|
||||
else:
|
||||
grads_and_vars = optimizer.compute_gradients(
|
||||
loss, var_list=self.q_func_vars)
|
||||
grads_and_vars = [(g, v) for (g, v) in grads_and_vars if g is not None]
|
||||
return grads_and_vars
|
||||
|
||||
@override(TFPolicy)
|
||||
def extra_compute_action_feed_dict(self):
|
||||
return {
|
||||
self.stochastic: True,
|
||||
self.eps: self.cur_epsilon,
|
||||
}
|
||||
|
||||
@override(TFPolicy)
|
||||
def extra_compute_grad_fetches(self):
|
||||
return {
|
||||
"td_error": self.loss.td_error,
|
||||
LEARNER_STATS_KEY: self.stats_fetches,
|
||||
}
|
||||
def set_epsilon(self, epsilon):
|
||||
self.cur_epsilon = epsilon
|
||||
|
||||
@override(Policy)
|
||||
def get_state(self):
|
||||
@@ -519,93 +327,262 @@ class DQNTFPolicy(LearningRateSchedule, DQNPostprocessing, TFPolicy):
|
||||
TFPolicy.set_state(self, state[0])
|
||||
self.set_epsilon(state[1])
|
||||
|
||||
def _build_parameter_noise(self, pnet_params):
|
||||
self.parameter_noise_sigma_val = 1.0
|
||||
self.parameter_noise_sigma = tf.get_variable(
|
||||
initializer=tf.constant_initializer(
|
||||
self.parameter_noise_sigma_val),
|
||||
name="parameter_noise_sigma",
|
||||
shape=(),
|
||||
trainable=False,
|
||||
dtype=tf.float32)
|
||||
self.parameter_noise = list()
|
||||
# No need to add any noise on LayerNorm parameters
|
||||
for var in pnet_params:
|
||||
noise_var = tf.get_variable(
|
||||
name=var.name.split(":")[0] + "_noise",
|
||||
shape=var.shape,
|
||||
initializer=tf.constant_initializer(.0),
|
||||
trainable=False)
|
||||
self.parameter_noise.append(noise_var)
|
||||
remove_noise_ops = list()
|
||||
for var, var_noise in zip(pnet_params, self.parameter_noise):
|
||||
remove_noise_ops.append(tf.assign_add(var, -var_noise))
|
||||
self.remove_noise_op = tf.group(*tuple(remove_noise_ops))
|
||||
generate_noise_ops = list()
|
||||
for var_noise in self.parameter_noise:
|
||||
generate_noise_ops.append(
|
||||
tf.assign(
|
||||
var_noise,
|
||||
tf.random_normal(
|
||||
shape=var_noise.shape,
|
||||
stddev=self.parameter_noise_sigma)))
|
||||
with tf.control_dependencies(generate_noise_ops):
|
||||
add_noise_ops = list()
|
||||
for var, var_noise in zip(pnet_params, self.parameter_noise):
|
||||
add_noise_ops.append(tf.assign_add(var, var_noise))
|
||||
self.add_noise_op = tf.group(*tuple(add_noise_ops))
|
||||
self.pi_distance = None
|
||||
|
||||
class TargetNetworkMixin(object):
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
# update_target_fn will be called periodically to copy Q network to
|
||||
# target Q network
|
||||
update_target_expr = []
|
||||
assert len(self.q_func_vars) == len(self.target_q_func_vars), \
|
||||
(self.q_func_vars, self.target_q_func_vars)
|
||||
for var, var_target in zip(self.q_func_vars, self.target_q_func_vars):
|
||||
update_target_expr.append(var_target.assign(var))
|
||||
self.update_target_expr = tf.group(*update_target_expr)
|
||||
|
||||
def update_target(self):
|
||||
return self.get_session().run(self.update_target_expr)
|
||||
|
||||
|
||||
class ComputeTDErrorMixin(object):
|
||||
def compute_td_error(self, obs_t, act_t, rew_t, obs_tp1, done_mask,
|
||||
importance_weights):
|
||||
td_err = self.sess.run(
|
||||
if not self.loss_initialized():
|
||||
return np.zeros_like(rew_t)
|
||||
|
||||
td_err = self.get_session().run(
|
||||
self.loss.td_error,
|
||||
feed_dict={
|
||||
self.obs_t: [np.array(ob) for ob in obs_t],
|
||||
self.act_t: act_t,
|
||||
self.rew_t: rew_t,
|
||||
self.obs_tp1: [np.array(ob) for ob in obs_tp1],
|
||||
self.done_mask: done_mask,
|
||||
self.importance_weights: importance_weights
|
||||
self.get_placeholder(SampleBatch.CUR_OBS): [
|
||||
np.array(ob) for ob in obs_t
|
||||
],
|
||||
self.get_placeholder(SampleBatch.ACTIONS): act_t,
|
||||
self.get_placeholder(SampleBatch.REWARDS): rew_t,
|
||||
self.get_placeholder(SampleBatch.NEXT_OBS): [
|
||||
np.array(ob) for ob in obs_tp1
|
||||
],
|
||||
self.get_placeholder(SampleBatch.DONES): done_mask,
|
||||
self.get_placeholder(PRIO_WEIGHTS): importance_weights,
|
||||
})
|
||||
return td_err
|
||||
|
||||
def add_parameter_noise(self):
|
||||
if self.config["parameter_noise"]:
|
||||
self.sess.run(self.add_noise_op)
|
||||
|
||||
def update_target(self):
|
||||
return self.sess.run(self.update_target_expr)
|
||||
def postprocess_trajectory(policy,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
if policy.config["parameter_noise"]:
|
||||
# adjust the sigma of parameter space noise
|
||||
states = [list(x) for x in sample_batch.columns(["obs"])][0]
|
||||
|
||||
def set_epsilon(self, epsilon):
|
||||
self.cur_epsilon = epsilon
|
||||
noisy_action_distribution = policy.get_session().run(
|
||||
policy.action_probs, feed_dict={policy.cur_observations: states})
|
||||
policy.get_session().run(policy.remove_noise_op)
|
||||
clean_action_distribution = policy.get_session().run(
|
||||
policy.action_probs, feed_dict={policy.cur_observations: states})
|
||||
distance_in_action_space = np.mean(
|
||||
entropy(clean_action_distribution.T, noisy_action_distribution.T))
|
||||
policy.pi_distance = distance_in_action_space
|
||||
if (distance_in_action_space <
|
||||
-np.log(1 - policy.cur_epsilon +
|
||||
policy.cur_epsilon / policy.num_actions)):
|
||||
policy.parameter_noise_sigma_val *= 1.01
|
||||
else:
|
||||
policy.parameter_noise_sigma_val /= 1.01
|
||||
policy.parameter_noise_sigma.load(
|
||||
policy.parameter_noise_sigma_val, session=policy.get_session())
|
||||
|
||||
def _build_q_network(self, obs, obs_space, action_space):
|
||||
qnet = QNetwork(
|
||||
ModelCatalog.get_model({
|
||||
"obs": obs,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
}, obs_space, action_space, self.num_actions,
|
||||
self.config["model"]), self.num_actions,
|
||||
self.config["dueling"], self.config["hiddens"],
|
||||
self.config["noisy"], self.config["num_atoms"],
|
||||
self.config["v_min"], self.config["v_max"], self.config["sigma0"],
|
||||
self.config["parameter_noise"])
|
||||
return qnet.value, qnet.logits, qnet.dist, qnet.model
|
||||
return _postprocess_dqn(policy, sample_batch)
|
||||
|
||||
def _build_q_value_policy(self, q_values):
|
||||
policy = QValuePolicy(
|
||||
q_values, self.cur_observations, self.num_actions, self.stochastic,
|
||||
self.eps, self.config["soft_q"], self.config["softmax_temp"])
|
||||
return policy.action, policy.action_prob
|
||||
|
||||
def _build_q_loss(self, q_t_selected, q_logits_t_selected, q_tp1_best,
|
||||
q_dist_tp1_best):
|
||||
return QLoss(q_t_selected, q_logits_t_selected, q_tp1_best,
|
||||
q_dist_tp1_best, self.importance_weights, self.rew_t,
|
||||
self.done_mask, self.config["gamma"],
|
||||
self.config["n_step"], self.config["num_atoms"],
|
||||
self.config["v_min"], self.config["v_max"])
|
||||
def build_q_networks(policy, input_dict, observation_space, action_space,
|
||||
config):
|
||||
|
||||
if not isinstance(action_space, Discrete):
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space {} is not supported for DQN.".format(action_space))
|
||||
|
||||
# Action Q network
|
||||
with tf.variable_scope(Q_SCOPE) as scope:
|
||||
q_values, q_logits, q_dist, _ = _build_q_network(
|
||||
policy, input_dict[SampleBatch.CUR_OBS], observation_space,
|
||||
action_space)
|
||||
policy.q_values = q_values
|
||||
policy.q_func_vars = _scope_vars(scope.name)
|
||||
|
||||
# Noise vars for Q network except for layer normalization vars
|
||||
if config["parameter_noise"]:
|
||||
_build_parameter_noise(
|
||||
policy,
|
||||
[var for var in policy.q_func_vars if "LayerNorm" not in var.name])
|
||||
policy.action_probs = tf.nn.softmax(policy.q_values)
|
||||
|
||||
# Action outputs
|
||||
qvp = QValuePolicy(q_values, input_dict[SampleBatch.CUR_OBS],
|
||||
action_space.n, policy.stochastic, policy.eps,
|
||||
policy.config["soft_q"], policy.config["softmax_temp"])
|
||||
policy.output_actions, policy.action_prob = qvp.action, qvp.action_prob
|
||||
|
||||
return policy.output_actions, policy.action_prob
|
||||
|
||||
|
||||
def _build_parameter_noise(policy, pnet_params):
|
||||
policy.parameter_noise_sigma_val = 1.0
|
||||
policy.parameter_noise_sigma = tf.get_variable(
|
||||
initializer=tf.constant_initializer(policy.parameter_noise_sigma_val),
|
||||
name="parameter_noise_sigma",
|
||||
shape=(),
|
||||
trainable=False,
|
||||
dtype=tf.float32)
|
||||
policy.parameter_noise = list()
|
||||
# No need to add any noise on LayerNorm parameters
|
||||
for var in pnet_params:
|
||||
noise_var = tf.get_variable(
|
||||
name=var.name.split(":")[0] + "_noise",
|
||||
shape=var.shape,
|
||||
initializer=tf.constant_initializer(.0),
|
||||
trainable=False)
|
||||
policy.parameter_noise.append(noise_var)
|
||||
remove_noise_ops = list()
|
||||
for var, var_noise in zip(pnet_params, policy.parameter_noise):
|
||||
remove_noise_ops.append(tf.assign_add(var, -var_noise))
|
||||
policy.remove_noise_op = tf.group(*tuple(remove_noise_ops))
|
||||
generate_noise_ops = list()
|
||||
for var_noise in policy.parameter_noise:
|
||||
generate_noise_ops.append(
|
||||
tf.assign(
|
||||
var_noise,
|
||||
tf.random_normal(
|
||||
shape=var_noise.shape,
|
||||
stddev=policy.parameter_noise_sigma)))
|
||||
with tf.control_dependencies(generate_noise_ops):
|
||||
add_noise_ops = list()
|
||||
for var, var_noise in zip(pnet_params, policy.parameter_noise):
|
||||
add_noise_ops.append(tf.assign_add(var, var_noise))
|
||||
policy.add_noise_op = tf.group(*tuple(add_noise_ops))
|
||||
policy.pi_distance = None
|
||||
|
||||
|
||||
def build_q_losses(policy, batch_tensors):
|
||||
# q network evaluation
|
||||
with tf.variable_scope(Q_SCOPE, reuse=True):
|
||||
prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
|
||||
q_t, q_logits_t, q_dist_t, model = _build_q_network(
|
||||
policy, batch_tensors[SampleBatch.CUR_OBS],
|
||||
policy.observation_space, policy.action_space)
|
||||
policy.q_batchnorm_update_ops = list(
|
||||
set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops)
|
||||
|
||||
# target q network evalution
|
||||
with tf.variable_scope(Q_TARGET_SCOPE) as scope:
|
||||
q_tp1, q_logits_tp1, q_dist_tp1, _ = _build_q_network(
|
||||
policy, batch_tensors[SampleBatch.NEXT_OBS],
|
||||
policy.observation_space, policy.action_space)
|
||||
policy.target_q_func_vars = _scope_vars(scope.name)
|
||||
|
||||
# q scores for actions which we know were selected in the given state.
|
||||
one_hot_selection = tf.one_hot(batch_tensors[SampleBatch.ACTIONS],
|
||||
policy.action_space.n)
|
||||
q_t_selected = tf.reduce_sum(q_t * one_hot_selection, 1)
|
||||
q_logits_t_selected = tf.reduce_sum(
|
||||
q_logits_t * tf.expand_dims(one_hot_selection, -1), 1)
|
||||
|
||||
# compute estimate of best possible value starting from state at t + 1
|
||||
if policy.config["double_q"]:
|
||||
with tf.variable_scope(Q_SCOPE, reuse=True):
|
||||
q_tp1_using_online_net, q_logits_tp1_using_online_net, \
|
||||
q_dist_tp1_using_online_net, _ = _build_q_network(
|
||||
policy,
|
||||
batch_tensors[SampleBatch.NEXT_OBS],
|
||||
policy.observation_space, policy.action_space)
|
||||
q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1)
|
||||
q_tp1_best_one_hot_selection = tf.one_hot(q_tp1_best_using_online_net,
|
||||
policy.action_space.n)
|
||||
q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
|
||||
q_dist_tp1_best = tf.reduce_sum(
|
||||
q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1), 1)
|
||||
else:
|
||||
q_tp1_best_one_hot_selection = tf.one_hot(
|
||||
tf.argmax(q_tp1, 1), policy.action_space.n)
|
||||
q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
|
||||
q_dist_tp1_best = tf.reduce_sum(
|
||||
q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1), 1)
|
||||
|
||||
policy.loss = _build_q_loss(
|
||||
q_t_selected, q_logits_t_selected, q_tp1_best, q_dist_tp1_best,
|
||||
batch_tensors[SampleBatch.REWARDS], batch_tensors[SampleBatch.DONES],
|
||||
batch_tensors[PRIO_WEIGHTS], policy.config)
|
||||
|
||||
return policy.loss.loss
|
||||
|
||||
|
||||
def adam_optimizer(policy, config):
|
||||
return tf.train.AdamOptimizer(
|
||||
learning_rate=policy.cur_lr, epsilon=config["adam_epsilon"])
|
||||
|
||||
|
||||
def clip_gradients(policy, optimizer, loss):
|
||||
if policy.config["grad_norm_clipping"] is not None:
|
||||
grads_and_vars = _minimize_and_clip(
|
||||
optimizer,
|
||||
loss,
|
||||
var_list=policy.q_func_vars,
|
||||
clip_val=policy.config["grad_norm_clipping"])
|
||||
else:
|
||||
grads_and_vars = optimizer.compute_gradients(
|
||||
loss, var_list=policy.q_func_vars)
|
||||
grads_and_vars = [(g, v) for (g, v) in grads_and_vars if g is not None]
|
||||
return grads_and_vars
|
||||
|
||||
|
||||
def exploration_setting_inputs(policy):
|
||||
return {
|
||||
policy.stochastic: True,
|
||||
policy.eps: policy.cur_epsilon,
|
||||
}
|
||||
|
||||
|
||||
def build_q_stats(policy, batch_tensors):
|
||||
return dict({
|
||||
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
|
||||
}, **policy.loss.stats)
|
||||
|
||||
|
||||
def setup_early_mixins(policy, obs_space, action_space, config):
|
||||
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
||||
ExplorationStateMixin.__init__(policy, obs_space, action_space, config)
|
||||
|
||||
|
||||
def setup_late_mixins(policy, obs_space, action_space, config):
|
||||
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
|
||||
|
||||
|
||||
def _build_q_network(policy, obs, obs_space, action_space):
|
||||
config = policy.config
|
||||
qnet = QNetwork(
|
||||
ModelCatalog.get_model({
|
||||
"obs": obs,
|
||||
"is_training": policy._get_is_training_placeholder(),
|
||||
}, obs_space, action_space, action_space.n, config["model"]),
|
||||
action_space.n, config["dueling"], config["hiddens"], config["noisy"],
|
||||
config["num_atoms"], config["v_min"], config["v_max"],
|
||||
config["sigma0"], config["parameter_noise"])
|
||||
return qnet.value, qnet.logits, qnet.dist, qnet.model
|
||||
|
||||
|
||||
def _build_q_value_policy(policy, q_values):
|
||||
policy = QValuePolicy(q_values, policy.cur_observations,
|
||||
policy.num_actions, policy.stochastic, policy.eps,
|
||||
policy.config["soft_q"],
|
||||
policy.config["softmax_temp"])
|
||||
return policy.action, policy.action_prob
|
||||
|
||||
|
||||
def _build_q_loss(q_t_selected, q_logits_t_selected, q_tp1_best,
|
||||
q_dist_tp1_best, rewards, dones, importance_weights, config):
|
||||
return QLoss(q_t_selected, q_logits_t_selected, q_tp1_best,
|
||||
q_dist_tp1_best, importance_weights, rewards,
|
||||
tf.cast(dones, tf.float32), config["gamma"], config["n_step"],
|
||||
config["num_atoms"], config["v_min"], config["v_max"])
|
||||
|
||||
|
||||
def _adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones):
|
||||
@@ -706,3 +683,27 @@ def _scope_vars(scope, trainable_only=False):
|
||||
tf.GraphKeys.TRAINABLE_VARIABLES
|
||||
if trainable_only else tf.GraphKeys.VARIABLES,
|
||||
scope=scope if isinstance(scope, str) else scope.name)
|
||||
|
||||
|
||||
DQNTFPolicy = build_tf_policy(
|
||||
name="DQNTFPolicy",
|
||||
get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG,
|
||||
make_action_sampler=build_q_networks,
|
||||
loss_fn=build_q_losses,
|
||||
stats_fn=build_q_stats,
|
||||
postprocess_fn=postprocess_trajectory,
|
||||
optimizer_fn=adam_optimizer,
|
||||
gradients_fn=clip_gradients,
|
||||
extra_action_feed_fn=exploration_setting_inputs,
|
||||
extra_action_fetches_fn=lambda policy: {"q_values": policy.q_values},
|
||||
extra_learn_fetches_fn=lambda policy: {"td_error": policy.loss.td_error},
|
||||
update_ops_fn=lambda policy: policy.q_batchnorm_update_ops,
|
||||
before_init=setup_early_mixins,
|
||||
after_init=setup_late_mixins,
|
||||
obs_include_prev_action_reward=False,
|
||||
mixins=[
|
||||
ExplorationStateMixin,
|
||||
TargetNetworkMixin,
|
||||
ComputeTDErrorMixin,
|
||||
LearningRateSchedule,
|
||||
])
|
||||
|
||||
@@ -365,12 +365,15 @@ class ValueNetworkMixin(object):
|
||||
tf.get_variable_scope().name)
|
||||
|
||||
def value(self, ob, *args):
|
||||
feed_dict = {self._obs_input: [ob], self.model.seq_lens: [1]}
|
||||
feed_dict = {
|
||||
self.get_placeholder(SampleBatch.CUR_OBS): [ob],
|
||||
self.model.seq_lens: [1]
|
||||
}
|
||||
assert len(args) == len(self.model.state_in), \
|
||||
(args, self.model.state_in)
|
||||
for k, v in zip(self.model.state_in, args):
|
||||
feed_dict[k] = v
|
||||
vf = self._sess.run(self.value_function, feed_dict)
|
||||
vf = self.get_session().run(self.value_function, feed_dict)
|
||||
return vf[0]
|
||||
|
||||
|
||||
|
||||
@@ -216,7 +216,7 @@ class KLCoeffMixin(object):
|
||||
self.kl_coeff_val *= 1.5
|
||||
elif sampled_kl < 0.5 * self.kl_target:
|
||||
self.kl_coeff_val *= 0.5
|
||||
self.kl_coeff.load(self.kl_coeff_val, session=self._sess)
|
||||
self.kl_coeff.load(self.kl_coeff_val, session=self.get_session())
|
||||
return self.kl_coeff_val
|
||||
|
||||
|
||||
@@ -240,28 +240,26 @@ class ValueNetworkMixin(object):
|
||||
"a custom LSTM model that overrides the "
|
||||
"value_function() method.")
|
||||
with tf.variable_scope("value_function"):
|
||||
self.value_function = ModelCatalog.get_model({
|
||||
"obs": self._obs_input,
|
||||
"prev_actions": self._prev_action_input,
|
||||
"prev_rewards": self._prev_reward_input,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
}, obs_space, action_space, 1, vf_config).outputs
|
||||
self.value_function = ModelCatalog.get_model(
|
||||
self.get_obs_input_dict(), obs_space, action_space, 1,
|
||||
vf_config).outputs
|
||||
self.value_function = tf.reshape(self.value_function, [-1])
|
||||
else:
|
||||
self.value_function = tf.zeros(shape=tf.shape(self._obs_input)[:1])
|
||||
self.value_function = tf.zeros(
|
||||
shape=tf.shape(self.get_placeholder(SampleBatch.CUR_OBS))[:1])
|
||||
|
||||
def _value(self, ob, prev_action, prev_reward, *args):
|
||||
feed_dict = {
|
||||
self._obs_input: [ob],
|
||||
self._prev_action_input: [prev_action],
|
||||
self._prev_reward_input: [prev_reward],
|
||||
self.get_placeholder(SampleBatch.CUR_OBS): [ob],
|
||||
self.get_placeholder(SampleBatch.PREV_ACTIONS): [prev_action],
|
||||
self.get_placeholder(SampleBatch.PREV_REWARDS): [prev_reward],
|
||||
self.model.seq_lens: [1]
|
||||
}
|
||||
assert len(args) == len(self.model.state_in), \
|
||||
(args, self.model.state_in)
|
||||
for k, v in zip(self.model.state_in, args):
|
||||
feed_dict[k] = v
|
||||
vf = self._sess.run(self.value_function, feed_dict)
|
||||
vf = self.get_session().run(self.value_function, feed_dict)
|
||||
return vf[0]
|
||||
|
||||
|
||||
|
||||
@@ -1,146 +0,0 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def build_tf_policy(name,
|
||||
loss_fn,
|
||||
get_default_config=None,
|
||||
stats_fn=None,
|
||||
grad_stats_fn=None,
|
||||
extra_action_fetches_fn=None,
|
||||
postprocess_fn=None,
|
||||
optimizer_fn=None,
|
||||
gradients_fn=None,
|
||||
before_init=None,
|
||||
before_loss_init=None,
|
||||
after_init=None,
|
||||
make_action_sampler=None,
|
||||
mixins=None,
|
||||
get_batch_divisibility_req=None):
|
||||
"""Helper function for creating a dynamic tf policy at runtime.
|
||||
|
||||
Arguments:
|
||||
name (str): name of the policy (e.g., "PPOTFPolicy")
|
||||
loss_fn (func): function that returns a loss tensor the policy,
|
||||
and dict of experience tensor placeholders
|
||||
get_default_config (func): optional function that returns the default
|
||||
config to merge with any overrides
|
||||
stats_fn (func): optional function that returns a dict of
|
||||
TF fetches given the policy and batch input tensors
|
||||
grad_stats_fn (func): optional function that returns a dict of
|
||||
TF fetches given the policy and loss gradient tensors
|
||||
extra_action_fetches_fn (func): optional function that returns
|
||||
a dict of TF fetches given the policy object
|
||||
postprocess_fn (func): optional experience postprocessing function
|
||||
that takes the same args as Policy.postprocess_trajectory()
|
||||
optimizer_fn (func): optional function that returns a tf.Optimizer
|
||||
given the policy and config
|
||||
gradients_fn (func): optional function that returns a list of gradients
|
||||
given a tf optimizer and loss tensor. If not specified, this
|
||||
defaults to optimizer.compute_gradients(loss)
|
||||
before_init (func): optional function to run at the beginning of
|
||||
policy init that takes the same arguments as the policy constructor
|
||||
before_loss_init (func): optional function to run prior to loss
|
||||
init that takes the same arguments as the policy constructor
|
||||
after_init (func): optional function to run at the end of policy init
|
||||
that takes the same arguments as the policy constructor
|
||||
make_action_sampler (func): optional function that returns a
|
||||
tuple of action and action prob tensors. The function takes
|
||||
(policy, input_dict, obs_space, action_space, config) as its
|
||||
arguments
|
||||
mixins (list): list of any class mixins for the returned policy class.
|
||||
These mixins will be applied in order and will have higher
|
||||
precedence than the DynamicTFPolicy class
|
||||
get_batch_divisibility_req (func): optional function that returns
|
||||
the divisibility requirement for sample batches
|
||||
|
||||
Returns:
|
||||
a DynamicTFPolicy instance that uses the specified args
|
||||
"""
|
||||
|
||||
if not name.endswith("TFPolicy"):
|
||||
raise ValueError("Name should match *TFPolicy", name)
|
||||
|
||||
base = DynamicTFPolicy
|
||||
while mixins:
|
||||
|
||||
class new_base(mixins.pop(), base):
|
||||
pass
|
||||
|
||||
base = new_base
|
||||
|
||||
class policy_cls(base):
|
||||
def __init__(self,
|
||||
obs_space,
|
||||
action_space,
|
||||
config,
|
||||
existing_inputs=None):
|
||||
if get_default_config:
|
||||
config = dict(get_default_config(), **config)
|
||||
|
||||
if before_init:
|
||||
before_init(self, obs_space, action_space, config)
|
||||
|
||||
def before_loss_init_wrapper(policy, obs_space, action_space,
|
||||
config):
|
||||
if before_loss_init:
|
||||
before_loss_init(policy, obs_space, action_space, config)
|
||||
if extra_action_fetches_fn is None:
|
||||
self._extra_action_fetches = {}
|
||||
else:
|
||||
self._extra_action_fetches = extra_action_fetches_fn(self)
|
||||
|
||||
DynamicTFPolicy.__init__(
|
||||
self,
|
||||
obs_space,
|
||||
action_space,
|
||||
config,
|
||||
loss_fn,
|
||||
stats_fn=stats_fn,
|
||||
grad_stats_fn=grad_stats_fn,
|
||||
before_loss_init=before_loss_init_wrapper,
|
||||
existing_inputs=existing_inputs)
|
||||
|
||||
if after_init:
|
||||
after_init(self, obs_space, action_space, config)
|
||||
|
||||
@override(Policy)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
if not postprocess_fn:
|
||||
return sample_batch
|
||||
return postprocess_fn(self, sample_batch, other_agent_batches,
|
||||
episode)
|
||||
|
||||
@override(TFPolicy)
|
||||
def optimizer(self):
|
||||
if optimizer_fn:
|
||||
return optimizer_fn(self, self.config)
|
||||
else:
|
||||
return TFPolicy.optimizer(self)
|
||||
|
||||
@override(TFPolicy)
|
||||
def gradients(self, optimizer, loss):
|
||||
if gradients_fn:
|
||||
return gradients_fn(self, optimizer, loss)
|
||||
else:
|
||||
return TFPolicy.gradients(self, optimizer, loss)
|
||||
|
||||
@override(TFPolicy)
|
||||
def extra_compute_action_fetches(self):
|
||||
return dict(
|
||||
TFPolicy.extra_compute_action_fetches(self),
|
||||
**self._extra_action_fetches)
|
||||
|
||||
policy_cls.__name__ = name
|
||||
policy_cls.__qualname__ = name
|
||||
return policy_cls
|
||||
@@ -37,11 +37,13 @@ class DynamicTFPolicy(TFPolicy):
|
||||
config,
|
||||
loss_fn,
|
||||
stats_fn=None,
|
||||
update_ops_fn=None,
|
||||
grad_stats_fn=None,
|
||||
before_loss_init=None,
|
||||
make_action_sampler=None,
|
||||
existing_inputs=None,
|
||||
get_batch_divisibility_req=None):
|
||||
get_batch_divisibility_req=None,
|
||||
obs_include_prev_action_reward=True):
|
||||
"""Initialize a dynamic TF policy.
|
||||
|
||||
Arguments:
|
||||
@@ -54,6 +56,8 @@ class DynamicTFPolicy(TFPolicy):
|
||||
TF fetches given the policy and batch input tensors
|
||||
grad_stats_fn (func): optional function that returns a dict of
|
||||
TF fetches given the policy and loss gradient tensors
|
||||
update_ops_fn (func): optional function that returns a list
|
||||
overriding the update ops to run when applying gradients
|
||||
before_loss_init (func): optional function to run prior to loss
|
||||
init that takes the same arguments as __init__
|
||||
make_action_sampler (func): optional function that returns a
|
||||
@@ -65,30 +69,39 @@ class DynamicTFPolicy(TFPolicy):
|
||||
defining new ones
|
||||
get_batch_divisibility_req (func): optional function that returns
|
||||
the divisibility requirement for sample batches
|
||||
obs_include_prev_action_reward (bool): whether to include the
|
||||
previous action and reward in the model input
|
||||
"""
|
||||
self.config = config
|
||||
self._loss_fn = loss_fn
|
||||
self._stats_fn = stats_fn
|
||||
self._grad_stats_fn = grad_stats_fn
|
||||
self._update_ops_fn = update_ops_fn
|
||||
self._obs_include_prev_action_reward = obs_include_prev_action_reward
|
||||
|
||||
# Setup standard placeholders
|
||||
prev_actions = None
|
||||
prev_rewards = None
|
||||
if existing_inputs is not None:
|
||||
obs = existing_inputs[SampleBatch.CUR_OBS]
|
||||
prev_actions = existing_inputs[SampleBatch.PREV_ACTIONS]
|
||||
prev_rewards = existing_inputs[SampleBatch.PREV_REWARDS]
|
||||
if self._obs_include_prev_action_reward:
|
||||
prev_actions = existing_inputs[SampleBatch.PREV_ACTIONS]
|
||||
prev_rewards = existing_inputs[SampleBatch.PREV_REWARDS]
|
||||
else:
|
||||
obs = tf.placeholder(
|
||||
tf.float32,
|
||||
shape=[None] + list(obs_space.shape),
|
||||
name="observation")
|
||||
prev_actions = ModelCatalog.get_action_placeholder(action_space)
|
||||
prev_rewards = tf.placeholder(
|
||||
tf.float32, [None], name="prev_reward")
|
||||
if self._obs_include_prev_action_reward:
|
||||
prev_actions = ModelCatalog.get_action_placeholder(
|
||||
action_space)
|
||||
prev_rewards = tf.placeholder(
|
||||
tf.float32, [None], name="prev_reward")
|
||||
|
||||
input_dict = {
|
||||
"obs": obs,
|
||||
"prev_actions": prev_actions,
|
||||
"prev_rewards": prev_rewards,
|
||||
self.input_dict = {
|
||||
SampleBatch.CUR_OBS: obs,
|
||||
SampleBatch.PREV_ACTIONS: prev_actions,
|
||||
SampleBatch.PREV_REWARDS: prev_rewards,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
}
|
||||
|
||||
@@ -100,7 +113,7 @@ class DynamicTFPolicy(TFPolicy):
|
||||
self.dist_class = None
|
||||
self.action_dist = None
|
||||
action_sampler, action_prob = make_action_sampler(
|
||||
self, input_dict, obs_space, action_space, config)
|
||||
self, self.input_dict, obs_space, action_space, config)
|
||||
else:
|
||||
self.dist_class, logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"])
|
||||
@@ -117,7 +130,7 @@ class DynamicTFPolicy(TFPolicy):
|
||||
existing_state_in = []
|
||||
existing_seq_lens = None
|
||||
self.model = ModelCatalog.get_model(
|
||||
input_dict,
|
||||
self.input_dict,
|
||||
obs_space,
|
||||
action_space,
|
||||
logit_dim,
|
||||
@@ -158,6 +171,13 @@ class DynamicTFPolicy(TFPolicy):
|
||||
if not existing_inputs:
|
||||
self._initialize_loss()
|
||||
|
||||
def get_obs_input_dict(self):
|
||||
"""Returns the obs input dict used to build policy models.
|
||||
|
||||
This dict includes the obs, prev actions, prev rewards, etc. tensors.
|
||||
"""
|
||||
return self.input_dict
|
||||
|
||||
@override(TFPolicy)
|
||||
def copy(self, existing_inputs):
|
||||
"""Creates a copy of self using existing input placeholders."""
|
||||
@@ -190,10 +210,8 @@ class DynamicTFPolicy(TFPolicy):
|
||||
self.action_space,
|
||||
self.config,
|
||||
existing_inputs=input_dict)
|
||||
loss = instance._loss_fn(instance, input_dict)
|
||||
if instance._stats_fn:
|
||||
instance._stats_fetches.update(
|
||||
instance._stats_fn(instance, input_dict))
|
||||
|
||||
loss = instance._do_loss_init(input_dict)
|
||||
TFPolicy._initialize_loss(
|
||||
instance, loss, [(k, existing_inputs[i])
|
||||
for i, (k, _) in enumerate(self._loss_inputs)])
|
||||
@@ -216,14 +234,18 @@ class DynamicTFPolicy(TFPolicy):
|
||||
return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype)
|
||||
|
||||
dummy_batch = {
|
||||
SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input),
|
||||
SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input),
|
||||
SampleBatch.CUR_OBS: fake_array(self._obs_input),
|
||||
SampleBatch.NEXT_OBS: fake_array(self._obs_input),
|
||||
SampleBatch.ACTIONS: fake_array(self._prev_action_input),
|
||||
SampleBatch.REWARDS: np.array([0], dtype=np.float32),
|
||||
SampleBatch.DONES: np.array([False], dtype=np.bool),
|
||||
SampleBatch.ACTIONS: fake_array(
|
||||
ModelCatalog.get_action_placeholder(self.action_space)),
|
||||
SampleBatch.REWARDS: np.array([0], dtype=np.float32),
|
||||
}
|
||||
if self._obs_include_prev_action_reward:
|
||||
dummy_batch.update({
|
||||
SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input),
|
||||
SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input),
|
||||
})
|
||||
state_init = self.get_initial_state()
|
||||
for i, h in enumerate(state_init):
|
||||
dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0)
|
||||
@@ -238,16 +260,24 @@ class DynamicTFPolicy(TFPolicy):
|
||||
postprocessed_batch = self.postprocess_trajectory(
|
||||
SampleBatch(dummy_batch))
|
||||
|
||||
batch_tensors = UsageTrackingDict({
|
||||
SampleBatch.PREV_ACTIONS: self._prev_action_input,
|
||||
SampleBatch.PREV_REWARDS: self._prev_reward_input,
|
||||
SampleBatch.CUR_OBS: self._obs_input,
|
||||
})
|
||||
loss_inputs = [
|
||||
(SampleBatch.PREV_ACTIONS, self._prev_action_input),
|
||||
(SampleBatch.PREV_REWARDS, self._prev_reward_input),
|
||||
(SampleBatch.CUR_OBS, self._obs_input),
|
||||
]
|
||||
if self._obs_include_prev_action_reward:
|
||||
batch_tensors = UsageTrackingDict({
|
||||
SampleBatch.PREV_ACTIONS: self._prev_action_input,
|
||||
SampleBatch.PREV_REWARDS: self._prev_reward_input,
|
||||
SampleBatch.CUR_OBS: self._obs_input,
|
||||
})
|
||||
loss_inputs = [
|
||||
(SampleBatch.PREV_ACTIONS, self._prev_action_input),
|
||||
(SampleBatch.PREV_REWARDS, self._prev_reward_input),
|
||||
(SampleBatch.CUR_OBS, self._obs_input),
|
||||
]
|
||||
else:
|
||||
batch_tensors = UsageTrackingDict({
|
||||
SampleBatch.CUR_OBS: self._obs_input,
|
||||
})
|
||||
loss_inputs = [
|
||||
(SampleBatch.CUR_OBS, self._obs_input),
|
||||
]
|
||||
|
||||
for k, v in postprocessed_batch.items():
|
||||
if k in batch_tensors:
|
||||
@@ -264,12 +294,18 @@ class DynamicTFPolicy(TFPolicy):
|
||||
"Initializing loss function with dummy input:\n\n{}\n".format(
|
||||
summarize(batch_tensors)))
|
||||
|
||||
loss = self._loss_fn(self, batch_tensors)
|
||||
if self._stats_fn:
|
||||
self._stats_fetches.update(self._stats_fn(self, batch_tensors))
|
||||
loss = self._do_loss_init(batch_tensors)
|
||||
for k in sorted(batch_tensors.accessed_keys):
|
||||
loss_inputs.append((k, batch_tensors[k]))
|
||||
TFPolicy._initialize_loss(self, loss, loss_inputs)
|
||||
if self._grad_stats_fn:
|
||||
self._stats_fetches.update(self._grad_stats_fn(self, self._grads))
|
||||
self._sess.run(tf.global_variables_initializer())
|
||||
|
||||
def _do_loss_init(self, batch_tensors):
|
||||
loss = self._loss_fn(self, batch_tensors)
|
||||
if self._stats_fn:
|
||||
self._stats_fetches.update(self._stats_fn(self, batch_tensors))
|
||||
if self._update_ops_fn:
|
||||
self._update_ops = self._update_ops_fn(self)
|
||||
return loss
|
||||
|
||||
@@ -139,6 +139,39 @@ class TFPolicy(Policy):
|
||||
raise ValueError(
|
||||
"seq_lens tensor must be given if state inputs are defined")
|
||||
|
||||
def get_placeholder(self, name):
|
||||
"""Returns the given action or loss input placeholder by name.
|
||||
|
||||
If the loss has not been initialized and a loss input placeholder is
|
||||
requested, an error is raised.
|
||||
"""
|
||||
|
||||
obs_inputs = {
|
||||
SampleBatch.CUR_OBS: self._obs_input,
|
||||
SampleBatch.PREV_ACTIONS: self._prev_action_input,
|
||||
SampleBatch.PREV_REWARDS: self._prev_reward_input,
|
||||
}
|
||||
if name in obs_inputs:
|
||||
return obs_inputs[name]
|
||||
|
||||
if not self.loss_initialized():
|
||||
raise RuntimeError(
|
||||
"You cannot call policy.get_placeholder() for non-obs inputs "
|
||||
"before the loss has been initialized. To avoid this, use "
|
||||
"policy.loss_initialized() to check whether this is the "
|
||||
"case, or move the call to later (e.g., from stats_fn to "
|
||||
"grad_stats_fn).")
|
||||
|
||||
return self._loss_input_dict[name]
|
||||
|
||||
def get_session(self):
|
||||
"""Returns a reference to the TF session for this policy."""
|
||||
return self._sess
|
||||
|
||||
def loss_initialized(self):
|
||||
"""Returns whether the loss function has been initialized."""
|
||||
return self._loss is not None
|
||||
|
||||
def _initialize_loss(self, loss, loss_inputs):
|
||||
self._loss_inputs = loss_inputs
|
||||
self._loss_input_dict = dict(self._loss_inputs)
|
||||
@@ -172,7 +205,7 @@ class TFPolicy(Policy):
|
||||
self._grads_and_vars)
|
||||
|
||||
if log_once("loss_used"):
|
||||
logger.debug(
|
||||
logger.info(
|
||||
"These tensors were used in the loss_fn:\n\n{}\n".format(
|
||||
summarize(self._loss_input_dict)))
|
||||
|
||||
@@ -195,21 +228,21 @@ class TFPolicy(Policy):
|
||||
|
||||
@override(Policy)
|
||||
def compute_gradients(self, postprocessed_batch):
|
||||
assert self._loss is not None, "Loss not initialized"
|
||||
assert self.loss_initialized()
|
||||
builder = TFRunBuilder(self._sess, "compute_gradients")
|
||||
fetches = self._build_compute_gradients(builder, postprocessed_batch)
|
||||
return builder.get(fetches)
|
||||
|
||||
@override(Policy)
|
||||
def apply_gradients(self, gradients):
|
||||
assert self._loss is not None, "Loss not initialized"
|
||||
assert self.loss_initialized()
|
||||
builder = TFRunBuilder(self._sess, "apply_gradients")
|
||||
fetches = self._build_apply_gradients(builder, gradients)
|
||||
builder.get(fetches)
|
||||
|
||||
@override(Policy)
|
||||
def learn_on_batch(self, postprocessed_batch):
|
||||
assert self._loss is not None, "Loss not initialized"
|
||||
assert self.loss_initialized()
|
||||
builder = TFRunBuilder(self._sess, "learn_on_batch")
|
||||
fetches = self._build_learn_on_batch(builder, postprocessed_batch)
|
||||
return builder.get(fetches)
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
|
||||
@@ -12,39 +12,60 @@ from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
def build_tf_policy(name,
|
||||
loss_fn,
|
||||
get_default_config=None,
|
||||
stats_fn=None,
|
||||
grad_stats_fn=None,
|
||||
extra_action_fetches_fn=None,
|
||||
postprocess_fn=None,
|
||||
stats_fn=None,
|
||||
update_ops_fn=None,
|
||||
optimizer_fn=None,
|
||||
gradients_fn=None,
|
||||
grad_stats_fn=None,
|
||||
extra_action_fetches_fn=None,
|
||||
extra_action_feed_fn=None,
|
||||
extra_learn_fetches_fn=None,
|
||||
extra_learn_feed_fn=None,
|
||||
before_init=None,
|
||||
before_loss_init=None,
|
||||
after_init=None,
|
||||
make_action_sampler=None,
|
||||
mixins=None,
|
||||
get_batch_divisibility_req=None):
|
||||
get_batch_divisibility_req=None,
|
||||
obs_include_prev_action_reward=True):
|
||||
"""Helper function for creating a dynamic tf policy at runtime.
|
||||
|
||||
Functions will be run in this order to initialize the policy:
|
||||
1. Placeholder setup: postprocess_fn
|
||||
2. Loss init: loss_fn, stats_fn, update_ops_fn
|
||||
3. Optimizer init: optimizer_fn, gradients_fn, grad_stats_fn
|
||||
|
||||
This means that you can e.g., depend on any policy attributes created in
|
||||
the running of `loss_fn` in later functions such as `stats_fn`.
|
||||
|
||||
Arguments:
|
||||
name (str): name of the policy (e.g., "PPOTFPolicy")
|
||||
loss_fn (func): function that returns a loss tensor the policy,
|
||||
and dict of experience tensor placeholders
|
||||
and dict of experience tensor placeholdes
|
||||
get_default_config (func): optional function that returns the default
|
||||
config to merge with any overrides
|
||||
stats_fn (func): optional function that returns a dict of
|
||||
TF fetches given the policy and batch input tensors
|
||||
grad_stats_fn (func): optional function that returns a dict of
|
||||
TF fetches given the policy and loss gradient tensors
|
||||
extra_action_fetches_fn (func): optional function that returns
|
||||
a dict of TF fetches given the policy object
|
||||
postprocess_fn (func): optional experience postprocessing function
|
||||
that takes the same args as Policy.postprocess_trajectory()
|
||||
stats_fn (func): optional function that returns a dict of
|
||||
TF fetches given the policy and batch input tensors
|
||||
update_ops_fn (func): optional function that returns a list overriding
|
||||
the update ops to run when applying gradients
|
||||
optimizer_fn (func): optional function that returns a tf.Optimizer
|
||||
given the policy and config
|
||||
gradients_fn (func): optional function that returns a list of gradients
|
||||
given a tf optimizer and loss tensor. If not specified, this
|
||||
defaults to optimizer.compute_gradients(loss)
|
||||
grad_stats_fn (func): optional function that returns a dict of
|
||||
TF fetches given the policy and loss gradient tensors
|
||||
extra_action_fetches_fn (func): optional function that returns
|
||||
a dict of TF fetches given the policy object
|
||||
extra_action_feed_fn (func): optional function that returns a feed dict
|
||||
to also feed to TF when computing actions
|
||||
extra_learn_fetches_fn (func): optional function that returns a dict of
|
||||
extra values to fetch and return when learning on a batch
|
||||
extra_learn_feed_fn (func): optional function that returns a feed dict
|
||||
to also feed to TF when learning on a batch
|
||||
before_init (func): optional function to run at the beginning of
|
||||
policy init that takes the same arguments as the policy constructor
|
||||
before_loss_init (func): optional function to run prior to loss
|
||||
@@ -60,6 +81,8 @@ def build_tf_policy(name,
|
||||
precedence than the DynamicTFPolicy class
|
||||
get_batch_divisibility_req (func): optional function that returns
|
||||
the divisibility requirement for sample batches
|
||||
obs_include_prev_action_reward (bool): whether to include the
|
||||
previous action and reward in the model input
|
||||
|
||||
Returns:
|
||||
a DynamicTFPolicy instance that uses the specified args
|
||||
@@ -105,8 +128,11 @@ def build_tf_policy(name,
|
||||
loss_fn,
|
||||
stats_fn=stats_fn,
|
||||
grad_stats_fn=grad_stats_fn,
|
||||
update_ops_fn=update_ops_fn,
|
||||
before_loss_init=before_loss_init_wrapper,
|
||||
existing_inputs=existing_inputs)
|
||||
make_action_sampler=make_action_sampler,
|
||||
existing_inputs=existing_inputs,
|
||||
obs_include_prev_action_reward=obs_include_prev_action_reward)
|
||||
|
||||
if after_init:
|
||||
after_init(self, obs_space, action_space, config)
|
||||
@@ -141,6 +167,30 @@ def build_tf_policy(name,
|
||||
TFPolicy.extra_compute_action_fetches(self),
|
||||
**self._extra_action_fetches)
|
||||
|
||||
@override(TFPolicy)
|
||||
def extra_compute_action_feed_dict(self):
|
||||
if extra_action_feed_fn:
|
||||
return extra_action_feed_fn(self)
|
||||
else:
|
||||
return TFPolicy.extra_compute_action_feed_dict(self)
|
||||
|
||||
@override(TFPolicy)
|
||||
def extra_compute_grad_fetches(self):
|
||||
if extra_learn_fetches_fn:
|
||||
# auto-add empty learner stats dict if needed
|
||||
return dict({
|
||||
LEARNER_STATS_KEY: {}
|
||||
}, **extra_learn_fetches_fn(self))
|
||||
else:
|
||||
return TFPolicy.extra_compute_grad_fetches(self)
|
||||
|
||||
@override(TFPolicy)
|
||||
def extra_compute_grad_feed_dict(self):
|
||||
if extra_learn_feed_fn:
|
||||
return extra_learn_feed_fn(self)
|
||||
else:
|
||||
return TFPolicy.extra_compute_grad_feed_dict(self)
|
||||
|
||||
policy_cls.__name__ = name
|
||||
policy_cls.__qualname__ = name
|
||||
return policy_cls
|
||||
|
||||
Reference in New Issue
Block a user