Allow multiple threads to call ray.get and ray.wait (#3244)

* Handle multiple threads calling ray.get

* Multithreaded ray.wait

* Pass in current task ID in java backend

* Add multithreaded actor to tests, add warning messages to worker for multithreaded ray.get

* Fix test

* Some cleanups

* Improve error message

* Add assertion

* Cleanup, throw error in HandleTaskUnblocked if task not actually blocked

* lint

* Fix python worker reset

* Fix references to reconstruct_objects

* Linting

* java lint

* Fix java

* Fix iterator
This commit is contained in:
Stephanie Wang
2018-11-07 22:39:28 -08:00
committed by GitHub
parent 0bab8ed95c
commit d950e92f63
23 changed files with 460 additions and 281 deletions
+1 -1
View File
@@ -36,7 +36,7 @@ def fetch(oids):
local_sched_client = ray.worker.global_worker.local_scheduler_client
for o in oids:
ray_obj_id = ray.ObjectID(o)
local_sched_client.reconstruct_objects([ray_obj_id], True)
local_sched_client.fetch_or_reconstruct([ray_obj_id], True)
def run_timeline(sess, ops, feed_dict=None, write_timeline=False, name=""):
+1 -1
View File
@@ -40,7 +40,7 @@ class TaskPool(object):
for worker, obj_id in self.completed():
plasma_id = ray.pyarrow.plasma.ObjectID(obj_id.id())
(ray.worker.global_worker.local_scheduler_client.
reconstruct_objects([obj_id], True))
fetch_or_reconstruct([obj_id], True))
self._fetching.append((worker, obj_id))
remaining = []
+4
View File
@@ -423,3 +423,7 @@ def thread_safe_client(client, lock=None):
if lock is None:
lock = threading.Lock()
return _ThreadSafeProxy(client, lock)
def is_main_thread():
return threading.current_thread().getName() == "MainThread"
+69 -14
View File
@@ -217,9 +217,38 @@ class Worker(object):
# A dictionary that maps from driver id to SerializationContext
# TODO: clean up the SerializationContext once the job finished.
self.serialization_context_map = {}
# Identity of the driver that this worker is processing.
self.task_driver_id = None
self.function_actor_manager = FunctionActorManager(self)
# Reads/writes to the following fields must be protected by
# self.state_lock.
# Identity of the driver that this worker is processing.
self.task_driver_id = ray.ObjectID(NIL_ID)
self.current_task_id = ray.ObjectID(NIL_ID)
self.task_index = 0
self.put_index = 1
def get_current_thread_task_id(self):
"""Get the current thread's task ID.
This returns the assigned task ID if called on the main thread, else a
random task ID. This method is not thread-safe and must be called with
self.state_lock acquired.
"""
current_task_id = self.current_task_id
if not ray.utils.is_main_thread():
# If this is running on a separate thread, then the mapping
# to the current task ID may not be correct. Generate a
# random task ID so that the backend can differentiate
# between different threads.
current_task_id = ray.ObjectID(random_string())
if not self.multithreading_warned:
logger.warning(
"Calling ray.get or ray.wait in a separate thread "
"may lead to deadlock if the main thread blocks on this "
"thread and there are not enough resources to execute "
"more tasks")
self.multithreading_warned = True
assert not current_task_id.is_nil()
return current_task_id
def mark_actor_init_failed(self, error):
"""Called to mark this actor as failed during initialization."""
@@ -456,7 +485,7 @@ class Worker(object):
]
for i in range(0, len(object_ids),
ray._config.worker_fetch_request_size()):
self.local_scheduler_client.reconstruct_objects(
self.local_scheduler_client.fetch_or_reconstruct(
object_ids[i:(i + ray._config.worker_fetch_request_size())],
True)
@@ -472,6 +501,9 @@ class Worker(object):
if len(unready_ids) > 0:
with self.state_lock:
# Get the task ID, to notify the backend which task is blocked.
current_task_id = self.get_current_thread_task_id()
# Try reconstructing any objects we haven't gotten yet. Try to
# get them until at least get_timeout_milliseconds
# milliseconds passes, then repeat.
@@ -488,9 +520,10 @@ class Worker(object):
ray._config.worker_fetch_request_size())
for i in range(0, len(object_ids_to_fetch),
fetch_request_size):
self.local_scheduler_client.reconstruct_objects(
self.local_scheduler_client.fetch_or_reconstruct(
ray_object_ids_to_fetch[i:(
i + fetch_request_size)], False)
i + fetch_request_size)], False,
current_task_id)
results = self.retrieve_and_deserialize(
object_ids_to_fetch,
max([
@@ -508,7 +541,7 @@ class Worker(object):
# If there were objects that we weren't able to get locally,
# let the local scheduler know that we're now unblocked.
self.local_scheduler_client.notify_unblocked()
self.local_scheduler_client.notify_unblocked(current_task_id)
assert len(final_results) == len(object_ids)
return final_results
@@ -615,6 +648,8 @@ class Worker(object):
# have been submitted by the current task so far.
task_index = self.task_index
self.task_index += 1
# The parent task must be set for the submitted task.
assert not self.current_task_id.is_nil()
# Submit the task to local scheduler.
task = ray.raylet.Task(
driver_id, ray.ObjectID(
@@ -762,13 +797,18 @@ class Worker(object):
(these will be retrieved by calls to get or by subsequent tasks that
use the outputs of this task).
"""
# The ID of the driver that this task belongs to. This is needed so
# that if the task throws an exception, we propagate the error
# message to the correct driver.
self.task_driver_id = task.driver_id()
self.current_task_id = task.task_id()
self.task_index = 0
self.put_index = 1
with self.state_lock:
assert self.task_driver_id.is_nil()
assert self.current_task_id.is_nil()
assert self.task_index == 0
assert self.put_index == 1
# The ID of the driver that this task belongs to. This is needed so
# that if the task throws an exception, we propagate the error
# message to the correct driver.
self.task_driver_id = task.driver_id()
self.current_task_id = task.task_id()
function_id = task.function_id()
args = task.arguments()
return_object_ids = task.returns()
@@ -912,6 +952,12 @@ class Worker(object):
with profiling.profile("task", extra_data=extra_data, worker=self):
with _changeproctitle(title):
self._process_task(task, execution_info)
# Reset the state fields so the next task can run.
with self.state_lock:
self.task_driver_id = ray.ObjectID(NIL_ID)
self.current_task_id = ray.ObjectID(NIL_ID)
self.task_index = 0
self.put_index = 1
# Increase the task execution counter.
self.function_actor_manager.increase_task_counter(
@@ -2044,6 +2090,9 @@ def connect(info,
else:
# A non-driver worker begins without an assigned task.
worker.current_task_id = ray.ObjectID(NIL_ID)
# A flag for making sure that we only print one warning message about
# multithreading per worker.
worker.multithreading_warned = False
worker.local_scheduler_client = ray.raylet.LocalSchedulerClient(
local_scheduler_socket, worker.worker_id, is_worker,
@@ -2376,6 +2425,7 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
type(object_id)))
worker.check_connected()
# TODO(swang): Check main thread.
with profiling.profile("ray.wait", worker=worker):
# When Ray is run in LOCAL_MODE, all functions are run immediately,
# so all objects in object_id are ready.
@@ -2396,9 +2446,14 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
if num_returns > len(object_ids):
raise Exception("num_returns cannot be greater than the number "
"of objects provided to ray.wait.")
# Get the task ID, to notify the backend which task is blocked.
with worker.state_lock:
current_task_id = worker.get_current_thread_task_id()
timeout = timeout if timeout is not None else 2**30
ready_ids, remaining_ids = worker.local_scheduler_client.wait(
object_ids, num_returns, timeout, False)
object_ids, num_returns, timeout, False, current_task_id)
return ready_ids, remaining_ids