diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index e80581c1f..d8de9a086 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -88,6 +88,8 @@ public abstract class AbstractRayRuntime implements RayRuntime { @Override public List get(List objectIds) { boolean wasBlocked = false; + // TODO(swang): If we are not on the main thread, then we should generate a + // random task ID to pass to the backend. UniqueId taskId = workerContext.getCurrentTask().taskId; try { @@ -97,7 +99,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { List> fetchBatches = splitIntoBatches(objectIds, FETCH_BATCH_SIZE); for (List batch : fetchBatches) { - rayletClient.reconstructObjects(batch, true); + rayletClient.fetchOrReconstruct(batch, true, taskId); } // Get the objects. We initially try to get the objects immediately. @@ -122,7 +124,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { splitIntoBatches(unreadyList, FETCH_BATCH_SIZE); for (List batch : reconstructBatches) { - rayletClient.reconstructObjects(batch, false); + rayletClient.fetchOrReconstruct(batch, false, taskId); } List> results = objectStoreProxy @@ -157,7 +159,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { // If there were objects that we weren't able to get locally, let the local // scheduler know that we're now unblocked. if (wasBlocked) { - rayletClient.notifyUnblocked(); + rayletClient.notifyUnblocked(taskId); } } } @@ -185,7 +187,10 @@ public abstract class AbstractRayRuntime implements RayRuntime { @Override public WaitResult wait(List> waitList, int numReturns, int timeoutMs) { - return rayletClient.wait(waitList, numReturns, timeoutMs); + // TODO(swang): If we are not on the main thread, then we should generate a + // random task ID to pass to the backend. + return rayletClient.wait(waitList, numReturns, timeoutMs, + workerContext.getCurrentTask().taskId); } @Override diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java index 95a8abdf4..0da3dbe80 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java @@ -66,12 +66,13 @@ public class MockRayletClient implements RayletClient { } @Override - public void reconstructObjects(List objectIds, boolean fetchOnly) { + public void fetchOrReconstruct(List objectIds, boolean fetchOnly, + UniqueId currentTaskId) { } @Override - public void notifyUnblocked() { + public void notifyUnblocked(UniqueId currentTaskId) { } @@ -81,7 +82,8 @@ public class MockRayletClient implements RayletClient { } @Override - public WaitResult wait(List> waitFor, int numReturns, int timeoutMs) { + public WaitResult wait(List> waitFor, int numReturns, int + timeoutMs, UniqueId currentTaskId) { return new WaitResult( waitFor, ImmutableList.of() diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java index baa32a142..b68fe0182 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java @@ -15,13 +15,14 @@ public interface RayletClient { TaskSpec getTask(); - void reconstructObjects(List objectIds, boolean fetchOnly); + void fetchOrReconstruct(List objectIds, boolean fetchOnly, UniqueId currentTaskId); - void notifyUnblocked(); + void notifyUnblocked(UniqueId currentTaskId); UniqueId generateTaskId(UniqueId driverId, UniqueId parentTaskId, int taskIndex); - WaitResult wait(List> waitFor, int numReturns, int timeoutMs); + WaitResult wait(List> waitFor, int numReturns, int + timeoutMs, UniqueId currentTaskId); void freePlasmaObjects(List objectIds, boolean localOnly); } diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index 28f0cd97c..9cf70c348 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -44,14 +44,15 @@ public class RayletClientImpl implements RayletClient { } @Override - public WaitResult wait(List> waitFor, int numReturns, int timeoutMs) { + public WaitResult wait(List> waitFor, int numReturns, int + timeoutMs, UniqueId currentTaskId) { List ids = new ArrayList<>(); for (RayObject element : waitFor) { ids.add(element.getId()); } boolean[] ready = nativeWaitObject(client, UniqueIdUtil.getIdBytes(ids), - numReturns, timeoutMs, false); + numReturns, timeoutMs, false, currentTaskId.getBytes()); List> readyList = new ArrayList<>(); List> unreadyList = new ArrayList<>(); @@ -87,12 +88,14 @@ public class RayletClientImpl implements RayletClient { } @Override - public void reconstructObjects(List objectIds, boolean fetchOnly) { + public void fetchOrReconstruct(List objectIds, boolean fetchOnly, + UniqueId currentTaskId) { if (RayLog.core.isInfoEnabled()) { - RayLog.core.info("Reconstructing objects for task {}, object IDs are {}", + RayLog.core.info("Blocked on objects for task {}, object IDs are {}", UniqueIdUtil.computeTaskId(objectIds.get(0)), objectIds); } - nativeReconstructObjects(client, UniqueIdUtil.getIdBytes(objectIds), fetchOnly); + nativeFetchOrReconstruct(client, UniqueIdUtil.getIdBytes(objectIds), + fetchOnly, currentTaskId.getBytes()); } @Override @@ -102,8 +105,8 @@ public class RayletClientImpl implements RayletClient { } @Override - public void notifyUnblocked() { - nativeNotifyUnblocked(client); + public void notifyUnblocked(UniqueId currentTaskId) { + nativeNotifyUnblocked(client, currentTaskId.getBytes()); } @Override @@ -271,15 +274,15 @@ public class RayletClientImpl implements RayletClient { private static native void nativeDestroy(long client); - private static native void nativeReconstructObjects(long client, byte[][] objectIds, - boolean fetchOnly); + private static native void nativeFetchOrReconstruct(long client, byte[][] objectIds, + boolean fetchOnly, byte[] currentTaskId); - private static native void nativeNotifyUnblocked(long client); + private static native void nativeNotifyUnblocked(long client, byte[] currentTaskId); private static native void nativePutObject(long client, byte[] taskId, byte[] objectId); private static native boolean[] nativeWaitObject(long conn, byte[][] objectIds, - int numReturns, int timeout, boolean waitLocal); + int numReturns, int timeout, boolean waitLocal, byte[] currentTaskId); private static native byte[] nativeGenerateTaskId(byte[] driverId, byte[] parentTaskId, int taskIndex); diff --git a/python/ray/experimental/sgd/util.py b/python/ray/experimental/sgd/util.py index 82bc60a25..c8df01cb3 100644 --- a/python/ray/experimental/sgd/util.py +++ b/python/ray/experimental/sgd/util.py @@ -36,7 +36,7 @@ def fetch(oids): local_sched_client = ray.worker.global_worker.local_scheduler_client for o in oids: ray_obj_id = ray.ObjectID(o) - local_sched_client.reconstruct_objects([ray_obj_id], True) + local_sched_client.fetch_or_reconstruct([ray_obj_id], True) def run_timeline(sess, ops, feed_dict=None, write_timeline=False, name=""): diff --git a/python/ray/rllib/utils/actors.py b/python/ray/rllib/utils/actors.py index 1e19b703d..701807331 100644 --- a/python/ray/rllib/utils/actors.py +++ b/python/ray/rllib/utils/actors.py @@ -40,7 +40,7 @@ class TaskPool(object): for worker, obj_id in self.completed(): plasma_id = ray.pyarrow.plasma.ObjectID(obj_id.id()) (ray.worker.global_worker.local_scheduler_client. - reconstruct_objects([obj_id], True)) + fetch_or_reconstruct([obj_id], True)) self._fetching.append((worker, obj_id)) remaining = [] diff --git a/python/ray/utils.py b/python/ray/utils.py index a568cb9a2..e75e00672 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -423,3 +423,7 @@ def thread_safe_client(client, lock=None): if lock is None: lock = threading.Lock() return _ThreadSafeProxy(client, lock) + + +def is_main_thread(): + return threading.current_thread().getName() == "MainThread" diff --git a/python/ray/worker.py b/python/ray/worker.py index 1fd4043c2..105332a72 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -217,9 +217,38 @@ class Worker(object): # A dictionary that maps from driver id to SerializationContext # TODO: clean up the SerializationContext once the job finished. self.serialization_context_map = {} - # Identity of the driver that this worker is processing. - self.task_driver_id = None self.function_actor_manager = FunctionActorManager(self) + # Reads/writes to the following fields must be protected by + # self.state_lock. + # Identity of the driver that this worker is processing. + self.task_driver_id = ray.ObjectID(NIL_ID) + self.current_task_id = ray.ObjectID(NIL_ID) + self.task_index = 0 + self.put_index = 1 + + def get_current_thread_task_id(self): + """Get the current thread's task ID. + + This returns the assigned task ID if called on the main thread, else a + random task ID. This method is not thread-safe and must be called with + self.state_lock acquired. + """ + current_task_id = self.current_task_id + if not ray.utils.is_main_thread(): + # If this is running on a separate thread, then the mapping + # to the current task ID may not be correct. Generate a + # random task ID so that the backend can differentiate + # between different threads. + current_task_id = ray.ObjectID(random_string()) + if not self.multithreading_warned: + logger.warning( + "Calling ray.get or ray.wait in a separate thread " + "may lead to deadlock if the main thread blocks on this " + "thread and there are not enough resources to execute " + "more tasks") + self.multithreading_warned = True + assert not current_task_id.is_nil() + return current_task_id def mark_actor_init_failed(self, error): """Called to mark this actor as failed during initialization.""" @@ -456,7 +485,7 @@ class Worker(object): ] for i in range(0, len(object_ids), ray._config.worker_fetch_request_size()): - self.local_scheduler_client.reconstruct_objects( + self.local_scheduler_client.fetch_or_reconstruct( object_ids[i:(i + ray._config.worker_fetch_request_size())], True) @@ -472,6 +501,9 @@ class Worker(object): if len(unready_ids) > 0: with self.state_lock: + # Get the task ID, to notify the backend which task is blocked. + current_task_id = self.get_current_thread_task_id() + # Try reconstructing any objects we haven't gotten yet. Try to # get them until at least get_timeout_milliseconds # milliseconds passes, then repeat. @@ -488,9 +520,10 @@ class Worker(object): ray._config.worker_fetch_request_size()) for i in range(0, len(object_ids_to_fetch), fetch_request_size): - self.local_scheduler_client.reconstruct_objects( + self.local_scheduler_client.fetch_or_reconstruct( ray_object_ids_to_fetch[i:( - i + fetch_request_size)], False) + i + fetch_request_size)], False, + current_task_id) results = self.retrieve_and_deserialize( object_ids_to_fetch, max([ @@ -508,7 +541,7 @@ class Worker(object): # If there were objects that we weren't able to get locally, # let the local scheduler know that we're now unblocked. - self.local_scheduler_client.notify_unblocked() + self.local_scheduler_client.notify_unblocked(current_task_id) assert len(final_results) == len(object_ids) return final_results @@ -615,6 +648,8 @@ class Worker(object): # have been submitted by the current task so far. task_index = self.task_index self.task_index += 1 + # The parent task must be set for the submitted task. + assert not self.current_task_id.is_nil() # Submit the task to local scheduler. task = ray.raylet.Task( driver_id, ray.ObjectID( @@ -762,13 +797,18 @@ class Worker(object): (these will be retrieved by calls to get or by subsequent tasks that use the outputs of this task). """ - # The ID of the driver that this task belongs to. This is needed so - # that if the task throws an exception, we propagate the error - # message to the correct driver. - self.task_driver_id = task.driver_id() - self.current_task_id = task.task_id() - self.task_index = 0 - self.put_index = 1 + with self.state_lock: + assert self.task_driver_id.is_nil() + assert self.current_task_id.is_nil() + assert self.task_index == 0 + assert self.put_index == 1 + + # The ID of the driver that this task belongs to. This is needed so + # that if the task throws an exception, we propagate the error + # message to the correct driver. + self.task_driver_id = task.driver_id() + self.current_task_id = task.task_id() + function_id = task.function_id() args = task.arguments() return_object_ids = task.returns() @@ -912,6 +952,12 @@ class Worker(object): with profiling.profile("task", extra_data=extra_data, worker=self): with _changeproctitle(title): self._process_task(task, execution_info) + # Reset the state fields so the next task can run. + with self.state_lock: + self.task_driver_id = ray.ObjectID(NIL_ID) + self.current_task_id = ray.ObjectID(NIL_ID) + self.task_index = 0 + self.put_index = 1 # Increase the task execution counter. self.function_actor_manager.increase_task_counter( @@ -2044,6 +2090,9 @@ def connect(info, else: # A non-driver worker begins without an assigned task. worker.current_task_id = ray.ObjectID(NIL_ID) + # A flag for making sure that we only print one warning message about + # multithreading per worker. + worker.multithreading_warned = False worker.local_scheduler_client = ray.raylet.LocalSchedulerClient( local_scheduler_socket, worker.worker_id, is_worker, @@ -2376,6 +2425,7 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): type(object_id))) worker.check_connected() + # TODO(swang): Check main thread. with profiling.profile("ray.wait", worker=worker): # When Ray is run in LOCAL_MODE, all functions are run immediately, # so all objects in object_id are ready. @@ -2396,9 +2446,14 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): if num_returns > len(object_ids): raise Exception("num_returns cannot be greater than the number " "of objects provided to ray.wait.") + + # Get the task ID, to notify the backend which task is blocked. + with worker.state_lock: + current_task_id = worker.get_current_thread_task_id() + timeout = timeout if timeout is not None else 2**30 ready_ids, remaining_ids = worker.local_scheduler_client.wait( - object_ids, num_returns, timeout, False) + object_ids, num_returns, timeout, False, current_task_id) return ready_ids, remaining_ids diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index cc1e9938a..1e62202d7 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -37,7 +37,7 @@ enum MessageType:int { ExecuteTask, // Reconstruct or fetch possibly lost objects. This is sent from a worker to // a local scheduler. - ReconstructObjects, + FetchOrReconstruct, // For a worker that was blocked on some object(s), tell the local scheduler // that the worker is now unblocked. This is sent from a worker to a local // scheduler. @@ -150,11 +150,18 @@ table ForwardTaskRequest { uncommitted_tasks: [Task]; } -table ReconstructObjects { +table FetchOrReconstruct { // List of object IDs of the objects that we want to reconstruct or fetch. object_ids: [string]; // Do we only want to fetch the objects or also reconstruct them? fetch_only: bool; + // The current task ID. If fetch_only is false, then this task is blocked. + task_id: string; +} + +table NotifyUnblocked { + // The current task ID. This task is no longer blocked. + task_id: string; } table WaitRequest { @@ -166,6 +173,9 @@ table WaitRequest { timeout: long; // Whether to wait until objects appear locally. wait_local: bool; + // The current task ID. If there are less than num_ready_objects local, then + // this task is blocked. + task_id: string; } table WaitReply { diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc index 7446b3af9..12388d181 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc @@ -118,12 +118,13 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeDestro /* * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeReconstructObjects + * Method: nativeFetchOrReconstruct * Signature: (J[[BZ)V */ JNIEXPORT void JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeReconstructObjects( - JNIEnv *env, jclass, jlong client, jobjectArray objectIds, jboolean fetchOnly) { +Java_org_ray_runtime_raylet_RayletClientImpl_nativeFetchOrReconstruct( + JNIEnv *env, jclass, jlong client, jobjectArray objectIds, jboolean fetchOnly, + jbyteArray currentTaskId) { std::vector object_ids; auto len = env->GetArrayLength(objectIds); for (int i = 0; i < len; i++) { @@ -133,8 +134,9 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeReconstructObjects( object_ids.push_back(*object_id.PID); env->DeleteLocalRef(object_id_bytes); } + UniqueIdFromJByteArray current_task_id(env, currentTaskId); auto conn = reinterpret_cast(client); - local_scheduler_reconstruct_objects(conn, object_ids, fetchOnly); + local_scheduler_fetch_or_reconstruct(conn, object_ids, fetchOnly, *current_task_id.PID); } /* @@ -143,9 +145,10 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeReconstructObjects( * Signature: (J)V */ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyUnblocked( - JNIEnv *, jclass, jlong client) { + JNIEnv *env, jclass, jlong client, jbyteArray currentTaskId) { + UniqueIdFromJByteArray current_task_id(env, currentTaskId); auto conn = reinterpret_cast(client); - local_scheduler_notify_unblocked(conn); + local_scheduler_notify_unblocked(conn, *current_task_id.PID); } /* @@ -156,7 +159,7 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotify JNIEXPORT jbooleanArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeWaitObject( JNIEnv *env, jclass, jlong client, jobjectArray objectIds, jint numReturns, - jint timeoutMillis, jboolean isWaitLocal) { + jint timeoutMillis, jboolean isWaitLocal, jbyteArray currentTaskId) { std::vector object_ids; auto len = env->GetArrayLength(objectIds); for (int i = 0; i < len; i++) { @@ -166,12 +169,14 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeWaitObject( object_ids.push_back(*object_id.PID); env->DeleteLocalRef(object_id_bytes); } + UniqueIdFromJByteArray current_task_id(env, currentTaskId); auto conn = reinterpret_cast(client); // Invoke wait. - std::pair, std::vector> result = local_scheduler_wait( - conn, object_ids, numReturns, timeoutMillis, static_cast(isWaitLocal)); + std::pair, std::vector> result = + local_scheduler_wait(conn, object_ids, numReturns, timeoutMillis, + static_cast(isWaitLocal), *current_task_id.PID); // Convert result to java object. jboolean put_value = true; diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h index abe412cf3..8940046ce 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h @@ -41,13 +41,14 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeDestroy(JNIEnv *, jclass, jlo /* * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeReconstructObjects + * Method: nativeFetchOrReconstruct * Signature: (J[[BZ)V */ JNIEXPORT void JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeReconstructObjects(JNIEnv *, jclass, +Java_org_ray_runtime_raylet_RayletClientImpl_nativeFetchOrReconstruct(JNIEnv *, jclass, jlong, jobjectArray, - jboolean); + jboolean, + jbyteArray); /* * Class: org_ray_runtime_raylet_RayletClientImpl @@ -55,7 +56,7 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeReconstructObjects(JNIEnv *, * Signature: (J)V */ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyUnblocked( - JNIEnv *, jclass, jlong); + JNIEnv *, jclass, jlong, jbyteArray); /* * Class: org_ray_runtime_raylet_RayletClientImpl @@ -65,7 +66,7 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotify JNIEXPORT jbooleanArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeWaitObject(JNIEnv *, jclass, jlong, jobjectArray, jint, jint, - jboolean); + jboolean, jbyteArray); /* * Class: org_ray_runtime_raylet_RayletClientImpl diff --git a/src/ray/raylet/lib/python/common_extension.cc b/src/ray/raylet/lib/python/common_extension.cc index 1ca9d5d8f..f4979620c 100644 --- a/src/ray/raylet/lib/python/common_extension.cc +++ b/src/ray/raylet/lib/python/common_extension.cc @@ -176,6 +176,16 @@ static PyObject *PyObjectID_hex(PyObject *self) { return result; } +static PyObject *PyObjectID_is_nil(PyObject *self) { + ObjectID object_id; + PyObjectToUniqueID(self, &object_id); + if (object_id.is_nil()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} + static PyObject *PyObjectID_richcompare(PyObjectID *self, PyObject *other, int op) { PyObject *result = NULL; if (Py_TYPE(self)->tp_richcompare != Py_TYPE(other)->tp_richcompare) { @@ -245,6 +255,8 @@ static PyMethodDef PyObjectID_methods[] = { "Return the redis shard that this ObjectID is associated with"}, {"hex", (PyCFunction)PyObjectID_hex, METH_NOARGS, "Return the object ID as a string in hex."}, + {"is_nil", (PyCFunction)PyObjectID_is_nil, METH_NOARGS, + "Return whether the ObjectID is nil"}, {"__reduce__", (PyCFunction)PyObjectID___reduce__, METH_NOARGS, "Say how to pickle this ObjectID. This raises an exception to prevent" "object IDs from being serialized."}, diff --git a/src/ray/raylet/lib/python/local_scheduler_extension.cc b/src/ray/raylet/lib/python/local_scheduler_extension.cc index 2bde78ab2..05d24cdb4 100644 --- a/src/ray/raylet/lib/python/local_scheduler_extension.cc +++ b/src/ray/raylet/lib/python/local_scheduler_extension.cc @@ -72,12 +72,14 @@ static PyObject *PyLocalSchedulerClient_get_task(PyObject *self) { } // clang-format on -static PyObject *PyLocalSchedulerClient_reconstruct_objects(PyObject *self, - PyObject *args) { +static PyObject *PyLocalSchedulerClient_fetch_or_reconstruct(PyObject *self, + PyObject *args) { PyObject *py_object_ids; PyObject *py_fetch_only; std::vector object_ids; - if (!PyArg_ParseTuple(args, "OO", &py_object_ids, &py_fetch_only)) { + TaskID current_task_id; + if (!PyArg_ParseTuple(args, "OO|O&", &py_object_ids, &py_fetch_only, + &PyObjectToUniqueID, ¤t_task_id)) { return NULL; } bool fetch_only = PyObject_IsTrue(py_fetch_only); @@ -90,15 +92,19 @@ static PyObject *PyLocalSchedulerClient_reconstruct_objects(PyObject *self, } object_ids.push_back(object_id); } - local_scheduler_reconstruct_objects( + local_scheduler_fetch_or_reconstruct( reinterpret_cast(self)->local_scheduler_connection, - object_ids, fetch_only); + object_ids, fetch_only, current_task_id); Py_RETURN_NONE; } -static PyObject *PyLocalSchedulerClient_notify_unblocked(PyObject *self) { +static PyObject *PyLocalSchedulerClient_notify_unblocked(PyObject *self, PyObject *args) { + TaskID current_task_id; + if (!PyArg_ParseTuple(args, "O&", &PyObjectToUniqueID, ¤t_task_id)) { + return NULL; + } local_scheduler_notify_unblocked( - ((PyLocalSchedulerClient *)self)->local_scheduler_connection); + ((PyLocalSchedulerClient *)self)->local_scheduler_connection, current_task_id); Py_RETURN_NONE; } @@ -160,9 +166,10 @@ static PyObject *PyLocalSchedulerClient_wait(PyObject *self, PyObject *args) { int num_returns; int64_t timeout_ms; PyObject *py_wait_local; + TaskID current_task_id; - if (!PyArg_ParseTuple(args, "OilO", &py_object_ids, &num_returns, &timeout_ms, - &py_wait_local)) { + if (!PyArg_ParseTuple(args, "OilOO&", &py_object_ids, &num_returns, &timeout_ms, + &py_wait_local, &PyObjectToUniqueID, ¤t_task_id)) { return NULL; } @@ -190,7 +197,7 @@ static PyObject *PyLocalSchedulerClient_wait(PyObject *self, PyObject *args) { // Invoke wait. std::pair, std::vector> result = local_scheduler_wait( reinterpret_cast(self)->local_scheduler_connection, - object_ids, num_returns, timeout_ms, static_cast(wait_local)); + object_ids, num_returns, timeout_ms, wait_local, current_task_id); // Convert result to py object. PyObject *py_found = PyList_New(static_cast(result.first.size())); @@ -364,10 +371,10 @@ static PyMethodDef PyLocalSchedulerClient_methods[] = { "Submit a task to the local scheduler."}, {"get_task", (PyCFunction)PyLocalSchedulerClient_get_task, METH_NOARGS, "Get a task from the local scheduler."}, - {"reconstruct_objects", (PyCFunction)PyLocalSchedulerClient_reconstruct_objects, + {"fetch_or_reconstruct", (PyCFunction)PyLocalSchedulerClient_fetch_or_reconstruct, METH_VARARGS, "Ask the local scheduler to reconstruct an object."}, {"notify_unblocked", (PyCFunction)PyLocalSchedulerClient_notify_unblocked, - METH_NOARGS, "Notify the local scheduler that we are unblocked."}, + METH_VARARGS, "Notify the local scheduler that we are unblocked."}, {"compute_put_id", (PyCFunction)PyLocalSchedulerClient_compute_put_id, METH_VARARGS, "Return the object ID for a put call within a task."}, {"gpu_ids", (PyCFunction)PyLocalSchedulerClient_gpu_ids, METH_NOARGS, diff --git a/src/ray/raylet/local_scheduler_client.cc b/src/ray/raylet/local_scheduler_client.cc index 1f6c59300..1f481a79e 100644 --- a/src/ray/raylet/local_scheduler_client.cc +++ b/src/ray/raylet/local_scheduler_client.cc @@ -304,31 +304,39 @@ void local_scheduler_task_done(LocalSchedulerConnection *conn) { &conn->write_mutex); } -void local_scheduler_reconstruct_objects(LocalSchedulerConnection *conn, - const std::vector &object_ids, - bool fetch_only) { +void local_scheduler_fetch_or_reconstruct(LocalSchedulerConnection *conn, + const std::vector &object_ids, + bool fetch_only, + const TaskID ¤t_task_id) { flatbuffers::FlatBufferBuilder fbb; auto object_ids_message = to_flatbuf(fbb, object_ids); - auto message = - ray::protocol::CreateReconstructObjects(fbb, object_ids_message, fetch_only); + auto message = ray::protocol::CreateFetchOrReconstruct( + fbb, object_ids_message, fetch_only, to_flatbuf(fbb, current_task_id)); fbb.Finish(message); - write_message(conn->conn, static_cast(MessageType::ReconstructObjects), + write_message(conn->conn, static_cast(MessageType::FetchOrReconstruct), fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); /* TODO(swang): Propagate the error. */ } -void local_scheduler_notify_unblocked(LocalSchedulerConnection *conn) { - write_message(conn->conn, static_cast(MessageType::NotifyUnblocked), 0, NULL, - &conn->write_mutex); +void local_scheduler_notify_unblocked(LocalSchedulerConnection *conn, + const TaskID ¤t_task_id) { + flatbuffers::FlatBufferBuilder fbb; + auto message = + ray::protocol::CreateNotifyUnblocked(fbb, to_flatbuf(fbb, current_task_id)); + fbb.Finish(message); + write_message(conn->conn, static_cast(MessageType::NotifyUnblocked), + fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); } std::pair, std::vector> local_scheduler_wait( LocalSchedulerConnection *conn, const std::vector &object_ids, - int num_returns, int64_t timeout_milliseconds, bool wait_local) { + int num_returns, int64_t timeout_milliseconds, bool wait_local, + const TaskID ¤t_task_id) { // Write request. flatbuffers::FlatBufferBuilder fbb; auto message = ray::protocol::CreateWaitRequest( - fbb, to_flatbuf(fbb, object_ids), num_returns, timeout_milliseconds, wait_local); + fbb, to_flatbuf(fbb, object_ids), num_returns, timeout_milliseconds, wait_local, + to_flatbuf(fbb, current_task_id)); fbb.Finish(message); int64_t type; int64_t reply_size; diff --git a/src/ray/raylet/local_scheduler_client.h b/src/ray/raylet/local_scheduler_client.h index c1f2511a9..66c76f37a 100644 --- a/src/ray/raylet/local_scheduler_client.h +++ b/src/ray/raylet/local_scheduler_client.h @@ -96,19 +96,22 @@ void local_scheduler_task_done(LocalSchedulerConnection *conn); * @param conn The connection information. * @param object_ids The IDs of the objects to reconstruct. * @param fetch_only Only fetch objects, do not reconstruct them. + * @param current_task_id The task that needs the objects. * @return Void. */ -void local_scheduler_reconstruct_objects(LocalSchedulerConnection *conn, - const std::vector &object_ids, - bool fetch_only = false); +void local_scheduler_fetch_or_reconstruct(LocalSchedulerConnection *conn, + const std::vector &object_ids, + bool fetch_only, const TaskID ¤t_task_id); /** * Notify the local scheduler that this client (worker) is no longer blocked. * * @param conn The connection information. + * @param current_task_id The task that is no longer blocked. * @return Void. */ -void local_scheduler_notify_unblocked(LocalSchedulerConnection *conn); +void local_scheduler_notify_unblocked(LocalSchedulerConnection *conn, + const TaskID ¤t_task_id); // /** // * Get an actor's current task frontier. @@ -140,11 +143,13 @@ void local_scheduler_notify_unblocked(LocalSchedulerConnection *conn); /// \param timeout_milliseconds Duration, in milliseconds, to wait before /// returning. /// \param wait_local Whether to wait for objects to appear on this node. +/// \param current_task_id The task that called wait. /// \return A pair with the first element containing the object ids that were /// found, and the second element the objects that were not found. std::pair, std::vector> local_scheduler_wait( LocalSchedulerConnection *conn, const std::vector &object_ids, - int num_returns, int64_t timeout_milliseconds, bool wait_local); + int num_returns, int64_t timeout_milliseconds, bool wait_local, + const TaskID ¤t_task_id); /// Push an error to the relevant driver. /// diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index e2d86df75..fd606b388 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -572,11 +572,12 @@ void NodeManager::ProcessClientMessage( case protocol::MessageType::SubmitTask: { ProcessSubmitTaskMessage(message_data); } break; - case protocol::MessageType::ReconstructObjects: { - ProcessReconstructObjectsMessage(client, message_data); + case protocol::MessageType::FetchOrReconstruct: { + ProcessFetchOrReconstructMessage(client, message_data); } break; case protocol::MessageType::NotifyUnblocked: { - HandleClientUnblocked(client); + auto message = flatbuffers::GetRoot(message_data); + HandleTaskUnblocked(client, from_flatbuf(*message->task_id())); } break; case protocol::MessageType::WaitRequest: { ProcessWaitRequestMessage(client, message_data); @@ -645,27 +646,36 @@ void NodeManager::ProcessGetTaskMessage( void NodeManager::ProcessDisconnectClientMessage( const std::shared_ptr &client, bool push_warning) { - const std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); - const std::shared_ptr driver = worker_pool_.GetRegisteredDriver(client); - // This client can't be a worker and a driver. - RAY_CHECK(worker == nullptr || driver == nullptr); + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + bool is_worker = false, is_driver = false; + if (worker) { + // The client is a worker. + is_worker = true; + } else { + worker = worker_pool_.GetRegisteredDriver(client); + if (worker) { + // The client is a driver. + is_driver = true; + } else { + RAY_LOG(INFO) << "Ignoring client disconnect because the client has already " + << "been disconnected."; + } + } + RAY_CHECK(!(is_worker && is_driver)); - // If both worker and driver are null, then this method has already been - // called, so just return. - if (worker == nullptr && driver == nullptr) { - RAY_LOG(INFO) << "Ignoring client disconnect because the client has already " - << "been disconnected."; - return; + // If the client has any blocked tasks, mark them as unblocked. In + // particular, we are no longer waiting for their dependencies. + if (worker) { + while (!worker->GetBlockedTaskIds().empty()) { + // NOTE(swang): HandleTaskUnblocked will modify the worker, so it is + // not safe to pass in the iterator directly. + const TaskID task_id = *worker->GetBlockedTaskIds().begin(); + HandleTaskUnblocked(client, task_id); + } } - // If the client is blocked, we need to treat it as unblocked. In particular, - // we are no longer waiting for its dependencies. If the client is not - // blocked, this won't do anything. - HandleClientUnblocked(client); - // Remove the dead client from the pool and stop listening for messages. - - if (worker) { + if (is_worker) { // The client is a worker. Handle the case where the worker is killed // while executing a task. Clean up the assigned task's resources, push // an error to the driver. @@ -676,7 +686,6 @@ void NodeManager::ProcessDisconnectClientMessage( // the task that this worker is currently executing exits, the task for this // worker has already been removed from queue, so the following are skipped. task_dependency_manager_.TaskCanceled(task_id); - // task_dependency_manager_.UnsubscribeDependencies(current_task_id); const Task &task = local_queues_.RemoveTask(task_id); const TaskSpecification &spec = task.GetTaskSpecification(); // Handle the task failure in order to raise an exception in the @@ -730,18 +739,17 @@ void NodeManager::ProcessDisconnectClientMessage( // Since some resources may have been released, we can try to dispatch more tasks. DispatchTasks(); - } else { + } else if (is_driver) { // The client is a driver. RAY_CHECK_OK(gcs_client_->driver_table().AppendDriverData(client->GetClientID(), /*is_dead=*/true)); - RAY_CHECK(driver); - auto driver_id = driver->GetAssignedTaskId(); + auto driver_id = worker->GetAssignedTaskId(); RAY_CHECK(!driver_id.is_nil()); local_queues_.RemoveDriverTaskId(driver_id); - worker_pool_.DisconnectDriver(driver); + worker_pool_.DisconnectDriver(worker); - RAY_LOG(DEBUG) << "Driver (pid=" << driver->Pid() << ") is disconnected. " - << "driver_id: " << driver->GetAssignedDriverId(); + RAY_LOG(DEBUG) << "Driver (pid=" << worker->Pid() << ") is disconnected. " + << "driver_id: " << worker->GetAssignedDriverId(); } // TODO(rkn): Tell the object manager that this client has disconnected so @@ -761,29 +769,31 @@ void NodeManager::ProcessSubmitTaskMessage(const uint8_t *message_data) { SubmitTask(task, Lineage()); } -void NodeManager::ProcessReconstructObjectsMessage( +void NodeManager::ProcessFetchOrReconstructMessage( const std::shared_ptr &client, const uint8_t *message_data) { - auto message = flatbuffers::GetRoot(message_data); + auto message = flatbuffers::GetRoot(message_data); std::vector required_object_ids; for (size_t i = 0; i < message->object_ids()->size(); ++i) { ObjectID object_id = from_flatbuf(*message->object_ids()->Get(i)); - if (!task_dependency_manager_.CheckObjectLocal(object_id)) { - if (message->fetch_only()) { - // If only a fetch is required, then do not subscribe to the - // dependencies to the task dependency manager. + if (message->fetch_only()) { + // If only a fetch is required, then do not subscribe to the + // dependencies to the task dependency manager. + if (!task_dependency_manager_.CheckObjectLocal(object_id)) { + // Fetch the object if it's not already local. RAY_CHECK_OK(object_manager_.Pull(object_id)); - } else { - // If reconstruction is also required, then add any missing objects - // to the list to subscribe to in the task dependency manager. These - // objects will be pulled from remote node managers and reconstructed - // if necessary. - required_object_ids.push_back(object_id); } + } else { + // If reconstruction is also required, then add any requested objects to + // the list to subscribe to in the task dependency manager. These objects + // will be pulled from remote node managers and reconstructed if + // necessary. + required_object_ids.push_back(object_id); } } if (!required_object_ids.empty()) { - HandleClientBlocked(client, required_object_ids); + const TaskID task_id = from_flatbuf(*message->task_id()); + HandleTaskBlocked(client, required_object_ids, task_id); } } @@ -806,15 +816,16 @@ void NodeManager::ProcessWaitRequestMessage( } } + const TaskID ¤t_task_id = from_flatbuf(*message->task_id()); bool client_blocked = !required_object_ids.empty(); if (client_blocked) { - HandleClientBlocked(client, required_object_ids); + HandleTaskBlocked(client, required_object_ids, current_task_id); } ray::Status status = object_manager_.Wait( object_ids, wait_ms, num_required_objects, wait_local, - [this, client_blocked, client](std::vector found, - std::vector remaining) { + [this, client_blocked, client, current_task_id](std::vector found, + std::vector remaining) { // Write the data. flatbuffers::FlatBufferBuilder fbb; flatbuffers::Offset wait_reply = protocol::CreateWaitReply( @@ -827,7 +838,7 @@ void NodeManager::ProcessWaitRequestMessage( if (status.ok()) { // The client is unblocked now because the wait call has returned. if (client_blocked) { - HandleClientUnblocked(client); + HandleTaskUnblocked(client, current_task_id); } } else { // We failed to write to the client, so disconnect the client. @@ -1098,133 +1109,114 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag } } -void NodeManager::HandleWorkerBlocked(std::shared_ptr worker) { - RAY_CHECK(worker); - if (worker->IsBlocked()) { - return; - } - // If the worker isn't already blocked, then release any CPU resources that - // it acquired for its assigned task while it is blocked. The resources will - // be acquired again once the worker is unblocked. - RAY_CHECK(!worker->GetAssignedTaskId().is_nil()); - // (See design_docs/task_states.rst for the state transition diagram.) - const auto task = local_queues_.RemoveTask(worker->GetAssignedTaskId()); - // Get the CPU resources required by the running task. - const auto required_resources = task.GetTaskSpecification().GetRequiredResources(); - double required_cpus = required_resources.GetNumCpus(); - const std::unordered_map cpu_resources = { - {kCPU_ResourceLabel, required_cpus}}; - - // Release the CPU resources. - auto const cpu_resource_ids = worker->ReleaseTaskCpuResources(); - local_available_resources_.Release(cpu_resource_ids); - RAY_CHECK(cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()].Release( - ResourceSet(cpu_resources))); - - // Mark the task as blocked. - local_queues_.QueueBlockedTasks({task}); - worker->MarkBlocked(); - - DispatchTasks(); -} - -void NodeManager::HandleWorkerUnblocked(std::shared_ptr worker) { - RAY_CHECK(worker); - if (!worker->IsBlocked()) { - return; - } - - // (See design_docs/task_states.rst for the state transition diagram.) - const auto task = local_queues_.RemoveTask(worker->GetAssignedTaskId()); - // Get the CPU resources required by the running task. - const auto required_resources = task.GetTaskSpecification().GetRequiredResources(); - double required_cpus = required_resources.GetNumCpus(); - const ResourceSet cpu_resources( - std::unordered_map({{kCPU_ResourceLabel, required_cpus}})); - - // Check if we can reacquire the CPU resources. - bool oversubscribed = !local_available_resources_.Contains(cpu_resources); - - if (!oversubscribed) { - // Reacquire the CPU resources for the worker. Note that care needs to be - // taken if the user is using the specific CPU IDs since the IDs that we - // reacquire here may be different from the ones that the task started with. - auto const resource_ids = local_available_resources_.Acquire(cpu_resources); - worker->AcquireTaskCpuResources(resource_ids); - RAY_CHECK( - cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()].Acquire( - cpu_resources)); - } else { - // In this case, we simply don't reacquire the CPU resources for the worker. - // The worker can keep running and when the task finishes, it will simply - // not have any CPU resources to release. - RAY_LOG(WARNING) - << "Resources oversubscribed: " - << cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()] - .GetAvailableResources() - .ToString(); - } - - // Mark the task as running again. - // (See design_docs/task_states.rst for the state transition diagram.) - local_queues_.QueueRunningTasks({task}); - worker->MarkUnblocked(); -} - -void NodeManager::HandleClientBlocked( - const std::shared_ptr &client, - const std::vector &required_object_ids) { +void NodeManager::HandleTaskBlocked(const std::shared_ptr &client, + const std::vector &required_object_ids, + const TaskID ¤t_task_id) { std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); if (worker) { - // The client is a worker. Mark the worker as blocked. This - // temporarily releases any resources that the worker holds while it is - // blocked. - HandleWorkerBlocked(worker); + // The client is a worker. If the worker is not already blocked and the + // blocked task matches the one assigned to the worker, then mark the + // worker as blocked. This temporarily releases any resources that the + // worker holds while it is blocked. + if (!worker->IsBlocked() && current_task_id == worker->GetAssignedTaskId()) { + const auto task = local_queues_.RemoveTask(current_task_id); + local_queues_.QueueRunningTasks({task}); + // Get the CPU resources required by the running task. + const auto required_resources = task.GetTaskSpecification().GetRequiredResources(); + double required_cpus = required_resources.GetNumCpus(); + const std::unordered_map cpu_resources = { + {kCPU_ResourceLabel, required_cpus}}; + + // Release the CPU resources. + auto const cpu_resource_ids = worker->ReleaseTaskCpuResources(); + local_available_resources_.Release(cpu_resource_ids); + RAY_CHECK( + cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()].Release( + ResourceSet(cpu_resources))); + worker->MarkBlocked(); + + // Try dispatching tasks since we may have released some resources. + DispatchTasks(); + } } else { - // The client is a driver. Drivers do not hold resources, so we simply - // mark the driver as blocked. + // The client is a driver. Drivers do not hold resources, so we simply mark + // the task as blocked. worker = worker_pool_.GetRegisteredDriver(client); - RAY_CHECK(worker); - worker->MarkBlocked(); } - const TaskID current_task_id = worker->GetAssignedTaskId(); - RAY_CHECK(!current_task_id.is_nil()); + + RAY_CHECK(worker); + // Mark the task as blocked. + worker->AddBlockedTaskId(current_task_id); + if (local_queues_.GetBlockedTaskIds().count(current_task_id) == 0) { + local_queues_.AddBlockedTaskId(current_task_id); + } + // Subscribe to the objects required by the ray.get. These objects will // be fetched and/or reconstructed as necessary, until the objects become // local or are unsubscribed. task_dependency_manager_.SubscribeDependencies(current_task_id, required_object_ids); } -void NodeManager::HandleClientUnblocked( - const std::shared_ptr &client) { +void NodeManager::HandleTaskUnblocked( + const std::shared_ptr &client, const TaskID ¤t_task_id) { std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); - // Re-acquire the CPU resources for the task that was assigned to the - // unblocked worker. // TODO(swang): Because the object dependencies are tracked in the task // dependency manager, we could actually remove this message entirely and // instead unblock the worker once all the objects become available. - bool was_blocked; if (worker) { - was_blocked = worker->IsBlocked(); - // Mark the worker as unblocked. This returns the temporarily released - // resources to the worker. - HandleWorkerUnblocked(worker); + // The client is a worker. If the worker is not already unblocked and the + // unblocked task matches the one assigned to the worker, then mark the + // worker as unblocked. This returns the temporarily released resources to + // the worker. + if (worker->IsBlocked() && current_task_id == worker->GetAssignedTaskId()) { + // (See design_docs/task_states.rst for the state transition diagram.) + const auto task = local_queues_.RemoveTask(current_task_id); + local_queues_.QueueRunningTasks({task}); + // Get the CPU resources required by the running task. + const auto required_resources = task.GetTaskSpecification().GetRequiredResources(); + double required_cpus = required_resources.GetNumCpus(); + const ResourceSet cpu_resources( + std::unordered_map({{kCPU_ResourceLabel, required_cpus}})); + + // Check if we can reacquire the CPU resources. + bool oversubscribed = !local_available_resources_.Contains(cpu_resources); + + if (!oversubscribed) { + // Reacquire the CPU resources for the worker. Note that care needs to be + // taken if the user is using the specific CPU IDs since the IDs that we + // reacquire here may be different from the ones that the task started with. + auto const resource_ids = local_available_resources_.Acquire(cpu_resources); + worker->AcquireTaskCpuResources(resource_ids); + RAY_CHECK( + cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()].Acquire( + cpu_resources)); + } else { + // In this case, we simply don't reacquire the CPU resources for the worker. + // The worker can keep running and when the task finishes, it will simply + // not have any CPU resources to release. + RAY_LOG(WARNING) + << "Resources oversubscribed: " + << cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()] + .GetAvailableResources() + .ToString(); + } + worker->MarkUnblocked(); + } } else { // The client is a driver. Drivers do not hold resources, so we simply // mark the driver as unblocked. worker = worker_pool_.GetRegisteredDriver(client); - RAY_CHECK(worker); - was_blocked = worker->IsBlocked(); - worker->MarkUnblocked(); } + + RAY_CHECK(worker); + // If the task was previously blocked, then stop waiting for its dependencies + // and mark the task as unblocked. + worker->RemoveBlockedTaskId(current_task_id); // Unsubscribe to the objects. Any fetch or reconstruction operations to // make the objects local are canceled. - if (was_blocked) { - const TaskID current_task_id = worker->GetAssignedTaskId(); - RAY_CHECK(!current_task_id.is_nil()); - task_dependency_manager_.UnsubscribeDependencies(current_task_id); - } + task_dependency_manager_.UnsubscribeDependencies(current_task_id); + local_queues_.RemoveBlockedTaskId(current_task_id); } void NodeManager::EnqueuePlaceableTask(const Task &task) { diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 9e45a8e20..10a3a14e3 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -196,33 +196,31 @@ class NodeManager { /// Dispatch locally scheduled tasks. This attempts the transition from "scheduled" to /// "running" task state. void DispatchTasks(); - /// Handle a worker becoming blocked in a `ray.get`. - /// - /// \param worker The worker that is blocked. - /// \return Void. - void HandleWorkerBlocked(std::shared_ptr worker); - /// Handle a worker exiting a `ray.get`. - /// - /// \param worker The worker that is unblocked. - /// \return Void. - void HandleWorkerUnblocked(std::shared_ptr worker); - /// Handle a client that is blocked. This could be a worker or a driver. This - /// can be triggered when a client starts a get call or a wait call. + /// Handle a task that is blocked. This could be a task assigned to a worker, + /// an out-of-band task (e.g., a thread created by the application), or a + /// driver task. This can be triggered when a client starts a get call or a + /// wait call. /// - /// \param client The client that is blocked. + /// \param client The client that is executing the blocked task. /// \param required_object_ids The IDs that the client is blocked waiting for. + /// \param current_task_id The task that is blocked. /// \return Void. - void HandleClientBlocked(const std::shared_ptr &client, - const std::vector &required_object_ids); + void HandleTaskBlocked(const std::shared_ptr &client, + const std::vector &required_object_ids, + const TaskID ¤t_task_id); - /// Handle a client that is unblocked. This could be a worker or a driver. - /// This can be triggered when a client is finished with a get call or a wait - /// call. It is ok to call this even if the client is not actually blocked. + /// Handle a task that is unblocked. This could be a task assigned to a + /// worker, an out-of-band task (e.g., a thread created by the application), + /// or a driver task. This can be triggered when a client finishes a get call + /// or a wait call. The given task must be blocked, via a previous call to + /// HandleTaskBlocked. /// - /// \param client The client that is unblocked. + /// \param client The client that is executing the unblocked task. + /// \param current_task_id The task that is unblocked. /// \return Void. - void HandleClientUnblocked(const std::shared_ptr &client); + void HandleTaskUnblocked(const std::shared_ptr &client, + const TaskID ¤t_task_id); /// Kill a worker. /// @@ -312,12 +310,12 @@ class NodeManager { /// \return Void. void ProcessSubmitTaskMessage(const uint8_t *message_data); - /// Process client message of ReconstructObjects + /// Process client message of FetchOrReconstruct /// /// \param client The client that sent the message. /// \param message_data A pointer to the message data. /// \return Void. - void ProcessReconstructObjectsMessage( + void ProcessFetchOrReconstructMessage( const std::shared_ptr &client, const uint8_t *message_data); /// Process client message of WaitRequest diff --git a/src/ray/raylet/scheduling_queue.cc b/src/ray/raylet/scheduling_queue.cc index 049f772a4..7f3a412e9 100644 --- a/src/ray/raylet/scheduling_queue.cc +++ b/src/ray/raylet/scheduling_queue.cc @@ -147,8 +147,8 @@ const std::list &SchedulingQueue::GetRunningTasks() const { return this->running_tasks_.GetTasks(); } -const std::list &SchedulingQueue::GetBlockedTasks() const { - return this->blocked_tasks_.GetTasks(); +const std::unordered_set &SchedulingQueue::GetBlockedTaskIds() const { + return blocked_task_ids_; } void SchedulingQueue::FilterState(std::unordered_set &task_ids, @@ -166,9 +166,16 @@ void SchedulingQueue::FilterState(std::unordered_set &task_ids, case TaskState::RUNNING: FilterStateFromQueue(running_tasks_, task_ids, filter_state); break; - case TaskState::BLOCKED: - FilterStateFromQueue(blocked_tasks_, task_ids, filter_state); - break; + case TaskState::BLOCKED: { + const auto blocked_ids = GetBlockedTaskIds(); + for (auto it = task_ids.begin(); it != task_ids.end();) { + if (blocked_ids.count(*it) == 1) { + it = task_ids.erase(it); + } else { + it++; + } + } + } break; case TaskState::INFEASIBLE: FilterStateFromQueue(infeasible_tasks_, task_ids, filter_state); break; @@ -198,7 +205,6 @@ std::vector SchedulingQueue::RemoveTasks(std::unordered_set &task_ RemoveTasksFromQueue(placeable_tasks_, task_ids, removed_tasks); RemoveTasksFromQueue(ready_tasks_, task_ids, removed_tasks); RemoveTasksFromQueue(running_tasks_, task_ids, removed_tasks); - RemoveTasksFromQueue(blocked_tasks_, task_ids, removed_tasks); RemoveTasksFromQueue(infeasible_tasks_, task_ids, removed_tasks); RAY_CHECK(task_ids.size() == 0); @@ -230,9 +236,6 @@ void SchedulingQueue::MoveTasks(std::unordered_set &task_ids, TaskState case TaskState::RUNNING: RemoveTasksFromQueue(running_tasks_, task_ids, removed_tasks); break; - case TaskState::BLOCKED: - RemoveTasksFromQueue(blocked_tasks_, task_ids, removed_tasks); - break; case TaskState::INFEASIBLE: RemoveTasksFromQueue(infeasible_tasks_, task_ids, removed_tasks); break; @@ -254,9 +257,6 @@ void SchedulingQueue::MoveTasks(std::unordered_set &task_ids, TaskState case TaskState::RUNNING: QueueTasks(running_tasks_, removed_tasks); break; - case TaskState::BLOCKED: - QueueTasks(blocked_tasks_, removed_tasks); - break; case TaskState::INFEASIBLE: QueueTasks(infeasible_tasks_, removed_tasks); break; @@ -275,7 +275,7 @@ bool SchedulingQueue::HasTask(const TaskID &task_id) const { return (methods_waiting_for_actor_creation_.HasTask(task_id) || waiting_tasks_.HasTask(task_id) || placeable_tasks_.HasTask(task_id) || ready_tasks_.HasTask(task_id) || running_tasks_.HasTask(task_id) || - blocked_tasks_.HasTask(task_id) || infeasible_tasks_.HasTask(task_id)); + infeasible_tasks_.HasTask(task_id)); } void SchedulingQueue::QueueWaitingTasks(const std::vector &tasks) { @@ -294,10 +294,6 @@ void SchedulingQueue::QueueRunningTasks(const std::vector &tasks) { QueueTasks(running_tasks_, tasks); } -void SchedulingQueue::QueueBlockedTasks(const std::vector &tasks) { - QueueTasks(blocked_tasks_, tasks); -} - std::unordered_set SchedulingQueue::GetTaskIdsForDriver( const DriverID &driver_id) const { std::unordered_set task_ids; @@ -307,7 +303,6 @@ std::unordered_set SchedulingQueue::GetTaskIdsForDriver( GetDriverTasksFromQueue(placeable_tasks_, driver_id, task_ids); GetDriverTasksFromQueue(ready_tasks_, driver_id, task_ids); GetDriverTasksFromQueue(running_tasks_, driver_id, task_ids); - GetDriverTasksFromQueue(blocked_tasks_, driver_id, task_ids); GetDriverTasksFromQueue(infeasible_tasks_, driver_id, task_ids); return task_ids; @@ -322,12 +317,21 @@ std::unordered_set SchedulingQueue::GetTaskIdsForActor( GetActorTasksFromQueue(placeable_tasks_, actor_id, task_ids); GetActorTasksFromQueue(ready_tasks_, actor_id, task_ids); GetActorTasksFromQueue(running_tasks_, actor_id, task_ids); - GetActorTasksFromQueue(blocked_tasks_, actor_id, task_ids); GetActorTasksFromQueue(infeasible_tasks_, actor_id, task_ids); return task_ids; } +void SchedulingQueue::AddBlockedTaskId(const TaskID &task_id) { + auto inserted = blocked_task_ids_.insert(task_id); + RAY_CHECK(inserted.second); +} + +void SchedulingQueue::RemoveBlockedTaskId(const TaskID &task_id) { + auto erased = blocked_task_ids_.erase(task_id); + RAY_CHECK(erased == 1); +} + void SchedulingQueue::AddDriverTaskId(const TaskID &driver_id) { auto inserted = driver_task_ids_.insert(driver_id); RAY_CHECK(inserted.second); @@ -353,8 +357,6 @@ const std::string SchedulingQueue::ToString() const { "ready_tasks_ size is " + std::to_string(ready_tasks_.GetTasks().size()) + "\n"; result += "running_tasks_ size is " + std::to_string(running_tasks_.GetTasks().size()) + "\n"; - result += - "blocked_tasks_ size is " + std::to_string(blocked_tasks_.GetTasks().size()) + "\n"; result += "infeasible_tasks_ size is " + std::to_string(infeasible_tasks_.GetTasks().size()) + "\n"; result += "methods_waiting_for_actor_creation_ size is " + diff --git a/src/ray/raylet/scheduling_queue.h b/src/ray/raylet/scheduling_queue.h index 14030426a..a360edf64 100644 --- a/src/ray/raylet/scheduling_queue.h +++ b/src/ray/raylet/scheduling_queue.h @@ -14,12 +14,26 @@ namespace raylet { enum class TaskState { INIT, + // The task may be placed on a node. PLACEABLE, + // The task has been placed on a node and is waiting for some object + // dependencies to become local. WAITING, + // The task has been placed on a node, all dependencies are satisfied, and is + // waiting for resources to run. READY, + // The task is running on a worker. The task may also be blocked in a ray.get + // or ray.wait call, in which case it also has state BLOCKED. RUNNING, + // The task is running but blocked in a ray.get or ray.wait call. Tasks that + // were explicitly assigned by us may be both BLOCKED and RUNNING, while + // tasks that were created out-of-band (e.g., the application created + // multiple threads) are only BLOCKED. BLOCKED, + // The task is a driver task. DRIVER, + // The task has resources that cannot be satisfied by any node, as far as we + // know. INFEASIBLE }; @@ -86,10 +100,12 @@ class SchedulingQueue { /// Get the tasks in the blocked state. /// - /// \return A const reference to the queue of tasks that have been dispatched - /// to a worker but are blocked on a data dependency discovered to be missing - /// at runtime. - const std::list &GetBlockedTasks() const; + /// \return A const reference to the tasks that are are blocked on a data + /// dependency discovered to be missing at runtime. These include RUNNING + /// tasks that were explicitly assigned to a worker by us, as well as tasks + /// that were created out-of-band (e.g., the application created + // multiple threads) are only BLOCKED. + const std::unordered_set &GetBlockedTaskIds() const; /// Get the set of driver task IDs. /// @@ -143,12 +159,19 @@ class SchedulingQueue { /// \param tasks The tasks to queue. void QueueRunningTasks(const std::vector &tasks); - /// Queue tasks in the blocked state. These are tasks that have been + /// Add a task ID in the blocked state. These are tasks that have been /// dispatched to a worker but are blocked on a data dependency that was /// discovered to be missing at runtime. /// - /// \param tasks The tasks to queue. - void QueueBlockedTasks(const std::vector &tasks); + /// \param task_id The task to mark as blocked. + void AddBlockedTaskId(const TaskID &task_id); + + /// Remove a task ID in the blocked state. These are tasks that have been + /// dispatched to a worker but were blocked on a data dependency that was + /// discovered to be missing at runtime. + /// + /// \param task_id The task to mark as unblocked. + void RemoveBlockedTaskId(const TaskID &task_id); /// Add a driver task ID. This is an empty task used to represent a driver. /// @@ -265,7 +288,7 @@ class SchedulingQueue { TaskQueue running_tasks_; /// Tasks that were dispatched to a worker but are blocked on a data /// dependency that was missing at runtime. - TaskQueue blocked_tasks_; + std::unordered_set blocked_task_ids_; /// Tasks that require resources that are not available on any of the nodes /// in the cluster. TaskQueue infeasible_tasks_; diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index 0db5bb406..5a2342514 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -38,6 +38,20 @@ void Worker::AssignTaskId(const TaskID &task_id) { assigned_task_id_ = task_id; const TaskID &Worker::GetAssignedTaskId() const { return assigned_task_id_; } +bool Worker::AddBlockedTaskId(const TaskID &task_id) { + auto inserted = blocked_task_ids_.insert(task_id); + return inserted.second; +} + +bool Worker::RemoveBlockedTaskId(const TaskID &task_id) { + auto erased = blocked_task_ids_.erase(task_id); + return erased == 1; +} + +const std::unordered_set &Worker::GetBlockedTaskIds() const { + return blocked_task_ids_; +} + void Worker::AssignDriverId(const DriverID &driver_id) { assigned_driver_id_ = driver_id; } diff --git a/src/ray/raylet/worker.h b/src/ray/raylet/worker.h index c6ec7bac8..4860342e3 100644 --- a/src/ray/raylet/worker.h +++ b/src/ray/raylet/worker.h @@ -31,6 +31,9 @@ class Worker { Language GetLanguage() const; void AssignTaskId(const TaskID &task_id); const TaskID &GetAssignedTaskId() const; + bool AddBlockedTaskId(const TaskID &task_id); + bool RemoveBlockedTaskId(const TaskID &task_id); + const std::unordered_set &GetBlockedTaskIds() const; void AssignDriverId(const DriverID &driver_id); const DriverID &GetAssignedDriverId() const; void AssignActorId(const ActorID &actor_id); @@ -72,6 +75,7 @@ class Worker { /// The specific resource IDs that this worker currently owns for the duration // of a task. ResourceIdSet task_resource_ids_; + std::unordered_set blocked_task_ids_; }; } // namespace raylet diff --git a/test/runtest.py b/test/runtest.py index 704d1145a..3e2faadf4 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -1171,7 +1171,9 @@ def test_illegal_api_calls(shutdown_only): def test_multithreading(shutdown_only): - ray.init(num_cpus=1) + # This test requires at least 2 CPUs to finish since the worker does not + # relase resources when joining the threads. + ray.init(num_cpus=2) @ray.remote def f(): @@ -1196,11 +1198,37 @@ def test_multithreading(shutdown_only): def test_multi_threading_in_worker(): test_multi_threading() + def block(args, n): + ray.wait(args, num_returns=n) + ray.get(args[:n]) + + @ray.remote + class MultithreadedActor(object): + def __init__(self): + pass + + def spawn(self): + objects = [f.remote() for _ in range(1000)] + self.threads = [ + threading.Thread(target=block, args=(objects, n)) + for n in [1, 5, 10, 100, 1000] + ] + + [thread.start() for thread in self.threads] + + def join(self): + [thread.join() for thread in self.threads] + # test multi-threading in the driver test_multi_threading() # test multi-threading in the worker ray.get(test_multi_threading_in_worker.remote()) + # test multi-threading in the actor + a = MultithreadedActor.remote() + ray.get(a.spawn.remote()) + ray.get(a.join.remote()) + def test_free_objects_multi_node(shutdown_only): ray.worker._init( diff --git a/test/stress_tests.py b/test/stress_tests.py index 284bdb436..3d4b0fb36 100644 --- a/test/stress_tests.py +++ b/test/stress_tests.py @@ -535,7 +535,7 @@ def test_driver_put_errors(ray_start_driver_put_errors): # were evicted and whose originating tasks are still running, this # for-loop should hang on its first iteration and push an error to the # driver. - ray.worker.global_worker.local_scheduler_client.reconstruct_objects( + ray.worker.global_worker.local_scheduler_client.fetch_or_reconstruct( [args[0]], False) def error_check(errors):