mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 22:23:17 +08:00
[rllib] Replace ray.get() with ray_get_and_free() to optimize memory usage (#4586)
This commit is contained in:
@@ -108,7 +108,7 @@ This is how the example in the previous section looks when written using a polic
|
||||
Trainers
|
||||
--------
|
||||
|
||||
Trainers are the boilerplate classes that put the above components together. Trainer make algorithms accessible via Python API and the command line. They manage algorithm configuration, setup of the policy evaluators and optimizer, and collection of training metrics. Trainers also implement the `Trainable API <https://ray.readthedocs.io/en/latest/tune-usage.html#training-api>`__ for easy experiment management.
|
||||
Trainers are the boilerplate classes that put the above components together, making algorithms accessible via Python API and the command line. They manage algorithm configuration, setup of the policy evaluators and optimizer, and collection of training metrics. Trainers also implement the `Trainable API <https://ray.readthedocs.io/en/latest/tune-usage.html#training-api>`__ for easy experiment management.
|
||||
|
||||
Example of two equivalent ways of interacting with the PPO trainer:
|
||||
|
||||
|
||||
@@ -30,6 +30,9 @@ def free(object_ids, local_only=False, delete_creating_tasks=False):
|
||||
"""
|
||||
worker = ray.worker.get_global_worker()
|
||||
|
||||
if ray.worker._mode() == ray.worker.LOCAL_MODE:
|
||||
return
|
||||
|
||||
if isinstance(object_ids, ray.ObjectID):
|
||||
object_ids = [object_ids]
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from ray.rllib.agents.ars import policies
|
||||
from ray.rllib.agents.ars import utils
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
from ray.rllib.utils import FilterManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -312,7 +313,7 @@ class ARSTrainer(Trainer):
|
||||
worker.do_rollouts.remote(theta_id) for worker in self.workers
|
||||
]
|
||||
# Get the results of the rollouts.
|
||||
for result in ray.get(rollout_ids):
|
||||
for result in ray_get_and_free(rollout_ids):
|
||||
results.append(result)
|
||||
# Update the number of episodes and the number of timesteps
|
||||
# keeping in mind that result.noisy_lengths is a list of lists,
|
||||
|
||||
@@ -18,6 +18,7 @@ from ray.rllib.agents.es import policies
|
||||
from ray.rllib.agents.es import utils
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
from ray.rllib.utils import FilterManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -309,7 +310,7 @@ class ESTrainer(Trainer):
|
||||
worker.do_rollouts.remote(theta_id) for worker in self.workers
|
||||
]
|
||||
# Get the results of the rollouts.
|
||||
for result in ray.get(rollout_ids):
|
||||
for result in ray_get_and_free(rollout_ids):
|
||||
results.append(result)
|
||||
# Update the number of episodes and the number of timesteps
|
||||
# keeping in mind that result.noisy_lengths is a list of lists,
|
||||
|
||||
@@ -24,6 +24,7 @@ from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
|
||||
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.trial import Resources, ExportFormat
|
||||
@@ -668,7 +669,7 @@ class Trainer(Trainable):
|
||||
for i, obj_id in enumerate(checks):
|
||||
ev = self.optimizer.remote_evaluators[i]
|
||||
try:
|
||||
ray.get(obj_id)
|
||||
ray_get_and_free(obj_id)
|
||||
healthy_evaluators.append(ev)
|
||||
logger.info("Worker {} looks healthy".format(i + 1))
|
||||
except RayError:
|
||||
|
||||
+2
-1
@@ -6,6 +6,7 @@ import logging
|
||||
|
||||
import ray
|
||||
from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -60,7 +61,7 @@ class RemoteVectorEnv(BaseEnv):
|
||||
actor = self.pending.pop(obj_id)
|
||||
env_id = self.actors.index(actor)
|
||||
env_ids.add(env_id)
|
||||
ob, rew, done, info = ray.get(obj_id)
|
||||
ob, rew, done, info = ray_get_and_free(obj_id)
|
||||
obs[env_id] = ob
|
||||
rewards[env_id] = rew
|
||||
dones[env_id] = done
|
||||
|
||||
@@ -10,6 +10,7 @@ import ray
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -69,7 +70,7 @@ def collect_episodes(local_evaluator=None,
|
||||
"Timed out waiting for metrics from workers. You can configure "
|
||||
"this timeout with `collect_metrics_timeout`.")
|
||||
|
||||
metric_lists = ray.get(collected)
|
||||
metric_lists = ray_get_and_free(collected)
|
||||
if local_evaluator:
|
||||
metric_lists.append(local_evaluator.get_metrics())
|
||||
episodes = []
|
||||
|
||||
@@ -10,6 +10,7 @@ import random
|
||||
import ray
|
||||
from ray.rllib.utils.actors import TaskPool
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
|
||||
class Aggregator(object):
|
||||
@@ -143,7 +144,7 @@ class AggregationWorkerBase(object):
|
||||
return len(self.replay_batches) > num_needed
|
||||
|
||||
for ev, sample_batch in sample_futures:
|
||||
sample_batch = ray.get(sample_batch)
|
||||
sample_batch = ray_get_and_free(sample_batch)
|
||||
yield ev, sample_batch
|
||||
|
||||
if can_replay():
|
||||
|
||||
@@ -14,6 +14,7 @@ from ray.rllib.utils.actors import TaskPool, create_colocated
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.optimizers.aso_aggregator import Aggregator, \
|
||||
AggregationWorkerBase
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -86,7 +87,7 @@ class TreeAggregator(Aggregator):
|
||||
def iter_train_batches(self):
|
||||
assert self.initialized, "Must call init() before using this class."
|
||||
for agg, batches in self.agg_tasks.completed_prefetch():
|
||||
for b in ray.get(batches):
|
||||
for b in ray_get_and_free(batches):
|
||||
self.num_sent_since_broadcast += 1
|
||||
yield b
|
||||
agg.set_weights.remote(self.broadcasted_weights)
|
||||
|
||||
@@ -7,6 +7,7 @@ from ray.rllib.evaluation.metrics import get_learner_stats
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
|
||||
class AsyncGradientsOptimizer(PolicyOptimizer):
|
||||
@@ -49,7 +50,7 @@ class AsyncGradientsOptimizer(PolicyOptimizer):
|
||||
ready_list = wait_results[0]
|
||||
future = ready_list[0]
|
||||
|
||||
gradient, info = ray.get(future)
|
||||
gradient, info = ray_get_and_free(future)
|
||||
e = pending_gradients.pop(future)
|
||||
self.learner_stats = get_learner_stats(info)
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.actors import TaskPool, create_colocated
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.window_stat import WindowStat
|
||||
|
||||
@@ -143,7 +144,8 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
||||
|
||||
@override(PolicyOptimizer)
|
||||
def stats(self):
|
||||
replay_stats = ray.get(self.replay_actors[0].stats.remote(self.debug))
|
||||
replay_stats = ray_get_and_free(self.replay_actors[0].stats.remote(
|
||||
self.debug))
|
||||
timing = {
|
||||
"{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3)
|
||||
for k in self.timers
|
||||
@@ -188,7 +190,7 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
||||
|
||||
with self.timers["sample_processing"]:
|
||||
completed = list(self.sample_tasks.completed())
|
||||
counts = ray.get([c[1][1] for c in completed])
|
||||
counts = ray_get_and_free([c[1][1] for c in completed])
|
||||
for i, (ev, (sample_batch, count)) in enumerate(completed):
|
||||
sample_timesteps += counts[i]
|
||||
|
||||
@@ -220,7 +222,7 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
||||
self.num_samples_dropped += 1
|
||||
else:
|
||||
with self.timers["get_samples"]:
|
||||
samples = ray.get(replay)
|
||||
samples = ray_get_and_free(replay)
|
||||
# Defensive copy against plasma crashes, see #2610 #3452
|
||||
self.learner.inqueue.put((ra, samples and samples.copy()))
|
||||
|
||||
|
||||
@@ -4,9 +4,9 @@ from __future__ import print_function
|
||||
|
||||
import logging
|
||||
|
||||
import ray
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -140,7 +140,7 @@ class PolicyOptimizer(object):
|
||||
"""Apply the given function to each evaluator instance."""
|
||||
|
||||
local_result = [func(self.local_evaluator)]
|
||||
remote_results = ray.get(
|
||||
remote_results = ray_get_and_free(
|
||||
[ev.apply.remote(func) for ev in self.remote_evaluators])
|
||||
return local_result + remote_results
|
||||
|
||||
@@ -152,7 +152,7 @@ class PolicyOptimizer(object):
|
||||
"""
|
||||
|
||||
local_result = [func(self.local_evaluator, 0)]
|
||||
remote_results = ray.get([
|
||||
remote_results = ray_get_and_free([
|
||||
ev.apply.remote(func, i + 1)
|
||||
for i, ev in enumerate(self.remote_evaluators)
|
||||
])
|
||||
|
||||
@@ -6,6 +6,7 @@ import logging
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -25,7 +26,7 @@ def collect_samples(agents, sample_batch_size, num_envs_per_worker,
|
||||
while agent_dict:
|
||||
[fut_sample], _ = ray.wait(list(agent_dict))
|
||||
agent = agent_dict.pop(fut_sample)
|
||||
next_sample = ray.get(fut_sample)
|
||||
next_sample = ray_get_and_free(fut_sample)
|
||||
assert next_sample.count >= sample_batch_size * num_envs_per_worker
|
||||
num_timesteps_so_far += next_sample.count
|
||||
trajectories.append(next_sample)
|
||||
@@ -63,7 +64,7 @@ def collect_samples_straggler_mitigation(agents, train_batch_size):
|
||||
fut_sample2 = agent.sample.remote()
|
||||
agent_dict[fut_sample2] = agent
|
||||
|
||||
next_sample = ray.get(fut_sample)
|
||||
next_sample = ray_get_and_free(fut_sample)
|
||||
num_timesteps_so_far += next_sample.count
|
||||
trajectories.append(next_sample)
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
|
||||
class SyncBatchReplayOptimizer(PolicyOptimizer):
|
||||
@@ -51,7 +52,7 @@ class SyncBatchReplayOptimizer(PolicyOptimizer):
|
||||
|
||||
with self.sample_timer:
|
||||
if self.remote_evaluators:
|
||||
batches = ray.get(
|
||||
batches = ray_get_and_free(
|
||||
[e.sample.remote() for e in self.remote_evaluators])
|
||||
else:
|
||||
batches = [self.local_evaluator.sample()]
|
||||
|
||||
@@ -17,6 +17,7 @@ from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.compression import pack_if_needed
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.schedules import LinearSchedule
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -89,7 +90,7 @@ class SyncReplayOptimizer(PolicyOptimizer):
|
||||
with self.sample_timer:
|
||||
if self.remote_evaluators:
|
||||
batch = SampleBatch.concat_samples(
|
||||
ray.get(
|
||||
ray_get_and_free(
|
||||
[e.sample.remote() for e in self.remote_evaluators]))
|
||||
else:
|
||||
batch = self.local_evaluator.sample()
|
||||
|
||||
@@ -10,6 +10,7 @@ from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.filter import RunningStat
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -50,7 +51,7 @@ class SyncSamplesOptimizer(PolicyOptimizer):
|
||||
while sum(s.count for s in samples) < self.train_batch_size:
|
||||
if self.remote_evaluators:
|
||||
samples.extend(
|
||||
ray.get([
|
||||
ray_get_and_free([
|
||||
e.sample.remote() for e in self.remote_evaluators
|
||||
]))
|
||||
else:
|
||||
|
||||
@@ -47,6 +47,8 @@ if __name__ == "__main__":
|
||||
do_link("tune", force=args.yes)
|
||||
do_link("autoscaler", force=args.yes)
|
||||
do_link("scripts", force=args.yes)
|
||||
do_link("internal", force=args.yes)
|
||||
do_link("experimental", force=args.yes)
|
||||
print("Created links.\n\nIf you run into issues initializing Ray, please "
|
||||
"ensure that your local repo and the installed Ray are in sync "
|
||||
"(pip install -U the latest wheels at "
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import print_function
|
||||
|
||||
import ray
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
@@ -24,7 +25,7 @@ class FilterManager(object):
|
||||
remotes (list): Remote evaluators with filters.
|
||||
update_remote (bool): Whether to push updates to remote filters.
|
||||
"""
|
||||
remote_filters = ray.get(
|
||||
remote_filters = ray_get_and_free(
|
||||
[r.get_filters.remote(flush_after=True) for r in remotes])
|
||||
for rf in remote_filters:
|
||||
for k in local_filters:
|
||||
|
||||
@@ -3,6 +3,47 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
import ray
|
||||
|
||||
FREE_DELAY_S = 10.0
|
||||
MAX_FREE_QUEUE_SIZE = 100
|
||||
_last_free_time = 0.0
|
||||
_to_free = []
|
||||
|
||||
|
||||
def ray_get_and_free(object_ids):
|
||||
"""Call ray.get and then queue the object ids for deletion.
|
||||
|
||||
This function should be used whenever possible in RLlib, to optimize
|
||||
memory usage. The only exception is when an object_id is shared among
|
||||
multiple readers.
|
||||
|
||||
Args:
|
||||
object_ids (ObjectID|List[ObjectID]): Object ids to fetch and free.
|
||||
|
||||
Returns:
|
||||
The result of ray.get(object_ids).
|
||||
"""
|
||||
|
||||
global _last_free_time
|
||||
global _to_free
|
||||
|
||||
result = ray.get(object_ids)
|
||||
if type(object_ids) is not list:
|
||||
object_ids = [object_ids]
|
||||
_to_free.extend(object_ids)
|
||||
|
||||
# batch calls to free to reduce overheads
|
||||
now = time.time()
|
||||
if (len(_to_free) > MAX_FREE_QUEUE_SIZE
|
||||
or now - _last_free_time > FREE_DELAY_S):
|
||||
ray.internal.free(_to_free)
|
||||
_to_free = []
|
||||
_last_free_time = now
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def aligned_array(size, dtype, align=64):
|
||||
|
||||
Reference in New Issue
Block a user