From a644060daa4c22aabedcfd9b891d43f9afa88f97 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sat, 7 Mar 2020 14:47:58 -0800 Subject: [PATCH] [rllib] First pass at pipeline implementation of DQN (#7433) * wip iters * add test * speed up * update docs * document it * support serial sampling * add test * spacing * annotate it * update * rename to pipeline * comment * iter2 wip * update * update * context test * update * fix * fix * a3c pipeline * doc * update * move timer * comment * add piepline test * fix * clean up * document * iter s * wip dqn * wip * wip * metrics * metrics rename * metrics ctx * wip * constants * add todo * suppport .union * wip * support union * remove prints * add todo * remove auto timer * fix up * fix pipeline test * typing * fix breakage * remove bad assert * wip * fix multiagent example * fixapply * update a3c * remove a2c pl * 0 workers * wip * wip * share metrics * wip * wip * doc * fix weight sync and global var updates * mode * fix * fix * doc * fix --- python/ray/tune/progress_reporter.py | 9 +- python/ray/util/iter.py | 45 +++--- rllib/agents/dqn/dqn.py | 31 +++- rllib/agents/trainer_template.py | 4 +- rllib/evaluation/rollout_worker.py | 4 +- rllib/evaluation/worker_set.py | 7 + rllib/tests/test_supported_spaces.py | 28 ++-- rllib/utils/experimental_dsl.py | 204 ++++++++++++++++++++++++--- 8 files changed, 258 insertions(+), 74 deletions(-) diff --git a/python/ray/tune/progress_reporter.py b/python/ray/tune/progress_reporter.py index 36b4592f3..e57eb5804 100644 --- a/python/ray/tune/progress_reporter.py +++ b/python/ray/tune/progress_reporter.py @@ -47,14 +47,14 @@ class TuneReporterBase(ProgressReporter): """Abstract base class for the default Tune reporters.""" # Truncated representations of column names (to accommodate small screens). - DEFAULT_COLUMNS = { - EPISODE_REWARD_MEAN: "reward", + DEFAULT_COLUMNS = collections.OrderedDict({ MEAN_ACCURACY: "acc", MEAN_LOSS: "loss", + TRAINING_ITERATION: "iter", TIME_TOTAL_S: "total time (s)", TIMESTEPS_TOTAL: "ts", - TRAINING_ITERATION: "iter", - } + EPISODE_REWARD_MEAN: "reward", + }) def __init__(self, metric_columns=None, @@ -301,7 +301,6 @@ def trial_progress_str(trials, metric_columns, fmt="psql", max_rows=None): k for k in keys if any( t.last_result.get(k) is not None for t in trials) ] - keys = sorted(keys) # Build trial rows. params = sorted(set().union(*[t.evaluated_params for t in trials])) trial_table = [_get_trial_info(trial, params, keys) for trial in trials] diff --git a/python/ray/util/iter.py b/python/ray/util/iter.py index 6c9583c88..6595c4326 100644 --- a/python/ray/util/iter.py +++ b/python/ray/util/iter.py @@ -776,36 +776,35 @@ class LocalIterator(Generic[T]): if i >= n: break - def union(self, other: "LocalIterator[T]", + def union(self, *others: "LocalIterator[T]", deterministic: bool = False) -> "LocalIterator[T]": - """Return an iterator that is the union of this and the other. + """Return an iterator that is the union of this and the others. If deterministic=True, we alternate between reading from one iterator - and the other. Otherwise we return items from iterators as they + and the others. Otherwise we return items from iterators as they become ready. """ - if not isinstance(other, LocalIterator): - raise ValueError( - "other must be of type LocalIterator, got {}".format( - type(other))) + for it in others: + if not isinstance(it, LocalIterator): + raise ValueError( + "other must be of type LocalIterator, got {}".format( + type(it))) if deterministic: timeout = None else: timeout = 0 - it1 = LocalIterator( - self.base_iterator, - self.metrics, - self.local_transforms, - timeout=timeout) - it2 = LocalIterator( - other.base_iterator, - other.metrics, - other.local_transforms, - timeout=timeout) - active = [it1, it2] + active = [] + shared_metrics = MetricsContext() + for it in [self] + list(others): + active.append( + LocalIterator( + it.base_iterator, + shared_metrics, + it.local_transforms, + timeout=timeout)) def build_union(timeout=None): while True: @@ -826,15 +825,11 @@ class LocalIterator(Generic[T]): if not active: break - # TODO(ekl) is this the best way to represent union() of metrics? - new_ctx = MetricsContext() - new_ctx.parent_metrics.append(self.metrics) - new_ctx.parent_metrics.append(other.metrics) - return LocalIterator( build_union, - new_ctx, [], - name="LocalUnion[{}, {}]".format(self, other)) + shared_metrics, [], + name="LocalUnion[{}, {}]".format(self, ", ".join(map(str, + others)))) class ParallelIteratorWorker(object): diff --git a/rllib/agents/dqn/dqn.py b/rllib/agents/dqn/dqn.py index 1b7330c44..8a98a6cb1 100644 --- a/rllib/agents/dqn/dqn.py +++ b/rllib/agents/dqn/dqn.py @@ -5,9 +5,13 @@ from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy from ray.rllib.agents.dqn.simple_q_policy import SimpleQPolicy from ray.rllib.optimizers import SyncReplayOptimizer +from ray.rllib.optimizers.replay_buffer import ReplayBuffer from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE from ray.rllib.utils.exploration import PerWorkerEpsilonGreedy +from ray.rllib.utils.experimental_dsl import ( + ParallelRollouts, Concurrently, StoreToReplayBuffer, LocalReplay, + TrainOneStep, StandardMetricsReporting, UpdateTargetNetwork) logger = logging.getLogger(__name__) @@ -308,6 +312,30 @@ def update_target_if_needed(trainer, fetches): trainer.state["num_target_updates"] += 1 +# Experimental pipeline-based impl; enable with "use_pipeline_impl": True. +def training_pipeline(workers, config): + local_replay_buffer = ReplayBuffer(config["buffer_size"]) + rollouts = ParallelRollouts(workers, mode="bulk_sync") + + # We execute the following steps concurrently: + # (1) Generate rollouts and store them in our local replay buffer. Calling + # next() on store_op drives this. + store_op = rollouts.for_each(StoreToReplayBuffer(local_replay_buffer)) + + # (2) Read and train on experiences from the replay buffer. Every batch + # returned from the LocalReplay() iterator is passed to TrainOneStep to + # take a SGD step, and then we decide whether to update the target network. + replay_op = LocalReplay(local_replay_buffer, config["train_batch_size"]) \ + .for_each(TrainOneStep(workers)) \ + .for_each(UpdateTargetNetwork( + workers, config["target_network_update_freq"])) + + # Alternate deterministically between (1) and (2). + train_op = Concurrently([store_op, replay_op], mode="round_robin") + + return StandardMetricsReporting(train_op, workers, config) + + GenericOffPolicyTrainer = build_trainer( name="GenericOffPolicyAlgorithm", default_policy=None, @@ -317,7 +345,8 @@ GenericOffPolicyTrainer = build_trainer( make_policy_optimizer=make_policy_optimizer, before_train_step=update_worker_exploration, after_optimizer_step=update_target_if_needed, - after_train_result=after_train_result) + after_train_result=after_train_result, + training_pipeline=training_pipeline) DQNTrainer = GenericOffPolicyTrainer.with_updates( name="DQN", default_policy=DQNTFPolicy, default_config=DEFAULT_CONFIG) diff --git a/rllib/agents/trainer_template.py b/rllib/agents/trainer_template.py index 10aa2b4c7..87e2a14a6 100644 --- a/rllib/agents/trainer_template.py +++ b/rllib/agents/trainer_template.py @@ -168,10 +168,10 @@ def build_trainer(name, def _train_pipeline(self): if before_train_step: - before_train_step(self) + logger.warning("Ignoring before_train_step callback") res = next(self.train_pipeline) if after_train_result: - after_train_result(self, res) + logger.warning("Ignoring after_train_result callback") return res @override(Trainer) diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 597dd9fc6..b55a6778f 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -546,9 +546,11 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker): } @override(EvaluatorInterface) - def set_weights(self, weights): + def set_weights(self, weights, global_vars=None): for pid, w in weights.items(): self.policy_map[pid].set_weights(w) + if global_vars: + self.set_global_vars(global_vars) @override(EvaluatorInterface) def compute_gradients(self, samples): diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index 83a224297..c4922d2eb 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -1,6 +1,7 @@ import logging from types import FunctionType +import ray from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.evaluation.rollout_worker import RolloutWorker, \ _validate_multiagent_config @@ -71,6 +72,12 @@ class WorkerSet: """Return a list of remote rollout workers.""" return self._remote_workers + def sync_weights(self): + """Syncs weights of remote workers with the local worker.""" + weights = ray.put(self.local_worker().get_weights()) + for e in self.remote_workers(): + e.set_weights.remote(weights) + def add_workers(self, num_workers): """Creates and add a number of remote workers to this worker set. diff --git a/rllib/tests/test_supported_spaces.py b/rllib/tests/test_supported_spaces.py index 6a732042b..e76268d05 100644 --- a/rllib/tests/test_supported_spaces.py +++ b/rllib/tests/test_supported_spaces.py @@ -97,18 +97,18 @@ def check_support(alg, config, stats, check_bounds=False, name=None): if alg not in ["DDPG", "ES", "ARS", "SAC"]: if o_name in ["atari", "image"]: if torch: - assert isinstance( - a.get_policy().model, TorchVisionNetV2) + assert isinstance(a.get_policy().model, + TorchVisionNetV2) else: - assert isinstance( - a.get_policy().model, VisionNetV2) + assert isinstance(a.get_policy().model, + VisionNetV2) elif o_name in ["vector", "vector2"]: if torch: - assert isinstance( - a.get_policy().model, TorchFCNetV2) + assert isinstance(a.get_policy().model, + TorchFCNetV2) else: - assert isinstance( - a.get_policy().model, FCNetV2) + assert isinstance(a.get_policy().model, + FCNetV2) a.train() covered_a.add(a_name) covered_o.add(o_name) @@ -159,12 +159,7 @@ class ModelSupportedSpaces(unittest.TestCase): ray.shutdown() def test_a3c(self): - config = { - "num_workers": 1, - "optimizer": { - "grads_per_step": 1 - } - } + config = {"num_workers": 1, "optimizer": {"grads_per_step": 1}} check_support("A3C", config, self.stats, check_bounds=True) config["use_pytorch"] = True check_support("A3C", config, self.stats, check_bounds=True) @@ -228,10 +223,7 @@ class ModelSupportedSpaces(unittest.TestCase): check_support("PPO", config, self.stats, check_bounds=True) def test_pg(self): - config = { - "num_workers": 1, - "optimizer": {} - } + config = {"num_workers": 1, "optimizer": {}} check_support("PG", config, self.stats, check_bounds=True) config["use_pytorch"] = True check_support("PG", config, self.stats, check_bounds=True) diff --git a/rllib/utils/experimental_dsl.py b/rllib/utils/experimental_dsl.py index 637fba842..f2e581798 100644 --- a/rllib/utils/experimental_dsl.py +++ b/rllib/utils/experimental_dsl.py @@ -4,28 +4,40 @@ TODO(ekl): describe the concepts.""" import logging from typing import List, Any, Tuple, Union +import numpy as np import time import ray from ray.util.iter import from_actors, LocalIterator from ray.util.iter_metrics import MetricsContext +from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer from ray.rllib.evaluation.metrics import collect_episodes, \ summarize_episodes, get_learner_stats from ray.rllib.evaluation.rollout_worker import get_global_worker from ray.rllib.evaluation.worker_set import WorkerSet -from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch +from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch, \ + DEFAULT_POLICY_ID +from ray.rllib.utils.compression import pack_if_needed logger = logging.getLogger(__name__) -# Metrics context key definitions. +# Counters for training progress (keys for metrics.counters). STEPS_SAMPLED_COUNTER = "num_steps_sampled" STEPS_TRAINED_COUNTER = "num_steps_trained" + +# Counters to track target network updates. +LAST_TARGET_UPDATE_TS = "last_target_update_ts" +NUM_TARGET_UPDATES = "num_target_updates" + +# Performance timers (keys for metrics.timers). APPLY_GRADS_TIMER = "apply_grad" COMPUTE_GRADS_TIMER = "compute_grads" WORKER_UPDATE_TIMER = "update" GRAD_WAIT_TIMER = "grad_wait" SAMPLE_TIMER = "sample" LEARN_ON_BATCH_TIMER = "learn" + +# Instant metrics (keys for metrics.info). LEARNER_INFO = "learner" # Type aliases. @@ -33,12 +45,19 @@ GradientType = dict SampleBatchType = Union[SampleBatch, MultiAgentBatch] +# Asserts that an object is a type of SampleBatch. def _check_sample_batch_type(batch): if not isinstance(batch, SampleBatchType.__args__): raise ValueError("Expected either SampleBatch or MultiAgentBatch, " "got {}: {}".format(type(batch), batch)) +# Returns pipeline global vars that should be periodically sent to each worker. +def _get_global_vars(): + metrics = LocalIterator.get_metrics() + return {"timestep": metrics.counters[STEPS_SAMPLED_COUNTER]} + + def ParallelRollouts(workers: WorkerSet, mode="bulk_sync") -> LocalIterator[SampleBatch]: """Operator to collect experiences in parallel from rollout workers. @@ -71,6 +90,9 @@ def ParallelRollouts(workers: WorkerSet, Updates the STEPS_SAMPLED_COUNTER counter in the local iterator context. """ + # Ensure workers are initially in sync. + workers.sync_weights() + def report_timesteps(batch): metrics = LocalIterator.get_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count @@ -119,6 +141,9 @@ def AsyncGradients( local iterator context. """ + # Ensure workers are initially in sync. + workers.sync_weights() + # This function will be applied remotely on the workers. def samples_to_grads(samples): return get_global_worker().compute_gradients(samples), samples.count @@ -240,7 +265,9 @@ class TrainOneStep: with metrics.timers[WORKER_UPDATE_TIMER]: weights = ray.put(self.workers.local_worker().get_weights()) for e in self.workers.remote_workers(): - e.set_weights.remote(weights) + e.set_weights.remote(weights, _get_global_vars()) + # Also update global vars of the local worker. + self.workers.local_worker().set_global_vars(_get_global_vars()) return info @@ -266,9 +293,7 @@ class CollectMetrics: self.timeout_seconds = timeout_seconds def __call__(self, _): - metrics = LocalIterator.get_metrics() - if metrics.parent_metrics: - raise ValueError("TODO: support nested metrics") + # Collect worker metrics. episodes, self.to_be_collected = collect_episodes( self.workers.local_worker(), self.workers.remote_workers(), @@ -282,22 +307,31 @@ class CollectMetrics: self.episode_history.extend(orig_episodes) self.episode_history = self.episode_history[-self.min_history:] res = summarize_episodes(episodes, orig_episodes) - res.update(info=metrics.info) - res["info"].update({ - STEPS_SAMPLED_COUNTER: metrics.counters[STEPS_SAMPLED_COUNTER], - STEPS_TRAINED_COUNTER: metrics.counters[STEPS_TRAINED_COUNTER], - }) + + # Add in iterator metrics. + metrics = LocalIterator.get_metrics() + if metrics.parent_metrics: + print("TODO: support nested metrics better") + all_metrics = [metrics] + metrics.parent_metrics timers = {} - for k, timer in metrics.timers.items(): - timers["{}_time_ms".format(k)] = round(timer.mean * 1000, 3) - if timer.has_units_processed(): - timers["{}_throughput".format(k)] = round( - timer.mean_throughput, 3) + counters = {} + info = {} + for metrics in all_metrics: + info.update(metrics.info) + for k, counter in metrics.counters.items(): + counters[k] = counter + for k, timer in metrics.timers.items(): + timers["{}_time_ms".format(k)] = round(timer.mean * 1000, 3) + if timer.has_units_processed(): + timers["{}_throughput".format(k)] = round( + timer.mean_throughput, 3) + res.update({ + "num_healthy_workers": len(self.workers.remote_workers()), + "timesteps_total": metrics.counters[STEPS_SAMPLED_COUNTER], + }) res["timers"] = timers - res.update({ - "num_healthy_workers": len(self.workers.remote_workers()), - "timesteps_total": metrics.counters[STEPS_SAMPLED_COUNTER], - }) + res["info"] = info + res["info"].update(counters) return res @@ -392,13 +426,16 @@ class ApplyGradients: self.workers.local_worker().apply_gradients(gradients) apply_timer.push_units_processed(count) + # Also update global vars of the local worker. + self.workers.local_worker().set_global_vars(_get_global_vars()) + if self.update_all: if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: weights = ray.put( self.workers.local_worker().get_weights()) for e in self.workers.remote_workers(): - e.set_weights.remote(weights) + e.set_weights.remote(weights, _get_global_vars()) else: if metrics.cur_actor is None: raise ValueError("Could not find actor to update. When " @@ -406,7 +443,8 @@ class ApplyGradients: "in the iterator context.") with metrics.timers[WORKER_UPDATE_TIMER]: weights = self.workers.local_worker().get_weights() - metrics.cur_actor.set_weights.remote(weights) + metrics.cur_actor.set_weights.remote(weights, + _get_global_vars()) class AverageGradients: @@ -434,3 +472,125 @@ class AverageGradients: logger.info("Computing average of {} microbatch gradients " "({} samples total)".format(len(gradients), sum_count)) return acc, sum_count + + +class StoreToReplayBuffer: + def __init__(self, replay_buffer): + self.replay_buffers = {DEFAULT_POLICY_ID: replay_buffer} + + def __call__(self, batch: SampleBatchType): + # Handle everything as if multiagent + if isinstance(batch, SampleBatch): + batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count) + + for policy_id, s in batch.policy_batches.items(): + for row in s.rows(): + self.replay_buffers[policy_id].add( + pack_if_needed(row["obs"]), + row["actions"], + row["rewards"], + pack_if_needed(row["new_obs"]), + row["dones"], + weight=None) + + +def LocalReplay(replay_buffer, train_batch_size): + replay_buffers = {DEFAULT_POLICY_ID: replay_buffer} + # TODO(ekl) support more options + synchronize_sampling = False + prioritized_replay_beta = None + + def gen_replay(timeout): + while True: + samples = {} + idxes = None + for policy_id, replay_buffer in replay_buffers.items(): + if synchronize_sampling: + if idxes is None: + idxes = replay_buffer.sample_idxes(train_batch_size) + else: + idxes = replay_buffer.sample_idxes(train_batch_size) + + if isinstance(replay_buffer, PrioritizedReplayBuffer): + metrics = LocalIterator.get_metrics() + num_steps_trained = metrics.counters[STEPS_TRAINED_COUNTER] + (obses_t, actions, rewards, obses_tp1, dones, weights, + batch_indexes) = replay_buffer.sample_with_idxes( + idxes, + beta=prioritized_replay_beta.value(num_steps_trained)) + else: + (obses_t, actions, rewards, obses_tp1, + dones) = replay_buffer.sample_with_idxes(idxes) + weights = np.ones_like(rewards) + batch_indexes = -np.ones_like(rewards) + samples[policy_id] = SampleBatch({ + "obs": obses_t, + "actions": actions, + "rewards": rewards, + "new_obs": obses_tp1, + "dones": dones, + "weights": weights, + "batch_indexes": batch_indexes + }) + yield MultiAgentBatch(samples, train_batch_size) + + return LocalIterator(gen_replay, MetricsContext()) + + +def Concurrently(ops: List[LocalIterator], mode="round_robin"): + """Operator that runs the given parent iterators concurrently. + + Arguments: + mode (str): One of {'round_robin', 'async'}. + - In 'round_robin' mode, we alternate between pulling items from + each parent iterator in order deterministically. + - In 'async' mode, we pull from each parent iterator as fast as + they are produced. This is non-deterministic. + + >>> sim_op = ParallelRollouts(...).for_each(...) + >>> replay_op = LocalReplay(...).for_each(...) + >>> combined_op = Concurrently([sim_op, replay_op]) + """ + + if len(ops) < 2: + raise ValueError("Should specify at least 2 ops.") + if mode == "round_robin": + deterministic = True + elif mode == "async": + deterministic = False + else: + raise ValueError("Unknown mode {}".format(mode)) + return ops[0].union(*ops[1:], deterministic=deterministic) + + +class UpdateTargetNetwork: + """Periodically call policy.update_target() on all trainable policies. + + This should be used with the .for_each() operator after training step + has been taken. + + Examples: + >>> train_op = ParallelRollouts(...).for_each(TrainOneStep(...)) + >>> update_op = train_op.for_each( + ... UpdateTargetIfNeeded(workers, target_update_freq=500)) + >>> print(next(update_op)) + None + + Updates the LAST_TARGET_UPDATE_TS and NUM_TARGET_UPDATES counters in the + local iterator context. The value of the last update counter is used to + track when we should update the target next. + """ + + def __init__(self, workers, target_update_freq): + self.workers = workers + self.target_update_freq = target_update_freq + + def __call__(self, _): + metrics = LocalIterator.get_metrics() + cur_ts = metrics.counters[STEPS_SAMPLED_COUNTER] + last_update = metrics.counters[LAST_TARGET_UPDATE_TS] + if cur_ts - last_update > self.target_update_freq: + self.workers.local_worker().foreach_trainable_policy( + lambda p, _: p.update_target()) + metrics.counters[NUM_TARGET_UPDATES] += 1 + metrics.counters[LAST_TARGET_UPDATE_TS] = cur_ts