[rllib] Allow access to batches prior to postprocessing (#4871)

This commit is contained in:
Eric Liang
2019-05-29 18:17:14 -07:00
committed by GitHub
parent a218a14c92
commit 2dd0beb5bd
5 changed files with 24 additions and 4 deletions
+8 -2
View File
@@ -54,14 +54,20 @@ COMMON_CONFIG = {
# Callbacks that will be run during various phases of training. These all
# take a single "info" dict as an argument. For episode callbacks, custom
# metrics can be attached to the episode by updating the episode object's
# custom metrics dict (see examples/custom_metrics_and_callbacks.py).
# custom metrics dict (see examples/custom_metrics_and_callbacks.py). You
# may also mutate the passed in batch data in your callback.
"callbacks": {
"on_episode_start": None, # arg: {"env": .., "episode": ...}
"on_episode_step": None, # arg: {"env": .., "episode": ...}
"on_episode_end": None, # arg: {"env": .., "episode": ...}
"on_sample_end": None, # arg: {"samples": .., "evaluator": ...}
"on_train_result": None, # arg: {"trainer": ..., "result": ...}
"on_postprocess_traj": None, # arg: {"batch": ..., "episode": ...}
"on_postprocess_traj": None, # arg: {
# "agent_id": ..., "episode": ...,
# "pre_batch": (before processing),
# "post_batch": (after processing),
# "all_pre_batches": (other agent ids),
# }
},
# Whether to attempt to continue training if a worker crashes.
"ignore_worker_failures": False,
@@ -165,7 +165,13 @@ class MultiAgentSampleBatchBuilder(object):
self.policy_builders[self.agent_to_policy[agent_id]].add_batch(
post_batch)
if self.postp_callback:
self.postp_callback({"episode": episode, "batch": post_batch})
self.postp_callback({
"episode": episode,
"agent_id": agent_id,
"pre_batch": pre_batches[agent_id],
"post_batch": post_batch,
"all_pre_batches": pre_batches,
})
self.agent_builders.clear()
self.agent_to_policy.clear()
@@ -46,7 +46,7 @@ def on_train_result(info):
def on_postprocess_traj(info):
episode = info["episode"]
batch = info["batch"]
batch = info["post_batch"]
print("postprocessed {} steps".format(batch.count))
if "num_batches" not in episode.custom_metrics:
episode.custom_metrics["num_batches"] = 0