[rllib] Replace ray.get() with ray_get_and_free() to optimize memory usage (#4586)

This commit is contained in:
Eric Liang
2019-04-17 20:30:03 -04:00
committed by GitHub
parent 20c4c16891
commit 6848dfd179
19 changed files with 82 additions and 21 deletions
+3
View File
@@ -30,6 +30,9 @@ def free(object_ids, local_only=False, delete_creating_tasks=False):
"""
worker = ray.worker.get_global_worker()
if ray.worker._mode() == ray.worker.LOCAL_MODE:
return
if isinstance(object_ids, ray.ObjectID):
object_ids = [object_ids]
+2 -1
View File
@@ -19,6 +19,7 @@ from ray.rllib.agents.ars import policies
from ray.rllib.agents.ars import utils
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override
from ray.rllib.utils.memory import ray_get_and_free
from ray.rllib.utils import FilterManager
logger = logging.getLogger(__name__)
@@ -312,7 +313,7 @@ class ARSTrainer(Trainer):
worker.do_rollouts.remote(theta_id) for worker in self.workers
]
# Get the results of the rollouts.
for result in ray.get(rollout_ids):
for result in ray_get_and_free(rollout_ids):
results.append(result)
# Update the number of episodes and the number of timesteps
# keeping in mind that result.noisy_lengths is a list of lists,
+2 -1
View File
@@ -18,6 +18,7 @@ from ray.rllib.agents.es import policies
from ray.rllib.agents.es import utils
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override
from ray.rllib.utils.memory import ray_get_and_free
from ray.rllib.utils import FilterManager
logger = logging.getLogger(__name__)
@@ -309,7 +310,7 @@ class ESTrainer(Trainer):
worker.do_rollouts.remote(theta_id) for worker in self.workers
]
# Get the results of the rollouts.
for result in ray.get(rollout_ids):
for result in ray_get_and_free(rollout_ids):
results.append(result)
# Update the number of episodes and the number of timesteps
# keeping in mind that result.noisy_lengths is a list of lists,
+2 -1
View File
@@ -24,6 +24,7 @@ from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
from ray.rllib.utils.memory import ray_get_and_free
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
from ray.tune.trainable import Trainable
from ray.tune.trial import Resources, ExportFormat
@@ -668,7 +669,7 @@ class Trainer(Trainable):
for i, obj_id in enumerate(checks):
ev = self.optimizer.remote_evaluators[i]
try:
ray.get(obj_id)
ray_get_and_free(obj_id)
healthy_evaluators.append(ev)
logger.info("Worker {} looks healthy".format(i + 1))
except RayError:
+2 -1
View File
@@ -6,6 +6,7 @@ import logging
import ray
from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN
from ray.rllib.utils.memory import ray_get_and_free
logger = logging.getLogger(__name__)
@@ -60,7 +61,7 @@ class RemoteVectorEnv(BaseEnv):
actor = self.pending.pop(obj_id)
env_id = self.actors.index(actor)
env_ids.add(env_id)
ob, rew, done, info = ray.get(obj_id)
ob, rew, done, info = ray_get_and_free(obj_id)
obs[env_id] = ob
rewards[env_id] = rew
dones[env_id] = done
+2 -1
View File
@@ -10,6 +10,7 @@ import ray
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.memory import ray_get_and_free
logger = logging.getLogger(__name__)
@@ -69,7 +70,7 @@ def collect_episodes(local_evaluator=None,
"Timed out waiting for metrics from workers. You can configure "
"this timeout with `collect_metrics_timeout`.")
metric_lists = ray.get(collected)
metric_lists = ray_get_and_free(collected)
if local_evaluator:
metric_lists.append(local_evaluator.get_metrics())
episodes = []
@@ -10,6 +10,7 @@ import random
import ray
from ray.rllib.utils.actors import TaskPool
from ray.rllib.utils.annotations import override
from ray.rllib.utils.memory import ray_get_and_free
class Aggregator(object):
@@ -143,7 +144,7 @@ class AggregationWorkerBase(object):
return len(self.replay_batches) > num_needed
for ev, sample_batch in sample_futures:
sample_batch = ray.get(sample_batch)
sample_batch = ray_get_and_free(sample_batch)
yield ev, sample_batch
if can_replay():
@@ -14,6 +14,7 @@ from ray.rllib.utils.actors import TaskPool, create_colocated
from ray.rllib.utils.annotations import override
from ray.rllib.optimizers.aso_aggregator import Aggregator, \
AggregationWorkerBase
from ray.rllib.utils.memory import ray_get_and_free
logger = logging.getLogger(__name__)
@@ -86,7 +87,7 @@ class TreeAggregator(Aggregator):
def iter_train_batches(self):
assert self.initialized, "Must call init() before using this class."
for agg, batches in self.agg_tasks.completed_prefetch():
for b in ray.get(batches):
for b in ray_get_and_free(batches):
self.num_sent_since_broadcast += 1
yield b
agg.set_weights.remote(self.broadcasted_weights)
@@ -7,6 +7,7 @@ from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.utils.annotations import override
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.memory import ray_get_and_free
class AsyncGradientsOptimizer(PolicyOptimizer):
@@ -49,7 +50,7 @@ class AsyncGradientsOptimizer(PolicyOptimizer):
ready_list = wait_results[0]
future = ready_list[0]
gradient, info = ray.get(future)
gradient, info = ray_get_and_free(future)
e = pending_gradients.pop(future)
self.learner_stats = get_learner_stats(info)
@@ -23,6 +23,7 @@ from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer
from ray.rllib.utils.annotations import override
from ray.rllib.utils.actors import TaskPool, create_colocated
from ray.rllib.utils.memory import ray_get_and_free
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.window_stat import WindowStat
@@ -143,7 +144,8 @@ class AsyncReplayOptimizer(PolicyOptimizer):
@override(PolicyOptimizer)
def stats(self):
replay_stats = ray.get(self.replay_actors[0].stats.remote(self.debug))
replay_stats = ray_get_and_free(self.replay_actors[0].stats.remote(
self.debug))
timing = {
"{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3)
for k in self.timers
@@ -188,7 +190,7 @@ class AsyncReplayOptimizer(PolicyOptimizer):
with self.timers["sample_processing"]:
completed = list(self.sample_tasks.completed())
counts = ray.get([c[1][1] for c in completed])
counts = ray_get_and_free([c[1][1] for c in completed])
for i, (ev, (sample_batch, count)) in enumerate(completed):
sample_timesteps += counts[i]
@@ -220,7 +222,7 @@ class AsyncReplayOptimizer(PolicyOptimizer):
self.num_samples_dropped += 1
else:
with self.timers["get_samples"]:
samples = ray.get(replay)
samples = ray_get_and_free(replay)
# Defensive copy against plasma crashes, see #2610 #3452
self.learner.inqueue.put((ra, samples and samples.copy()))
@@ -4,9 +4,9 @@ from __future__ import print_function
import logging
import ray
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes
from ray.rllib.utils.memory import ray_get_and_free
logger = logging.getLogger(__name__)
@@ -140,7 +140,7 @@ class PolicyOptimizer(object):
"""Apply the given function to each evaluator instance."""
local_result = [func(self.local_evaluator)]
remote_results = ray.get(
remote_results = ray_get_and_free(
[ev.apply.remote(func) for ev in self.remote_evaluators])
return local_result + remote_results
@@ -152,7 +152,7 @@ class PolicyOptimizer(object):
"""
local_result = [func(self.local_evaluator, 0)]
remote_results = ray.get([
remote_results = ray_get_and_free([
ev.apply.remote(func, i + 1)
for i, ev in enumerate(self.remote_evaluators)
])
+3 -2
View File
@@ -6,6 +6,7 @@ import logging
import ray
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.utils.memory import ray_get_and_free
logger = logging.getLogger(__name__)
@@ -25,7 +26,7 @@ def collect_samples(agents, sample_batch_size, num_envs_per_worker,
while agent_dict:
[fut_sample], _ = ray.wait(list(agent_dict))
agent = agent_dict.pop(fut_sample)
next_sample = ray.get(fut_sample)
next_sample = ray_get_and_free(fut_sample)
assert next_sample.count >= sample_batch_size * num_envs_per_worker
num_timesteps_so_far += next_sample.count
trajectories.append(next_sample)
@@ -63,7 +64,7 @@ def collect_samples_straggler_mitigation(agents, train_batch_size):
fut_sample2 = agent.sample.remote()
agent_dict[fut_sample2] = agent
next_sample = ray.get(fut_sample)
next_sample = ray_get_and_free(fut_sample)
num_timesteps_so_far += next_sample.count
trajectories.append(next_sample)
@@ -11,6 +11,7 @@ from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
MultiAgentBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.memory import ray_get_and_free
class SyncBatchReplayOptimizer(PolicyOptimizer):
@@ -51,7 +52,7 @@ class SyncBatchReplayOptimizer(PolicyOptimizer):
with self.sample_timer:
if self.remote_evaluators:
batches = ray.get(
batches = ray_get_and_free(
[e.sample.remote() for e in self.remote_evaluators])
else:
batches = [self.local_evaluator.sample()]
@@ -17,6 +17,7 @@ from ray.rllib.utils.annotations import override
from ray.rllib.utils.compression import pack_if_needed
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.schedules import LinearSchedule
from ray.rllib.utils.memory import ray_get_and_free
logger = logging.getLogger(__name__)
@@ -89,7 +90,7 @@ class SyncReplayOptimizer(PolicyOptimizer):
with self.sample_timer:
if self.remote_evaluators:
batch = SampleBatch.concat_samples(
ray.get(
ray_get_and_free(
[e.sample.remote() for e in self.remote_evaluators]))
else:
batch = self.local_evaluator.sample()
@@ -10,6 +10,7 @@ from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.filter import RunningStat
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.memory import ray_get_and_free
logger = logging.getLogger(__name__)
@@ -50,7 +51,7 @@ class SyncSamplesOptimizer(PolicyOptimizer):
while sum(s.count for s in samples) < self.train_batch_size:
if self.remote_evaluators:
samples.extend(
ray.get([
ray_get_and_free([
e.sample.remote() for e in self.remote_evaluators
]))
else:
+2
View File
@@ -47,6 +47,8 @@ if __name__ == "__main__":
do_link("tune", force=args.yes)
do_link("autoscaler", force=args.yes)
do_link("scripts", force=args.yes)
do_link("internal", force=args.yes)
do_link("experimental", force=args.yes)
print("Created links.\n\nIf you run into issues initializing Ray, please "
"ensure that your local repo and the installed Ray are in sync "
"(pip install -U the latest wheels at "
+2 -1
View File
@@ -4,6 +4,7 @@ from __future__ import print_function
import ray
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.memory import ray_get_and_free
@DeveloperAPI
@@ -24,7 +25,7 @@ class FilterManager(object):
remotes (list): Remote evaluators with filters.
update_remote (bool): Whether to push updates to remote filters.
"""
remote_filters = ray.get(
remote_filters = ray_get_and_free(
[r.get_filters.remote(flush_after=True) for r in remotes])
for rf in remote_filters:
for k in local_filters:
+41
View File
@@ -3,6 +3,47 @@ from __future__ import division
from __future__ import print_function
import numpy as np
import time
import ray
FREE_DELAY_S = 10.0
MAX_FREE_QUEUE_SIZE = 100
_last_free_time = 0.0
_to_free = []
def ray_get_and_free(object_ids):
"""Call ray.get and then queue the object ids for deletion.
This function should be used whenever possible in RLlib, to optimize
memory usage. The only exception is when an object_id is shared among
multiple readers.
Args:
object_ids (ObjectID|List[ObjectID]): Object ids to fetch and free.
Returns:
The result of ray.get(object_ids).
"""
global _last_free_time
global _to_free
result = ray.get(object_ids)
if type(object_ids) is not list:
object_ids = [object_ids]
_to_free.extend(object_ids)
# batch calls to free to reduce overheads
now = time.time()
if (len(_to_free) > MAX_FREE_QUEUE_SIZE
or now - _last_free_time > FREE_DELAY_S):
ray.internal.free(_to_free)
_to_free = []
_last_free_time = now
return result
def aligned_array(size, dtype, align=64):