mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 17:21:06 +08:00
[rllib] use ray.wait to speed up parallel simulations for policy gradients (#754)
* use ray.wait to speed up parallel simulations for policy gradients * linting
This commit is contained in:
committed by
Robert Nishihara
parent
2b3190ad13
commit
ade6d80820
@@ -87,17 +87,25 @@ def collect_samples(agents, num_timesteps, gamma, lam, horizon,
|
||||
trajectories = []
|
||||
total_rewards = []
|
||||
traj_len_means = []
|
||||
# This variable maps the object IDs of trajectories that are currently
|
||||
# computed to the agent that they are computed on; we start some initial
|
||||
# tasks here.
|
||||
agent_dict = {agent.compute_trajectory.remote(gamma, lam, horizon):
|
||||
agent for agent in agents}
|
||||
while num_timesteps_so_far < num_timesteps:
|
||||
trajectory_batch = ray.get(
|
||||
[agent.compute_trajectory.remote(gamma, lam, horizon)
|
||||
for agent in agents])
|
||||
trajectory = concatenate(trajectory_batch)
|
||||
trajectory = flatten(trajectory)
|
||||
# TODO(pcm): Make wait support arbitrary iterators and remove the
|
||||
# conversion to list here.
|
||||
[next_trajectory], waiting_trajectories = ray.wait(
|
||||
list(agent_dict.keys()))
|
||||
agent = agent_dict.pop(next_trajectory)
|
||||
# Start task with next trajectory and record it in the dictionary.
|
||||
agent_dict[agent.compute_trajectory.remote(gamma, lam, horizon)] = (
|
||||
agent)
|
||||
trajectory = flatten(ray.get(next_trajectory))
|
||||
not_done = np.logical_not(trajectory["dones"])
|
||||
total_rewards.append(
|
||||
trajectory["raw_rewards"][not_done].sum(axis=0).mean() /
|
||||
len(agents))
|
||||
traj_len_means.append(not_done.sum(axis=0).mean() / len(agents))
|
||||
trajectory["raw_rewards"][not_done].sum(axis=0).mean())
|
||||
traj_len_means.append(not_done.sum(axis=0).mean())
|
||||
trajectory = {key: val[not_done] for key, val in trajectory.items()}
|
||||
num_timesteps_so_far += len(trajectory["dones"])
|
||||
trajectories.append(trajectory)
|
||||
|
||||
Reference in New Issue
Block a user