mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 11:37:28 +08:00
[rllib] Allow access to batches prior to postprocessing (#4871)
This commit is contained in:
@@ -54,14 +54,20 @@ COMMON_CONFIG = {
|
||||
# Callbacks that will be run during various phases of training. These all
|
||||
# take a single "info" dict as an argument. For episode callbacks, custom
|
||||
# metrics can be attached to the episode by updating the episode object's
|
||||
# custom metrics dict (see examples/custom_metrics_and_callbacks.py).
|
||||
# custom metrics dict (see examples/custom_metrics_and_callbacks.py). You
|
||||
# may also mutate the passed in batch data in your callback.
|
||||
"callbacks": {
|
||||
"on_episode_start": None, # arg: {"env": .., "episode": ...}
|
||||
"on_episode_step": None, # arg: {"env": .., "episode": ...}
|
||||
"on_episode_end": None, # arg: {"env": .., "episode": ...}
|
||||
"on_sample_end": None, # arg: {"samples": .., "evaluator": ...}
|
||||
"on_train_result": None, # arg: {"trainer": ..., "result": ...}
|
||||
"on_postprocess_traj": None, # arg: {"batch": ..., "episode": ...}
|
||||
"on_postprocess_traj": None, # arg: {
|
||||
# "agent_id": ..., "episode": ...,
|
||||
# "pre_batch": (before processing),
|
||||
# "post_batch": (after processing),
|
||||
# "all_pre_batches": (other agent ids),
|
||||
# }
|
||||
},
|
||||
# Whether to attempt to continue training if a worker crashes.
|
||||
"ignore_worker_failures": False,
|
||||
|
||||
@@ -165,7 +165,13 @@ class MultiAgentSampleBatchBuilder(object):
|
||||
self.policy_builders[self.agent_to_policy[agent_id]].add_batch(
|
||||
post_batch)
|
||||
if self.postp_callback:
|
||||
self.postp_callback({"episode": episode, "batch": post_batch})
|
||||
self.postp_callback({
|
||||
"episode": episode,
|
||||
"agent_id": agent_id,
|
||||
"pre_batch": pre_batches[agent_id],
|
||||
"post_batch": post_batch,
|
||||
"all_pre_batches": pre_batches,
|
||||
})
|
||||
|
||||
self.agent_builders.clear()
|
||||
self.agent_to_policy.clear()
|
||||
|
||||
@@ -46,7 +46,7 @@ def on_train_result(info):
|
||||
|
||||
def on_postprocess_traj(info):
|
||||
episode = info["episode"]
|
||||
batch = info["batch"]
|
||||
batch = info["post_batch"]
|
||||
print("postprocessed {} steps".format(batch.count))
|
||||
if "num_batches" not in episode.custom_metrics:
|
||||
episode.custom_metrics["num_batches"] = 0
|
||||
|
||||
Reference in New Issue
Block a user