[rllib] Replace ray.get() with ray_get_and_free() to optimize memory usage (#4586)

This commit is contained in:
Eric Liang
2019-04-17 20:30:03 -04:00
committed by GitHub
parent 20c4c16891
commit 6848dfd179
19 changed files with 82 additions and 21 deletions
+1 -1
View File
@@ -108,7 +108,7 @@ This is how the example in the previous section looks when written using a polic
Trainers
--------
Trainers are the boilerplate classes that put the above components together. Trainer make algorithms accessible via Python API and the command line. They manage algorithm configuration, setup of the policy evaluators and optimizer, and collection of training metrics. Trainers also implement the `Trainable API <https://ray.readthedocs.io/en/latest/tune-usage.html#training-api>`__ for easy experiment management.
Trainers are the boilerplate classes that put the above components together, making algorithms accessible via Python API and the command line. They manage algorithm configuration, setup of the policy evaluators and optimizer, and collection of training metrics. Trainers also implement the `Trainable API <https://ray.readthedocs.io/en/latest/tune-usage.html#training-api>`__ for easy experiment management.
Example of two equivalent ways of interacting with the PPO trainer:
+3
View File
@@ -30,6 +30,9 @@ def free(object_ids, local_only=False, delete_creating_tasks=False):
"""
worker = ray.worker.get_global_worker()
if ray.worker._mode() == ray.worker.LOCAL_MODE:
return
if isinstance(object_ids, ray.ObjectID):
object_ids = [object_ids]
+2 -1
View File
@@ -19,6 +19,7 @@ from ray.rllib.agents.ars import policies
from ray.rllib.agents.ars import utils
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override
from ray.rllib.utils.memory import ray_get_and_free
from ray.rllib.utils import FilterManager
logger = logging.getLogger(__name__)
@@ -312,7 +313,7 @@ class ARSTrainer(Trainer):
worker.do_rollouts.remote(theta_id) for worker in self.workers
]
# Get the results of the rollouts.
for result in ray.get(rollout_ids):
for result in ray_get_and_free(rollout_ids):
results.append(result)
# Update the number of episodes and the number of timesteps
# keeping in mind that result.noisy_lengths is a list of lists,
+2 -1
View File
@@ -18,6 +18,7 @@ from ray.rllib.agents.es import policies
from ray.rllib.agents.es import utils
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override
from ray.rllib.utils.memory import ray_get_and_free
from ray.rllib.utils import FilterManager
logger = logging.getLogger(__name__)
@@ -309,7 +310,7 @@ class ESTrainer(Trainer):
worker.do_rollouts.remote(theta_id) for worker in self.workers
]
# Get the results of the rollouts.
for result in ray.get(rollout_ids):
for result in ray_get_and_free(rollout_ids):
results.append(result)
# Update the number of episodes and the number of timesteps
# keeping in mind that result.noisy_lengths is a list of lists,
+2 -1
View File
@@ -24,6 +24,7 @@ from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
from ray.rllib.utils.memory import ray_get_and_free
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
from ray.tune.trainable import Trainable
from ray.tune.trial import Resources, ExportFormat
@@ -668,7 +669,7 @@ class Trainer(Trainable):
for i, obj_id in enumerate(checks):
ev = self.optimizer.remote_evaluators[i]
try:
ray.get(obj_id)
ray_get_and_free(obj_id)
healthy_evaluators.append(ev)
logger.info("Worker {} looks healthy".format(i + 1))
except RayError:
+2 -1
View File
@@ -6,6 +6,7 @@ import logging
import ray
from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN
from ray.rllib.utils.memory import ray_get_and_free
logger = logging.getLogger(__name__)
@@ -60,7 +61,7 @@ class RemoteVectorEnv(BaseEnv):
actor = self.pending.pop(obj_id)
env_id = self.actors.index(actor)
env_ids.add(env_id)
ob, rew, done, info = ray.get(obj_id)
ob, rew, done, info = ray_get_and_free(obj_id)
obs[env_id] = ob
rewards[env_id] = rew
dones[env_id] = done
+2 -1
View File
@@ -10,6 +10,7 @@ import ray
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.memory import ray_get_and_free
logger = logging.getLogger(__name__)
@@ -69,7 +70,7 @@ def collect_episodes(local_evaluator=None,
"Timed out waiting for metrics from workers. You can configure "
"this timeout with `collect_metrics_timeout`.")
metric_lists = ray.get(collected)
metric_lists = ray_get_and_free(collected)
if local_evaluator:
metric_lists.append(local_evaluator.get_metrics())
episodes = []
@@ -10,6 +10,7 @@ import random
import ray
from ray.rllib.utils.actors import TaskPool
from ray.rllib.utils.annotations import override
from ray.rllib.utils.memory import ray_get_and_free
class Aggregator(object):
@@ -143,7 +144,7 @@ class AggregationWorkerBase(object):
return len(self.replay_batches) > num_needed
for ev, sample_batch in sample_futures:
sample_batch = ray.get(sample_batch)
sample_batch = ray_get_and_free(sample_batch)
yield ev, sample_batch
if can_replay():
@@ -14,6 +14,7 @@ from ray.rllib.utils.actors import TaskPool, create_colocated
from ray.rllib.utils.annotations import override
from ray.rllib.optimizers.aso_aggregator import Aggregator, \
AggregationWorkerBase
from ray.rllib.utils.memory import ray_get_and_free
logger = logging.getLogger(__name__)
@@ -86,7 +87,7 @@ class TreeAggregator(Aggregator):
def iter_train_batches(self):
assert self.initialized, "Must call init() before using this class."
for agg, batches in self.agg_tasks.completed_prefetch():
for b in ray.get(batches):
for b in ray_get_and_free(batches):
self.num_sent_since_broadcast += 1
yield b
agg.set_weights.remote(self.broadcasted_weights)
@@ -7,6 +7,7 @@ from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.utils.annotations import override
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.memory import ray_get_and_free
class AsyncGradientsOptimizer(PolicyOptimizer):
@@ -49,7 +50,7 @@ class AsyncGradientsOptimizer(PolicyOptimizer):
ready_list = wait_results[0]
future = ready_list[0]
gradient, info = ray.get(future)
gradient, info = ray_get_and_free(future)
e = pending_gradients.pop(future)
self.learner_stats = get_learner_stats(info)
@@ -23,6 +23,7 @@ from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer
from ray.rllib.utils.annotations import override
from ray.rllib.utils.actors import TaskPool, create_colocated
from ray.rllib.utils.memory import ray_get_and_free
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.window_stat import WindowStat
@@ -143,7 +144,8 @@ class AsyncReplayOptimizer(PolicyOptimizer):
@override(PolicyOptimizer)
def stats(self):
replay_stats = ray.get(self.replay_actors[0].stats.remote(self.debug))
replay_stats = ray_get_and_free(self.replay_actors[0].stats.remote(
self.debug))
timing = {
"{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3)
for k in self.timers
@@ -188,7 +190,7 @@ class AsyncReplayOptimizer(PolicyOptimizer):
with self.timers["sample_processing"]:
completed = list(self.sample_tasks.completed())
counts = ray.get([c[1][1] for c in completed])
counts = ray_get_and_free([c[1][1] for c in completed])
for i, (ev, (sample_batch, count)) in enumerate(completed):
sample_timesteps += counts[i]
@@ -220,7 +222,7 @@ class AsyncReplayOptimizer(PolicyOptimizer):
self.num_samples_dropped += 1
else:
with self.timers["get_samples"]:
samples = ray.get(replay)
samples = ray_get_and_free(replay)
# Defensive copy against plasma crashes, see #2610 #3452
self.learner.inqueue.put((ra, samples and samples.copy()))
@@ -4,9 +4,9 @@ from __future__ import print_function
import logging
import ray
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes
from ray.rllib.utils.memory import ray_get_and_free
logger = logging.getLogger(__name__)
@@ -140,7 +140,7 @@ class PolicyOptimizer(object):
"""Apply the given function to each evaluator instance."""
local_result = [func(self.local_evaluator)]
remote_results = ray.get(
remote_results = ray_get_and_free(
[ev.apply.remote(func) for ev in self.remote_evaluators])
return local_result + remote_results
@@ -152,7 +152,7 @@ class PolicyOptimizer(object):
"""
local_result = [func(self.local_evaluator, 0)]
remote_results = ray.get([
remote_results = ray_get_and_free([
ev.apply.remote(func, i + 1)
for i, ev in enumerate(self.remote_evaluators)
])
+3 -2
View File
@@ -6,6 +6,7 @@ import logging
import ray
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.utils.memory import ray_get_and_free
logger = logging.getLogger(__name__)
@@ -25,7 +26,7 @@ def collect_samples(agents, sample_batch_size, num_envs_per_worker,
while agent_dict:
[fut_sample], _ = ray.wait(list(agent_dict))
agent = agent_dict.pop(fut_sample)
next_sample = ray.get(fut_sample)
next_sample = ray_get_and_free(fut_sample)
assert next_sample.count >= sample_batch_size * num_envs_per_worker
num_timesteps_so_far += next_sample.count
trajectories.append(next_sample)
@@ -63,7 +64,7 @@ def collect_samples_straggler_mitigation(agents, train_batch_size):
fut_sample2 = agent.sample.remote()
agent_dict[fut_sample2] = agent
next_sample = ray.get(fut_sample)
next_sample = ray_get_and_free(fut_sample)
num_timesteps_so_far += next_sample.count
trajectories.append(next_sample)
@@ -11,6 +11,7 @@ from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
MultiAgentBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.memory import ray_get_and_free
class SyncBatchReplayOptimizer(PolicyOptimizer):
@@ -51,7 +52,7 @@ class SyncBatchReplayOptimizer(PolicyOptimizer):
with self.sample_timer:
if self.remote_evaluators:
batches = ray.get(
batches = ray_get_and_free(
[e.sample.remote() for e in self.remote_evaluators])
else:
batches = [self.local_evaluator.sample()]
@@ -17,6 +17,7 @@ from ray.rllib.utils.annotations import override
from ray.rllib.utils.compression import pack_if_needed
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.schedules import LinearSchedule
from ray.rllib.utils.memory import ray_get_and_free
logger = logging.getLogger(__name__)
@@ -89,7 +90,7 @@ class SyncReplayOptimizer(PolicyOptimizer):
with self.sample_timer:
if self.remote_evaluators:
batch = SampleBatch.concat_samples(
ray.get(
ray_get_and_free(
[e.sample.remote() for e in self.remote_evaluators]))
else:
batch = self.local_evaluator.sample()
@@ -10,6 +10,7 @@ from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.filter import RunningStat
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.memory import ray_get_and_free
logger = logging.getLogger(__name__)
@@ -50,7 +51,7 @@ class SyncSamplesOptimizer(PolicyOptimizer):
while sum(s.count for s in samples) < self.train_batch_size:
if self.remote_evaluators:
samples.extend(
ray.get([
ray_get_and_free([
e.sample.remote() for e in self.remote_evaluators
]))
else:
+2
View File
@@ -47,6 +47,8 @@ if __name__ == "__main__":
do_link("tune", force=args.yes)
do_link("autoscaler", force=args.yes)
do_link("scripts", force=args.yes)
do_link("internal", force=args.yes)
do_link("experimental", force=args.yes)
print("Created links.\n\nIf you run into issues initializing Ray, please "
"ensure that your local repo and the installed Ray are in sync "
"(pip install -U the latest wheels at "
+2 -1
View File
@@ -4,6 +4,7 @@ from __future__ import print_function
import ray
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.memory import ray_get_and_free
@DeveloperAPI
@@ -24,7 +25,7 @@ class FilterManager(object):
remotes (list): Remote evaluators with filters.
update_remote (bool): Whether to push updates to remote filters.
"""
remote_filters = ray.get(
remote_filters = ray_get_and_free(
[r.get_filters.remote(flush_after=True) for r in remotes])
for rf in remote_filters:
for k in local_filters:
+41
View File
@@ -3,6 +3,47 @@ from __future__ import division
from __future__ import print_function
import numpy as np
import time
import ray
FREE_DELAY_S = 10.0
MAX_FREE_QUEUE_SIZE = 100
_last_free_time = 0.0
_to_free = []
def ray_get_and_free(object_ids):
"""Call ray.get and then queue the object ids for deletion.
This function should be used whenever possible in RLlib, to optimize
memory usage. The only exception is when an object_id is shared among
multiple readers.
Args:
object_ids (ObjectID|List[ObjectID]): Object ids to fetch and free.
Returns:
The result of ray.get(object_ids).
"""
global _last_free_time
global _to_free
result = ray.get(object_ids)
if type(object_ids) is not list:
object_ids = [object_ids]
_to_free.extend(object_ids)
# batch calls to free to reduce overheads
now = time.time()
if (len(_to_free) > MAX_FREE_QUEUE_SIZE
or now - _last_free_time > FREE_DELAY_S):
ray.internal.free(_to_free)
_to_free = []
_last_free_time = now
return result
def aligned_array(size, dtype, align=64):