add large data warning (#10957)

This commit is contained in:
Eric Liang
2020-09-23 15:46:06 -07:00
committed by GitHub
parent 567009d5fd
commit ecdaaffc67
6 changed files with 35 additions and 8 deletions
+2 -2
View File
@@ -29,8 +29,8 @@ DEFAULT_CONFIG = with_common_config({
"lr": 1e-4,
# Number of timesteps collected for each SGD round.
"train_batch_size": 2000,
# Number of steps max to keep in the batch replay buffer.
"replay_buffer_size": 100000,
# Size of the replay buffer in batches (not timesteps!).
"replay_buffer_size": 1000,
# Number of steps to read before learning starts.
"learning_starts": 0,
# === Parallelism ===
+2 -2
View File
@@ -56,8 +56,8 @@ DEFAULT_CONFIG = with_common_config({
"target_network_update_freq": 500,
# === Replay buffer ===
# Size of the replay buffer in steps.
"buffer_size": 10000,
# Size of the replay buffer in batches (not timesteps!).
"buffer_size": 1000,
# === Optimization ===
# Learning rate for RMSProp optimizer
@@ -53,7 +53,8 @@ DEFAULT_CONFIG = with_common_config({
"num_sgd_iter": 30,
# IN case a buffer optimizer is used
"learning_starts": 1000,
"buffer_size": 10000,
# Size of the replay buffer in batches (not timesteps!).
"buffer_size": 1000,
# Stepsize of SGD
"lr": 5e-5,
# Learning rate schedule
+25 -1
View File
@@ -5,12 +5,16 @@ import platform
import random
from typing import List
import ray
# Import ray before psutil will make sure we use psutil's bundled version
import ray # noqa F401
import psutil # noqa E402
from ray.rllib.execution.segment_tree import SumSegmentTree, MinSegmentTree
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch, \
DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import DeveloperAPI
from ray.util.iter import ParallelIteratorWorker
from ray.util.debug import log_once
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.window_stat import WindowStat
from ray.rllib.utils.typing import SampleBatchType
@@ -21,6 +25,25 @@ _ALL_POLICIES = "__all__"
logger = logging.getLogger(__name__)
def warn_replay_buffer_size(*, item: SampleBatchType, num_items: int) -> None:
"""Warn if the configured replay buffer size is too large."""
if log_once("replay_buffer_size"):
item_size = item.size_bytes()
psutil_mem = psutil.virtual_memory()
total_gb = psutil_mem.total / 1e9
mem_size = num_items * item_size / 1e9
msg = ("Estimated max memory usage for replay buffer is {} GB "
"({} batches of {} bytes each), "
"available system memory is {} GB".format(
mem_size, num_items, item_size, total_gb))
if mem_size > total_gb:
raise ValueError(msg)
elif mem_size > 0.2 * total_gb:
logger.warning(msg)
else:
logger.info(msg)
@DeveloperAPI
class ReplayBuffer:
@DeveloperAPI
@@ -45,6 +68,7 @@ class ReplayBuffer:
@DeveloperAPI
def add(self, item: SampleBatchType, weight: float):
warn_replay_buffer_size(item=item, num_items=self._maxsize)
assert item.count > 0, item
self._num_added += 1
+3 -1
View File
@@ -3,7 +3,8 @@ import random
from ray.util.iter import from_actors, LocalIterator, _NextValueNotReady
from ray.util.iter_metrics import SharedMetrics
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
from ray.rllib.execution.replay_buffer import LocalReplayBuffer, \
warn_replay_buffer_size
from ray.rllib.execution.common import \
STEPS_SAMPLED_COUNTER, _get_shared_metrics
from ray.rllib.utils.typing import SampleBatchType
@@ -122,6 +123,7 @@ class SimpleReplayBuffer:
self.replay_index = 0
def add_batch(self, sample_batch):
warn_replay_buffer_size(item=sample_batch, num_items=self.num_slots)
if self.num_slots > 0:
if len(self.replay_batches) < self.num_slots:
self.replay_batches.append(sample_batch)
+1 -1
View File
@@ -313,7 +313,7 @@ class SampleBatch:
Returns:
int: The overall size in bytes of the data buffer (all columns).
"""
return sum(sys.getsizeof(d) for d in self.data)
return sum(sys.getsizeof(d) for d in self.data.values())
@PublicAPI
def __getitem__(self, key: str) -> TensorType: