mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 00:52:10 +08:00
Allow multiple threads to call ray.get and ray.wait (#3244)
* Handle multiple threads calling ray.get * Multithreaded ray.wait * Pass in current task ID in java backend * Add multithreaded actor to tests, add warning messages to worker for multithreaded ray.get * Fix test * Some cleanups * Improve error message * Add assertion * Cleanup, throw error in HandleTaskUnblocked if task not actually blocked * lint * Fix python worker reset * Fix references to reconstruct_objects * Linting * java lint * Fix java * Fix iterator
This commit is contained in:
@@ -36,7 +36,7 @@ def fetch(oids):
|
||||
local_sched_client = ray.worker.global_worker.local_scheduler_client
|
||||
for o in oids:
|
||||
ray_obj_id = ray.ObjectID(o)
|
||||
local_sched_client.reconstruct_objects([ray_obj_id], True)
|
||||
local_sched_client.fetch_or_reconstruct([ray_obj_id], True)
|
||||
|
||||
|
||||
def run_timeline(sess, ops, feed_dict=None, write_timeline=False, name=""):
|
||||
|
||||
@@ -40,7 +40,7 @@ class TaskPool(object):
|
||||
for worker, obj_id in self.completed():
|
||||
plasma_id = ray.pyarrow.plasma.ObjectID(obj_id.id())
|
||||
(ray.worker.global_worker.local_scheduler_client.
|
||||
reconstruct_objects([obj_id], True))
|
||||
fetch_or_reconstruct([obj_id], True))
|
||||
self._fetching.append((worker, obj_id))
|
||||
|
||||
remaining = []
|
||||
|
||||
@@ -423,3 +423,7 @@ def thread_safe_client(client, lock=None):
|
||||
if lock is None:
|
||||
lock = threading.Lock()
|
||||
return _ThreadSafeProxy(client, lock)
|
||||
|
||||
|
||||
def is_main_thread():
|
||||
return threading.current_thread().getName() == "MainThread"
|
||||
|
||||
+69
-14
@@ -217,9 +217,38 @@ class Worker(object):
|
||||
# A dictionary that maps from driver id to SerializationContext
|
||||
# TODO: clean up the SerializationContext once the job finished.
|
||||
self.serialization_context_map = {}
|
||||
# Identity of the driver that this worker is processing.
|
||||
self.task_driver_id = None
|
||||
self.function_actor_manager = FunctionActorManager(self)
|
||||
# Reads/writes to the following fields must be protected by
|
||||
# self.state_lock.
|
||||
# Identity of the driver that this worker is processing.
|
||||
self.task_driver_id = ray.ObjectID(NIL_ID)
|
||||
self.current_task_id = ray.ObjectID(NIL_ID)
|
||||
self.task_index = 0
|
||||
self.put_index = 1
|
||||
|
||||
def get_current_thread_task_id(self):
|
||||
"""Get the current thread's task ID.
|
||||
|
||||
This returns the assigned task ID if called on the main thread, else a
|
||||
random task ID. This method is not thread-safe and must be called with
|
||||
self.state_lock acquired.
|
||||
"""
|
||||
current_task_id = self.current_task_id
|
||||
if not ray.utils.is_main_thread():
|
||||
# If this is running on a separate thread, then the mapping
|
||||
# to the current task ID may not be correct. Generate a
|
||||
# random task ID so that the backend can differentiate
|
||||
# between different threads.
|
||||
current_task_id = ray.ObjectID(random_string())
|
||||
if not self.multithreading_warned:
|
||||
logger.warning(
|
||||
"Calling ray.get or ray.wait in a separate thread "
|
||||
"may lead to deadlock if the main thread blocks on this "
|
||||
"thread and there are not enough resources to execute "
|
||||
"more tasks")
|
||||
self.multithreading_warned = True
|
||||
assert not current_task_id.is_nil()
|
||||
return current_task_id
|
||||
|
||||
def mark_actor_init_failed(self, error):
|
||||
"""Called to mark this actor as failed during initialization."""
|
||||
@@ -456,7 +485,7 @@ class Worker(object):
|
||||
]
|
||||
for i in range(0, len(object_ids),
|
||||
ray._config.worker_fetch_request_size()):
|
||||
self.local_scheduler_client.reconstruct_objects(
|
||||
self.local_scheduler_client.fetch_or_reconstruct(
|
||||
object_ids[i:(i + ray._config.worker_fetch_request_size())],
|
||||
True)
|
||||
|
||||
@@ -472,6 +501,9 @@ class Worker(object):
|
||||
|
||||
if len(unready_ids) > 0:
|
||||
with self.state_lock:
|
||||
# Get the task ID, to notify the backend which task is blocked.
|
||||
current_task_id = self.get_current_thread_task_id()
|
||||
|
||||
# Try reconstructing any objects we haven't gotten yet. Try to
|
||||
# get them until at least get_timeout_milliseconds
|
||||
# milliseconds passes, then repeat.
|
||||
@@ -488,9 +520,10 @@ class Worker(object):
|
||||
ray._config.worker_fetch_request_size())
|
||||
for i in range(0, len(object_ids_to_fetch),
|
||||
fetch_request_size):
|
||||
self.local_scheduler_client.reconstruct_objects(
|
||||
self.local_scheduler_client.fetch_or_reconstruct(
|
||||
ray_object_ids_to_fetch[i:(
|
||||
i + fetch_request_size)], False)
|
||||
i + fetch_request_size)], False,
|
||||
current_task_id)
|
||||
results = self.retrieve_and_deserialize(
|
||||
object_ids_to_fetch,
|
||||
max([
|
||||
@@ -508,7 +541,7 @@ class Worker(object):
|
||||
|
||||
# If there were objects that we weren't able to get locally,
|
||||
# let the local scheduler know that we're now unblocked.
|
||||
self.local_scheduler_client.notify_unblocked()
|
||||
self.local_scheduler_client.notify_unblocked(current_task_id)
|
||||
|
||||
assert len(final_results) == len(object_ids)
|
||||
return final_results
|
||||
@@ -615,6 +648,8 @@ class Worker(object):
|
||||
# have been submitted by the current task so far.
|
||||
task_index = self.task_index
|
||||
self.task_index += 1
|
||||
# The parent task must be set for the submitted task.
|
||||
assert not self.current_task_id.is_nil()
|
||||
# Submit the task to local scheduler.
|
||||
task = ray.raylet.Task(
|
||||
driver_id, ray.ObjectID(
|
||||
@@ -762,13 +797,18 @@ class Worker(object):
|
||||
(these will be retrieved by calls to get or by subsequent tasks that
|
||||
use the outputs of this task).
|
||||
"""
|
||||
# The ID of the driver that this task belongs to. This is needed so
|
||||
# that if the task throws an exception, we propagate the error
|
||||
# message to the correct driver.
|
||||
self.task_driver_id = task.driver_id()
|
||||
self.current_task_id = task.task_id()
|
||||
self.task_index = 0
|
||||
self.put_index = 1
|
||||
with self.state_lock:
|
||||
assert self.task_driver_id.is_nil()
|
||||
assert self.current_task_id.is_nil()
|
||||
assert self.task_index == 0
|
||||
assert self.put_index == 1
|
||||
|
||||
# The ID of the driver that this task belongs to. This is needed so
|
||||
# that if the task throws an exception, we propagate the error
|
||||
# message to the correct driver.
|
||||
self.task_driver_id = task.driver_id()
|
||||
self.current_task_id = task.task_id()
|
||||
|
||||
function_id = task.function_id()
|
||||
args = task.arguments()
|
||||
return_object_ids = task.returns()
|
||||
@@ -912,6 +952,12 @@ class Worker(object):
|
||||
with profiling.profile("task", extra_data=extra_data, worker=self):
|
||||
with _changeproctitle(title):
|
||||
self._process_task(task, execution_info)
|
||||
# Reset the state fields so the next task can run.
|
||||
with self.state_lock:
|
||||
self.task_driver_id = ray.ObjectID(NIL_ID)
|
||||
self.current_task_id = ray.ObjectID(NIL_ID)
|
||||
self.task_index = 0
|
||||
self.put_index = 1
|
||||
|
||||
# Increase the task execution counter.
|
||||
self.function_actor_manager.increase_task_counter(
|
||||
@@ -2044,6 +2090,9 @@ def connect(info,
|
||||
else:
|
||||
# A non-driver worker begins without an assigned task.
|
||||
worker.current_task_id = ray.ObjectID(NIL_ID)
|
||||
# A flag for making sure that we only print one warning message about
|
||||
# multithreading per worker.
|
||||
worker.multithreading_warned = False
|
||||
|
||||
worker.local_scheduler_client = ray.raylet.LocalSchedulerClient(
|
||||
local_scheduler_socket, worker.worker_id, is_worker,
|
||||
@@ -2376,6 +2425,7 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
|
||||
type(object_id)))
|
||||
|
||||
worker.check_connected()
|
||||
# TODO(swang): Check main thread.
|
||||
with profiling.profile("ray.wait", worker=worker):
|
||||
# When Ray is run in LOCAL_MODE, all functions are run immediately,
|
||||
# so all objects in object_id are ready.
|
||||
@@ -2396,9 +2446,14 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
|
||||
if num_returns > len(object_ids):
|
||||
raise Exception("num_returns cannot be greater than the number "
|
||||
"of objects provided to ray.wait.")
|
||||
|
||||
# Get the task ID, to notify the backend which task is blocked.
|
||||
with worker.state_lock:
|
||||
current_task_id = worker.get_current_thread_task_id()
|
||||
|
||||
timeout = timeout if timeout is not None else 2**30
|
||||
ready_ids, remaining_ids = worker.local_scheduler_client.wait(
|
||||
object_ids, num_returns, timeout, False)
|
||||
object_ids, num_returns, timeout, False, current_task_id)
|
||||
return ready_ids, remaining_ids
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user