mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 07:23:55 +08:00
Refine multi-threading support (#3672)
* [Python] refine multi-threading support fix * [java] refine multithreading code fix java * format
This commit is contained in:
+28
-35
@@ -7,6 +7,7 @@ import hashlib
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
import ray.cloudpickle as pickle
|
||||
@@ -225,8 +226,7 @@ class ActorMethod(object):
|
||||
self._method_name,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
num_return_vals=num_return_vals,
|
||||
dependency=self._actor._ray_actor_cursor)
|
||||
num_return_vals=num_return_vals)
|
||||
|
||||
|
||||
class ActorClass(object):
|
||||
@@ -525,13 +525,13 @@ class ActorHandle(object):
|
||||
self._ray_actor_method_cpus = actor_method_cpus
|
||||
self._ray_actor_driver_id = actor_driver_id
|
||||
self._ray_new_actor_handles = []
|
||||
self._ray_actor_lock = threading.Lock()
|
||||
|
||||
def _actor_method_call(self,
|
||||
method_name,
|
||||
args=None,
|
||||
kwargs=None,
|
||||
num_return_vals=None,
|
||||
dependency=None):
|
||||
num_return_vals=None):
|
||||
"""Method execution stub for an actor handle.
|
||||
|
||||
This is the function that executes when
|
||||
@@ -570,41 +570,34 @@ class ActorHandle(object):
|
||||
return getattr(worker.actors[self._ray_actor_id],
|
||||
method_name)(*copy.deepcopy(args))
|
||||
|
||||
# Add the execution dependency.
|
||||
if dependency is None:
|
||||
execution_dependencies = []
|
||||
else:
|
||||
execution_dependencies = [dependency]
|
||||
|
||||
is_actor_checkpoint_method = (method_name == "__ray_checkpoint__")
|
||||
|
||||
function_descriptor = FunctionDescriptor(
|
||||
self._ray_module_name, method_name, self._ray_class_name)
|
||||
object_ids = worker.submit_task(
|
||||
function_descriptor,
|
||||
args,
|
||||
actor_id=self._ray_actor_id,
|
||||
actor_handle_id=self._ray_actor_handle_id,
|
||||
actor_counter=self._ray_actor_counter,
|
||||
is_actor_checkpoint_method=is_actor_checkpoint_method,
|
||||
actor_creation_dummy_object_id=(
|
||||
self._ray_actor_creation_dummy_object_id),
|
||||
execution_dependencies=execution_dependencies,
|
||||
new_actor_handles=self._ray_new_actor_handles,
|
||||
# We add one for the dummy return ID.
|
||||
num_return_vals=num_return_vals + 1,
|
||||
resources={"CPU": self._ray_actor_method_cpus},
|
||||
placement_resources={},
|
||||
driver_id=self._ray_actor_driver_id)
|
||||
# Update the actor counter and cursor to reflect the most recent
|
||||
# invocation.
|
||||
self._ray_actor_counter += 1
|
||||
# The last object returned is the dummy object that should be
|
||||
# passed in to the next actor method. Do not return it to the user.
|
||||
self._ray_actor_cursor = object_ids.pop()
|
||||
# We have notified the backend of the new actor handles to expect since
|
||||
# the last task was submitted, so clear the list.
|
||||
self._ray_new_actor_handles = []
|
||||
with self._ray_actor_lock:
|
||||
object_ids = worker.submit_task(
|
||||
function_descriptor,
|
||||
args,
|
||||
actor_id=self._ray_actor_id,
|
||||
actor_handle_id=self._ray_actor_handle_id,
|
||||
actor_counter=self._ray_actor_counter,
|
||||
is_actor_checkpoint_method=is_actor_checkpoint_method,
|
||||
actor_creation_dummy_object_id=(
|
||||
self._ray_actor_creation_dummy_object_id),
|
||||
execution_dependencies=[self._ray_actor_cursor],
|
||||
new_actor_handles=self._ray_new_actor_handles,
|
||||
# We add one for the dummy return ID.
|
||||
num_return_vals=num_return_vals + 1,
|
||||
resources={"CPU": self._ray_actor_method_cpus},
|
||||
placement_resources={},
|
||||
driver_id=self._ray_actor_driver_id,
|
||||
)
|
||||
# Update the actor counter and cursor to reflect the most recent
|
||||
# invocation.
|
||||
self._ray_actor_counter += 1
|
||||
# The last object returned is the dummy object that should be
|
||||
# passed in to the next actor method. Do not return it to the user.
|
||||
self._ray_actor_cursor = object_ids.pop()
|
||||
|
||||
if len(object_ids) == 1:
|
||||
object_ids = object_ids[0]
|
||||
|
||||
+157
-139
@@ -143,13 +143,6 @@ class Worker(object):
|
||||
cached_functions_to_run (List): A list of functions to run on all of
|
||||
the workers that should be exported as soon as connect is called.
|
||||
profiler: the profiler used to aggregate profiling information.
|
||||
state_lock (Lock):
|
||||
Used to lock worker's non-thread-safe internal states:
|
||||
1) task_index increment: make sure we generate unique task ids;
|
||||
2) Object reconstruction: because the node manager will
|
||||
recycle/return the worker's resources before/after reconstruction,
|
||||
it's unsafe for multiple threads to call object
|
||||
reconstruction simultaneously.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -169,42 +162,56 @@ class Worker(object):
|
||||
self.original_gpu_ids = ray.utils.get_cuda_visible_devices()
|
||||
self.profiler = None
|
||||
self.memory_monitor = memory_monitor.MemoryMonitor()
|
||||
self.state_lock = threading.Lock()
|
||||
# A dictionary that maps from driver id to SerializationContext
|
||||
# TODO: clean up the SerializationContext once the job finished.
|
||||
self.serialization_context_map = {}
|
||||
self.function_actor_manager = FunctionActorManager(self)
|
||||
# Reads/writes to the following fields must be protected by
|
||||
# self.state_lock.
|
||||
# Identity of the driver that this worker is processing.
|
||||
self.task_driver_id = ray.ObjectID(NIL_ID)
|
||||
self.current_task_id = ray.ObjectID(NIL_ID)
|
||||
self.task_index = 0
|
||||
self.put_index = 1
|
||||
self._task_context = threading.local()
|
||||
|
||||
def get_current_thread_task_id(self):
|
||||
"""Get the current thread's task ID.
|
||||
@property
|
||||
def task_context(self):
|
||||
"""A thread-local that contains the following attributes.
|
||||
|
||||
This returns the assigned task ID if called on the main thread, else a
|
||||
random task ID. This method is not thread-safe and must be called with
|
||||
self.state_lock acquired.
|
||||
current_task_id: For the main thread, this field is the ID of this
|
||||
worker's current running task; for other threads, this field is a
|
||||
fake random ID.
|
||||
task_index: The number of tasks that have been submitted from the
|
||||
current task.
|
||||
put_index: The number of objects that have been put from the current
|
||||
task.
|
||||
"""
|
||||
current_task_id = self.current_task_id
|
||||
if not ray.utils.is_main_thread():
|
||||
# If this is running on a separate thread, then the mapping
|
||||
# to the current task ID may not be correct. Generate a
|
||||
# random task ID so that the backend can differentiate
|
||||
# between different threads.
|
||||
current_task_id = ray.ObjectID(random_string())
|
||||
if not self.multithreading_warned:
|
||||
logger.warning(
|
||||
"Calling ray.get or ray.wait in a separate thread "
|
||||
"may lead to deadlock if the main thread blocks on this "
|
||||
"thread and there are not enough resources to execute "
|
||||
"more tasks")
|
||||
self.multithreading_warned = True
|
||||
assert not current_task_id.is_nil()
|
||||
return current_task_id
|
||||
if not hasattr(self._task_context, 'initialized'):
|
||||
# Initialize task_context for the current thread.
|
||||
if ray.utils.is_main_thread():
|
||||
# If this is running on the main thread, initialize it to
|
||||
# NIL. The actual value will set when the worker receives
|
||||
# a task from raylet backend.
|
||||
self._task_context.current_task_id = ray.ObjectID(NIL_ID)
|
||||
else:
|
||||
# If this is running on a separate thread, then the mapping
|
||||
# to the current task ID may not be correct. Generate a
|
||||
# random task ID so that the backend can differentiate
|
||||
# between different threads.
|
||||
self._task_context.current_task_id = ray.ObjectID(
|
||||
random_string())
|
||||
if getattr(self, '_multithreading_warned', False) is not True:
|
||||
logger.warning(
|
||||
"Calling ray.get or ray.wait in a separate thread "
|
||||
"may lead to deadlock if the main thread blocks on "
|
||||
"this thread and there are not enough resources to "
|
||||
"execute more tasks")
|
||||
self._multithreading_warned = True
|
||||
|
||||
self._task_context.task_index = 0
|
||||
self._task_context.put_index = 1
|
||||
self._task_context.initialized = True
|
||||
return self._task_context
|
||||
|
||||
@property
|
||||
def current_task_id(self):
|
||||
return self.task_context.current_task_id
|
||||
|
||||
def mark_actor_init_failed(self, error):
|
||||
"""Called to mark this actor as failed during initialization."""
|
||||
@@ -467,48 +474,45 @@ class Worker(object):
|
||||
}
|
||||
|
||||
if len(unready_ids) > 0:
|
||||
with self.state_lock:
|
||||
# Get the task ID, to notify the backend which task is blocked.
|
||||
current_task_id = self.get_current_thread_task_id()
|
||||
# Try reconstructing any objects we haven't gotten yet. Try to
|
||||
# get them until at least get_timeout_milliseconds
|
||||
# milliseconds passes, then repeat.
|
||||
while len(unready_ids) > 0:
|
||||
object_ids_to_fetch = [
|
||||
plasma.ObjectID(unready_id)
|
||||
for unready_id in unready_ids.keys()
|
||||
]
|
||||
ray_object_ids_to_fetch = [
|
||||
ray.ObjectID(unready_id)
|
||||
for unready_id in unready_ids.keys()
|
||||
]
|
||||
fetch_request_size = ray._config.worker_fetch_request_size()
|
||||
for i in range(0, len(object_ids_to_fetch),
|
||||
fetch_request_size):
|
||||
self.raylet_client.fetch_or_reconstruct(
|
||||
ray_object_ids_to_fetch[i:(i + fetch_request_size)],
|
||||
False,
|
||||
self.current_task_id,
|
||||
)
|
||||
results = self.retrieve_and_deserialize(
|
||||
object_ids_to_fetch,
|
||||
max([
|
||||
ray._config.get_timeout_milliseconds(),
|
||||
int(0.01 * len(unready_ids)),
|
||||
]),
|
||||
)
|
||||
# Remove any entries for objects we received during this
|
||||
# iteration so we don't retrieve the same object twice.
|
||||
for i, val in enumerate(results):
|
||||
if val is not plasma.ObjectNotAvailable:
|
||||
object_id = object_ids_to_fetch[i].binary()
|
||||
index = unready_ids[object_id]
|
||||
final_results[index] = val
|
||||
unready_ids.pop(object_id)
|
||||
|
||||
# Try reconstructing any objects we haven't gotten yet. Try to
|
||||
# get them until at least get_timeout_milliseconds
|
||||
# milliseconds passes, then repeat.
|
||||
while len(unready_ids) > 0:
|
||||
object_ids_to_fetch = [
|
||||
plasma.ObjectID(unready_id)
|
||||
for unready_id in unready_ids.keys()
|
||||
]
|
||||
ray_object_ids_to_fetch = [
|
||||
ray.ObjectID(unready_id)
|
||||
for unready_id in unready_ids.keys()
|
||||
]
|
||||
fetch_request_size = (
|
||||
ray._config.worker_fetch_request_size())
|
||||
for i in range(0, len(object_ids_to_fetch),
|
||||
fetch_request_size):
|
||||
self.raylet_client.fetch_or_reconstruct(
|
||||
ray_object_ids_to_fetch[i:(
|
||||
i + fetch_request_size)], False,
|
||||
current_task_id)
|
||||
results = self.retrieve_and_deserialize(
|
||||
object_ids_to_fetch,
|
||||
max([
|
||||
ray._config.get_timeout_milliseconds(),
|
||||
int(0.01 * len(unready_ids))
|
||||
]))
|
||||
# Remove any entries for objects we received during this
|
||||
# iteration so we don't retrieve the same object twice.
|
||||
for i, val in enumerate(results):
|
||||
if val is not plasma.ObjectNotAvailable:
|
||||
object_id = object_ids_to_fetch[i].binary()
|
||||
index = unready_ids[object_id]
|
||||
final_results[index] = val
|
||||
unready_ids.pop(object_id)
|
||||
|
||||
# If there were objects that we weren't able to get locally,
|
||||
# let the local scheduler know that we're now unblocked.
|
||||
self.raylet_client.notify_unblocked(current_task_id)
|
||||
# If there were objects that we weren't able to get locally,
|
||||
# let the local scheduler know that we're now unblocked.
|
||||
self.raylet_client.notify_unblocked(self.current_task_id)
|
||||
|
||||
assert len(final_results) == len(object_ids)
|
||||
return final_results
|
||||
@@ -616,24 +620,32 @@ class Worker(object):
|
||||
if placement_resources is None:
|
||||
placement_resources = {}
|
||||
|
||||
with self.state_lock:
|
||||
# Increment the worker's task index to track how many tasks
|
||||
# have been submitted by the current task so far.
|
||||
task_index = self.task_index
|
||||
self.task_index += 1
|
||||
# The parent task must be set for the submitted task.
|
||||
if self.actor_id == NIL_ACTOR_ID:
|
||||
assert not self.current_task_id.is_nil()
|
||||
# Increment the worker's task index to track how many tasks
|
||||
# have been submitted by the current task so far.
|
||||
self.task_context.task_index += 1
|
||||
# The parent task must be set for the submitted task.
|
||||
assert not self.current_task_id.is_nil()
|
||||
# Submit the task to local scheduler.
|
||||
function_descriptor_list = (
|
||||
function_descriptor.get_function_descriptor_list())
|
||||
task = ray.raylet.Task(
|
||||
driver_id, function_descriptor_list, args_for_local_scheduler,
|
||||
num_return_vals, self.current_task_id, task_index,
|
||||
actor_creation_id, actor_creation_dummy_object_id,
|
||||
max_actor_reconstructions, actor_id, actor_handle_id,
|
||||
actor_counter, new_actor_handles, execution_dependencies,
|
||||
resources, placement_resources)
|
||||
driver_id,
|
||||
function_descriptor_list,
|
||||
args_for_local_scheduler,
|
||||
num_return_vals,
|
||||
self.current_task_id,
|
||||
self.task_context.task_index,
|
||||
actor_creation_id,
|
||||
actor_creation_dummy_object_id,
|
||||
max_actor_reconstructions,
|
||||
actor_id,
|
||||
actor_handle_id,
|
||||
actor_counter,
|
||||
new_actor_handles,
|
||||
execution_dependencies,
|
||||
resources,
|
||||
placement_resources,
|
||||
)
|
||||
self.raylet_client.submit_task(task)
|
||||
|
||||
return task.returns()
|
||||
@@ -770,24 +782,23 @@ class Worker(object):
|
||||
(these will be retrieved by calls to get or by subsequent tasks that
|
||||
use the outputs of this task).
|
||||
"""
|
||||
with self.state_lock:
|
||||
assert self.current_task_id.is_nil()
|
||||
assert self.task_index == 0
|
||||
assert self.put_index == 1
|
||||
if task.actor_id().is_nil():
|
||||
# If this worker is not an actor, check that `task_driver_id`
|
||||
# was reset when the worker finished the previous task.
|
||||
assert self.task_driver_id.is_nil()
|
||||
# Set the driver ID of the current running task. This is
|
||||
# needed so that if the task throws an exception, we propagate
|
||||
# the error message to the correct driver.
|
||||
self.task_driver_id = task.driver_id()
|
||||
else:
|
||||
# If this worker is an actor, task_driver_id wasn't reset.
|
||||
# Check that current task's driver ID equals the previous one.
|
||||
assert self.task_driver_id == task.driver_id()
|
||||
assert self.current_task_id.is_nil()
|
||||
assert self.task_context.task_index == 0
|
||||
assert self.task_context.put_index == 1
|
||||
if task.actor_id().is_nil():
|
||||
# If this worker is not an actor, check that `task_driver_id`
|
||||
# was reset when the worker finished the previous task.
|
||||
assert self.task_driver_id.is_nil()
|
||||
# Set the driver ID of the current running task. This is
|
||||
# needed so that if the task throws an exception, we propagate
|
||||
# the error message to the correct driver.
|
||||
self.task_driver_id = task.driver_id()
|
||||
else:
|
||||
# If this worker is an actor, task_driver_id wasn't reset.
|
||||
# Check that current task's driver ID equals the previous one.
|
||||
assert self.task_driver_id == task.driver_id()
|
||||
|
||||
self.current_task_id = task.task_id()
|
||||
self.task_context.current_task_id = task.task_id()
|
||||
|
||||
function_descriptor = FunctionDescriptor.from_bytes_list(
|
||||
task.function_descriptor_list())
|
||||
@@ -931,13 +942,14 @@ class Worker(object):
|
||||
with _changeproctitle(title, next_title):
|
||||
self._process_task(task, execution_info)
|
||||
# Reset the state fields so the next task can run.
|
||||
with self.state_lock:
|
||||
if self.actor_id == NIL_ACTOR_ID:
|
||||
# We will keep task_driver_id unchanged for actor.
|
||||
self.task_driver_id = ray.ObjectID(NIL_ID)
|
||||
self.current_task_id = ray.ObjectID(NIL_ID)
|
||||
self.task_index = 0
|
||||
self.put_index = 1
|
||||
self.task_context.current_task_id = ray.ObjectID(NIL_ID)
|
||||
self.task_context.task_index = 0
|
||||
self.task_context.put_index = 1
|
||||
if self.actor_id == NIL_ACTOR_ID:
|
||||
# Don't need to reset task_driver_id if the worker is an
|
||||
# actor. Because the following tasks should all have the
|
||||
# same driver id.
|
||||
self.task_driver_id = ray.ObjectID(NIL_ID)
|
||||
|
||||
# Increase the task execution counter.
|
||||
self.function_actor_manager.increase_task_counter(
|
||||
@@ -1925,13 +1937,8 @@ def connect(ray_params,
|
||||
else:
|
||||
# Try to use true randomness.
|
||||
np.random.seed(None)
|
||||
worker.current_task_id = ray.ObjectID(
|
||||
np.random.bytes(ray_constants.ID_SIZE))
|
||||
# Reset the state of the numpy random number generator.
|
||||
np.random.set_state(numpy_state)
|
||||
# Set other fields needed for computing task IDs.
|
||||
worker.task_index = 0
|
||||
worker.put_index = 1
|
||||
|
||||
# Create an entry for the driver task in the task table. This task is
|
||||
# added immediately with status RUNNING. This allows us to push errors
|
||||
@@ -1944,11 +1951,22 @@ def connect(ray_params,
|
||||
function_descriptor = FunctionDescriptor.for_driver_task()
|
||||
driver_task = ray.raylet.Task(
|
||||
worker.task_driver_id,
|
||||
function_descriptor.get_function_descriptor_list(), [], 0,
|
||||
worker.current_task_id, worker.task_index,
|
||||
ray.ObjectID(NIL_ACTOR_ID), ray.ObjectID(NIL_ACTOR_ID), 0,
|
||||
ray.ObjectID(NIL_ACTOR_ID), ray.ObjectID(NIL_ACTOR_ID),
|
||||
nil_actor_counter, [], [], {"CPU": 0}, {})
|
||||
function_descriptor.get_function_descriptor_list(),
|
||||
[], # arguments.
|
||||
0, # num_returns.
|
||||
ray.ObjectID(random_string()), # parent_task_id.
|
||||
0, # parent_counter.
|
||||
ray.ObjectID(NIL_ACTOR_ID), # actor_creation_id.
|
||||
ray.ObjectID(NIL_ACTOR_ID), # actor_creation_dummy_object_id.
|
||||
0, # max_actor_reconstructions.
|
||||
ray.ObjectID(NIL_ACTOR_ID), # actor_id.
|
||||
ray.ObjectID(NIL_ACTOR_ID), # actor_handle_id.
|
||||
nil_actor_counter, # actor_counter.
|
||||
[], # new_actor_handles.
|
||||
[], # execution_dependencies.
|
||||
{"CPU": 0}, # resource_map.
|
||||
{}, # placement_resource_map.
|
||||
)
|
||||
|
||||
# Add the driver task to the task table.
|
||||
global_state._execute_command(driver_task.task_id(), "RAY.TABLE_ADD",
|
||||
@@ -1959,16 +1977,14 @@ def connect(ray_params,
|
||||
|
||||
# Set the driver's current task ID to the task ID assigned to the
|
||||
# driver task.
|
||||
worker.current_task_id = driver_task.task_id()
|
||||
else:
|
||||
# A non-driver worker begins without an assigned task.
|
||||
worker.current_task_id = ray.ObjectID(NIL_ID)
|
||||
# A flag for making sure that we only print one warning message about
|
||||
# multithreading per worker.
|
||||
worker.multithreading_warned = False
|
||||
worker.task_context.current_task_id = driver_task.task_id()
|
||||
|
||||
worker.raylet_client = ray.raylet.RayletClient(
|
||||
raylet_socket, worker.worker_id, is_worker, worker.current_task_id)
|
||||
raylet_socket,
|
||||
worker.worker_id,
|
||||
is_worker,
|
||||
worker.current_task_id,
|
||||
)
|
||||
|
||||
# Start the import thread
|
||||
import_thread.ImportThread(worker, mode).start()
|
||||
@@ -2254,9 +2270,11 @@ def put(value, worker=global_worker):
|
||||
# In LOCAL_MODE, ray.put is the identity operation.
|
||||
return value
|
||||
object_id = worker.raylet_client.compute_put_id(
|
||||
worker.current_task_id, worker.put_index)
|
||||
worker.current_task_id,
|
||||
worker.task_context.put_index,
|
||||
)
|
||||
worker.put_object(object_id, value)
|
||||
worker.put_index += 1
|
||||
worker.task_context.put_index += 1
|
||||
return object_id
|
||||
|
||||
|
||||
@@ -2342,15 +2360,15 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
|
||||
raise Exception("num_returns cannot be greater than the number "
|
||||
"of objects provided to ray.wait.")
|
||||
|
||||
# Get the task ID, to notify the backend which task is blocked.
|
||||
with worker.state_lock:
|
||||
current_task_id = worker.get_current_thread_task_id()
|
||||
|
||||
timeout = timeout if timeout is not None else 10**6
|
||||
timeout_milliseconds = int(timeout * 1000)
|
||||
ready_ids, remaining_ids = worker.raylet_client.wait(
|
||||
object_ids, num_returns, timeout_milliseconds, False,
|
||||
current_task_id)
|
||||
object_ids,
|
||||
num_returns,
|
||||
timeout_milliseconds,
|
||||
False,
|
||||
worker.current_task_id,
|
||||
)
|
||||
return ready_ids, remaining_ids
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user