mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 22:18:18 +08:00
[rllib] Allow access to batches prior to postprocessing (#4871)
This commit is contained in:
@@ -101,6 +101,10 @@ Tuned examples: `PongNoFrameskip-v4 <https://github.com/ray-project/ray/blob/mas
|
||||
|
||||
**APPO-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
||||
|
||||
.. warning::
|
||||
|
||||
Keras custom models are not compatible with multi-GPU (this includes PPO in single-GPU mode). This is because the multi-GPU implementation in RLlib relies on variable scopes to implement cross-GPU support.
|
||||
|
||||
.. literalinclude:: ../../python/ray/rllib/agents/ppo/appo.py
|
||||
:language: python
|
||||
:start-after: __sphinx_doc_begin__
|
||||
|
||||
@@ -35,6 +35,10 @@ Custom Models (TensorFlow)
|
||||
|
||||
Custom TF models should subclass the common RLlib `model class <https://github.com/ray-project/ray/blob/master/python/ray/rllib/models/model.py>`__ and override the ``_build_layers_v2`` method. This method takes in a dict of tensor inputs (the observation ``obs``, ``prev_action``, and ``prev_reward``, ``is_training``), and returns a feature layer and float vector of the specified output size. You can also override the ``value_function`` method to implement a custom value branch. Additional supervised / self-supervised losses can be added via the ``custom_loss`` method. The model can then be registered and used in place of a built-in model:
|
||||
|
||||
.. warning::
|
||||
|
||||
Keras custom models are not compatible with multi-GPU (this includes PPO in single-GPU mode). This is because the multi-GPU implementation in RLlib relies on variable scopes to implement cross-GPU support.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import ray
|
||||
|
||||
@@ -54,14 +54,20 @@ COMMON_CONFIG = {
|
||||
# Callbacks that will be run during various phases of training. These all
|
||||
# take a single "info" dict as an argument. For episode callbacks, custom
|
||||
# metrics can be attached to the episode by updating the episode object's
|
||||
# custom metrics dict (see examples/custom_metrics_and_callbacks.py).
|
||||
# custom metrics dict (see examples/custom_metrics_and_callbacks.py). You
|
||||
# may also mutate the passed in batch data in your callback.
|
||||
"callbacks": {
|
||||
"on_episode_start": None, # arg: {"env": .., "episode": ...}
|
||||
"on_episode_step": None, # arg: {"env": .., "episode": ...}
|
||||
"on_episode_end": None, # arg: {"env": .., "episode": ...}
|
||||
"on_sample_end": None, # arg: {"samples": .., "evaluator": ...}
|
||||
"on_train_result": None, # arg: {"trainer": ..., "result": ...}
|
||||
"on_postprocess_traj": None, # arg: {"batch": ..., "episode": ...}
|
||||
"on_postprocess_traj": None, # arg: {
|
||||
# "agent_id": ..., "episode": ...,
|
||||
# "pre_batch": (before processing),
|
||||
# "post_batch": (after processing),
|
||||
# "all_pre_batches": (other agent ids),
|
||||
# }
|
||||
},
|
||||
# Whether to attempt to continue training if a worker crashes.
|
||||
"ignore_worker_failures": False,
|
||||
|
||||
@@ -165,7 +165,13 @@ class MultiAgentSampleBatchBuilder(object):
|
||||
self.policy_builders[self.agent_to_policy[agent_id]].add_batch(
|
||||
post_batch)
|
||||
if self.postp_callback:
|
||||
self.postp_callback({"episode": episode, "batch": post_batch})
|
||||
self.postp_callback({
|
||||
"episode": episode,
|
||||
"agent_id": agent_id,
|
||||
"pre_batch": pre_batches[agent_id],
|
||||
"post_batch": post_batch,
|
||||
"all_pre_batches": pre_batches,
|
||||
})
|
||||
|
||||
self.agent_builders.clear()
|
||||
self.agent_to_policy.clear()
|
||||
|
||||
@@ -46,7 +46,7 @@ def on_train_result(info):
|
||||
|
||||
def on_postprocess_traj(info):
|
||||
episode = info["episode"]
|
||||
batch = info["batch"]
|
||||
batch = info["post_batch"]
|
||||
print("postprocessed {} steps".format(batch.count))
|
||||
if "num_batches" not in episode.custom_metrics:
|
||||
episode.custom_metrics["num_batches"] = 0
|
||||
|
||||
Reference in New Issue
Block a user