[rllib] use ray.wait to speed up parallel simulations for policy gradients (#754)

* use ray.wait to speed up parallel simulations for policy gradients

* linting
This commit is contained in:
Philipp Moritz
2017-07-19 23:09:15 +00:00
committed by Robert Nishihara
parent 2b3190ad13
commit ade6d80820
+16 -8
View File
@@ -87,17 +87,25 @@ def collect_samples(agents, num_timesteps, gamma, lam, horizon,
trajectories = []
total_rewards = []
traj_len_means = []
# This variable maps the object IDs of trajectories that are currently
# computed to the agent that they are computed on; we start some initial
# tasks here.
agent_dict = {agent.compute_trajectory.remote(gamma, lam, horizon):
agent for agent in agents}
while num_timesteps_so_far < num_timesteps:
trajectory_batch = ray.get(
[agent.compute_trajectory.remote(gamma, lam, horizon)
for agent in agents])
trajectory = concatenate(trajectory_batch)
trajectory = flatten(trajectory)
# TODO(pcm): Make wait support arbitrary iterators and remove the
# conversion to list here.
[next_trajectory], waiting_trajectories = ray.wait(
list(agent_dict.keys()))
agent = agent_dict.pop(next_trajectory)
# Start task with next trajectory and record it in the dictionary.
agent_dict[agent.compute_trajectory.remote(gamma, lam, horizon)] = (
agent)
trajectory = flatten(ray.get(next_trajectory))
not_done = np.logical_not(trajectory["dones"])
total_rewards.append(
trajectory["raw_rewards"][not_done].sum(axis=0).mean() /
len(agents))
traj_len_means.append(not_done.sum(axis=0).mean() / len(agents))
trajectory["raw_rewards"][not_done].sum(axis=0).mean())
traj_len_means.append(not_done.sum(axis=0).mean())
trajectory = {key: val[not_done] for key, val in trajectory.items()}
num_timesteps_so_far += len(trajectory["dones"])
trajectories.append(trajectory)