From ecdaaffc67a906c10fed6ac0767dab2a636cbd88 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 23 Sep 2020 15:46:06 -0700 Subject: [PATCH] add large data warning (#10957) --- rllib/agents/marwil/marwil.py | 4 +-- rllib/agents/qmix/qmix.py | 4 +-- .../alpha_zero/core/alpha_zero_trainer.py | 3 ++- rllib/execution/replay_buffer.py | 26 ++++++++++++++++++- rllib/execution/replay_ops.py | 4 ++- rllib/policy/sample_batch.py | 2 +- 6 files changed, 35 insertions(+), 8 deletions(-) diff --git a/rllib/agents/marwil/marwil.py b/rllib/agents/marwil/marwil.py index e68c61dc9..6aeb373c5 100644 --- a/rllib/agents/marwil/marwil.py +++ b/rllib/agents/marwil/marwil.py @@ -29,8 +29,8 @@ DEFAULT_CONFIG = with_common_config({ "lr": 1e-4, # Number of timesteps collected for each SGD round. "train_batch_size": 2000, - # Number of steps max to keep in the batch replay buffer. - "replay_buffer_size": 100000, + # Size of the replay buffer in batches (not timesteps!). + "replay_buffer_size": 1000, # Number of steps to read before learning starts. "learning_starts": 0, # === Parallelism === diff --git a/rllib/agents/qmix/qmix.py b/rllib/agents/qmix/qmix.py index 3f50a3d42..7d64680f5 100644 --- a/rllib/agents/qmix/qmix.py +++ b/rllib/agents/qmix/qmix.py @@ -56,8 +56,8 @@ DEFAULT_CONFIG = with_common_config({ "target_network_update_freq": 500, # === Replay buffer === - # Size of the replay buffer in steps. - "buffer_size": 10000, + # Size of the replay buffer in batches (not timesteps!). + "buffer_size": 1000, # === Optimization === # Learning rate for RMSProp optimizer diff --git a/rllib/contrib/alpha_zero/core/alpha_zero_trainer.py b/rllib/contrib/alpha_zero/core/alpha_zero_trainer.py index 199c752b5..27315108b 100644 --- a/rllib/contrib/alpha_zero/core/alpha_zero_trainer.py +++ b/rllib/contrib/alpha_zero/core/alpha_zero_trainer.py @@ -53,7 +53,8 @@ DEFAULT_CONFIG = with_common_config({ "num_sgd_iter": 30, # IN case a buffer optimizer is used "learning_starts": 1000, - "buffer_size": 10000, + # Size of the replay buffer in batches (not timesteps!). + "buffer_size": 1000, # Stepsize of SGD "lr": 5e-5, # Learning rate schedule diff --git a/rllib/execution/replay_buffer.py b/rllib/execution/replay_buffer.py index 15a0719cd..a7355b85b 100644 --- a/rllib/execution/replay_buffer.py +++ b/rllib/execution/replay_buffer.py @@ -5,12 +5,16 @@ import platform import random from typing import List -import ray +# Import ray before psutil will make sure we use psutil's bundled version +import ray # noqa F401 +import psutil # noqa E402 + from ray.rllib.execution.segment_tree import SumSegmentTree, MinSegmentTree from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch, \ DEFAULT_POLICY_ID from ray.rllib.utils.annotations import DeveloperAPI from ray.util.iter import ParallelIteratorWorker +from ray.util.debug import log_once from ray.rllib.utils.timer import TimerStat from ray.rllib.utils.window_stat import WindowStat from ray.rllib.utils.typing import SampleBatchType @@ -21,6 +25,25 @@ _ALL_POLICIES = "__all__" logger = logging.getLogger(__name__) +def warn_replay_buffer_size(*, item: SampleBatchType, num_items: int) -> None: + """Warn if the configured replay buffer size is too large.""" + if log_once("replay_buffer_size"): + item_size = item.size_bytes() + psutil_mem = psutil.virtual_memory() + total_gb = psutil_mem.total / 1e9 + mem_size = num_items * item_size / 1e9 + msg = ("Estimated max memory usage for replay buffer is {} GB " + "({} batches of {} bytes each), " + "available system memory is {} GB".format( + mem_size, num_items, item_size, total_gb)) + if mem_size > total_gb: + raise ValueError(msg) + elif mem_size > 0.2 * total_gb: + logger.warning(msg) + else: + logger.info(msg) + + @DeveloperAPI class ReplayBuffer: @DeveloperAPI @@ -45,6 +68,7 @@ class ReplayBuffer: @DeveloperAPI def add(self, item: SampleBatchType, weight: float): + warn_replay_buffer_size(item=item, num_items=self._maxsize) assert item.count > 0, item self._num_added += 1 diff --git a/rllib/execution/replay_ops.py b/rllib/execution/replay_ops.py index 9be190a74..9ed25e9e9 100644 --- a/rllib/execution/replay_ops.py +++ b/rllib/execution/replay_ops.py @@ -3,7 +3,8 @@ import random from ray.util.iter import from_actors, LocalIterator, _NextValueNotReady from ray.util.iter_metrics import SharedMetrics -from ray.rllib.execution.replay_buffer import LocalReplayBuffer +from ray.rllib.execution.replay_buffer import LocalReplayBuffer, \ + warn_replay_buffer_size from ray.rllib.execution.common import \ STEPS_SAMPLED_COUNTER, _get_shared_metrics from ray.rllib.utils.typing import SampleBatchType @@ -122,6 +123,7 @@ class SimpleReplayBuffer: self.replay_index = 0 def add_batch(self, sample_batch): + warn_replay_buffer_size(item=sample_batch, num_items=self.num_slots) if self.num_slots > 0: if len(self.replay_batches) < self.num_slots: self.replay_batches.append(sample_batch) diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index 6f7c24b53..33d943abc 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -313,7 +313,7 @@ class SampleBatch: Returns: int: The overall size in bytes of the data buffer (all columns). """ - return sum(sys.getsizeof(d) for d in self.data) + return sum(sys.getsizeof(d) for d in self.data.values()) @PublicAPI def __getitem__(self, key: str) -> TensorType: