From 8ae82180b4e1e5eddfb44ff8b96837fdd539d869 Mon Sep 17 00:00:00 2001 From: Melih Elibol Date: Thu, 9 Aug 2018 02:41:40 -0400 Subject: [PATCH] [xray] Adds a driver table. (#2289) This PR adds a driver table for the new GCS, which enables cleanup functionality associated with monitoring driver death. Some testing in `monitor_test.py` is restored, but redis sharding for xray is needed to enable remaining tests. --- .gitignore | 2 +- .travis.yml | 2 +- doc/source/conf.py | 1 + python/ray/experimental/state.py | 2 +- python/ray/gcs_utils.py | 7 +- python/ray/local_scheduler/__init__.py | 8 +- python/ray/monitor.py | 90 ++++++++++++++++++- src/common/lib/python/common_extension.cc | 9 ++ src/common/lib/python/common_extension.h | 1 + .../lib/python/local_scheduler_extension.cc | 2 + src/ray/gcs/client.cc | 3 + src/ray/gcs/client.h | 2 + src/ray/gcs/format/gcs.fbs | 10 +++ src/ray/gcs/tables.cc | 12 +++ src/ray/gcs/tables.h | 17 ++++ src/ray/object_manager/object_manager.cc | 16 +++- src/ray/object_manager/object_manager.h | 4 + src/ray/raylet/node_manager.cc | 24 +++++ src/ray/raylet/node_manager.h | 4 + test/monitor_test.py | 38 ++++++-- 20 files changed, 230 insertions(+), 24 deletions(-) diff --git a/.gitignore b/.gitignore index cd28233ac..1bd48415a 100644 --- a/.gitignore +++ b/.gitignore @@ -18,7 +18,7 @@ # Files generated by flatc should be ignored /src/common/format/*.py /src/common/format/*_generated.h -/src/plasma/format/*_generated.h +/src/plasma/format/ /src/local_scheduler/format/*_generated.h /src/ray/gcs/format/*_generated.h /src/ray/object_manager/format/*_generated.h diff --git a/.travis.yml b/.travis.yml index be1ef5306..83d051f71 100644 --- a/.travis.yml +++ b/.travis.yml @@ -148,7 +148,7 @@ matrix: # - pytest test/component_failures_test.py - python test/multi_node_test.py - python -m pytest test/recursion_test.py - # - pytest test/monitor_test.py + - pytest test/monitor_test.py - python -m pytest test/cython_test.py - python -m pytest test/credis_test.py diff --git a/doc/source/conf.py b/doc/source/conf.py index 3178715a5..1b113b71a 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -47,6 +47,7 @@ MOCK_MODULES = ["gym", "ray.core.generated.ClientTableData", "ray.core.generated.GcsTableEntry", "ray.core.generated.HeartbeatTableData", + "ray.core.generated.DriverTableData", "ray.core.generated.ErrorTableData", "ray.core.generated.ProfileTableData", "ray.core.generated.ObjectTableData", diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index 299fa981a..687fbdae5 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -169,7 +169,7 @@ class GlobalState(object): """ result = [] for client in self.redis_clients: - result.extend(client.keys(pattern)) + result.extend(list(client.scan_iter(match=pattern))) return result def _object_table(self, object_id): diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index c9fa5e2c6..53fa9d8d0 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -24,6 +24,7 @@ from ray.core.generated.ClientTableData import ClientTableData from ray.core.generated.ErrorTableData import ErrorTableData from ray.core.generated.ProfileTableData import ProfileTableData from ray.core.generated.HeartbeatTableData import HeartbeatTableData +from ray.core.generated.DriverTableData import DriverTableData from ray.core.generated.ObjectTableData import ObjectTableData from ray.core.generated.ray.protocol.Task import Task @@ -34,9 +35,9 @@ __all__ = [ "SubscribeToNotificationsReply", "ResultTableReply", "TaskExecutionDependencies", "TaskReply", "DriverTableMessage", "LocalSchedulerInfoMessage", "SubscribeToDBClientTableReply", "TaskInfo", - "GcsTableEntry", "ClientTableData", "ErrorTableData", "ProfileTableData", - "HeartbeatTableData", "ObjectTableData", "Task", "TablePrefix", - "TablePubsub", "construct_error_message" + "GcsTableEntry", "ClientTableData", "ErrorTableData", "HeartbeatTableData", + "DriverTableData", "ProfileTableData", "ObjectTableData", "Task", + "TablePrefix", "TablePubsub", "construct_error_message" ] # These prefixes must be kept up-to-date with the definitions in diff --git a/python/ray/local_scheduler/__init__.py b/python/ray/local_scheduler/__init__.py index 0f9c455bb..a469776f1 100644 --- a/python/ray/local_scheduler/__init__.py +++ b/python/ray/local_scheduler/__init__.py @@ -3,12 +3,12 @@ from __future__ import division from __future__ import print_function from ray.core.src.local_scheduler.liblocal_scheduler_library_python import ( - Task, LocalSchedulerClient, ObjectID, check_simple_value, task_from_string, - task_to_string, _config, common_error) + Task, LocalSchedulerClient, ObjectID, check_simple_value, compute_task_id, + task_from_string, task_to_string, _config, common_error) from .local_scheduler_services import start_local_scheduler __all__ = [ "Task", "LocalSchedulerClient", "ObjectID", "check_simple_value", - "task_from_string", "task_to_string", "start_local_scheduler", "_config", - "common_error" + "compute_task_id", "task_from_string", "task_to_string", + "start_local_scheduler", "_config", "common_error" ] diff --git a/python/ray/monitor.py b/python/ray/monitor.py index d659afd81..be9d0f252 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -37,6 +37,9 @@ DRIVER_DEATH_CHANNEL = b"driver_deaths" XRAY_HEARTBEAT_CHANNEL = str( ray.gcs_utils.TablePubsub.HEARTBEAT).encode("ascii") +# xray driver updates +XRAY_DRIVER_CHANNEL = str(ray.gcs_utils.TablePubsub.DRIVER).encode("ascii") + # common/redis_module/ray_redis_module.cc OBJECT_INFO_PREFIX = b"OI:" OBJECT_LOCATION_PREFIX = b"OL:" @@ -496,6 +499,87 @@ class Monitor(object): self._clean_up_entries_for_driver(driver_id) + def _xray_clean_up_entries_for_driver(self, driver_id): + """Remove this driver's object/task entries from redis. + + Removes control-state entries of all tasks and task return + objects belonging to the driver. + + Args: + driver_id: The driver id. + """ + + xray_task_table_prefix = ( + ray.gcs_utils.TablePrefix_RAYLET_TASK_string.encode("ascii")) + xray_object_table_prefix = ( + ray.gcs_utils.TablePrefix_OBJECT_string.encode("ascii")) + + task_table_objects = self.state.task_table() + driver_id_hex = binary_to_hex(driver_id) + driver_task_id_bins = set() + for task_id_hex in task_table_objects: + if len(task_table_objects[task_id_hex]) == 0: + continue + task_table_object = task_table_objects[task_id_hex][0]["TaskSpec"] + task_driver_id_hex = task_table_object["DriverID"] + if driver_id_hex != task_driver_id_hex: + # Ignore tasks that aren't from this driver. + continue + driver_task_id_bins.add(hex_to_binary(task_id_hex)) + + # Get objects associated with the driver. + object_table_objects = self.state.object_table() + driver_object_id_bins = set() + for object_id, object_table_object in object_table_objects.items(): + assert len(object_table_object) > 0 + task_id_bin = ray.local_scheduler.compute_task_id(object_id).id() + if task_id_bin in driver_task_id_bins: + driver_object_id_bins.add(object_id.id()) + + def to_shard_index(id_bin): + return binary_to_object_id(id_bin).redis_shard_hash() % len( + self.state.redis_clients) + + # Form the redis keys to delete. + sharded_keys = [[] for _ in range(len(self.state.redis_clients))] + for task_id_bin in driver_task_id_bins: + sharded_keys[to_shard_index(task_id_bin)].append( + xray_task_table_prefix + task_id_bin) + for object_id_bin in driver_object_id_bins: + sharded_keys[to_shard_index(object_id_bin)].append( + xray_object_table_prefix + object_id_bin) + + # Remove with best effort. + for shard_index in range(len(sharded_keys)): + keys = sharded_keys[shard_index] + if len(keys) == 0: + continue + redis = self.state.redis_clients[shard_index] + num_deleted = redis.delete(*keys) + log.info("Removed {} dead redis entries of the driver" + " from redis shard {}.".format(num_deleted, shard_index)) + if num_deleted != len(keys): + log.warning("Failed to remove {} relevant redis entries" + " from redis shard {}.".format( + len(keys) - num_deleted, shard_index)) + + def xray_driver_removed_handler(self, unused_channel, data): + """Handle a notification that a driver has been removed. + + Args: + unused_channel: The message channel. + data: The message data. + """ + gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( + data, 0) + driver_data = gcs_entries.Entries(0) + message = ray.gcs_utils.DriverTableData.GetRootAsDriverTableData( + driver_data, 0) + driver_id = message.DriverId() + log.info("XRay Driver {} has been removed.".format( + binary_to_hex(driver_id))) + self._xray_clean_up_entries_for_driver(driver_id) + def process_messages(self, max_messages=10000): """Process all messages ready in the subscription channels. @@ -537,6 +621,9 @@ class Monitor(object): elif channel == XRAY_HEARTBEAT_CHANNEL: # Similar functionality as local scheduler info channel message_handler = self.xray_heartbeat_handler + elif channel == XRAY_DRIVER_CHANNEL: + # Handles driver death. + message_handler = self.xray_driver_removed_handler else: raise Exception("This code should be unreachable.") @@ -582,7 +669,7 @@ class Monitor(object): max_entries_to_flush = self.gcs_flush_policy.num_entries_to_flush() num_flushed = self.redis_shard.execute_command( "HEAD.FLUSH {}".format(max_entries_to_flush)) - log.info('num_flushed {}'.format(num_flushed)) + log.info("num_flushed {}".format(num_flushed)) # This flushes event log and log files. ray.experimental.flush_redis_unsafe(self.redis) @@ -601,6 +688,7 @@ class Monitor(object): self.subscribe(PLASMA_MANAGER_HEARTBEAT_CHANNEL) self.subscribe(DRIVER_DEATH_CHANNEL) self.subscribe(XRAY_HEARTBEAT_CHANNEL, primary=False) + self.subscribe(XRAY_DRIVER_CHANNEL) # Scan the database table for dead database clients. NOTE: This must be # called before reading any messages from the subscription channel. diff --git a/src/common/lib/python/common_extension.cc b/src/common/lib/python/common_extension.cc index 31178160d..2bc379c37 100644 --- a/src/common/lib/python/common_extension.cc +++ b/src/common/lib/python/common_extension.cc @@ -907,3 +907,12 @@ PyObject *check_simple_value(PyObject *self, PyObject *args) { } Py_RETURN_FALSE; } + +PyObject *compute_task_id(PyObject *self, PyObject *args) { + ObjectID object_id; + if (!PyArg_ParseTuple(args, "O&", &PyObjectToUniqueID, &object_id)) { + return NULL; + } + TaskID task_id = ray::ComputeTaskId(object_id); + return PyObjectID_make(task_id); +} diff --git a/src/common/lib/python/common_extension.h b/src/common/lib/python/common_extension.h index b24e45a1f..0172135f1 100644 --- a/src/common/lib/python/common_extension.h +++ b/src/common/lib/python/common_extension.h @@ -56,6 +56,7 @@ int PyObjectToUniqueID(PyObject *object, ray::ObjectID *object_id); PyObject *PyObjectID_make(ray::ObjectID object_id); PyObject *check_simple_value(PyObject *self, PyObject *args); +PyObject *compute_task_id(PyObject *self, PyObject *args); PyObject *PyTask_to_string(PyObject *, PyObject *args); PyObject *PyTask_from_string(PyObject *, PyObject *args); diff --git a/src/local_scheduler/lib/python/local_scheduler_extension.cc b/src/local_scheduler/lib/python/local_scheduler_extension.cc index e2883b6ff..d5f9d1a7a 100644 --- a/src/local_scheduler/lib/python/local_scheduler_extension.cc +++ b/src/local_scheduler/lib/python/local_scheduler_extension.cc @@ -493,6 +493,8 @@ static PyTypeObject PyLocalSchedulerClientType = { static PyMethodDef local_scheduler_methods[] = { {"check_simple_value", check_simple_value, METH_VARARGS, "Should the object be passed by value?"}, + {"compute_task_id", compute_task_id, METH_VARARGS, + "Return the task ID of an object ID."}, {"task_from_string", PyTask_from_string, METH_VARARGS, "Creates a Python PyTask object from a string representation of " "TaskSpec."}, diff --git a/src/ray/gcs/client.cc b/src/ray/gcs/client.cc index d7041d328..88eda1a5d 100644 --- a/src/ray/gcs/client.cc +++ b/src/ray/gcs/client.cc @@ -17,6 +17,7 @@ AsyncGcsClient::AsyncGcsClient(const ClientID &client_id, CommandType command_ty task_reconstruction_log_.reset(new TaskReconstructionLog(context_, this)); task_lease_table_.reset(new TaskLeaseTable(context_, this)); heartbeat_table_.reset(new HeartbeatTable(context_, this)); + driver_table_.reset(new DriverTable(primary_context_, this)); error_table_.reset(new ErrorTable(primary_context_, this)); profile_table_.reset(new ProfileTable(context_, this)); command_type_ = command_type; @@ -88,6 +89,8 @@ HeartbeatTable &AsyncGcsClient::heartbeat_table() { return *heartbeat_table_; } ErrorTable &AsyncGcsClient::error_table() { return *error_table_; } +DriverTable &AsyncGcsClient::driver_table() { return *driver_table_; } + ProfileTable &AsyncGcsClient::profile_table() { return *profile_table_; } } // namespace gcs diff --git a/src/ray/gcs/client.h b/src/ray/gcs/client.h index 8a72dbed9..9da272e64 100644 --- a/src/ray/gcs/client.h +++ b/src/ray/gcs/client.h @@ -60,6 +60,7 @@ class RAY_EXPORT AsyncGcsClient { ClientTable &client_table(); HeartbeatTable &heartbeat_table(); ErrorTable &error_table(); + DriverTable &driver_table(); ProfileTable &profile_table(); // We also need something to export generic code to run on workers from the @@ -92,6 +93,7 @@ class RAY_EXPORT AsyncGcsClient { std::unique_ptr asio_subscribe_client_; // The following context writes everything to the primary shard std::shared_ptr primary_context_; + std::unique_ptr driver_table_; std::unique_ptr asio_async_auxiliary_client_; std::unique_ptr asio_subscribe_auxiliary_client_; CommandType command_type_; diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index 72b2a62bd..f0cbddae0 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -4,6 +4,7 @@ enum Language:int { JAVA = 2 } +// These indexes are mapped to strings in ray_redis_module.cc. enum TablePrefix:int { UNUSED = 0, TASK, @@ -15,6 +16,7 @@ enum TablePrefix:int { TASK_RECONSTRUCTION, HEARTBEAT, ERROR_INFO, + DRIVER, PROFILE, TASK_LEASE, } @@ -30,6 +32,7 @@ enum TablePubsub:int { HEARTBEAT, ERROR_INFO, TASK_LEASE, + DRIVER, } table GcsTableEntry { @@ -202,3 +205,10 @@ table TaskLeaseData { // The period that the lease is active for. timeout: long; } + +table DriverTableData { + // The driver ID. + driver_id: string; + // Whether it's dead. + is_dead: bool; +} diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index 8b21d9e59..dd09a2a3b 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -266,6 +266,17 @@ Status ProfileTable::AddProfileEventBatch(const ProfileTableData &profile_events }); } +Status DriverTable::AppendDriverData(const JobID &driver_id, bool is_dead) { + auto data = std::make_shared(); + data->driver_id = driver_id.binary(); + data->is_dead = is_dead; + return Append(driver_id, driver_id, data, + [](ray::gcs::AsyncGcsClient *client, const JobID &id, + const DriverTableDataT &data) { + RAY_LOG(DEBUG) << "Driver entry added callback"; + }); +} + void ClientTable::RegisterClientAddedCallback(const ClientTableCallback &callback) { client_added_callback_ = callback; // Call the callback for any added clients that are cached. @@ -425,6 +436,7 @@ template class Table; template class Table; template class Log; template class Log; +template class Log; template class Log; } // namespace gcs diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index 50424d661..d5ccc7a99 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -317,6 +317,23 @@ class HeartbeatTable : public Table { virtual ~HeartbeatTable() {} }; +class DriverTable : public Log { + public: + DriverTable(const std::shared_ptr &context, AsyncGcsClient *client) + : Log(context, client) { + pubsub_channel_ = TablePubsub::DRIVER; + prefix_ = TablePrefix::DRIVER; + }; + virtual ~DriverTable() {} + + /// Appends driver data to the driver table. + /// + /// \param driver_id The driver id. + /// \param is_dead Whether the driver is dead. + /// \return The return status. + Status AppendDriverData(const JobID &driver_id, bool is_dead); +}; + class FunctionTable : public Table { public: FunctionTable(const std::shared_ptr &context, AsyncGcsClient *client) diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index 32099d67b..aeaa528c3 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -25,8 +25,10 @@ ObjectManager::ObjectManager(asio::io_service &main_service, RAY_CHECK(config_.max_sends > 0); RAY_CHECK(config_.max_receives > 0); main_service_ = &main_service; - store_notification_.SubscribeObjAdded( - [this](const ObjectInfoT &object_info) { NotifyDirectoryObjectAdd(object_info); }); + store_notification_.SubscribeObjAdded([this](const ObjectInfoT &object_info) { + NotifyDirectoryObjectAdd(object_info); + HandleUnfulfilledPushRequests(object_info); + }); store_notification_.SubscribeObjDeleted( [this](const ObjectID &oid) { NotifyDirectoryObjectDeleted(oid); }); StartIOService(); @@ -49,8 +51,10 @@ ObjectManager::ObjectManager(asio::io_service &main_service, RAY_CHECK(config_.max_receives > 0); // TODO(hme) Client ID is never set with this constructor. main_service_ = &main_service; - store_notification_.SubscribeObjAdded( - [this](const ObjectInfoT &object_info) { NotifyDirectoryObjectAdd(object_info); }); + store_notification_.SubscribeObjAdded([this](const ObjectInfoT &object_info) { + NotifyDirectoryObjectAdd(object_info); + HandleUnfulfilledPushRequests(object_info); + }); store_notification_.SubscribeObjDeleted( [this](const ObjectID &oid) { NotifyDirectoryObjectDeleted(oid); }); StartIOService(); @@ -89,6 +93,10 @@ void ObjectManager::NotifyDirectoryObjectAdd(const ObjectInfoT &object_info) { local_objects_[object_id] = object_info; ray::Status status = object_directory_->ReportObjectAdded(object_id, client_id_, object_info); +} + +void ObjectManager::HandleUnfulfilledPushRequests(const ObjectInfoT &object_info) { + ObjectID object_id = ObjectID::from_binary(object_info.object_id); // Handle the unfulfilled_push_requests_ which contains the push request that is not // completed due to unsatisfied local objects. auto iter = unfulfilled_push_requests_.find(object_id); diff --git a/src/ray/object_manager/object_manager.h b/src/ray/object_manager/object_manager.h index 1ccd2454e..4f3d707d0 100644 --- a/src/ray/object_manager/object_manager.h +++ b/src/ray/object_manager/object_manager.h @@ -266,6 +266,10 @@ class ObjectManager : public ObjectManagerInterface { /// Register object remove with directory. void NotifyDirectoryObjectDeleted(const ObjectID &object_id); + /// Handle any push requests that were made before an object was available. + /// This is invoked when an "object added" notification is received from the store. + void HandleUnfulfilledPushRequests(const ObjectInfoT &object_info); + /// Part of an asynchronous sequence of Pull methods. /// Uses an existing connection or creates a connection to ClientID. /// Executes on main_service_ thread. diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index b2c6f91e4..bdfbacf7b 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -193,12 +193,33 @@ ray::Status NodeManager::RegisterGcs() { RAY_LOG(DEBUG) << "heartbeat table subscription done callback called."; })); + // Subscribe to driver table updates. + const auto driver_table_handler = [this]( + gcs::AsyncGcsClient *client, const ClientID &client_id, + const std::vector &driver_data) { + HandleDriverTableUpdate(client_id, driver_data); + }; + RAY_RETURN_NOT_OK(gcs_client_->driver_table().Subscribe(JobID::nil(), UniqueID::nil(), + driver_table_handler, nullptr)); + // Start sending heartbeats to the GCS. Heartbeat(); return ray::Status::OK(); } +void NodeManager::HandleDriverTableUpdate( + const ClientID &id, const std::vector &driver_data) { + for (const auto &entry : driver_data) { + RAY_LOG(DEBUG) << "HandleDriverTableUpdate " << UniqueID::from_binary(entry.driver_id) + << " " << entry.is_dead; + if (entry.is_dead) { + // TODO: Implement cleanup on driver death. For reference, + // see handle_driver_removed_callback in local_scheduler.cc + } + } +} + void NodeManager::Heartbeat() { RAY_LOG(DEBUG) << "[Heartbeat] sending heartbeat."; auto &heartbeat_table = gcs_client_->heartbeat_table(); @@ -449,6 +470,7 @@ void NodeManager::ProcessClientMessage( switch (static_cast(message_type)) { case protocol::MessageType::RegisterClientRequest: { auto message = flatbuffers::GetRoot(message_data); + client->SetClientID(from_flatbuf(*message->client_id())); auto worker = std::make_shared(message->worker_pid(), client); if (message->is_worker()) { // Register the new worker. @@ -543,6 +565,8 @@ void NodeManager::ProcessClientMessage( DispatchTasks(); } else { // The client is a driver. + RAY_CHECK_OK(gcs_client_->driver_table().AppendDriverData(client->GetClientID(), + /*is_dead=*/true)); const std::shared_ptr driver = worker_pool_.GetRegisteredDriver(client); RAY_CHECK(driver); auto driver_id = driver->GetAssignedTaskId(); diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index dd7d23a91..171aad929 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -144,6 +144,10 @@ class NodeManager { /// accounting, but does not write to any global accounting in the GCS. void HandleObjectMissing(const ObjectID &object_id); + /// Handles updates to driver table. + void HandleDriverTableUpdate(const ClientID &id, + const std::vector &driver_data); + boost::asio::io_service &io_service_; ObjectManager &object_manager_; /// A Plasma object store client. This is used exclusively for creating new diff --git a/test/monitor_test.py b/test/monitor_test.py index b4680a575..588676c1c 100644 --- a/test/monitor_test.py +++ b/test/monitor_test.py @@ -41,11 +41,18 @@ class MonitorTest(unittest.TestCase): if (0, 1) != summary_start[:2]: success.value = False + max_attempts_before_failing = 100 + # Two new objects. ray.get(ray.put(1111)) ray.get(ray.put(1111)) - if (2, 1, summary_start[2]) != StateSummary(): - success.value = False + attempts = 0 + while (2, 1, summary_start[2]) != StateSummary(): + time.sleep(0.1) + attempts += 1 + if attempts == max_attempts_before_failing: + success.value = False + break @ray.remote def f(): @@ -53,12 +60,22 @@ class MonitorTest(unittest.TestCase): return 1111 # A returned object as well. # 1 new function. - if (2, 1, summary_start[2] + 1) != StateSummary(): - success.value = False + attempts = 0 + while (2, 1, summary_start[2] + 1) != StateSummary(): + time.sleep(0.1) + attempts += 1 + if attempts == max_attempts_before_failing: + success.value = False + break ray.get(f.remote()) - if (4, 2, summary_start[2] + 1) != StateSummary(): - success.value = False + attempts = 0 + while (4, 2, summary_start[2] + 1) != StateSummary(): + time.sleep(0.1) + attempts += 1 + if attempts == max_attempts_before_failing: + success.value = False + break ray.shutdown() @@ -67,7 +84,7 @@ class MonitorTest(unittest.TestCase): driver.start() # Wait for client to exit. driver.join() - time.sleep(5) + time.sleep(3) # Just make sure Driver() is run and succeeded. Note(rkn), if the below # assertion starts failing, then the issue may be that the summary @@ -85,13 +102,16 @@ class MonitorTest(unittest.TestCase): subprocess.Popen(["ray", "stop"]).wait() @unittest.skipIf( - os.environ.get('RAY_USE_NEW_GCS', False), + os.environ.get("RAY_USE_NEW_GCS", False), "Failing with the new GCS API.") def testCleanupOnDriverExitSingleRedisShard(self): self._testCleanupOnDriverExit(num_redis_shards=1) @unittest.skipIf( - os.environ.get('RAY_USE_NEW_GCS', False), + os.environ.get("RAY_USE_XRAY") == "1", + "This test does not work with xray yet.") + @unittest.skipIf( + os.environ.get("RAY_USE_NEW_GCS", False), "Hanging with the new GCS API.") def testCleanupOnDriverExitManyRedisShards(self): self._testCleanupOnDriverExit(num_redis_shards=5)