From ade6d80820a6db265179a0247f1adfc4d180c5bc Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 19 Jul 2017 23:09:15 +0000 Subject: [PATCH] [rllib] use ray.wait to speed up parallel simulations for policy gradients (#754) * use ray.wait to speed up parallel simulations for policy gradients * linting --- python/ray/rllib/policy_gradient/rollout.py | 24 ++++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/python/ray/rllib/policy_gradient/rollout.py b/python/ray/rllib/policy_gradient/rollout.py index 356a49710..6c4f0804c 100644 --- a/python/ray/rllib/policy_gradient/rollout.py +++ b/python/ray/rllib/policy_gradient/rollout.py @@ -87,17 +87,25 @@ def collect_samples(agents, num_timesteps, gamma, lam, horizon, trajectories = [] total_rewards = [] traj_len_means = [] + # This variable maps the object IDs of trajectories that are currently + # computed to the agent that they are computed on; we start some initial + # tasks here. + agent_dict = {agent.compute_trajectory.remote(gamma, lam, horizon): + agent for agent in agents} while num_timesteps_so_far < num_timesteps: - trajectory_batch = ray.get( - [agent.compute_trajectory.remote(gamma, lam, horizon) - for agent in agents]) - trajectory = concatenate(trajectory_batch) - trajectory = flatten(trajectory) + # TODO(pcm): Make wait support arbitrary iterators and remove the + # conversion to list here. + [next_trajectory], waiting_trajectories = ray.wait( + list(agent_dict.keys())) + agent = agent_dict.pop(next_trajectory) + # Start task with next trajectory and record it in the dictionary. + agent_dict[agent.compute_trajectory.remote(gamma, lam, horizon)] = ( + agent) + trajectory = flatten(ray.get(next_trajectory)) not_done = np.logical_not(trajectory["dones"]) total_rewards.append( - trajectory["raw_rewards"][not_done].sum(axis=0).mean() / - len(agents)) - traj_len_means.append(not_done.sum(axis=0).mean() / len(agents)) + trajectory["raw_rewards"][not_done].sum(axis=0).mean()) + traj_len_means.append(not_done.sum(axis=0).mean()) trajectory = {key: val[not_done] for key, val in trajectory.items()} num_timesteps_so_far += len(trajectory["dones"]) trajectories.append(trajectory)