mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 08:53:44 +08:00
[rllib] Port IMPALA to ModelV2/build_tf_policy (#5130)
* port vtrace * fix vf * fix vs * fix the example * wip ddpg * fix tests * fix tests * remove ddpg model * comments * set vf share layers True by default * typo * fix test
This commit is contained in:
@@ -75,6 +75,12 @@ class DDPGTFPolicy(DDPGPostprocessing, TFPolicy):
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space {} is not supported for DDPG.".format(
|
||||
action_space))
|
||||
if len(action_space.shape) > 1:
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space has multiple dimensions "
|
||||
"{}. ".format(action_space.shape) +
|
||||
"Consider reshaping this into a single dimension, "
|
||||
"using a Tuple action space, or the multi-agent API.")
|
||||
|
||||
self.config = config
|
||||
self.cur_noise_scale = 1.0
|
||||
|
||||
@@ -166,6 +166,12 @@ class DistributionalQModel(TFModelV2):
|
||||
self.state_value_head = tf.keras.Model(self.model_out, state_out)
|
||||
self.register_variables(self.state_value_head.variables)
|
||||
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
"""This generates the model_out tensor input.
|
||||
|
||||
You must implement this as documented in modelv2.py."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_q_value_distributions(self, model_out):
|
||||
"""Returns distributional values for Q(s, a) given a state embedding.
|
||||
|
||||
|
||||
@@ -191,7 +191,8 @@ def build_q_model(policy, obs_space, action_space, config):
|
||||
"Action space {} is not supported for DQN.".format(action_space))
|
||||
|
||||
if config["hiddens"]:
|
||||
num_outputs = 256
|
||||
# try to infer the last layer size, otherwise fall back to 256
|
||||
num_outputs = ([256] + config["model"]["fcnet_hiddens"])[-1]
|
||||
config["model"]["no_final_linear"] = True
|
||||
else:
|
||||
num_outputs = action_space.n
|
||||
|
||||
@@ -54,6 +54,12 @@ class SimpleQModel(TFModelV2):
|
||||
self.q_value_head = tf.keras.Model(self.model_out, q_out)
|
||||
self.register_variables(self.q_value_head.variables)
|
||||
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
"""This generates the model_out tensor input.
|
||||
|
||||
You must implement this as documented in modelv2.py."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_q_values(self, model_out):
|
||||
"""Returns Q(s, a) given a feature tensor for the state.
|
||||
|
||||
|
||||
@@ -6,24 +6,23 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gym
|
||||
import ray
|
||||
import numpy as np
|
||||
import logging
|
||||
import gym
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.impala import vtrace
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.tf_policy import TFPolicy, \
|
||||
LearningRateSchedule
|
||||
from ray.rllib.models.action_dist import Categorical
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.policy.tf_policy import LearningRateSchedule
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
# Frozen logits of the policy that computed the action
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BEHAVIOUR_LOGITS = "behaviour_logits"
|
||||
|
||||
|
||||
@@ -88,6 +87,7 @@ class VTraceLoss(object):
|
||||
clip_rho_threshold=tf.cast(clip_rho_threshold, tf.float32),
|
||||
clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold,
|
||||
tf.float32))
|
||||
self.value_targets = self.vtrace_returns.vs
|
||||
|
||||
# The policy gradients loss
|
||||
self.pi_loss = -tf.reduce_sum(
|
||||
@@ -107,237 +107,191 @@ class VTraceLoss(object):
|
||||
self.entropy * entropy_coeff)
|
||||
|
||||
|
||||
class VTracePostprocessing(object):
|
||||
"""Adds the policy logits to the trajectory."""
|
||||
def _make_time_major(policy, tensor, drop_last=False):
|
||||
"""Swaps batch and trajectory axis.
|
||||
|
||||
@override(TFPolicy)
|
||||
def extra_compute_action_fetches(self):
|
||||
return dict(
|
||||
TFPolicy.extra_compute_action_fetches(self),
|
||||
**{BEHAVIOUR_LOGITS: self.model.outputs})
|
||||
Arguments:
|
||||
policy: Policy reference
|
||||
tensor: A tensor or list of tensors to reshape.
|
||||
drop_last: A bool indicating whether to drop the last
|
||||
trajectory item.
|
||||
|
||||
@override(Policy)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
# not used, so save some bandwidth
|
||||
del sample_batch.data[SampleBatch.NEXT_OBS]
|
||||
return sample_batch
|
||||
Returns:
|
||||
res: A tensor with swapped axes or a list of tensors with
|
||||
swapped axes.
|
||||
"""
|
||||
if isinstance(tensor, list):
|
||||
return [_make_time_major(policy, t, drop_last) for t in tensor]
|
||||
|
||||
if policy.state_in:
|
||||
B = tf.shape(policy.seq_lens)[0]
|
||||
T = tf.shape(tensor)[0] // B
|
||||
else:
|
||||
# Important: chop the tensor into batches at known episode cut
|
||||
# boundaries. TODO(ekl) this is kind of a hack
|
||||
T = policy.config["sample_batch_size"]
|
||||
B = tf.shape(tensor)[0] // T
|
||||
rs = tf.reshape(tensor, tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0))
|
||||
|
||||
# swap B and T axes
|
||||
res = tf.transpose(
|
||||
rs, [1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0]))))
|
||||
|
||||
if drop_last:
|
||||
return res[:-1]
|
||||
return res
|
||||
|
||||
|
||||
class VTraceTFPolicy(LearningRateSchedule, VTracePostprocessing, TFPolicy):
|
||||
def __init__(self,
|
||||
observation_space,
|
||||
action_space,
|
||||
config,
|
||||
existing_inputs=None):
|
||||
config = dict(ray.rllib.agents.impala.impala.DEFAULT_CONFIG, **config)
|
||||
assert config["batch_mode"] == "truncate_episodes", \
|
||||
"Must use `truncate_episodes` batch mode with V-trace."
|
||||
self.config = config
|
||||
self.sess = tf.get_default_session()
|
||||
self.grads = None
|
||||
def build_vtrace_loss(policy, batch_tensors):
|
||||
if isinstance(policy.action_space, gym.spaces.Discrete):
|
||||
is_multidiscrete = False
|
||||
output_hidden_shape = [policy.action_space.n]
|
||||
elif isinstance(policy.action_space,
|
||||
gym.spaces.multi_discrete.MultiDiscrete):
|
||||
is_multidiscrete = True
|
||||
output_hidden_shape = policy.action_space.nvec.astype(np.int32)
|
||||
else:
|
||||
is_multidiscrete = False
|
||||
output_hidden_shape = 1
|
||||
|
||||
if isinstance(action_space, gym.spaces.Discrete):
|
||||
is_multidiscrete = False
|
||||
output_hidden_shape = [action_space.n]
|
||||
elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete):
|
||||
is_multidiscrete = True
|
||||
output_hidden_shape = action_space.nvec.astype(np.int32)
|
||||
else:
|
||||
is_multidiscrete = False
|
||||
output_hidden_shape = 1
|
||||
def make_time_major(*args, **kw):
|
||||
return _make_time_major(policy, *args, **kw)
|
||||
|
||||
# Create input placeholders
|
||||
dist_class, logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"])
|
||||
if existing_inputs:
|
||||
actions, dones, behaviour_logits, rewards, observations, \
|
||||
prev_actions, prev_rewards = existing_inputs[:7]
|
||||
existing_state_in = existing_inputs[7:-1]
|
||||
existing_seq_lens = existing_inputs[-1]
|
||||
else:
|
||||
actions = ModelCatalog.get_action_placeholder(action_space)
|
||||
dones = tf.placeholder(tf.bool, [None], name="dones")
|
||||
rewards = tf.placeholder(tf.float32, [None], name="rewards")
|
||||
behaviour_logits = tf.placeholder(
|
||||
tf.float32, [None, logit_dim], name="behaviour_logits")
|
||||
observations = tf.placeholder(
|
||||
tf.float32, [None] + list(observation_space.shape))
|
||||
existing_state_in = None
|
||||
existing_seq_lens = None
|
||||
actions = batch_tensors[SampleBatch.ACTIONS]
|
||||
dones = batch_tensors[SampleBatch.DONES]
|
||||
rewards = batch_tensors[SampleBatch.REWARDS]
|
||||
behaviour_logits = batch_tensors[BEHAVIOUR_LOGITS]
|
||||
unpacked_behaviour_logits = tf.split(
|
||||
behaviour_logits, output_hidden_shape, axis=1)
|
||||
unpacked_outputs = tf.split(policy.model_out, output_hidden_shape, axis=1)
|
||||
action_dist = policy.action_dist
|
||||
values = policy.value_function
|
||||
|
||||
# Unpack behaviour logits
|
||||
unpacked_behaviour_logits = tf.split(
|
||||
behaviour_logits, output_hidden_shape, axis=1)
|
||||
if policy.state_in:
|
||||
max_seq_len = tf.reduce_max(policy.seq_lens) - 1
|
||||
mask = tf.sequence_mask(policy.seq_lens, max_seq_len)
|
||||
mask = tf.reshape(mask, [-1])
|
||||
else:
|
||||
mask = tf.ones_like(rewards)
|
||||
|
||||
# Setup the policy
|
||||
prev_actions = ModelCatalog.get_action_placeholder(action_space)
|
||||
prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward")
|
||||
self.model = ModelCatalog.get_model(
|
||||
{
|
||||
"obs": observations,
|
||||
"prev_actions": prev_actions,
|
||||
"prev_rewards": prev_rewards,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
},
|
||||
observation_space,
|
||||
action_space,
|
||||
logit_dim,
|
||||
self.config["model"],
|
||||
state_in=existing_state_in,
|
||||
seq_lens=existing_seq_lens)
|
||||
unpacked_outputs = tf.split(
|
||||
self.model.outputs, output_hidden_shape, axis=1)
|
||||
# Prepare actions for loss
|
||||
loss_actions = actions if is_multidiscrete else tf.expand_dims(
|
||||
actions, axis=1)
|
||||
|
||||
action_dist = dist_class(self.model.outputs)
|
||||
# Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc.
|
||||
policy.loss = VTraceLoss(
|
||||
actions=make_time_major(loss_actions, drop_last=True),
|
||||
actions_logp=make_time_major(
|
||||
action_dist.logp(actions), drop_last=True),
|
||||
actions_entropy=make_time_major(action_dist.entropy(), drop_last=True),
|
||||
dones=make_time_major(dones, drop_last=True),
|
||||
behaviour_logits=make_time_major(
|
||||
unpacked_behaviour_logits, drop_last=True),
|
||||
target_logits=make_time_major(unpacked_outputs, drop_last=True),
|
||||
discount=policy.config["gamma"],
|
||||
rewards=make_time_major(rewards, drop_last=True),
|
||||
values=make_time_major(values, drop_last=True),
|
||||
bootstrap_value=make_time_major(values)[-1],
|
||||
dist_class=Categorical if is_multidiscrete else policy.dist_class,
|
||||
valid_mask=make_time_major(mask, drop_last=True),
|
||||
vf_loss_coeff=policy.config["vf_loss_coeff"],
|
||||
entropy_coeff=policy.config["entropy_coeff"],
|
||||
clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"],
|
||||
clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"])
|
||||
|
||||
values = self.model.value_function()
|
||||
return policy.loss.total_loss
|
||||
|
||||
|
||||
def stats(policy, batch_tensors):
|
||||
values_batched = _make_time_major(
|
||||
policy, policy.value_function, drop_last=policy.config["vtrace"])
|
||||
|
||||
return {
|
||||
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
|
||||
"policy_loss": policy.loss.pi_loss,
|
||||
"entropy": policy.loss.entropy,
|
||||
"var_gnorm": tf.global_norm(policy.var_list),
|
||||
"vf_loss": policy.loss.vf_loss,
|
||||
"vf_explained_var": explained_variance(
|
||||
tf.reshape(policy.loss.value_targets, [-1]),
|
||||
tf.reshape(values_batched, [-1])),
|
||||
}
|
||||
|
||||
|
||||
def grad_stats(policy, grads):
|
||||
return {
|
||||
"grad_gnorm": tf.global_norm(grads),
|
||||
}
|
||||
|
||||
|
||||
def postprocess_trajectory(policy,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
# not used, so save some bandwidth
|
||||
del sample_batch.data[SampleBatch.NEXT_OBS]
|
||||
return sample_batch
|
||||
|
||||
|
||||
def add_behaviour_logits(policy):
|
||||
return {BEHAVIOUR_LOGITS: policy.model_out}
|
||||
|
||||
|
||||
def validate_config(policy, obs_space, action_space, config):
|
||||
assert config["batch_mode"] == "truncate_episodes", \
|
||||
"Must use `truncate_episodes` batch mode with V-trace."
|
||||
|
||||
|
||||
def choose_optimizer(policy, config):
|
||||
if policy.config["opt_type"] == "adam":
|
||||
return tf.train.AdamOptimizer(policy.cur_lr)
|
||||
else:
|
||||
return tf.train.RMSPropOptimizer(policy.cur_lr, config["decay"],
|
||||
config["momentum"], config["epsilon"])
|
||||
|
||||
|
||||
def clip_gradients(policy, optimizer, loss):
|
||||
grads = tf.gradients(loss, policy.var_list)
|
||||
policy.grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"])
|
||||
clipped_grads = list(zip(policy.grads, policy.var_list))
|
||||
return clipped_grads
|
||||
|
||||
|
||||
class ValueNetworkMixin(object):
|
||||
def __init__(self):
|
||||
self.value_function = self.model.value_function()
|
||||
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
|
||||
tf.get_variable_scope().name)
|
||||
|
||||
def make_time_major(tensor, drop_last=False):
|
||||
"""Swaps batch and trajectory axis.
|
||||
Args:
|
||||
tensor: A tensor or list of tensors to reshape.
|
||||
drop_last: A bool indicating whether to drop the last
|
||||
trajectory item.
|
||||
Returns:
|
||||
res: A tensor with swapped axes or a list of tensors with
|
||||
swapped axes.
|
||||
"""
|
||||
if isinstance(tensor, list):
|
||||
return [make_time_major(t, drop_last) for t in tensor]
|
||||
|
||||
if self.model.state_init:
|
||||
B = tf.shape(self.model.seq_lens)[0]
|
||||
T = tf.shape(tensor)[0] // B
|
||||
else:
|
||||
# Important: chop the tensor into batches at known episode cut
|
||||
# boundaries. TODO(ekl) this is kind of a hack
|
||||
T = self.config["sample_batch_size"]
|
||||
B = tf.shape(tensor)[0] // T
|
||||
rs = tf.reshape(tensor,
|
||||
tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0))
|
||||
|
||||
# swap B and T axes
|
||||
res = tf.transpose(
|
||||
rs,
|
||||
[1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0]))))
|
||||
|
||||
if drop_last:
|
||||
return res[:-1]
|
||||
return res
|
||||
|
||||
if self.model.state_in:
|
||||
max_seq_len = tf.reduce_max(self.model.seq_lens) - 1
|
||||
mask = tf.sequence_mask(self.model.seq_lens, max_seq_len)
|
||||
mask = tf.reshape(mask, [-1])
|
||||
else:
|
||||
mask = tf.ones_like(rewards, dtype=tf.bool)
|
||||
|
||||
# Prepare actions for loss
|
||||
loss_actions = actions if is_multidiscrete else tf.expand_dims(
|
||||
actions, axis=1)
|
||||
|
||||
# Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc.
|
||||
self.loss = VTraceLoss(
|
||||
actions=make_time_major(loss_actions, drop_last=True),
|
||||
actions_logp=make_time_major(
|
||||
action_dist.logp(actions), drop_last=True),
|
||||
actions_entropy=make_time_major(
|
||||
action_dist.entropy(), drop_last=True),
|
||||
dones=make_time_major(dones, drop_last=True),
|
||||
behaviour_logits=make_time_major(
|
||||
unpacked_behaviour_logits, drop_last=True),
|
||||
target_logits=make_time_major(unpacked_outputs, drop_last=True),
|
||||
discount=config["gamma"],
|
||||
rewards=make_time_major(rewards, drop_last=True),
|
||||
values=make_time_major(values, drop_last=True),
|
||||
bootstrap_value=make_time_major(values)[-1],
|
||||
dist_class=Categorical if is_multidiscrete else dist_class,
|
||||
valid_mask=make_time_major(mask, drop_last=True),
|
||||
vf_loss_coeff=self.config["vf_loss_coeff"],
|
||||
entropy_coeff=self.config["entropy_coeff"],
|
||||
clip_rho_threshold=self.config["vtrace_clip_rho_threshold"],
|
||||
clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"])
|
||||
|
||||
# Initialize TFPolicy
|
||||
loss_in = [
|
||||
(SampleBatch.ACTIONS, actions),
|
||||
(SampleBatch.DONES, dones),
|
||||
(BEHAVIOUR_LOGITS, behaviour_logits),
|
||||
(SampleBatch.REWARDS, rewards),
|
||||
(SampleBatch.CUR_OBS, observations),
|
||||
(SampleBatch.PREV_ACTIONS, prev_actions),
|
||||
(SampleBatch.PREV_REWARDS, prev_rewards),
|
||||
]
|
||||
LearningRateSchedule.__init__(self, self.config["lr"],
|
||||
self.config["lr_schedule"])
|
||||
TFPolicy.__init__(
|
||||
self,
|
||||
observation_space,
|
||||
action_space,
|
||||
self.sess,
|
||||
obs_input=observations,
|
||||
action_sampler=action_dist.sample(),
|
||||
action_prob=action_dist.sampled_action_prob(),
|
||||
loss=self.loss.total_loss,
|
||||
model=self.model,
|
||||
loss_inputs=loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
state_outputs=self.model.state_out,
|
||||
prev_action_input=prev_actions,
|
||||
prev_reward_input=prev_rewards,
|
||||
seq_lens=self.model.seq_lens,
|
||||
max_seq_len=self.config["model"]["max_seq_len"],
|
||||
batch_divisibility_req=self.config["sample_batch_size"])
|
||||
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
self.stats_fetches = {
|
||||
LEARNER_STATS_KEY: {
|
||||
"cur_lr": tf.cast(self.cur_lr, tf.float64),
|
||||
"policy_loss": self.loss.pi_loss,
|
||||
"entropy": self.loss.entropy,
|
||||
"grad_gnorm": tf.global_norm(self._grads),
|
||||
"var_gnorm": tf.global_norm(self.var_list),
|
||||
"vf_loss": self.loss.vf_loss,
|
||||
"vf_explained_var": explained_variance(
|
||||
tf.reshape(self.loss.vtrace_returns.vs, [-1]),
|
||||
tf.reshape(make_time_major(values, drop_last=True), [-1])),
|
||||
},
|
||||
def value(self, ob, *args):
|
||||
feed_dict = {
|
||||
self.get_placeholder(SampleBatch.CUR_OBS): [ob],
|
||||
self.seq_lens: [1]
|
||||
}
|
||||
assert len(args) == len(self.state_in), \
|
||||
(args, self.state_in)
|
||||
for k, v in zip(self.state_in, args):
|
||||
feed_dict[k] = v
|
||||
vf = self.get_session().run(self.value_function, feed_dict)
|
||||
return vf[0]
|
||||
|
||||
@override(TFPolicy)
|
||||
def copy(self, existing_inputs):
|
||||
return VTraceTFPolicy(
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
self.config,
|
||||
existing_inputs=existing_inputs)
|
||||
|
||||
@override(TFPolicy)
|
||||
def optimizer(self):
|
||||
if self.config["opt_type"] == "adam":
|
||||
return tf.train.AdamOptimizer(self.cur_lr)
|
||||
else:
|
||||
return tf.train.RMSPropOptimizer(self.cur_lr, self.config["decay"],
|
||||
self.config["momentum"],
|
||||
self.config["epsilon"])
|
||||
def setup_mixins(policy, obs_space, action_space, config):
|
||||
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
||||
ValueNetworkMixin.__init__(policy)
|
||||
|
||||
@override(TFPolicy)
|
||||
def gradients(self, optimizer, loss):
|
||||
grads = tf.gradients(loss, self.var_list)
|
||||
self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"])
|
||||
clipped_grads = list(zip(self.grads, self.var_list))
|
||||
return clipped_grads
|
||||
|
||||
@override(TFPolicy)
|
||||
def extra_compute_grad_fetches(self):
|
||||
return self.stats_fetches
|
||||
|
||||
@override(Policy)
|
||||
def get_initial_state(self):
|
||||
return self.model.state_init
|
||||
VTraceTFPolicy = build_tf_policy(
|
||||
name="VTraceTFPolicy",
|
||||
get_default_config=lambda: ray.rllib.agents.impala.impala.DEFAULT_CONFIG,
|
||||
loss_fn=build_vtrace_loss,
|
||||
stats_fn=stats,
|
||||
grad_stats_fn=grad_stats,
|
||||
postprocess_fn=postprocess_trajectory,
|
||||
optimizer_fn=choose_optimizer,
|
||||
gradients_fn=clip_gradients,
|
||||
extra_action_fetches_fn=add_behaviour_logits,
|
||||
before_init=validate_config,
|
||||
before_loss_init=setup_mixins,
|
||||
mixins=[LearningRateSchedule, ValueNetworkMixin],
|
||||
get_batch_divisibility_req=lambda p: p.config["sample_batch_size"])
|
||||
|
||||
@@ -10,14 +10,12 @@ import numpy as np
|
||||
import logging
|
||||
import gym
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.impala import vtrace
|
||||
from ray.rllib.agents.impala.vtrace_policy import _make_time_major, \
|
||||
BEHAVIOUR_LOGITS, VTraceTFPolicy
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.models.action_dist import Categorical
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.policy.tf_policy import LearningRateSchedule
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
@@ -25,8 +23,6 @@ tf = try_import_tf()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BEHAVIOUR_LOGITS = "behaviour_logits"
|
||||
|
||||
|
||||
class PPOSurrogateLoss(object):
|
||||
"""Loss used when V-trace is disabled.
|
||||
@@ -163,41 +159,6 @@ class VTraceSurrogateLoss(object):
|
||||
self.entropy * entropy_coeff)
|
||||
|
||||
|
||||
def _make_time_major(policy, tensor, drop_last=False):
|
||||
"""Swaps batch and trajectory axis.
|
||||
|
||||
Arguments:
|
||||
policy: Policy reference
|
||||
tensor: A tensor or list of tensors to reshape.
|
||||
drop_last: A bool indicating whether to drop the last
|
||||
trajectory item.
|
||||
|
||||
Returns:
|
||||
res: A tensor with swapped axes or a list of tensors with
|
||||
swapped axes.
|
||||
"""
|
||||
if isinstance(tensor, list):
|
||||
return [_make_time_major(policy, t, drop_last) for t in tensor]
|
||||
|
||||
if policy.state_in:
|
||||
B = tf.shape(policy.seq_lens)[0]
|
||||
T = tf.shape(tensor)[0] // B
|
||||
else:
|
||||
# Important: chop the tensor into batches at known episode cut
|
||||
# boundaries. TODO(ekl) this is kind of a hack
|
||||
T = policy.config["sample_batch_size"]
|
||||
B = tf.shape(tensor)[0] // T
|
||||
rs = tf.reshape(tensor, tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0))
|
||||
|
||||
# swap B and T axes
|
||||
res = tf.transpose(
|
||||
rs, [1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0]))))
|
||||
|
||||
if drop_last:
|
||||
return res[:-1]
|
||||
return res
|
||||
|
||||
|
||||
def build_appo_surrogate_loss(policy, batch_tensors):
|
||||
if isinstance(policy.action_space, gym.spaces.Discrete):
|
||||
is_multidiscrete = False
|
||||
@@ -283,28 +244,6 @@ def build_appo_surrogate_loss(policy, batch_tensors):
|
||||
return policy.loss.total_loss
|
||||
|
||||
|
||||
def stats(policy, batch_tensors):
|
||||
values_batched = _make_time_major(
|
||||
policy, policy.value_function, drop_last=policy.config["vtrace"])
|
||||
|
||||
return {
|
||||
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
|
||||
"policy_loss": policy.loss.pi_loss,
|
||||
"entropy": policy.loss.entropy,
|
||||
"var_gnorm": tf.global_norm(policy.var_list),
|
||||
"vf_loss": policy.loss.vf_loss,
|
||||
"vf_explained_var": explained_variance(
|
||||
tf.reshape(policy.loss.value_targets, [-1]),
|
||||
tf.reshape(values_batched, [-1])),
|
||||
}
|
||||
|
||||
|
||||
def grad_stats(policy, grads):
|
||||
return {
|
||||
"grad_gnorm": tf.global_norm(grads),
|
||||
}
|
||||
|
||||
|
||||
def postprocess_trajectory(policy,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
@@ -337,61 +276,8 @@ def add_values_and_logits(policy):
|
||||
return out
|
||||
|
||||
|
||||
def validate_config(policy, obs_space, action_space, config):
|
||||
assert config["batch_mode"] == "truncate_episodes", \
|
||||
"Must use `truncate_episodes` batch mode with V-trace."
|
||||
|
||||
|
||||
def choose_optimizer(policy, config):
|
||||
if policy.config["opt_type"] == "adam":
|
||||
return tf.train.AdamOptimizer(policy.cur_lr)
|
||||
else:
|
||||
return tf.train.RMSPropOptimizer(policy.cur_lr, config["decay"],
|
||||
config["momentum"], config["epsilon"])
|
||||
|
||||
|
||||
def clip_gradients(policy, optimizer, loss):
|
||||
grads = tf.gradients(loss, policy.var_list)
|
||||
policy.grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"])
|
||||
clipped_grads = list(zip(policy.grads, policy.var_list))
|
||||
return clipped_grads
|
||||
|
||||
|
||||
class ValueNetworkMixin(object):
|
||||
def __init__(self):
|
||||
self.value_function = self.model.value_function()
|
||||
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
|
||||
tf.get_variable_scope().name)
|
||||
|
||||
def value(self, ob, *args):
|
||||
feed_dict = {
|
||||
self.get_placeholder(SampleBatch.CUR_OBS): [ob],
|
||||
self.seq_lens: [1]
|
||||
}
|
||||
assert len(args) == len(self.state_in), \
|
||||
(args, self.state_in)
|
||||
for k, v in zip(self.state_in, args):
|
||||
feed_dict[k] = v
|
||||
vf = self.get_session().run(self.value_function, feed_dict)
|
||||
return vf[0]
|
||||
|
||||
|
||||
def setup_mixins(policy, obs_space, action_space, config):
|
||||
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
||||
ValueNetworkMixin.__init__(policy)
|
||||
|
||||
|
||||
AsyncPPOTFPolicy = build_tf_policy(
|
||||
AsyncPPOTFPolicy = VTraceTFPolicy.with_updates(
|
||||
name="AsyncPPOTFPolicy",
|
||||
get_default_config=lambda: ray.rllib.agents.impala.impala.DEFAULT_CONFIG,
|
||||
loss_fn=build_appo_surrogate_loss,
|
||||
stats_fn=stats,
|
||||
grad_stats_fn=grad_stats,
|
||||
postprocess_fn=postprocess_trajectory,
|
||||
optimizer_fn=choose_optimizer,
|
||||
gradients_fn=clip_gradients,
|
||||
extra_action_fetches_fn=add_values_and_logits,
|
||||
before_init=validate_config,
|
||||
before_loss_init=setup_mixins,
|
||||
mixins=[LearningRateSchedule, ValueNetworkMixin],
|
||||
get_batch_divisibility_req=lambda p: p.config["sample_batch_size"])
|
||||
extra_action_fetches_fn=add_values_and_logits)
|
||||
|
||||
@@ -154,9 +154,6 @@ def validate_config(config):
|
||||
"FYI: By default, the value function will not share layers "
|
||||
"with the policy model ('vf_share_layers': False).")
|
||||
|
||||
# auto set the model option for layer sharing
|
||||
config["model"]["vf_share_layers"] = config["vf_share_layers"]
|
||||
|
||||
|
||||
PPOTrainer = build_trainer(
|
||||
name="PPO",
|
||||
|
||||
@@ -241,6 +241,11 @@ class ValueNetworkMixin(object):
|
||||
return vf[0]
|
||||
|
||||
|
||||
def setup_config(policy, obs_space, action_space, config):
|
||||
# auto set the model option for layer sharing
|
||||
config["model"]["vf_share_layers"] = config["vf_share_layers"]
|
||||
|
||||
|
||||
def setup_mixins(policy, obs_space, action_space, config):
|
||||
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
|
||||
KLCoeffMixin.__init__(policy, config)
|
||||
@@ -255,5 +260,6 @@ PPOTFPolicy = build_tf_policy(
|
||||
extra_action_fetches_fn=vf_preds_and_logits_fetches,
|
||||
postprocess_fn=postprocess_ppo_gae,
|
||||
gradients_fn=clip_gradients,
|
||||
before_init=setup_config,
|
||||
before_loss_init=setup_mixins,
|
||||
mixins=[LearningRateSchedule, KLCoeffMixin, ValueNetworkMixin])
|
||||
|
||||
@@ -93,7 +93,7 @@ class MyKerasQModel(DistributionalQModel):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init(local_mode=True)
|
||||
ray.init()
|
||||
args = parser.parse_args()
|
||||
ModelCatalog.register_custom_model("keras_model", MyKerasModel)
|
||||
ModelCatalog.register_custom_model("keras_q_model", MyKerasQModel)
|
||||
@@ -102,6 +102,7 @@ if __name__ == "__main__":
|
||||
stop={"episode_reward_mean": args.stop},
|
||||
config={
|
||||
"env": "CartPole-v0",
|
||||
"num_gpus": 0,
|
||||
"model": {
|
||||
"custom_model": "keras_q_model"
|
||||
if args.run == "DQN" else "keras_model"
|
||||
|
||||
@@ -23,8 +23,9 @@ from ray.rllib.models.visionnet import VisionNetwork
|
||||
from ray.rllib.models.lstm import LSTM
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
@@ -49,7 +50,7 @@ MODEL_DEFAULTS = {
|
||||
# should already match num_outputs.
|
||||
"no_final_linear": False,
|
||||
# Whether layers should be shared for the value function.
|
||||
"vf_share_layers": False,
|
||||
"vf_share_layers": True,
|
||||
|
||||
# == LSTM ==
|
||||
# Whether to wrap the model with a LSTM
|
||||
@@ -120,7 +121,7 @@ class ModelCatalog(object):
|
||||
config = config or MODEL_DEFAULTS
|
||||
if isinstance(action_space, gym.spaces.Box):
|
||||
if len(action_space.shape) > 1:
|
||||
raise ValueError(
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space has multiple dimensions "
|
||||
"{}. ".format(action_space.shape) +
|
||||
"Consider reshaping this into a single dimension, "
|
||||
|
||||
@@ -140,6 +140,7 @@ def build_tf_policy(name,
|
||||
action_sampler_fn=action_sampler_fn,
|
||||
existing_model=existing_model,
|
||||
existing_inputs=existing_inputs,
|
||||
get_batch_divisibility_req=get_batch_divisibility_req,
|
||||
obs_include_prev_action_reward=obs_include_prev_action_reward)
|
||||
|
||||
if after_init:
|
||||
|
||||
@@ -17,6 +17,10 @@ from ray.tune.registry import register_env
|
||||
ACTION_SPACES_TO_TEST = {
|
||||
"discrete": Discrete(5),
|
||||
"vector": Box(-1.0, 1.0, (5, ), dtype=np.float32),
|
||||
"vector2": Box(-1.0, 1.0, (
|
||||
5,
|
||||
5,
|
||||
), dtype=np.float32),
|
||||
"multidiscrete": MultiDiscrete([1, 2, 3, 4]),
|
||||
"tuple": Tuple(
|
||||
[Discrete(2),
|
||||
@@ -63,6 +67,8 @@ def make_stub_env(action_space, obs_space, check_action_bounds):
|
||||
|
||||
|
||||
def check_support(alg, config, stats, check_bounds=False, name=None):
|
||||
covered_a = set()
|
||||
covered_o = set()
|
||||
for a_name, action_space in ACTION_SPACES_TO_TEST.items():
|
||||
for o_name, obs_space in OBSERVATION_SPACES_TO_TEST.items():
|
||||
print("=== Testing", alg, action_space, obs_space, "===")
|
||||
@@ -71,8 +77,13 @@ def check_support(alg, config, stats, check_bounds=False, name=None):
|
||||
stat = "ok"
|
||||
a = None
|
||||
try:
|
||||
a = get_agent_class(alg)(config=config, env="stub_env")
|
||||
a.train()
|
||||
if a_name in covered_a and o_name in covered_o:
|
||||
stat = "skip" # speed up tests by avoiding full grid
|
||||
else:
|
||||
a = get_agent_class(alg)(config=config, env="stub_env")
|
||||
a.train()
|
||||
covered_a.add(a_name)
|
||||
covered_o.add(o_name)
|
||||
except UnsupportedSpaceException:
|
||||
stat = "unsupported"
|
||||
except Exception as e:
|
||||
@@ -171,7 +182,7 @@ class ModelSupportedSpaces(unittest.TestCase):
|
||||
check_bounds=True)
|
||||
num_unexpected_errors = 0
|
||||
for (alg, a_name, o_name), stat in sorted(stats.items()):
|
||||
if stat not in ["ok", "unsupported"]:
|
||||
if stat not in ["ok", "unsupported", "skip"]:
|
||||
num_unexpected_errors += 1
|
||||
print(alg, "action_space", a_name, "obs_space", o_name, "result",
|
||||
stat)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
basic-dqn:
|
||||
atari-dist-dqn:
|
||||
env:
|
||||
grid_search:
|
||||
- BreakoutNoFrameskip-v4
|
||||
|
||||
Reference in New Issue
Block a user