mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 23:39:37 +08:00
[rllib] Replace ray.get() with ray_get_and_free() to optimize memory usage (#4586)
This commit is contained in:
@@ -30,6 +30,9 @@ def free(object_ids, local_only=False, delete_creating_tasks=False):
|
||||
"""
|
||||
worker = ray.worker.get_global_worker()
|
||||
|
||||
if ray.worker._mode() == ray.worker.LOCAL_MODE:
|
||||
return
|
||||
|
||||
if isinstance(object_ids, ray.ObjectID):
|
||||
object_ids = [object_ids]
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from ray.rllib.agents.ars import policies
|
||||
from ray.rllib.agents.ars import utils
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
from ray.rllib.utils import FilterManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -312,7 +313,7 @@ class ARSTrainer(Trainer):
|
||||
worker.do_rollouts.remote(theta_id) for worker in self.workers
|
||||
]
|
||||
# Get the results of the rollouts.
|
||||
for result in ray.get(rollout_ids):
|
||||
for result in ray_get_and_free(rollout_ids):
|
||||
results.append(result)
|
||||
# Update the number of episodes and the number of timesteps
|
||||
# keeping in mind that result.noisy_lengths is a list of lists,
|
||||
|
||||
@@ -18,6 +18,7 @@ from ray.rllib.agents.es import policies
|
||||
from ray.rllib.agents.es import utils
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
from ray.rllib.utils import FilterManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -309,7 +310,7 @@ class ESTrainer(Trainer):
|
||||
worker.do_rollouts.remote(theta_id) for worker in self.workers
|
||||
]
|
||||
# Get the results of the rollouts.
|
||||
for result in ray.get(rollout_ids):
|
||||
for result in ray_get_and_free(rollout_ids):
|
||||
results.append(result)
|
||||
# Update the number of episodes and the number of timesteps
|
||||
# keeping in mind that result.noisy_lengths is a list of lists,
|
||||
|
||||
@@ -24,6 +24,7 @@ from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
|
||||
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.trial import Resources, ExportFormat
|
||||
@@ -668,7 +669,7 @@ class Trainer(Trainable):
|
||||
for i, obj_id in enumerate(checks):
|
||||
ev = self.optimizer.remote_evaluators[i]
|
||||
try:
|
||||
ray.get(obj_id)
|
||||
ray_get_and_free(obj_id)
|
||||
healthy_evaluators.append(ev)
|
||||
logger.info("Worker {} looks healthy".format(i + 1))
|
||||
except RayError:
|
||||
|
||||
+2
-1
@@ -6,6 +6,7 @@ import logging
|
||||
|
||||
import ray
|
||||
from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -60,7 +61,7 @@ class RemoteVectorEnv(BaseEnv):
|
||||
actor = self.pending.pop(obj_id)
|
||||
env_id = self.actors.index(actor)
|
||||
env_ids.add(env_id)
|
||||
ob, rew, done, info = ray.get(obj_id)
|
||||
ob, rew, done, info = ray_get_and_free(obj_id)
|
||||
obs[env_id] = ob
|
||||
rewards[env_id] = rew
|
||||
dones[env_id] = done
|
||||
|
||||
@@ -10,6 +10,7 @@ import ray
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -69,7 +70,7 @@ def collect_episodes(local_evaluator=None,
|
||||
"Timed out waiting for metrics from workers. You can configure "
|
||||
"this timeout with `collect_metrics_timeout`.")
|
||||
|
||||
metric_lists = ray.get(collected)
|
||||
metric_lists = ray_get_and_free(collected)
|
||||
if local_evaluator:
|
||||
metric_lists.append(local_evaluator.get_metrics())
|
||||
episodes = []
|
||||
|
||||
@@ -10,6 +10,7 @@ import random
|
||||
import ray
|
||||
from ray.rllib.utils.actors import TaskPool
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
|
||||
class Aggregator(object):
|
||||
@@ -143,7 +144,7 @@ class AggregationWorkerBase(object):
|
||||
return len(self.replay_batches) > num_needed
|
||||
|
||||
for ev, sample_batch in sample_futures:
|
||||
sample_batch = ray.get(sample_batch)
|
||||
sample_batch = ray_get_and_free(sample_batch)
|
||||
yield ev, sample_batch
|
||||
|
||||
if can_replay():
|
||||
|
||||
@@ -14,6 +14,7 @@ from ray.rllib.utils.actors import TaskPool, create_colocated
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.optimizers.aso_aggregator import Aggregator, \
|
||||
AggregationWorkerBase
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -86,7 +87,7 @@ class TreeAggregator(Aggregator):
|
||||
def iter_train_batches(self):
|
||||
assert self.initialized, "Must call init() before using this class."
|
||||
for agg, batches in self.agg_tasks.completed_prefetch():
|
||||
for b in ray.get(batches):
|
||||
for b in ray_get_and_free(batches):
|
||||
self.num_sent_since_broadcast += 1
|
||||
yield b
|
||||
agg.set_weights.remote(self.broadcasted_weights)
|
||||
|
||||
@@ -7,6 +7,7 @@ from ray.rllib.evaluation.metrics import get_learner_stats
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
|
||||
class AsyncGradientsOptimizer(PolicyOptimizer):
|
||||
@@ -49,7 +50,7 @@ class AsyncGradientsOptimizer(PolicyOptimizer):
|
||||
ready_list = wait_results[0]
|
||||
future = ready_list[0]
|
||||
|
||||
gradient, info = ray.get(future)
|
||||
gradient, info = ray_get_and_free(future)
|
||||
e = pending_gradients.pop(future)
|
||||
self.learner_stats = get_learner_stats(info)
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.actors import TaskPool, create_colocated
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.window_stat import WindowStat
|
||||
|
||||
@@ -143,7 +144,8 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
||||
|
||||
@override(PolicyOptimizer)
|
||||
def stats(self):
|
||||
replay_stats = ray.get(self.replay_actors[0].stats.remote(self.debug))
|
||||
replay_stats = ray_get_and_free(self.replay_actors[0].stats.remote(
|
||||
self.debug))
|
||||
timing = {
|
||||
"{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3)
|
||||
for k in self.timers
|
||||
@@ -188,7 +190,7 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
||||
|
||||
with self.timers["sample_processing"]:
|
||||
completed = list(self.sample_tasks.completed())
|
||||
counts = ray.get([c[1][1] for c in completed])
|
||||
counts = ray_get_and_free([c[1][1] for c in completed])
|
||||
for i, (ev, (sample_batch, count)) in enumerate(completed):
|
||||
sample_timesteps += counts[i]
|
||||
|
||||
@@ -220,7 +222,7 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
||||
self.num_samples_dropped += 1
|
||||
else:
|
||||
with self.timers["get_samples"]:
|
||||
samples = ray.get(replay)
|
||||
samples = ray_get_and_free(replay)
|
||||
# Defensive copy against plasma crashes, see #2610 #3452
|
||||
self.learner.inqueue.put((ra, samples and samples.copy()))
|
||||
|
||||
|
||||
@@ -4,9 +4,9 @@ from __future__ import print_function
|
||||
|
||||
import logging
|
||||
|
||||
import ray
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -140,7 +140,7 @@ class PolicyOptimizer(object):
|
||||
"""Apply the given function to each evaluator instance."""
|
||||
|
||||
local_result = [func(self.local_evaluator)]
|
||||
remote_results = ray.get(
|
||||
remote_results = ray_get_and_free(
|
||||
[ev.apply.remote(func) for ev in self.remote_evaluators])
|
||||
return local_result + remote_results
|
||||
|
||||
@@ -152,7 +152,7 @@ class PolicyOptimizer(object):
|
||||
"""
|
||||
|
||||
local_result = [func(self.local_evaluator, 0)]
|
||||
remote_results = ray.get([
|
||||
remote_results = ray_get_and_free([
|
||||
ev.apply.remote(func, i + 1)
|
||||
for i, ev in enumerate(self.remote_evaluators)
|
||||
])
|
||||
|
||||
@@ -6,6 +6,7 @@ import logging
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -25,7 +26,7 @@ def collect_samples(agents, sample_batch_size, num_envs_per_worker,
|
||||
while agent_dict:
|
||||
[fut_sample], _ = ray.wait(list(agent_dict))
|
||||
agent = agent_dict.pop(fut_sample)
|
||||
next_sample = ray.get(fut_sample)
|
||||
next_sample = ray_get_and_free(fut_sample)
|
||||
assert next_sample.count >= sample_batch_size * num_envs_per_worker
|
||||
num_timesteps_so_far += next_sample.count
|
||||
trajectories.append(next_sample)
|
||||
@@ -63,7 +64,7 @@ def collect_samples_straggler_mitigation(agents, train_batch_size):
|
||||
fut_sample2 = agent.sample.remote()
|
||||
agent_dict[fut_sample2] = agent
|
||||
|
||||
next_sample = ray.get(fut_sample)
|
||||
next_sample = ray_get_and_free(fut_sample)
|
||||
num_timesteps_so_far += next_sample.count
|
||||
trajectories.append(next_sample)
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
|
||||
class SyncBatchReplayOptimizer(PolicyOptimizer):
|
||||
@@ -51,7 +52,7 @@ class SyncBatchReplayOptimizer(PolicyOptimizer):
|
||||
|
||||
with self.sample_timer:
|
||||
if self.remote_evaluators:
|
||||
batches = ray.get(
|
||||
batches = ray_get_and_free(
|
||||
[e.sample.remote() for e in self.remote_evaluators])
|
||||
else:
|
||||
batches = [self.local_evaluator.sample()]
|
||||
|
||||
@@ -17,6 +17,7 @@ from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.compression import pack_if_needed
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.schedules import LinearSchedule
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -89,7 +90,7 @@ class SyncReplayOptimizer(PolicyOptimizer):
|
||||
with self.sample_timer:
|
||||
if self.remote_evaluators:
|
||||
batch = SampleBatch.concat_samples(
|
||||
ray.get(
|
||||
ray_get_and_free(
|
||||
[e.sample.remote() for e in self.remote_evaluators]))
|
||||
else:
|
||||
batch = self.local_evaluator.sample()
|
||||
|
||||
@@ -10,6 +10,7 @@ from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.filter import RunningStat
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -50,7 +51,7 @@ class SyncSamplesOptimizer(PolicyOptimizer):
|
||||
while sum(s.count for s in samples) < self.train_batch_size:
|
||||
if self.remote_evaluators:
|
||||
samples.extend(
|
||||
ray.get([
|
||||
ray_get_and_free([
|
||||
e.sample.remote() for e in self.remote_evaluators
|
||||
]))
|
||||
else:
|
||||
|
||||
@@ -47,6 +47,8 @@ if __name__ == "__main__":
|
||||
do_link("tune", force=args.yes)
|
||||
do_link("autoscaler", force=args.yes)
|
||||
do_link("scripts", force=args.yes)
|
||||
do_link("internal", force=args.yes)
|
||||
do_link("experimental", force=args.yes)
|
||||
print("Created links.\n\nIf you run into issues initializing Ray, please "
|
||||
"ensure that your local repo and the installed Ray are in sync "
|
||||
"(pip install -U the latest wheels at "
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import print_function
|
||||
|
||||
import ray
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
@@ -24,7 +25,7 @@ class FilterManager(object):
|
||||
remotes (list): Remote evaluators with filters.
|
||||
update_remote (bool): Whether to push updates to remote filters.
|
||||
"""
|
||||
remote_filters = ray.get(
|
||||
remote_filters = ray_get_and_free(
|
||||
[r.get_filters.remote(flush_after=True) for r in remotes])
|
||||
for rf in remote_filters:
|
||||
for k in local_filters:
|
||||
|
||||
@@ -3,6 +3,47 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
import ray
|
||||
|
||||
FREE_DELAY_S = 10.0
|
||||
MAX_FREE_QUEUE_SIZE = 100
|
||||
_last_free_time = 0.0
|
||||
_to_free = []
|
||||
|
||||
|
||||
def ray_get_and_free(object_ids):
|
||||
"""Call ray.get and then queue the object ids for deletion.
|
||||
|
||||
This function should be used whenever possible in RLlib, to optimize
|
||||
memory usage. The only exception is when an object_id is shared among
|
||||
multiple readers.
|
||||
|
||||
Args:
|
||||
object_ids (ObjectID|List[ObjectID]): Object ids to fetch and free.
|
||||
|
||||
Returns:
|
||||
The result of ray.get(object_ids).
|
||||
"""
|
||||
|
||||
global _last_free_time
|
||||
global _to_free
|
||||
|
||||
result = ray.get(object_ids)
|
||||
if type(object_ids) is not list:
|
||||
object_ids = [object_ids]
|
||||
_to_free.extend(object_ids)
|
||||
|
||||
# batch calls to free to reduce overheads
|
||||
now = time.time()
|
||||
if (len(_to_free) > MAX_FREE_QUEUE_SIZE
|
||||
or now - _last_free_time > FREE_DELAY_S):
|
||||
ray.internal.free(_to_free)
|
||||
_to_free = []
|
||||
_last_free_time = now
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def aligned_array(size, dtype, align=64):
|
||||
|
||||
Reference in New Issue
Block a user