diff --git a/doc/source/package-ref.rst b/doc/source/package-ref.rst index b7dcb165a..acd55eae5 100644 --- a/doc/source/package-ref.rst +++ b/doc/source/package-ref.rst @@ -34,8 +34,6 @@ Inspect the Cluster State .. autofunction:: ray.nodes -.. autofunction:: ray.tasks - .. autofunction:: ray.objects .. autofunction:: ray.timeline diff --git a/python/ray/__init__.py b/python/ray/__init__.py index d4664ff5e..3a2a69f82 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -59,7 +59,7 @@ from ray._raylet import ( _config = _Config() from ray.profiling import profile # noqa: E402 -from ray.state import (jobs, nodes, actors, tasks, objects, timeline, +from ray.state import (jobs, nodes, actors, objects, timeline, object_transfer_timeline, cluster_resources, available_resources, errors) # noqa: E402 from ray.worker import ( @@ -99,7 +99,6 @@ __all__ = [ "jobs", "nodes", "actors", - "tasks", "objects", "timeline", "object_transfer_timeline", diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index bc0595949..63400ff47 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -75,7 +75,6 @@ from ray.includes.libcoreworker cimport ( CFiberEvent, CActorHandle, ) -from ray.includes.task cimport CTaskSpec from ray.includes.ray_config cimport RayConfig import ray @@ -98,7 +97,6 @@ cimport cpython include "includes/unique_ids.pxi" include "includes/ray_config.pxi" include "includes/function_descriptor.pxi" -include "includes/task.pxi" include "includes/buffer.pxi" include "includes/common.pxi" include "includes/serialization.pxi" diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 9daf0dee8..053a968fb 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -36,7 +36,6 @@ from ray.includes.common cimport ( from ray.includes.function_descriptor cimport ( CFunctionDescriptor, ) -from ray.includes.task cimport CTaskSpec ctypedef unordered_map[c_string, c_vector[pair[int64_t, double]]] \ ResourceMappingType diff --git a/python/ray/includes/task.pxd b/python/ray/includes/task.pxd deleted file mode 100644 index d19387279..000000000 --- a/python/ray/includes/task.pxd +++ /dev/null @@ -1,89 +0,0 @@ -from libc.stdint cimport uint8_t, uint64_t -from libcpp cimport bool as c_bool -from libcpp.memory cimport unique_ptr, shared_ptr -from libcpp.string cimport string as c_string -from libcpp.unordered_map cimport unordered_map -from libcpp.vector cimport vector as c_vector - -from ray.includes.common cimport ( - CLanguage, - ResourceSet, -) -from ray.includes.unique_ids cimport ( - CActorID, - CJobID, - CObjectID, - CTaskID, -) -from ray.includes.function_descriptor cimport ( - CFunctionDescriptor, -) - -cdef extern from "ray/protobuf/common.pb.h" nogil: - cdef cppclass RpcTaskSpec "ray::rpc::TaskSpec": - void CopyFrom(const RpcTaskSpec &value) - - cdef cppclass RpcTaskExecutionSpec "ray::rpc::TaskExecutionSpec": - void CopyFrom(const RpcTaskExecutionSpec &value) - void add_dependencies(const c_string &value) - - cdef cppclass RpcTask "ray::rpc::Task": - RpcTaskSpec *mutable_task_spec() - -cdef extern from "ray/protobuf/gcs.pb.h" nogil: - cdef cppclass TaskTableData "ray::rpc::TaskTableData": - RpcTask *mutable_task() - const c_string &SerializeAsString() - - -cdef extern from "ray/common/task/task_spec.h" nogil: - cdef cppclass CTaskSpec "ray::TaskSpecification": - CTaskSpec(const RpcTaskSpec message) - CTaskSpec(const c_string &serialized_binary) - const RpcTaskSpec &GetMessage() - c_string Serialize() const - - CTaskID TaskId() const - CJobID JobId() const - CTaskID ParentTaskId() const - uint64_t ParentCounter() const - CFunctionDescriptor FunctionDescriptor() const - c_string FunctionDescriptorString() const - uint64_t NumArgs() const - uint64_t NumReturns() const - c_bool ArgByRef(uint64_t arg_index) const - int ArgIdCount(uint64_t arg_index) const - CObjectID ArgId(uint64_t arg_index, uint64_t id_index) const - CObjectID ReturnIdForPlasma(uint64_t return_index) const - const uint8_t *ArgData(uint64_t arg_index) const - size_t ArgDataSize(uint64_t arg_index) const - const uint8_t *ArgMetadata(uint64_t arg_index) const - size_t ArgMetadataSize(uint64_t arg_index) const - double GetRequiredResource(const c_string &resource_name) const - const ResourceSet GetRequiredResources() const - const ResourceSet GetRequiredPlacementResources() const - c_bool IsDriverTask() const - CLanguage GetLanguage() const - c_bool IsNormalTask() const - c_bool IsActorCreationTask() const - c_bool IsActorTask() const - CActorID ActorCreationId() const - CObjectID ActorCreationDummyObjectId() const - CObjectID PreviousActorTaskDummyObjectId() const - uint64_t MaxActorReconstructions() const - CActorID ActorId() const - uint64_t ActorCounter() const - CObjectID ActorDummyObject() const - - -cdef extern from "ray/common/task/task_execution_spec.h" nogil: - cdef cppclass CTaskExecutionSpec "ray::TaskExecutionSpecification": - CTaskExecutionSpec(RpcTaskExecutionSpec message) - CTaskExecutionSpec(const c_string &serialized_binary) - const RpcTaskExecutionSpec &GetMessage() - c_vector[CObjectID] ExecutionDependencies() - uint64_t NumForwards() - -cdef extern from "ray/common/task/task.h" nogil: - cdef cppclass CTask "ray::Task": - CTask(CTaskSpec task_spec, CTaskExecutionSpec task_execution_spec) diff --git a/python/ray/includes/task.pxi b/python/ray/includes/task.pxi deleted file mode 100644 index 3512b7716..000000000 --- a/python/ray/includes/task.pxi +++ /dev/null @@ -1,209 +0,0 @@ -from ray.includes.task cimport ( - CTask, - CTaskExecutionSpec, - CTaskSpec, - RpcTaskExecutionSpec, - TaskTableData, -) -from ray.ray_constants import RAW_BUFFER_METADATA -from ray.utils import decode - - -cdef class TaskSpec: - """Cython wrapper class of C++ `ray::TaskSpecification`.""" - cdef: - unique_ptr[CTaskSpec] task_spec - - @staticmethod - def from_string(const c_string& task_spec_str): - """Convert a string to a Ray task specification Python object. - - Args: - task_spec_str: String representation of the task specification. - - Returns: - Python task specification object. - """ - cdef TaskSpec self = TaskSpec.__new__(TaskSpec) - self.task_spec.reset(new CTaskSpec(task_spec_str)) - return self - - def to_string(self): - """Convert a Ray task specification Python object to a string. - - Returns: - String representing the task specification. - """ - return self.task_spec.get().Serialize() - - def is_normal_task(self): - """Whether this task is a normal task.""" - return self.task_spec.get().IsNormalTask() - - def is_actor_task(self): - """Whether this task is an actor task.""" - return self.task_spec.get().IsActorTask() - - def is_actor_creation_task(self): - """Whether this task is an actor creation task.""" - return self.task_spec.get().IsActorCreationTask() - - def job_id(self): - """Return the job ID for this task.""" - return JobID(self.task_spec.get().JobId().Binary()) - - def task_id(self): - """Return the task ID for this task.""" - return TaskID(self.task_spec.get().TaskId().Binary()) - - def parent_task_id(self): - """Return the task ID of the parent task.""" - return TaskID(self.task_spec.get().ParentTaskId().Binary()) - - def parent_counter(self): - """Return the parent counter of this task.""" - return self.task_spec.get().ParentCounter() - - def function_descriptor(self): - """Return the function descriptor for this task.""" - return CFunctionDescriptorToPython( - self.task_spec.get().FunctionDescriptor()) - - def arguments(self): - """Return the arguments for the task.""" - cdef: - int64_t num_args = self.task_spec.get().NumArgs() - int32_t lang = self.task_spec.get().GetLanguage() - int count - arg_list = [] - - if lang == LANGUAGE_PYTHON: - for i in range(num_args): - count = self.task_spec.get().ArgIdCount(i) - if count > 0: - assert count == 1 - arg_list.append( - ObjectID(self.task_spec.get().ArgId(i, 0).Binary())) - else: - data = self.task_spec.get().ArgData(i)[ - :self.task_spec.get().ArgDataSize(i)] - metadata = self.task_spec.get().ArgMetadata(i)[ - :self.task_spec.get().ArgMetadataSize(i)] - if metadata == RAW_BUFFER_METADATA: - obj = data - else: - obj = data - arg_list.append(obj) - elif lang == LANGUAGE_JAVA: - arg_list = num_args * [""] - - return arg_list - - def returns(self): - """Return the object IDs for the return values of the task.""" - return_id_list = [] - for i in range(self.task_spec.get().NumReturns()): - return_id_list.append( - ObjectID(self.task_spec.get().ReturnIdForPlasma(i).Binary())) - return return_id_list - - def required_resources(self): - """Return the resource dictionary of the task.""" - cdef: - unordered_map[c_string, double] resource_map = ( - self.task_spec.get().GetRequiredResources().GetResourceMap()) - c_string resource_name - double resource_value - unordered_map[c_string, double].iterator iterator = ( - resource_map.begin()) - - required_resources = {} - while iterator != resource_map.end(): - resource_name = dereference(iterator).first - # bytes for Py2, unicode for Py3 - py_resource_name = decode(resource_name) - resource_value = dereference(iterator).second - required_resources[py_resource_name] = resource_value - postincrement(iterator) - return required_resources - - def language(self): - """Return the language of the task.""" - return Language.from_native(self.task_spec.get().GetLanguage()) - - def actor_creation_id(self): - """Return the actor creation ID for the task.""" - if not self.is_actor_creation_task(): - return ActorID.nil() - return ActorID(self.task_spec.get().ActorCreationId().Binary()) - - def actor_creation_dummy_object_id(self): - """Return the actor creation dummy object ID for the task.""" - if not self.is_actor_task(): - return ObjectID.nil() - return ObjectID( - self.task_spec.get().ActorCreationDummyObjectId().Binary()) - - def previous_actor_task_dummy_object_id(self): - """Return the object ID of the previously executed actor task.""" - if not self.is_actor_task(): - return ObjectID.nil() - return ObjectID( - self.task_spec.get().PreviousActorTaskDummyObjectId().Binary()) - - def actor_id(self): - """Return the actor ID for this task.""" - if not self.is_actor_task(): - return ActorID.nil() - return ActorID(self.task_spec.get().ActorId().Binary()) - - def actor_counter(self): - """Return the actor counter for this task.""" - if not self.is_actor_task(): - return 0 - return self.task_spec.get().ActorCounter() - - -cdef class TaskExecutionSpec: - """Cython wrapper class of C++ `ray::TaskExecutionSpecification`.""" - cdef: - unique_ptr[CTaskExecutionSpec] c_spec - - def __init__(self): - cdef: - RpcTaskExecutionSpec message - - self.c_spec.reset(new CTaskExecutionSpec(message)) - - @staticmethod - def from_string(const c_string& string): - """Convert a string to a Ray `TaskExecutionSpec` Python object. - """ - cdef TaskExecutionSpec self = TaskExecutionSpec.__new__( - TaskExecutionSpec) - self.c_spec.reset(new CTaskExecutionSpec(string)) - return self - - def num_forwards(self): - return self.c_spec.get().NumForwards() - - -cdef class Task: - """Cython wrapper class of C++ `ray::Task`.""" - cdef: - unique_ptr[CTask] c_task - - def __init__( - self, TaskSpec task_spec, TaskExecutionSpec task_execution_spec): - self.c_task.reset(new CTask(task_spec.task_spec.get()[0], - task_execution_spec.c_spec.get()[0])) - - -def generate_gcs_task_table_data(TaskSpec task_spec): - """Converts a Python `TaskSpec` object to serialized GCS `TaskTableData`. - """ - cdef: - TaskTableData task_table_data - task_table_data.mutable_task().mutable_task_spec().CopyFrom( - task_spec.task_spec.get().GetMessage()) - return task_table_data.SerializeAsString() diff --git a/python/ray/includes/unique_ids.pxi b/python/ray/includes/unique_ids.pxi index ea9f298d6..fc87b49dd 100644 --- a/python/ray/includes/unique_ids.pxi +++ b/python/ray/includes/unique_ids.pxi @@ -134,8 +134,9 @@ cdef class ObjectID(BaseID): self.in_core_worker = False worker = ray.worker.global_worker - # TODO(edoakes): there are dummy object IDs being created in - # includes/task.pxi before the core worker is initialized. + # TODO(edoakes): We should be able to remove the in_core_worker flag. + # But there are still some dummy object IDs being created outside the + # context of a core worker. if hasattr(worker, "core_worker"): worker.core_worker.add_object_id_reference(self) self.in_core_worker = True diff --git a/python/ray/monitor.py b/python/ray/monitor.py index 7c6675847..a1197cf69 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -10,8 +10,7 @@ from ray.autoscaler.autoscaler import LoadMetrics, StandardAutoscaler import ray.gcs_utils import ray.utils import ray.ray_constants as ray_constants -from ray.utils import (binary_to_hex, binary_to_object_id, binary_to_task_id, - hex_to_binary, setup_logger) +from ray.utils import binary_to_hex, setup_logger from ray.autoscaler.commands import teardown_cluster logger = logging.getLogger(__name__) @@ -99,74 +98,6 @@ class Monitor: "Monitor: " "could not find ip for client {}".format(client_id)) - def _xray_clean_up_entries_for_job(self, job_id): - """Remove this job's object/task entries from redis. - - Removes control-state entries of all tasks and task return - objects belonging to the driver. - - Args: - job_id: The job id. - """ - - xray_task_table_prefix = ( - ray.gcs_utils.TablePrefix_RAYLET_TASK_string.encode("ascii")) - xray_object_table_prefix = ( - ray.gcs_utils.TablePrefix_OBJECT_string.encode("ascii")) - - task_table_objects = ray.tasks() - job_id_hex = binary_to_hex(job_id) - job_task_id_bins = set() - for task_id_hex, task_info in task_table_objects.items(): - task_table_object = task_info["TaskSpec"] - task_job_id_hex = task_table_object["JobID"] - if job_id_hex != task_job_id_hex: - # Ignore tasks that aren't from this driver. - continue - job_task_id_bins.add(hex_to_binary(task_id_hex)) - - # Get objects associated with the driver. - object_table_objects = ray.objects() - job_object_id_bins = set() - for object_id, _ in object_table_objects.items(): - task_id_bin = ray._raylet.compute_task_id(object_id).binary() - if task_id_bin in job_task_id_bins: - job_object_id_bins.add(object_id.binary()) - - def to_shard_index(id_bin): - if len(id_bin) == ray.TaskID.size(): - return binary_to_task_id(id_bin).redis_shard_hash() % len( - ray.state.state.redis_clients) - else: - return binary_to_object_id(id_bin).redis_shard_hash() % len( - ray.state.state.redis_clients) - - # Form the redis keys to delete. - sharded_keys = [[] for _ in range(len(ray.state.state.redis_clients))] - for task_id_bin in job_task_id_bins: - sharded_keys[to_shard_index(task_id_bin)].append( - xray_task_table_prefix + task_id_bin) - for object_id_bin in job_object_id_bins: - sharded_keys[to_shard_index(object_id_bin)].append( - xray_object_table_prefix + object_id_bin) - - # Remove with best effort. - for shard_index in range(len(sharded_keys)): - keys = sharded_keys[shard_index] - if len(keys) == 0: - continue - redis = ray.state.state.redis_clients[shard_index] - num_deleted = redis.delete(*keys) - logger.info("Monitor: " - "Removed {} dead redis entries of the " - "driver from redis shard {}.".format( - num_deleted, shard_index)) - if num_deleted != len(keys): - logger.warning("Monitor: " - "Failed to remove {} relevant redis " - "entries from redis shard {}.".format( - len(keys) - num_deleted, shard_index)) - def xray_job_notification_handler(self, unused_channel, data): """Handle a notification that a job has been added or removed. @@ -182,7 +113,6 @@ class Monitor: logger.info("Monitor: " "XRay Driver {} has been removed.".format( binary_to_hex(job_id))) - self._xray_clean_up_entries_for_job(job_id) def autoscaler_resource_request_handler(self, _, data): """Handle a notification of a resource request for the autoscaler. diff --git a/python/ray/state.py b/python/ray/state.py index 5985d0561..0bc97656e 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -369,86 +369,6 @@ class GlobalState: ray.ActorID(actor_id_binary)) return results - def _task_table(self, task_id): - """Fetch and parse the task table information for a single task ID. - - Args: - task_id: A task ID to get information about. - - Returns: - A dictionary with information about the task ID in question. - """ - assert isinstance(task_id, ray.TaskID) - message = self._execute_command( - task_id, "RAY.TABLE_LOOKUP", - gcs_utils.TablePrefix.Value("RAYLET_TASK"), "", task_id.binary()) - if message is None: - return {} - gcs_entries = gcs_utils.GcsEntry.FromString(message) - - assert len(gcs_entries.entries) == 1 - task_table_data = gcs_utils.TaskTableData.FromString( - gcs_entries.entries[0]) - - task = ray._raylet.TaskSpec.from_string( - task_table_data.task.task_spec.SerializeToString()) - function_descriptor = task.function_descriptor() - - task_spec_info = { - "JobID": task.job_id().hex(), - "TaskID": task.task_id().hex(), - "ParentTaskID": task.parent_task_id().hex(), - "ParentCounter": task.parent_counter(), - "ActorID": (task.actor_id().hex()), - "ActorCreationID": task.actor_creation_id().hex(), - "ActorCreationDummyObjectID": ( - task.actor_creation_dummy_object_id().hex()), - "PreviousActorTaskDummyObjectID": ( - task.previous_actor_task_dummy_object_id().hex()), - "ActorCounter": task.actor_counter(), - "Args": task.arguments(), - "ReturnObjectIDs": task.returns(), - "RequiredResources": task.required_resources(), - "FunctionDescriptor": function_descriptor.to_dict(), - } - - execution_spec = ray._raylet.TaskExecutionSpec.from_string( - task_table_data.task.task_execution_spec.SerializeToString()) - return { - "ExecutionSpec": { - "NumForwards": execution_spec.num_forwards(), - }, - "TaskSpec": task_spec_info - } - - def task_table(self, task_id=None): - """Fetch and parse the task table information for one or more task IDs. - - Args: - task_id: A hex string of the task ID to fetch information about. If - this is None, then the task object table is fetched. - - Returns: - Information from the task table. - """ - self._check_connected() - if task_id is not None: - task_id = ray.TaskID(hex_to_binary(task_id)) - return self._task_table(task_id) - else: - task_table_keys = self._keys( - gcs_utils.TablePrefix_RAYLET_TASK_string + "*") - task_ids_binary = [ - key[len(gcs_utils.TablePrefix_RAYLET_TASK_string):] - for key in task_table_keys - ] - - results = {} - for task_id_binary in task_ids_binary: - results[binary_to_hex(task_id_binary)] = self._task_table( - ray.TaskID(task_id_binary)) - return results - def client_table(self): """Fetch and parse the Redis DB client table. @@ -1133,19 +1053,6 @@ def actors(actor_id=None): return state.actor_table(actor_id=actor_id) -def tasks(task_id=None): - """Fetch and parse the task table information for one or more task IDs. - - Args: - task_id: A hex string of the task ID to fetch information about. If - this is None, then the task object table is fetched. - - Returns: - Information from the task table. - """ - return state.task_table(task_id=task_id) - - def objects(object_id=None): """Fetch and parse the object table info for one or more object IDs. diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 6cabf40de..db16c7881 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -282,15 +282,6 @@ py_test( deps = ["//:ray_lib"], ) -py_test( - name = "test_monitors", - size = "small", - srcs = ["test_monitors.py"], - # TODO(ekl) tasks() and objects() are different in direct call mode. - tags = ["exclusive", "manual"], - deps = ["//:ray_lib"], -) - py_test( name = "test_multiprocessing", size = "medium", diff --git a/python/ray/tests/test_advanced_3.py b/python/ray/tests/test_advanced_3.py index 5bca918e1..e4bb178bc 100644 --- a/python/ray/tests/test_advanced_3.py +++ b/python/ray/tests/test_advanced_3.py @@ -94,15 +94,6 @@ def wait_for_num_actors(num_actors, timeout=10): raise RayTestTimeoutException("Timed out while waiting for global state.") -def wait_for_num_tasks(num_tasks, timeout=10): - start_time = time.time() - while time.time() - start_time < timeout: - if len(ray.tasks()) >= num_tasks: - return - time.sleep(0.1) - raise RayTestTimeoutException("Timed out while waiting for global state.") - - def wait_for_num_objects(num_objects, timeout=10): start_time = time.time() while time.time() - start_time < timeout: @@ -123,9 +114,6 @@ def test_global_state_api(shutdown_only): with pytest.raises(Exception, match=error_message): ray.actors() - with pytest.raises(Exception, match=error_message): - ray.tasks() - with pytest.raises(Exception, match=error_message): ray.nodes() @@ -142,22 +130,6 @@ def test_global_state_api(shutdown_only): job_id = ray.utils.compute_job_id_from_driver( ray.WorkerID(ray.worker.global_worker.worker_id)) - driver_task_id = ray.worker.global_worker.current_task_id.hex() - - # One task is put in the task table which corresponds to this driver. - wait_for_num_tasks(1) - task_table = ray.tasks() - assert len(task_table) == 1 - assert driver_task_id == list(task_table.keys())[0] - task_spec = task_table[driver_task_id]["TaskSpec"] - nil_actor_id_hex = ray.ActorID.nil().hex() - - assert task_spec["TaskID"] == driver_task_id - assert task_spec["ActorID"] == nil_actor_id_hex - assert task_spec["Args"] == [] - assert task_spec["JobID"] == job_id.hex() - assert task_spec["FunctionDescriptor"]["type"] == "EmptyFunctionDescriptor" - assert task_spec["ReturnObjectIDs"] == [] client_table = ray.nodes() node_ip_address = ray.worker.global_worker.node_ip_address diff --git a/python/ray/tests/test_monitors.py b/python/ray/tests/test_monitors.py deleted file mode 100644 index f6296c8a5..000000000 --- a/python/ray/tests/test_monitors.py +++ /dev/null @@ -1,114 +0,0 @@ -import multiprocessing -import os -import pytest -import subprocess -import time - -import ray - - -def _test_cleanup_on_driver_exit(num_redis_shards): - output = ray.utils.decode( - subprocess.check_output( - [ - "ray", - "start", - "--head", - "--num-redis-shards", - str(num_redis_shards), - ], - stderr=subprocess.STDOUT)) - lines = [m.strip() for m in output.split("\n")] - init_cmd = [m for m in lines if m.startswith("ray.init")] - assert 1 == len(init_cmd) - address = init_cmd[0].split("address=\"")[-1][:-2] - max_attempts_before_failing = 100 - # Wait for monitor.py to start working. - time.sleep(2) - - def StateSummary(): - obj_tbl_len = len(ray.objects()) - task_tbl_len = len(ray.tasks()) - return obj_tbl_len, task_tbl_len - - def Driver(success): - success.value = True - # Start driver. - ray.init(address=address) - summary_start = StateSummary() - if (0, 1) != summary_start: - success.value = False - - # Two new objects. - ray.get(ray.put(1111)) - ray.get(ray.put(1111)) - - @ray.remote - def f(): - ray.put(1111) # Yet another object. - return 1111 # A returned object as well. - - # 1 new function. - attempts = 0 - while (2, 1) != StateSummary(): - time.sleep(0.1) - attempts += 1 - if attempts == max_attempts_before_failing: - success.value = False - break - - ray.get(f.remote()) - attempts = 0 - while (4, 2) != StateSummary(): - time.sleep(0.1) - attempts += 1 - if attempts == max_attempts_before_failing: - success.value = False - break - - ray.shutdown() - - success = multiprocessing.Value("b", False) - driver = multiprocessing.Process(target=Driver, args=(success, )) - driver.start() - # Wait for client to exit. - driver.join() - - # Just make sure Driver() is run and succeeded. - assert success.value - # Check that objects, tasks, and functions are cleaned up. - ray.init(address=address) - attempts = 0 - while (0, 1) != StateSummary(): - time.sleep(0.1) - attempts += 1 - if attempts == max_attempts_before_failing: - break - assert (0, 1) == StateSummary() - - ray.shutdown() - subprocess.check_output(["ray", "stop"]) - - -@pytest.mark.skipif( - os.environ.get("RAY_USE_NEW_GCS") == "on", - reason="Hanging with the new GCS API.") -def test_cleanup_on_driver_exit_single_redis_shard(): - _test_cleanup_on_driver_exit(num_redis_shards=1) - - -@pytest.mark.skipif( - os.environ.get("RAY_USE_NEW_GCS") == "on", - reason="Hanging with the new GCS API.") -def test_cleanup_on_driver_exit_many_redis_shards(): - _test_cleanup_on_driver_exit(num_redis_shards=5) - _test_cleanup_on_driver_exit(num_redis_shards=31) - - -if __name__ == "__main__": - import pytest - import sys - # Make subprocess happy in bazel. - os.environ["LC_ALL"] = "en_US.UTF-8" - os.environ["LANG"] = "en_US.UTF-8" - sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_multi_node.py b/python/ray/tests/test_multi_node.py index 22ce94c02..a97637a8c 100644 --- a/python/ray/tests/test_multi_node.py +++ b/python/ray/tests/test_multi_node.py @@ -167,6 +167,39 @@ print("success") assert "success" in out +def test_cleanup_on_driver_exit(call_ray_start): + # This test will create a driver that creates a bunch of objects and then + # exits. The entries in the object table should be cleaned up. + address = call_ray_start + + ray.init(address=address) + + # Define a driver that creates a bunch of objects and exits. + driver_script = """ +import time +import ray +ray.init(address="{}") +object_ids = [ray.put(i) for i in range(1000)] +start_time = time.time() +while time.time() - start_time < 30: + if len(ray.objects()) == 1000: + break +else: + raise Exception("Objects did not appear in object table.") +print("success") +""".format(address) + + run_string_as_driver(driver_script) + + # Make sure the objects are removed from the object table. + start_time = time.time() + while time.time() - start_time < 30: + if len(ray.objects()) == 0: + break + else: + raise Exception("Objects were not all removed from object table.") + + def test_drivers_named_actors(call_ray_start): # This test will create some drivers that submit some tasks to the same # named actor.