mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 13:15:35 +08:00
[rllib] [RFC] Dynamic definition of loss functions and modularization support (#4795)
* dynamic graph * wip * clean up * fix * document trainer * wip * initialize the graph using a fake batch * clean up dynamic init * wip * spelling * use builder for ppo pol graph * add ppo graph * fix naming * order * docs * set class name correctly * add torch builder * add custom model support in builder * cleanup * remove underscores * fix py2 compat * Update dynamic_tf_policy_graph.py * Update tracking_dict.py * wip * rename * debug level * rename policy_graph -> policy in new classes * fix test * rename ppo tf policy * port appo too * forgot grads * default policy optimizer * make default config optional * add config to optimizer * use lr by default in optimizer * update * comments * remove optimizer * fix tuple actions support in dynamic tf graph
This commit is contained in:
@@ -49,8 +49,8 @@ class A3CTrainer(Trainer):
|
||||
def _init(self, config, env_creator):
|
||||
if config["use_pytorch"]:
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy_graph import \
|
||||
A3CTorchPolicyGraph
|
||||
policy_cls = A3CTorchPolicyGraph
|
||||
A3CTorchPolicy
|
||||
policy_cls = A3CTorchPolicy
|
||||
else:
|
||||
policy_cls = self._policy_graph
|
||||
|
||||
|
||||
@@ -7,109 +7,84 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
import ray
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||
Postprocessing
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.torch_policy_graph import TorchPolicyGraph
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.evaluation.torch_policy_template import build_torch_policy
|
||||
|
||||
|
||||
class A3CLoss(nn.Module):
|
||||
def __init__(self, dist_class, vf_loss_coeff=0.5, entropy_coeff=0.01):
|
||||
nn.Module.__init__(self)
|
||||
self.dist_class = dist_class
|
||||
self.vf_loss_coeff = vf_loss_coeff
|
||||
self.entropy_coeff = entropy_coeff
|
||||
|
||||
def forward(self, policy_model, observations, actions, advantages,
|
||||
value_targets):
|
||||
logits, _, values, _ = policy_model({
|
||||
SampleBatch.CUR_OBS: observations
|
||||
}, [])
|
||||
dist = self.dist_class(logits)
|
||||
log_probs = dist.logp(actions)
|
||||
self.entropy = dist.entropy().mean()
|
||||
self.pi_err = -advantages.dot(log_probs.reshape(-1))
|
||||
self.value_err = F.mse_loss(values.reshape(-1), value_targets)
|
||||
overall_err = sum([
|
||||
self.pi_err,
|
||||
self.vf_loss_coeff * self.value_err,
|
||||
-self.entropy_coeff * self.entropy,
|
||||
])
|
||||
|
||||
return overall_err
|
||||
def actor_critic_loss(policy, batch_tensors):
|
||||
logits, _, values, _ = policy.model({
|
||||
SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS]
|
||||
}, [])
|
||||
dist = policy.dist_class(logits)
|
||||
log_probs = dist.logp(batch_tensors[SampleBatch.ACTIONS])
|
||||
policy.entropy = dist.entropy().mean()
|
||||
policy.pi_err = -batch_tensors[Postprocessing.ADVANTAGES].dot(
|
||||
log_probs.reshape(-1))
|
||||
policy.value_err = F.mse_loss(
|
||||
values.reshape(-1), batch_tensors[Postprocessing.VALUE_TARGETS])
|
||||
overall_err = sum([
|
||||
policy.pi_err,
|
||||
policy.config["vf_loss_coeff"] * policy.value_err,
|
||||
-policy.config["entropy_coeff"] * policy.entropy,
|
||||
])
|
||||
return overall_err
|
||||
|
||||
|
||||
class A3CPostprocessing(object):
|
||||
"""Adds the VF preds and advantages fields to the trajectory."""
|
||||
|
||||
@override(TorchPolicyGraph)
|
||||
def extra_action_out(self, model_out):
|
||||
return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()}
|
||||
|
||||
@override(PolicyGraph)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
completed = sample_batch[SampleBatch.DONES][-1]
|
||||
if completed:
|
||||
last_r = 0.0
|
||||
else:
|
||||
last_r = self._value(sample_batch[SampleBatch.NEXT_OBS][-1])
|
||||
return compute_advantages(sample_batch, last_r, self.config["gamma"],
|
||||
self.config["lambda"])
|
||||
def loss_and_entropy_stats(policy, batch_tensors):
|
||||
return {
|
||||
"policy_entropy": policy.entropy.item(),
|
||||
"policy_loss": policy.pi_err.item(),
|
||||
"vf_loss": policy.value_err.item(),
|
||||
}
|
||||
|
||||
|
||||
class A3CTorchPolicyGraph(A3CPostprocessing, TorchPolicyGraph):
|
||||
"""A simple, non-recurrent PyTorch policy example."""
|
||||
def add_advantages(policy,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
completed = sample_batch[SampleBatch.DONES][-1]
|
||||
if completed:
|
||||
last_r = 0.0
|
||||
else:
|
||||
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1])
|
||||
return compute_advantages(sample_batch, last_r, policy.config["gamma"],
|
||||
policy.config["lambda"])
|
||||
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
|
||||
self.config = config
|
||||
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"], torch=True)
|
||||
model = ModelCatalog.get_torch_model(obs_space, self.logit_dim,
|
||||
self.config["model"])
|
||||
loss = A3CLoss(dist_class, self.config["vf_loss_coeff"],
|
||||
self.config["entropy_coeff"])
|
||||
TorchPolicyGraph.__init__(
|
||||
self,
|
||||
obs_space,
|
||||
action_space,
|
||||
model,
|
||||
loss,
|
||||
loss_inputs=[
|
||||
SampleBatch.CUR_OBS, SampleBatch.ACTIONS,
|
||||
Postprocessing.ADVANTAGES, Postprocessing.VALUE_TARGETS
|
||||
],
|
||||
action_distribution_cls=dist_class)
|
||||
|
||||
@override(TorchPolicyGraph)
|
||||
def optimizer(self):
|
||||
return torch.optim.Adam(self._model.parameters(), lr=self.config["lr"])
|
||||
def model_value_predictions(policy, model_out):
|
||||
return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()}
|
||||
|
||||
@override(TorchPolicyGraph)
|
||||
def extra_grad_process(self):
|
||||
info = {}
|
||||
if self.config["grad_clip"]:
|
||||
total_norm = nn.utils.clip_grad_norm_(self._model.parameters(),
|
||||
self.config["grad_clip"])
|
||||
info["grad_gnorm"] = total_norm
|
||||
return info
|
||||
|
||||
@override(TorchPolicyGraph)
|
||||
def extra_grad_info(self):
|
||||
return {
|
||||
"policy_entropy": self._loss.entropy.item(),
|
||||
"policy_loss": self._loss.pi_err.item(),
|
||||
"vf_loss": self._loss.value_err.item()
|
||||
}
|
||||
def apply_grad_clipping(policy):
|
||||
info = {}
|
||||
if policy.config["grad_clip"]:
|
||||
total_norm = nn.utils.clip_grad_norm_(policy.model.parameters(),
|
||||
policy.config["grad_clip"])
|
||||
info["grad_gnorm"] = total_norm
|
||||
return info
|
||||
|
||||
|
||||
def torch_optimizer(policy, config):
|
||||
return torch.optim.Adam(policy.model.parameters(), lr=config["lr"])
|
||||
|
||||
|
||||
class ValueNetworkMixin(object):
|
||||
def _value(self, obs):
|
||||
with self.lock:
|
||||
obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device)
|
||||
_, _, vf, _ = self._model({"obs": obs}, [])
|
||||
_, _, vf, _ = self.model({"obs": obs}, [])
|
||||
return vf.detach().cpu().numpy().squeeze()
|
||||
|
||||
|
||||
A3CTorchPolicy = build_torch_policy(
|
||||
name="A3CTorchPolicy",
|
||||
get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG,
|
||||
loss_fn=actor_critic_loss,
|
||||
stats_fn=loss_and_entropy_stats,
|
||||
postprocess_fn=add_advantages,
|
||||
extra_action_out_fn=model_value_predictions,
|
||||
extra_grad_process_fn=apply_grad_clipping,
|
||||
optimizer_fn=torch_optimizer,
|
||||
mixins=[ValueNetworkMixin])
|
||||
|
||||
@@ -2,11 +2,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
|
||||
|
||||
from ray.rllib.optimizers import SyncSamplesOptimizer
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
@@ -22,40 +20,16 @@ DEFAULT_CONFIG = with_common_config({
|
||||
# yapf: enable
|
||||
|
||||
|
||||
class PGTrainer(Trainer):
|
||||
"""Simple policy gradient agent.
|
||||
def get_policy_class(config):
|
||||
if config["use_pytorch"]:
|
||||
from ray.rllib.agents.pg.torch_pg_policy_graph import PGTorchPolicy
|
||||
return PGTorchPolicy
|
||||
else:
|
||||
return PGTFPolicy
|
||||
|
||||
This is an example agent to show how to implement algorithms in RLlib.
|
||||
In most cases, you will probably want to use the PPO agent instead.
|
||||
"""
|
||||
|
||||
_name = "PG"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy_graph = PGPolicyGraph
|
||||
|
||||
@override(Trainer)
|
||||
def _init(self, config, env_creator):
|
||||
if config["use_pytorch"]:
|
||||
from ray.rllib.agents.pg.torch_pg_policy_graph import \
|
||||
PGTorchPolicyGraph
|
||||
policy_cls = PGTorchPolicyGraph
|
||||
else:
|
||||
policy_cls = self._policy_graph
|
||||
self.local_evaluator = self.make_local_evaluator(
|
||||
env_creator, policy_cls)
|
||||
self.remote_evaluators = self.make_remote_evaluators(
|
||||
env_creator, policy_cls, config["num_workers"])
|
||||
optimizer_config = dict(
|
||||
config["optimizer"],
|
||||
**{"train_batch_size": config["train_batch_size"]})
|
||||
self.optimizer = SyncSamplesOptimizer(
|
||||
self.local_evaluator, self.remote_evaluators, **optimizer_config)
|
||||
|
||||
@override(Trainer)
|
||||
def _train(self):
|
||||
prev_steps = self.optimizer.num_steps_sampled
|
||||
self.optimizer.step()
|
||||
result = self.collect_metrics()
|
||||
result.update(timesteps_this_iter=self.optimizer.num_steps_sampled -
|
||||
prev_steps)
|
||||
return result
|
||||
PGTrainer = build_trainer(
|
||||
name="PG",
|
||||
default_config=DEFAULT_CONFIG,
|
||||
default_policy=PGTFPolicy,
|
||||
get_policy_class=get_policy_class)
|
||||
|
||||
@@ -3,102 +3,33 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||
Postprocessing
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
class PGLoss(object):
|
||||
"""The basic policy gradient loss."""
|
||||
|
||||
def __init__(self, action_dist, actions, advantages):
|
||||
self.loss = -tf.reduce_mean(action_dist.logp(actions) * advantages)
|
||||
# The basic policy gradients loss
|
||||
def policy_gradient_loss(policy, batch_tensors):
|
||||
actions = batch_tensors[SampleBatch.ACTIONS]
|
||||
advantages = batch_tensors[Postprocessing.ADVANTAGES]
|
||||
return -tf.reduce_mean(policy.action_dist.logp(actions) * advantages)
|
||||
|
||||
|
||||
class PGPostprocessing(object):
|
||||
"""Adds the advantages field to the trajectory."""
|
||||
|
||||
@override(PolicyGraph)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
# This adds the "advantages" column to the sample batch
|
||||
return compute_advantages(
|
||||
sample_batch, 0.0, self.config["gamma"], use_gae=False)
|
||||
# This adds the "advantages" column to the sample batch.
|
||||
def postprocess_advantages(policy,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
return compute_advantages(
|
||||
sample_batch, 0.0, policy.config["gamma"], use_gae=False)
|
||||
|
||||
|
||||
class PGPolicyGraph(PGPostprocessing, TFPolicyGraph):
|
||||
"""Simple policy gradient example of defining a policy graph."""
|
||||
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
config = dict(ray.rllib.agents.pg.pg.DEFAULT_CONFIG, **config)
|
||||
self.config = config
|
||||
|
||||
# Setup placeholders
|
||||
obs = tf.placeholder(tf.float32, shape=[None] + list(obs_space.shape))
|
||||
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"])
|
||||
prev_actions = ModelCatalog.get_action_placeholder(action_space)
|
||||
prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward")
|
||||
|
||||
# Create the model network and action outputs
|
||||
self.model = ModelCatalog.get_model({
|
||||
"obs": obs,
|
||||
"prev_actions": prev_actions,
|
||||
"prev_rewards": prev_rewards,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
}, obs_space, action_space, self.logit_dim, self.config["model"])
|
||||
action_dist = dist_class(self.model.outputs) # logit for each action
|
||||
|
||||
# Setup policy loss
|
||||
actions = ModelCatalog.get_action_placeholder(action_space)
|
||||
advantages = tf.placeholder(tf.float32, [None], name="adv")
|
||||
loss = PGLoss(action_dist, actions, advantages).loss
|
||||
|
||||
# Mapping from sample batch keys to placeholders. These keys will be
|
||||
# read from postprocessed sample batches and fed into the specified
|
||||
# placeholders during loss computation.
|
||||
loss_in = [
|
||||
(SampleBatch.CUR_OBS, obs),
|
||||
(SampleBatch.ACTIONS, actions),
|
||||
(SampleBatch.PREV_ACTIONS, prev_actions),
|
||||
(SampleBatch.PREV_REWARDS, prev_rewards),
|
||||
(Postprocessing.ADVANTAGES, advantages),
|
||||
]
|
||||
|
||||
# Initialize TFPolicyGraph
|
||||
sess = tf.get_default_session()
|
||||
TFPolicyGraph.__init__(
|
||||
self,
|
||||
obs_space,
|
||||
action_space,
|
||||
sess,
|
||||
obs_input=obs,
|
||||
action_sampler=action_dist.sample(),
|
||||
action_prob=action_dist.sampled_action_prob(),
|
||||
loss=loss,
|
||||
loss_inputs=loss_in,
|
||||
model=self.model,
|
||||
state_inputs=self.model.state_in,
|
||||
state_outputs=self.model.state_out,
|
||||
prev_action_input=prev_actions,
|
||||
prev_reward_input=prev_rewards,
|
||||
seq_lens=self.model.seq_lens,
|
||||
max_seq_len=config["model"]["max_seq_len"])
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
@override(PolicyGraph)
|
||||
def get_initial_state(self):
|
||||
return self.model.state_init
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def optimizer(self):
|
||||
return tf.train.AdamOptimizer(learning_rate=self.config["lr"])
|
||||
PGTFPolicy = build_tf_policy(
|
||||
name="PGTFPolicy",
|
||||
get_default_config=lambda: ray.rllib.agents.pg.pg.DEFAULT_CONFIG,
|
||||
postprocess_fn=postprocess_advantages,
|
||||
loss_fn=policy_gradient_loss)
|
||||
|
||||
@@ -2,82 +2,41 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import ray
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||
Postprocessing
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.torch_policy_graph import TorchPolicyGraph
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.evaluation.torch_policy_template import build_torch_policy
|
||||
|
||||
|
||||
class PGLoss(nn.Module):
|
||||
def __init__(self, dist_class):
|
||||
nn.Module.__init__(self)
|
||||
self.dist_class = dist_class
|
||||
|
||||
def forward(self, policy_model, observations, actions, advantages):
|
||||
logits, _, values, _ = policy_model({
|
||||
SampleBatch.CUR_OBS: observations
|
||||
}, [])
|
||||
dist = self.dist_class(logits)
|
||||
log_probs = dist.logp(actions)
|
||||
self.pi_err = -advantages.dot(log_probs.reshape(-1))
|
||||
return self.pi_err
|
||||
def pg_torch_loss(policy, batch_tensors):
|
||||
logits, _, values, _ = policy.model({
|
||||
SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS]
|
||||
}, [])
|
||||
action_dist = policy.dist_class(logits)
|
||||
log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS])
|
||||
# save the error in the policy object
|
||||
policy.pi_err = -batch_tensors[Postprocessing.ADVANTAGES].dot(
|
||||
log_probs.reshape(-1))
|
||||
return policy.pi_err
|
||||
|
||||
|
||||
class PGPostprocessing(object):
|
||||
"""Adds the value func output and advantages field to the trajectory."""
|
||||
|
||||
@override(TorchPolicyGraph)
|
||||
def extra_action_out(self, model_out):
|
||||
return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()}
|
||||
|
||||
@override(PolicyGraph)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
return compute_advantages(
|
||||
sample_batch, 0.0, self.config["gamma"], use_gae=False)
|
||||
def postprocess_advantages(policy,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
return compute_advantages(
|
||||
sample_batch, 0.0, policy.config["gamma"], use_gae=False)
|
||||
|
||||
|
||||
class PGTorchPolicyGraph(PGPostprocessing, TorchPolicyGraph):
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
|
||||
self.config = config
|
||||
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"], torch=True)
|
||||
model = ModelCatalog.get_torch_model(obs_space, self.logit_dim,
|
||||
self.config["model"])
|
||||
loss = PGLoss(dist_class)
|
||||
def pg_loss_stats(policy, batch_tensors):
|
||||
# the error is recorded when computing the loss
|
||||
return {"policy_loss": policy.pi_err.item()}
|
||||
|
||||
TorchPolicyGraph.__init__(
|
||||
self,
|
||||
obs_space,
|
||||
action_space,
|
||||
model,
|
||||
loss,
|
||||
loss_inputs=[
|
||||
SampleBatch.CUR_OBS, SampleBatch.ACTIONS,
|
||||
Postprocessing.ADVANTAGES
|
||||
],
|
||||
action_distribution_cls=dist_class)
|
||||
|
||||
@override(TorchPolicyGraph)
|
||||
def optimizer(self):
|
||||
return torch.optim.Adam(self._model.parameters(), lr=self.config["lr"])
|
||||
|
||||
@override(TorchPolicyGraph)
|
||||
def extra_grad_info(self):
|
||||
return {"policy_loss": self._loss.pi_err.item()}
|
||||
|
||||
def _value(self, obs):
|
||||
with self.lock:
|
||||
obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device)
|
||||
_, _, vf, _ = self.model({"obs": obs}, [])
|
||||
return vf.detach().cpu().numpy().squeeze()
|
||||
PGTorchPolicy = build_torch_policy(
|
||||
name="PGTorchPolicy",
|
||||
get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG,
|
||||
loss_fn=pg_torch_loss,
|
||||
stats_fn=pg_loss_stats,
|
||||
postprocess_fn=postprocess_advantages)
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.ppo.appo_policy_graph import AsyncPPOPolicyGraph
|
||||
from ray.rllib.agents.ppo.appo_policy_graph import AsyncPPOTFPolicy
|
||||
from ray.rllib.agents.trainer import with_base_config
|
||||
from ray.rllib.agents import impala
|
||||
from ray.rllib.utils.annotations import override
|
||||
@@ -57,8 +57,8 @@ class APPOTrainer(impala.ImpalaTrainer):
|
||||
|
||||
_name = "APPO"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy_graph = AsyncPPOPolicyGraph
|
||||
_policy_graph = AsyncPPOTFPolicy
|
||||
|
||||
@override(impala.ImpalaTrainer)
|
||||
def _get_policy_graph(self):
|
||||
return AsyncPPOPolicyGraph
|
||||
return AsyncPPOTFPolicy
|
||||
|
||||
@@ -12,14 +12,11 @@ import gym
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.impala import vtrace
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
|
||||
LearningRateSchedule
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.evaluation.tf_policy_graph import LearningRateSchedule
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.models.action_dist import MultiCategorical
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
@@ -27,6 +24,8 @@ tf = try_import_tf()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BEHAVIOUR_LOGITS = "behaviour_logits"
|
||||
|
||||
|
||||
class PPOSurrogateLoss(object):
|
||||
"""Loss used when V-trace is disabled.
|
||||
@@ -163,333 +162,235 @@ class VTraceSurrogateLoss(object):
|
||||
self.entropy * entropy_coeff)
|
||||
|
||||
|
||||
class APPOPostprocessing(object):
|
||||
"""Adds the policy logits, VF preds, and advantages to the trajectory."""
|
||||
def _make_time_major(policy, tensor, drop_last=False):
|
||||
"""Swaps batch and trajectory axis.
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_action_fetches(self):
|
||||
out = {"behaviour_logits": self.model.outputs}
|
||||
if not self.config["vtrace"]:
|
||||
out["vf_preds"] = self.value_function
|
||||
return dict(TFPolicyGraph.extra_compute_action_fetches(self), **out)
|
||||
Arguments:
|
||||
policy: Policy reference
|
||||
tensor: A tensor or list of tensors to reshape.
|
||||
drop_last: A bool indicating whether to drop the last
|
||||
trajectory item.
|
||||
|
||||
@override(PolicyGraph)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
if not self.config["vtrace"]:
|
||||
completed = sample_batch["dones"][-1]
|
||||
if completed:
|
||||
last_r = 0.0
|
||||
else:
|
||||
next_state = []
|
||||
for i in range(len(self.model.state_in)):
|
||||
next_state.append(
|
||||
[sample_batch["state_out_{}".format(i)][-1]])
|
||||
last_r = self.value(sample_batch["new_obs"][-1], *next_state)
|
||||
batch = compute_advantages(
|
||||
sample_batch,
|
||||
last_r,
|
||||
self.config["gamma"],
|
||||
self.config["lambda"],
|
||||
use_gae=self.config["use_gae"])
|
||||
Returns:
|
||||
res: A tensor with swapped axes or a list of tensors with
|
||||
swapped axes.
|
||||
"""
|
||||
if isinstance(tensor, list):
|
||||
return [_make_time_major(policy, t, drop_last) for t in tensor]
|
||||
|
||||
if policy.model.state_init:
|
||||
B = tf.shape(policy.model.seq_lens)[0]
|
||||
T = tf.shape(tensor)[0] // B
|
||||
else:
|
||||
# Important: chop the tensor into batches at known episode cut
|
||||
# boundaries. TODO(ekl) this is kind of a hack
|
||||
T = policy.config["sample_batch_size"]
|
||||
B = tf.shape(tensor)[0] // T
|
||||
rs = tf.reshape(tensor, tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0))
|
||||
|
||||
# swap B and T axes
|
||||
res = tf.transpose(
|
||||
rs, [1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0]))))
|
||||
|
||||
if drop_last:
|
||||
return res[:-1]
|
||||
return res
|
||||
|
||||
|
||||
def build_appo_surrogate_loss(policy, batch_tensors):
|
||||
if isinstance(policy.action_space, gym.spaces.Discrete):
|
||||
is_multidiscrete = False
|
||||
output_hidden_shape = [policy.action_space.n]
|
||||
elif isinstance(policy.action_space,
|
||||
gym.spaces.multi_discrete.MultiDiscrete):
|
||||
is_multidiscrete = True
|
||||
output_hidden_shape = policy.action_space.nvec.astype(np.int32)
|
||||
else:
|
||||
is_multidiscrete = False
|
||||
output_hidden_shape = 1
|
||||
|
||||
def make_time_major(*args, **kw):
|
||||
return _make_time_major(policy, *args, **kw)
|
||||
|
||||
actions = batch_tensors[SampleBatch.ACTIONS]
|
||||
dones = batch_tensors[SampleBatch.DONES]
|
||||
rewards = batch_tensors[SampleBatch.REWARDS]
|
||||
behaviour_logits = batch_tensors[BEHAVIOUR_LOGITS]
|
||||
unpacked_behaviour_logits = tf.split(
|
||||
behaviour_logits, output_hidden_shape, axis=1)
|
||||
unpacked_outputs = tf.split(
|
||||
policy.model.outputs, output_hidden_shape, axis=1)
|
||||
prev_dist_inputs = unpacked_behaviour_logits if is_multidiscrete else \
|
||||
behaviour_logits
|
||||
action_dist = policy.action_dist
|
||||
prev_action_dist = policy.dist_class(prev_dist_inputs)
|
||||
values = policy.value_function
|
||||
|
||||
if policy.model.state_in:
|
||||
max_seq_len = tf.reduce_max(policy.model.seq_lens) - 1
|
||||
mask = tf.sequence_mask(policy.model.seq_lens, max_seq_len)
|
||||
mask = tf.reshape(mask, [-1])
|
||||
else:
|
||||
mask = tf.ones_like(rewards)
|
||||
|
||||
if policy.config["vtrace"]:
|
||||
logger.info("Using V-Trace surrogate loss (vtrace=True)")
|
||||
|
||||
# Prepare actions for loss
|
||||
loss_actions = actions if is_multidiscrete else tf.expand_dims(
|
||||
actions, axis=1)
|
||||
|
||||
policy.loss = VTraceSurrogateLoss(
|
||||
actions=make_time_major(loss_actions, drop_last=True),
|
||||
prev_actions_logp=make_time_major(
|
||||
prev_action_dist.logp(actions), drop_last=True),
|
||||
actions_logp=make_time_major(
|
||||
action_dist.logp(actions), drop_last=True),
|
||||
action_kl=prev_action_dist.kl(action_dist),
|
||||
actions_entropy=make_time_major(
|
||||
action_dist.entropy(), drop_last=True),
|
||||
dones=make_time_major(dones, drop_last=True),
|
||||
behaviour_logits=make_time_major(
|
||||
unpacked_behaviour_logits, drop_last=True),
|
||||
target_logits=make_time_major(unpacked_outputs, drop_last=True),
|
||||
discount=policy.config["gamma"],
|
||||
rewards=make_time_major(rewards, drop_last=True),
|
||||
values=make_time_major(values, drop_last=True),
|
||||
bootstrap_value=make_time_major(values)[-1],
|
||||
dist_class=policy.dist_class,
|
||||
valid_mask=make_time_major(mask, drop_last=True),
|
||||
vf_loss_coeff=policy.config["vf_loss_coeff"],
|
||||
entropy_coeff=policy.config["entropy_coeff"],
|
||||
clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"],
|
||||
clip_pg_rho_threshold=policy.config[
|
||||
"vtrace_clip_pg_rho_threshold"],
|
||||
clip_param=policy.config["clip_param"])
|
||||
else:
|
||||
logger.info("Using PPO surrogate loss (vtrace=False)")
|
||||
policy.loss = PPOSurrogateLoss(
|
||||
prev_actions_logp=make_time_major(prev_action_dist.logp(actions)),
|
||||
actions_logp=make_time_major(action_dist.logp(actions)),
|
||||
action_kl=prev_action_dist.kl(action_dist),
|
||||
actions_entropy=make_time_major(action_dist.entropy()),
|
||||
values=make_time_major(values),
|
||||
valid_mask=make_time_major(mask),
|
||||
advantages=make_time_major(
|
||||
batch_tensors[Postprocessing.ADVANTAGES]),
|
||||
value_targets=make_time_major(
|
||||
batch_tensors[Postprocessing.VALUE_TARGETS]),
|
||||
vf_loss_coeff=policy.config["vf_loss_coeff"],
|
||||
entropy_coeff=policy.config["entropy_coeff"],
|
||||
clip_param=policy.config["clip_param"])
|
||||
|
||||
return policy.loss.total_loss
|
||||
|
||||
|
||||
def stats(policy, batch_tensors):
|
||||
values_batched = _make_time_major(
|
||||
policy, policy.value_function, drop_last=policy.config["vtrace"])
|
||||
|
||||
return {
|
||||
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
|
||||
"policy_loss": policy.loss.pi_loss,
|
||||
"entropy": policy.loss.entropy,
|
||||
"var_gnorm": tf.global_norm(policy.var_list),
|
||||
"vf_loss": policy.loss.vf_loss,
|
||||
"vf_explained_var": explained_variance(
|
||||
tf.reshape(policy.loss.value_targets, [-1]),
|
||||
tf.reshape(values_batched, [-1])),
|
||||
}
|
||||
|
||||
|
||||
def grad_stats(policy, grads):
|
||||
return {
|
||||
"grad_gnorm": tf.global_norm(grads),
|
||||
}
|
||||
|
||||
|
||||
def postprocess_trajectory(policy,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
if not policy.config["vtrace"]:
|
||||
completed = sample_batch["dones"][-1]
|
||||
if completed:
|
||||
last_r = 0.0
|
||||
else:
|
||||
batch = sample_batch
|
||||
del batch.data["new_obs"] # not used, so save some bandwidth
|
||||
return batch
|
||||
next_state = []
|
||||
for i in range(len(policy.model.state_in)):
|
||||
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
|
||||
last_r = policy.value(sample_batch["new_obs"][-1], *next_state)
|
||||
batch = compute_advantages(
|
||||
sample_batch,
|
||||
last_r,
|
||||
policy.config["gamma"],
|
||||
policy.config["lambda"],
|
||||
use_gae=policy.config["use_gae"])
|
||||
else:
|
||||
batch = sample_batch
|
||||
del batch.data["new_obs"] # not used, so save some bandwidth
|
||||
return batch
|
||||
|
||||
|
||||
class AsyncPPOPolicyGraph(LearningRateSchedule, APPOPostprocessing,
|
||||
TFPolicyGraph):
|
||||
def __init__(self,
|
||||
observation_space,
|
||||
action_space,
|
||||
config,
|
||||
existing_inputs=None):
|
||||
config = dict(ray.rllib.agents.impala.impala.DEFAULT_CONFIG, **config)
|
||||
assert config["batch_mode"] == "truncate_episodes", \
|
||||
"Must use `truncate_episodes` batch mode with V-trace."
|
||||
self.config = config
|
||||
self.sess = tf.get_default_session()
|
||||
self.grads = None
|
||||
def add_values_and_logits(policy):
|
||||
out = {BEHAVIOUR_LOGITS: policy.model.outputs}
|
||||
if not policy.config["vtrace"]:
|
||||
out[SampleBatch.VF_PREDS] = policy.value_function
|
||||
return out
|
||||
|
||||
if isinstance(action_space, gym.spaces.Discrete):
|
||||
is_multidiscrete = False
|
||||
output_hidden_shape = [action_space.n]
|
||||
elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete):
|
||||
is_multidiscrete = True
|
||||
output_hidden_shape = action_space.nvec.astype(np.int32)
|
||||
else:
|
||||
is_multidiscrete = False
|
||||
output_hidden_shape = 1
|
||||
|
||||
# Policy network model
|
||||
dist_class, logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"])
|
||||
def validate_config(policy, obs_space, action_space, config):
|
||||
assert config["batch_mode"] == "truncate_episodes", \
|
||||
"Must use `truncate_episodes` batch mode with V-trace."
|
||||
|
||||
# Create input placeholders
|
||||
if existing_inputs:
|
||||
if self.config["vtrace"]:
|
||||
actions, dones, behaviour_logits, rewards, observations, \
|
||||
prev_actions, prev_rewards = existing_inputs[:7]
|
||||
existing_state_in = existing_inputs[7:-1]
|
||||
existing_seq_lens = existing_inputs[-1]
|
||||
else:
|
||||
actions, dones, behaviour_logits, rewards, observations, \
|
||||
prev_actions, prev_rewards, adv_ph, value_targets = \
|
||||
existing_inputs[:9]
|
||||
existing_state_in = existing_inputs[9:-1]
|
||||
existing_seq_lens = existing_inputs[-1]
|
||||
else:
|
||||
actions = ModelCatalog.get_action_placeholder(action_space)
|
||||
dones = tf.placeholder(tf.bool, [None], name="dones")
|
||||
rewards = tf.placeholder(tf.float32, [None], name="rewards")
|
||||
behaviour_logits = tf.placeholder(
|
||||
tf.float32, [None, logit_dim], name="behaviour_logits")
|
||||
observations = tf.placeholder(
|
||||
tf.float32, [None] + list(observation_space.shape))
|
||||
existing_state_in = None
|
||||
existing_seq_lens = None
|
||||
|
||||
if not self.config["vtrace"]:
|
||||
adv_ph = tf.placeholder(
|
||||
tf.float32, name="advantages", shape=(None, ))
|
||||
value_targets = tf.placeholder(
|
||||
tf.float32, name="value_targets", shape=(None, ))
|
||||
self.observations = observations
|
||||
def choose_optimizer(policy, config):
|
||||
if policy.config["opt_type"] == "adam":
|
||||
return tf.train.AdamOptimizer(policy.cur_lr)
|
||||
else:
|
||||
return tf.train.RMSPropOptimizer(policy.cur_lr, config["decay"],
|
||||
config["momentum"], config["epsilon"])
|
||||
|
||||
# Unpack behaviour logits
|
||||
unpacked_behaviour_logits = tf.split(
|
||||
behaviour_logits, output_hidden_shape, axis=1)
|
||||
|
||||
# Setup the policy
|
||||
dist_class, logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"])
|
||||
prev_actions = ModelCatalog.get_action_placeholder(action_space)
|
||||
prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward")
|
||||
self.model = ModelCatalog.get_model(
|
||||
{
|
||||
"obs": observations,
|
||||
"prev_actions": prev_actions,
|
||||
"prev_rewards": prev_rewards,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
},
|
||||
observation_space,
|
||||
action_space,
|
||||
logit_dim,
|
||||
self.config["model"],
|
||||
state_in=existing_state_in,
|
||||
seq_lens=existing_seq_lens)
|
||||
unpacked_outputs = tf.split(
|
||||
self.model.outputs, output_hidden_shape, axis=1)
|
||||
def clip_gradients(policy, optimizer, loss):
|
||||
grads = tf.gradients(loss, policy.var_list)
|
||||
policy.grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"])
|
||||
clipped_grads = list(zip(policy.grads, policy.var_list))
|
||||
return clipped_grads
|
||||
|
||||
dist_inputs = unpacked_outputs if is_multidiscrete else \
|
||||
self.model.outputs
|
||||
prev_dist_inputs = unpacked_behaviour_logits if is_multidiscrete else \
|
||||
behaviour_logits
|
||||
|
||||
action_dist = dist_class(dist_inputs)
|
||||
prev_action_dist = dist_class(prev_dist_inputs)
|
||||
|
||||
values = self.model.value_function()
|
||||
self.value_function = values
|
||||
class ValueNetworkMixin(object):
|
||||
def __init__(self):
|
||||
self.value_function = self.model.value_function()
|
||||
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
|
||||
tf.get_variable_scope().name)
|
||||
|
||||
def make_time_major(tensor, drop_last=False):
|
||||
"""Swaps batch and trajectory axis.
|
||||
Args:
|
||||
tensor: A tensor or list of tensors to reshape.
|
||||
drop_last: A bool indicating whether to drop the last
|
||||
trajectory item.
|
||||
Returns:
|
||||
res: A tensor with swapped axes or a list of tensors with
|
||||
swapped axes.
|
||||
"""
|
||||
if isinstance(tensor, list):
|
||||
return [make_time_major(t, drop_last) for t in tensor]
|
||||
|
||||
if self.model.state_init:
|
||||
B = tf.shape(self.model.seq_lens)[0]
|
||||
T = tf.shape(tensor)[0] // B
|
||||
else:
|
||||
# Important: chop the tensor into batches at known episode cut
|
||||
# boundaries. TODO(ekl) this is kind of a hack
|
||||
T = self.config["sample_batch_size"]
|
||||
B = tf.shape(tensor)[0] // T
|
||||
rs = tf.reshape(tensor,
|
||||
tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0))
|
||||
|
||||
# swap B and T axes
|
||||
res = tf.transpose(
|
||||
rs,
|
||||
[1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0]))))
|
||||
|
||||
if drop_last:
|
||||
return res[:-1]
|
||||
return res
|
||||
|
||||
if self.model.state_in:
|
||||
max_seq_len = tf.reduce_max(self.model.seq_lens) - 1
|
||||
mask = tf.sequence_mask(self.model.seq_lens, max_seq_len)
|
||||
mask = tf.reshape(mask, [-1])
|
||||
else:
|
||||
mask = tf.ones_like(rewards)
|
||||
|
||||
# Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc.
|
||||
if self.config["vtrace"]:
|
||||
logger.info("Using V-Trace surrogate loss (vtrace=True)")
|
||||
|
||||
# Prepare actions for loss
|
||||
loss_actions = actions if is_multidiscrete else tf.expand_dims(
|
||||
actions, axis=1)
|
||||
|
||||
self.loss = VTraceSurrogateLoss(
|
||||
actions=make_time_major(loss_actions, drop_last=True),
|
||||
prev_actions_logp=make_time_major(
|
||||
prev_action_dist.logp(actions), drop_last=True),
|
||||
actions_logp=make_time_major(
|
||||
action_dist.logp(actions), drop_last=True),
|
||||
action_kl=prev_action_dist.kl(action_dist),
|
||||
actions_entropy=make_time_major(
|
||||
action_dist.entropy(), drop_last=True),
|
||||
dones=make_time_major(dones, drop_last=True),
|
||||
behaviour_logits=make_time_major(
|
||||
unpacked_behaviour_logits, drop_last=True),
|
||||
target_logits=make_time_major(
|
||||
unpacked_outputs, drop_last=True),
|
||||
discount=config["gamma"],
|
||||
rewards=make_time_major(rewards, drop_last=True),
|
||||
values=make_time_major(values, drop_last=True),
|
||||
bootstrap_value=make_time_major(values)[-1],
|
||||
dist_class=dist_class,
|
||||
valid_mask=make_time_major(mask, drop_last=True),
|
||||
vf_loss_coeff=self.config["vf_loss_coeff"],
|
||||
entropy_coeff=self.config["entropy_coeff"],
|
||||
clip_rho_threshold=self.config["vtrace_clip_rho_threshold"],
|
||||
clip_pg_rho_threshold=self.config[
|
||||
"vtrace_clip_pg_rho_threshold"],
|
||||
clip_param=self.config["clip_param"])
|
||||
else:
|
||||
logger.info("Using PPO surrogate loss (vtrace=False)")
|
||||
self.loss = PPOSurrogateLoss(
|
||||
prev_actions_logp=make_time_major(
|
||||
prev_action_dist.logp(actions)),
|
||||
actions_logp=make_time_major(action_dist.logp(actions)),
|
||||
action_kl=prev_action_dist.kl(action_dist),
|
||||
actions_entropy=make_time_major(action_dist.entropy()),
|
||||
values=make_time_major(values),
|
||||
valid_mask=make_time_major(mask),
|
||||
advantages=make_time_major(adv_ph),
|
||||
value_targets=make_time_major(value_targets),
|
||||
vf_loss_coeff=self.config["vf_loss_coeff"],
|
||||
entropy_coeff=self.config["entropy_coeff"],
|
||||
clip_param=self.config["clip_param"])
|
||||
|
||||
# KL divergence between worker and learner logits for debugging
|
||||
model_dist = MultiCategorical(unpacked_outputs)
|
||||
behaviour_dist = MultiCategorical(unpacked_behaviour_logits)
|
||||
|
||||
kls = model_dist.kl(behaviour_dist)
|
||||
if len(kls) > 1:
|
||||
self.KL_stats = {}
|
||||
|
||||
for i, kl in enumerate(kls):
|
||||
self.KL_stats.update({
|
||||
"mean_KL_{}".format(i): tf.reduce_mean(kl),
|
||||
"max_KL_{}".format(i): tf.reduce_max(kl),
|
||||
})
|
||||
else:
|
||||
self.KL_stats = {
|
||||
"mean_KL": tf.reduce_mean(kls[0]),
|
||||
"max_KL": tf.reduce_max(kls[0]),
|
||||
}
|
||||
|
||||
# Initialize TFPolicyGraph
|
||||
loss_in = [
|
||||
("actions", actions),
|
||||
("dones", dones),
|
||||
("behaviour_logits", behaviour_logits),
|
||||
("rewards", rewards),
|
||||
("obs", observations),
|
||||
("prev_actions", prev_actions),
|
||||
("prev_rewards", prev_rewards),
|
||||
]
|
||||
if not self.config["vtrace"]:
|
||||
loss_in.append(("advantages", adv_ph))
|
||||
loss_in.append(("value_targets", value_targets))
|
||||
LearningRateSchedule.__init__(self, self.config["lr"],
|
||||
self.config["lr_schedule"])
|
||||
TFPolicyGraph.__init__(
|
||||
self,
|
||||
observation_space,
|
||||
action_space,
|
||||
self.sess,
|
||||
obs_input=observations,
|
||||
action_sampler=action_dist.sample(),
|
||||
action_prob=action_dist.sampled_action_prob(),
|
||||
loss=self.loss.total_loss,
|
||||
model=self.model,
|
||||
loss_inputs=loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
state_outputs=self.model.state_out,
|
||||
prev_action_input=prev_actions,
|
||||
prev_reward_input=prev_rewards,
|
||||
seq_lens=self.model.seq_lens,
|
||||
max_seq_len=self.config["model"]["max_seq_len"],
|
||||
batch_divisibility_req=self.config["sample_batch_size"])
|
||||
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
values_batched = make_time_major(
|
||||
values, drop_last=self.config["vtrace"])
|
||||
self.stats_fetches = {
|
||||
LEARNER_STATS_KEY: dict({
|
||||
"cur_lr": tf.cast(self.cur_lr, tf.float64),
|
||||
"policy_loss": self.loss.pi_loss,
|
||||
"entropy": self.loss.entropy,
|
||||
"grad_gnorm": tf.global_norm(self._grads),
|
||||
"var_gnorm": tf.global_norm(self.var_list),
|
||||
"vf_loss": self.loss.vf_loss,
|
||||
"vf_explained_var": explained_variance(
|
||||
tf.reshape(self.loss.value_targets, [-1]),
|
||||
tf.reshape(values_batched, [-1])),
|
||||
}, **self.KL_stats),
|
||||
}
|
||||
|
||||
def optimizer(self):
|
||||
if self.config["opt_type"] == "adam":
|
||||
return tf.train.AdamOptimizer(self.cur_lr)
|
||||
else:
|
||||
return tf.train.RMSPropOptimizer(self.cur_lr, self.config["decay"],
|
||||
self.config["momentum"],
|
||||
self.config["epsilon"])
|
||||
|
||||
def gradients(self, optimizer, loss):
|
||||
grads = tf.gradients(loss, self.var_list)
|
||||
self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"])
|
||||
clipped_grads = list(zip(self.grads, self.var_list))
|
||||
return clipped_grads
|
||||
|
||||
def extra_compute_grad_fetches(self):
|
||||
return self.stats_fetches
|
||||
|
||||
def value(self, ob, *args):
|
||||
feed_dict = {self.observations: [ob], self.model.seq_lens: [1]}
|
||||
feed_dict = {self._obs_input: [ob], self.model.seq_lens: [1]}
|
||||
assert len(args) == len(self.model.state_in), \
|
||||
(args, self.model.state_in)
|
||||
for k, v in zip(self.model.state_in, args):
|
||||
feed_dict[k] = v
|
||||
vf = self.sess.run(self.value_function, feed_dict)
|
||||
vf = self._sess.run(self.value_function, feed_dict)
|
||||
return vf[0]
|
||||
|
||||
def get_initial_state(self):
|
||||
return self.model.state_init
|
||||
|
||||
def copy(self, existing_inputs):
|
||||
return AsyncPPOPolicyGraph(
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
self.config,
|
||||
existing_inputs=existing_inputs)
|
||||
def setup_mixins(policy, obs_space, action_space, config):
|
||||
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
||||
ValueNetworkMixin.__init__(policy)
|
||||
|
||||
|
||||
AsyncPPOTFPolicy = build_tf_policy(
|
||||
name="AsyncPPOTFPolicy",
|
||||
get_default_config=lambda: ray.rllib.agents.impala.impala.DEFAULT_CONFIG,
|
||||
loss_fn=build_appo_surrogate_loss,
|
||||
stats_fn=stats,
|
||||
grad_stats_fn=grad_stats,
|
||||
postprocess_fn=postprocess_trajectory,
|
||||
optimizer_fn=choose_optimizer,
|
||||
gradients_fn=clip_gradients,
|
||||
extra_action_fetches_fn=add_values_and_logits,
|
||||
before_init=validate_config,
|
||||
before_loss_init=setup_mixins,
|
||||
mixins=[LearningRateSchedule, ValueNetworkMixin],
|
||||
get_batch_divisibility_req=lambda p: p.config["sample_batch_size"])
|
||||
|
||||
@@ -4,10 +4,10 @@ from __future__ import print_function
|
||||
|
||||
import logging
|
||||
|
||||
from ray.rllib.agents import Trainer, with_common_config
|
||||
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
|
||||
from ray.rllib.agents import with_common_config
|
||||
from ray.rllib.agents.ppo.ppo_policy_graph import PPOTFPolicy
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.optimizers import SyncSamplesOptimizer, LocalMultiGPUOptimizer
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -63,110 +63,104 @@ DEFAULT_CONFIG = with_common_config({
|
||||
# yapf: enable
|
||||
|
||||
|
||||
class PPOTrainer(Trainer):
|
||||
"""Multi-GPU optimized implementation of PPO in TensorFlow."""
|
||||
def make_optimizer(local_evaluator, remote_evaluators, config):
|
||||
if config["simple_optimizer"]:
|
||||
return SyncSamplesOptimizer(
|
||||
local_evaluator,
|
||||
remote_evaluators,
|
||||
num_sgd_iter=config["num_sgd_iter"],
|
||||
train_batch_size=config["train_batch_size"])
|
||||
|
||||
_name = "PPO"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy_graph = PPOPolicyGraph
|
||||
return LocalMultiGPUOptimizer(
|
||||
local_evaluator,
|
||||
remote_evaluators,
|
||||
sgd_batch_size=config["sgd_minibatch_size"],
|
||||
num_sgd_iter=config["num_sgd_iter"],
|
||||
num_gpus=config["num_gpus"],
|
||||
sample_batch_size=config["sample_batch_size"],
|
||||
num_envs_per_worker=config["num_envs_per_worker"],
|
||||
train_batch_size=config["train_batch_size"],
|
||||
standardize_fields=["advantages"],
|
||||
straggler_mitigation=config["straggler_mitigation"])
|
||||
|
||||
@override(Trainer)
|
||||
def _init(self, config, env_creator):
|
||||
self._validate_config()
|
||||
self.local_evaluator = self.make_local_evaluator(
|
||||
env_creator, self._policy_graph)
|
||||
self.remote_evaluators = self.make_remote_evaluators(
|
||||
env_creator, self._policy_graph, config["num_workers"])
|
||||
if config["simple_optimizer"]:
|
||||
self.optimizer = SyncSamplesOptimizer(
|
||||
self.local_evaluator,
|
||||
self.remote_evaluators,
|
||||
num_sgd_iter=config["num_sgd_iter"],
|
||||
train_batch_size=config["train_batch_size"])
|
||||
else:
|
||||
self.optimizer = LocalMultiGPUOptimizer(
|
||||
self.local_evaluator,
|
||||
self.remote_evaluators,
|
||||
sgd_batch_size=config["sgd_minibatch_size"],
|
||||
num_sgd_iter=config["num_sgd_iter"],
|
||||
num_gpus=config["num_gpus"],
|
||||
sample_batch_size=config["sample_batch_size"],
|
||||
num_envs_per_worker=config["num_envs_per_worker"],
|
||||
train_batch_size=config["train_batch_size"],
|
||||
standardize_fields=["advantages"],
|
||||
straggler_mitigation=config["straggler_mitigation"])
|
||||
|
||||
@override(Trainer)
|
||||
def _train(self):
|
||||
if "observation_filter" not in self.raw_user_config:
|
||||
# TODO(ekl) remove this message after a few releases
|
||||
logger.info(
|
||||
"Important! Since 0.7.0, observation normalization is no "
|
||||
"longer enabled by default. To enable running-mean "
|
||||
"normalization, set 'observation_filter': 'MeanStdFilter'. "
|
||||
"You can ignore this message if your environment doesn't "
|
||||
"require observation normalization.")
|
||||
prev_steps = self.optimizer.num_steps_sampled
|
||||
fetches = self.optimizer.step()
|
||||
if "kl" in fetches:
|
||||
# single-agent
|
||||
self.local_evaluator.for_policy(
|
||||
lambda pi: pi.update_kl(fetches["kl"]))
|
||||
else:
|
||||
def update_kl(trainer, fetches):
|
||||
if "kl" in fetches:
|
||||
# single-agent
|
||||
trainer.local_evaluator.for_policy(
|
||||
lambda pi: pi.update_kl(fetches["kl"]))
|
||||
else:
|
||||
|
||||
def update(pi, pi_id):
|
||||
if pi_id in fetches:
|
||||
pi.update_kl(fetches[pi_id]["kl"])
|
||||
else:
|
||||
logger.debug(
|
||||
"No data for {}, not updating kl".format(pi_id))
|
||||
def update(pi, pi_id):
|
||||
if pi_id in fetches:
|
||||
pi.update_kl(fetches[pi_id]["kl"])
|
||||
else:
|
||||
logger.debug("No data for {}, not updating kl".format(pi_id))
|
||||
|
||||
# multi-agent
|
||||
self.local_evaluator.foreach_trainable_policy(update)
|
||||
res = self.collect_metrics()
|
||||
res.update(
|
||||
timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps,
|
||||
info=res.get("info", {}))
|
||||
# multi-agent
|
||||
trainer.local_evaluator.foreach_trainable_policy(update)
|
||||
|
||||
# Warn about bad clipping configs
|
||||
if self.config["vf_clip_param"] <= 0:
|
||||
rew_scale = float("inf")
|
||||
elif res["policy_reward_mean"]:
|
||||
rew_scale = 0 # punt on handling multiagent case
|
||||
else:
|
||||
rew_scale = round(
|
||||
abs(res["episode_reward_mean"]) / self.config["vf_clip_param"],
|
||||
0)
|
||||
if rew_scale > 200:
|
||||
logger.warning(
|
||||
"The magnitude of your environment rewards are more than "
|
||||
"{}x the scale of `vf_clip_param`. ".format(rew_scale) +
|
||||
"This means that it will take more than "
|
||||
"{} iterations for your value ".format(rew_scale) +
|
||||
"function to converge. If this is not intended, consider "
|
||||
"increasing `vf_clip_param`.")
|
||||
return res
|
||||
|
||||
def _validate_config(self):
|
||||
if self.config["entropy_coeff"] < 0:
|
||||
raise DeprecationWarning("entropy_coeff must be >= 0")
|
||||
if self.config["sgd_minibatch_size"] > self.config["train_batch_size"]:
|
||||
raise ValueError(
|
||||
"Minibatch size {} must be <= train batch size {}.".format(
|
||||
self.config["sgd_minibatch_size"],
|
||||
self.config["train_batch_size"]))
|
||||
if (self.config["batch_mode"] == "truncate_episodes"
|
||||
and not self.config["use_gae"]):
|
||||
raise ValueError(
|
||||
"Episode truncation is not supported without a value "
|
||||
"function. Consider setting batch_mode=complete_episodes.")
|
||||
if (self.config["multiagent"]["policy_graphs"]
|
||||
and not self.config["simple_optimizer"]):
|
||||
logger.info(
|
||||
"In multi-agent mode, policies will be optimized sequentially "
|
||||
"by the multi-GPU optimizer. Consider setting "
|
||||
"simple_optimizer=True if this doesn't work for you.")
|
||||
if not self.config["vf_share_layers"]:
|
||||
logger.warning(
|
||||
"FYI: By default, the value function will not share layers "
|
||||
"with the policy model ('vf_share_layers': False).")
|
||||
def warn_about_obs_filter(trainer):
|
||||
if "observation_filter" not in trainer.raw_user_config:
|
||||
# TODO(ekl) remove this message after a few releases
|
||||
logger.info(
|
||||
"Important! Since 0.7.0, observation normalization is no "
|
||||
"longer enabled by default. To enable running-mean "
|
||||
"normalization, set 'observation_filter': 'MeanStdFilter'. "
|
||||
"You can ignore this message if your environment doesn't "
|
||||
"require observation normalization.")
|
||||
|
||||
|
||||
def warn_about_bad_reward_scales(trainer, result):
|
||||
# Warn about bad clipping configs
|
||||
if trainer.config["vf_clip_param"] <= 0:
|
||||
rew_scale = float("inf")
|
||||
elif result["policy_reward_mean"]:
|
||||
rew_scale = 0 # punt on handling multiagent case
|
||||
else:
|
||||
rew_scale = round(
|
||||
abs(result["episode_reward_mean"]) /
|
||||
trainer.config["vf_clip_param"], 0)
|
||||
if rew_scale > 200:
|
||||
logger.warning(
|
||||
"The magnitude of your environment rewards are more than "
|
||||
"{}x the scale of `vf_clip_param`. ".format(rew_scale) +
|
||||
"This means that it will take more than "
|
||||
"{} iterations for your value ".format(rew_scale) +
|
||||
"function to converge. If this is not intended, consider "
|
||||
"increasing `vf_clip_param`.")
|
||||
|
||||
|
||||
def validate_config(config):
|
||||
if config["entropy_coeff"] < 0:
|
||||
raise DeprecationWarning("entropy_coeff must be >= 0")
|
||||
if config["sgd_minibatch_size"] > config["train_batch_size"]:
|
||||
raise ValueError(
|
||||
"Minibatch size {} must be <= train batch size {}.".format(
|
||||
config["sgd_minibatch_size"], config["train_batch_size"]))
|
||||
if (config["batch_mode"] == "truncate_episodes" and not config["use_gae"]):
|
||||
raise ValueError(
|
||||
"Episode truncation is not supported without a value "
|
||||
"function. Consider setting batch_mode=complete_episodes.")
|
||||
if (config["multiagent"]["policy_graphs"]
|
||||
and not config["simple_optimizer"]):
|
||||
logger.info(
|
||||
"In multi-agent mode, policies will be optimized sequentially "
|
||||
"by the multi-GPU optimizer. Consider setting "
|
||||
"simple_optimizer=True if this doesn't work for you.")
|
||||
if not config["vf_share_layers"]:
|
||||
logger.warning(
|
||||
"FYI: By default, the value function will not share layers "
|
||||
"with the policy model ('vf_share_layers': False).")
|
||||
|
||||
|
||||
PPOTrainer = build_trainer(
|
||||
name="PPO",
|
||||
default_config=DEFAULT_CONFIG,
|
||||
default_policy=PPOTFPolicy,
|
||||
make_policy_optimizer=make_optimizer,
|
||||
validate_config=validate_config,
|
||||
after_optimizer_step=update_kl,
|
||||
before_train_step=warn_about_obs_filter,
|
||||
after_train_result=warn_about_bad_reward_scales)
|
||||
|
||||
@@ -7,13 +7,10 @@ import logging
|
||||
import ray
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||
Postprocessing
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
|
||||
LearningRateSchedule
|
||||
from ray.rllib.evaluation.tf_policy_graph import LearningRateSchedule
|
||||
from ray.rllib.evaluation.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
@@ -107,119 +104,106 @@ class PPOLoss(object):
|
||||
self.loss = loss
|
||||
|
||||
|
||||
class PPOPostprocessing(object):
|
||||
def ppo_surrogate_loss(policy, batch_tensors):
|
||||
if policy.model.state_in:
|
||||
max_seq_len = tf.reduce_max(policy.model.seq_lens)
|
||||
mask = tf.sequence_mask(policy.model.seq_lens, max_seq_len)
|
||||
mask = tf.reshape(mask, [-1])
|
||||
else:
|
||||
mask = tf.ones_like(
|
||||
batch_tensors[Postprocessing.ADVANTAGES], dtype=tf.bool)
|
||||
|
||||
policy.loss_obj = PPOLoss(
|
||||
policy.action_space,
|
||||
batch_tensors[Postprocessing.VALUE_TARGETS],
|
||||
batch_tensors[Postprocessing.ADVANTAGES],
|
||||
batch_tensors[SampleBatch.ACTIONS],
|
||||
batch_tensors[BEHAVIOUR_LOGITS],
|
||||
batch_tensors[SampleBatch.VF_PREDS],
|
||||
policy.action_dist,
|
||||
policy.value_function,
|
||||
policy.kl_coeff,
|
||||
mask,
|
||||
entropy_coeff=policy.config["entropy_coeff"],
|
||||
clip_param=policy.config["clip_param"],
|
||||
vf_clip_param=policy.config["vf_clip_param"],
|
||||
vf_loss_coeff=policy.config["vf_loss_coeff"],
|
||||
use_gae=policy.config["use_gae"])
|
||||
|
||||
return policy.loss_obj.loss
|
||||
|
||||
|
||||
def kl_and_loss_stats(policy, batch_tensors):
|
||||
policy.explained_variance = explained_variance(
|
||||
batch_tensors[Postprocessing.VALUE_TARGETS], policy.value_function)
|
||||
|
||||
stats_fetches = {
|
||||
"cur_kl_coeff": policy.kl_coeff,
|
||||
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
|
||||
"total_loss": policy.loss_obj.loss,
|
||||
"policy_loss": policy.loss_obj.mean_policy_loss,
|
||||
"vf_loss": policy.loss_obj.mean_vf_loss,
|
||||
"vf_explained_var": policy.explained_variance,
|
||||
"kl": policy.loss_obj.mean_kl,
|
||||
"entropy": policy.loss_obj.mean_entropy,
|
||||
}
|
||||
|
||||
return stats_fetches
|
||||
|
||||
|
||||
def vf_preds_and_logits_fetches(policy):
|
||||
"""Adds value function and logits outputs to experience batches."""
|
||||
return {
|
||||
SampleBatch.VF_PREDS: policy.value_function,
|
||||
BEHAVIOUR_LOGITS: policy.model.outputs,
|
||||
}
|
||||
|
||||
|
||||
def postprocess_ppo_gae(policy,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
"""Adds the policy logits, VF preds, and advantages to the trajectory."""
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_action_fetches(self):
|
||||
return dict(
|
||||
TFPolicyGraph.extra_compute_action_fetches(self), **{
|
||||
SampleBatch.VF_PREDS: self.value_function,
|
||||
BEHAVIOUR_LOGITS: self.logits
|
||||
})
|
||||
|
||||
@override(PolicyGraph)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
completed = sample_batch["dones"][-1]
|
||||
if completed:
|
||||
last_r = 0.0
|
||||
else:
|
||||
next_state = []
|
||||
for i in range(len(self.model.state_in)):
|
||||
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
|
||||
last_r = self._value(sample_batch[SampleBatch.NEXT_OBS][-1],
|
||||
sample_batch[SampleBatch.ACTIONS][-1],
|
||||
sample_batch[SampleBatch.REWARDS][-1],
|
||||
*next_state)
|
||||
batch = compute_advantages(
|
||||
sample_batch,
|
||||
last_r,
|
||||
self.config["gamma"],
|
||||
self.config["lambda"],
|
||||
use_gae=self.config["use_gae"])
|
||||
return batch
|
||||
completed = sample_batch["dones"][-1]
|
||||
if completed:
|
||||
last_r = 0.0
|
||||
else:
|
||||
next_state = []
|
||||
for i in range(len(policy.model.state_in)):
|
||||
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
|
||||
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
|
||||
sample_batch[SampleBatch.ACTIONS][-1],
|
||||
sample_batch[SampleBatch.REWARDS][-1],
|
||||
*next_state)
|
||||
batch = compute_advantages(
|
||||
sample_batch,
|
||||
last_r,
|
||||
policy.config["gamma"],
|
||||
policy.config["lambda"],
|
||||
use_gae=policy.config["use_gae"])
|
||||
return batch
|
||||
|
||||
|
||||
class PPOPolicyGraph(LearningRateSchedule, PPOPostprocessing, TFPolicyGraph):
|
||||
def __init__(self,
|
||||
observation_space,
|
||||
action_space,
|
||||
config,
|
||||
existing_inputs=None):
|
||||
"""
|
||||
Arguments:
|
||||
observation_space: Environment observation space specification.
|
||||
action_space: Environment action space specification.
|
||||
config (dict): Configuration values for PPO graph.
|
||||
existing_inputs (list): Optional list of tuples that specify the
|
||||
placeholders upon which the graph should be built upon.
|
||||
"""
|
||||
config = dict(ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, **config)
|
||||
self.sess = tf.get_default_session()
|
||||
self.action_space = action_space
|
||||
self.config = config
|
||||
self.kl_coeff_val = self.config["kl_coeff"]
|
||||
self.kl_target = self.config["kl_target"]
|
||||
dist_cls, logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"])
|
||||
def clip_gradients(policy, optimizer, loss):
|
||||
if policy.config["grad_clip"] is not None:
|
||||
policy.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
|
||||
tf.get_variable_scope().name)
|
||||
grads = tf.gradients(loss, policy.var_list)
|
||||
policy.grads, _ = tf.clip_by_global_norm(grads,
|
||||
policy.config["grad_clip"])
|
||||
clipped_grads = list(zip(policy.grads, policy.var_list))
|
||||
return clipped_grads
|
||||
else:
|
||||
return optimizer.compute_gradients(
|
||||
loss, colocate_gradients_with_ops=True)
|
||||
|
||||
if existing_inputs:
|
||||
obs_ph, value_targets_ph, adv_ph, act_ph, \
|
||||
logits_ph, vf_preds_ph, prev_actions_ph, prev_rewards_ph = \
|
||||
existing_inputs[:8]
|
||||
existing_state_in = existing_inputs[8:-1]
|
||||
existing_seq_lens = existing_inputs[-1]
|
||||
else:
|
||||
obs_ph = tf.placeholder(
|
||||
tf.float32,
|
||||
name="obs",
|
||||
shape=(None, ) + observation_space.shape)
|
||||
adv_ph = tf.placeholder(
|
||||
tf.float32, name="advantages", shape=(None, ))
|
||||
act_ph = ModelCatalog.get_action_placeholder(action_space)
|
||||
logits_ph = tf.placeholder(
|
||||
tf.float32, name="logits", shape=(None, logit_dim))
|
||||
vf_preds_ph = tf.placeholder(
|
||||
tf.float32, name="vf_preds", shape=(None, ))
|
||||
value_targets_ph = tf.placeholder(
|
||||
tf.float32, name="value_targets", shape=(None, ))
|
||||
prev_actions_ph = ModelCatalog.get_action_placeholder(action_space)
|
||||
prev_rewards_ph = tf.placeholder(
|
||||
tf.float32, [None], name="prev_reward")
|
||||
existing_state_in = None
|
||||
existing_seq_lens = None
|
||||
self.observations = obs_ph
|
||||
self.prev_actions = prev_actions_ph
|
||||
self.prev_rewards = prev_rewards_ph
|
||||
|
||||
self.loss_in = [
|
||||
(SampleBatch.CUR_OBS, obs_ph),
|
||||
(Postprocessing.VALUE_TARGETS, value_targets_ph),
|
||||
(Postprocessing.ADVANTAGES, adv_ph),
|
||||
(SampleBatch.ACTIONS, act_ph),
|
||||
(BEHAVIOUR_LOGITS, logits_ph),
|
||||
(SampleBatch.VF_PREDS, vf_preds_ph),
|
||||
(SampleBatch.PREV_ACTIONS, prev_actions_ph),
|
||||
(SampleBatch.PREV_REWARDS, prev_rewards_ph),
|
||||
]
|
||||
self.model = ModelCatalog.get_model(
|
||||
{
|
||||
"obs": obs_ph,
|
||||
"prev_actions": prev_actions_ph,
|
||||
"prev_rewards": prev_rewards_ph,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
},
|
||||
observation_space,
|
||||
action_space,
|
||||
logit_dim,
|
||||
self.config["model"],
|
||||
state_in=existing_state_in,
|
||||
seq_lens=existing_seq_lens)
|
||||
|
||||
class KLCoeffMixin(object):
|
||||
def __init__(self, config):
|
||||
# KL Coefficient
|
||||
self.kl_coeff_val = config["kl_coeff"]
|
||||
self.kl_target = config["kl_target"]
|
||||
self.kl_coeff = tf.get_variable(
|
||||
initializer=tf.constant_initializer(self.kl_coeff_val),
|
||||
name="kl_coeff",
|
||||
@@ -227,14 +211,22 @@ class PPOPolicyGraph(LearningRateSchedule, PPOPostprocessing, TFPolicyGraph):
|
||||
trainable=False,
|
||||
dtype=tf.float32)
|
||||
|
||||
self.logits = self.model.outputs
|
||||
curr_action_dist = dist_cls(self.logits)
|
||||
self.sampler = curr_action_dist.sample()
|
||||
if self.config["use_gae"]:
|
||||
if self.config["vf_share_layers"]:
|
||||
def update_kl(self, sampled_kl):
|
||||
if sampled_kl > 2.0 * self.kl_target:
|
||||
self.kl_coeff_val *= 1.5
|
||||
elif sampled_kl < 0.5 * self.kl_target:
|
||||
self.kl_coeff_val *= 0.5
|
||||
self.kl_coeff.load(self.kl_coeff_val, session=self._sess)
|
||||
return self.kl_coeff_val
|
||||
|
||||
|
||||
class ValueNetworkMixin(object):
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
if config["use_gae"]:
|
||||
if config["vf_share_layers"]:
|
||||
self.value_function = self.model.value_function()
|
||||
else:
|
||||
vf_config = self.config["model"].copy()
|
||||
vf_config = config["model"].copy()
|
||||
# Do not split the last layer of the value function into
|
||||
# mean parameters and standard deviation parameters and
|
||||
# do not make the standard deviations free variables.
|
||||
@@ -249,122 +241,43 @@ class PPOPolicyGraph(LearningRateSchedule, PPOPostprocessing, TFPolicyGraph):
|
||||
"value_function() method.")
|
||||
with tf.variable_scope("value_function"):
|
||||
self.value_function = ModelCatalog.get_model({
|
||||
"obs": obs_ph,
|
||||
"prev_actions": prev_actions_ph,
|
||||
"prev_rewards": prev_rewards_ph,
|
||||
"obs": self._obs_input,
|
||||
"prev_actions": self._prev_action_input,
|
||||
"prev_rewards": self._prev_reward_input,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
}, observation_space, action_space, 1, vf_config).outputs
|
||||
}, obs_space, action_space, 1, vf_config).outputs
|
||||
self.value_function = tf.reshape(self.value_function, [-1])
|
||||
else:
|
||||
self.value_function = tf.zeros(shape=tf.shape(obs_ph)[:1])
|
||||
|
||||
if self.model.state_in:
|
||||
max_seq_len = tf.reduce_max(self.model.seq_lens)
|
||||
mask = tf.sequence_mask(self.model.seq_lens, max_seq_len)
|
||||
mask = tf.reshape(mask, [-1])
|
||||
else:
|
||||
mask = tf.ones_like(adv_ph, dtype=tf.bool)
|
||||
|
||||
self.loss_obj = PPOLoss(
|
||||
action_space,
|
||||
value_targets_ph,
|
||||
adv_ph,
|
||||
act_ph,
|
||||
logits_ph,
|
||||
vf_preds_ph,
|
||||
curr_action_dist,
|
||||
self.value_function,
|
||||
self.kl_coeff,
|
||||
mask,
|
||||
entropy_coeff=self.config["entropy_coeff"],
|
||||
clip_param=self.config["clip_param"],
|
||||
vf_clip_param=self.config["vf_clip_param"],
|
||||
vf_loss_coeff=self.config["vf_loss_coeff"],
|
||||
use_gae=self.config["use_gae"])
|
||||
|
||||
LearningRateSchedule.__init__(self, self.config["lr"],
|
||||
self.config["lr_schedule"])
|
||||
TFPolicyGraph.__init__(
|
||||
self,
|
||||
observation_space,
|
||||
action_space,
|
||||
self.sess,
|
||||
obs_input=obs_ph,
|
||||
action_sampler=self.sampler,
|
||||
action_prob=curr_action_dist.sampled_action_prob(),
|
||||
loss=self.loss_obj.loss,
|
||||
model=self.model,
|
||||
loss_inputs=self.loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
state_outputs=self.model.state_out,
|
||||
prev_action_input=prev_actions_ph,
|
||||
prev_reward_input=prev_rewards_ph,
|
||||
seq_lens=self.model.seq_lens,
|
||||
max_seq_len=config["model"]["max_seq_len"])
|
||||
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
self.explained_variance = explained_variance(value_targets_ph,
|
||||
self.value_function)
|
||||
self.stats_fetches = {
|
||||
"cur_kl_coeff": self.kl_coeff,
|
||||
"cur_lr": tf.cast(self.cur_lr, tf.float64),
|
||||
"total_loss": self.loss_obj.loss,
|
||||
"policy_loss": self.loss_obj.mean_policy_loss,
|
||||
"vf_loss": self.loss_obj.mean_vf_loss,
|
||||
"vf_explained_var": self.explained_variance,
|
||||
"kl": self.loss_obj.mean_kl,
|
||||
"entropy": self.loss_obj.mean_entropy
|
||||
}
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def copy(self, existing_inputs):
|
||||
"""Creates a copy of self using existing input placeholders."""
|
||||
return PPOPolicyGraph(
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
self.config,
|
||||
existing_inputs=existing_inputs)
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def gradients(self, optimizer, loss):
|
||||
if self.config["grad_clip"] is not None:
|
||||
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
|
||||
tf.get_variable_scope().name)
|
||||
grads = tf.gradients(loss, self.var_list)
|
||||
self.grads, _ = tf.clip_by_global_norm(grads,
|
||||
self.config["grad_clip"])
|
||||
clipped_grads = list(zip(self.grads, self.var_list))
|
||||
return clipped_grads
|
||||
else:
|
||||
return optimizer.compute_gradients(
|
||||
loss, colocate_gradients_with_ops=True)
|
||||
|
||||
@override(PolicyGraph)
|
||||
def get_initial_state(self):
|
||||
return self.model.state_init
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_grad_fetches(self):
|
||||
return {LEARNER_STATS_KEY: self.stats_fetches}
|
||||
|
||||
def update_kl(self, sampled_kl):
|
||||
if sampled_kl > 2.0 * self.kl_target:
|
||||
self.kl_coeff_val *= 1.5
|
||||
elif sampled_kl < 0.5 * self.kl_target:
|
||||
self.kl_coeff_val *= 0.5
|
||||
self.kl_coeff.load(self.kl_coeff_val, session=self.sess)
|
||||
return self.kl_coeff_val
|
||||
self.value_function = tf.zeros(shape=tf.shape(self._obs_input)[:1])
|
||||
|
||||
def _value(self, ob, prev_action, prev_reward, *args):
|
||||
feed_dict = {
|
||||
self.observations: [ob],
|
||||
self.prev_actions: [prev_action],
|
||||
self.prev_rewards: [prev_reward],
|
||||
self._obs_input: [ob],
|
||||
self._prev_action_input: [prev_action],
|
||||
self._prev_reward_input: [prev_reward],
|
||||
self.model.seq_lens: [1]
|
||||
}
|
||||
assert len(args) == len(self.model.state_in), \
|
||||
(args, self.model.state_in)
|
||||
for k, v in zip(self.model.state_in, args):
|
||||
feed_dict[k] = v
|
||||
vf = self.sess.run(self.value_function, feed_dict)
|
||||
vf = self._sess.run(self.value_function, feed_dict)
|
||||
return vf[0]
|
||||
|
||||
|
||||
def setup_mixins(policy, obs_space, action_space, config):
|
||||
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
|
||||
KLCoeffMixin.__init__(policy, config)
|
||||
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
||||
|
||||
|
||||
PPOTFPolicy = build_tf_policy(
|
||||
name="PPOTFPolicy",
|
||||
get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG,
|
||||
loss_fn=ppo_surrogate_loss,
|
||||
stats_fn=kl_and_loss_stats,
|
||||
extra_action_fetches_fn=vf_preds_and_logits_fetches,
|
||||
postprocess_fn=postprocess_ppo_gae,
|
||||
gradients_fn=clip_gradients,
|
||||
before_loss_init=setup_mixins,
|
||||
mixins=[LearningRateSchedule, KLCoeffMixin, ValueNetworkMixin])
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.optimizers import SyncSamplesOptimizer
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def build_trainer(name,
|
||||
default_policy,
|
||||
default_config=None,
|
||||
make_policy_optimizer=None,
|
||||
validate_config=None,
|
||||
get_policy_class=None,
|
||||
before_train_step=None,
|
||||
after_optimizer_step=None,
|
||||
after_train_result=None):
|
||||
"""Helper function for defining a custom trainer.
|
||||
|
||||
Arguments:
|
||||
name (str): name of the trainer (e.g., "PPO")
|
||||
default_policy (cls): the default PolicyGraph class to use
|
||||
default_config (dict): the default config dict of the algorithm,
|
||||
otherwises uses the Trainer default config
|
||||
make_policy_optimizer (func): optional function that returns a
|
||||
PolicyOptimizer instance given
|
||||
(local_evaluator, remote_evaluators, config)
|
||||
validate_config (func): optional callback that checks a given config
|
||||
for correctness. It may mutate the config as needed.
|
||||
get_policy_class (func): optional callback that takes a config and
|
||||
returns the policy graph class to override the default with
|
||||
before_train_step (func): optional callback to run before each train()
|
||||
call. It takes the trainer instance as an argument.
|
||||
after_optimizer_step (func): optional callback to run after each
|
||||
step() call to the policy optimizer. It takes the trainer instance
|
||||
and the policy gradient fetches as arguments.
|
||||
after_train_result (func): optional callback to run at the end of each
|
||||
train() call. It takes the trainer instance and result dict as
|
||||
arguments, and may mutate the result dict as needed.
|
||||
|
||||
Returns:
|
||||
a Trainer instance that uses the specified args.
|
||||
"""
|
||||
|
||||
if name.endswith("Trainer"):
|
||||
raise ValueError("Algorithm name should not include *Trainer suffix",
|
||||
name)
|
||||
|
||||
class trainer_cls(Trainer):
|
||||
_name = name
|
||||
_default_config = default_config or Trainer.COMMON_CONFIG
|
||||
_policy_graph = default_policy
|
||||
|
||||
def _init(self, config, env_creator):
|
||||
if validate_config:
|
||||
validate_config(config)
|
||||
if get_policy_class is None:
|
||||
policy_graph = default_policy
|
||||
else:
|
||||
policy_graph = get_policy_class(config)
|
||||
self.local_evaluator = self.make_local_evaluator(
|
||||
env_creator, policy_graph)
|
||||
self.remote_evaluators = self.make_remote_evaluators(
|
||||
env_creator, policy_graph, config["num_workers"])
|
||||
if make_policy_optimizer:
|
||||
self.optimizer = make_policy_optimizer(
|
||||
self.local_evaluator, self.remote_evaluators, config)
|
||||
else:
|
||||
optimizer_config = dict(
|
||||
config["optimizer"],
|
||||
**{"train_batch_size": config["train_batch_size"]})
|
||||
self.optimizer = SyncSamplesOptimizer(self.local_evaluator,
|
||||
self.remote_evaluators,
|
||||
**optimizer_config)
|
||||
|
||||
@override(Trainer)
|
||||
def _train(self):
|
||||
if before_train_step:
|
||||
before_train_step(self)
|
||||
prev_steps = self.optimizer.num_steps_sampled
|
||||
fetches = self.optimizer.step()
|
||||
if after_optimizer_step:
|
||||
after_optimizer_step(self, fetches)
|
||||
res = self.collect_metrics()
|
||||
res.update(
|
||||
timesteps_this_iter=self.optimizer.num_steps_sampled -
|
||||
prev_steps,
|
||||
info=res.get("info", {}))
|
||||
if after_train_result:
|
||||
after_train_result(self, res)
|
||||
return res
|
||||
|
||||
trainer_cls.__name__ = name + "Trainer"
|
||||
trainer_cls.__qualname__ = name + "Trainer"
|
||||
return trainer_cls
|
||||
@@ -0,0 +1,275 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import OrderedDict
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.debug import log_once, summarize
|
||||
from ray.rllib.utils.tracking_dict import UsageTrackingDict
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DynamicTFPolicyGraph(TFPolicyGraph):
|
||||
"""A TFPolicyGraph that auto-defines placeholders dynamically at runtime.
|
||||
|
||||
Initialization of this class occurs in two phases.
|
||||
* Phase 1: the model is created and model variables are initialized.
|
||||
* Phase 2: a fake batch of data is created, sent to the trajectory
|
||||
postprocessor, and then used to create placeholders for the loss
|
||||
function. The loss and stats functions are initialized with these
|
||||
placeholders.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
obs_space,
|
||||
action_space,
|
||||
config,
|
||||
loss_fn,
|
||||
stats_fn=None,
|
||||
grad_stats_fn=None,
|
||||
before_loss_init=None,
|
||||
make_action_sampler=None,
|
||||
existing_inputs=None,
|
||||
get_batch_divisibility_req=None):
|
||||
"""Initialize a dynamic TF policy graph.
|
||||
|
||||
Arguments:
|
||||
observation_space (gym.Space): Observation space of the policy.
|
||||
action_space (gym.Space): Action space of the policy.
|
||||
config (dict): Policy-specific configuration data.
|
||||
loss_fn (func): function that returns a loss tensor the policy
|
||||
graph, and dict of experience tensor placeholders
|
||||
stats_fn (func): optional function that returns a dict of
|
||||
TF fetches given the policy graph and batch input tensors
|
||||
grad_stats_fn (func): optional function that returns a dict of
|
||||
TF fetches given the policy graph and loss gradient tensors
|
||||
before_loss_init (func): optional function to run prior to loss
|
||||
init that takes the same arguments as __init__
|
||||
make_action_sampler (func): optional function that returns a
|
||||
tuple of action and action prob tensors. The function takes
|
||||
(policy, input_dict, obs_space, action_space, config) as its
|
||||
arguments
|
||||
existing_inputs (OrderedDict): when copying a policy graph, this
|
||||
specifies an existing dict of placeholders to use instead of
|
||||
defining new ones
|
||||
get_batch_divisibility_req (func): optional function that returns
|
||||
the divisibility requirement for sample batches
|
||||
"""
|
||||
self.config = config
|
||||
self._loss_fn = loss_fn
|
||||
self._stats_fn = stats_fn
|
||||
self._grad_stats_fn = grad_stats_fn
|
||||
|
||||
# Setup standard placeholders
|
||||
if existing_inputs is not None:
|
||||
obs = existing_inputs[SampleBatch.CUR_OBS]
|
||||
prev_actions = existing_inputs[SampleBatch.PREV_ACTIONS]
|
||||
prev_rewards = existing_inputs[SampleBatch.PREV_REWARDS]
|
||||
else:
|
||||
obs = tf.placeholder(
|
||||
tf.float32,
|
||||
shape=[None] + list(obs_space.shape),
|
||||
name="observation")
|
||||
prev_actions = ModelCatalog.get_action_placeholder(action_space)
|
||||
prev_rewards = tf.placeholder(
|
||||
tf.float32, [None], name="prev_reward")
|
||||
|
||||
input_dict = {
|
||||
"obs": obs,
|
||||
"prev_actions": prev_actions,
|
||||
"prev_rewards": prev_rewards,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
}
|
||||
|
||||
# Create the model network and action outputs
|
||||
if make_action_sampler:
|
||||
assert not existing_inputs, \
|
||||
"Cloning not supported with custom action sampler"
|
||||
self.model = None
|
||||
self.dist_class = None
|
||||
self.action_dist = None
|
||||
action_sampler, action_prob = make_action_sampler(
|
||||
self, input_dict, obs_space, action_space, config)
|
||||
else:
|
||||
self.dist_class, logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"])
|
||||
if existing_inputs:
|
||||
existing_state_in = [
|
||||
v for k, v in existing_inputs.items()
|
||||
if k.startswith("state_in_")
|
||||
]
|
||||
if existing_state_in:
|
||||
existing_seq_lens = existing_inputs["seq_lens"]
|
||||
else:
|
||||
existing_seq_lens = None
|
||||
else:
|
||||
existing_state_in = []
|
||||
existing_seq_lens = None
|
||||
self.model = ModelCatalog.get_model(
|
||||
input_dict,
|
||||
obs_space,
|
||||
action_space,
|
||||
logit_dim,
|
||||
self.config["model"],
|
||||
state_in=existing_state_in,
|
||||
seq_lens=existing_seq_lens)
|
||||
self.action_dist = self.dist_class(self.model.outputs)
|
||||
action_sampler = self.action_dist.sample()
|
||||
action_prob = self.action_dist.sampled_action_prob()
|
||||
|
||||
# Phase 1 init
|
||||
sess = tf.get_default_session()
|
||||
if get_batch_divisibility_req:
|
||||
batch_divisibility_req = get_batch_divisibility_req(self)
|
||||
else:
|
||||
batch_divisibility_req = 1
|
||||
TFPolicyGraph.__init__(
|
||||
self,
|
||||
obs_space,
|
||||
action_space,
|
||||
sess,
|
||||
obs_input=obs,
|
||||
action_sampler=action_sampler,
|
||||
action_prob=action_prob,
|
||||
loss=None, # dynamically initialized on run
|
||||
loss_inputs=[],
|
||||
model=self.model,
|
||||
state_inputs=self.model and self.model.state_in,
|
||||
state_outputs=self.model and self.model.state_out,
|
||||
prev_action_input=prev_actions,
|
||||
prev_reward_input=prev_rewards,
|
||||
seq_lens=self.model and self.model.seq_lens,
|
||||
max_seq_len=config["model"]["max_seq_len"],
|
||||
batch_divisibility_req=batch_divisibility_req)
|
||||
|
||||
# Phase 2 init
|
||||
before_loss_init(self, obs_space, action_space, config)
|
||||
if not existing_inputs:
|
||||
self._initialize_loss()
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def copy(self, existing_inputs):
|
||||
"""Creates a copy of self using existing input placeholders."""
|
||||
|
||||
# Note that there might be RNN state inputs at the end of the list
|
||||
if self._state_inputs:
|
||||
num_state_inputs = len(self._state_inputs) + 1
|
||||
else:
|
||||
num_state_inputs = 0
|
||||
if len(self._loss_inputs) + num_state_inputs != len(existing_inputs):
|
||||
raise ValueError("Tensor list mismatch", self._loss_inputs,
|
||||
self._state_inputs, existing_inputs)
|
||||
for i, (k, v) in enumerate(self._loss_inputs):
|
||||
if v.shape.as_list() != existing_inputs[i].shape.as_list():
|
||||
raise ValueError("Tensor shape mismatch", i, k, v.shape,
|
||||
existing_inputs[i].shape)
|
||||
# By convention, the loss inputs are followed by state inputs and then
|
||||
# the seq len tensor
|
||||
rnn_inputs = []
|
||||
for i in range(len(self._state_inputs)):
|
||||
rnn_inputs.append(("state_in_{}".format(i),
|
||||
existing_inputs[len(self._loss_inputs) + i]))
|
||||
if rnn_inputs:
|
||||
rnn_inputs.append(("seq_lens", existing_inputs[-1]))
|
||||
input_dict = OrderedDict(
|
||||
[(k, existing_inputs[i])
|
||||
for i, (k, _) in enumerate(self._loss_inputs)] + rnn_inputs)
|
||||
instance = self.__class__(
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
self.config,
|
||||
existing_inputs=input_dict)
|
||||
loss = instance._loss_fn(instance, input_dict)
|
||||
if instance._stats_fn:
|
||||
instance._stats_fetches.update(
|
||||
instance._stats_fn(instance, input_dict))
|
||||
TFPolicyGraph._initialize_loss(
|
||||
instance, loss, [(k, existing_inputs[i])
|
||||
for i, (k, _) in enumerate(self._loss_inputs)])
|
||||
if instance._grad_stats_fn:
|
||||
instance._stats_fetches.update(
|
||||
instance._grad_stats_fn(instance, instance._grads))
|
||||
return instance
|
||||
|
||||
@override(PolicyGraph)
|
||||
def get_initial_state(self):
|
||||
if self.model:
|
||||
return self.model.state_init
|
||||
else:
|
||||
return []
|
||||
|
||||
def _initialize_loss(self):
|
||||
def fake_array(tensor):
|
||||
shape = tensor.shape.as_list()
|
||||
shape[0] = 1
|
||||
return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype)
|
||||
|
||||
dummy_batch = {
|
||||
SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input),
|
||||
SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input),
|
||||
SampleBatch.CUR_OBS: fake_array(self._obs_input),
|
||||
SampleBatch.NEXT_OBS: fake_array(self._obs_input),
|
||||
SampleBatch.ACTIONS: fake_array(self._prev_action_input),
|
||||
SampleBatch.REWARDS: np.array([0], dtype=np.float32),
|
||||
SampleBatch.DONES: np.array([False], dtype=np.bool),
|
||||
}
|
||||
state_init = self.get_initial_state()
|
||||
for i, h in enumerate(state_init):
|
||||
dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0)
|
||||
dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0)
|
||||
if state_init:
|
||||
dummy_batch["seq_lens"] = np.array([1], dtype=np.int32)
|
||||
for k, v in self.extra_compute_action_fetches().items():
|
||||
dummy_batch[k] = fake_array(v)
|
||||
|
||||
# postprocessing might depend on variable init, so run it first here
|
||||
self._sess.run(tf.global_variables_initializer())
|
||||
postprocessed_batch = self.postprocess_trajectory(
|
||||
SampleBatch(dummy_batch))
|
||||
|
||||
batch_tensors = UsageTrackingDict({
|
||||
SampleBatch.PREV_ACTIONS: self._prev_action_input,
|
||||
SampleBatch.PREV_REWARDS: self._prev_reward_input,
|
||||
SampleBatch.CUR_OBS: self._obs_input,
|
||||
})
|
||||
loss_inputs = [
|
||||
(SampleBatch.PREV_ACTIONS, self._prev_action_input),
|
||||
(SampleBatch.PREV_REWARDS, self._prev_reward_input),
|
||||
(SampleBatch.CUR_OBS, self._obs_input),
|
||||
]
|
||||
|
||||
for k, v in postprocessed_batch.items():
|
||||
if k in batch_tensors:
|
||||
continue
|
||||
elif v.dtype == np.object:
|
||||
continue # can't handle arbitrary objects in TF
|
||||
shape = (None, ) + v.shape[1:]
|
||||
dtype = np.float32 if v.dtype == np.float64 else v.dtype
|
||||
placeholder = tf.placeholder(dtype, shape=shape, name=k)
|
||||
batch_tensors[k] = placeholder
|
||||
|
||||
if log_once("loss_init"):
|
||||
logger.info(
|
||||
"Initializing loss function with dummy input:\n\n{}\n".format(
|
||||
summarize(batch_tensors)))
|
||||
|
||||
loss = self._loss_fn(self, batch_tensors)
|
||||
if self._stats_fn:
|
||||
self._stats_fetches.update(self._stats_fn(self, batch_tensors))
|
||||
for k in sorted(batch_tensors.accessed_keys):
|
||||
loss_inputs.append((k, batch_tensors[k]))
|
||||
TFPolicyGraph._initialize_loss(self, loss, loss_inputs)
|
||||
if self._grad_stats_fn:
|
||||
self._stats_fetches.update(self._grad_stats_fn(self, self._grads))
|
||||
self._sess.run(tf.global_variables_initializer())
|
||||
@@ -65,7 +65,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
>>> # Create a policy evaluator and using it to collect experiences.
|
||||
>>> evaluator = PolicyEvaluator(
|
||||
... env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
... policy_graph=PGPolicyGraph)
|
||||
... policy_graph=PGTFPolicy)
|
||||
>>> print(evaluator.sample())
|
||||
SampleBatch({
|
||||
"obs": [[...]], "actions": [[...]], "rewards": [[...]],
|
||||
@@ -76,7 +76,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
... evaluator_cls=PolicyEvaluator,
|
||||
... evaluator_args={
|
||||
... "env_creator": lambda _: gym.make("CartPole-v0"),
|
||||
... "policy_graph": PGPolicyGraph,
|
||||
... "policy_graph": PGTFPolicy,
|
||||
... },
|
||||
... num_workers=10)
|
||||
>>> for _ in range(10): optimizer.step()
|
||||
@@ -87,12 +87,12 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
... policy_graphs={
|
||||
... # Use an ensemble of two policies for car agents
|
||||
... "car_policy1":
|
||||
... (PGPolicyGraph, Box(...), Discrete(...), {"gamma": 0.99}),
|
||||
... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.99}),
|
||||
... "car_policy2":
|
||||
... (PGPolicyGraph, Box(...), Discrete(...), {"gamma": 0.95}),
|
||||
... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.95}),
|
||||
... # Use a single shared policy for all traffic lights
|
||||
... "traffic_light_policy":
|
||||
... (PGPolicyGraph, Box(...), Discrete(...), {}),
|
||||
... (PGTFPolicy, Box(...), Discrete(...), {}),
|
||||
... },
|
||||
... policy_mapping_fn=lambda agent_id:
|
||||
... random.choice(["car_policy1", "car_policy2"])
|
||||
|
||||
@@ -112,46 +112,20 @@ class TFPolicyGraph(PolicyGraph):
|
||||
self._prev_action_input = prev_action_input
|
||||
self._prev_reward_input = prev_reward_input
|
||||
self._sampler = action_sampler
|
||||
self._loss_inputs = loss_inputs
|
||||
self._loss_input_dict = dict(self._loss_inputs)
|
||||
self._is_training = self._get_is_training_placeholder()
|
||||
self._action_prob = action_prob
|
||||
self._state_inputs = state_inputs or []
|
||||
self._state_outputs = state_outputs or []
|
||||
for i, ph in enumerate(self._state_inputs):
|
||||
self._loss_input_dict["state_in_{}".format(i)] = ph
|
||||
self._seq_lens = seq_lens
|
||||
self._max_seq_len = max_seq_len
|
||||
self._batch_divisibility_req = batch_divisibility_req
|
||||
self._update_ops = update_ops
|
||||
self._stats_fetches = {}
|
||||
|
||||
if self.model:
|
||||
self._loss = self.model.custom_loss(loss, self._loss_input_dict)
|
||||
self._stats_fetches = {"model": self.model.custom_stats()}
|
||||
if loss is not None:
|
||||
self._initialize_loss(loss, loss_inputs)
|
||||
else:
|
||||
self._loss = loss
|
||||
self._stats_fetches = {}
|
||||
|
||||
self._optimizer = self.optimizer()
|
||||
self._grads_and_vars = [
|
||||
(g, v) for (g, v) in self.gradients(self._optimizer, self._loss)
|
||||
if g is not None
|
||||
]
|
||||
self._grads = [g for (g, v) in self._grads_and_vars]
|
||||
self._variables = ray.experimental.tf_utils.TensorFlowVariables(
|
||||
self._loss, self._sess)
|
||||
|
||||
# gather update ops for any batch norm layers
|
||||
if update_ops:
|
||||
self._update_ops = update_ops
|
||||
else:
|
||||
self._update_ops = tf.get_collection(
|
||||
tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name)
|
||||
if self._update_ops:
|
||||
logger.debug("Update ops to run on apply gradient: {}".format(
|
||||
self._update_ops))
|
||||
with tf.control_dependencies(self._update_ops):
|
||||
self._apply_op = self.build_apply_op(self._optimizer,
|
||||
self._grads_and_vars)
|
||||
self._loss = None
|
||||
|
||||
if len(self._state_inputs) != len(self._state_outputs):
|
||||
raise ValueError(
|
||||
@@ -166,8 +140,44 @@ class TFPolicyGraph(PolicyGraph):
|
||||
raise ValueError(
|
||||
"seq_lens tensor must be given if state inputs are defined")
|
||||
|
||||
logger.debug("Created {} with loss inputs: {}".format(
|
||||
self, self._loss_input_dict))
|
||||
def _initialize_loss(self, loss, loss_inputs):
|
||||
self._loss_inputs = loss_inputs
|
||||
self._loss_input_dict = dict(self._loss_inputs)
|
||||
for i, ph in enumerate(self._state_inputs):
|
||||
self._loss_input_dict["state_in_{}".format(i)] = ph
|
||||
|
||||
if self.model:
|
||||
self._loss = self.model.custom_loss(loss, self._loss_input_dict)
|
||||
self._stats_fetches.update({"model": self.model.custom_stats()})
|
||||
else:
|
||||
self._loss = loss
|
||||
|
||||
self._optimizer = self.optimizer()
|
||||
self._grads_and_vars = [
|
||||
(g, v) for (g, v) in self.gradients(self._optimizer, self._loss)
|
||||
if g is not None
|
||||
]
|
||||
self._grads = [g for (g, v) in self._grads_and_vars]
|
||||
self._variables = ray.experimental.tf_utils.TensorFlowVariables(
|
||||
self._loss, self._sess)
|
||||
|
||||
# gather update ops for any batch norm layers
|
||||
if not self._update_ops:
|
||||
self._update_ops = tf.get_collection(
|
||||
tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name)
|
||||
if self._update_ops:
|
||||
logger.debug("Update ops to run on apply gradient: {}".format(
|
||||
self._update_ops))
|
||||
with tf.control_dependencies(self._update_ops):
|
||||
self._apply_op = self.build_apply_op(self._optimizer,
|
||||
self._grads_and_vars)
|
||||
|
||||
if log_once("loss_used"):
|
||||
logger.debug(
|
||||
"These tensors were used in the loss_fn:\n\n{}\n".format(
|
||||
summarize(self._loss_input_dict)))
|
||||
|
||||
self._sess.run(tf.global_variables_initializer())
|
||||
|
||||
@override(PolicyGraph)
|
||||
def compute_actions(self,
|
||||
@@ -186,18 +196,21 @@ class TFPolicyGraph(PolicyGraph):
|
||||
|
||||
@override(PolicyGraph)
|
||||
def compute_gradients(self, postprocessed_batch):
|
||||
assert self._loss is not None, "Loss not initialized"
|
||||
builder = TFRunBuilder(self._sess, "compute_gradients")
|
||||
fetches = self._build_compute_gradients(builder, postprocessed_batch)
|
||||
return builder.get(fetches)
|
||||
|
||||
@override(PolicyGraph)
|
||||
def apply_gradients(self, gradients):
|
||||
assert self._loss is not None, "Loss not initialized"
|
||||
builder = TFRunBuilder(self._sess, "apply_gradients")
|
||||
fetches = self._build_apply_gradients(builder, gradients)
|
||||
builder.get(fetches)
|
||||
|
||||
@override(PolicyGraph)
|
||||
def learn_on_batch(self, postprocessed_batch):
|
||||
assert self._loss is not None, "Loss not initialized"
|
||||
builder = TFRunBuilder(self._sess, "learn_on_batch")
|
||||
fetches = self._build_learn_on_batch(builder, postprocessed_batch)
|
||||
return builder.get(fetches)
|
||||
@@ -271,7 +284,10 @@ class TFPolicyGraph(PolicyGraph):
|
||||
@DeveloperAPI
|
||||
def optimizer(self):
|
||||
"""TF optimizer to use for policy optimization."""
|
||||
return tf.train.AdamOptimizer()
|
||||
if hasattr(self, "config"):
|
||||
return tf.train.AdamOptimizer(self.config["lr"])
|
||||
else:
|
||||
return tf.train.AdamOptimizer()
|
||||
|
||||
@DeveloperAPI
|
||||
def gradients(self, optimizer, loss):
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.evaluation.dynamic_tf_policy_graph import DynamicTFPolicyGraph
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def build_tf_policy(name,
|
||||
loss_fn,
|
||||
get_default_config=None,
|
||||
stats_fn=None,
|
||||
grad_stats_fn=None,
|
||||
extra_action_fetches_fn=None,
|
||||
postprocess_fn=None,
|
||||
optimizer_fn=None,
|
||||
gradients_fn=None,
|
||||
before_init=None,
|
||||
before_loss_init=None,
|
||||
after_init=None,
|
||||
make_action_sampler=None,
|
||||
mixins=None,
|
||||
get_batch_divisibility_req=None):
|
||||
"""Helper function for creating a dynamic tf policy at runtime.
|
||||
|
||||
Arguments:
|
||||
name (str): name of the graph (e.g., "PPOPolicy")
|
||||
loss_fn (func): function that returns a loss tensor the policy,
|
||||
and dict of experience tensor placeholders
|
||||
get_default_config (func): optional function that returns the default
|
||||
config to merge with any overrides
|
||||
stats_fn (func): optional function that returns a dict of
|
||||
TF fetches given the policy and batch input tensors
|
||||
grad_stats_fn (func): optional function that returns a dict of
|
||||
TF fetches given the policy and loss gradient tensors
|
||||
extra_action_fetches_fn (func): optional function that returns
|
||||
a dict of TF fetches given the policy object
|
||||
postprocess_fn (func): optional experience postprocessing function
|
||||
that takes the same args as PolicyGraph.postprocess_trajectory()
|
||||
optimizer_fn (func): optional function that returns a tf.Optimizer
|
||||
given the policy and config
|
||||
gradients_fn (func): optional function that returns a list of gradients
|
||||
given a tf optimizer and loss tensor. If not specified, this
|
||||
defaults to optimizer.compute_gradients(loss)
|
||||
before_init (func): optional function to run at the beginning of
|
||||
policy init that takes the same arguments as the policy constructor
|
||||
before_loss_init (func): optional function to run prior to loss
|
||||
init that takes the same arguments as the policy constructor
|
||||
after_init (func): optional function to run at the end of policy init
|
||||
that takes the same arguments as the policy constructor
|
||||
make_action_sampler (func): optional function that returns a
|
||||
tuple of action and action prob tensors. The function takes
|
||||
(policy, input_dict, obs_space, action_space, config) as its
|
||||
arguments
|
||||
mixins (list): list of any class mixins for the returned policy class.
|
||||
These mixins will be applied in order and will have higher
|
||||
precedence than the DynamicTFPolicyGraph class
|
||||
get_batch_divisibility_req (func): optional function that returns
|
||||
the divisibility requirement for sample batches
|
||||
|
||||
Returns:
|
||||
a DynamicTFPolicyGraph instance that uses the specified args
|
||||
"""
|
||||
|
||||
if not name.endswith("TFPolicy"):
|
||||
raise ValueError("Name should match *TFPolicy", name)
|
||||
|
||||
base = DynamicTFPolicyGraph
|
||||
while mixins:
|
||||
|
||||
class new_base(mixins.pop(), base):
|
||||
pass
|
||||
|
||||
base = new_base
|
||||
|
||||
class graph_cls(base):
|
||||
def __init__(self,
|
||||
obs_space,
|
||||
action_space,
|
||||
config,
|
||||
existing_inputs=None):
|
||||
if get_default_config:
|
||||
config = dict(get_default_config(), **config)
|
||||
|
||||
if before_init:
|
||||
before_init(self, obs_space, action_space, config)
|
||||
|
||||
def before_loss_init_wrapper(policy, obs_space, action_space,
|
||||
config):
|
||||
if before_loss_init:
|
||||
before_loss_init(policy, obs_space, action_space, config)
|
||||
if extra_action_fetches_fn is None:
|
||||
self._extra_action_fetches = {}
|
||||
else:
|
||||
self._extra_action_fetches = extra_action_fetches_fn(self)
|
||||
|
||||
DynamicTFPolicyGraph.__init__(
|
||||
self,
|
||||
obs_space,
|
||||
action_space,
|
||||
config,
|
||||
loss_fn,
|
||||
stats_fn=stats_fn,
|
||||
grad_stats_fn=grad_stats_fn,
|
||||
before_loss_init=before_loss_init_wrapper,
|
||||
existing_inputs=existing_inputs)
|
||||
|
||||
if after_init:
|
||||
after_init(self, obs_space, action_space, config)
|
||||
|
||||
@override(PolicyGraph)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
if not postprocess_fn:
|
||||
return sample_batch
|
||||
return postprocess_fn(self, sample_batch, other_agent_batches,
|
||||
episode)
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def optimizer(self):
|
||||
if optimizer_fn:
|
||||
return optimizer_fn(self, self.config)
|
||||
else:
|
||||
return TFPolicyGraph.optimizer(self)
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def gradients(self, optimizer, loss):
|
||||
if gradients_fn:
|
||||
return gradients_fn(self, optimizer, loss)
|
||||
else:
|
||||
return TFPolicyGraph.gradients(self, optimizer, loss)
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def extra_compute_action_fetches(self):
|
||||
return dict(
|
||||
TFPolicyGraph.extra_compute_action_fetches(self),
|
||||
**self._extra_action_fetches)
|
||||
|
||||
graph_cls.__name__ = name
|
||||
graph_cls.__qualname__ = name
|
||||
return graph_cls
|
||||
@@ -15,6 +15,7 @@ except ImportError:
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.tracking_dict import UsageTrackingDict
|
||||
|
||||
|
||||
class TorchPolicyGraph(PolicyGraph):
|
||||
@@ -30,7 +31,7 @@ class TorchPolicyGraph(PolicyGraph):
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space, action_space, model, loss,
|
||||
loss_inputs, action_distribution_cls):
|
||||
action_distribution_cls):
|
||||
"""Build a policy graph from policy and loss torch modules.
|
||||
|
||||
Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES
|
||||
@@ -42,13 +43,8 @@ class TorchPolicyGraph(PolicyGraph):
|
||||
model (nn.Module): PyTorch policy module. Given observations as
|
||||
input, this module must return a list of outputs where the
|
||||
first item is action logits, and the rest can be any value.
|
||||
loss (nn.Module): Loss defined as a PyTorch module. The inputs for
|
||||
this module are defined by the `loss_inputs` param. This module
|
||||
returns a single scalar loss. Note that this module should
|
||||
internally be using the model module.
|
||||
loss_inputs (list): List of SampleBatch columns that will be
|
||||
passed to the loss module's forward() function when computing
|
||||
the loss. For example, ["obs", "action", "advantages"].
|
||||
loss (func): Function that takes (policy_graph, batch_tensors)
|
||||
and returns a single scalar loss.
|
||||
action_distribution_cls (ActionDistribution): Class for action
|
||||
distribution.
|
||||
"""
|
||||
@@ -60,7 +56,6 @@ class TorchPolicyGraph(PolicyGraph):
|
||||
else torch.device("cpu"))
|
||||
self._model = model.to(self.device)
|
||||
self._loss = loss
|
||||
self._loss_inputs = loss_inputs
|
||||
self._optimizer = self.optimizer()
|
||||
self._action_dist_cls = action_distribution_cls
|
||||
|
||||
@@ -87,30 +82,26 @@ class TorchPolicyGraph(PolicyGraph):
|
||||
|
||||
@override(PolicyGraph)
|
||||
def learn_on_batch(self, postprocessed_batch):
|
||||
batch_tensors = self._lazy_tensor_dict(postprocessed_batch)
|
||||
|
||||
with self.lock:
|
||||
loss_in = []
|
||||
for key in self._loss_inputs:
|
||||
loss_in.append(
|
||||
torch.from_numpy(postprocessed_batch[key]).to(self.device))
|
||||
loss_out = self._loss(self._model, *loss_in)
|
||||
loss_out = self._loss(self, batch_tensors)
|
||||
self._optimizer.zero_grad()
|
||||
loss_out.backward()
|
||||
|
||||
grad_process_info = self.extra_grad_process()
|
||||
self._optimizer.step()
|
||||
|
||||
grad_info = self.extra_grad_info()
|
||||
grad_info = self.extra_grad_info(batch_tensors)
|
||||
grad_info.update(grad_process_info)
|
||||
return {LEARNER_STATS_KEY: grad_info}
|
||||
|
||||
@override(PolicyGraph)
|
||||
def compute_gradients(self, postprocessed_batch):
|
||||
batch_tensors = self._lazy_tensor_dict(postprocessed_batch)
|
||||
|
||||
with self.lock:
|
||||
loss_in = []
|
||||
for key in self._loss_inputs:
|
||||
loss_in.append(
|
||||
torch.from_numpy(postprocessed_batch[key]).to(self.device))
|
||||
loss_out = self._loss(self._model, *loss_in)
|
||||
loss_out = self._loss(self, batch_tensors)
|
||||
self._optimizer.zero_grad()
|
||||
loss_out.backward()
|
||||
|
||||
@@ -125,7 +116,7 @@ class TorchPolicyGraph(PolicyGraph):
|
||||
else:
|
||||
grads.append(None)
|
||||
|
||||
grad_info = self.extra_grad_info()
|
||||
grad_info = self.extra_grad_info(batch_tensors)
|
||||
grad_info.update(grad_process_info)
|
||||
return grads, {LEARNER_STATS_KEY: grad_info}
|
||||
|
||||
@@ -163,11 +154,21 @@ class TorchPolicyGraph(PolicyGraph):
|
||||
model_out (list): Outputs of the policy model module."""
|
||||
return {}
|
||||
|
||||
def extra_grad_info(self):
|
||||
def extra_grad_info(self, batch_tensors):
|
||||
"""Return dict of extra grad info."""
|
||||
|
||||
return {}
|
||||
|
||||
def optimizer(self):
|
||||
"""Custom PyTorch optimizer to use."""
|
||||
return torch.optim.Adam(self._model.parameters())
|
||||
if hasattr(self, "config"):
|
||||
return torch.optim.Adam(
|
||||
self._model.parameters(), lr=self.config["lr"])
|
||||
else:
|
||||
return torch.optim.Adam(self._model.parameters())
|
||||
|
||||
def _lazy_tensor_dict(self, postprocessed_batch):
|
||||
batch_tensors = UsageTrackingDict(postprocessed_batch)
|
||||
batch_tensors.set_get_interceptor(
|
||||
lambda arr: torch.from_numpy(arr).to(self.device))
|
||||
return batch_tensors
|
||||
|
||||
@@ -0,0 +1,133 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.torch_policy_graph import TorchPolicyGraph
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def build_torch_policy(name,
|
||||
loss_fn,
|
||||
get_default_config=None,
|
||||
stats_fn=None,
|
||||
postprocess_fn=None,
|
||||
extra_action_out_fn=None,
|
||||
extra_grad_process_fn=None,
|
||||
optimizer_fn=None,
|
||||
before_init=None,
|
||||
after_init=None,
|
||||
make_model_and_action_dist=None,
|
||||
mixins=None):
|
||||
"""Helper function for creating a torch policy at runtime.
|
||||
|
||||
Arguments:
|
||||
name (str): name of the graph (e.g., "PPOPolicy")
|
||||
loss_fn (func): function that returns a loss tensor the policy,
|
||||
and dict of experience tensor placeholders
|
||||
get_default_config (func): optional function that returns the default
|
||||
config to merge with any overrides
|
||||
stats_fn (func): optional function that returns a dict of
|
||||
values given the policy and batch input tensors
|
||||
postprocess_fn (func): optional experience postprocessing function
|
||||
that takes the same args as PolicyGraph.postprocess_trajectory()
|
||||
extra_action_out_fn (func): optional function that returns
|
||||
a dict of extra values to include in experiences
|
||||
extra_grad_process_fn (func): optional function that is called after
|
||||
gradients are computed and returns processing info
|
||||
optimizer_fn (func): optional function that returns a torch optimizer
|
||||
given the policy and config
|
||||
before_init (func): optional function to run at the beginning of
|
||||
policy init that takes the same arguments as the policy constructor
|
||||
after_init (func): optional function to run at the end of policy init
|
||||
that takes the same arguments as the policy constructor
|
||||
make_model_and_action_dist (func): optional func that takes the same
|
||||
arguments as policy init and returns a tuple of model instance and
|
||||
torch action distribution class. If not specified, the default
|
||||
model and action dist from the catalog will be used
|
||||
mixins (list): list of any class mixins for the returned policy class.
|
||||
These mixins will be applied in order and will have higher
|
||||
precedence than the TorchPolicyGraph class
|
||||
|
||||
Returns:
|
||||
a TorchPolicyGraph instance that uses the specified args
|
||||
"""
|
||||
|
||||
if not name.endswith("TorchPolicy"):
|
||||
raise ValueError("Name should match *TorchPolicy", name)
|
||||
|
||||
base = TorchPolicyGraph
|
||||
while mixins:
|
||||
|
||||
class new_base(mixins.pop(), base):
|
||||
pass
|
||||
|
||||
base = new_base
|
||||
|
||||
class graph_cls(base):
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
if get_default_config:
|
||||
config = dict(get_default_config(), **config)
|
||||
self.config = config
|
||||
|
||||
if before_init:
|
||||
before_init(self, obs_space, action_space, config)
|
||||
|
||||
if make_model_and_action_dist:
|
||||
self.model, self.dist_class = make_model_and_action_dist(
|
||||
self, obs_space, action_space, config)
|
||||
else:
|
||||
self.dist_class, logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"], torch=True)
|
||||
self.model = ModelCatalog.get_torch_model(
|
||||
obs_space, logit_dim, self.config["model"])
|
||||
|
||||
TorchPolicyGraph.__init__(self, obs_space, action_space,
|
||||
self.model, loss_fn, self.dist_class)
|
||||
|
||||
if after_init:
|
||||
after_init(self, obs_space, action_space, config)
|
||||
|
||||
@override(PolicyGraph)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
if not postprocess_fn:
|
||||
return sample_batch
|
||||
return postprocess_fn(self, sample_batch, other_agent_batches,
|
||||
episode)
|
||||
|
||||
@override(TorchPolicyGraph)
|
||||
def extra_grad_process(self):
|
||||
if extra_grad_process_fn:
|
||||
return extra_grad_process_fn(self)
|
||||
else:
|
||||
return TorchPolicyGraph.extra_grad_process(self)
|
||||
|
||||
@override(TorchPolicyGraph)
|
||||
def extra_action_out(self, model_out):
|
||||
if extra_action_out_fn:
|
||||
return extra_action_out_fn(self, model_out)
|
||||
else:
|
||||
return TorchPolicyGraph.extra_action_out(self, model_out)
|
||||
|
||||
@override(TorchPolicyGraph)
|
||||
def optimizer(self):
|
||||
if optimizer_fn:
|
||||
return optimizer_fn(self, self.config)
|
||||
else:
|
||||
return TorchPolicyGraph.optimizer(self)
|
||||
|
||||
@override(TorchPolicyGraph)
|
||||
def extra_grad_info(self, batch_tensors):
|
||||
if stats_fn:
|
||||
return stats_fn(self, batch_tensors)
|
||||
else:
|
||||
return TorchPolicyGraph.extra_grad_info(self, batch_tensors)
|
||||
|
||||
graph_cls.__name__ = name
|
||||
graph_cls.__qualname__ = name
|
||||
return graph_cls
|
||||
@@ -18,7 +18,7 @@ import ray
|
||||
from ray.rllib.agents.dqn.dqn import DQNTrainer
|
||||
from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
|
||||
from ray.rllib.agents.ppo.ppo import PPOTrainer
|
||||
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
|
||||
from ray.rllib.agents.ppo.ppo_policy_graph import PPOTFPolicy
|
||||
from ray.rllib.tests.test_multi_agent_env import MultiCartpole
|
||||
from ray.tune.logger import pretty_print
|
||||
from ray.tune.registry import register_env
|
||||
@@ -39,7 +39,7 @@ if __name__ == "__main__":
|
||||
# You can also have multiple policy graphs per trainer, but here we just
|
||||
# show one each for PPO and DQN.
|
||||
policy_graphs = {
|
||||
"ppo_policy": (PPOPolicyGraph, obs_space, act_space, {}),
|
||||
"ppo_policy": (PPOTFPolicy, obs_space, act_space, {}),
|
||||
"dqn_policy": (DQNPolicyGraph, obs_space, act_space, {}),
|
||||
}
|
||||
|
||||
|
||||
@@ -255,7 +255,7 @@ class LocalSyncParallelOptimizer(object):
|
||||
|
||||
fetches = {"train": self._train_op}
|
||||
for tower in self._towers:
|
||||
fetches.update(tower.loss_graph.extra_compute_grad_fetches())
|
||||
fetches.update(tower.loss_graph._get_grad_and_stats_fetches())
|
||||
|
||||
return sess.run(fetches, feed_dict=feed_dict)
|
||||
|
||||
|
||||
@@ -222,6 +222,6 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||
def _averaged(kv):
|
||||
out = {}
|
||||
for k, v in kv.items():
|
||||
if v[0] is not None:
|
||||
if v[0] is not None and not isinstance(v[0], dict):
|
||||
out[k] = np.mean(v)
|
||||
return out
|
||||
|
||||
@@ -8,7 +8,7 @@ import random
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
|
||||
from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy
|
||||
from ray.rllib.optimizers import SyncSamplesOptimizer
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
|
||||
@@ -67,7 +67,7 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
|
||||
obs_space = single_env.observation_space
|
||||
policies = {}
|
||||
for i in range(20):
|
||||
policies["pg_{}".format(i)] = (PGPolicyGraph, obs_space, act_space,
|
||||
policies["pg_{}".format(i)] = (PGTFPolicy, obs_space, act_space,
|
||||
{})
|
||||
policy_ids = list(policies.keys())
|
||||
ev = PolicyEvaluator(
|
||||
|
||||
@@ -15,7 +15,7 @@ import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.pg import PGTrainer
|
||||
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
|
||||
from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy
|
||||
from ray.rllib.evaluation import SampleBatch
|
||||
from ray.rllib.offline import IOContext, JsonWriter, JsonReader
|
||||
from ray.rllib.offline.json_writer import _to_json
|
||||
@@ -159,7 +159,7 @@ class AgentIOTest(unittest.TestCase):
|
||||
def gen_policy():
|
||||
obs_space = single_env.observation_space
|
||||
act_space = single_env.action_space
|
||||
return (PGPolicyGraph, obs_space, act_space, {})
|
||||
return (PGTFPolicy, obs_space, act_space, {})
|
||||
|
||||
pg = PGTrainer(
|
||||
env="multi_cartpole",
|
||||
|
||||
@@ -8,7 +8,7 @@ import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.pg import PGTrainer
|
||||
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
|
||||
from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy
|
||||
from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
|
||||
from ray.rllib.optimizers import (SyncSamplesOptimizer, SyncReplayOptimizer,
|
||||
AsyncGradientsOptimizer)
|
||||
@@ -470,7 +470,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
self.assertEqual(batch["state_out_0"][1], h)
|
||||
|
||||
def testReturningModelBasedRolloutsData(self):
|
||||
class ModelBasedPolicyGraph(PGPolicyGraph):
|
||||
class ModelBasedPolicyGraph(PGTFPolicy):
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches,
|
||||
@@ -584,7 +584,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
}
|
||||
else:
|
||||
policies = {
|
||||
"p1": (PGPolicyGraph, obs_space, act_space, {}),
|
||||
"p1": (PGTFPolicy, obs_space, act_space, {}),
|
||||
"p2": (DQNPolicyGraph, obs_space, act_space, dqn_config),
|
||||
}
|
||||
ev = PolicyEvaluator(
|
||||
@@ -640,7 +640,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
obs_space = env.observation_space
|
||||
policies = {}
|
||||
for i in range(20):
|
||||
policies["pg_{}".format(i)] = (PGPolicyGraph, obs_space, act_space,
|
||||
policies["pg_{}".format(i)] = (PGTFPolicy, obs_space, act_space,
|
||||
{})
|
||||
policy_ids = list(policies.keys())
|
||||
ev = PolicyEvaluator(
|
||||
|
||||
@@ -12,7 +12,7 @@ import unittest
|
||||
import ray
|
||||
from ray.rllib.agents.a3c import A2CTrainer
|
||||
from ray.rllib.agents.pg import PGTrainer
|
||||
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
|
||||
from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy
|
||||
from ray.rllib.env import MultiAgentEnv
|
||||
from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
@@ -333,10 +333,10 @@ class NestedSpacesTest(unittest.TestCase):
|
||||
"multiagent": {
|
||||
"policy_graphs": {
|
||||
"tuple_policy": (
|
||||
PGPolicyGraph, TUPLE_SPACE, act_space,
|
||||
PGTFPolicy, TUPLE_SPACE, act_space,
|
||||
{"model": {"custom_model": "tuple_spy"}}),
|
||||
"dict_policy": (
|
||||
PGPolicyGraph, DICT_SPACE, act_space,
|
||||
PGTFPolicy, DICT_SPACE, act_space,
|
||||
{"model": {"custom_model": "dict_spy"}}),
|
||||
},
|
||||
"policy_mapping_fn": lambda a: {
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
|
||||
from ray.rllib.agents.ppo.ppo_policy_graph import PPOTFPolicy
|
||||
from ray.rllib.evaluation import SampleBatch
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.optimizers import AsyncGradientsOptimizer, AsyncSamplesOptimizer
|
||||
@@ -240,12 +240,12 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
|
||||
|
||||
local = PolicyEvaluator(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy_graph=PPOPolicyGraph,
|
||||
policy_graph=PPOTFPolicy,
|
||||
tf_session_creator=make_sess)
|
||||
remotes = [
|
||||
PolicyEvaluator.as_remote().remote(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy_graph=PPOPolicyGraph,
|
||||
policy_graph=PPOTFPolicy,
|
||||
tf_session_creator=make_sess)
|
||||
]
|
||||
return local, remotes
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
class UsageTrackingDict(dict):
|
||||
"""Dict that tracks which keys have been accessed.
|
||||
|
||||
It can also intercept gets and allow an arbitrary callback to be applied
|
||||
(i.e., to lazily convert numpy arrays to Tensors).
|
||||
|
||||
We make the simplifying assumption only __getitem__ is used to access
|
||||
values.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
dict.__init__(self, *args, **kwargs)
|
||||
self.accessed_keys = set()
|
||||
self.intercepted_values = {}
|
||||
self.get_interceptor = None
|
||||
|
||||
def set_get_interceptor(self, fn):
|
||||
self.get_interceptor = fn
|
||||
|
||||
def __getitem__(self, key):
|
||||
self.accessed_keys.add(key)
|
||||
value = dict.__getitem__(self, key)
|
||||
if self.get_interceptor:
|
||||
if key not in self.intercepted_values:
|
||||
self.intercepted_values[key] = self.get_interceptor(value)
|
||||
value = self.intercepted_values[key]
|
||||
return value
|
||||
Reference in New Issue
Block a user