Refine multi-threading support (#3672)

* [Python] refine multi-threading support

fix

* [java] refine multithreading code

fix java

* format
This commit is contained in:
Hao Chen
2019-01-11 05:58:11 +08:00
committed by Stephanie Wang
parent 71243203a4
commit 597abb24ea
9 changed files with 394 additions and 313 deletions
+28 -35
View File
@@ -7,6 +7,7 @@ import hashlib
import inspect
import logging
import sys
import threading
import traceback
import ray.cloudpickle as pickle
@@ -225,8 +226,7 @@ class ActorMethod(object):
self._method_name,
args=args,
kwargs=kwargs,
num_return_vals=num_return_vals,
dependency=self._actor._ray_actor_cursor)
num_return_vals=num_return_vals)
class ActorClass(object):
@@ -525,13 +525,13 @@ class ActorHandle(object):
self._ray_actor_method_cpus = actor_method_cpus
self._ray_actor_driver_id = actor_driver_id
self._ray_new_actor_handles = []
self._ray_actor_lock = threading.Lock()
def _actor_method_call(self,
method_name,
args=None,
kwargs=None,
num_return_vals=None,
dependency=None):
num_return_vals=None):
"""Method execution stub for an actor handle.
This is the function that executes when
@@ -570,41 +570,34 @@ class ActorHandle(object):
return getattr(worker.actors[self._ray_actor_id],
method_name)(*copy.deepcopy(args))
# Add the execution dependency.
if dependency is None:
execution_dependencies = []
else:
execution_dependencies = [dependency]
is_actor_checkpoint_method = (method_name == "__ray_checkpoint__")
function_descriptor = FunctionDescriptor(
self._ray_module_name, method_name, self._ray_class_name)
object_ids = worker.submit_task(
function_descriptor,
args,
actor_id=self._ray_actor_id,
actor_handle_id=self._ray_actor_handle_id,
actor_counter=self._ray_actor_counter,
is_actor_checkpoint_method=is_actor_checkpoint_method,
actor_creation_dummy_object_id=(
self._ray_actor_creation_dummy_object_id),
execution_dependencies=execution_dependencies,
new_actor_handles=self._ray_new_actor_handles,
# We add one for the dummy return ID.
num_return_vals=num_return_vals + 1,
resources={"CPU": self._ray_actor_method_cpus},
placement_resources={},
driver_id=self._ray_actor_driver_id)
# Update the actor counter and cursor to reflect the most recent
# invocation.
self._ray_actor_counter += 1
# The last object returned is the dummy object that should be
# passed in to the next actor method. Do not return it to the user.
self._ray_actor_cursor = object_ids.pop()
# We have notified the backend of the new actor handles to expect since
# the last task was submitted, so clear the list.
self._ray_new_actor_handles = []
with self._ray_actor_lock:
object_ids = worker.submit_task(
function_descriptor,
args,
actor_id=self._ray_actor_id,
actor_handle_id=self._ray_actor_handle_id,
actor_counter=self._ray_actor_counter,
is_actor_checkpoint_method=is_actor_checkpoint_method,
actor_creation_dummy_object_id=(
self._ray_actor_creation_dummy_object_id),
execution_dependencies=[self._ray_actor_cursor],
new_actor_handles=self._ray_new_actor_handles,
# We add one for the dummy return ID.
num_return_vals=num_return_vals + 1,
resources={"CPU": self._ray_actor_method_cpus},
placement_resources={},
driver_id=self._ray_actor_driver_id,
)
# Update the actor counter and cursor to reflect the most recent
# invocation.
self._ray_actor_counter += 1
# The last object returned is the dummy object that should be
# passed in to the next actor method. Do not return it to the user.
self._ray_actor_cursor = object_ids.pop()
if len(object_ids) == 1:
object_ids = object_ids[0]
+157 -139
View File
@@ -143,13 +143,6 @@ class Worker(object):
cached_functions_to_run (List): A list of functions to run on all of
the workers that should be exported as soon as connect is called.
profiler: the profiler used to aggregate profiling information.
state_lock (Lock):
Used to lock worker's non-thread-safe internal states:
1) task_index increment: make sure we generate unique task ids;
2) Object reconstruction: because the node manager will
recycle/return the worker's resources before/after reconstruction,
it's unsafe for multiple threads to call object
reconstruction simultaneously.
"""
def __init__(self):
@@ -169,42 +162,56 @@ class Worker(object):
self.original_gpu_ids = ray.utils.get_cuda_visible_devices()
self.profiler = None
self.memory_monitor = memory_monitor.MemoryMonitor()
self.state_lock = threading.Lock()
# A dictionary that maps from driver id to SerializationContext
# TODO: clean up the SerializationContext once the job finished.
self.serialization_context_map = {}
self.function_actor_manager = FunctionActorManager(self)
# Reads/writes to the following fields must be protected by
# self.state_lock.
# Identity of the driver that this worker is processing.
self.task_driver_id = ray.ObjectID(NIL_ID)
self.current_task_id = ray.ObjectID(NIL_ID)
self.task_index = 0
self.put_index = 1
self._task_context = threading.local()
def get_current_thread_task_id(self):
"""Get the current thread's task ID.
@property
def task_context(self):
"""A thread-local that contains the following attributes.
This returns the assigned task ID if called on the main thread, else a
random task ID. This method is not thread-safe and must be called with
self.state_lock acquired.
current_task_id: For the main thread, this field is the ID of this
worker's current running task; for other threads, this field is a
fake random ID.
task_index: The number of tasks that have been submitted from the
current task.
put_index: The number of objects that have been put from the current
task.
"""
current_task_id = self.current_task_id
if not ray.utils.is_main_thread():
# If this is running on a separate thread, then the mapping
# to the current task ID may not be correct. Generate a
# random task ID so that the backend can differentiate
# between different threads.
current_task_id = ray.ObjectID(random_string())
if not self.multithreading_warned:
logger.warning(
"Calling ray.get or ray.wait in a separate thread "
"may lead to deadlock if the main thread blocks on this "
"thread and there are not enough resources to execute "
"more tasks")
self.multithreading_warned = True
assert not current_task_id.is_nil()
return current_task_id
if not hasattr(self._task_context, 'initialized'):
# Initialize task_context for the current thread.
if ray.utils.is_main_thread():
# If this is running on the main thread, initialize it to
# NIL. The actual value will set when the worker receives
# a task from raylet backend.
self._task_context.current_task_id = ray.ObjectID(NIL_ID)
else:
# If this is running on a separate thread, then the mapping
# to the current task ID may not be correct. Generate a
# random task ID so that the backend can differentiate
# between different threads.
self._task_context.current_task_id = ray.ObjectID(
random_string())
if getattr(self, '_multithreading_warned', False) is not True:
logger.warning(
"Calling ray.get or ray.wait in a separate thread "
"may lead to deadlock if the main thread blocks on "
"this thread and there are not enough resources to "
"execute more tasks")
self._multithreading_warned = True
self._task_context.task_index = 0
self._task_context.put_index = 1
self._task_context.initialized = True
return self._task_context
@property
def current_task_id(self):
return self.task_context.current_task_id
def mark_actor_init_failed(self, error):
"""Called to mark this actor as failed during initialization."""
@@ -467,48 +474,45 @@ class Worker(object):
}
if len(unready_ids) > 0:
with self.state_lock:
# Get the task ID, to notify the backend which task is blocked.
current_task_id = self.get_current_thread_task_id()
# Try reconstructing any objects we haven't gotten yet. Try to
# get them until at least get_timeout_milliseconds
# milliseconds passes, then repeat.
while len(unready_ids) > 0:
object_ids_to_fetch = [
plasma.ObjectID(unready_id)
for unready_id in unready_ids.keys()
]
ray_object_ids_to_fetch = [
ray.ObjectID(unready_id)
for unready_id in unready_ids.keys()
]
fetch_request_size = ray._config.worker_fetch_request_size()
for i in range(0, len(object_ids_to_fetch),
fetch_request_size):
self.raylet_client.fetch_or_reconstruct(
ray_object_ids_to_fetch[i:(i + fetch_request_size)],
False,
self.current_task_id,
)
results = self.retrieve_and_deserialize(
object_ids_to_fetch,
max([
ray._config.get_timeout_milliseconds(),
int(0.01 * len(unready_ids)),
]),
)
# Remove any entries for objects we received during this
# iteration so we don't retrieve the same object twice.
for i, val in enumerate(results):
if val is not plasma.ObjectNotAvailable:
object_id = object_ids_to_fetch[i].binary()
index = unready_ids[object_id]
final_results[index] = val
unready_ids.pop(object_id)
# Try reconstructing any objects we haven't gotten yet. Try to
# get them until at least get_timeout_milliseconds
# milliseconds passes, then repeat.
while len(unready_ids) > 0:
object_ids_to_fetch = [
plasma.ObjectID(unready_id)
for unready_id in unready_ids.keys()
]
ray_object_ids_to_fetch = [
ray.ObjectID(unready_id)
for unready_id in unready_ids.keys()
]
fetch_request_size = (
ray._config.worker_fetch_request_size())
for i in range(0, len(object_ids_to_fetch),
fetch_request_size):
self.raylet_client.fetch_or_reconstruct(
ray_object_ids_to_fetch[i:(
i + fetch_request_size)], False,
current_task_id)
results = self.retrieve_and_deserialize(
object_ids_to_fetch,
max([
ray._config.get_timeout_milliseconds(),
int(0.01 * len(unready_ids))
]))
# Remove any entries for objects we received during this
# iteration so we don't retrieve the same object twice.
for i, val in enumerate(results):
if val is not plasma.ObjectNotAvailable:
object_id = object_ids_to_fetch[i].binary()
index = unready_ids[object_id]
final_results[index] = val
unready_ids.pop(object_id)
# If there were objects that we weren't able to get locally,
# let the local scheduler know that we're now unblocked.
self.raylet_client.notify_unblocked(current_task_id)
# If there were objects that we weren't able to get locally,
# let the local scheduler know that we're now unblocked.
self.raylet_client.notify_unblocked(self.current_task_id)
assert len(final_results) == len(object_ids)
return final_results
@@ -616,24 +620,32 @@ class Worker(object):
if placement_resources is None:
placement_resources = {}
with self.state_lock:
# Increment the worker's task index to track how many tasks
# have been submitted by the current task so far.
task_index = self.task_index
self.task_index += 1
# The parent task must be set for the submitted task.
if self.actor_id == NIL_ACTOR_ID:
assert not self.current_task_id.is_nil()
# Increment the worker's task index to track how many tasks
# have been submitted by the current task so far.
self.task_context.task_index += 1
# The parent task must be set for the submitted task.
assert not self.current_task_id.is_nil()
# Submit the task to local scheduler.
function_descriptor_list = (
function_descriptor.get_function_descriptor_list())
task = ray.raylet.Task(
driver_id, function_descriptor_list, args_for_local_scheduler,
num_return_vals, self.current_task_id, task_index,
actor_creation_id, actor_creation_dummy_object_id,
max_actor_reconstructions, actor_id, actor_handle_id,
actor_counter, new_actor_handles, execution_dependencies,
resources, placement_resources)
driver_id,
function_descriptor_list,
args_for_local_scheduler,
num_return_vals,
self.current_task_id,
self.task_context.task_index,
actor_creation_id,
actor_creation_dummy_object_id,
max_actor_reconstructions,
actor_id,
actor_handle_id,
actor_counter,
new_actor_handles,
execution_dependencies,
resources,
placement_resources,
)
self.raylet_client.submit_task(task)
return task.returns()
@@ -770,24 +782,23 @@ class Worker(object):
(these will be retrieved by calls to get or by subsequent tasks that
use the outputs of this task).
"""
with self.state_lock:
assert self.current_task_id.is_nil()
assert self.task_index == 0
assert self.put_index == 1
if task.actor_id().is_nil():
# If this worker is not an actor, check that `task_driver_id`
# was reset when the worker finished the previous task.
assert self.task_driver_id.is_nil()
# Set the driver ID of the current running task. This is
# needed so that if the task throws an exception, we propagate
# the error message to the correct driver.
self.task_driver_id = task.driver_id()
else:
# If this worker is an actor, task_driver_id wasn't reset.
# Check that current task's driver ID equals the previous one.
assert self.task_driver_id == task.driver_id()
assert self.current_task_id.is_nil()
assert self.task_context.task_index == 0
assert self.task_context.put_index == 1
if task.actor_id().is_nil():
# If this worker is not an actor, check that `task_driver_id`
# was reset when the worker finished the previous task.
assert self.task_driver_id.is_nil()
# Set the driver ID of the current running task. This is
# needed so that if the task throws an exception, we propagate
# the error message to the correct driver.
self.task_driver_id = task.driver_id()
else:
# If this worker is an actor, task_driver_id wasn't reset.
# Check that current task's driver ID equals the previous one.
assert self.task_driver_id == task.driver_id()
self.current_task_id = task.task_id()
self.task_context.current_task_id = task.task_id()
function_descriptor = FunctionDescriptor.from_bytes_list(
task.function_descriptor_list())
@@ -931,13 +942,14 @@ class Worker(object):
with _changeproctitle(title, next_title):
self._process_task(task, execution_info)
# Reset the state fields so the next task can run.
with self.state_lock:
if self.actor_id == NIL_ACTOR_ID:
# We will keep task_driver_id unchanged for actor.
self.task_driver_id = ray.ObjectID(NIL_ID)
self.current_task_id = ray.ObjectID(NIL_ID)
self.task_index = 0
self.put_index = 1
self.task_context.current_task_id = ray.ObjectID(NIL_ID)
self.task_context.task_index = 0
self.task_context.put_index = 1
if self.actor_id == NIL_ACTOR_ID:
# Don't need to reset task_driver_id if the worker is an
# actor. Because the following tasks should all have the
# same driver id.
self.task_driver_id = ray.ObjectID(NIL_ID)
# Increase the task execution counter.
self.function_actor_manager.increase_task_counter(
@@ -1925,13 +1937,8 @@ def connect(ray_params,
else:
# Try to use true randomness.
np.random.seed(None)
worker.current_task_id = ray.ObjectID(
np.random.bytes(ray_constants.ID_SIZE))
# Reset the state of the numpy random number generator.
np.random.set_state(numpy_state)
# Set other fields needed for computing task IDs.
worker.task_index = 0
worker.put_index = 1
# Create an entry for the driver task in the task table. This task is
# added immediately with status RUNNING. This allows us to push errors
@@ -1944,11 +1951,22 @@ def connect(ray_params,
function_descriptor = FunctionDescriptor.for_driver_task()
driver_task = ray.raylet.Task(
worker.task_driver_id,
function_descriptor.get_function_descriptor_list(), [], 0,
worker.current_task_id, worker.task_index,
ray.ObjectID(NIL_ACTOR_ID), ray.ObjectID(NIL_ACTOR_ID), 0,
ray.ObjectID(NIL_ACTOR_ID), ray.ObjectID(NIL_ACTOR_ID),
nil_actor_counter, [], [], {"CPU": 0}, {})
function_descriptor.get_function_descriptor_list(),
[], # arguments.
0, # num_returns.
ray.ObjectID(random_string()), # parent_task_id.
0, # parent_counter.
ray.ObjectID(NIL_ACTOR_ID), # actor_creation_id.
ray.ObjectID(NIL_ACTOR_ID), # actor_creation_dummy_object_id.
0, # max_actor_reconstructions.
ray.ObjectID(NIL_ACTOR_ID), # actor_id.
ray.ObjectID(NIL_ACTOR_ID), # actor_handle_id.
nil_actor_counter, # actor_counter.
[], # new_actor_handles.
[], # execution_dependencies.
{"CPU": 0}, # resource_map.
{}, # placement_resource_map.
)
# Add the driver task to the task table.
global_state._execute_command(driver_task.task_id(), "RAY.TABLE_ADD",
@@ -1959,16 +1977,14 @@ def connect(ray_params,
# Set the driver's current task ID to the task ID assigned to the
# driver task.
worker.current_task_id = driver_task.task_id()
else:
# A non-driver worker begins without an assigned task.
worker.current_task_id = ray.ObjectID(NIL_ID)
# A flag for making sure that we only print one warning message about
# multithreading per worker.
worker.multithreading_warned = False
worker.task_context.current_task_id = driver_task.task_id()
worker.raylet_client = ray.raylet.RayletClient(
raylet_socket, worker.worker_id, is_worker, worker.current_task_id)
raylet_socket,
worker.worker_id,
is_worker,
worker.current_task_id,
)
# Start the import thread
import_thread.ImportThread(worker, mode).start()
@@ -2254,9 +2270,11 @@ def put(value, worker=global_worker):
# In LOCAL_MODE, ray.put is the identity operation.
return value
object_id = worker.raylet_client.compute_put_id(
worker.current_task_id, worker.put_index)
worker.current_task_id,
worker.task_context.put_index,
)
worker.put_object(object_id, value)
worker.put_index += 1
worker.task_context.put_index += 1
return object_id
@@ -2342,15 +2360,15 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
raise Exception("num_returns cannot be greater than the number "
"of objects provided to ray.wait.")
# Get the task ID, to notify the backend which task is blocked.
with worker.state_lock:
current_task_id = worker.get_current_thread_task_id()
timeout = timeout if timeout is not None else 10**6
timeout_milliseconds = int(timeout * 1000)
ready_ids, remaining_ids = worker.raylet_client.wait(
object_ids, num_returns, timeout_milliseconds, False,
current_task_id)
object_ids,
num_returns,
timeout_milliseconds,
False,
worker.current_task_id,
)
return ready_ids, remaining_ids