mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:16:19 +08:00
add large data warning (#10957)
This commit is contained in:
@@ -29,8 +29,8 @@ DEFAULT_CONFIG = with_common_config({
|
||||
"lr": 1e-4,
|
||||
# Number of timesteps collected for each SGD round.
|
||||
"train_batch_size": 2000,
|
||||
# Number of steps max to keep in the batch replay buffer.
|
||||
"replay_buffer_size": 100000,
|
||||
# Size of the replay buffer in batches (not timesteps!).
|
||||
"replay_buffer_size": 1000,
|
||||
# Number of steps to read before learning starts.
|
||||
"learning_starts": 0,
|
||||
# === Parallelism ===
|
||||
|
||||
@@ -56,8 +56,8 @@ DEFAULT_CONFIG = with_common_config({
|
||||
"target_network_update_freq": 500,
|
||||
|
||||
# === Replay buffer ===
|
||||
# Size of the replay buffer in steps.
|
||||
"buffer_size": 10000,
|
||||
# Size of the replay buffer in batches (not timesteps!).
|
||||
"buffer_size": 1000,
|
||||
|
||||
# === Optimization ===
|
||||
# Learning rate for RMSProp optimizer
|
||||
|
||||
@@ -53,7 +53,8 @@ DEFAULT_CONFIG = with_common_config({
|
||||
"num_sgd_iter": 30,
|
||||
# IN case a buffer optimizer is used
|
||||
"learning_starts": 1000,
|
||||
"buffer_size": 10000,
|
||||
# Size of the replay buffer in batches (not timesteps!).
|
||||
"buffer_size": 1000,
|
||||
# Stepsize of SGD
|
||||
"lr": 5e-5,
|
||||
# Learning rate schedule
|
||||
|
||||
@@ -5,12 +5,16 @@ import platform
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
import ray
|
||||
# Import ray before psutil will make sure we use psutil's bundled version
|
||||
import ray # noqa F401
|
||||
import psutil # noqa E402
|
||||
|
||||
from ray.rllib.execution.segment_tree import SumSegmentTree, MinSegmentTree
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch, \
|
||||
DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.util.iter import ParallelIteratorWorker
|
||||
from ray.util.debug import log_once
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.window_stat import WindowStat
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
@@ -21,6 +25,25 @@ _ALL_POLICIES = "__all__"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def warn_replay_buffer_size(*, item: SampleBatchType, num_items: int) -> None:
|
||||
"""Warn if the configured replay buffer size is too large."""
|
||||
if log_once("replay_buffer_size"):
|
||||
item_size = item.size_bytes()
|
||||
psutil_mem = psutil.virtual_memory()
|
||||
total_gb = psutil_mem.total / 1e9
|
||||
mem_size = num_items * item_size / 1e9
|
||||
msg = ("Estimated max memory usage for replay buffer is {} GB "
|
||||
"({} batches of {} bytes each), "
|
||||
"available system memory is {} GB".format(
|
||||
mem_size, num_items, item_size, total_gb))
|
||||
if mem_size > total_gb:
|
||||
raise ValueError(msg)
|
||||
elif mem_size > 0.2 * total_gb:
|
||||
logger.warning(msg)
|
||||
else:
|
||||
logger.info(msg)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class ReplayBuffer:
|
||||
@DeveloperAPI
|
||||
@@ -45,6 +68,7 @@ class ReplayBuffer:
|
||||
|
||||
@DeveloperAPI
|
||||
def add(self, item: SampleBatchType, weight: float):
|
||||
warn_replay_buffer_size(item=item, num_items=self._maxsize)
|
||||
assert item.count > 0, item
|
||||
self._num_added += 1
|
||||
|
||||
|
||||
@@ -3,7 +3,8 @@ import random
|
||||
|
||||
from ray.util.iter import from_actors, LocalIterator, _NextValueNotReady
|
||||
from ray.util.iter_metrics import SharedMetrics
|
||||
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
|
||||
from ray.rllib.execution.replay_buffer import LocalReplayBuffer, \
|
||||
warn_replay_buffer_size
|
||||
from ray.rllib.execution.common import \
|
||||
STEPS_SAMPLED_COUNTER, _get_shared_metrics
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
@@ -122,6 +123,7 @@ class SimpleReplayBuffer:
|
||||
self.replay_index = 0
|
||||
|
||||
def add_batch(self, sample_batch):
|
||||
warn_replay_buffer_size(item=sample_batch, num_items=self.num_slots)
|
||||
if self.num_slots > 0:
|
||||
if len(self.replay_batches) < self.num_slots:
|
||||
self.replay_batches.append(sample_batch)
|
||||
|
||||
@@ -313,7 +313,7 @@ class SampleBatch:
|
||||
Returns:
|
||||
int: The overall size in bytes of the data buffer (all columns).
|
||||
"""
|
||||
return sum(sys.getsizeof(d) for d in self.data)
|
||||
return sum(sys.getsizeof(d) for d in self.data.values())
|
||||
|
||||
@PublicAPI
|
||||
def __getitem__(self, key: str) -> TensorType:
|
||||
|
||||
Reference in New Issue
Block a user