[rllib] Allow access to batches prior to postprocessing (#4871)

This commit is contained in:
Eric Liang
2019-05-29 18:17:14 -07:00
committed by GitHub
parent a218a14c92
commit 2dd0beb5bd
5 changed files with 24 additions and 4 deletions
+4
View File
@@ -101,6 +101,10 @@ Tuned examples: `PongNoFrameskip-v4 <https://github.com/ray-project/ray/blob/mas
**APPO-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
.. warning::
Keras custom models are not compatible with multi-GPU (this includes PPO in single-GPU mode). This is because the multi-GPU implementation in RLlib relies on variable scopes to implement cross-GPU support.
.. literalinclude:: ../../python/ray/rllib/agents/ppo/appo.py
:language: python
:start-after: __sphinx_doc_begin__
+4
View File
@@ -35,6 +35,10 @@ Custom Models (TensorFlow)
Custom TF models should subclass the common RLlib `model class <https://github.com/ray-project/ray/blob/master/python/ray/rllib/models/model.py>`__ and override the ``_build_layers_v2`` method. This method takes in a dict of tensor inputs (the observation ``obs``, ``prev_action``, and ``prev_reward``, ``is_training``), and returns a feature layer and float vector of the specified output size. You can also override the ``value_function`` method to implement a custom value branch. Additional supervised / self-supervised losses can be added via the ``custom_loss`` method. The model can then be registered and used in place of a built-in model:
.. warning::
Keras custom models are not compatible with multi-GPU (this includes PPO in single-GPU mode). This is because the multi-GPU implementation in RLlib relies on variable scopes to implement cross-GPU support.
.. code-block:: python
import ray
+8 -2
View File
@@ -54,14 +54,20 @@ COMMON_CONFIG = {
# Callbacks that will be run during various phases of training. These all
# take a single "info" dict as an argument. For episode callbacks, custom
# metrics can be attached to the episode by updating the episode object's
# custom metrics dict (see examples/custom_metrics_and_callbacks.py).
# custom metrics dict (see examples/custom_metrics_and_callbacks.py). You
# may also mutate the passed in batch data in your callback.
"callbacks": {
"on_episode_start": None, # arg: {"env": .., "episode": ...}
"on_episode_step": None, # arg: {"env": .., "episode": ...}
"on_episode_end": None, # arg: {"env": .., "episode": ...}
"on_sample_end": None, # arg: {"samples": .., "evaluator": ...}
"on_train_result": None, # arg: {"trainer": ..., "result": ...}
"on_postprocess_traj": None, # arg: {"batch": ..., "episode": ...}
"on_postprocess_traj": None, # arg: {
# "agent_id": ..., "episode": ...,
# "pre_batch": (before processing),
# "post_batch": (after processing),
# "all_pre_batches": (other agent ids),
# }
},
# Whether to attempt to continue training if a worker crashes.
"ignore_worker_failures": False,
@@ -165,7 +165,13 @@ class MultiAgentSampleBatchBuilder(object):
self.policy_builders[self.agent_to_policy[agent_id]].add_batch(
post_batch)
if self.postp_callback:
self.postp_callback({"episode": episode, "batch": post_batch})
self.postp_callback({
"episode": episode,
"agent_id": agent_id,
"pre_batch": pre_batches[agent_id],
"post_batch": post_batch,
"all_pre_batches": pre_batches,
})
self.agent_builders.clear()
self.agent_to_policy.clear()
@@ -46,7 +46,7 @@ def on_train_result(info):
def on_postprocess_traj(info):
episode = info["episode"]
batch = info["batch"]
batch = info["post_batch"]
print("postprocessed {} steps".format(batch.count))
if "num_batches" not in episode.custom_metrics:
episode.custom_metrics["num_batches"] = 0