mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 10:49:16 +08:00
[rllib] Minor cleanups to TFPolicyGraph: add init args, constants for loss inputs (#4478)
This commit is contained in:
@@ -46,21 +46,21 @@ class A3CAgent(Agent):
|
||||
_policy_graph = A3CPolicyGraph
|
||||
|
||||
@override(Agent)
|
||||
def _init(self):
|
||||
if self.config["use_pytorch"]:
|
||||
def _init(self, config, env_creator):
|
||||
if config["use_pytorch"]:
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy_graph import \
|
||||
A3CTorchPolicyGraph
|
||||
policy_cls = A3CTorchPolicyGraph
|
||||
else:
|
||||
policy_cls = self._policy_graph
|
||||
|
||||
if self.config["entropy_coeff"] < 0:
|
||||
if config["entropy_coeff"] < 0:
|
||||
raise DeprecationWarning("entropy_coeff must be >= 0")
|
||||
|
||||
self.local_evaluator = self.make_local_evaluator(
|
||||
self.env_creator, policy_cls)
|
||||
env_creator, policy_cls)
|
||||
self.remote_evaluators = self.make_remote_evaluators(
|
||||
self.env_creator, policy_cls, self.config["num_workers"])
|
||||
env_creator, policy_cls, config["num_workers"])
|
||||
self.optimizer = self._make_optimizer()
|
||||
|
||||
@override(Agent)
|
||||
|
||||
@@ -9,10 +9,12 @@ import gym
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||
Postprocessing
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
|
||||
LearningRateSchedule
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
@@ -40,7 +42,36 @@ class A3CLoss(object):
|
||||
self.entropy * entropy_coeff)
|
||||
|
||||
|
||||
class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
class A3CPostprocessing(object):
|
||||
"""Adds the VF preds and advantages fields to the trajectory."""
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_action_fetches(self):
|
||||
return dict(
|
||||
TFPolicyGraph.extra_compute_action_fetches(self),
|
||||
**{SampleBatch.VF_PREDS: self.vf})
|
||||
|
||||
@override(PolicyGraph)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
completed = sample_batch[SampleBatch.DONES][-1]
|
||||
if completed:
|
||||
last_r = 0.0
|
||||
else:
|
||||
next_state = []
|
||||
for i in range(len(self.model.state_in)):
|
||||
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
|
||||
last_r = self._value(sample_batch[SampleBatch.NEXT_OBS][-1],
|
||||
sample_batch[SampleBatch.ACTIONS][-1],
|
||||
sample_batch[SampleBatch.REWARDS][-1],
|
||||
*next_state)
|
||||
return compute_advantages(sample_batch, last_r, self.config["gamma"],
|
||||
self.config["lambda"])
|
||||
|
||||
|
||||
class A3CPolicyGraph(LearningRateSchedule, A3CPostprocessing, TFPolicyGraph):
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
|
||||
self.config = config
|
||||
@@ -83,12 +114,12 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
|
||||
# Initialize TFPolicyGraph
|
||||
loss_in = [
|
||||
("obs", self.observations),
|
||||
("actions", actions),
|
||||
("prev_actions", self.prev_actions),
|
||||
("prev_rewards", self.prev_rewards),
|
||||
("advantages", advantages),
|
||||
("value_targets", self.v_target),
|
||||
(SampleBatch.CUR_OBS, self.observations),
|
||||
(SampleBatch.ACTIONS, actions),
|
||||
(SampleBatch.PREV_ACTIONS, self.prev_actions),
|
||||
(SampleBatch.PREV_REWARDS, self.prev_rewards),
|
||||
(Postprocessing.ADVANTAGES, advantages),
|
||||
(Postprocessing.VALUE_TARGETS, self.v_target),
|
||||
]
|
||||
LearningRateSchedule.__init__(self, self.config["lr"],
|
||||
self.config["lr_schedule"])
|
||||
@@ -128,24 +159,6 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
def get_initial_state(self):
|
||||
return self.model.state_init
|
||||
|
||||
@override(PolicyGraph)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
completed = sample_batch["dones"][-1]
|
||||
if completed:
|
||||
last_r = 0.0
|
||||
else:
|
||||
next_state = []
|
||||
for i in range(len(self.model.state_in)):
|
||||
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
|
||||
last_r = self._value(sample_batch["new_obs"][-1],
|
||||
sample_batch["actions"][-1],
|
||||
sample_batch["rewards"][-1], *next_state)
|
||||
return compute_advantages(sample_batch, last_r, self.config["gamma"],
|
||||
self.config["lambda"])
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def gradients(self, optimizer, loss):
|
||||
grads = tf.gradients(loss, self.var_list)
|
||||
@@ -157,12 +170,6 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
def extra_compute_grad_fetches(self):
|
||||
return self.stats_fetches
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_action_fetches(self):
|
||||
return dict(
|
||||
TFPolicyGraph.extra_compute_action_fetches(self),
|
||||
**{"vf_preds": self.vf})
|
||||
|
||||
def _value(self, ob, prev_action, prev_reward, *args):
|
||||
feed_dict = {
|
||||
self.observations: [ob],
|
||||
|
||||
@@ -8,8 +8,10 @@ from torch import nn
|
||||
|
||||
import ray
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||
Postprocessing
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.torch_policy_graph import TorchPolicyGraph
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
@@ -37,7 +39,28 @@ class A3CLoss(nn.Module):
|
||||
return overall_err
|
||||
|
||||
|
||||
class A3CTorchPolicyGraph(TorchPolicyGraph):
|
||||
class A3CPostprocessing(object):
|
||||
"""Adds the VF preds and advantages fields to the trajectory."""
|
||||
|
||||
@override(TorchPolicyGraph)
|
||||
def extra_action_out(self, model_out):
|
||||
return {SampleBatch.VF_PREDS: model_out[2].numpy()}
|
||||
|
||||
@override(PolicyGraph)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
completed = sample_batch[SampleBatch.DONES][-1]
|
||||
if completed:
|
||||
last_r = 0.0
|
||||
else:
|
||||
last_r = self._value(sample_batch[SampleBatch.NEXT_OBS][-1])
|
||||
return compute_advantages(sample_batch, last_r, self.config["gamma"],
|
||||
self.config["lambda"])
|
||||
|
||||
|
||||
class A3CTorchPolicyGraph(A3CPostprocessing, TorchPolicyGraph):
|
||||
"""A simple, non-recurrent PyTorch policy example."""
|
||||
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
@@ -55,29 +78,15 @@ class A3CTorchPolicyGraph(TorchPolicyGraph):
|
||||
action_space,
|
||||
self.model,
|
||||
loss,
|
||||
loss_inputs=["obs", "actions", "advantages", "value_targets"])
|
||||
|
||||
@override(TorchPolicyGraph)
|
||||
def extra_action_out(self, model_out):
|
||||
return {"vf_preds": model_out[2].numpy()}
|
||||
loss_inputs=[
|
||||
SampleBatch.CUR_OBS, SampleBatch.ACTIONS,
|
||||
Postprocessing.ADVANTAGES, Postprocessing.VALUE_TARGETS
|
||||
])
|
||||
|
||||
@override(TorchPolicyGraph)
|
||||
def optimizer(self):
|
||||
return torch.optim.Adam(self.model.parameters(), lr=self.config["lr"])
|
||||
|
||||
@override(PolicyGraph)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
completed = sample_batch["dones"][-1]
|
||||
if completed:
|
||||
last_r = 0.0
|
||||
else:
|
||||
last_r = self._value(sample_batch["new_obs"][-1])
|
||||
return compute_advantages(sample_batch, last_r, self.config["gamma"],
|
||||
self.config["lambda"])
|
||||
|
||||
def _value(self, obs):
|
||||
with self.lock:
|
||||
obs = torch.from_numpy(obs).float().unsqueeze(0)
|
||||
|
||||
@@ -374,7 +374,7 @@ class Agent(Trainable):
|
||||
|
||||
# TODO(ekl) setting the graph is unnecessary for PyTorch agents
|
||||
with tf.Graph().as_default():
|
||||
self._init()
|
||||
self._init(self.config, self.env_creator)
|
||||
|
||||
@override(Trainable)
|
||||
def _stop(self):
|
||||
@@ -398,7 +398,7 @@ class Agent(Trainable):
|
||||
self.__setstate__(extra_data)
|
||||
|
||||
@DeveloperAPI
|
||||
def _init(self):
|
||||
def _init(self, config, env_creator):
|
||||
"""Subclasses should override this for custom initialization."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -164,32 +164,31 @@ class ARSAgent(Agent):
|
||||
_default_config = DEFAULT_CONFIG
|
||||
|
||||
@override(Agent)
|
||||
def _init(self):
|
||||
env = self.env_creator(self.config["env_config"])
|
||||
def _init(self, config, env_creator):
|
||||
env = env_creator(config["env_config"])
|
||||
from ray.rllib import models
|
||||
preprocessor = models.ModelCatalog.get_preprocessor(env)
|
||||
|
||||
self.sess = utils.make_session(single_threaded=False)
|
||||
self.policy = policies.GenericPolicy(
|
||||
self.sess, env.action_space, env.observation_space, preprocessor,
|
||||
self.config["observation_filter"], self.config["model"])
|
||||
self.optimizer = optimizers.SGD(self.policy,
|
||||
self.config["sgd_stepsize"])
|
||||
config["observation_filter"], config["model"])
|
||||
self.optimizer = optimizers.SGD(self.policy, config["sgd_stepsize"])
|
||||
|
||||
self.rollouts_used = self.config["rollouts_used"]
|
||||
self.num_rollouts = self.config["num_rollouts"]
|
||||
self.report_length = self.config["report_length"]
|
||||
self.rollouts_used = config["rollouts_used"]
|
||||
self.num_rollouts = config["num_rollouts"]
|
||||
self.report_length = config["report_length"]
|
||||
|
||||
# Create the shared noise table.
|
||||
logger.info("Creating shared noise table.")
|
||||
noise_id = create_shared_noise.remote(self.config["noise_size"])
|
||||
noise_id = create_shared_noise.remote(config["noise_size"])
|
||||
self.noise = SharedNoiseTable(ray.get(noise_id))
|
||||
|
||||
# Create the actors.
|
||||
logger.info("Creating actors.")
|
||||
self.workers = [
|
||||
Worker.remote(self.config, self.env_creator, noise_id)
|
||||
for _ in range(self.config["num_workers"])
|
||||
Worker.remote(config, env_creator, noise_id)
|
||||
for _ in range(config["num_workers"])
|
||||
]
|
||||
|
||||
self.episodes_so_far = 0
|
||||
|
||||
@@ -11,6 +11,7 @@ import ray
|
||||
import ray.experimental.tf_utils
|
||||
from ray.rllib.agents.dqn.dqn_policy_graph import (
|
||||
_huber_loss, _minimize_and_clip, _scope_vars, _postprocess_dqn)
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.utils.annotations import override
|
||||
@@ -18,16 +19,116 @@ from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
|
||||
A_SCOPE = "a_func"
|
||||
P_SCOPE = "p_func"
|
||||
P_TARGET_SCOPE = "target_p_func"
|
||||
ACTION_SCOPE = "a_func"
|
||||
POLICY_SCOPE = "p_func"
|
||||
POLICY_TARGET_SCOPE = "target_p_func"
|
||||
Q_SCOPE = "q_func"
|
||||
Q_TARGET_SCOPE = "target_q_func"
|
||||
TWIN_Q_SCOPE = "twin_q_func"
|
||||
TWIN_Q_TARGET_SCOPE = "twin_target_q_func"
|
||||
|
||||
# Importance sampling weights for prioritized replay
|
||||
PRIO_WEIGHTS = "weights"
|
||||
|
||||
class PNetwork(object):
|
||||
|
||||
class ActorCriticLoss(object):
|
||||
def __init__(self,
|
||||
q_t,
|
||||
q_tp1,
|
||||
q_tp0,
|
||||
importance_weights,
|
||||
rewards,
|
||||
done_mask,
|
||||
twin_q_t,
|
||||
twin_q_tp1,
|
||||
actor_loss_coeff=0.1,
|
||||
critic_loss_coeff=1.0,
|
||||
gamma=0.99,
|
||||
n_step=1,
|
||||
use_huber=False,
|
||||
huber_threshold=1.0,
|
||||
twin_q=False,
|
||||
policy_delay=1):
|
||||
|
||||
q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1)
|
||||
if twin_q:
|
||||
twin_q_t_selected = tf.squeeze(twin_q_t, axis=len(q_t.shape) - 1)
|
||||
q_tp1 = tf.minimum(q_tp1, twin_q_tp1)
|
||||
|
||||
q_tp1_best = tf.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1)
|
||||
q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best
|
||||
|
||||
# compute RHS of bellman equation
|
||||
q_t_selected_target = rewards + gamma**n_step * q_tp1_best_masked
|
||||
|
||||
# compute the error (potentially clipped)
|
||||
if twin_q:
|
||||
td_error = q_t_selected - tf.stop_gradient(q_t_selected_target)
|
||||
twin_td_error = twin_q_t_selected - tf.stop_gradient(
|
||||
q_t_selected_target)
|
||||
self.td_error = td_error + twin_td_error
|
||||
if use_huber:
|
||||
errors = _huber_loss(td_error, huber_threshold) + _huber_loss(
|
||||
twin_td_error, huber_threshold)
|
||||
else:
|
||||
errors = 0.5 * tf.square(td_error) + 0.5 * tf.square(
|
||||
twin_td_error)
|
||||
else:
|
||||
self.td_error = (
|
||||
q_t_selected - tf.stop_gradient(q_t_selected_target))
|
||||
if use_huber:
|
||||
errors = _huber_loss(self.td_error, huber_threshold)
|
||||
else:
|
||||
errors = 0.5 * tf.square(self.td_error)
|
||||
|
||||
self.critic_loss = critic_loss_coeff * tf.reduce_mean(
|
||||
importance_weights * errors)
|
||||
|
||||
# for policy gradient, update policy net one time v.s.
|
||||
# update critic net `policy_delay` time(s)
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
policy_delay_mask = tf.to_float(
|
||||
tf.equal(tf.mod(global_step, policy_delay), 0))
|
||||
self.actor_loss = (-1.0 * actor_loss_coeff * policy_delay_mask *
|
||||
tf.reduce_mean(q_tp0))
|
||||
|
||||
|
||||
class DDPGPostprocessing(object):
|
||||
"""Implements n-step learning and param noise adjustments."""
|
||||
|
||||
@override(PolicyGraph)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
if self.config["parameter_noise"]:
|
||||
# adjust the sigma of parameter space noise
|
||||
states, noisy_actions = [
|
||||
list(x) for x in sample_batch.columns(
|
||||
[SampleBatch.CUR_OBS, SampleBatch.ACTIONS])
|
||||
]
|
||||
self.sess.run(self.remove_noise_op)
|
||||
clean_actions = self.sess.run(
|
||||
self.output_actions,
|
||||
feed_dict={
|
||||
self.cur_observations: states,
|
||||
self.stochastic: False,
|
||||
self.eps: .0
|
||||
})
|
||||
distance_in_action_space = np.sqrt(
|
||||
np.mean(np.square(clean_actions - noisy_actions)))
|
||||
self.pi_distance = distance_in_action_space
|
||||
if distance_in_action_space < self.config["exploration_sigma"]:
|
||||
self.parameter_noise_sigma_val *= 1.01
|
||||
else:
|
||||
self.parameter_noise_sigma_val /= 1.01
|
||||
self.parameter_noise_sigma.load(
|
||||
self.parameter_noise_sigma_val, session=self.sess)
|
||||
|
||||
return _postprocess_dqn(self, sample_batch)
|
||||
|
||||
|
||||
class PolicyNetwork(object):
|
||||
"""Maps an observations (i.e., state) to an action where each entry takes
|
||||
value from (0, 1) due to the sigmoid function."""
|
||||
|
||||
@@ -128,69 +229,7 @@ class QNetwork(object):
|
||||
self.model = model
|
||||
|
||||
|
||||
class ActorCriticLoss(object):
|
||||
def __init__(self,
|
||||
q_t,
|
||||
q_tp1,
|
||||
q_tp0,
|
||||
importance_weights,
|
||||
rewards,
|
||||
done_mask,
|
||||
twin_q_t,
|
||||
twin_q_tp1,
|
||||
actor_loss_coeff=0.1,
|
||||
critic_loss_coeff=1.0,
|
||||
gamma=0.99,
|
||||
n_step=1,
|
||||
use_huber=False,
|
||||
huber_threshold=1.0,
|
||||
twin_q=False,
|
||||
policy_delay=1):
|
||||
|
||||
q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1)
|
||||
if twin_q:
|
||||
twin_q_t_selected = tf.squeeze(twin_q_t, axis=len(q_t.shape) - 1)
|
||||
q_tp1 = tf.minimum(q_tp1, twin_q_tp1)
|
||||
|
||||
q_tp1_best = tf.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1)
|
||||
q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best
|
||||
|
||||
# compute RHS of bellman equation
|
||||
q_t_selected_target = rewards + gamma**n_step * q_tp1_best_masked
|
||||
|
||||
# compute the error (potentially clipped)
|
||||
if twin_q:
|
||||
td_error = q_t_selected - tf.stop_gradient(q_t_selected_target)
|
||||
twin_td_error = twin_q_t_selected - tf.stop_gradient(
|
||||
q_t_selected_target)
|
||||
self.td_error = td_error + twin_td_error
|
||||
if use_huber:
|
||||
errors = _huber_loss(td_error, huber_threshold) + _huber_loss(
|
||||
twin_td_error, huber_threshold)
|
||||
else:
|
||||
errors = 0.5 * tf.square(td_error) + 0.5 * tf.square(
|
||||
twin_td_error)
|
||||
else:
|
||||
self.td_error = (
|
||||
q_t_selected - tf.stop_gradient(q_t_selected_target))
|
||||
if use_huber:
|
||||
errors = _huber_loss(self.td_error, huber_threshold)
|
||||
else:
|
||||
errors = 0.5 * tf.square(self.td_error)
|
||||
|
||||
self.critic_loss = critic_loss_coeff * tf.reduce_mean(
|
||||
importance_weights * errors)
|
||||
|
||||
# for policy gradient, update policy net one time v.s.
|
||||
# update critic net `policy_delay` time(s)
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
policy_delay_mask = tf.to_float(
|
||||
tf.equal(tf.mod(global_step, policy_delay), 0))
|
||||
self.actor_loss = (-1.0 * actor_loss_coeff * policy_delay_mask *
|
||||
tf.reduce_mean(q_tp0))
|
||||
|
||||
|
||||
class DDPGPolicyGraph(TFPolicyGraph):
|
||||
class DDPGPolicyGraph(DDPGPostprocessing, TFPolicyGraph):
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
config = dict(ray.rllib.agents.ddpg.ddpg.DEFAULT_CONFIG, **config)
|
||||
if not isinstance(action_space, Box):
|
||||
@@ -216,7 +255,7 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
name="cur_obs")
|
||||
|
||||
# Actor: P (policy) network
|
||||
with tf.variable_scope(P_SCOPE) as scope:
|
||||
with tf.variable_scope(POLICY_SCOPE) as scope:
|
||||
p_values, self.p_model = self._build_p_network(
|
||||
self.cur_observations, observation_space, action_space)
|
||||
self.p_func_vars = _scope_vars(scope.name)
|
||||
@@ -228,14 +267,14 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
])
|
||||
|
||||
# Action outputs
|
||||
with tf.variable_scope(A_SCOPE):
|
||||
with tf.variable_scope(ACTION_SCOPE):
|
||||
self.output_actions = self._build_action_network(
|
||||
p_values, self.stochastic, self.eps)
|
||||
|
||||
if self.config["smooth_target_policy"]:
|
||||
self.reset_noise_op = tf.no_op()
|
||||
else:
|
||||
with tf.variable_scope(A_SCOPE, reuse=True):
|
||||
with tf.variable_scope(ACTION_SCOPE, reuse=True):
|
||||
exploration_sample = tf.get_variable(name="ornstein_uhlenbeck")
|
||||
self.reset_noise_op = tf.assign(exploration_sample,
|
||||
self.dim_actions * [.0])
|
||||
@@ -255,7 +294,7 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
tf.float32, [None], name="weight")
|
||||
|
||||
# p network evaluation
|
||||
with tf.variable_scope(P_SCOPE, reuse=True) as scope:
|
||||
with tf.variable_scope(POLICY_SCOPE, reuse=True) as scope:
|
||||
prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
|
||||
self.p_t, _ = self._build_p_network(self.obs_t, observation_space,
|
||||
action_space)
|
||||
@@ -264,13 +303,13 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
prev_update_ops)
|
||||
|
||||
# target p network evaluation
|
||||
with tf.variable_scope(P_TARGET_SCOPE) as scope:
|
||||
with tf.variable_scope(POLICY_TARGET_SCOPE) as scope:
|
||||
p_tp1, _ = self._build_p_network(self.obs_tp1, observation_space,
|
||||
action_space)
|
||||
target_p_func_vars = _scope_vars(scope.name)
|
||||
|
||||
# Action outputs
|
||||
with tf.variable_scope(A_SCOPE, reuse=True):
|
||||
with tf.variable_scope(ACTION_SCOPE, reuse=True):
|
||||
output_actions = self._build_action_network(
|
||||
self.p_t,
|
||||
stochastic=tf.constant(value=False, dtype=tf.bool),
|
||||
@@ -366,12 +405,12 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
|
||||
self.sess = tf.get_default_session()
|
||||
self.loss_inputs = [
|
||||
("obs", self.obs_t),
|
||||
("actions", self.act_t),
|
||||
("rewards", self.rew_t),
|
||||
("new_obs", self.obs_tp1),
|
||||
("dones", self.done_mask),
|
||||
("weights", self.importance_weights),
|
||||
(SampleBatch.CUR_OBS, self.obs_t),
|
||||
(SampleBatch.ACTIONS, self.act_t),
|
||||
(SampleBatch.REWARDS, self.rew_t),
|
||||
(SampleBatch.NEXT_OBS, self.obs_tp1),
|
||||
(SampleBatch.DONES, self.done_mask),
|
||||
(PRIO_WEIGHTS, self.importance_weights),
|
||||
]
|
||||
input_dict = dict(self.loss_inputs)
|
||||
|
||||
@@ -450,36 +489,6 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
LEARNER_STATS_KEY: self.stats,
|
||||
}
|
||||
|
||||
@override(PolicyGraph)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
if self.config["parameter_noise"]:
|
||||
# adjust the sigma of parameter space noise
|
||||
states, noisy_actions = [
|
||||
list(x) for x in sample_batch.columns(["obs", "actions"])
|
||||
]
|
||||
self.sess.run(self.remove_noise_op)
|
||||
clean_actions = self.sess.run(
|
||||
self.output_actions,
|
||||
feed_dict={
|
||||
self.cur_observations: states,
|
||||
self.stochastic: False,
|
||||
self.eps: .0
|
||||
})
|
||||
distance_in_action_space = np.sqrt(
|
||||
np.mean(np.square(clean_actions - noisy_actions)))
|
||||
self.pi_distance = distance_in_action_space
|
||||
if distance_in_action_space < self.config["exploration_sigma"]:
|
||||
self.parameter_noise_sigma_val *= 1.01
|
||||
else:
|
||||
self.parameter_noise_sigma_val /= 1.01
|
||||
self.parameter_noise_sigma.load(
|
||||
self.parameter_noise_sigma_val, session=self.sess)
|
||||
|
||||
return _postprocess_dqn(self, sample_batch)
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def get_weights(self):
|
||||
return self.variables.get_weights()
|
||||
@@ -508,7 +517,7 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
return q_net.value, q_net.model
|
||||
|
||||
def _build_p_network(self, obs, obs_space, action_space):
|
||||
policy_net = PNetwork(
|
||||
policy_net = PolicyNetwork(
|
||||
ModelCatalog.get_model({
|
||||
"obs": obs,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
|
||||
@@ -144,18 +144,18 @@ class DQNAgent(Agent):
|
||||
_optimizer_shared_configs = OPTIMIZER_SHARED_CONFIGS
|
||||
|
||||
@override(Agent)
|
||||
def _init(self):
|
||||
def _init(self, config, env_creator):
|
||||
self._validate_config()
|
||||
|
||||
# Update effective batch size to include n-step
|
||||
adjusted_batch_size = max(self.config["sample_batch_size"],
|
||||
self.config.get("n_step", 1))
|
||||
self.config["sample_batch_size"] = adjusted_batch_size
|
||||
adjusted_batch_size = max(config["sample_batch_size"],
|
||||
config.get("n_step", 1))
|
||||
config["sample_batch_size"] = adjusted_batch_size
|
||||
|
||||
self.exploration0 = self._make_exploration_schedule(-1)
|
||||
self.explorations = [
|
||||
self._make_exploration_schedule(i)
|
||||
for i in range(self.config["num_workers"])
|
||||
for i in range(config["num_workers"])
|
||||
]
|
||||
|
||||
for k in self._optimizer_shared_configs:
|
||||
@@ -165,12 +165,12 @@ class DQNAgent(Agent):
|
||||
]:
|
||||
# only Rainbow needs annealing prioritized_replay_beta
|
||||
continue
|
||||
if k not in self.config["optimizer"]:
|
||||
self.config["optimizer"][k] = self.config[k]
|
||||
if k not in config["optimizer"]:
|
||||
config["optimizer"][k] = config[k]
|
||||
|
||||
if self.config.get("parameter_noise", False):
|
||||
if self.config["callbacks"]["on_episode_start"]:
|
||||
start_callback = self.config["callbacks"]["on_episode_start"]
|
||||
if config.get("parameter_noise", False):
|
||||
if config["callbacks"]["on_episode_start"]:
|
||||
start_callback = config["callbacks"]["on_episode_start"]
|
||||
else:
|
||||
start_callback = None
|
||||
|
||||
@@ -183,10 +183,10 @@ class DQNAgent(Agent):
|
||||
if start_callback:
|
||||
start_callback(info)
|
||||
|
||||
self.config["callbacks"]["on_episode_start"] = tune.function(
|
||||
config["callbacks"]["on_episode_start"] = tune.function(
|
||||
on_episode_start)
|
||||
if self.config["callbacks"]["on_episode_end"]:
|
||||
end_callback = self.config["callbacks"]["on_episode_end"]
|
||||
if config["callbacks"]["on_episode_end"]:
|
||||
end_callback = config["callbacks"]["on_episode_end"]
|
||||
else:
|
||||
end_callback = None
|
||||
|
||||
@@ -200,15 +200,15 @@ class DQNAgent(Agent):
|
||||
if end_callback:
|
||||
end_callback(info)
|
||||
|
||||
self.config["callbacks"]["on_episode_end"] = tune.function(
|
||||
config["callbacks"]["on_episode_end"] = tune.function(
|
||||
on_episode_end)
|
||||
|
||||
self.local_evaluator = self.make_local_evaluator(
|
||||
self.env_creator, self._policy_graph)
|
||||
env_creator, self._policy_graph)
|
||||
|
||||
if self.config["evaluation_interval"]:
|
||||
if config["evaluation_interval"]:
|
||||
self.evaluation_ev = self.make_local_evaluator(
|
||||
self.env_creator,
|
||||
env_creator,
|
||||
self._policy_graph,
|
||||
extra_config={
|
||||
"batch_mode": "complete_episodes",
|
||||
@@ -217,19 +217,17 @@ class DQNAgent(Agent):
|
||||
self.evaluation_metrics = self._evaluate()
|
||||
|
||||
def create_remote_evaluators():
|
||||
return self.make_remote_evaluators(self.env_creator,
|
||||
self._policy_graph,
|
||||
self.config["num_workers"])
|
||||
return self.make_remote_evaluators(env_creator, self._policy_graph,
|
||||
config["num_workers"])
|
||||
|
||||
if self.config["optimizer_class"] != "AsyncReplayOptimizer":
|
||||
if config["optimizer_class"] != "AsyncReplayOptimizer":
|
||||
self.remote_evaluators = create_remote_evaluators()
|
||||
else:
|
||||
# Hack to workaround https://github.com/ray-project/ray/issues/2541
|
||||
self.remote_evaluators = None
|
||||
|
||||
self.optimizer = getattr(optimizers, self.config["optimizer_class"])(
|
||||
self.local_evaluator, self.remote_evaluators,
|
||||
self.config["optimizer"])
|
||||
self.optimizer = getattr(optimizers, config["optimizer_class"])(
|
||||
self.local_evaluator, self.remote_evaluators, config["optimizer"])
|
||||
# Create the remote evaluators *after* the replay actors
|
||||
if self.remote_evaluators is None:
|
||||
self.remote_evaluators = create_remote_evaluators()
|
||||
|
||||
@@ -9,6 +9,7 @@ import tensorflow as tf
|
||||
import tensorflow.contrib.layers as layers
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.models import ModelCatalog, Categorical
|
||||
from ray.rllib.utils.annotations import override
|
||||
@@ -19,6 +20,125 @@ from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
Q_SCOPE = "q_func"
|
||||
Q_TARGET_SCOPE = "target_q_func"
|
||||
|
||||
# Importance sampling weights for prioritized replay
|
||||
PRIO_WEIGHTS = "weights"
|
||||
|
||||
|
||||
class QLoss(object):
|
||||
def __init__(self,
|
||||
q_t_selected,
|
||||
q_logits_t_selected,
|
||||
q_tp1_best,
|
||||
q_dist_tp1_best,
|
||||
importance_weights,
|
||||
rewards,
|
||||
done_mask,
|
||||
gamma=0.99,
|
||||
n_step=1,
|
||||
num_atoms=1,
|
||||
v_min=-10.0,
|
||||
v_max=10.0):
|
||||
|
||||
if num_atoms > 1:
|
||||
# Distributional Q-learning which corresponds to an entropy loss
|
||||
|
||||
z = tf.range(num_atoms, dtype=tf.float32)
|
||||
z = v_min + z * (v_max - v_min) / float(num_atoms - 1)
|
||||
|
||||
# (batch_size, 1) * (1, num_atoms) = (batch_size, num_atoms)
|
||||
r_tau = tf.expand_dims(
|
||||
rewards, -1) + gamma**n_step * tf.expand_dims(
|
||||
1.0 - done_mask, -1) * tf.expand_dims(z, 0)
|
||||
r_tau = tf.clip_by_value(r_tau, v_min, v_max)
|
||||
b = (r_tau - v_min) / ((v_max - v_min) / float(num_atoms - 1))
|
||||
lb = tf.floor(b)
|
||||
ub = tf.ceil(b)
|
||||
# indispensable judgement which is missed in most implementations
|
||||
# when b happens to be an integer, lb == ub, so pr_j(s', a*) will
|
||||
# be discarded because (ub-b) == (b-lb) == 0
|
||||
floor_equal_ceil = tf.to_float(tf.less(ub - lb, 0.5))
|
||||
|
||||
l_project = tf.one_hot(
|
||||
tf.cast(lb, dtype=tf.int32),
|
||||
num_atoms) # (batch_size, num_atoms, num_atoms)
|
||||
u_project = tf.one_hot(
|
||||
tf.cast(ub, dtype=tf.int32),
|
||||
num_atoms) # (batch_size, num_atoms, num_atoms)
|
||||
ml_delta = q_dist_tp1_best * (ub - b + floor_equal_ceil)
|
||||
mu_delta = q_dist_tp1_best * (b - lb)
|
||||
ml_delta = tf.reduce_sum(
|
||||
l_project * tf.expand_dims(ml_delta, -1), axis=1)
|
||||
mu_delta = tf.reduce_sum(
|
||||
u_project * tf.expand_dims(mu_delta, -1), axis=1)
|
||||
m = ml_delta + mu_delta
|
||||
|
||||
# Rainbow paper claims that using this cross entropy loss for
|
||||
# priority is robust and insensitive to `prioritized_replay_alpha`
|
||||
self.td_error = tf.nn.softmax_cross_entropy_with_logits(
|
||||
labels=m, logits=q_logits_t_selected)
|
||||
self.loss = tf.reduce_mean(self.td_error * importance_weights)
|
||||
self.stats = {
|
||||
# TODO: better Q stats for dist dqn
|
||||
"mean_td_error": tf.reduce_mean(self.td_error),
|
||||
}
|
||||
else:
|
||||
q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best
|
||||
|
||||
# compute RHS of bellman equation
|
||||
q_t_selected_target = rewards + gamma**n_step * q_tp1_best_masked
|
||||
|
||||
# compute the error (potentially clipped)
|
||||
self.td_error = (
|
||||
q_t_selected - tf.stop_gradient(q_t_selected_target))
|
||||
self.loss = tf.reduce_mean(
|
||||
importance_weights * _huber_loss(self.td_error))
|
||||
self.stats = {
|
||||
"mean_q": tf.reduce_mean(q_t_selected),
|
||||
"min_q": tf.reduce_min(q_t_selected),
|
||||
"max_q": tf.reduce_max(q_t_selected),
|
||||
"mean_td_error": tf.reduce_mean(self.td_error),
|
||||
}
|
||||
|
||||
|
||||
class DQNPostprocessing(object):
|
||||
"""Implements n-step learning and param noise adjustments."""
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_action_fetches(self):
|
||||
return dict(
|
||||
TFPolicyGraph.extra_compute_action_fetches(self), **{
|
||||
"q_values": self.q_values,
|
||||
})
|
||||
|
||||
@override(PolicyGraph)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
if self.config["parameter_noise"]:
|
||||
# adjust the sigma of parameter space noise
|
||||
states = [list(x) for x in sample_batch.columns(["obs"])][0]
|
||||
|
||||
noisy_action_distribution = self.sess.run(
|
||||
self.action_probs, feed_dict={self.cur_observations: states})
|
||||
self.sess.run(self.remove_noise_op)
|
||||
clean_action_distribution = self.sess.run(
|
||||
self.action_probs, feed_dict={self.cur_observations: states})
|
||||
distance_in_action_space = np.mean(
|
||||
entropy(clean_action_distribution.T,
|
||||
noisy_action_distribution.T))
|
||||
self.pi_distance = distance_in_action_space
|
||||
if (distance_in_action_space <
|
||||
-np.log(1 - self.cur_epsilon +
|
||||
self.cur_epsilon / self.num_actions)):
|
||||
self.parameter_noise_sigma_val *= 1.01
|
||||
else:
|
||||
self.parameter_noise_sigma_val /= 1.01
|
||||
self.parameter_noise_sigma.load(
|
||||
self.parameter_noise_sigma_val, session=self.sess)
|
||||
|
||||
return _postprocess_dqn(self, sample_batch)
|
||||
|
||||
|
||||
class QNetwork(object):
|
||||
def __init__(self,
|
||||
@@ -216,83 +336,7 @@ class QValuePolicy(object):
|
||||
self.action_prob = None
|
||||
|
||||
|
||||
class QLoss(object):
|
||||
def __init__(self,
|
||||
q_t_selected,
|
||||
q_logits_t_selected,
|
||||
q_tp1_best,
|
||||
q_dist_tp1_best,
|
||||
importance_weights,
|
||||
rewards,
|
||||
done_mask,
|
||||
gamma=0.99,
|
||||
n_step=1,
|
||||
num_atoms=1,
|
||||
v_min=-10.0,
|
||||
v_max=10.0):
|
||||
|
||||
if num_atoms > 1:
|
||||
# Distributional Q-learning which corresponds to an entropy loss
|
||||
|
||||
z = tf.range(num_atoms, dtype=tf.float32)
|
||||
z = v_min + z * (v_max - v_min) / float(num_atoms - 1)
|
||||
|
||||
# (batch_size, 1) * (1, num_atoms) = (batch_size, num_atoms)
|
||||
r_tau = tf.expand_dims(
|
||||
rewards, -1) + gamma**n_step * tf.expand_dims(
|
||||
1.0 - done_mask, -1) * tf.expand_dims(z, 0)
|
||||
r_tau = tf.clip_by_value(r_tau, v_min, v_max)
|
||||
b = (r_tau - v_min) / ((v_max - v_min) / float(num_atoms - 1))
|
||||
lb = tf.floor(b)
|
||||
ub = tf.ceil(b)
|
||||
# indispensable judgement which is missed in most implementations
|
||||
# when b happens to be an integer, lb == ub, so pr_j(s', a*) will
|
||||
# be discarded because (ub-b) == (b-lb) == 0
|
||||
floor_equal_ceil = tf.to_float(tf.less(ub - lb, 0.5))
|
||||
|
||||
l_project = tf.one_hot(
|
||||
tf.cast(lb, dtype=tf.int32),
|
||||
num_atoms) # (batch_size, num_atoms, num_atoms)
|
||||
u_project = tf.one_hot(
|
||||
tf.cast(ub, dtype=tf.int32),
|
||||
num_atoms) # (batch_size, num_atoms, num_atoms)
|
||||
ml_delta = q_dist_tp1_best * (ub - b + floor_equal_ceil)
|
||||
mu_delta = q_dist_tp1_best * (b - lb)
|
||||
ml_delta = tf.reduce_sum(
|
||||
l_project * tf.expand_dims(ml_delta, -1), axis=1)
|
||||
mu_delta = tf.reduce_sum(
|
||||
u_project * tf.expand_dims(mu_delta, -1), axis=1)
|
||||
m = ml_delta + mu_delta
|
||||
|
||||
# Rainbow paper claims that using this cross entropy loss for
|
||||
# priority is robust and insensitive to `prioritized_replay_alpha`
|
||||
self.td_error = tf.nn.softmax_cross_entropy_with_logits(
|
||||
labels=m, logits=q_logits_t_selected)
|
||||
self.loss = tf.reduce_mean(self.td_error * importance_weights)
|
||||
self.stats = {
|
||||
# TODO: better Q stats for dist dqn
|
||||
"mean_td_error": tf.reduce_mean(self.td_error),
|
||||
}
|
||||
else:
|
||||
q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best
|
||||
|
||||
# compute RHS of bellman equation
|
||||
q_t_selected_target = rewards + gamma**n_step * q_tp1_best_masked
|
||||
|
||||
# compute the error (potentially clipped)
|
||||
self.td_error = (
|
||||
q_t_selected - tf.stop_gradient(q_t_selected_target))
|
||||
self.loss = tf.reduce_mean(
|
||||
importance_weights * _huber_loss(self.td_error))
|
||||
self.stats = {
|
||||
"mean_q": tf.reduce_mean(q_t_selected),
|
||||
"min_q": tf.reduce_min(q_t_selected),
|
||||
"max_q": tf.reduce_max(q_t_selected),
|
||||
"mean_td_error": tf.reduce_mean(self.td_error),
|
||||
}
|
||||
|
||||
|
||||
class DQNPolicyGraph(TFPolicyGraph):
|
||||
class DQNPolicyGraph(DQNPostprocessing, TFPolicyGraph):
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
config = dict(ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, **config)
|
||||
if not isinstance(action_space, Discrete):
|
||||
@@ -396,12 +440,12 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
# initialize TFPolicyGraph
|
||||
self.sess = tf.get_default_session()
|
||||
self.loss_inputs = [
|
||||
("obs", self.obs_t),
|
||||
("actions", self.act_t),
|
||||
("rewards", self.rew_t),
|
||||
("new_obs", self.obs_tp1),
|
||||
("dones", self.done_mask),
|
||||
("weights", self.importance_weights),
|
||||
(SampleBatch.CUR_OBS, self.obs_t),
|
||||
(SampleBatch.ACTIONS, self.act_t),
|
||||
(SampleBatch.REWARDS, self.rew_t),
|
||||
(SampleBatch.NEXT_OBS, self.obs_tp1),
|
||||
(SampleBatch.DONES, self.done_mask),
|
||||
(PRIO_WEIGHTS, self.importance_weights),
|
||||
]
|
||||
TFPolicyGraph.__init__(
|
||||
self,
|
||||
@@ -437,13 +481,6 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
grads_and_vars = [(g, v) for (g, v) in grads_and_vars if g is not None]
|
||||
return grads_and_vars
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_action_fetches(self):
|
||||
return dict(
|
||||
TFPolicyGraph.extra_compute_action_fetches(self), **{
|
||||
"q_values": self.q_values,
|
||||
})
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_action_feed_dict(self):
|
||||
return {
|
||||
@@ -458,35 +495,6 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
LEARNER_STATS_KEY: self.loss.stats,
|
||||
}
|
||||
|
||||
@override(PolicyGraph)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
if self.config["parameter_noise"]:
|
||||
# adjust the sigma of parameter space noise
|
||||
states = [list(x) for x in sample_batch.columns(["obs"])][0]
|
||||
|
||||
noisy_action_distribution = self.sess.run(
|
||||
self.action_probs, feed_dict={self.cur_observations: states})
|
||||
self.sess.run(self.remove_noise_op)
|
||||
clean_action_distribution = self.sess.run(
|
||||
self.action_probs, feed_dict={self.cur_observations: states})
|
||||
distance_in_action_space = np.mean(
|
||||
entropy(clean_action_distribution.T,
|
||||
noisy_action_distribution.T))
|
||||
self.pi_distance = distance_in_action_space
|
||||
if (distance_in_action_space <
|
||||
-np.log(1 - self.cur_epsilon +
|
||||
self.cur_epsilon / self.num_actions)):
|
||||
self.parameter_noise_sigma_val *= 1.01
|
||||
else:
|
||||
self.parameter_noise_sigma_val /= 1.01
|
||||
self.parameter_noise_sigma.load(
|
||||
self.parameter_noise_sigma_val, session=self.sess)
|
||||
|
||||
return _postprocess_dqn(self, sample_batch)
|
||||
|
||||
@override(PolicyGraph)
|
||||
def get_state(self):
|
||||
return [TFPolicyGraph.get_state(self), self.cur_epsilon]
|
||||
@@ -614,21 +622,22 @@ def _postprocess_dqn(policy_graph, batch):
|
||||
# N-step Q adjustments
|
||||
if policy_graph.config["n_step"] > 1:
|
||||
_adjust_nstep(policy_graph.config["n_step"],
|
||||
policy_graph.config["gamma"], batch["obs"],
|
||||
batch["actions"], batch["rewards"], batch["new_obs"],
|
||||
batch["dones"])
|
||||
policy_graph.config["gamma"], batch[SampleBatch.CUR_OBS],
|
||||
batch[SampleBatch.ACTIONS], batch[SampleBatch.REWARDS],
|
||||
batch[SampleBatch.NEXT_OBS], batch[SampleBatch.DONES])
|
||||
|
||||
if "weights" not in batch:
|
||||
batch["weights"] = np.ones_like(batch["rewards"])
|
||||
if PRIO_WEIGHTS not in batch:
|
||||
batch[PRIO_WEIGHTS] = np.ones_like(batch[SampleBatch.REWARDS])
|
||||
|
||||
# Prioritize on the worker side
|
||||
if batch.count > 0 and policy_graph.config["worker_side_prioritization"]:
|
||||
td_errors = policy_graph.compute_td_error(
|
||||
batch["obs"], batch["actions"], batch["rewards"], batch["new_obs"],
|
||||
batch["dones"], batch["weights"])
|
||||
batch[SampleBatch.CUR_OBS], batch[SampleBatch.ACTIONS],
|
||||
batch[SampleBatch.REWARDS], batch[SampleBatch.NEXT_OBS],
|
||||
batch[SampleBatch.DONES], batch[PRIO_WEIGHTS])
|
||||
new_priorities = (
|
||||
np.abs(td_errors) + policy_graph.config["prioritized_replay_eps"])
|
||||
batch.data["weights"] = new_priorities
|
||||
batch.data[PRIO_WEIGHTS] = new_priorities
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
@@ -170,31 +170,30 @@ class ESAgent(Agent):
|
||||
_default_config = DEFAULT_CONFIG
|
||||
|
||||
@override(Agent)
|
||||
def _init(self):
|
||||
def _init(self, config, env_creator):
|
||||
policy_params = {"action_noise_std": 0.01}
|
||||
|
||||
env = self.env_creator(self.config["env_config"])
|
||||
env = env_creator(config["env_config"])
|
||||
from ray.rllib import models
|
||||
preprocessor = models.ModelCatalog.get_preprocessor(env)
|
||||
|
||||
self.sess = utils.make_session(single_threaded=False)
|
||||
self.policy = policies.GenericPolicy(
|
||||
self.sess, env.action_space, env.observation_space, preprocessor,
|
||||
self.config["observation_filter"], self.config["model"],
|
||||
**policy_params)
|
||||
self.optimizer = optimizers.Adam(self.policy, self.config["stepsize"])
|
||||
self.report_length = self.config["report_length"]
|
||||
config["observation_filter"], config["model"], **policy_params)
|
||||
self.optimizer = optimizers.Adam(self.policy, config["stepsize"])
|
||||
self.report_length = config["report_length"]
|
||||
|
||||
# Create the shared noise table.
|
||||
logger.info("Creating shared noise table.")
|
||||
noise_id = create_shared_noise.remote(self.config["noise_size"])
|
||||
noise_id = create_shared_noise.remote(config["noise_size"])
|
||||
self.noise = SharedNoiseTable(ray.get(noise_id))
|
||||
|
||||
# Create the actors.
|
||||
logger.info("Creating actors.")
|
||||
self.workers = [
|
||||
Worker.remote(self.config, policy_params, self.env_creator,
|
||||
noise_id) for _ in range(self.config["num_workers"])
|
||||
Worker.remote(config, policy_params, env_creator, noise_id)
|
||||
for _ in range(config["num_workers"])
|
||||
]
|
||||
|
||||
self.episodes_so_far = 0
|
||||
|
||||
@@ -98,19 +98,18 @@ class ImpalaAgent(Agent):
|
||||
_policy_graph = VTracePolicyGraph
|
||||
|
||||
@override(Agent)
|
||||
def _init(self):
|
||||
def _init(self, config, env_creator):
|
||||
for k in OPTIMIZER_SHARED_CONFIGS:
|
||||
if k not in self.config["optimizer"]:
|
||||
self.config["optimizer"][k] = self.config[k]
|
||||
if k not in config["optimizer"]:
|
||||
config["optimizer"][k] = config[k]
|
||||
policy_cls = self._get_policy_graph()
|
||||
self.local_evaluator = self.make_local_evaluator(
|
||||
self.env_creator, policy_cls)
|
||||
env_creator, policy_cls)
|
||||
self.remote_evaluators = self.make_remote_evaluators(
|
||||
self.env_creator, policy_cls, self.config["num_workers"])
|
||||
self.optimizer = AsyncSamplesOptimizer(self.local_evaluator,
|
||||
self.remote_evaluators,
|
||||
self.config["optimizer"])
|
||||
if self.config["entropy_coeff"] < 0:
|
||||
env_creator, policy_cls, config["num_workers"])
|
||||
self.optimizer = AsyncSamplesOptimizer(
|
||||
self.local_evaluator, self.remote_evaluators, config["optimizer"])
|
||||
if config["entropy_coeff"] < 0:
|
||||
raise DeprecationWarning("entropy_coeff must be >= 0")
|
||||
|
||||
@override(Agent)
|
||||
|
||||
@@ -13,6 +13,7 @@ import tensorflow as tf
|
||||
from ray.rllib.agents.impala import vtrace
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
|
||||
LearningRateSchedule
|
||||
from ray.rllib.models.action_dist import MultiCategorical
|
||||
@@ -21,6 +22,9 @@ from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
|
||||
# Frozen logits of the policy that computed the action
|
||||
BEHAVIOUR_LOGITS = "behaviour_logits"
|
||||
|
||||
|
||||
class VTraceLoss(object):
|
||||
def __init__(self,
|
||||
@@ -99,7 +103,27 @@ class VTraceLoss(object):
|
||||
self.entropy * entropy_coeff)
|
||||
|
||||
|
||||
class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
class VTracePostprocessing(object):
|
||||
"""Adds the policy logits to the trajectory."""
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_action_fetches(self):
|
||||
return dict(
|
||||
TFPolicyGraph.extra_compute_action_fetches(self),
|
||||
**{BEHAVIOUR_LOGITS: self.model.outputs})
|
||||
|
||||
@override(PolicyGraph)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
# not used, so save some bandwidth
|
||||
del sample_batch.data[SampleBatch.NEXT_OBS]
|
||||
return sample_batch
|
||||
|
||||
|
||||
class VTracePolicyGraph(LearningRateSchedule, VTracePostprocessing,
|
||||
TFPolicyGraph):
|
||||
def __init__(self,
|
||||
observation_space,
|
||||
action_space,
|
||||
@@ -265,13 +289,13 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
|
||||
# Initialize TFPolicyGraph
|
||||
loss_in = [
|
||||
("actions", actions),
|
||||
("dones", dones),
|
||||
("behaviour_logits", behaviour_logits),
|
||||
("rewards", rewards),
|
||||
("obs", observations),
|
||||
("prev_actions", prev_actions),
|
||||
("prev_rewards", prev_rewards),
|
||||
(SampleBatch.ACTIONS, actions),
|
||||
(SampleBatch.DONES, dones),
|
||||
(BEHAVIOUR_LOGITS, behaviour_logits),
|
||||
(SampleBatch.REWARDS, rewards),
|
||||
(SampleBatch.CUR_OBS, observations),
|
||||
(SampleBatch.PREV_ACTIONS, prev_actions),
|
||||
(SampleBatch.PREV_REWARDS, prev_rewards),
|
||||
]
|
||||
LearningRateSchedule.__init__(self, self.config["lr"],
|
||||
self.config["lr_schedule"])
|
||||
@@ -334,24 +358,10 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
clipped_grads = list(zip(self.grads, self.var_list))
|
||||
return clipped_grads
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_action_fetches(self):
|
||||
return dict(
|
||||
TFPolicyGraph.extra_compute_action_fetches(self),
|
||||
**{"behaviour_logits": self.model.outputs})
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_grad_fetches(self):
|
||||
return self.stats_fetches
|
||||
|
||||
@override(PolicyGraph)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
del sample_batch.data["new_obs"] # not used, so save some bandwidth
|
||||
return sample_batch
|
||||
|
||||
@override(PolicyGraph)
|
||||
def get_initial_state(self):
|
||||
return self.model.state_init
|
||||
|
||||
@@ -47,16 +47,16 @@ class MARWILAgent(Agent):
|
||||
_policy_graph = MARWILPolicyGraph
|
||||
|
||||
@override(Agent)
|
||||
def _init(self):
|
||||
def _init(self, config, env_creator):
|
||||
self.local_evaluator = self.make_local_evaluator(
|
||||
self.env_creator, self._policy_graph)
|
||||
env_creator, self._policy_graph)
|
||||
self.remote_evaluators = self.make_remote_evaluators(
|
||||
self.env_creator, self._policy_graph, self.config["num_workers"])
|
||||
env_creator, self._policy_graph, config["num_workers"])
|
||||
self.optimizer = SyncBatchReplayOptimizer(
|
||||
self.local_evaluator, self.remote_evaluators, {
|
||||
"learning_starts": self.config["learning_starts"],
|
||||
"buffer_size": self.config["replay_buffer_size"],
|
||||
"train_batch_size": self.config["train_batch_size"],
|
||||
"learning_starts": config["learning_starts"],
|
||||
"buffer_size": config["replay_buffer_size"],
|
||||
"train_batch_size": config["train_batch_size"],
|
||||
})
|
||||
|
||||
@override(Agent)
|
||||
|
||||
@@ -6,16 +6,18 @@ import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||
Postprocessing
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.agents.dqn.dqn_policy_graph import _scope_vars
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
|
||||
P_SCOPE = "p_func"
|
||||
V_SCOPE = "v_func"
|
||||
POLICY_SCOPE = "p_func"
|
||||
VALUE_SCOPE = "v_func"
|
||||
|
||||
|
||||
class ValueLoss(object):
|
||||
@@ -53,7 +55,30 @@ class ReweightedImitationLoss(object):
|
||||
tf.stop_gradient(exp_advs) * logprobs)
|
||||
|
||||
|
||||
class MARWILPolicyGraph(TFPolicyGraph):
|
||||
class MARWILPostprocessing(object):
|
||||
"""Adds the advantages field to the trajectory."""
|
||||
|
||||
@override(PolicyGraph)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
completed = sample_batch["dones"][-1]
|
||||
if completed:
|
||||
last_r = 0.0
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"last done mask in a batch should be True. "
|
||||
"For now, we only support reading experience batches produced "
|
||||
"with batch_mode='complete_episodes'.",
|
||||
len(sample_batch[SampleBatch.DONES]),
|
||||
sample_batch[SampleBatch.DONES][-1])
|
||||
batch = compute_advantages(
|
||||
sample_batch, last_r, gamma=self.config["gamma"], use_gae=False)
|
||||
return batch
|
||||
|
||||
|
||||
class MARWILPolicyGraph(MARWILPostprocessing, TFPolicyGraph):
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
config = dict(ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, **config)
|
||||
self.config = config
|
||||
@@ -68,7 +93,7 @@ class MARWILPolicyGraph(TFPolicyGraph):
|
||||
prev_rewards_ph = tf.placeholder(
|
||||
tf.float32, [None], name="prev_reward")
|
||||
|
||||
with tf.variable_scope(P_SCOPE) as scope:
|
||||
with tf.variable_scope(POLICY_SCOPE) as scope:
|
||||
self.model = ModelCatalog.get_model({
|
||||
"obs": self.obs_t,
|
||||
"prev_actions": prev_actions_ph,
|
||||
@@ -88,7 +113,7 @@ class MARWILPolicyGraph(TFPolicyGraph):
|
||||
self.cum_rew_t = tf.placeholder(tf.float32, [None], name="reward")
|
||||
|
||||
# v network evaluation
|
||||
with tf.variable_scope(V_SCOPE) as scope:
|
||||
with tf.variable_scope(VALUE_SCOPE) as scope:
|
||||
state_values = self.model.value_function()
|
||||
self.v_func_vars = _scope_vars(scope.name)
|
||||
self.v_loss = self._build_value_loss(state_values, self.cum_rew_t)
|
||||
@@ -104,9 +129,9 @@ class MARWILPolicyGraph(TFPolicyGraph):
|
||||
# initialize TFPolicyGraph
|
||||
self.sess = tf.get_default_session()
|
||||
self.loss_inputs = [
|
||||
("obs", self.obs_t),
|
||||
("actions", self.act_t),
|
||||
("advantages", self.cum_rew_t),
|
||||
(SampleBatch.CUR_OBS, self.obs_t),
|
||||
(SampleBatch.ACTIONS, self.act_t),
|
||||
(Postprocessing.ADVANTAGES, self.cum_rew_t),
|
||||
]
|
||||
TFPolicyGraph.__init__(
|
||||
self,
|
||||
@@ -144,24 +169,6 @@ class MARWILPolicyGraph(TFPolicyGraph):
|
||||
def extra_compute_grad_fetches(self):
|
||||
return {LEARNER_STATS_KEY: self.stats_fetches}
|
||||
|
||||
@override(PolicyGraph)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
completed = sample_batch["dones"][-1]
|
||||
if completed:
|
||||
last_r = 0.0
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"last done mask in a batch should be True. "
|
||||
"For now, we only support reading experience batches produced "
|
||||
"with batch_mode='complete_episodes'.",
|
||||
len(sample_batch["dones"]), sample_batch["dones"][-1])
|
||||
batch = compute_advantages(
|
||||
sample_batch, last_r, gamma=self.config["gamma"], use_gae=False)
|
||||
return batch
|
||||
|
||||
@override(PolicyGraph)
|
||||
def get_initial_state(self):
|
||||
return self.model.state_init
|
||||
|
||||
@@ -20,7 +20,7 @@ class _MockAgent(Agent):
|
||||
"num_workers": 0,
|
||||
})
|
||||
|
||||
def _init(self):
|
||||
def _init(self, config, env_creator):
|
||||
self.info = None
|
||||
self.restored = False
|
||||
|
||||
|
||||
@@ -34,20 +34,20 @@ class PGAgent(Agent):
|
||||
_policy_graph = PGPolicyGraph
|
||||
|
||||
@override(Agent)
|
||||
def _init(self):
|
||||
if self.config["use_pytorch"]:
|
||||
def _init(self, config, env_creator):
|
||||
if config["use_pytorch"]:
|
||||
from ray.rllib.agents.pg.torch_pg_policy_graph import \
|
||||
PGTorchPolicyGraph
|
||||
policy_cls = PGTorchPolicyGraph
|
||||
else:
|
||||
policy_cls = self._policy_graph
|
||||
self.local_evaluator = self.make_local_evaluator(
|
||||
self.env_creator, policy_cls)
|
||||
env_creator, policy_cls)
|
||||
self.remote_evaluators = self.make_remote_evaluators(
|
||||
self.env_creator, policy_cls, self.config["num_workers"])
|
||||
env_creator, policy_cls, config["num_workers"])
|
||||
optimizer_config = dict(
|
||||
self.config["optimizer"],
|
||||
**{"train_batch_size": self.config["train_batch_size"]})
|
||||
config["optimizer"],
|
||||
**{"train_batch_size": config["train_batch_size"]})
|
||||
self.optimizer = SyncSamplesOptimizer(
|
||||
self.local_evaluator, self.remote_evaluators, optimizer_config)
|
||||
|
||||
|
||||
@@ -6,20 +6,35 @@ import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||
Postprocessing
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
|
||||
class PGLoss(object):
|
||||
"""Simple policy gradient loss."""
|
||||
"""The basic policy gradient loss."""
|
||||
|
||||
def __init__(self, action_dist, actions, advantages):
|
||||
self.loss = -tf.reduce_mean(action_dist.logp(actions) * advantages)
|
||||
|
||||
|
||||
class PGPolicyGraph(TFPolicyGraph):
|
||||
class PGPostprocessing(object):
|
||||
"""Adds the advantages field to the trajectory."""
|
||||
|
||||
@override(PolicyGraph)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
# This adds the "advantages" column to the sample batch
|
||||
return compute_advantages(
|
||||
sample_batch, 0.0, self.config["gamma"], use_gae=False)
|
||||
|
||||
|
||||
class PGPolicyGraph(PGPostprocessing, TFPolicyGraph):
|
||||
"""Simple policy gradient example of defining a policy graph."""
|
||||
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
@@ -51,11 +66,11 @@ class PGPolicyGraph(TFPolicyGraph):
|
||||
# read from postprocessed sample batches and fed into the specified
|
||||
# placeholders during loss computation.
|
||||
loss_in = [
|
||||
("obs", obs),
|
||||
("actions", actions),
|
||||
("prev_actions", prev_actions),
|
||||
("prev_rewards", prev_rewards),
|
||||
("advantages", advantages), # added during postprocessing
|
||||
(SampleBatch.CUR_OBS, obs),
|
||||
(SampleBatch.ACTIONS, actions),
|
||||
(SampleBatch.PREV_ACTIONS, prev_actions),
|
||||
(SampleBatch.PREV_REWARDS, prev_rewards),
|
||||
(Postprocessing.ADVANTAGES, advantages),
|
||||
]
|
||||
|
||||
# Initialize TFPolicyGraph
|
||||
@@ -79,15 +94,6 @@ class PGPolicyGraph(TFPolicyGraph):
|
||||
max_seq_len=config["model"]["max_seq_len"])
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
@override(PolicyGraph)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
# This adds the "advantages" column to the sample batch
|
||||
return compute_advantages(
|
||||
sample_batch, 0.0, self.config["gamma"], use_gae=False)
|
||||
|
||||
@override(PolicyGraph)
|
||||
def get_initial_state(self):
|
||||
return self.model.state_init
|
||||
|
||||
@@ -8,8 +8,10 @@ from torch import nn
|
||||
|
||||
import ray
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||
Postprocessing
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.torch_policy_graph import TorchPolicyGraph
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
@@ -27,7 +29,23 @@ class PGLoss(nn.Module):
|
||||
return pi_err
|
||||
|
||||
|
||||
class PGTorchPolicyGraph(TorchPolicyGraph):
|
||||
class PGPostprocessing(object):
|
||||
"""Adds the value func output and advantages field to the trajectory."""
|
||||
|
||||
@override(TorchPolicyGraph)
|
||||
def extra_action_out(self, model_out):
|
||||
return {SampleBatch.VF_PREDS: model_out[2].numpy()}
|
||||
|
||||
@override(PolicyGraph)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
return compute_advantages(
|
||||
sample_batch, 0.0, self.config["gamma"], use_gae=False)
|
||||
|
||||
|
||||
class PGTorchPolicyGraph(PGPostprocessing, TorchPolicyGraph):
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
|
||||
self.config = config
|
||||
@@ -43,24 +61,15 @@ class PGTorchPolicyGraph(TorchPolicyGraph):
|
||||
action_space,
|
||||
self.model,
|
||||
loss,
|
||||
loss_inputs=["obs", "actions", "advantages"])
|
||||
|
||||
@override(TorchPolicyGraph)
|
||||
def extra_action_out(self, model_out):
|
||||
return {"vf_preds": model_out[2].numpy()}
|
||||
loss_inputs=[
|
||||
SampleBatch.CUR_OBS, SampleBatch.ACTIONS,
|
||||
Postprocessing.ADVANTAGES
|
||||
])
|
||||
|
||||
@override(TorchPolicyGraph)
|
||||
def optimizer(self):
|
||||
return torch.optim.Adam(self.model.parameters(), lr=self.config["lr"])
|
||||
|
||||
@override(PolicyGraph)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
return compute_advantages(
|
||||
sample_batch, 0.0, self.config["gamma"], use_gae=False)
|
||||
|
||||
def _value(self, obs):
|
||||
with self.lock:
|
||||
obs = torch.from_numpy(obs).float().unsqueeze(0)
|
||||
|
||||
@@ -1,488 +1,497 @@
|
||||
"""Adapted from VTracePolicyGraph to use the PPO surrogate loss.
|
||||
|
||||
Keep in sync with changes to VTracePolicyGraph."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import logging
|
||||
import gym
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.impala import vtrace
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
|
||||
LearningRateSchedule
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.models.action_dist import MultiCategorical
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PPOSurrogateLoss(object):
|
||||
"""Loss used when V-trace is disabled.
|
||||
|
||||
Arguments:
|
||||
prev_actions_logp: A float32 tensor of shape [T, B].
|
||||
actions_logp: A float32 tensor of shape [T, B].
|
||||
action_kl: A float32 tensor of shape [T, B].
|
||||
actions_entropy: A float32 tensor of shape [T, B].
|
||||
values: A float32 tensor of shape [T, B].
|
||||
valid_mask: A bool tensor of valid RNN input elements (#2992).
|
||||
advantages: A float32 tensor of shape [T, B].
|
||||
value_targets: A float32 tensor of shape [T, B].
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
prev_actions_logp,
|
||||
actions_logp,
|
||||
action_kl,
|
||||
actions_entropy,
|
||||
values,
|
||||
valid_mask,
|
||||
advantages,
|
||||
value_targets,
|
||||
vf_loss_coeff=0.5,
|
||||
entropy_coeff=0.01,
|
||||
clip_param=0.3):
|
||||
|
||||
logp_ratio = tf.exp(actions_logp - prev_actions_logp)
|
||||
|
||||
surrogate_loss = tf.minimum(
|
||||
advantages * logp_ratio,
|
||||
advantages * tf.clip_by_value(logp_ratio, 1 - clip_param,
|
||||
1 + clip_param))
|
||||
|
||||
self.mean_kl = tf.reduce_mean(action_kl)
|
||||
self.pi_loss = -tf.reduce_sum(surrogate_loss)
|
||||
|
||||
# The baseline loss
|
||||
delta = tf.boolean_mask(values - value_targets, valid_mask)
|
||||
self.value_targets = value_targets
|
||||
self.vf_loss = 0.5 * tf.reduce_sum(tf.square(delta))
|
||||
|
||||
# The entropy loss
|
||||
self.entropy = tf.reduce_sum(
|
||||
tf.boolean_mask(actions_entropy, valid_mask))
|
||||
|
||||
# The summed weighted loss
|
||||
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
|
||||
self.entropy * entropy_coeff)
|
||||
|
||||
|
||||
class VTraceSurrogateLoss(object):
|
||||
def __init__(self,
|
||||
actions,
|
||||
prev_actions_logp,
|
||||
actions_logp,
|
||||
action_kl,
|
||||
actions_entropy,
|
||||
dones,
|
||||
behaviour_logits,
|
||||
target_logits,
|
||||
discount,
|
||||
rewards,
|
||||
values,
|
||||
bootstrap_value,
|
||||
valid_mask,
|
||||
vf_loss_coeff=0.5,
|
||||
entropy_coeff=0.01,
|
||||
clip_rho_threshold=1.0,
|
||||
clip_pg_rho_threshold=1.0,
|
||||
clip_param=0.3):
|
||||
"""PPO surrogate loss with vtrace importance weighting.
|
||||
|
||||
VTraceLoss takes tensors of shape [T, B, ...], where `B` is the
|
||||
batch_size. The reason we need to know `B` is for V-trace to properly
|
||||
handle episode cut boundaries.
|
||||
|
||||
Arguments:
|
||||
actions: An int32 tensor of shape [T, B, NUM_ACTIONS].
|
||||
prev_actions_logp: A float32 tensor of shape [T, B].
|
||||
actions_logp: A float32 tensor of shape [T, B].
|
||||
action_kl: A float32 tensor of shape [T, B].
|
||||
actions_entropy: A float32 tensor of shape [T, B].
|
||||
dones: A bool tensor of shape [T, B].
|
||||
behaviour_logits: A float32 tensor of shape [T, B, NUM_ACTIONS].
|
||||
target_logits: A float32 tensor of shape [T, B, NUM_ACTIONS].
|
||||
discount: A float32 scalar.
|
||||
rewards: A float32 tensor of shape [T, B].
|
||||
values: A float32 tensor of shape [T, B].
|
||||
bootstrap_value: A float32 tensor of shape [B].
|
||||
valid_mask: A bool tensor of valid RNN input elements (#2992).
|
||||
"""
|
||||
|
||||
# Compute vtrace on the CPU for better perf.
|
||||
with tf.device("/cpu:0"):
|
||||
self.vtrace_returns = vtrace.multi_from_logits(
|
||||
behaviour_policy_logits=behaviour_logits,
|
||||
target_policy_logits=target_logits,
|
||||
actions=tf.unstack(tf.cast(actions, tf.int32), axis=2),
|
||||
discounts=tf.to_float(~dones) * discount,
|
||||
rewards=rewards,
|
||||
values=values,
|
||||
bootstrap_value=bootstrap_value,
|
||||
clip_rho_threshold=tf.cast(clip_rho_threshold, tf.float32),
|
||||
clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold,
|
||||
tf.float32))
|
||||
|
||||
logp_ratio = tf.exp(actions_logp - prev_actions_logp)
|
||||
|
||||
advantages = self.vtrace_returns.pg_advantages
|
||||
surrogate_loss = tf.minimum(
|
||||
advantages * logp_ratio,
|
||||
advantages * tf.clip_by_value(logp_ratio, 1 - clip_param,
|
||||
1 + clip_param))
|
||||
|
||||
self.mean_kl = tf.reduce_mean(action_kl)
|
||||
self.pi_loss = -tf.reduce_sum(surrogate_loss)
|
||||
|
||||
# The baseline loss
|
||||
delta = tf.boolean_mask(values - self.vtrace_returns.vs, valid_mask)
|
||||
self.value_targets = self.vtrace_returns.vs
|
||||
self.vf_loss = 0.5 * tf.reduce_sum(tf.square(delta))
|
||||
|
||||
# The entropy loss
|
||||
self.entropy = tf.reduce_sum(
|
||||
tf.boolean_mask(actions_entropy, valid_mask))
|
||||
|
||||
# The summed weighted loss
|
||||
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
|
||||
self.entropy * entropy_coeff)
|
||||
|
||||
|
||||
class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
def __init__(self,
|
||||
observation_space,
|
||||
action_space,
|
||||
config,
|
||||
existing_inputs=None):
|
||||
config = dict(ray.rllib.agents.impala.impala.DEFAULT_CONFIG, **config)
|
||||
assert config["batch_mode"] == "truncate_episodes", \
|
||||
"Must use `truncate_episodes` batch mode with V-trace."
|
||||
self.config = config
|
||||
self.sess = tf.get_default_session()
|
||||
self.grads = None
|
||||
|
||||
if isinstance(action_space, gym.spaces.Discrete):
|
||||
is_multidiscrete = False
|
||||
output_hidden_shape = [action_space.n]
|
||||
elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete):
|
||||
is_multidiscrete = True
|
||||
output_hidden_shape = action_space.nvec.astype(np.int32)
|
||||
elif self.config["vtrace"]:
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space {} is not supported for APPO + VTrace.",
|
||||
format(action_space))
|
||||
else:
|
||||
is_multidiscrete = False
|
||||
output_hidden_shape = 1
|
||||
|
||||
# Policy network model
|
||||
dist_class, logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"])
|
||||
|
||||
# Create input placeholders
|
||||
if existing_inputs:
|
||||
if self.config["vtrace"]:
|
||||
actions, dones, behaviour_logits, rewards, observations, \
|
||||
prev_actions, prev_rewards = existing_inputs[:7]
|
||||
existing_state_in = existing_inputs[7:-1]
|
||||
existing_seq_lens = existing_inputs[-1]
|
||||
else:
|
||||
actions, dones, behaviour_logits, rewards, observations, \
|
||||
prev_actions, prev_rewards, adv_ph, value_targets = \
|
||||
existing_inputs[:9]
|
||||
existing_state_in = existing_inputs[9:-1]
|
||||
existing_seq_lens = existing_inputs[-1]
|
||||
else:
|
||||
actions = ModelCatalog.get_action_placeholder(action_space)
|
||||
dones = tf.placeholder(tf.bool, [None], name="dones")
|
||||
rewards = tf.placeholder(tf.float32, [None], name="rewards")
|
||||
behaviour_logits = tf.placeholder(
|
||||
tf.float32, [None, logit_dim], name="behaviour_logits")
|
||||
observations = tf.placeholder(
|
||||
tf.float32, [None] + list(observation_space.shape))
|
||||
existing_state_in = None
|
||||
existing_seq_lens = None
|
||||
|
||||
if not self.config["vtrace"]:
|
||||
adv_ph = tf.placeholder(
|
||||
tf.float32, name="advantages", shape=(None, ))
|
||||
value_targets = tf.placeholder(
|
||||
tf.float32, name="value_targets", shape=(None, ))
|
||||
self.observations = observations
|
||||
|
||||
# Unpack behaviour logits
|
||||
unpacked_behaviour_logits = tf.split(
|
||||
behaviour_logits, output_hidden_shape, axis=1)
|
||||
|
||||
# Setup the policy
|
||||
dist_class, logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"])
|
||||
prev_actions = ModelCatalog.get_action_placeholder(action_space)
|
||||
prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward")
|
||||
self.model = ModelCatalog.get_model(
|
||||
{
|
||||
"obs": observations,
|
||||
"prev_actions": prev_actions,
|
||||
"prev_rewards": prev_rewards,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
},
|
||||
observation_space,
|
||||
action_space,
|
||||
logit_dim,
|
||||
self.config["model"],
|
||||
state_in=existing_state_in,
|
||||
seq_lens=existing_seq_lens)
|
||||
unpacked_outputs = tf.split(
|
||||
self.model.outputs, output_hidden_shape, axis=1)
|
||||
|
||||
dist_inputs = unpacked_outputs if is_multidiscrete else \
|
||||
self.model.outputs
|
||||
prev_dist_inputs = unpacked_behaviour_logits if is_multidiscrete else \
|
||||
behaviour_logits
|
||||
|
||||
action_dist = dist_class(dist_inputs)
|
||||
prev_action_dist = dist_class(prev_dist_inputs)
|
||||
|
||||
values = self.model.value_function()
|
||||
self.value_function = values
|
||||
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
|
||||
tf.get_variable_scope().name)
|
||||
|
||||
def make_time_major(tensor, drop_last=False):
|
||||
"""Swaps batch and trajectory axis.
|
||||
Args:
|
||||
tensor: A tensor or list of tensors to reshape.
|
||||
drop_last: A bool indicating whether to drop the last
|
||||
trajectory item.
|
||||
Returns:
|
||||
res: A tensor with swapped axes or a list of tensors with
|
||||
swapped axes.
|
||||
"""
|
||||
if isinstance(tensor, list):
|
||||
return [make_time_major(t, drop_last) for t in tensor]
|
||||
|
||||
if self.model.state_init:
|
||||
B = tf.shape(self.model.seq_lens)[0]
|
||||
T = tf.shape(tensor)[0] // B
|
||||
else:
|
||||
# Important: chop the tensor into batches at known episode cut
|
||||
# boundaries. TODO(ekl) this is kind of a hack
|
||||
T = self.config["sample_batch_size"]
|
||||
B = tf.shape(tensor)[0] // T
|
||||
rs = tf.reshape(tensor,
|
||||
tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0))
|
||||
|
||||
# swap B and T axes
|
||||
res = tf.transpose(
|
||||
rs,
|
||||
[1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0]))))
|
||||
|
||||
if drop_last:
|
||||
return res[:-1]
|
||||
return res
|
||||
|
||||
if self.model.state_in:
|
||||
max_seq_len = tf.reduce_max(self.model.seq_lens) - 1
|
||||
mask = tf.sequence_mask(self.model.seq_lens, max_seq_len)
|
||||
mask = tf.reshape(mask, [-1])
|
||||
else:
|
||||
mask = tf.ones_like(rewards)
|
||||
|
||||
# Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc.
|
||||
if self.config["vtrace"]:
|
||||
logger.info("Using V-Trace surrogate loss (vtrace=True)")
|
||||
|
||||
# Prepare actions for loss
|
||||
loss_actions = actions if is_multidiscrete else tf.expand_dims(
|
||||
actions, axis=1)
|
||||
|
||||
self.loss = VTraceSurrogateLoss(
|
||||
actions=make_time_major(loss_actions, drop_last=True),
|
||||
prev_actions_logp=make_time_major(
|
||||
prev_action_dist.logp(actions), drop_last=True),
|
||||
actions_logp=make_time_major(
|
||||
action_dist.logp(actions), drop_last=True),
|
||||
action_kl=prev_action_dist.kl(action_dist),
|
||||
actions_entropy=make_time_major(
|
||||
action_dist.entropy(), drop_last=True),
|
||||
dones=make_time_major(dones, drop_last=True),
|
||||
behaviour_logits=make_time_major(
|
||||
unpacked_behaviour_logits, drop_last=True),
|
||||
target_logits=make_time_major(
|
||||
unpacked_outputs, drop_last=True),
|
||||
discount=config["gamma"],
|
||||
rewards=make_time_major(rewards, drop_last=True),
|
||||
values=make_time_major(values, drop_last=True),
|
||||
bootstrap_value=make_time_major(values)[-1],
|
||||
valid_mask=make_time_major(mask, drop_last=True),
|
||||
vf_loss_coeff=self.config["vf_loss_coeff"],
|
||||
entropy_coeff=self.config["entropy_coeff"],
|
||||
clip_rho_threshold=self.config["vtrace_clip_rho_threshold"],
|
||||
clip_pg_rho_threshold=self.config[
|
||||
"vtrace_clip_pg_rho_threshold"],
|
||||
clip_param=self.config["clip_param"])
|
||||
else:
|
||||
logger.info("Using PPO surrogate loss (vtrace=False)")
|
||||
self.loss = PPOSurrogateLoss(
|
||||
prev_actions_logp=make_time_major(
|
||||
prev_action_dist.logp(actions)),
|
||||
actions_logp=make_time_major(action_dist.logp(actions)),
|
||||
action_kl=prev_action_dist.kl(action_dist),
|
||||
actions_entropy=make_time_major(action_dist.entropy()),
|
||||
values=make_time_major(values),
|
||||
valid_mask=make_time_major(mask),
|
||||
advantages=make_time_major(adv_ph),
|
||||
value_targets=make_time_major(value_targets),
|
||||
vf_loss_coeff=self.config["vf_loss_coeff"],
|
||||
entropy_coeff=self.config["entropy_coeff"],
|
||||
clip_param=self.config["clip_param"])
|
||||
|
||||
# KL divergence between worker and learner logits for debugging
|
||||
model_dist = MultiCategorical(unpacked_outputs)
|
||||
behaviour_dist = MultiCategorical(unpacked_behaviour_logits)
|
||||
|
||||
kls = model_dist.kl(behaviour_dist)
|
||||
if len(kls) > 1:
|
||||
self.KL_stats = {}
|
||||
|
||||
for i, kl in enumerate(kls):
|
||||
self.KL_stats.update({
|
||||
"mean_KL_{}".format(i): tf.reduce_mean(kl),
|
||||
"max_KL_{}".format(i): tf.reduce_max(kl),
|
||||
"median_KL_{}".format(i): tf.contrib.distributions.
|
||||
percentile(kl, 50.0),
|
||||
})
|
||||
else:
|
||||
self.KL_stats = {
|
||||
"mean_KL": tf.reduce_mean(kls[0]),
|
||||
"max_KL": tf.reduce_max(kls[0]),
|
||||
"median_KL": tf.contrib.distributions.percentile(kls[0], 50.0),
|
||||
}
|
||||
|
||||
# Initialize TFPolicyGraph
|
||||
loss_in = [
|
||||
("actions", actions),
|
||||
("dones", dones),
|
||||
("behaviour_logits", behaviour_logits),
|
||||
("rewards", rewards),
|
||||
("obs", observations),
|
||||
("prev_actions", prev_actions),
|
||||
("prev_rewards", prev_rewards),
|
||||
]
|
||||
if not self.config["vtrace"]:
|
||||
loss_in.append(("advantages", adv_ph))
|
||||
loss_in.append(("value_targets", value_targets))
|
||||
LearningRateSchedule.__init__(self, self.config["lr"],
|
||||
self.config["lr_schedule"])
|
||||
TFPolicyGraph.__init__(
|
||||
self,
|
||||
observation_space,
|
||||
action_space,
|
||||
self.sess,
|
||||
obs_input=observations,
|
||||
action_sampler=action_dist.sample(),
|
||||
action_prob=action_dist.sampled_action_prob(),
|
||||
loss=self.loss.total_loss,
|
||||
model=self.model,
|
||||
loss_inputs=loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
state_outputs=self.model.state_out,
|
||||
prev_action_input=prev_actions,
|
||||
prev_reward_input=prev_rewards,
|
||||
seq_lens=self.model.seq_lens,
|
||||
max_seq_len=self.config["model"]["max_seq_len"],
|
||||
batch_divisibility_req=self.config["sample_batch_size"])
|
||||
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
values_batched = make_time_major(
|
||||
values, drop_last=self.config["vtrace"])
|
||||
self.stats_fetches = {
|
||||
LEARNER_STATS_KEY: dict({
|
||||
"cur_lr": tf.cast(self.cur_lr, tf.float64),
|
||||
"policy_loss": self.loss.pi_loss,
|
||||
"entropy": self.loss.entropy,
|
||||
"grad_gnorm": tf.global_norm(self._grads),
|
||||
"var_gnorm": tf.global_norm(self.var_list),
|
||||
"vf_loss": self.loss.vf_loss,
|
||||
"vf_explained_var": explained_variance(
|
||||
tf.reshape(self.loss.value_targets, [-1]),
|
||||
tf.reshape(values_batched, [-1])),
|
||||
}, **self.KL_stats),
|
||||
}
|
||||
|
||||
def optimizer(self):
|
||||
if self.config["opt_type"] == "adam":
|
||||
return tf.train.AdamOptimizer(self.cur_lr)
|
||||
else:
|
||||
return tf.train.RMSPropOptimizer(self.cur_lr, self.config["decay"],
|
||||
self.config["momentum"],
|
||||
self.config["epsilon"])
|
||||
|
||||
def gradients(self, optimizer, loss):
|
||||
grads = tf.gradients(loss, self.var_list)
|
||||
self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"])
|
||||
clipped_grads = list(zip(self.grads, self.var_list))
|
||||
return clipped_grads
|
||||
|
||||
def extra_compute_action_fetches(self):
|
||||
out = {"behaviour_logits": self.model.outputs}
|
||||
if not self.config["vtrace"]:
|
||||
out["vf_preds"] = self.value_function
|
||||
return dict(TFPolicyGraph.extra_compute_action_fetches(self), **out)
|
||||
|
||||
def extra_compute_grad_fetches(self):
|
||||
return self.stats_fetches
|
||||
|
||||
def value(self, ob, *args):
|
||||
feed_dict = {self.observations: [ob], self.model.seq_lens: [1]}
|
||||
assert len(args) == len(self.model.state_in), \
|
||||
(args, self.model.state_in)
|
||||
for k, v in zip(self.model.state_in, args):
|
||||
feed_dict[k] = v
|
||||
vf = self.sess.run(self.value_function, feed_dict)
|
||||
return vf[0]
|
||||
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
if not self.config["vtrace"]:
|
||||
completed = sample_batch["dones"][-1]
|
||||
if completed:
|
||||
last_r = 0.0
|
||||
else:
|
||||
next_state = []
|
||||
for i in range(len(self.model.state_in)):
|
||||
next_state.append(
|
||||
[sample_batch["state_out_{}".format(i)][-1]])
|
||||
last_r = self.value(sample_batch["new_obs"][-1], *next_state)
|
||||
batch = compute_advantages(
|
||||
sample_batch,
|
||||
last_r,
|
||||
self.config["gamma"],
|
||||
self.config["lambda"],
|
||||
use_gae=self.config["use_gae"])
|
||||
else:
|
||||
batch = sample_batch
|
||||
del batch.data["new_obs"] # not used, so save some bandwidth
|
||||
return batch
|
||||
|
||||
def get_initial_state(self):
|
||||
return self.model.state_init
|
||||
|
||||
def copy(self, existing_inputs):
|
||||
return AsyncPPOPolicyGraph(
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
self.config,
|
||||
existing_inputs=existing_inputs)
|
||||
"""Adapted from VTracePolicyGraph to use the PPO surrogate loss.
|
||||
|
||||
Keep in sync with changes to VTracePolicyGraph."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import logging
|
||||
import gym
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.impala import vtrace
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
|
||||
LearningRateSchedule
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.models.action_dist import MultiCategorical
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PPOSurrogateLoss(object):
|
||||
"""Loss used when V-trace is disabled.
|
||||
|
||||
Arguments:
|
||||
prev_actions_logp: A float32 tensor of shape [T, B].
|
||||
actions_logp: A float32 tensor of shape [T, B].
|
||||
action_kl: A float32 tensor of shape [T, B].
|
||||
actions_entropy: A float32 tensor of shape [T, B].
|
||||
values: A float32 tensor of shape [T, B].
|
||||
valid_mask: A bool tensor of valid RNN input elements (#2992).
|
||||
advantages: A float32 tensor of shape [T, B].
|
||||
value_targets: A float32 tensor of shape [T, B].
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
prev_actions_logp,
|
||||
actions_logp,
|
||||
action_kl,
|
||||
actions_entropy,
|
||||
values,
|
||||
valid_mask,
|
||||
advantages,
|
||||
value_targets,
|
||||
vf_loss_coeff=0.5,
|
||||
entropy_coeff=0.01,
|
||||
clip_param=0.3):
|
||||
|
||||
logp_ratio = tf.exp(actions_logp - prev_actions_logp)
|
||||
|
||||
surrogate_loss = tf.minimum(
|
||||
advantages * logp_ratio,
|
||||
advantages * tf.clip_by_value(logp_ratio, 1 - clip_param,
|
||||
1 + clip_param))
|
||||
|
||||
self.mean_kl = tf.reduce_mean(action_kl)
|
||||
self.pi_loss = -tf.reduce_sum(surrogate_loss)
|
||||
|
||||
# The baseline loss
|
||||
delta = tf.boolean_mask(values - value_targets, valid_mask)
|
||||
self.value_targets = value_targets
|
||||
self.vf_loss = 0.5 * tf.reduce_sum(tf.square(delta))
|
||||
|
||||
# The entropy loss
|
||||
self.entropy = tf.reduce_sum(
|
||||
tf.boolean_mask(actions_entropy, valid_mask))
|
||||
|
||||
# The summed weighted loss
|
||||
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
|
||||
self.entropy * entropy_coeff)
|
||||
|
||||
|
||||
class VTraceSurrogateLoss(object):
|
||||
def __init__(self,
|
||||
actions,
|
||||
prev_actions_logp,
|
||||
actions_logp,
|
||||
action_kl,
|
||||
actions_entropy,
|
||||
dones,
|
||||
behaviour_logits,
|
||||
target_logits,
|
||||
discount,
|
||||
rewards,
|
||||
values,
|
||||
bootstrap_value,
|
||||
valid_mask,
|
||||
vf_loss_coeff=0.5,
|
||||
entropy_coeff=0.01,
|
||||
clip_rho_threshold=1.0,
|
||||
clip_pg_rho_threshold=1.0,
|
||||
clip_param=0.3):
|
||||
"""PPO surrogate loss with vtrace importance weighting.
|
||||
|
||||
VTraceLoss takes tensors of shape [T, B, ...], where `B` is the
|
||||
batch_size. The reason we need to know `B` is for V-trace to properly
|
||||
handle episode cut boundaries.
|
||||
|
||||
Arguments:
|
||||
actions: An int32 tensor of shape [T, B, NUM_ACTIONS].
|
||||
prev_actions_logp: A float32 tensor of shape [T, B].
|
||||
actions_logp: A float32 tensor of shape [T, B].
|
||||
action_kl: A float32 tensor of shape [T, B].
|
||||
actions_entropy: A float32 tensor of shape [T, B].
|
||||
dones: A bool tensor of shape [T, B].
|
||||
behaviour_logits: A float32 tensor of shape [T, B, NUM_ACTIONS].
|
||||
target_logits: A float32 tensor of shape [T, B, NUM_ACTIONS].
|
||||
discount: A float32 scalar.
|
||||
rewards: A float32 tensor of shape [T, B].
|
||||
values: A float32 tensor of shape [T, B].
|
||||
bootstrap_value: A float32 tensor of shape [B].
|
||||
valid_mask: A bool tensor of valid RNN input elements (#2992).
|
||||
"""
|
||||
|
||||
# Compute vtrace on the CPU for better perf.
|
||||
with tf.device("/cpu:0"):
|
||||
self.vtrace_returns = vtrace.multi_from_logits(
|
||||
behaviour_policy_logits=behaviour_logits,
|
||||
target_policy_logits=target_logits,
|
||||
actions=tf.unstack(tf.cast(actions, tf.int32), axis=2),
|
||||
discounts=tf.to_float(~dones) * discount,
|
||||
rewards=rewards,
|
||||
values=values,
|
||||
bootstrap_value=bootstrap_value,
|
||||
clip_rho_threshold=tf.cast(clip_rho_threshold, tf.float32),
|
||||
clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold,
|
||||
tf.float32))
|
||||
|
||||
logp_ratio = tf.exp(actions_logp - prev_actions_logp)
|
||||
|
||||
advantages = self.vtrace_returns.pg_advantages
|
||||
surrogate_loss = tf.minimum(
|
||||
advantages * logp_ratio,
|
||||
advantages * tf.clip_by_value(logp_ratio, 1 - clip_param,
|
||||
1 + clip_param))
|
||||
|
||||
self.mean_kl = tf.reduce_mean(action_kl)
|
||||
self.pi_loss = -tf.reduce_sum(surrogate_loss)
|
||||
|
||||
# The baseline loss
|
||||
delta = tf.boolean_mask(values - self.vtrace_returns.vs, valid_mask)
|
||||
self.value_targets = self.vtrace_returns.vs
|
||||
self.vf_loss = 0.5 * tf.reduce_sum(tf.square(delta))
|
||||
|
||||
# The entropy loss
|
||||
self.entropy = tf.reduce_sum(
|
||||
tf.boolean_mask(actions_entropy, valid_mask))
|
||||
|
||||
# The summed weighted loss
|
||||
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
|
||||
self.entropy * entropy_coeff)
|
||||
|
||||
|
||||
class APPOPostprocessing(object):
|
||||
"""Adds the policy logits, VF preds, and advantages to the trajectory."""
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_action_fetches(self):
|
||||
out = {"behaviour_logits": self.model.outputs}
|
||||
if not self.config["vtrace"]:
|
||||
out["vf_preds"] = self.value_function
|
||||
return dict(TFPolicyGraph.extra_compute_action_fetches(self), **out)
|
||||
|
||||
@override(PolicyGraph)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
if not self.config["vtrace"]:
|
||||
completed = sample_batch["dones"][-1]
|
||||
if completed:
|
||||
last_r = 0.0
|
||||
else:
|
||||
next_state = []
|
||||
for i in range(len(self.model.state_in)):
|
||||
next_state.append(
|
||||
[sample_batch["state_out_{}".format(i)][-1]])
|
||||
last_r = self.value(sample_batch["new_obs"][-1], *next_state)
|
||||
batch = compute_advantages(
|
||||
sample_batch,
|
||||
last_r,
|
||||
self.config["gamma"],
|
||||
self.config["lambda"],
|
||||
use_gae=self.config["use_gae"])
|
||||
else:
|
||||
batch = sample_batch
|
||||
del batch.data["new_obs"] # not used, so save some bandwidth
|
||||
return batch
|
||||
|
||||
|
||||
class AsyncPPOPolicyGraph(LearningRateSchedule, APPOPostprocessing,
|
||||
TFPolicyGraph):
|
||||
def __init__(self,
|
||||
observation_space,
|
||||
action_space,
|
||||
config,
|
||||
existing_inputs=None):
|
||||
config = dict(ray.rllib.agents.impala.impala.DEFAULT_CONFIG, **config)
|
||||
assert config["batch_mode"] == "truncate_episodes", \
|
||||
"Must use `truncate_episodes` batch mode with V-trace."
|
||||
self.config = config
|
||||
self.sess = tf.get_default_session()
|
||||
self.grads = None
|
||||
|
||||
if isinstance(action_space, gym.spaces.Discrete):
|
||||
is_multidiscrete = False
|
||||
output_hidden_shape = [action_space.n]
|
||||
elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete):
|
||||
is_multidiscrete = True
|
||||
output_hidden_shape = action_space.nvec.astype(np.int32)
|
||||
elif self.config["vtrace"]:
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space {} is not supported for APPO + VTrace.",
|
||||
format(action_space))
|
||||
else:
|
||||
is_multidiscrete = False
|
||||
output_hidden_shape = 1
|
||||
|
||||
# Policy network model
|
||||
dist_class, logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"])
|
||||
|
||||
# Create input placeholders
|
||||
if existing_inputs:
|
||||
if self.config["vtrace"]:
|
||||
actions, dones, behaviour_logits, rewards, observations, \
|
||||
prev_actions, prev_rewards = existing_inputs[:7]
|
||||
existing_state_in = existing_inputs[7:-1]
|
||||
existing_seq_lens = existing_inputs[-1]
|
||||
else:
|
||||
actions, dones, behaviour_logits, rewards, observations, \
|
||||
prev_actions, prev_rewards, adv_ph, value_targets = \
|
||||
existing_inputs[:9]
|
||||
existing_state_in = existing_inputs[9:-1]
|
||||
existing_seq_lens = existing_inputs[-1]
|
||||
else:
|
||||
actions = ModelCatalog.get_action_placeholder(action_space)
|
||||
dones = tf.placeholder(tf.bool, [None], name="dones")
|
||||
rewards = tf.placeholder(tf.float32, [None], name="rewards")
|
||||
behaviour_logits = tf.placeholder(
|
||||
tf.float32, [None, logit_dim], name="behaviour_logits")
|
||||
observations = tf.placeholder(
|
||||
tf.float32, [None] + list(observation_space.shape))
|
||||
existing_state_in = None
|
||||
existing_seq_lens = None
|
||||
|
||||
if not self.config["vtrace"]:
|
||||
adv_ph = tf.placeholder(
|
||||
tf.float32, name="advantages", shape=(None, ))
|
||||
value_targets = tf.placeholder(
|
||||
tf.float32, name="value_targets", shape=(None, ))
|
||||
self.observations = observations
|
||||
|
||||
# Unpack behaviour logits
|
||||
unpacked_behaviour_logits = tf.split(
|
||||
behaviour_logits, output_hidden_shape, axis=1)
|
||||
|
||||
# Setup the policy
|
||||
dist_class, logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"])
|
||||
prev_actions = ModelCatalog.get_action_placeholder(action_space)
|
||||
prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward")
|
||||
self.model = ModelCatalog.get_model(
|
||||
{
|
||||
"obs": observations,
|
||||
"prev_actions": prev_actions,
|
||||
"prev_rewards": prev_rewards,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
},
|
||||
observation_space,
|
||||
action_space,
|
||||
logit_dim,
|
||||
self.config["model"],
|
||||
state_in=existing_state_in,
|
||||
seq_lens=existing_seq_lens)
|
||||
unpacked_outputs = tf.split(
|
||||
self.model.outputs, output_hidden_shape, axis=1)
|
||||
|
||||
dist_inputs = unpacked_outputs if is_multidiscrete else \
|
||||
self.model.outputs
|
||||
prev_dist_inputs = unpacked_behaviour_logits if is_multidiscrete else \
|
||||
behaviour_logits
|
||||
|
||||
action_dist = dist_class(dist_inputs)
|
||||
prev_action_dist = dist_class(prev_dist_inputs)
|
||||
|
||||
values = self.model.value_function()
|
||||
self.value_function = values
|
||||
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
|
||||
tf.get_variable_scope().name)
|
||||
|
||||
def make_time_major(tensor, drop_last=False):
|
||||
"""Swaps batch and trajectory axis.
|
||||
Args:
|
||||
tensor: A tensor or list of tensors to reshape.
|
||||
drop_last: A bool indicating whether to drop the last
|
||||
trajectory item.
|
||||
Returns:
|
||||
res: A tensor with swapped axes or a list of tensors with
|
||||
swapped axes.
|
||||
"""
|
||||
if isinstance(tensor, list):
|
||||
return [make_time_major(t, drop_last) for t in tensor]
|
||||
|
||||
if self.model.state_init:
|
||||
B = tf.shape(self.model.seq_lens)[0]
|
||||
T = tf.shape(tensor)[0] // B
|
||||
else:
|
||||
# Important: chop the tensor into batches at known episode cut
|
||||
# boundaries. TODO(ekl) this is kind of a hack
|
||||
T = self.config["sample_batch_size"]
|
||||
B = tf.shape(tensor)[0] // T
|
||||
rs = tf.reshape(tensor,
|
||||
tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0))
|
||||
|
||||
# swap B and T axes
|
||||
res = tf.transpose(
|
||||
rs,
|
||||
[1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0]))))
|
||||
|
||||
if drop_last:
|
||||
return res[:-1]
|
||||
return res
|
||||
|
||||
if self.model.state_in:
|
||||
max_seq_len = tf.reduce_max(self.model.seq_lens) - 1
|
||||
mask = tf.sequence_mask(self.model.seq_lens, max_seq_len)
|
||||
mask = tf.reshape(mask, [-1])
|
||||
else:
|
||||
mask = tf.ones_like(rewards)
|
||||
|
||||
# Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc.
|
||||
if self.config["vtrace"]:
|
||||
logger.info("Using V-Trace surrogate loss (vtrace=True)")
|
||||
|
||||
# Prepare actions for loss
|
||||
loss_actions = actions if is_multidiscrete else tf.expand_dims(
|
||||
actions, axis=1)
|
||||
|
||||
self.loss = VTraceSurrogateLoss(
|
||||
actions=make_time_major(loss_actions, drop_last=True),
|
||||
prev_actions_logp=make_time_major(
|
||||
prev_action_dist.logp(actions), drop_last=True),
|
||||
actions_logp=make_time_major(
|
||||
action_dist.logp(actions), drop_last=True),
|
||||
action_kl=prev_action_dist.kl(action_dist),
|
||||
actions_entropy=make_time_major(
|
||||
action_dist.entropy(), drop_last=True),
|
||||
dones=make_time_major(dones, drop_last=True),
|
||||
behaviour_logits=make_time_major(
|
||||
unpacked_behaviour_logits, drop_last=True),
|
||||
target_logits=make_time_major(
|
||||
unpacked_outputs, drop_last=True),
|
||||
discount=config["gamma"],
|
||||
rewards=make_time_major(rewards, drop_last=True),
|
||||
values=make_time_major(values, drop_last=True),
|
||||
bootstrap_value=make_time_major(values)[-1],
|
||||
valid_mask=make_time_major(mask, drop_last=True),
|
||||
vf_loss_coeff=self.config["vf_loss_coeff"],
|
||||
entropy_coeff=self.config["entropy_coeff"],
|
||||
clip_rho_threshold=self.config["vtrace_clip_rho_threshold"],
|
||||
clip_pg_rho_threshold=self.config[
|
||||
"vtrace_clip_pg_rho_threshold"],
|
||||
clip_param=self.config["clip_param"])
|
||||
else:
|
||||
logger.info("Using PPO surrogate loss (vtrace=False)")
|
||||
self.loss = PPOSurrogateLoss(
|
||||
prev_actions_logp=make_time_major(
|
||||
prev_action_dist.logp(actions)),
|
||||
actions_logp=make_time_major(action_dist.logp(actions)),
|
||||
action_kl=prev_action_dist.kl(action_dist),
|
||||
actions_entropy=make_time_major(action_dist.entropy()),
|
||||
values=make_time_major(values),
|
||||
valid_mask=make_time_major(mask),
|
||||
advantages=make_time_major(adv_ph),
|
||||
value_targets=make_time_major(value_targets),
|
||||
vf_loss_coeff=self.config["vf_loss_coeff"],
|
||||
entropy_coeff=self.config["entropy_coeff"],
|
||||
clip_param=self.config["clip_param"])
|
||||
|
||||
# KL divergence between worker and learner logits for debugging
|
||||
model_dist = MultiCategorical(unpacked_outputs)
|
||||
behaviour_dist = MultiCategorical(unpacked_behaviour_logits)
|
||||
|
||||
kls = model_dist.kl(behaviour_dist)
|
||||
if len(kls) > 1:
|
||||
self.KL_stats = {}
|
||||
|
||||
for i, kl in enumerate(kls):
|
||||
self.KL_stats.update({
|
||||
"mean_KL_{}".format(i): tf.reduce_mean(kl),
|
||||
"max_KL_{}".format(i): tf.reduce_max(kl),
|
||||
"median_KL_{}".format(i): tf.contrib.distributions.
|
||||
percentile(kl, 50.0),
|
||||
})
|
||||
else:
|
||||
self.KL_stats = {
|
||||
"mean_KL": tf.reduce_mean(kls[0]),
|
||||
"max_KL": tf.reduce_max(kls[0]),
|
||||
"median_KL": tf.contrib.distributions.percentile(kls[0], 50.0),
|
||||
}
|
||||
|
||||
# Initialize TFPolicyGraph
|
||||
loss_in = [
|
||||
("actions", actions),
|
||||
("dones", dones),
|
||||
("behaviour_logits", behaviour_logits),
|
||||
("rewards", rewards),
|
||||
("obs", observations),
|
||||
("prev_actions", prev_actions),
|
||||
("prev_rewards", prev_rewards),
|
||||
]
|
||||
if not self.config["vtrace"]:
|
||||
loss_in.append(("advantages", adv_ph))
|
||||
loss_in.append(("value_targets", value_targets))
|
||||
LearningRateSchedule.__init__(self, self.config["lr"],
|
||||
self.config["lr_schedule"])
|
||||
TFPolicyGraph.__init__(
|
||||
self,
|
||||
observation_space,
|
||||
action_space,
|
||||
self.sess,
|
||||
obs_input=observations,
|
||||
action_sampler=action_dist.sample(),
|
||||
action_prob=action_dist.sampled_action_prob(),
|
||||
loss=self.loss.total_loss,
|
||||
model=self.model,
|
||||
loss_inputs=loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
state_outputs=self.model.state_out,
|
||||
prev_action_input=prev_actions,
|
||||
prev_reward_input=prev_rewards,
|
||||
seq_lens=self.model.seq_lens,
|
||||
max_seq_len=self.config["model"]["max_seq_len"],
|
||||
batch_divisibility_req=self.config["sample_batch_size"])
|
||||
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
values_batched = make_time_major(
|
||||
values, drop_last=self.config["vtrace"])
|
||||
self.stats_fetches = {
|
||||
LEARNER_STATS_KEY: dict({
|
||||
"cur_lr": tf.cast(self.cur_lr, tf.float64),
|
||||
"policy_loss": self.loss.pi_loss,
|
||||
"entropy": self.loss.entropy,
|
||||
"grad_gnorm": tf.global_norm(self._grads),
|
||||
"var_gnorm": tf.global_norm(self.var_list),
|
||||
"vf_loss": self.loss.vf_loss,
|
||||
"vf_explained_var": explained_variance(
|
||||
tf.reshape(self.loss.value_targets, [-1]),
|
||||
tf.reshape(values_batched, [-1])),
|
||||
}, **self.KL_stats),
|
||||
}
|
||||
|
||||
def optimizer(self):
|
||||
if self.config["opt_type"] == "adam":
|
||||
return tf.train.AdamOptimizer(self.cur_lr)
|
||||
else:
|
||||
return tf.train.RMSPropOptimizer(self.cur_lr, self.config["decay"],
|
||||
self.config["momentum"],
|
||||
self.config["epsilon"])
|
||||
|
||||
def gradients(self, optimizer, loss):
|
||||
grads = tf.gradients(loss, self.var_list)
|
||||
self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"])
|
||||
clipped_grads = list(zip(self.grads, self.var_list))
|
||||
return clipped_grads
|
||||
|
||||
def extra_compute_grad_fetches(self):
|
||||
return self.stats_fetches
|
||||
|
||||
def value(self, ob, *args):
|
||||
feed_dict = {self.observations: [ob], self.model.seq_lens: [1]}
|
||||
assert len(args) == len(self.model.state_in), \
|
||||
(args, self.model.state_in)
|
||||
for k, v in zip(self.model.state_in, args):
|
||||
feed_dict[k] = v
|
||||
vf = self.sess.run(self.value_function, feed_dict)
|
||||
return vf[0]
|
||||
|
||||
def get_initial_state(self):
|
||||
return self.model.state_init
|
||||
|
||||
def copy(self, existing_inputs):
|
||||
return AsyncPPOPolicyGraph(
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
self.config,
|
||||
existing_inputs=existing_inputs)
|
||||
|
||||
@@ -71,30 +71,29 @@ class PPOAgent(Agent):
|
||||
_policy_graph = PPOPolicyGraph
|
||||
|
||||
@override(Agent)
|
||||
def _init(self):
|
||||
def _init(self, config, env_creator):
|
||||
self._validate_config()
|
||||
self.local_evaluator = self.make_local_evaluator(
|
||||
self.env_creator, self._policy_graph)
|
||||
env_creator, self._policy_graph)
|
||||
self.remote_evaluators = self.make_remote_evaluators(
|
||||
self.env_creator, self._policy_graph, self.config["num_workers"])
|
||||
if self.config["simple_optimizer"]:
|
||||
env_creator, self._policy_graph, config["num_workers"])
|
||||
if config["simple_optimizer"]:
|
||||
self.optimizer = SyncSamplesOptimizer(
|
||||
self.local_evaluator, self.remote_evaluators, {
|
||||
"num_sgd_iter": self.config["num_sgd_iter"],
|
||||
"train_batch_size": self.config["train_batch_size"],
|
||||
"num_sgd_iter": config["num_sgd_iter"],
|
||||
"train_batch_size": config["train_batch_size"],
|
||||
})
|
||||
else:
|
||||
self.optimizer = LocalMultiGPUOptimizer(
|
||||
self.local_evaluator, self.remote_evaluators, {
|
||||
"sgd_batch_size": self.config["sgd_minibatch_size"],
|
||||
"num_sgd_iter": self.config["num_sgd_iter"],
|
||||
"num_gpus": self.config["num_gpus"],
|
||||
"sample_batch_size": self.config["sample_batch_size"],
|
||||
"num_envs_per_worker": self.config["num_envs_per_worker"],
|
||||
"train_batch_size": self.config["train_batch_size"],
|
||||
"sgd_batch_size": config["sgd_minibatch_size"],
|
||||
"num_sgd_iter": config["num_sgd_iter"],
|
||||
"num_gpus": config["num_gpus"],
|
||||
"sample_batch_size": config["sample_batch_size"],
|
||||
"num_envs_per_worker": config["num_envs_per_worker"],
|
||||
"train_batch_size": config["train_batch_size"],
|
||||
"standardize_fields": ["advantages"],
|
||||
"straggler_mitigation": (
|
||||
self.config["straggler_mitigation"]),
|
||||
"straggler_mitigation": config["straggler_mitigation"],
|
||||
})
|
||||
|
||||
@override(Agent)
|
||||
|
||||
@@ -6,9 +6,11 @@ import logging
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||
Postprocessing
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
|
||||
LearningRateSchedule
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
@@ -17,6 +19,9 @@ from ray.rllib.utils.explained_variance import explained_variance
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Frozen logits of the policy that computed the action
|
||||
BEHAVIOUR_LOGITS = "behaviour_logits"
|
||||
|
||||
|
||||
class PPOLoss(object):
|
||||
def __init__(self,
|
||||
@@ -100,7 +105,43 @@ class PPOLoss(object):
|
||||
self.loss = loss
|
||||
|
||||
|
||||
class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
class PPOPostprocessing(object):
|
||||
"""Adds the policy logits, VF preds, and advantages to the trajectory."""
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_action_fetches(self):
|
||||
return dict(
|
||||
TFPolicyGraph.extra_compute_action_fetches(self), **{
|
||||
SampleBatch.VF_PREDS: self.value_function,
|
||||
BEHAVIOUR_LOGITS: self.logits
|
||||
})
|
||||
|
||||
@override(PolicyGraph)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
completed = sample_batch["dones"][-1]
|
||||
if completed:
|
||||
last_r = 0.0
|
||||
else:
|
||||
next_state = []
|
||||
for i in range(len(self.model.state_in)):
|
||||
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
|
||||
last_r = self._value(sample_batch[SampleBatch.NEXT_OBS][-1],
|
||||
sample_batch[SampleBatch.ACTIONS][-1],
|
||||
sample_batch[SampleBatch.REWARDS][-1],
|
||||
*next_state)
|
||||
batch = compute_advantages(
|
||||
sample_batch,
|
||||
last_r,
|
||||
self.config["gamma"],
|
||||
self.config["lambda"],
|
||||
use_gae=self.config["use_gae"])
|
||||
return batch
|
||||
|
||||
|
||||
class PPOPolicyGraph(LearningRateSchedule, PPOPostprocessing, TFPolicyGraph):
|
||||
def __init__(self,
|
||||
observation_space,
|
||||
action_space,
|
||||
@@ -153,14 +194,14 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
self.prev_rewards = prev_rewards_ph
|
||||
|
||||
self.loss_in = [
|
||||
("obs", obs_ph),
|
||||
("value_targets", value_targets_ph),
|
||||
("advantages", adv_ph),
|
||||
("actions", act_ph),
|
||||
("logits", logits_ph),
|
||||
("vf_preds", vf_preds_ph),
|
||||
("prev_actions", prev_actions_ph),
|
||||
("prev_rewards", prev_rewards_ph),
|
||||
(SampleBatch.CUR_OBS, obs_ph),
|
||||
(Postprocessing.VALUE_TARGETS, value_targets_ph),
|
||||
(Postprocessing.ADVANTAGES, adv_ph),
|
||||
(SampleBatch.ACTIONS, act_ph),
|
||||
(BEHAVIOUR_LOGITS, logits_ph),
|
||||
(SampleBatch.VF_PREDS, vf_preds_ph),
|
||||
(SampleBatch.PREV_ACTIONS, prev_actions_ph),
|
||||
(SampleBatch.PREV_REWARDS, prev_rewards_ph),
|
||||
]
|
||||
self.model = ModelCatalog.get_model(
|
||||
{
|
||||
@@ -282,29 +323,6 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
self.config,
|
||||
existing_inputs=existing_inputs)
|
||||
|
||||
@override(PolicyGraph)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
completed = sample_batch["dones"][-1]
|
||||
if completed:
|
||||
last_r = 0.0
|
||||
else:
|
||||
next_state = []
|
||||
for i in range(len(self.model.state_in)):
|
||||
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
|
||||
last_r = self._value(sample_batch["new_obs"][-1],
|
||||
sample_batch["actions"][-1],
|
||||
sample_batch["rewards"][-1], *next_state)
|
||||
batch = compute_advantages(
|
||||
sample_batch,
|
||||
last_r,
|
||||
self.config["gamma"],
|
||||
self.config["lambda"],
|
||||
use_gae=self.config["use_gae"])
|
||||
return batch
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def gradients(self, optimizer, loss):
|
||||
if self.config["grad_clip"] is not None:
|
||||
@@ -323,14 +341,6 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
def get_initial_state(self):
|
||||
return self.model.state_init
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_action_fetches(self):
|
||||
return dict(
|
||||
TFPolicyGraph.extra_compute_action_fetches(self), **{
|
||||
"vf_preds": self.value_function,
|
||||
"logits": self.logits
|
||||
})
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_grad_fetches(self):
|
||||
return {LEARNER_STATS_KEY: self.stats_fetches}
|
||||
|
||||
@@ -15,6 +15,7 @@ from ray.rllib.agents.qmix.mixers import VDNMixer, QMixer
|
||||
from ray.rllib.agents.qmix.model import RNNModel, _get_size
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.models.action_dist import TupleActions
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.lstm import chop_into_sequences
|
||||
@@ -236,16 +237,17 @@ class QMixPolicyGraph(PolicyGraph):
|
||||
|
||||
@override(PolicyGraph)
|
||||
def learn_on_batch(self, samples):
|
||||
obs_batch, action_mask = self._unpack_observation(samples["obs"])
|
||||
group_rewards = self._get_group_rewards(samples["infos"])
|
||||
obs_batch, action_mask = self._unpack_observation(
|
||||
samples[SampleBatch.CUR_OBS])
|
||||
group_rewards = self._get_group_rewards(samples[SampleBatch.INFOS])
|
||||
|
||||
# These will be padded to shape [B * T, ...]
|
||||
[rew, action_mask, act, dones, obs], initial_states, seq_lens = \
|
||||
chop_into_sequences(
|
||||
samples["eps_id"],
|
||||
samples["agent_index"], [
|
||||
group_rewards, action_mask, samples["actions"],
|
||||
samples["dones"], obs_batch
|
||||
samples[SampleBatch.EPS_ID],
|
||||
samples[SampleBatch.AGENT_INDEX], [
|
||||
group_rewards, action_mask, samples[SampleBatch.ACTIONS],
|
||||
samples[SampleBatch.DONES], obs_batch
|
||||
],
|
||||
[samples["state_in_{}".format(k)]
|
||||
for k in range(len(self.get_initial_state()))],
|
||||
|
||||
@@ -19,8 +19,8 @@ class RandomAgent(Agent):
|
||||
})
|
||||
|
||||
@override(Agent)
|
||||
def _init(self):
|
||||
self.env = self.env_creator(self.config["env_config"])
|
||||
def _init(self, config, env_creator):
|
||||
self.env = env_creator(config["env_config"])
|
||||
|
||||
@override(Agent)
|
||||
def _train(self):
|
||||
|
||||
Vendored
+1
-1
@@ -186,7 +186,7 @@ class BaseEnv(object):
|
||||
|
||||
|
||||
# Fixed agent identifier when there is only the single agent in the env
|
||||
_DUMMY_AGENT_ID = "singleton_agent"
|
||||
_DUMMY_AGENT_ID = "agent0"
|
||||
|
||||
|
||||
def _with_dummy_agent_id(env_id_to_values, dummy_id=_DUMMY_AGENT_ID):
|
||||
|
||||
@@ -8,14 +8,12 @@ from ray.rllib.evaluation.sample_batch import SampleBatch, MultiAgentBatch
|
||||
from ray.rllib.evaluation.sample_batch_builder import (
|
||||
SampleBatchBuilder, MultiAgentSampleBatchBuilder)
|
||||
from ray.rllib.evaluation.sampler import SyncSampler, AsyncSampler
|
||||
from ray.rllib.evaluation.postprocessing import (compute_advantages,
|
||||
compute_targets)
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
|
||||
__all__ = [
|
||||
"EvaluatorInterface", "PolicyEvaluator", "PolicyGraph", "TFPolicyGraph",
|
||||
"TorchPolicyGraph", "SampleBatch", "MultiAgentBatch", "SampleBatchBuilder",
|
||||
"MultiAgentSampleBatchBuilder", "SyncSampler", "AsyncSampler",
|
||||
"compute_advantages", "compute_targets", "collect_metrics",
|
||||
"MultiAgentEpisode"
|
||||
"compute_advantages", "collect_metrics", "MultiAgentEpisode"
|
||||
]
|
||||
|
||||
@@ -12,6 +12,13 @@ def discount(x, gamma):
|
||||
return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1]
|
||||
|
||||
|
||||
class Postprocessing(object):
|
||||
"""Constant definitions for postprocessing."""
|
||||
|
||||
ADVANTAGES = "advantages"
|
||||
VALUE_TARGETS = "value_targets"
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_advantages(rollout, last_r, gamma=0.9, lambda_=1.0, use_gae=True):
|
||||
"""Given a rollout, compute its value targets and the advantage.
|
||||
@@ -29,52 +36,35 @@ def compute_advantages(rollout, last_r, gamma=0.9, lambda_=1.0, use_gae=True):
|
||||
"""
|
||||
|
||||
traj = {}
|
||||
trajsize = len(rollout["actions"])
|
||||
trajsize = len(rollout[SampleBatch.ACTIONS])
|
||||
for key in rollout:
|
||||
traj[key] = np.stack(rollout[key])
|
||||
|
||||
if use_gae:
|
||||
assert "vf_preds" in rollout, "Values not found!"
|
||||
vpred_t = np.concatenate([rollout["vf_preds"], np.array([last_r])])
|
||||
delta_t = traj["rewards"] + gamma * vpred_t[1:] - vpred_t[:-1]
|
||||
assert SampleBatch.VF_PREDS in rollout, "Values not found!"
|
||||
vpred_t = np.concatenate(
|
||||
[rollout[SampleBatch.VF_PREDS],
|
||||
np.array([last_r])])
|
||||
delta_t = (
|
||||
traj[SampleBatch.REWARDS] + gamma * vpred_t[1:] - vpred_t[:-1])
|
||||
# This formula for the advantage comes
|
||||
# "Generalized Advantage Estimation": https://arxiv.org/abs/1506.02438
|
||||
traj["advantages"] = discount(delta_t, gamma * lambda_)
|
||||
traj["value_targets"] = (
|
||||
traj["advantages"] + traj["vf_preds"]).copy().astype(np.float32)
|
||||
traj[Postprocessing.ADVANTAGES] = discount(delta_t, gamma * lambda_)
|
||||
traj[Postprocessing.VALUE_TARGETS] = (
|
||||
traj[Postprocessing.ADVANTAGES] +
|
||||
traj[SampleBatch.VF_PREDS]).copy().astype(np.float32)
|
||||
else:
|
||||
rewards_plus_v = np.concatenate(
|
||||
[rollout["rewards"], np.array([last_r])])
|
||||
traj["advantages"] = discount(rewards_plus_v, gamma)[:-1]
|
||||
[rollout[SampleBatch.REWARDS],
|
||||
np.array([last_r])])
|
||||
traj[Postprocessing.ADVANTAGES] = discount(rewards_plus_v, gamma)[:-1]
|
||||
# TODO(ekl): support using a critic without GAE
|
||||
traj["value_targets"] = np.zeros_like(traj["advantages"])
|
||||
traj[Postprocessing.VALUE_TARGETS] = np.zeros_like(
|
||||
traj[Postprocessing.ADVANTAGES])
|
||||
|
||||
traj["advantages"] = traj["advantages"].copy().astype(np.float32)
|
||||
traj[Postprocessing.ADVANTAGES] = traj[
|
||||
Postprocessing.ADVANTAGES].copy().astype(np.float32)
|
||||
|
||||
assert all(val.shape[0] == trajsize for val in traj.values()), \
|
||||
"Rollout stacked incorrectly!"
|
||||
return SampleBatch(traj)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_targets(rollout, action_space, last_r=0.0, gamma=0.9, lambda_=1.0):
|
||||
"""Given a rollout, compute targets.
|
||||
|
||||
Used for categorical crossentropy loss on the policy. Also assumes
|
||||
there is a value function. Uses GAE to calculate advantages.
|
||||
|
||||
Args:
|
||||
rollout (SampleBatch): SampleBatch of a single trajectory
|
||||
action_space (gym.Space): Dimensions of the advantage targets.
|
||||
last_r (float): Value estimation for last observation
|
||||
gamma (float): Discount factor.
|
||||
lambda_ (float): Parameter for GAE
|
||||
"""
|
||||
|
||||
rollout = compute_advantages(rollout, last_r, gamma=gamma, lambda_=lambda_)
|
||||
rollout["adv_targets"] = np.zeros((rollout.count, action_space.n))
|
||||
rollout["adv_targets"][np.arange(rollout.count), rollout["actions"]] = \
|
||||
rollout["advantages"]
|
||||
rollout["value_targets"] = rollout["rewards"].copy()
|
||||
rollout["value_targets"][:-1] += gamma * rollout["vf_preds"][1:]
|
||||
return rollout
|
||||
|
||||
@@ -82,6 +82,25 @@ class SampleBatch(object):
|
||||
samples, each with an "obs" and "reward" attribute.
|
||||
"""
|
||||
|
||||
# Outputs from interacting with the environment
|
||||
CUR_OBS = "obs"
|
||||
NEXT_OBS = "new_obs"
|
||||
ACTIONS = "actions"
|
||||
REWARDS = "rewards"
|
||||
PREV_ACTIONS = "prev_actions"
|
||||
PREV_REWARDS = "prev_rewards"
|
||||
DONES = "dones"
|
||||
INFOS = "infos"
|
||||
|
||||
# Uniquely identifies an episode
|
||||
EPS_ID = "eps_id"
|
||||
|
||||
# Uniquely identifies an agent within an episode
|
||||
AGENT_INDEX = "agent_index"
|
||||
|
||||
# Value function predictions emitted by the behaviour policy
|
||||
VF_PREDS = "vf_preds"
|
||||
|
||||
@PublicAPI
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Constructs a sample batch (same params as dict constructor)."""
|
||||
|
||||
@@ -529,7 +529,7 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
|
||||
builder = None
|
||||
|
||||
if log_once("compute_actions_input"):
|
||||
logger.info("Example compute_actions() input:\n\n{}\n".format(
|
||||
logger.info("Inputs to compute_actions():\n\n{}\n".format(
|
||||
summarize(to_eval)))
|
||||
|
||||
for policy_id, eval_data in to_eval.items():
|
||||
@@ -556,7 +556,7 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
|
||||
eval_results[k] = builder.get(v)
|
||||
|
||||
if log_once("compute_actions_result"):
|
||||
logger.info("Example compute_actions() result:\n\n{}\n".format(
|
||||
logger.info("Outputs of compute_actions():\n\n{}\n".format(
|
||||
summarize(eval_results)))
|
||||
|
||||
return eval_results
|
||||
|
||||
@@ -12,6 +12,7 @@ import ray
|
||||
import ray.experimental.tf_utils
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.models.lstm import chop_into_sequences
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.debug import log_once, summarize
|
||||
@@ -437,8 +438,9 @@ class TFPolicyGraph(PolicyGraph):
|
||||
feed_dict = {}
|
||||
if self._batch_divisibility_req > 1:
|
||||
meets_divisibility_reqs = (
|
||||
len(batch["obs"]) % self._batch_divisibility_req == 0
|
||||
and max(batch["agent_index"]) == 0) # not multiagent
|
||||
len(batch[SampleBatch.CUR_OBS]) %
|
||||
self._batch_divisibility_req == 0
|
||||
and max(batch[SampleBatch.AGENT_INDEX]) == 0) # not multiagent
|
||||
else:
|
||||
meets_divisibility_reqs = True
|
||||
|
||||
@@ -461,8 +463,8 @@ class TFPolicyGraph(PolicyGraph):
|
||||
"state_in_{}".format(i) for i in range(len(self._state_inputs))
|
||||
]
|
||||
feature_sequences, initial_states, seq_lens = chop_into_sequences(
|
||||
batch["eps_id"],
|
||||
batch["agent_index"], [batch[k] for k in feature_keys],
|
||||
batch[SampleBatch.EPS_ID],
|
||||
batch[SampleBatch.AGENT_INDEX], [batch[k] for k in feature_keys],
|
||||
[batch[k] for k in state_keys],
|
||||
max_seq_len,
|
||||
dynamic_max=dynamic_max)
|
||||
|
||||
@@ -63,7 +63,7 @@ class PolicyOptimizer(object):
|
||||
config, self))
|
||||
|
||||
@DeveloperAPI
|
||||
def _init(self):
|
||||
def _init(self, **config):
|
||||
"""Subclasses should prefer overriding this instead of __init__."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -9,6 +9,8 @@ import time
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.client import timeline
|
||||
|
||||
from ray.rllib.utils.debug import log_once
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -75,12 +77,16 @@ def run_timeline(sess, ops, debug_name, feed_dict={}, timeline_dir=None):
|
||||
global _count
|
||||
outf = os.path.join(
|
||||
timeline_dir, "timeline-{}-{}-{}.json".format(
|
||||
debug_name, os.getpid(), _count))
|
||||
debug_name, os.getpid(), _count % 10))
|
||||
_count += 1
|
||||
trace_file = open(outf, "w")
|
||||
logger.info("Wrote tf timeline ({} s) to {}".format(
|
||||
time.time() - start, os.path.abspath(outf)))
|
||||
trace_file.write(trace.generate_chrome_trace_format())
|
||||
else:
|
||||
if log_once("tf_timeline"):
|
||||
logger.info(
|
||||
"Executing TF run without tracing. To dump TF timeline traces "
|
||||
"to disk, set the TF_TIMELINE_DIR environment variable.")
|
||||
fetches = sess.run(ops, feed_dict=feed_dict)
|
||||
return fetches
|
||||
|
||||
Reference in New Issue
Block a user