[rllib] [RFC] Dynamic definition of loss functions and modularization support (#4795)

* dynamic graph

* wip

* clean up

* fix

* document trainer

* wip

* initialize the graph using a fake batch

* clean up dynamic init

* wip

* spelling

* use builder for ppo pol graph

* add ppo graph

* fix naming

* order

* docs

* set class name correctly

* add torch builder

* add custom model support in builder

* cleanup

* remove underscores

* fix py2 compat

* Update dynamic_tf_policy_graph.py

* Update tracking_dict.py

* wip

* rename

* debug level

* rename policy_graph -> policy in new classes

* fix test

* rename ppo tf policy

* port appo too

* forgot grads

* default policy optimizer

* make default config optional

* add config to optimizer

* use lr by default in optimizer

* update

* comments

* remove optimizer

* fix tuple actions support in dynamic tf graph
This commit is contained in:
Eric Liang
2019-05-18 00:23:11 -07:00
committed by GitHub
parent 1ef9c0729d
commit 6cb5b90bd6
25 changed files with 1353 additions and 1006 deletions
+2 -2
View File
@@ -49,8 +49,8 @@ class A3CTrainer(Trainer):
def _init(self, config, env_creator):
if config["use_pytorch"]:
from ray.rllib.agents.a3c.a3c_torch_policy_graph import \
A3CTorchPolicyGraph
policy_cls = A3CTorchPolicyGraph
A3CTorchPolicy
policy_cls = A3CTorchPolicy
else:
policy_cls = self._policy_graph
@@ -7,109 +7,84 @@ import torch.nn.functional as F
from torch import nn
import ray
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.evaluation.postprocessing import compute_advantages, \
Postprocessing
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.evaluation.torch_policy_graph import TorchPolicyGraph
from ray.rllib.utils.annotations import override
from ray.rllib.evaluation.torch_policy_template import build_torch_policy
class A3CLoss(nn.Module):
def __init__(self, dist_class, vf_loss_coeff=0.5, entropy_coeff=0.01):
nn.Module.__init__(self)
self.dist_class = dist_class
self.vf_loss_coeff = vf_loss_coeff
self.entropy_coeff = entropy_coeff
def forward(self, policy_model, observations, actions, advantages,
value_targets):
logits, _, values, _ = policy_model({
SampleBatch.CUR_OBS: observations
}, [])
dist = self.dist_class(logits)
log_probs = dist.logp(actions)
self.entropy = dist.entropy().mean()
self.pi_err = -advantages.dot(log_probs.reshape(-1))
self.value_err = F.mse_loss(values.reshape(-1), value_targets)
overall_err = sum([
self.pi_err,
self.vf_loss_coeff * self.value_err,
-self.entropy_coeff * self.entropy,
])
return overall_err
def actor_critic_loss(policy, batch_tensors):
logits, _, values, _ = policy.model({
SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS]
}, [])
dist = policy.dist_class(logits)
log_probs = dist.logp(batch_tensors[SampleBatch.ACTIONS])
policy.entropy = dist.entropy().mean()
policy.pi_err = -batch_tensors[Postprocessing.ADVANTAGES].dot(
log_probs.reshape(-1))
policy.value_err = F.mse_loss(
values.reshape(-1), batch_tensors[Postprocessing.VALUE_TARGETS])
overall_err = sum([
policy.pi_err,
policy.config["vf_loss_coeff"] * policy.value_err,
-policy.config["entropy_coeff"] * policy.entropy,
])
return overall_err
class A3CPostprocessing(object):
"""Adds the VF preds and advantages fields to the trajectory."""
@override(TorchPolicyGraph)
def extra_action_out(self, model_out):
return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()}
@override(PolicyGraph)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
completed = sample_batch[SampleBatch.DONES][-1]
if completed:
last_r = 0.0
else:
last_r = self._value(sample_batch[SampleBatch.NEXT_OBS][-1])
return compute_advantages(sample_batch, last_r, self.config["gamma"],
self.config["lambda"])
def loss_and_entropy_stats(policy, batch_tensors):
return {
"policy_entropy": policy.entropy.item(),
"policy_loss": policy.pi_err.item(),
"vf_loss": policy.value_err.item(),
}
class A3CTorchPolicyGraph(A3CPostprocessing, TorchPolicyGraph):
"""A simple, non-recurrent PyTorch policy example."""
def add_advantages(policy,
sample_batch,
other_agent_batches=None,
episode=None):
completed = sample_batch[SampleBatch.DONES][-1]
if completed:
last_r = 0.0
else:
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1])
return compute_advantages(sample_batch, last_r, policy.config["gamma"],
policy.config["lambda"])
def __init__(self, obs_space, action_space, config):
config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
self.config = config
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"], torch=True)
model = ModelCatalog.get_torch_model(obs_space, self.logit_dim,
self.config["model"])
loss = A3CLoss(dist_class, self.config["vf_loss_coeff"],
self.config["entropy_coeff"])
TorchPolicyGraph.__init__(
self,
obs_space,
action_space,
model,
loss,
loss_inputs=[
SampleBatch.CUR_OBS, SampleBatch.ACTIONS,
Postprocessing.ADVANTAGES, Postprocessing.VALUE_TARGETS
],
action_distribution_cls=dist_class)
@override(TorchPolicyGraph)
def optimizer(self):
return torch.optim.Adam(self._model.parameters(), lr=self.config["lr"])
def model_value_predictions(policy, model_out):
return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()}
@override(TorchPolicyGraph)
def extra_grad_process(self):
info = {}
if self.config["grad_clip"]:
total_norm = nn.utils.clip_grad_norm_(self._model.parameters(),
self.config["grad_clip"])
info["grad_gnorm"] = total_norm
return info
@override(TorchPolicyGraph)
def extra_grad_info(self):
return {
"policy_entropy": self._loss.entropy.item(),
"policy_loss": self._loss.pi_err.item(),
"vf_loss": self._loss.value_err.item()
}
def apply_grad_clipping(policy):
info = {}
if policy.config["grad_clip"]:
total_norm = nn.utils.clip_grad_norm_(policy.model.parameters(),
policy.config["grad_clip"])
info["grad_gnorm"] = total_norm
return info
def torch_optimizer(policy, config):
return torch.optim.Adam(policy.model.parameters(), lr=config["lr"])
class ValueNetworkMixin(object):
def _value(self, obs):
with self.lock:
obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device)
_, _, vf, _ = self._model({"obs": obs}, [])
_, _, vf, _ = self.model({"obs": obs}, [])
return vf.detach().cpu().numpy().squeeze()
A3CTorchPolicy = build_torch_policy(
name="A3CTorchPolicy",
get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG,
loss_fn=actor_critic_loss,
stats_fn=loss_and_entropy_stats,
postprocess_fn=add_advantages,
extra_action_out_fn=model_value_predictions,
extra_grad_process_fn=apply_grad_clipping,
optimizer_fn=torch_optimizer,
mixins=[ValueNetworkMixin])
+14 -40
View File
@@ -2,11 +2,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
from ray.rllib.optimizers import SyncSamplesOptimizer
from ray.rllib.utils.annotations import override
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy
# yapf: disable
# __sphinx_doc_begin__
@@ -22,40 +20,16 @@ DEFAULT_CONFIG = with_common_config({
# yapf: enable
class PGTrainer(Trainer):
"""Simple policy gradient agent.
def get_policy_class(config):
if config["use_pytorch"]:
from ray.rllib.agents.pg.torch_pg_policy_graph import PGTorchPolicy
return PGTorchPolicy
else:
return PGTFPolicy
This is an example agent to show how to implement algorithms in RLlib.
In most cases, you will probably want to use the PPO agent instead.
"""
_name = "PG"
_default_config = DEFAULT_CONFIG
_policy_graph = PGPolicyGraph
@override(Trainer)
def _init(self, config, env_creator):
if config["use_pytorch"]:
from ray.rllib.agents.pg.torch_pg_policy_graph import \
PGTorchPolicyGraph
policy_cls = PGTorchPolicyGraph
else:
policy_cls = self._policy_graph
self.local_evaluator = self.make_local_evaluator(
env_creator, policy_cls)
self.remote_evaluators = self.make_remote_evaluators(
env_creator, policy_cls, config["num_workers"])
optimizer_config = dict(
config["optimizer"],
**{"train_batch_size": config["train_batch_size"]})
self.optimizer = SyncSamplesOptimizer(
self.local_evaluator, self.remote_evaluators, **optimizer_config)
@override(Trainer)
def _train(self):
prev_steps = self.optimizer.num_steps_sampled
self.optimizer.step()
result = self.collect_metrics()
result.update(timesteps_this_iter=self.optimizer.num_steps_sampled -
prev_steps)
return result
PGTrainer = build_trainer(
name="PG",
default_config=DEFAULT_CONFIG,
default_policy=PGTFPolicy,
get_policy_class=get_policy_class)
+18 -87
View File
@@ -3,102 +3,33 @@ from __future__ import division
from __future__ import print_function
import ray
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.evaluation.postprocessing import compute_advantages, \
Postprocessing
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.tf_policy_template import build_tf_policy
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.utils.annotations import override
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
class PGLoss(object):
"""The basic policy gradient loss."""
def __init__(self, action_dist, actions, advantages):
self.loss = -tf.reduce_mean(action_dist.logp(actions) * advantages)
# The basic policy gradients loss
def policy_gradient_loss(policy, batch_tensors):
actions = batch_tensors[SampleBatch.ACTIONS]
advantages = batch_tensors[Postprocessing.ADVANTAGES]
return -tf.reduce_mean(policy.action_dist.logp(actions) * advantages)
class PGPostprocessing(object):
"""Adds the advantages field to the trajectory."""
@override(PolicyGraph)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
# This adds the "advantages" column to the sample batch
return compute_advantages(
sample_batch, 0.0, self.config["gamma"], use_gae=False)
# This adds the "advantages" column to the sample batch.
def postprocess_advantages(policy,
sample_batch,
other_agent_batches=None,
episode=None):
return compute_advantages(
sample_batch, 0.0, policy.config["gamma"], use_gae=False)
class PGPolicyGraph(PGPostprocessing, TFPolicyGraph):
"""Simple policy gradient example of defining a policy graph."""
def __init__(self, obs_space, action_space, config):
config = dict(ray.rllib.agents.pg.pg.DEFAULT_CONFIG, **config)
self.config = config
# Setup placeholders
obs = tf.placeholder(tf.float32, shape=[None] + list(obs_space.shape))
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])
prev_actions = ModelCatalog.get_action_placeholder(action_space)
prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward")
# Create the model network and action outputs
self.model = ModelCatalog.get_model({
"obs": obs,
"prev_actions": prev_actions,
"prev_rewards": prev_rewards,
"is_training": self._get_is_training_placeholder(),
}, obs_space, action_space, self.logit_dim, self.config["model"])
action_dist = dist_class(self.model.outputs) # logit for each action
# Setup policy loss
actions = ModelCatalog.get_action_placeholder(action_space)
advantages = tf.placeholder(tf.float32, [None], name="adv")
loss = PGLoss(action_dist, actions, advantages).loss
# Mapping from sample batch keys to placeholders. These keys will be
# read from postprocessed sample batches and fed into the specified
# placeholders during loss computation.
loss_in = [
(SampleBatch.CUR_OBS, obs),
(SampleBatch.ACTIONS, actions),
(SampleBatch.PREV_ACTIONS, prev_actions),
(SampleBatch.PREV_REWARDS, prev_rewards),
(Postprocessing.ADVANTAGES, advantages),
]
# Initialize TFPolicyGraph
sess = tf.get_default_session()
TFPolicyGraph.__init__(
self,
obs_space,
action_space,
sess,
obs_input=obs,
action_sampler=action_dist.sample(),
action_prob=action_dist.sampled_action_prob(),
loss=loss,
loss_inputs=loss_in,
model=self.model,
state_inputs=self.model.state_in,
state_outputs=self.model.state_out,
prev_action_input=prev_actions,
prev_reward_input=prev_rewards,
seq_lens=self.model.seq_lens,
max_seq_len=config["model"]["max_seq_len"])
sess.run(tf.global_variables_initializer())
@override(PolicyGraph)
def get_initial_state(self):
return self.model.state_init
@override(TFPolicyGraph)
def optimizer(self):
return tf.train.AdamOptimizer(learning_rate=self.config["lr"])
PGTFPolicy = build_tf_policy(
name="PGTFPolicy",
get_default_config=lambda: ray.rllib.agents.pg.pg.DEFAULT_CONFIG,
postprocess_fn=postprocess_advantages,
loss_fn=policy_gradient_loss)
@@ -2,82 +2,41 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
from torch import nn
import ray
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.evaluation.postprocessing import compute_advantages, \
Postprocessing
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.evaluation.torch_policy_graph import TorchPolicyGraph
from ray.rllib.utils.annotations import override
from ray.rllib.evaluation.torch_policy_template import build_torch_policy
class PGLoss(nn.Module):
def __init__(self, dist_class):
nn.Module.__init__(self)
self.dist_class = dist_class
def forward(self, policy_model, observations, actions, advantages):
logits, _, values, _ = policy_model({
SampleBatch.CUR_OBS: observations
}, [])
dist = self.dist_class(logits)
log_probs = dist.logp(actions)
self.pi_err = -advantages.dot(log_probs.reshape(-1))
return self.pi_err
def pg_torch_loss(policy, batch_tensors):
logits, _, values, _ = policy.model({
SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS]
}, [])
action_dist = policy.dist_class(logits)
log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS])
# save the error in the policy object
policy.pi_err = -batch_tensors[Postprocessing.ADVANTAGES].dot(
log_probs.reshape(-1))
return policy.pi_err
class PGPostprocessing(object):
"""Adds the value func output and advantages field to the trajectory."""
@override(TorchPolicyGraph)
def extra_action_out(self, model_out):
return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()}
@override(PolicyGraph)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
return compute_advantages(
sample_batch, 0.0, self.config["gamma"], use_gae=False)
def postprocess_advantages(policy,
sample_batch,
other_agent_batches=None,
episode=None):
return compute_advantages(
sample_batch, 0.0, policy.config["gamma"], use_gae=False)
class PGTorchPolicyGraph(PGPostprocessing, TorchPolicyGraph):
def __init__(self, obs_space, action_space, config):
config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
self.config = config
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"], torch=True)
model = ModelCatalog.get_torch_model(obs_space, self.logit_dim,
self.config["model"])
loss = PGLoss(dist_class)
def pg_loss_stats(policy, batch_tensors):
# the error is recorded when computing the loss
return {"policy_loss": policy.pi_err.item()}
TorchPolicyGraph.__init__(
self,
obs_space,
action_space,
model,
loss,
loss_inputs=[
SampleBatch.CUR_OBS, SampleBatch.ACTIONS,
Postprocessing.ADVANTAGES
],
action_distribution_cls=dist_class)
@override(TorchPolicyGraph)
def optimizer(self):
return torch.optim.Adam(self._model.parameters(), lr=self.config["lr"])
@override(TorchPolicyGraph)
def extra_grad_info(self):
return {"policy_loss": self._loss.pi_err.item()}
def _value(self, obs):
with self.lock:
obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device)
_, _, vf, _ = self.model({"obs": obs}, [])
return vf.detach().cpu().numpy().squeeze()
PGTorchPolicy = build_torch_policy(
name="PGTorchPolicy",
get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG,
loss_fn=pg_torch_loss,
stats_fn=pg_loss_stats,
postprocess_fn=postprocess_advantages)
+3 -3
View File
@@ -2,7 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.ppo.appo_policy_graph import AsyncPPOPolicyGraph
from ray.rllib.agents.ppo.appo_policy_graph import AsyncPPOTFPolicy
from ray.rllib.agents.trainer import with_base_config
from ray.rllib.agents import impala
from ray.rllib.utils.annotations import override
@@ -57,8 +57,8 @@ class APPOTrainer(impala.ImpalaTrainer):
_name = "APPO"
_default_config = DEFAULT_CONFIG
_policy_graph = AsyncPPOPolicyGraph
_policy_graph = AsyncPPOTFPolicy
@override(impala.ImpalaTrainer)
def _get_policy_graph(self):
return AsyncPPOPolicyGraph
return AsyncPPOTFPolicy
+214 -313
View File
@@ -12,14 +12,11 @@ import gym
import ray
from ray.rllib.agents.impala import vtrace
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
LearningRateSchedule
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.utils.annotations import override
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.evaluation.tf_policy_template import build_tf_policy
from ray.rllib.evaluation.tf_policy_graph import LearningRateSchedule
from ray.rllib.utils.explained_variance import explained_variance
from ray.rllib.models.action_dist import MultiCategorical
from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.utils import try_import_tf
@@ -27,6 +24,8 @@ tf = try_import_tf()
logger = logging.getLogger(__name__)
BEHAVIOUR_LOGITS = "behaviour_logits"
class PPOSurrogateLoss(object):
"""Loss used when V-trace is disabled.
@@ -163,333 +162,235 @@ class VTraceSurrogateLoss(object):
self.entropy * entropy_coeff)
class APPOPostprocessing(object):
"""Adds the policy logits, VF preds, and advantages to the trajectory."""
def _make_time_major(policy, tensor, drop_last=False):
"""Swaps batch and trajectory axis.
@override(TFPolicyGraph)
def extra_compute_action_fetches(self):
out = {"behaviour_logits": self.model.outputs}
if not self.config["vtrace"]:
out["vf_preds"] = self.value_function
return dict(TFPolicyGraph.extra_compute_action_fetches(self), **out)
Arguments:
policy: Policy reference
tensor: A tensor or list of tensors to reshape.
drop_last: A bool indicating whether to drop the last
trajectory item.
@override(PolicyGraph)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
if not self.config["vtrace"]:
completed = sample_batch["dones"][-1]
if completed:
last_r = 0.0
else:
next_state = []
for i in range(len(self.model.state_in)):
next_state.append(
[sample_batch["state_out_{}".format(i)][-1]])
last_r = self.value(sample_batch["new_obs"][-1], *next_state)
batch = compute_advantages(
sample_batch,
last_r,
self.config["gamma"],
self.config["lambda"],
use_gae=self.config["use_gae"])
Returns:
res: A tensor with swapped axes or a list of tensors with
swapped axes.
"""
if isinstance(tensor, list):
return [_make_time_major(policy, t, drop_last) for t in tensor]
if policy.model.state_init:
B = tf.shape(policy.model.seq_lens)[0]
T = tf.shape(tensor)[0] // B
else:
# Important: chop the tensor into batches at known episode cut
# boundaries. TODO(ekl) this is kind of a hack
T = policy.config["sample_batch_size"]
B = tf.shape(tensor)[0] // T
rs = tf.reshape(tensor, tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0))
# swap B and T axes
res = tf.transpose(
rs, [1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0]))))
if drop_last:
return res[:-1]
return res
def build_appo_surrogate_loss(policy, batch_tensors):
if isinstance(policy.action_space, gym.spaces.Discrete):
is_multidiscrete = False
output_hidden_shape = [policy.action_space.n]
elif isinstance(policy.action_space,
gym.spaces.multi_discrete.MultiDiscrete):
is_multidiscrete = True
output_hidden_shape = policy.action_space.nvec.astype(np.int32)
else:
is_multidiscrete = False
output_hidden_shape = 1
def make_time_major(*args, **kw):
return _make_time_major(policy, *args, **kw)
actions = batch_tensors[SampleBatch.ACTIONS]
dones = batch_tensors[SampleBatch.DONES]
rewards = batch_tensors[SampleBatch.REWARDS]
behaviour_logits = batch_tensors[BEHAVIOUR_LOGITS]
unpacked_behaviour_logits = tf.split(
behaviour_logits, output_hidden_shape, axis=1)
unpacked_outputs = tf.split(
policy.model.outputs, output_hidden_shape, axis=1)
prev_dist_inputs = unpacked_behaviour_logits if is_multidiscrete else \
behaviour_logits
action_dist = policy.action_dist
prev_action_dist = policy.dist_class(prev_dist_inputs)
values = policy.value_function
if policy.model.state_in:
max_seq_len = tf.reduce_max(policy.model.seq_lens) - 1
mask = tf.sequence_mask(policy.model.seq_lens, max_seq_len)
mask = tf.reshape(mask, [-1])
else:
mask = tf.ones_like(rewards)
if policy.config["vtrace"]:
logger.info("Using V-Trace surrogate loss (vtrace=True)")
# Prepare actions for loss
loss_actions = actions if is_multidiscrete else tf.expand_dims(
actions, axis=1)
policy.loss = VTraceSurrogateLoss(
actions=make_time_major(loss_actions, drop_last=True),
prev_actions_logp=make_time_major(
prev_action_dist.logp(actions), drop_last=True),
actions_logp=make_time_major(
action_dist.logp(actions), drop_last=True),
action_kl=prev_action_dist.kl(action_dist),
actions_entropy=make_time_major(
action_dist.entropy(), drop_last=True),
dones=make_time_major(dones, drop_last=True),
behaviour_logits=make_time_major(
unpacked_behaviour_logits, drop_last=True),
target_logits=make_time_major(unpacked_outputs, drop_last=True),
discount=policy.config["gamma"],
rewards=make_time_major(rewards, drop_last=True),
values=make_time_major(values, drop_last=True),
bootstrap_value=make_time_major(values)[-1],
dist_class=policy.dist_class,
valid_mask=make_time_major(mask, drop_last=True),
vf_loss_coeff=policy.config["vf_loss_coeff"],
entropy_coeff=policy.config["entropy_coeff"],
clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"],
clip_pg_rho_threshold=policy.config[
"vtrace_clip_pg_rho_threshold"],
clip_param=policy.config["clip_param"])
else:
logger.info("Using PPO surrogate loss (vtrace=False)")
policy.loss = PPOSurrogateLoss(
prev_actions_logp=make_time_major(prev_action_dist.logp(actions)),
actions_logp=make_time_major(action_dist.logp(actions)),
action_kl=prev_action_dist.kl(action_dist),
actions_entropy=make_time_major(action_dist.entropy()),
values=make_time_major(values),
valid_mask=make_time_major(mask),
advantages=make_time_major(
batch_tensors[Postprocessing.ADVANTAGES]),
value_targets=make_time_major(
batch_tensors[Postprocessing.VALUE_TARGETS]),
vf_loss_coeff=policy.config["vf_loss_coeff"],
entropy_coeff=policy.config["entropy_coeff"],
clip_param=policy.config["clip_param"])
return policy.loss.total_loss
def stats(policy, batch_tensors):
values_batched = _make_time_major(
policy, policy.value_function, drop_last=policy.config["vtrace"])
return {
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
"policy_loss": policy.loss.pi_loss,
"entropy": policy.loss.entropy,
"var_gnorm": tf.global_norm(policy.var_list),
"vf_loss": policy.loss.vf_loss,
"vf_explained_var": explained_variance(
tf.reshape(policy.loss.value_targets, [-1]),
tf.reshape(values_batched, [-1])),
}
def grad_stats(policy, grads):
return {
"grad_gnorm": tf.global_norm(grads),
}
def postprocess_trajectory(policy,
sample_batch,
other_agent_batches=None,
episode=None):
if not policy.config["vtrace"]:
completed = sample_batch["dones"][-1]
if completed:
last_r = 0.0
else:
batch = sample_batch
del batch.data["new_obs"] # not used, so save some bandwidth
return batch
next_state = []
for i in range(len(policy.model.state_in)):
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
last_r = policy.value(sample_batch["new_obs"][-1], *next_state)
batch = compute_advantages(
sample_batch,
last_r,
policy.config["gamma"],
policy.config["lambda"],
use_gae=policy.config["use_gae"])
else:
batch = sample_batch
del batch.data["new_obs"] # not used, so save some bandwidth
return batch
class AsyncPPOPolicyGraph(LearningRateSchedule, APPOPostprocessing,
TFPolicyGraph):
def __init__(self,
observation_space,
action_space,
config,
existing_inputs=None):
config = dict(ray.rllib.agents.impala.impala.DEFAULT_CONFIG, **config)
assert config["batch_mode"] == "truncate_episodes", \
"Must use `truncate_episodes` batch mode with V-trace."
self.config = config
self.sess = tf.get_default_session()
self.grads = None
def add_values_and_logits(policy):
out = {BEHAVIOUR_LOGITS: policy.model.outputs}
if not policy.config["vtrace"]:
out[SampleBatch.VF_PREDS] = policy.value_function
return out
if isinstance(action_space, gym.spaces.Discrete):
is_multidiscrete = False
output_hidden_shape = [action_space.n]
elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete):
is_multidiscrete = True
output_hidden_shape = action_space.nvec.astype(np.int32)
else:
is_multidiscrete = False
output_hidden_shape = 1
# Policy network model
dist_class, logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])
def validate_config(policy, obs_space, action_space, config):
assert config["batch_mode"] == "truncate_episodes", \
"Must use `truncate_episodes` batch mode with V-trace."
# Create input placeholders
if existing_inputs:
if self.config["vtrace"]:
actions, dones, behaviour_logits, rewards, observations, \
prev_actions, prev_rewards = existing_inputs[:7]
existing_state_in = existing_inputs[7:-1]
existing_seq_lens = existing_inputs[-1]
else:
actions, dones, behaviour_logits, rewards, observations, \
prev_actions, prev_rewards, adv_ph, value_targets = \
existing_inputs[:9]
existing_state_in = existing_inputs[9:-1]
existing_seq_lens = existing_inputs[-1]
else:
actions = ModelCatalog.get_action_placeholder(action_space)
dones = tf.placeholder(tf.bool, [None], name="dones")
rewards = tf.placeholder(tf.float32, [None], name="rewards")
behaviour_logits = tf.placeholder(
tf.float32, [None, logit_dim], name="behaviour_logits")
observations = tf.placeholder(
tf.float32, [None] + list(observation_space.shape))
existing_state_in = None
existing_seq_lens = None
if not self.config["vtrace"]:
adv_ph = tf.placeholder(
tf.float32, name="advantages", shape=(None, ))
value_targets = tf.placeholder(
tf.float32, name="value_targets", shape=(None, ))
self.observations = observations
def choose_optimizer(policy, config):
if policy.config["opt_type"] == "adam":
return tf.train.AdamOptimizer(policy.cur_lr)
else:
return tf.train.RMSPropOptimizer(policy.cur_lr, config["decay"],
config["momentum"], config["epsilon"])
# Unpack behaviour logits
unpacked_behaviour_logits = tf.split(
behaviour_logits, output_hidden_shape, axis=1)
# Setup the policy
dist_class, logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])
prev_actions = ModelCatalog.get_action_placeholder(action_space)
prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward")
self.model = ModelCatalog.get_model(
{
"obs": observations,
"prev_actions": prev_actions,
"prev_rewards": prev_rewards,
"is_training": self._get_is_training_placeholder(),
},
observation_space,
action_space,
logit_dim,
self.config["model"],
state_in=existing_state_in,
seq_lens=existing_seq_lens)
unpacked_outputs = tf.split(
self.model.outputs, output_hidden_shape, axis=1)
def clip_gradients(policy, optimizer, loss):
grads = tf.gradients(loss, policy.var_list)
policy.grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"])
clipped_grads = list(zip(policy.grads, policy.var_list))
return clipped_grads
dist_inputs = unpacked_outputs if is_multidiscrete else \
self.model.outputs
prev_dist_inputs = unpacked_behaviour_logits if is_multidiscrete else \
behaviour_logits
action_dist = dist_class(dist_inputs)
prev_action_dist = dist_class(prev_dist_inputs)
values = self.model.value_function()
self.value_function = values
class ValueNetworkMixin(object):
def __init__(self):
self.value_function = self.model.value_function()
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
tf.get_variable_scope().name)
def make_time_major(tensor, drop_last=False):
"""Swaps batch and trajectory axis.
Args:
tensor: A tensor or list of tensors to reshape.
drop_last: A bool indicating whether to drop the last
trajectory item.
Returns:
res: A tensor with swapped axes or a list of tensors with
swapped axes.
"""
if isinstance(tensor, list):
return [make_time_major(t, drop_last) for t in tensor]
if self.model.state_init:
B = tf.shape(self.model.seq_lens)[0]
T = tf.shape(tensor)[0] // B
else:
# Important: chop the tensor into batches at known episode cut
# boundaries. TODO(ekl) this is kind of a hack
T = self.config["sample_batch_size"]
B = tf.shape(tensor)[0] // T
rs = tf.reshape(tensor,
tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0))
# swap B and T axes
res = tf.transpose(
rs,
[1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0]))))
if drop_last:
return res[:-1]
return res
if self.model.state_in:
max_seq_len = tf.reduce_max(self.model.seq_lens) - 1
mask = tf.sequence_mask(self.model.seq_lens, max_seq_len)
mask = tf.reshape(mask, [-1])
else:
mask = tf.ones_like(rewards)
# Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc.
if self.config["vtrace"]:
logger.info("Using V-Trace surrogate loss (vtrace=True)")
# Prepare actions for loss
loss_actions = actions if is_multidiscrete else tf.expand_dims(
actions, axis=1)
self.loss = VTraceSurrogateLoss(
actions=make_time_major(loss_actions, drop_last=True),
prev_actions_logp=make_time_major(
prev_action_dist.logp(actions), drop_last=True),
actions_logp=make_time_major(
action_dist.logp(actions), drop_last=True),
action_kl=prev_action_dist.kl(action_dist),
actions_entropy=make_time_major(
action_dist.entropy(), drop_last=True),
dones=make_time_major(dones, drop_last=True),
behaviour_logits=make_time_major(
unpacked_behaviour_logits, drop_last=True),
target_logits=make_time_major(
unpacked_outputs, drop_last=True),
discount=config["gamma"],
rewards=make_time_major(rewards, drop_last=True),
values=make_time_major(values, drop_last=True),
bootstrap_value=make_time_major(values)[-1],
dist_class=dist_class,
valid_mask=make_time_major(mask, drop_last=True),
vf_loss_coeff=self.config["vf_loss_coeff"],
entropy_coeff=self.config["entropy_coeff"],
clip_rho_threshold=self.config["vtrace_clip_rho_threshold"],
clip_pg_rho_threshold=self.config[
"vtrace_clip_pg_rho_threshold"],
clip_param=self.config["clip_param"])
else:
logger.info("Using PPO surrogate loss (vtrace=False)")
self.loss = PPOSurrogateLoss(
prev_actions_logp=make_time_major(
prev_action_dist.logp(actions)),
actions_logp=make_time_major(action_dist.logp(actions)),
action_kl=prev_action_dist.kl(action_dist),
actions_entropy=make_time_major(action_dist.entropy()),
values=make_time_major(values),
valid_mask=make_time_major(mask),
advantages=make_time_major(adv_ph),
value_targets=make_time_major(value_targets),
vf_loss_coeff=self.config["vf_loss_coeff"],
entropy_coeff=self.config["entropy_coeff"],
clip_param=self.config["clip_param"])
# KL divergence between worker and learner logits for debugging
model_dist = MultiCategorical(unpacked_outputs)
behaviour_dist = MultiCategorical(unpacked_behaviour_logits)
kls = model_dist.kl(behaviour_dist)
if len(kls) > 1:
self.KL_stats = {}
for i, kl in enumerate(kls):
self.KL_stats.update({
"mean_KL_{}".format(i): tf.reduce_mean(kl),
"max_KL_{}".format(i): tf.reduce_max(kl),
})
else:
self.KL_stats = {
"mean_KL": tf.reduce_mean(kls[0]),
"max_KL": tf.reduce_max(kls[0]),
}
# Initialize TFPolicyGraph
loss_in = [
("actions", actions),
("dones", dones),
("behaviour_logits", behaviour_logits),
("rewards", rewards),
("obs", observations),
("prev_actions", prev_actions),
("prev_rewards", prev_rewards),
]
if not self.config["vtrace"]:
loss_in.append(("advantages", adv_ph))
loss_in.append(("value_targets", value_targets))
LearningRateSchedule.__init__(self, self.config["lr"],
self.config["lr_schedule"])
TFPolicyGraph.__init__(
self,
observation_space,
action_space,
self.sess,
obs_input=observations,
action_sampler=action_dist.sample(),
action_prob=action_dist.sampled_action_prob(),
loss=self.loss.total_loss,
model=self.model,
loss_inputs=loss_in,
state_inputs=self.model.state_in,
state_outputs=self.model.state_out,
prev_action_input=prev_actions,
prev_reward_input=prev_rewards,
seq_lens=self.model.seq_lens,
max_seq_len=self.config["model"]["max_seq_len"],
batch_divisibility_req=self.config["sample_batch_size"])
self.sess.run(tf.global_variables_initializer())
values_batched = make_time_major(
values, drop_last=self.config["vtrace"])
self.stats_fetches = {
LEARNER_STATS_KEY: dict({
"cur_lr": tf.cast(self.cur_lr, tf.float64),
"policy_loss": self.loss.pi_loss,
"entropy": self.loss.entropy,
"grad_gnorm": tf.global_norm(self._grads),
"var_gnorm": tf.global_norm(self.var_list),
"vf_loss": self.loss.vf_loss,
"vf_explained_var": explained_variance(
tf.reshape(self.loss.value_targets, [-1]),
tf.reshape(values_batched, [-1])),
}, **self.KL_stats),
}
def optimizer(self):
if self.config["opt_type"] == "adam":
return tf.train.AdamOptimizer(self.cur_lr)
else:
return tf.train.RMSPropOptimizer(self.cur_lr, self.config["decay"],
self.config["momentum"],
self.config["epsilon"])
def gradients(self, optimizer, loss):
grads = tf.gradients(loss, self.var_list)
self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"])
clipped_grads = list(zip(self.grads, self.var_list))
return clipped_grads
def extra_compute_grad_fetches(self):
return self.stats_fetches
def value(self, ob, *args):
feed_dict = {self.observations: [ob], self.model.seq_lens: [1]}
feed_dict = {self._obs_input: [ob], self.model.seq_lens: [1]}
assert len(args) == len(self.model.state_in), \
(args, self.model.state_in)
for k, v in zip(self.model.state_in, args):
feed_dict[k] = v
vf = self.sess.run(self.value_function, feed_dict)
vf = self._sess.run(self.value_function, feed_dict)
return vf[0]
def get_initial_state(self):
return self.model.state_init
def copy(self, existing_inputs):
return AsyncPPOPolicyGraph(
self.observation_space,
self.action_space,
self.config,
existing_inputs=existing_inputs)
def setup_mixins(policy, obs_space, action_space, config):
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
ValueNetworkMixin.__init__(policy)
AsyncPPOTFPolicy = build_tf_policy(
name="AsyncPPOTFPolicy",
get_default_config=lambda: ray.rllib.agents.impala.impala.DEFAULT_CONFIG,
loss_fn=build_appo_surrogate_loss,
stats_fn=stats,
grad_stats_fn=grad_stats,
postprocess_fn=postprocess_trajectory,
optimizer_fn=choose_optimizer,
gradients_fn=clip_gradients,
extra_action_fetches_fn=add_values_and_logits,
before_init=validate_config,
before_loss_init=setup_mixins,
mixins=[LearningRateSchedule, ValueNetworkMixin],
get_batch_divisibility_req=lambda p: p.config["sample_batch_size"])
+97 -103
View File
@@ -4,10 +4,10 @@ from __future__ import print_function
import logging
from ray.rllib.agents import Trainer, with_common_config
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
from ray.rllib.agents import with_common_config
from ray.rllib.agents.ppo.ppo_policy_graph import PPOTFPolicy
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.optimizers import SyncSamplesOptimizer, LocalMultiGPUOptimizer
from ray.rllib.utils.annotations import override
logger = logging.getLogger(__name__)
@@ -63,110 +63,104 @@ DEFAULT_CONFIG = with_common_config({
# yapf: enable
class PPOTrainer(Trainer):
"""Multi-GPU optimized implementation of PPO in TensorFlow."""
def make_optimizer(local_evaluator, remote_evaluators, config):
if config["simple_optimizer"]:
return SyncSamplesOptimizer(
local_evaluator,
remote_evaluators,
num_sgd_iter=config["num_sgd_iter"],
train_batch_size=config["train_batch_size"])
_name = "PPO"
_default_config = DEFAULT_CONFIG
_policy_graph = PPOPolicyGraph
return LocalMultiGPUOptimizer(
local_evaluator,
remote_evaluators,
sgd_batch_size=config["sgd_minibatch_size"],
num_sgd_iter=config["num_sgd_iter"],
num_gpus=config["num_gpus"],
sample_batch_size=config["sample_batch_size"],
num_envs_per_worker=config["num_envs_per_worker"],
train_batch_size=config["train_batch_size"],
standardize_fields=["advantages"],
straggler_mitigation=config["straggler_mitigation"])
@override(Trainer)
def _init(self, config, env_creator):
self._validate_config()
self.local_evaluator = self.make_local_evaluator(
env_creator, self._policy_graph)
self.remote_evaluators = self.make_remote_evaluators(
env_creator, self._policy_graph, config["num_workers"])
if config["simple_optimizer"]:
self.optimizer = SyncSamplesOptimizer(
self.local_evaluator,
self.remote_evaluators,
num_sgd_iter=config["num_sgd_iter"],
train_batch_size=config["train_batch_size"])
else:
self.optimizer = LocalMultiGPUOptimizer(
self.local_evaluator,
self.remote_evaluators,
sgd_batch_size=config["sgd_minibatch_size"],
num_sgd_iter=config["num_sgd_iter"],
num_gpus=config["num_gpus"],
sample_batch_size=config["sample_batch_size"],
num_envs_per_worker=config["num_envs_per_worker"],
train_batch_size=config["train_batch_size"],
standardize_fields=["advantages"],
straggler_mitigation=config["straggler_mitigation"])
@override(Trainer)
def _train(self):
if "observation_filter" not in self.raw_user_config:
# TODO(ekl) remove this message after a few releases
logger.info(
"Important! Since 0.7.0, observation normalization is no "
"longer enabled by default. To enable running-mean "
"normalization, set 'observation_filter': 'MeanStdFilter'. "
"You can ignore this message if your environment doesn't "
"require observation normalization.")
prev_steps = self.optimizer.num_steps_sampled
fetches = self.optimizer.step()
if "kl" in fetches:
# single-agent
self.local_evaluator.for_policy(
lambda pi: pi.update_kl(fetches["kl"]))
else:
def update_kl(trainer, fetches):
if "kl" in fetches:
# single-agent
trainer.local_evaluator.for_policy(
lambda pi: pi.update_kl(fetches["kl"]))
else:
def update(pi, pi_id):
if pi_id in fetches:
pi.update_kl(fetches[pi_id]["kl"])
else:
logger.debug(
"No data for {}, not updating kl".format(pi_id))
def update(pi, pi_id):
if pi_id in fetches:
pi.update_kl(fetches[pi_id]["kl"])
else:
logger.debug("No data for {}, not updating kl".format(pi_id))
# multi-agent
self.local_evaluator.foreach_trainable_policy(update)
res = self.collect_metrics()
res.update(
timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps,
info=res.get("info", {}))
# multi-agent
trainer.local_evaluator.foreach_trainable_policy(update)
# Warn about bad clipping configs
if self.config["vf_clip_param"] <= 0:
rew_scale = float("inf")
elif res["policy_reward_mean"]:
rew_scale = 0 # punt on handling multiagent case
else:
rew_scale = round(
abs(res["episode_reward_mean"]) / self.config["vf_clip_param"],
0)
if rew_scale > 200:
logger.warning(
"The magnitude of your environment rewards are more than "
"{}x the scale of `vf_clip_param`. ".format(rew_scale) +
"This means that it will take more than "
"{} iterations for your value ".format(rew_scale) +
"function to converge. If this is not intended, consider "
"increasing `vf_clip_param`.")
return res
def _validate_config(self):
if self.config["entropy_coeff"] < 0:
raise DeprecationWarning("entropy_coeff must be >= 0")
if self.config["sgd_minibatch_size"] > self.config["train_batch_size"]:
raise ValueError(
"Minibatch size {} must be <= train batch size {}.".format(
self.config["sgd_minibatch_size"],
self.config["train_batch_size"]))
if (self.config["batch_mode"] == "truncate_episodes"
and not self.config["use_gae"]):
raise ValueError(
"Episode truncation is not supported without a value "
"function. Consider setting batch_mode=complete_episodes.")
if (self.config["multiagent"]["policy_graphs"]
and not self.config["simple_optimizer"]):
logger.info(
"In multi-agent mode, policies will be optimized sequentially "
"by the multi-GPU optimizer. Consider setting "
"simple_optimizer=True if this doesn't work for you.")
if not self.config["vf_share_layers"]:
logger.warning(
"FYI: By default, the value function will not share layers "
"with the policy model ('vf_share_layers': False).")
def warn_about_obs_filter(trainer):
if "observation_filter" not in trainer.raw_user_config:
# TODO(ekl) remove this message after a few releases
logger.info(
"Important! Since 0.7.0, observation normalization is no "
"longer enabled by default. To enable running-mean "
"normalization, set 'observation_filter': 'MeanStdFilter'. "
"You can ignore this message if your environment doesn't "
"require observation normalization.")
def warn_about_bad_reward_scales(trainer, result):
# Warn about bad clipping configs
if trainer.config["vf_clip_param"] <= 0:
rew_scale = float("inf")
elif result["policy_reward_mean"]:
rew_scale = 0 # punt on handling multiagent case
else:
rew_scale = round(
abs(result["episode_reward_mean"]) /
trainer.config["vf_clip_param"], 0)
if rew_scale > 200:
logger.warning(
"The magnitude of your environment rewards are more than "
"{}x the scale of `vf_clip_param`. ".format(rew_scale) +
"This means that it will take more than "
"{} iterations for your value ".format(rew_scale) +
"function to converge. If this is not intended, consider "
"increasing `vf_clip_param`.")
def validate_config(config):
if config["entropy_coeff"] < 0:
raise DeprecationWarning("entropy_coeff must be >= 0")
if config["sgd_minibatch_size"] > config["train_batch_size"]:
raise ValueError(
"Minibatch size {} must be <= train batch size {}.".format(
config["sgd_minibatch_size"], config["train_batch_size"]))
if (config["batch_mode"] == "truncate_episodes" and not config["use_gae"]):
raise ValueError(
"Episode truncation is not supported without a value "
"function. Consider setting batch_mode=complete_episodes.")
if (config["multiagent"]["policy_graphs"]
and not config["simple_optimizer"]):
logger.info(
"In multi-agent mode, policies will be optimized sequentially "
"by the multi-GPU optimizer. Consider setting "
"simple_optimizer=True if this doesn't work for you.")
if not config["vf_share_layers"]:
logger.warning(
"FYI: By default, the value function will not share layers "
"with the policy model ('vf_share_layers': False).")
PPOTrainer = build_trainer(
name="PPO",
default_config=DEFAULT_CONFIG,
default_policy=PPOTFPolicy,
make_policy_optimizer=make_optimizer,
validate_config=validate_config,
after_optimizer_step=update_kl,
before_train_step=warn_about_obs_filter,
after_train_result=warn_about_bad_reward_scales)
+136 -223
View File
@@ -7,13 +7,10 @@ import logging
import ray
from ray.rllib.evaluation.postprocessing import compute_advantages, \
Postprocessing
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
LearningRateSchedule
from ray.rllib.evaluation.tf_policy_graph import LearningRateSchedule
from ray.rllib.evaluation.tf_policy_template import build_tf_policy
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.utils.annotations import override
from ray.rllib.utils.explained_variance import explained_variance
from ray.rllib.utils import try_import_tf
@@ -107,119 +104,106 @@ class PPOLoss(object):
self.loss = loss
class PPOPostprocessing(object):
def ppo_surrogate_loss(policy, batch_tensors):
if policy.model.state_in:
max_seq_len = tf.reduce_max(policy.model.seq_lens)
mask = tf.sequence_mask(policy.model.seq_lens, max_seq_len)
mask = tf.reshape(mask, [-1])
else:
mask = tf.ones_like(
batch_tensors[Postprocessing.ADVANTAGES], dtype=tf.bool)
policy.loss_obj = PPOLoss(
policy.action_space,
batch_tensors[Postprocessing.VALUE_TARGETS],
batch_tensors[Postprocessing.ADVANTAGES],
batch_tensors[SampleBatch.ACTIONS],
batch_tensors[BEHAVIOUR_LOGITS],
batch_tensors[SampleBatch.VF_PREDS],
policy.action_dist,
policy.value_function,
policy.kl_coeff,
mask,
entropy_coeff=policy.config["entropy_coeff"],
clip_param=policy.config["clip_param"],
vf_clip_param=policy.config["vf_clip_param"],
vf_loss_coeff=policy.config["vf_loss_coeff"],
use_gae=policy.config["use_gae"])
return policy.loss_obj.loss
def kl_and_loss_stats(policy, batch_tensors):
policy.explained_variance = explained_variance(
batch_tensors[Postprocessing.VALUE_TARGETS], policy.value_function)
stats_fetches = {
"cur_kl_coeff": policy.kl_coeff,
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
"total_loss": policy.loss_obj.loss,
"policy_loss": policy.loss_obj.mean_policy_loss,
"vf_loss": policy.loss_obj.mean_vf_loss,
"vf_explained_var": policy.explained_variance,
"kl": policy.loss_obj.mean_kl,
"entropy": policy.loss_obj.mean_entropy,
}
return stats_fetches
def vf_preds_and_logits_fetches(policy):
"""Adds value function and logits outputs to experience batches."""
return {
SampleBatch.VF_PREDS: policy.value_function,
BEHAVIOUR_LOGITS: policy.model.outputs,
}
def postprocess_ppo_gae(policy,
sample_batch,
other_agent_batches=None,
episode=None):
"""Adds the policy logits, VF preds, and advantages to the trajectory."""
@override(TFPolicyGraph)
def extra_compute_action_fetches(self):
return dict(
TFPolicyGraph.extra_compute_action_fetches(self), **{
SampleBatch.VF_PREDS: self.value_function,
BEHAVIOUR_LOGITS: self.logits
})
@override(PolicyGraph)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
completed = sample_batch["dones"][-1]
if completed:
last_r = 0.0
else:
next_state = []
for i in range(len(self.model.state_in)):
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
last_r = self._value(sample_batch[SampleBatch.NEXT_OBS][-1],
sample_batch[SampleBatch.ACTIONS][-1],
sample_batch[SampleBatch.REWARDS][-1],
*next_state)
batch = compute_advantages(
sample_batch,
last_r,
self.config["gamma"],
self.config["lambda"],
use_gae=self.config["use_gae"])
return batch
completed = sample_batch["dones"][-1]
if completed:
last_r = 0.0
else:
next_state = []
for i in range(len(policy.model.state_in)):
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
sample_batch[SampleBatch.ACTIONS][-1],
sample_batch[SampleBatch.REWARDS][-1],
*next_state)
batch = compute_advantages(
sample_batch,
last_r,
policy.config["gamma"],
policy.config["lambda"],
use_gae=policy.config["use_gae"])
return batch
class PPOPolicyGraph(LearningRateSchedule, PPOPostprocessing, TFPolicyGraph):
def __init__(self,
observation_space,
action_space,
config,
existing_inputs=None):
"""
Arguments:
observation_space: Environment observation space specification.
action_space: Environment action space specification.
config (dict): Configuration values for PPO graph.
existing_inputs (list): Optional list of tuples that specify the
placeholders upon which the graph should be built upon.
"""
config = dict(ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, **config)
self.sess = tf.get_default_session()
self.action_space = action_space
self.config = config
self.kl_coeff_val = self.config["kl_coeff"]
self.kl_target = self.config["kl_target"]
dist_cls, logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])
def clip_gradients(policy, optimizer, loss):
if policy.config["grad_clip"] is not None:
policy.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
tf.get_variable_scope().name)
grads = tf.gradients(loss, policy.var_list)
policy.grads, _ = tf.clip_by_global_norm(grads,
policy.config["grad_clip"])
clipped_grads = list(zip(policy.grads, policy.var_list))
return clipped_grads
else:
return optimizer.compute_gradients(
loss, colocate_gradients_with_ops=True)
if existing_inputs:
obs_ph, value_targets_ph, adv_ph, act_ph, \
logits_ph, vf_preds_ph, prev_actions_ph, prev_rewards_ph = \
existing_inputs[:8]
existing_state_in = existing_inputs[8:-1]
existing_seq_lens = existing_inputs[-1]
else:
obs_ph = tf.placeholder(
tf.float32,
name="obs",
shape=(None, ) + observation_space.shape)
adv_ph = tf.placeholder(
tf.float32, name="advantages", shape=(None, ))
act_ph = ModelCatalog.get_action_placeholder(action_space)
logits_ph = tf.placeholder(
tf.float32, name="logits", shape=(None, logit_dim))
vf_preds_ph = tf.placeholder(
tf.float32, name="vf_preds", shape=(None, ))
value_targets_ph = tf.placeholder(
tf.float32, name="value_targets", shape=(None, ))
prev_actions_ph = ModelCatalog.get_action_placeholder(action_space)
prev_rewards_ph = tf.placeholder(
tf.float32, [None], name="prev_reward")
existing_state_in = None
existing_seq_lens = None
self.observations = obs_ph
self.prev_actions = prev_actions_ph
self.prev_rewards = prev_rewards_ph
self.loss_in = [
(SampleBatch.CUR_OBS, obs_ph),
(Postprocessing.VALUE_TARGETS, value_targets_ph),
(Postprocessing.ADVANTAGES, adv_ph),
(SampleBatch.ACTIONS, act_ph),
(BEHAVIOUR_LOGITS, logits_ph),
(SampleBatch.VF_PREDS, vf_preds_ph),
(SampleBatch.PREV_ACTIONS, prev_actions_ph),
(SampleBatch.PREV_REWARDS, prev_rewards_ph),
]
self.model = ModelCatalog.get_model(
{
"obs": obs_ph,
"prev_actions": prev_actions_ph,
"prev_rewards": prev_rewards_ph,
"is_training": self._get_is_training_placeholder(),
},
observation_space,
action_space,
logit_dim,
self.config["model"],
state_in=existing_state_in,
seq_lens=existing_seq_lens)
class KLCoeffMixin(object):
def __init__(self, config):
# KL Coefficient
self.kl_coeff_val = config["kl_coeff"]
self.kl_target = config["kl_target"]
self.kl_coeff = tf.get_variable(
initializer=tf.constant_initializer(self.kl_coeff_val),
name="kl_coeff",
@@ -227,14 +211,22 @@ class PPOPolicyGraph(LearningRateSchedule, PPOPostprocessing, TFPolicyGraph):
trainable=False,
dtype=tf.float32)
self.logits = self.model.outputs
curr_action_dist = dist_cls(self.logits)
self.sampler = curr_action_dist.sample()
if self.config["use_gae"]:
if self.config["vf_share_layers"]:
def update_kl(self, sampled_kl):
if sampled_kl > 2.0 * self.kl_target:
self.kl_coeff_val *= 1.5
elif sampled_kl < 0.5 * self.kl_target:
self.kl_coeff_val *= 0.5
self.kl_coeff.load(self.kl_coeff_val, session=self._sess)
return self.kl_coeff_val
class ValueNetworkMixin(object):
def __init__(self, obs_space, action_space, config):
if config["use_gae"]:
if config["vf_share_layers"]:
self.value_function = self.model.value_function()
else:
vf_config = self.config["model"].copy()
vf_config = config["model"].copy()
# Do not split the last layer of the value function into
# mean parameters and standard deviation parameters and
# do not make the standard deviations free variables.
@@ -249,122 +241,43 @@ class PPOPolicyGraph(LearningRateSchedule, PPOPostprocessing, TFPolicyGraph):
"value_function() method.")
with tf.variable_scope("value_function"):
self.value_function = ModelCatalog.get_model({
"obs": obs_ph,
"prev_actions": prev_actions_ph,
"prev_rewards": prev_rewards_ph,
"obs": self._obs_input,
"prev_actions": self._prev_action_input,
"prev_rewards": self._prev_reward_input,
"is_training": self._get_is_training_placeholder(),
}, observation_space, action_space, 1, vf_config).outputs
}, obs_space, action_space, 1, vf_config).outputs
self.value_function = tf.reshape(self.value_function, [-1])
else:
self.value_function = tf.zeros(shape=tf.shape(obs_ph)[:1])
if self.model.state_in:
max_seq_len = tf.reduce_max(self.model.seq_lens)
mask = tf.sequence_mask(self.model.seq_lens, max_seq_len)
mask = tf.reshape(mask, [-1])
else:
mask = tf.ones_like(adv_ph, dtype=tf.bool)
self.loss_obj = PPOLoss(
action_space,
value_targets_ph,
adv_ph,
act_ph,
logits_ph,
vf_preds_ph,
curr_action_dist,
self.value_function,
self.kl_coeff,
mask,
entropy_coeff=self.config["entropy_coeff"],
clip_param=self.config["clip_param"],
vf_clip_param=self.config["vf_clip_param"],
vf_loss_coeff=self.config["vf_loss_coeff"],
use_gae=self.config["use_gae"])
LearningRateSchedule.__init__(self, self.config["lr"],
self.config["lr_schedule"])
TFPolicyGraph.__init__(
self,
observation_space,
action_space,
self.sess,
obs_input=obs_ph,
action_sampler=self.sampler,
action_prob=curr_action_dist.sampled_action_prob(),
loss=self.loss_obj.loss,
model=self.model,
loss_inputs=self.loss_in,
state_inputs=self.model.state_in,
state_outputs=self.model.state_out,
prev_action_input=prev_actions_ph,
prev_reward_input=prev_rewards_ph,
seq_lens=self.model.seq_lens,
max_seq_len=config["model"]["max_seq_len"])
self.sess.run(tf.global_variables_initializer())
self.explained_variance = explained_variance(value_targets_ph,
self.value_function)
self.stats_fetches = {
"cur_kl_coeff": self.kl_coeff,
"cur_lr": tf.cast(self.cur_lr, tf.float64),
"total_loss": self.loss_obj.loss,
"policy_loss": self.loss_obj.mean_policy_loss,
"vf_loss": self.loss_obj.mean_vf_loss,
"vf_explained_var": self.explained_variance,
"kl": self.loss_obj.mean_kl,
"entropy": self.loss_obj.mean_entropy
}
@override(TFPolicyGraph)
def copy(self, existing_inputs):
"""Creates a copy of self using existing input placeholders."""
return PPOPolicyGraph(
self.observation_space,
self.action_space,
self.config,
existing_inputs=existing_inputs)
@override(TFPolicyGraph)
def gradients(self, optimizer, loss):
if self.config["grad_clip"] is not None:
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
tf.get_variable_scope().name)
grads = tf.gradients(loss, self.var_list)
self.grads, _ = tf.clip_by_global_norm(grads,
self.config["grad_clip"])
clipped_grads = list(zip(self.grads, self.var_list))
return clipped_grads
else:
return optimizer.compute_gradients(
loss, colocate_gradients_with_ops=True)
@override(PolicyGraph)
def get_initial_state(self):
return self.model.state_init
@override(TFPolicyGraph)
def extra_compute_grad_fetches(self):
return {LEARNER_STATS_KEY: self.stats_fetches}
def update_kl(self, sampled_kl):
if sampled_kl > 2.0 * self.kl_target:
self.kl_coeff_val *= 1.5
elif sampled_kl < 0.5 * self.kl_target:
self.kl_coeff_val *= 0.5
self.kl_coeff.load(self.kl_coeff_val, session=self.sess)
return self.kl_coeff_val
self.value_function = tf.zeros(shape=tf.shape(self._obs_input)[:1])
def _value(self, ob, prev_action, prev_reward, *args):
feed_dict = {
self.observations: [ob],
self.prev_actions: [prev_action],
self.prev_rewards: [prev_reward],
self._obs_input: [ob],
self._prev_action_input: [prev_action],
self._prev_reward_input: [prev_reward],
self.model.seq_lens: [1]
}
assert len(args) == len(self.model.state_in), \
(args, self.model.state_in)
for k, v in zip(self.model.state_in, args):
feed_dict[k] = v
vf = self.sess.run(self.value_function, feed_dict)
vf = self._sess.run(self.value_function, feed_dict)
return vf[0]
def setup_mixins(policy, obs_space, action_space, config):
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
KLCoeffMixin.__init__(policy, config)
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
PPOTFPolicy = build_tf_policy(
name="PPOTFPolicy",
get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG,
loss_fn=ppo_surrogate_loss,
stats_fn=kl_and_loss_stats,
extra_action_fetches_fn=vf_preds_and_logits_fetches,
postprocess_fn=postprocess_ppo_gae,
gradients_fn=clip_gradients,
before_loss_init=setup_mixins,
mixins=[LearningRateSchedule, KLCoeffMixin, ValueNetworkMixin])
@@ -0,0 +1,97 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.trainer import Trainer
from ray.rllib.optimizers import SyncSamplesOptimizer
from ray.rllib.utils.annotations import override, DeveloperAPI
@DeveloperAPI
def build_trainer(name,
default_policy,
default_config=None,
make_policy_optimizer=None,
validate_config=None,
get_policy_class=None,
before_train_step=None,
after_optimizer_step=None,
after_train_result=None):
"""Helper function for defining a custom trainer.
Arguments:
name (str): name of the trainer (e.g., "PPO")
default_policy (cls): the default PolicyGraph class to use
default_config (dict): the default config dict of the algorithm,
otherwises uses the Trainer default config
make_policy_optimizer (func): optional function that returns a
PolicyOptimizer instance given
(local_evaluator, remote_evaluators, config)
validate_config (func): optional callback that checks a given config
for correctness. It may mutate the config as needed.
get_policy_class (func): optional callback that takes a config and
returns the policy graph class to override the default with
before_train_step (func): optional callback to run before each train()
call. It takes the trainer instance as an argument.
after_optimizer_step (func): optional callback to run after each
step() call to the policy optimizer. It takes the trainer instance
and the policy gradient fetches as arguments.
after_train_result (func): optional callback to run at the end of each
train() call. It takes the trainer instance and result dict as
arguments, and may mutate the result dict as needed.
Returns:
a Trainer instance that uses the specified args.
"""
if name.endswith("Trainer"):
raise ValueError("Algorithm name should not include *Trainer suffix",
name)
class trainer_cls(Trainer):
_name = name
_default_config = default_config or Trainer.COMMON_CONFIG
_policy_graph = default_policy
def _init(self, config, env_creator):
if validate_config:
validate_config(config)
if get_policy_class is None:
policy_graph = default_policy
else:
policy_graph = get_policy_class(config)
self.local_evaluator = self.make_local_evaluator(
env_creator, policy_graph)
self.remote_evaluators = self.make_remote_evaluators(
env_creator, policy_graph, config["num_workers"])
if make_policy_optimizer:
self.optimizer = make_policy_optimizer(
self.local_evaluator, self.remote_evaluators, config)
else:
optimizer_config = dict(
config["optimizer"],
**{"train_batch_size": config["train_batch_size"]})
self.optimizer = SyncSamplesOptimizer(self.local_evaluator,
self.remote_evaluators,
**optimizer_config)
@override(Trainer)
def _train(self):
if before_train_step:
before_train_step(self)
prev_steps = self.optimizer.num_steps_sampled
fetches = self.optimizer.step()
if after_optimizer_step:
after_optimizer_step(self, fetches)
res = self.collect_metrics()
res.update(
timesteps_this_iter=self.optimizer.num_steps_sampled -
prev_steps,
info=res.get("info", {}))
if after_train_result:
after_train_result(self, res)
return res
trainer_cls.__name__ = name + "Trainer"
trainer_cls.__qualname__ = name + "Trainer"
return trainer_cls
@@ -0,0 +1,275 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
import logging
import numpy as np
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.utils.annotations import override
from ray.rllib.utils import try_import_tf
from ray.rllib.utils.debug import log_once, summarize
from ray.rllib.utils.tracking_dict import UsageTrackingDict
tf = try_import_tf()
logger = logging.getLogger(__name__)
class DynamicTFPolicyGraph(TFPolicyGraph):
"""A TFPolicyGraph that auto-defines placeholders dynamically at runtime.
Initialization of this class occurs in two phases.
* Phase 1: the model is created and model variables are initialized.
* Phase 2: a fake batch of data is created, sent to the trajectory
postprocessor, and then used to create placeholders for the loss
function. The loss and stats functions are initialized with these
placeholders.
"""
def __init__(self,
obs_space,
action_space,
config,
loss_fn,
stats_fn=None,
grad_stats_fn=None,
before_loss_init=None,
make_action_sampler=None,
existing_inputs=None,
get_batch_divisibility_req=None):
"""Initialize a dynamic TF policy graph.
Arguments:
observation_space (gym.Space): Observation space of the policy.
action_space (gym.Space): Action space of the policy.
config (dict): Policy-specific configuration data.
loss_fn (func): function that returns a loss tensor the policy
graph, and dict of experience tensor placeholders
stats_fn (func): optional function that returns a dict of
TF fetches given the policy graph and batch input tensors
grad_stats_fn (func): optional function that returns a dict of
TF fetches given the policy graph and loss gradient tensors
before_loss_init (func): optional function to run prior to loss
init that takes the same arguments as __init__
make_action_sampler (func): optional function that returns a
tuple of action and action prob tensors. The function takes
(policy, input_dict, obs_space, action_space, config) as its
arguments
existing_inputs (OrderedDict): when copying a policy graph, this
specifies an existing dict of placeholders to use instead of
defining new ones
get_batch_divisibility_req (func): optional function that returns
the divisibility requirement for sample batches
"""
self.config = config
self._loss_fn = loss_fn
self._stats_fn = stats_fn
self._grad_stats_fn = grad_stats_fn
# Setup standard placeholders
if existing_inputs is not None:
obs = existing_inputs[SampleBatch.CUR_OBS]
prev_actions = existing_inputs[SampleBatch.PREV_ACTIONS]
prev_rewards = existing_inputs[SampleBatch.PREV_REWARDS]
else:
obs = tf.placeholder(
tf.float32,
shape=[None] + list(obs_space.shape),
name="observation")
prev_actions = ModelCatalog.get_action_placeholder(action_space)
prev_rewards = tf.placeholder(
tf.float32, [None], name="prev_reward")
input_dict = {
"obs": obs,
"prev_actions": prev_actions,
"prev_rewards": prev_rewards,
"is_training": self._get_is_training_placeholder(),
}
# Create the model network and action outputs
if make_action_sampler:
assert not existing_inputs, \
"Cloning not supported with custom action sampler"
self.model = None
self.dist_class = None
self.action_dist = None
action_sampler, action_prob = make_action_sampler(
self, input_dict, obs_space, action_space, config)
else:
self.dist_class, logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])
if existing_inputs:
existing_state_in = [
v for k, v in existing_inputs.items()
if k.startswith("state_in_")
]
if existing_state_in:
existing_seq_lens = existing_inputs["seq_lens"]
else:
existing_seq_lens = None
else:
existing_state_in = []
existing_seq_lens = None
self.model = ModelCatalog.get_model(
input_dict,
obs_space,
action_space,
logit_dim,
self.config["model"],
state_in=existing_state_in,
seq_lens=existing_seq_lens)
self.action_dist = self.dist_class(self.model.outputs)
action_sampler = self.action_dist.sample()
action_prob = self.action_dist.sampled_action_prob()
# Phase 1 init
sess = tf.get_default_session()
if get_batch_divisibility_req:
batch_divisibility_req = get_batch_divisibility_req(self)
else:
batch_divisibility_req = 1
TFPolicyGraph.__init__(
self,
obs_space,
action_space,
sess,
obs_input=obs,
action_sampler=action_sampler,
action_prob=action_prob,
loss=None, # dynamically initialized on run
loss_inputs=[],
model=self.model,
state_inputs=self.model and self.model.state_in,
state_outputs=self.model and self.model.state_out,
prev_action_input=prev_actions,
prev_reward_input=prev_rewards,
seq_lens=self.model and self.model.seq_lens,
max_seq_len=config["model"]["max_seq_len"],
batch_divisibility_req=batch_divisibility_req)
# Phase 2 init
before_loss_init(self, obs_space, action_space, config)
if not existing_inputs:
self._initialize_loss()
@override(TFPolicyGraph)
def copy(self, existing_inputs):
"""Creates a copy of self using existing input placeholders."""
# Note that there might be RNN state inputs at the end of the list
if self._state_inputs:
num_state_inputs = len(self._state_inputs) + 1
else:
num_state_inputs = 0
if len(self._loss_inputs) + num_state_inputs != len(existing_inputs):
raise ValueError("Tensor list mismatch", self._loss_inputs,
self._state_inputs, existing_inputs)
for i, (k, v) in enumerate(self._loss_inputs):
if v.shape.as_list() != existing_inputs[i].shape.as_list():
raise ValueError("Tensor shape mismatch", i, k, v.shape,
existing_inputs[i].shape)
# By convention, the loss inputs are followed by state inputs and then
# the seq len tensor
rnn_inputs = []
for i in range(len(self._state_inputs)):
rnn_inputs.append(("state_in_{}".format(i),
existing_inputs[len(self._loss_inputs) + i]))
if rnn_inputs:
rnn_inputs.append(("seq_lens", existing_inputs[-1]))
input_dict = OrderedDict(
[(k, existing_inputs[i])
for i, (k, _) in enumerate(self._loss_inputs)] + rnn_inputs)
instance = self.__class__(
self.observation_space,
self.action_space,
self.config,
existing_inputs=input_dict)
loss = instance._loss_fn(instance, input_dict)
if instance._stats_fn:
instance._stats_fetches.update(
instance._stats_fn(instance, input_dict))
TFPolicyGraph._initialize_loss(
instance, loss, [(k, existing_inputs[i])
for i, (k, _) in enumerate(self._loss_inputs)])
if instance._grad_stats_fn:
instance._stats_fetches.update(
instance._grad_stats_fn(instance, instance._grads))
return instance
@override(PolicyGraph)
def get_initial_state(self):
if self.model:
return self.model.state_init
else:
return []
def _initialize_loss(self):
def fake_array(tensor):
shape = tensor.shape.as_list()
shape[0] = 1
return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype)
dummy_batch = {
SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input),
SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input),
SampleBatch.CUR_OBS: fake_array(self._obs_input),
SampleBatch.NEXT_OBS: fake_array(self._obs_input),
SampleBatch.ACTIONS: fake_array(self._prev_action_input),
SampleBatch.REWARDS: np.array([0], dtype=np.float32),
SampleBatch.DONES: np.array([False], dtype=np.bool),
}
state_init = self.get_initial_state()
for i, h in enumerate(state_init):
dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0)
dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0)
if state_init:
dummy_batch["seq_lens"] = np.array([1], dtype=np.int32)
for k, v in self.extra_compute_action_fetches().items():
dummy_batch[k] = fake_array(v)
# postprocessing might depend on variable init, so run it first here
self._sess.run(tf.global_variables_initializer())
postprocessed_batch = self.postprocess_trajectory(
SampleBatch(dummy_batch))
batch_tensors = UsageTrackingDict({
SampleBatch.PREV_ACTIONS: self._prev_action_input,
SampleBatch.PREV_REWARDS: self._prev_reward_input,
SampleBatch.CUR_OBS: self._obs_input,
})
loss_inputs = [
(SampleBatch.PREV_ACTIONS, self._prev_action_input),
(SampleBatch.PREV_REWARDS, self._prev_reward_input),
(SampleBatch.CUR_OBS, self._obs_input),
]
for k, v in postprocessed_batch.items():
if k in batch_tensors:
continue
elif v.dtype == np.object:
continue # can't handle arbitrary objects in TF
shape = (None, ) + v.shape[1:]
dtype = np.float32 if v.dtype == np.float64 else v.dtype
placeholder = tf.placeholder(dtype, shape=shape, name=k)
batch_tensors[k] = placeholder
if log_once("loss_init"):
logger.info(
"Initializing loss function with dummy input:\n\n{}\n".format(
summarize(batch_tensors)))
loss = self._loss_fn(self, batch_tensors)
if self._stats_fn:
self._stats_fetches.update(self._stats_fn(self, batch_tensors))
for k in sorted(batch_tensors.accessed_keys):
loss_inputs.append((k, batch_tensors[k]))
TFPolicyGraph._initialize_loss(self, loss, loss_inputs)
if self._grad_stats_fn:
self._stats_fetches.update(self._grad_stats_fn(self, self._grads))
self._sess.run(tf.global_variables_initializer())
@@ -65,7 +65,7 @@ class PolicyEvaluator(EvaluatorInterface):
>>> # Create a policy evaluator and using it to collect experiences.
>>> evaluator = PolicyEvaluator(
... env_creator=lambda _: gym.make("CartPole-v0"),
... policy_graph=PGPolicyGraph)
... policy_graph=PGTFPolicy)
>>> print(evaluator.sample())
SampleBatch({
"obs": [[...]], "actions": [[...]], "rewards": [[...]],
@@ -76,7 +76,7 @@ class PolicyEvaluator(EvaluatorInterface):
... evaluator_cls=PolicyEvaluator,
... evaluator_args={
... "env_creator": lambda _: gym.make("CartPole-v0"),
... "policy_graph": PGPolicyGraph,
... "policy_graph": PGTFPolicy,
... },
... num_workers=10)
>>> for _ in range(10): optimizer.step()
@@ -87,12 +87,12 @@ class PolicyEvaluator(EvaluatorInterface):
... policy_graphs={
... # Use an ensemble of two policies for car agents
... "car_policy1":
... (PGPolicyGraph, Box(...), Discrete(...), {"gamma": 0.99}),
... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.99}),
... "car_policy2":
... (PGPolicyGraph, Box(...), Discrete(...), {"gamma": 0.95}),
... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.95}),
... # Use a single shared policy for all traffic lights
... "traffic_light_policy":
... (PGPolicyGraph, Box(...), Discrete(...), {}),
... (PGTFPolicy, Box(...), Discrete(...), {}),
... },
... policy_mapping_fn=lambda agent_id:
... random.choice(["car_policy1", "car_policy2"])
+50 -34
View File
@@ -112,46 +112,20 @@ class TFPolicyGraph(PolicyGraph):
self._prev_action_input = prev_action_input
self._prev_reward_input = prev_reward_input
self._sampler = action_sampler
self._loss_inputs = loss_inputs
self._loss_input_dict = dict(self._loss_inputs)
self._is_training = self._get_is_training_placeholder()
self._action_prob = action_prob
self._state_inputs = state_inputs or []
self._state_outputs = state_outputs or []
for i, ph in enumerate(self._state_inputs):
self._loss_input_dict["state_in_{}".format(i)] = ph
self._seq_lens = seq_lens
self._max_seq_len = max_seq_len
self._batch_divisibility_req = batch_divisibility_req
self._update_ops = update_ops
self._stats_fetches = {}
if self.model:
self._loss = self.model.custom_loss(loss, self._loss_input_dict)
self._stats_fetches = {"model": self.model.custom_stats()}
if loss is not None:
self._initialize_loss(loss, loss_inputs)
else:
self._loss = loss
self._stats_fetches = {}
self._optimizer = self.optimizer()
self._grads_and_vars = [
(g, v) for (g, v) in self.gradients(self._optimizer, self._loss)
if g is not None
]
self._grads = [g for (g, v) in self._grads_and_vars]
self._variables = ray.experimental.tf_utils.TensorFlowVariables(
self._loss, self._sess)
# gather update ops for any batch norm layers
if update_ops:
self._update_ops = update_ops
else:
self._update_ops = tf.get_collection(
tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name)
if self._update_ops:
logger.debug("Update ops to run on apply gradient: {}".format(
self._update_ops))
with tf.control_dependencies(self._update_ops):
self._apply_op = self.build_apply_op(self._optimizer,
self._grads_and_vars)
self._loss = None
if len(self._state_inputs) != len(self._state_outputs):
raise ValueError(
@@ -166,8 +140,44 @@ class TFPolicyGraph(PolicyGraph):
raise ValueError(
"seq_lens tensor must be given if state inputs are defined")
logger.debug("Created {} with loss inputs: {}".format(
self, self._loss_input_dict))
def _initialize_loss(self, loss, loss_inputs):
self._loss_inputs = loss_inputs
self._loss_input_dict = dict(self._loss_inputs)
for i, ph in enumerate(self._state_inputs):
self._loss_input_dict["state_in_{}".format(i)] = ph
if self.model:
self._loss = self.model.custom_loss(loss, self._loss_input_dict)
self._stats_fetches.update({"model": self.model.custom_stats()})
else:
self._loss = loss
self._optimizer = self.optimizer()
self._grads_and_vars = [
(g, v) for (g, v) in self.gradients(self._optimizer, self._loss)
if g is not None
]
self._grads = [g for (g, v) in self._grads_and_vars]
self._variables = ray.experimental.tf_utils.TensorFlowVariables(
self._loss, self._sess)
# gather update ops for any batch norm layers
if not self._update_ops:
self._update_ops = tf.get_collection(
tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name)
if self._update_ops:
logger.debug("Update ops to run on apply gradient: {}".format(
self._update_ops))
with tf.control_dependencies(self._update_ops):
self._apply_op = self.build_apply_op(self._optimizer,
self._grads_and_vars)
if log_once("loss_used"):
logger.debug(
"These tensors were used in the loss_fn:\n\n{}\n".format(
summarize(self._loss_input_dict)))
self._sess.run(tf.global_variables_initializer())
@override(PolicyGraph)
def compute_actions(self,
@@ -186,18 +196,21 @@ class TFPolicyGraph(PolicyGraph):
@override(PolicyGraph)
def compute_gradients(self, postprocessed_batch):
assert self._loss is not None, "Loss not initialized"
builder = TFRunBuilder(self._sess, "compute_gradients")
fetches = self._build_compute_gradients(builder, postprocessed_batch)
return builder.get(fetches)
@override(PolicyGraph)
def apply_gradients(self, gradients):
assert self._loss is not None, "Loss not initialized"
builder = TFRunBuilder(self._sess, "apply_gradients")
fetches = self._build_apply_gradients(builder, gradients)
builder.get(fetches)
@override(PolicyGraph)
def learn_on_batch(self, postprocessed_batch):
assert self._loss is not None, "Loss not initialized"
builder = TFRunBuilder(self._sess, "learn_on_batch")
fetches = self._build_learn_on_batch(builder, postprocessed_batch)
return builder.get(fetches)
@@ -271,7 +284,10 @@ class TFPolicyGraph(PolicyGraph):
@DeveloperAPI
def optimizer(self):
"""TF optimizer to use for policy optimization."""
return tf.train.AdamOptimizer()
if hasattr(self, "config"):
return tf.train.AdamOptimizer(self.config["lr"])
else:
return tf.train.AdamOptimizer()
@DeveloperAPI
def gradients(self, optimizer, loss):
@@ -0,0 +1,146 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.evaluation.dynamic_tf_policy_graph import DynamicTFPolicyGraph
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.utils.annotations import override, DeveloperAPI
@DeveloperAPI
def build_tf_policy(name,
loss_fn,
get_default_config=None,
stats_fn=None,
grad_stats_fn=None,
extra_action_fetches_fn=None,
postprocess_fn=None,
optimizer_fn=None,
gradients_fn=None,
before_init=None,
before_loss_init=None,
after_init=None,
make_action_sampler=None,
mixins=None,
get_batch_divisibility_req=None):
"""Helper function for creating a dynamic tf policy at runtime.
Arguments:
name (str): name of the graph (e.g., "PPOPolicy")
loss_fn (func): function that returns a loss tensor the policy,
and dict of experience tensor placeholders
get_default_config (func): optional function that returns the default
config to merge with any overrides
stats_fn (func): optional function that returns a dict of
TF fetches given the policy and batch input tensors
grad_stats_fn (func): optional function that returns a dict of
TF fetches given the policy and loss gradient tensors
extra_action_fetches_fn (func): optional function that returns
a dict of TF fetches given the policy object
postprocess_fn (func): optional experience postprocessing function
that takes the same args as PolicyGraph.postprocess_trajectory()
optimizer_fn (func): optional function that returns a tf.Optimizer
given the policy and config
gradients_fn (func): optional function that returns a list of gradients
given a tf optimizer and loss tensor. If not specified, this
defaults to optimizer.compute_gradients(loss)
before_init (func): optional function to run at the beginning of
policy init that takes the same arguments as the policy constructor
before_loss_init (func): optional function to run prior to loss
init that takes the same arguments as the policy constructor
after_init (func): optional function to run at the end of policy init
that takes the same arguments as the policy constructor
make_action_sampler (func): optional function that returns a
tuple of action and action prob tensors. The function takes
(policy, input_dict, obs_space, action_space, config) as its
arguments
mixins (list): list of any class mixins for the returned policy class.
These mixins will be applied in order and will have higher
precedence than the DynamicTFPolicyGraph class
get_batch_divisibility_req (func): optional function that returns
the divisibility requirement for sample batches
Returns:
a DynamicTFPolicyGraph instance that uses the specified args
"""
if not name.endswith("TFPolicy"):
raise ValueError("Name should match *TFPolicy", name)
base = DynamicTFPolicyGraph
while mixins:
class new_base(mixins.pop(), base):
pass
base = new_base
class graph_cls(base):
def __init__(self,
obs_space,
action_space,
config,
existing_inputs=None):
if get_default_config:
config = dict(get_default_config(), **config)
if before_init:
before_init(self, obs_space, action_space, config)
def before_loss_init_wrapper(policy, obs_space, action_space,
config):
if before_loss_init:
before_loss_init(policy, obs_space, action_space, config)
if extra_action_fetches_fn is None:
self._extra_action_fetches = {}
else:
self._extra_action_fetches = extra_action_fetches_fn(self)
DynamicTFPolicyGraph.__init__(
self,
obs_space,
action_space,
config,
loss_fn,
stats_fn=stats_fn,
grad_stats_fn=grad_stats_fn,
before_loss_init=before_loss_init_wrapper,
existing_inputs=existing_inputs)
if after_init:
after_init(self, obs_space, action_space, config)
@override(PolicyGraph)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
if not postprocess_fn:
return sample_batch
return postprocess_fn(self, sample_batch, other_agent_batches,
episode)
@override(TFPolicyGraph)
def optimizer(self):
if optimizer_fn:
return optimizer_fn(self, self.config)
else:
return TFPolicyGraph.optimizer(self)
@override(TFPolicyGraph)
def gradients(self, optimizer, loss):
if gradients_fn:
return gradients_fn(self, optimizer, loss)
else:
return TFPolicyGraph.gradients(self, optimizer, loss)
@override(TFPolicyGraph)
def extra_compute_action_fetches(self):
return dict(
TFPolicyGraph.extra_compute_action_fetches(self),
**self._extra_action_fetches)
graph_cls.__name__ = name
graph_cls.__qualname__ = name
return graph_cls
@@ -15,6 +15,7 @@ except ImportError:
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.utils.annotations import override
from ray.rllib.utils.tracking_dict import UsageTrackingDict
class TorchPolicyGraph(PolicyGraph):
@@ -30,7 +31,7 @@ class TorchPolicyGraph(PolicyGraph):
"""
def __init__(self, observation_space, action_space, model, loss,
loss_inputs, action_distribution_cls):
action_distribution_cls):
"""Build a policy graph from policy and loss torch modules.
Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES
@@ -42,13 +43,8 @@ class TorchPolicyGraph(PolicyGraph):
model (nn.Module): PyTorch policy module. Given observations as
input, this module must return a list of outputs where the
first item is action logits, and the rest can be any value.
loss (nn.Module): Loss defined as a PyTorch module. The inputs for
this module are defined by the `loss_inputs` param. This module
returns a single scalar loss. Note that this module should
internally be using the model module.
loss_inputs (list): List of SampleBatch columns that will be
passed to the loss module's forward() function when computing
the loss. For example, ["obs", "action", "advantages"].
loss (func): Function that takes (policy_graph, batch_tensors)
and returns a single scalar loss.
action_distribution_cls (ActionDistribution): Class for action
distribution.
"""
@@ -60,7 +56,6 @@ class TorchPolicyGraph(PolicyGraph):
else torch.device("cpu"))
self._model = model.to(self.device)
self._loss = loss
self._loss_inputs = loss_inputs
self._optimizer = self.optimizer()
self._action_dist_cls = action_distribution_cls
@@ -87,30 +82,26 @@ class TorchPolicyGraph(PolicyGraph):
@override(PolicyGraph)
def learn_on_batch(self, postprocessed_batch):
batch_tensors = self._lazy_tensor_dict(postprocessed_batch)
with self.lock:
loss_in = []
for key in self._loss_inputs:
loss_in.append(
torch.from_numpy(postprocessed_batch[key]).to(self.device))
loss_out = self._loss(self._model, *loss_in)
loss_out = self._loss(self, batch_tensors)
self._optimizer.zero_grad()
loss_out.backward()
grad_process_info = self.extra_grad_process()
self._optimizer.step()
grad_info = self.extra_grad_info()
grad_info = self.extra_grad_info(batch_tensors)
grad_info.update(grad_process_info)
return {LEARNER_STATS_KEY: grad_info}
@override(PolicyGraph)
def compute_gradients(self, postprocessed_batch):
batch_tensors = self._lazy_tensor_dict(postprocessed_batch)
with self.lock:
loss_in = []
for key in self._loss_inputs:
loss_in.append(
torch.from_numpy(postprocessed_batch[key]).to(self.device))
loss_out = self._loss(self._model, *loss_in)
loss_out = self._loss(self, batch_tensors)
self._optimizer.zero_grad()
loss_out.backward()
@@ -125,7 +116,7 @@ class TorchPolicyGraph(PolicyGraph):
else:
grads.append(None)
grad_info = self.extra_grad_info()
grad_info = self.extra_grad_info(batch_tensors)
grad_info.update(grad_process_info)
return grads, {LEARNER_STATS_KEY: grad_info}
@@ -163,11 +154,21 @@ class TorchPolicyGraph(PolicyGraph):
model_out (list): Outputs of the policy model module."""
return {}
def extra_grad_info(self):
def extra_grad_info(self, batch_tensors):
"""Return dict of extra grad info."""
return {}
def optimizer(self):
"""Custom PyTorch optimizer to use."""
return torch.optim.Adam(self._model.parameters())
if hasattr(self, "config"):
return torch.optim.Adam(
self._model.parameters(), lr=self.config["lr"])
else:
return torch.optim.Adam(self._model.parameters())
def _lazy_tensor_dict(self, postprocessed_batch):
batch_tensors = UsageTrackingDict(postprocessed_batch)
batch_tensors.set_get_interceptor(
lambda arr: torch.from_numpy(arr).to(self.device))
return batch_tensors
@@ -0,0 +1,133 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.torch_policy_graph import TorchPolicyGraph
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.utils.annotations import override, DeveloperAPI
@DeveloperAPI
def build_torch_policy(name,
loss_fn,
get_default_config=None,
stats_fn=None,
postprocess_fn=None,
extra_action_out_fn=None,
extra_grad_process_fn=None,
optimizer_fn=None,
before_init=None,
after_init=None,
make_model_and_action_dist=None,
mixins=None):
"""Helper function for creating a torch policy at runtime.
Arguments:
name (str): name of the graph (e.g., "PPOPolicy")
loss_fn (func): function that returns a loss tensor the policy,
and dict of experience tensor placeholders
get_default_config (func): optional function that returns the default
config to merge with any overrides
stats_fn (func): optional function that returns a dict of
values given the policy and batch input tensors
postprocess_fn (func): optional experience postprocessing function
that takes the same args as PolicyGraph.postprocess_trajectory()
extra_action_out_fn (func): optional function that returns
a dict of extra values to include in experiences
extra_grad_process_fn (func): optional function that is called after
gradients are computed and returns processing info
optimizer_fn (func): optional function that returns a torch optimizer
given the policy and config
before_init (func): optional function to run at the beginning of
policy init that takes the same arguments as the policy constructor
after_init (func): optional function to run at the end of policy init
that takes the same arguments as the policy constructor
make_model_and_action_dist (func): optional func that takes the same
arguments as policy init and returns a tuple of model instance and
torch action distribution class. If not specified, the default
model and action dist from the catalog will be used
mixins (list): list of any class mixins for the returned policy class.
These mixins will be applied in order and will have higher
precedence than the TorchPolicyGraph class
Returns:
a TorchPolicyGraph instance that uses the specified args
"""
if not name.endswith("TorchPolicy"):
raise ValueError("Name should match *TorchPolicy", name)
base = TorchPolicyGraph
while mixins:
class new_base(mixins.pop(), base):
pass
base = new_base
class graph_cls(base):
def __init__(self, obs_space, action_space, config):
if get_default_config:
config = dict(get_default_config(), **config)
self.config = config
if before_init:
before_init(self, obs_space, action_space, config)
if make_model_and_action_dist:
self.model, self.dist_class = make_model_and_action_dist(
self, obs_space, action_space, config)
else:
self.dist_class, logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"], torch=True)
self.model = ModelCatalog.get_torch_model(
obs_space, logit_dim, self.config["model"])
TorchPolicyGraph.__init__(self, obs_space, action_space,
self.model, loss_fn, self.dist_class)
if after_init:
after_init(self, obs_space, action_space, config)
@override(PolicyGraph)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
if not postprocess_fn:
return sample_batch
return postprocess_fn(self, sample_batch, other_agent_batches,
episode)
@override(TorchPolicyGraph)
def extra_grad_process(self):
if extra_grad_process_fn:
return extra_grad_process_fn(self)
else:
return TorchPolicyGraph.extra_grad_process(self)
@override(TorchPolicyGraph)
def extra_action_out(self, model_out):
if extra_action_out_fn:
return extra_action_out_fn(self, model_out)
else:
return TorchPolicyGraph.extra_action_out(self, model_out)
@override(TorchPolicyGraph)
def optimizer(self):
if optimizer_fn:
return optimizer_fn(self, self.config)
else:
return TorchPolicyGraph.optimizer(self)
@override(TorchPolicyGraph)
def extra_grad_info(self, batch_tensors):
if stats_fn:
return stats_fn(self, batch_tensors)
else:
return TorchPolicyGraph.extra_grad_info(self, batch_tensors)
graph_cls.__name__ = name
graph_cls.__qualname__ = name
return graph_cls
@@ -18,7 +18,7 @@ import ray
from ray.rllib.agents.dqn.dqn import DQNTrainer
from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
from ray.rllib.agents.ppo.ppo import PPOTrainer
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
from ray.rllib.agents.ppo.ppo_policy_graph import PPOTFPolicy
from ray.rllib.tests.test_multi_agent_env import MultiCartpole
from ray.tune.logger import pretty_print
from ray.tune.registry import register_env
@@ -39,7 +39,7 @@ if __name__ == "__main__":
# You can also have multiple policy graphs per trainer, but here we just
# show one each for PPO and DQN.
policy_graphs = {
"ppo_policy": (PPOPolicyGraph, obs_space, act_space, {}),
"ppo_policy": (PPOTFPolicy, obs_space, act_space, {}),
"dqn_policy": (DQNPolicyGraph, obs_space, act_space, {}),
}
@@ -255,7 +255,7 @@ class LocalSyncParallelOptimizer(object):
fetches = {"train": self._train_op}
for tower in self._towers:
fetches.update(tower.loss_graph.extra_compute_grad_fetches())
fetches.update(tower.loss_graph._get_grad_and_stats_fetches())
return sess.run(fetches, feed_dict=feed_dict)
@@ -222,6 +222,6 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
def _averaged(kv):
out = {}
for k, v in kv.items():
if v[0] is not None:
if v[0] is not None and not isinstance(v[0], dict):
out[k] = np.mean(v)
return out
@@ -8,7 +8,7 @@ import random
import unittest
import ray
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy
from ray.rllib.optimizers import SyncSamplesOptimizer
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
@@ -67,7 +67,7 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
obs_space = single_env.observation_space
policies = {}
for i in range(20):
policies["pg_{}".format(i)] = (PGPolicyGraph, obs_space, act_space,
policies["pg_{}".format(i)] = (PGTFPolicy, obs_space, act_space,
{})
policy_ids = list(policies.keys())
ev = PolicyEvaluator(
+2 -2
View File
@@ -15,7 +15,7 @@ import unittest
import ray
from ray.rllib.agents.pg import PGTrainer
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy
from ray.rllib.evaluation import SampleBatch
from ray.rllib.offline import IOContext, JsonWriter, JsonReader
from ray.rllib.offline.json_writer import _to_json
@@ -159,7 +159,7 @@ class AgentIOTest(unittest.TestCase):
def gen_policy():
obs_space = single_env.observation_space
act_space = single_env.action_space
return (PGPolicyGraph, obs_space, act_space, {})
return (PGTFPolicy, obs_space, act_space, {})
pg = PGTrainer(
env="multi_cartpole",
@@ -8,7 +8,7 @@ import unittest
import ray
from ray.rllib.agents.pg import PGTrainer
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy
from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
from ray.rllib.optimizers import (SyncSamplesOptimizer, SyncReplayOptimizer,
AsyncGradientsOptimizer)
@@ -470,7 +470,7 @@ class TestMultiAgentEnv(unittest.TestCase):
self.assertEqual(batch["state_out_0"][1], h)
def testReturningModelBasedRolloutsData(self):
class ModelBasedPolicyGraph(PGPolicyGraph):
class ModelBasedPolicyGraph(PGTFPolicy):
def compute_actions(self,
obs_batch,
state_batches,
@@ -584,7 +584,7 @@ class TestMultiAgentEnv(unittest.TestCase):
}
else:
policies = {
"p1": (PGPolicyGraph, obs_space, act_space, {}),
"p1": (PGTFPolicy, obs_space, act_space, {}),
"p2": (DQNPolicyGraph, obs_space, act_space, dqn_config),
}
ev = PolicyEvaluator(
@@ -640,7 +640,7 @@ class TestMultiAgentEnv(unittest.TestCase):
obs_space = env.observation_space
policies = {}
for i in range(20):
policies["pg_{}".format(i)] = (PGPolicyGraph, obs_space, act_space,
policies["pg_{}".format(i)] = (PGTFPolicy, obs_space, act_space,
{})
policy_ids = list(policies.keys())
ev = PolicyEvaluator(
+3 -3
View File
@@ -12,7 +12,7 @@ import unittest
import ray
from ray.rllib.agents.a3c import A2CTrainer
from ray.rllib.agents.pg import PGTrainer
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy
from ray.rllib.env import MultiAgentEnv
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.vector_env import VectorEnv
@@ -333,10 +333,10 @@ class NestedSpacesTest(unittest.TestCase):
"multiagent": {
"policy_graphs": {
"tuple_policy": (
PGPolicyGraph, TUPLE_SPACE, act_space,
PGTFPolicy, TUPLE_SPACE, act_space,
{"model": {"custom_model": "tuple_spy"}}),
"dict_policy": (
PGPolicyGraph, DICT_SPACE, act_space,
PGTFPolicy, DICT_SPACE, act_space,
{"model": {"custom_model": "dict_spy"}}),
},
"policy_mapping_fn": lambda a: {
+3 -3
View File
@@ -9,7 +9,7 @@ import unittest
import ray
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
from ray.rllib.agents.ppo.ppo_policy_graph import PPOTFPolicy
from ray.rllib.evaluation import SampleBatch
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.optimizers import AsyncGradientsOptimizer, AsyncSamplesOptimizer
@@ -240,12 +240,12 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
local = PolicyEvaluator(
env_creator=lambda _: gym.make("CartPole-v0"),
policy_graph=PPOPolicyGraph,
policy_graph=PPOTFPolicy,
tf_session_creator=make_sess)
remotes = [
PolicyEvaluator.as_remote().remote(
env_creator=lambda _: gym.make("CartPole-v0"),
policy_graph=PPOPolicyGraph,
policy_graph=PPOTFPolicy,
tf_session_creator=make_sess)
]
return local, remotes
+32
View File
@@ -0,0 +1,32 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class UsageTrackingDict(dict):
"""Dict that tracks which keys have been accessed.
It can also intercept gets and allow an arbitrary callback to be applied
(i.e., to lazily convert numpy arrays to Tensors).
We make the simplifying assumption only __getitem__ is used to access
values.
"""
def __init__(self, *args, **kwargs):
dict.__init__(self, *args, **kwargs)
self.accessed_keys = set()
self.intercepted_values = {}
self.get_interceptor = None
def set_get_interceptor(self, fn):
self.get_interceptor = fn
def __getitem__(self, key):
self.accessed_keys.add(key)
value = dict.__getitem__(self, key)
if self.get_interceptor:
if key not in self.intercepted_values:
self.intercepted_values[key] = self.get_interceptor(value)
value = self.intercepted_values[key]
return value