mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 23:54:34 +08:00
[rllib] Document env compatibility, Ape-X support for multi-agent (#3147)
This commit is contained in:
@@ -47,7 +47,7 @@ class ApexDDPGAgent(DDPGAgent):
|
||||
def default_resource_request(cls, config):
|
||||
cf = merge_dicts(cls._default_config, config)
|
||||
return Resources(
|
||||
cpu=1 + cf["optimizer"]["num_replay_buffer_shards"],
|
||||
cpu=1,
|
||||
gpu=cf["gpu"] and cf["gpu_fraction"] or 0,
|
||||
extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"],
|
||||
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"])
|
||||
|
||||
@@ -50,7 +50,7 @@ class ApexAgent(DQNAgent):
|
||||
def default_resource_request(cls, config):
|
||||
cf = merge_dicts(cls._default_config, config)
|
||||
return Resources(
|
||||
cpu=1 + cf["optimizer"]["num_replay_buffer_shards"],
|
||||
cpu=1,
|
||||
gpu=cf["gpu"] and cf["gpu_fraction"] or 0,
|
||||
extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"],
|
||||
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"])
|
||||
|
||||
@@ -85,7 +85,6 @@ if __name__ == "__main__":
|
||||
"custom_model": ["model1", "model2"][i % 2],
|
||||
},
|
||||
"gamma": random.choice([0.5, 0.8, 0.9, 0.95, 0.99]),
|
||||
"n_step": random.choice([1, 2, 3, 4, 5]),
|
||||
}
|
||||
return (PPOPolicyGraph, obs_space, act_space, config)
|
||||
|
||||
@@ -98,12 +97,13 @@ if __name__ == "__main__":
|
||||
|
||||
run_experiments({
|
||||
"test": {
|
||||
"run": "PG",
|
||||
"run": "PPO",
|
||||
"env": "multi_cartpole",
|
||||
"stop": {
|
||||
"training_iteration": args.num_iters
|
||||
},
|
||||
"config": {
|
||||
"simple_optimizer": True,
|
||||
"multiagent": {
|
||||
"policy_graphs": policy_graphs,
|
||||
"policy_mapping_fn": tune.function(
|
||||
|
||||
@@ -6,6 +6,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
@@ -15,9 +16,10 @@ import numpy as np
|
||||
from six.moves import queue
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
MultiAgentBatch
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.actors import TaskPool, create_colocated
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.window_stat import WindowStat
|
||||
@@ -43,49 +45,61 @@ class ReplayActor(object):
|
||||
self.prioritized_replay_beta = prioritized_replay_beta
|
||||
self.prioritized_replay_eps = prioritized_replay_eps
|
||||
|
||||
self.replay_buffer = PrioritizedReplayBuffer(
|
||||
self.buffer_size, alpha=prioritized_replay_alpha)
|
||||
def new_buffer():
|
||||
return PrioritizedReplayBuffer(
|
||||
self.buffer_size, alpha=prioritized_replay_alpha)
|
||||
|
||||
self.replay_buffers = collections.defaultdict(new_buffer)
|
||||
|
||||
# Metrics
|
||||
self.add_batch_timer = TimerStat()
|
||||
self.replay_timer = TimerStat()
|
||||
self.update_priorities_timer = TimerStat()
|
||||
self.num_added = 0
|
||||
|
||||
def get_host(self):
|
||||
return os.uname()[1]
|
||||
|
||||
def add_batch(self, batch):
|
||||
PolicyOptimizer._check_not_multiagent(batch)
|
||||
# Handle everything as if multiagent
|
||||
if isinstance(batch, SampleBatch):
|
||||
batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count)
|
||||
with self.add_batch_timer:
|
||||
for row in batch.rows():
|
||||
self.replay_buffer.add(row["obs"], row["actions"],
|
||||
row["rewards"], row["new_obs"],
|
||||
row["dones"], row["weights"])
|
||||
for policy_id, s in batch.policy_batches.items():
|
||||
for row in s.rows():
|
||||
self.replay_buffers[policy_id].add(
|
||||
row["obs"], row["actions"], row["rewards"],
|
||||
row["new_obs"], row["dones"], row["weights"])
|
||||
self.num_added += batch.count
|
||||
|
||||
def replay(self):
|
||||
if self.num_added < self.replay_starts:
|
||||
return None
|
||||
|
||||
with self.replay_timer:
|
||||
if len(self.replay_buffer) < self.replay_starts:
|
||||
return None
|
||||
samples = {}
|
||||
for policy_id, replay_buffer in self.replay_buffers.items():
|
||||
(obses_t, actions, rewards, obses_tp1, dones, weights,
|
||||
batch_indexes) = replay_buffer.sample(
|
||||
self.train_batch_size, beta=self.prioritized_replay_beta)
|
||||
samples[policy_id] = SampleBatch({
|
||||
"obs": obses_t,
|
||||
"actions": actions,
|
||||
"rewards": rewards,
|
||||
"new_obs": obses_tp1,
|
||||
"dones": dones,
|
||||
"weights": weights,
|
||||
"batch_indexes": batch_indexes
|
||||
})
|
||||
return MultiAgentBatch(samples, self.train_batch_size)
|
||||
|
||||
(obses_t, actions, rewards, obses_tp1, dones, weights,
|
||||
batch_indexes) = self.replay_buffer.sample(
|
||||
self.train_batch_size, beta=self.prioritized_replay_beta)
|
||||
|
||||
batch = SampleBatch({
|
||||
"obs": obses_t,
|
||||
"actions": actions,
|
||||
"rewards": rewards,
|
||||
"new_obs": obses_tp1,
|
||||
"dones": dones,
|
||||
"weights": weights,
|
||||
"batch_indexes": batch_indexes
|
||||
})
|
||||
return batch
|
||||
|
||||
def update_priorities(self, batch_indexes, td_errors):
|
||||
def update_priorities(self, prio_dict):
|
||||
with self.update_priorities_timer:
|
||||
new_priorities = (np.abs(td_errors) + self.prioritized_replay_eps)
|
||||
self.replay_buffer.update_priorities(batch_indexes, new_priorities)
|
||||
for policy_id, (batch_indexes, td_errors) in prio_dict.items():
|
||||
new_priorities = (
|
||||
np.abs(td_errors) + self.prioritized_replay_eps)
|
||||
self.replay_buffers[policy_id].update_priorities(
|
||||
batch_indexes, new_priorities)
|
||||
|
||||
def stats(self, debug=False):
|
||||
stat = {
|
||||
@@ -94,7 +108,10 @@ class ReplayActor(object):
|
||||
"update_priorities_time_ms": round(
|
||||
1000 * self.update_priorities_timer.mean, 3),
|
||||
}
|
||||
stat.update(self.replay_buffer.stats(debug=debug))
|
||||
for policy_id, replay_buffer in self.replay_buffers.items():
|
||||
stat.update({
|
||||
"policy_{}".format(policy_id): replay_buffer.stats(debug=debug)
|
||||
})
|
||||
return stat
|
||||
|
||||
|
||||
@@ -126,10 +143,16 @@ class LearnerThread(threading.Thread):
|
||||
with self.queue_timer:
|
||||
ra, replay = self.inqueue.get()
|
||||
if replay is not None:
|
||||
prio_dict = {}
|
||||
with self.grad_timer:
|
||||
td_error = self.local_evaluator.compute_apply(replay)[
|
||||
"td_error"]
|
||||
self.outqueue.put((ra, replay, td_error, replay.count))
|
||||
grad_out = self.local_evaluator.compute_apply(replay)
|
||||
for pid, info in grad_out.items():
|
||||
prio_dict[pid] = (
|
||||
replay.policy_batches[pid]["batch_indexes"],
|
||||
info["td_error"])
|
||||
# send `replay` back also so that it gets released by the original
|
||||
# thread: https://github.com/ray-project/ray/issues/2610
|
||||
self.outqueue.put((ra, replay, prio_dict, replay.count))
|
||||
self.learner_queue_size.push(self.inqueue.qsize())
|
||||
self.weights_updated = True
|
||||
|
||||
@@ -267,8 +290,8 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
||||
|
||||
with self.timers["update_priorities"]:
|
||||
while not self.learner.outqueue.empty():
|
||||
ra, replay, td_error, count = self.learner.outqueue.get()
|
||||
ra.update_priorities.remote(replay["batch_indexes"], td_error)
|
||||
ra, _, prio_dict, count = self.learner.outqueue.get()
|
||||
ra.update_priorities.remote(prio_dict)
|
||||
train_timesteps += count
|
||||
|
||||
return sample_timesteps, train_timesteps
|
||||
|
||||
Reference in New Issue
Block a user