diff --git a/python/ray/experimental/__init__.py b/python/ray/experimental/__init__.py index 8f50b5b8a..2d942efe3 100644 --- a/python/ray/experimental/__init__.py +++ b/python/ray/experimental/__init__.py @@ -3,6 +3,7 @@ from __future__ import division from __future__ import print_function from .tfutils import TensorFlowVariables -from .features import flush_redis_unsafe +from .features import flush_redis_unsafe, flush_task_and_object_metadata_unsafe -__all__ = ["TensorFlowVariables", "flush_redis_unsafe"] +__all__ = ["TensorFlowVariables", "flush_redis_unsafe", + "flush_task_and_object_metadata_unsafe"] diff --git a/python/ray/experimental/features.py b/python/ray/experimental/features.py index b6c3e77b3..ca12fad68 100644 --- a/python/ray/experimental/features.py +++ b/python/ray/experimental/features.py @@ -4,6 +4,10 @@ from __future__ import print_function import ray +OBJECT_INFO_PREFIX = b"OI:" +OBJECT_LOCATION_PREFIX = b"OL:" +TASK_TABLE_PREFIX = b"TT:" + def flush_redis_unsafe(): """This removes some non-critical state from the primary Redis shard. @@ -35,3 +39,45 @@ def flush_redis_unsafe(): else: num_deleted = 0 print("Deleted {} event logs from Redis.".format(num_deleted)) + + +def flush_task_and_object_metadata_unsafe(): + """This removes some critical state from the Redis shards. + + In a multitenant environment, this will flush metadata for all jobs, which + may be undesirable. + + This removes all of the object and task metadata. This can be used to try + to address out-of-memory errors caused by the accumulation of metadata in + Redis. However, after running this command, fault tolerance will most + likely not work. + """ + if not hasattr(ray.worker.global_worker, "redis_client"): + raise Exception("ray.experimental.flush_redis_unsafe cannot be called " + "before ray.init() has been called.") + + def flush_shard(redis_client): + # Flush the task table. Note that this also flushes the driver tasks + # which may be undesirable. + num_task_keys_deleted = 0 + for key in redis_client.scan_iter(match=TASK_TABLE_PREFIX + b"*"): + num_task_keys_deleted += redis_client.delete(key) + print("Deleted {} task keys from Redis.".format(num_task_keys_deleted)) + + # Flush the object information. + num_object_keys_deleted = 0 + for key in redis_client.scan_iter(match=OBJECT_INFO_PREFIX + b"*"): + num_object_keys_deleted += redis_client.delete(key) + print("Deleted {} object info keys from Redis.".format( + num_object_keys_deleted)) + + # Flush the object locations. + num_object_location_keys_deleted = 0 + for key in redis_client.scan_iter(match=OBJECT_LOCATION_PREFIX + b"*"): + num_object_location_keys_deleted += redis_client.delete(key) + print("Deleted {} object location keys from Redis.".format( + num_object_location_keys_deleted)) + + # Loop over the shards and flush all of them. + for redis_client in ray.worker.global_state.redis_clients: + flush_shard(redis_client) diff --git a/test/runtest.py b/test/runtest.py index 186c9c23d..fb326a12d 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -2068,6 +2068,60 @@ class GlobalStateAPI(unittest.TestCase): # the visualization actually renders (e.g., the context of the dumped # trace could be malformed). + def testFlushAPI(self): + ray.init(num_cpus=1) + + @ray.remote + def f(): + return 1 + + [ray.put(1) for _ in range(10)] + ray.get([f.remote() for _ in range(10)]) + + # Wait until all of the task and object information has been stored in + # Redis. Note that since a given key may be updated multiple times + # (e.g., multiple calls to TaskTableUpdate), this is an attempt to wait + # until all updates have happened. Note that in a real application we + # could encounter this kind of issue as well. + while True: + object_table = ray.global_state.object_table() + task_table = ray.global_state.task_table() + + tables_ready = True + + if len(object_table) != 20: + tables_ready = False + + for object_info in object_table.values(): + if len(object_info) != 5: + tables_ready = False + if (object_info["ManagerIDs"] is None or + object_info["DataSize"] == -1 or + object_info["Hash"] == ""): + tables_ready = False + + if len(task_table) != 10 + 1: + tables_ready = False + + driver_task_id = ray.utils.binary_to_hex( + ray.worker.global_worker.current_task_id.id()) + + for info in task_table.values(): + if info["State"] != ray.experimental.state.TASK_STATUS_DONE: + if info["TaskSpec"]["TaskID"] != driver_task_id: + tables_ready = False + + if tables_ready: + break + + # Flush the tables. + ray.experimental.flush_redis_unsafe() + ray.experimental.flush_task_and_object_metadata_unsafe() + + # Make sure the tables are empty. + assert len(ray.global_state.object_table()) == 0 + assert len(ray.global_state.task_table()) == 0 + if __name__ == "__main__": unittest.main(verbosity=2)