mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 03:34:48 +08:00
Separate thread locks for worker and function manager. (#4499)
* Separate lock for function manager and worker * Lint * Add test case * Remove print in remote function. * Remove test and add ray.exit_actor * Update python/ray/worker.py Co-Authored-By: guoyuhong <guoyuhong1985@outlook.com> * Move exit_actor from worker.py to actor.py * Update actor.py * Update actor.py
This commit is contained in:
+25
-6
@@ -708,12 +708,7 @@ def make_actor(cls, num_cpus, num_gpus, resources, max_reconstructions):
|
||||
def __ray_terminate__(self):
|
||||
worker = ray.worker.get_global_worker()
|
||||
if worker.mode != ray.LOCAL_MODE:
|
||||
# Disconnect the worker from the raylet. The point of
|
||||
# this is so that when the worker kills itself below, the
|
||||
# raylet won't push an error message to the driver.
|
||||
worker.raylet_client.disconnect()
|
||||
sys.exit(0)
|
||||
assert False, "This process should have terminated."
|
||||
ray.actor.exit_actor()
|
||||
|
||||
def __ray_checkpoint__(self):
|
||||
"""Save a checkpoint.
|
||||
@@ -738,6 +733,30 @@ def make_actor(cls, num_cpus, num_gpus, resources, max_reconstructions):
|
||||
resources)
|
||||
|
||||
|
||||
def exit_actor():
|
||||
"""Intentionally exit the current actor.
|
||||
|
||||
This function is used to disconnect an actor and exit the worker.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if this is a driver or this
|
||||
worker is not an actor.
|
||||
"""
|
||||
worker = ray.worker.global_worker
|
||||
if worker.mode == ray.WORKER_MODE and not worker.actor_id.is_nil():
|
||||
# Disconnect the worker from the raylet. The point of
|
||||
# this is so that when the worker kills itself below, the
|
||||
# raylet won't push an error message to the driver.
|
||||
worker.raylet_client.disconnect()
|
||||
ray.disconnect()
|
||||
# Disconnect global state from GCS.
|
||||
ray.global_state.disconnect()
|
||||
sys.exit(0)
|
||||
assert False, "This process should have terminated."
|
||||
else:
|
||||
raise Exception("exit_actor called on a non-actor worker.")
|
||||
|
||||
|
||||
ray.worker.global_worker.make_actor = make_actor
|
||||
|
||||
CheckpointContext = namedtuple(
|
||||
|
||||
@@ -9,6 +9,7 @@ import json
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
import traceback
|
||||
from collections import (
|
||||
namedtuple,
|
||||
@@ -300,6 +301,7 @@ class FunctionActorManager(object):
|
||||
# these types.
|
||||
self.imported_actor_classes = set()
|
||||
self._loaded_actor_classes = {}
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def increase_task_counter(self, driver_id, function_descriptor):
|
||||
function_id = function_descriptor.function_id
|
||||
@@ -407,41 +409,48 @@ class FunctionActorManager(object):
|
||||
def f():
|
||||
raise Exception("This function was not imported properly.")
|
||||
|
||||
self._function_execution_info[driver_id][function_id] = (
|
||||
FunctionExecutionInfo(
|
||||
function=f, function_name=function_name, max_calls=max_calls))
|
||||
self._num_task_executions[driver_id][function_id] = 0
|
||||
|
||||
try:
|
||||
function = pickle.loads(serialized_function)
|
||||
except Exception:
|
||||
# If an exception was thrown when the remote function was imported,
|
||||
# we record the traceback and notify the scheduler of the failure.
|
||||
traceback_str = format_error_message(traceback.format_exc())
|
||||
# Log the error message.
|
||||
push_error_to_driver(
|
||||
self._worker,
|
||||
ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR,
|
||||
"Failed to unpickle the remote function '{}' with function ID "
|
||||
"{}. Traceback:\n{}".format(function_name, function_id.hex(),
|
||||
traceback_str),
|
||||
driver_id=driver_id)
|
||||
else:
|
||||
# The below line is necessary. Because in the driver process,
|
||||
# if the function is defined in the file where the python script
|
||||
# was started from, its module is `__main__`.
|
||||
# However in the worker process, the `__main__` module is a
|
||||
# different module, which is `default_worker.py`
|
||||
function.__module__ = module
|
||||
# This function is called by ImportThread. This operation needs to be
|
||||
# atomic. Otherwise, there is race condition. Another thread may use
|
||||
# the temporary function above before the real function is ready.
|
||||
with self.lock:
|
||||
self._function_execution_info[driver_id][function_id] = (
|
||||
FunctionExecutionInfo(
|
||||
function=function,
|
||||
function=f,
|
||||
function_name=function_name,
|
||||
max_calls=max_calls))
|
||||
# Add the function to the function table.
|
||||
self._worker.redis_client.rpush(
|
||||
b"FunctionTable:" + function_id.binary(),
|
||||
self._worker.worker_id)
|
||||
self._num_task_executions[driver_id][function_id] = 0
|
||||
|
||||
try:
|
||||
function = pickle.loads(serialized_function)
|
||||
except Exception:
|
||||
# If an exception was thrown when the remote function was
|
||||
# imported, we record the traceback and notify the scheduler
|
||||
# of the failure.
|
||||
traceback_str = format_error_message(traceback.format_exc())
|
||||
# Log the error message.
|
||||
push_error_to_driver(
|
||||
self._worker,
|
||||
ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR,
|
||||
"Failed to unpickle the remote function '{}' with "
|
||||
"function ID {}. Traceback:\n{}".format(
|
||||
function_name, function_id.hex(), traceback_str),
|
||||
driver_id=driver_id)
|
||||
else:
|
||||
# The below line is necessary. Because in the driver process,
|
||||
# if the function is defined in the file where the python
|
||||
# script was started from, its module is `__main__`.
|
||||
# However in the worker process, the `__main__` module is a
|
||||
# different module, which is `default_worker.py`
|
||||
function.__module__ = module
|
||||
self._function_execution_info[driver_id][function_id] = (
|
||||
FunctionExecutionInfo(
|
||||
function=function,
|
||||
function_name=function_name,
|
||||
max_calls=max_calls))
|
||||
# Add the function to the function table.
|
||||
self._worker.redis_client.rpush(
|
||||
b"FunctionTable:" + function_id.binary(),
|
||||
self._worker.worker_id)
|
||||
|
||||
def get_execution_info(self, driver_id, function_descriptor):
|
||||
"""Get the FunctionExecutionInfo of a remote function.
|
||||
@@ -526,7 +535,7 @@ class FunctionActorManager(object):
|
||||
# Only send the warning once.
|
||||
warning_sent = False
|
||||
while True:
|
||||
with self._worker.lock:
|
||||
with self.lock:
|
||||
if (self._worker.actor_id.is_nil()
|
||||
and (function_descriptor.function_id in
|
||||
self._function_execution_info[driver_id])):
|
||||
@@ -534,18 +543,18 @@ class FunctionActorManager(object):
|
||||
elif not self._worker.actor_id.is_nil() and (
|
||||
self._worker.actor_id in self._worker.actors):
|
||||
break
|
||||
if time.time() - start_time > timeout:
|
||||
warning_message = ("This worker was asked to execute a "
|
||||
"function that it does not have "
|
||||
"registered. You may have to restart "
|
||||
"Ray.")
|
||||
if not warning_sent:
|
||||
ray.utils.push_error_to_driver(
|
||||
self._worker,
|
||||
ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR,
|
||||
warning_message,
|
||||
driver_id=driver_id)
|
||||
warning_sent = True
|
||||
if time.time() - start_time > timeout:
|
||||
warning_message = ("This worker was asked to execute a "
|
||||
"function that it does not have "
|
||||
"registered. You may have to restart "
|
||||
"Ray.")
|
||||
if not warning_sent:
|
||||
ray.utils.push_error_to_driver(
|
||||
self._worker,
|
||||
ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR,
|
||||
warning_message,
|
||||
driver_id=driver_id)
|
||||
warning_sent = True
|
||||
time.sleep(0.001)
|
||||
|
||||
def _publish_actor_class_to_key(self, key, actor_class_info):
|
||||
@@ -716,7 +725,7 @@ class FunctionActorManager(object):
|
||||
|
||||
actor_class = None
|
||||
try:
|
||||
with self._worker.lock:
|
||||
with self.lock:
|
||||
actor_class = pickle.loads(pickled_class)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
|
||||
+13
-15
@@ -56,11 +56,10 @@ class ImportThread(object):
|
||||
|
||||
try:
|
||||
# Get the exports that occurred before the call to subscribe.
|
||||
with self.worker.lock:
|
||||
export_keys = self.redis_client.lrange("Exports", 0, -1)
|
||||
for key in export_keys:
|
||||
num_imported += 1
|
||||
self._process_key(key)
|
||||
export_keys = self.redis_client.lrange("Exports", 0, -1)
|
||||
for key in export_keys:
|
||||
num_imported += 1
|
||||
self._process_key(key)
|
||||
|
||||
while True:
|
||||
# Exit if we received a signal that we should stop.
|
||||
@@ -72,16 +71,15 @@ class ImportThread(object):
|
||||
self.threads_stopped.wait(timeout=0.01)
|
||||
continue
|
||||
|
||||
with self.worker.lock:
|
||||
if msg["type"] == "subscribe":
|
||||
continue
|
||||
assert msg["data"] == b"rpush"
|
||||
num_imports = self.redis_client.llen("Exports")
|
||||
assert num_imports >= num_imported
|
||||
for i in range(num_imported, num_imports):
|
||||
num_imported += 1
|
||||
key = self.redis_client.lindex("Exports", i)
|
||||
self._process_key(key)
|
||||
if msg["type"] == "subscribe":
|
||||
continue
|
||||
assert msg["data"] == b"rpush"
|
||||
num_imports = self.redis_client.llen("Exports")
|
||||
assert num_imports >= num_imported
|
||||
for i in range(num_imported, num_imports):
|
||||
num_imported += 1
|
||||
key = self.redis_client.lindex("Exports", i)
|
||||
self._process_key(key)
|
||||
finally:
|
||||
# Close the pubsub client to avoid leaking file descriptors.
|
||||
import_pubsub_client.close()
|
||||
|
||||
+39
-47
@@ -234,9 +234,14 @@ class Worker(object):
|
||||
Returns:
|
||||
The serialization context of the given driver.
|
||||
"""
|
||||
if driver_id not in self.serialization_context_map:
|
||||
_initialize_serialization(driver_id)
|
||||
return self.serialization_context_map[driver_id]
|
||||
# This function needs to be proctected by a lock, because it will be
|
||||
# called by`register_class_for_serialization`, as well as the import
|
||||
# thread, from different threads. Also, this function will recursively
|
||||
# call itself, so we use RLock here.
|
||||
with self.lock:
|
||||
if driver_id not in self.serialization_context_map:
|
||||
_initialize_serialization(driver_id)
|
||||
return self.serialization_context_map[driver_id]
|
||||
|
||||
def check_connected(self):
|
||||
"""Check if the worker is connected.
|
||||
@@ -428,11 +433,7 @@ class Worker(object):
|
||||
# Wait a little bit for the import thread to import the class.
|
||||
# If we currently have the worker lock, we need to release it
|
||||
# so that the import thread can acquire it.
|
||||
if self.mode == WORKER_MODE:
|
||||
self.lock.release()
|
||||
time.sleep(0.01)
|
||||
if self.mode == WORKER_MODE:
|
||||
self.lock.acquire()
|
||||
|
||||
if time.time() - start_time > error_timeout:
|
||||
warning_message = ("This worker or driver is waiting to "
|
||||
@@ -968,45 +969,37 @@ class Worker(object):
|
||||
driver_id, function_descriptor)
|
||||
|
||||
# Execute the task.
|
||||
# TODO(rkn): Consider acquiring this lock with a timeout and pushing a
|
||||
# warning to the user if we are waiting too long to acquire the lock
|
||||
# because that may indicate that the system is hanging, and it'd be
|
||||
# good to know where the system is hanging.
|
||||
with self.lock:
|
||||
function_name = execution_info.function_name
|
||||
extra_data = {
|
||||
"name": function_name,
|
||||
"task_id": task.task_id().hex()
|
||||
}
|
||||
if task.actor_id().is_nil():
|
||||
if task.actor_creation_id().is_nil():
|
||||
title = "ray_worker:{}()".format(function_name)
|
||||
next_title = "ray_worker"
|
||||
else:
|
||||
actor = self.actors[task.actor_creation_id()]
|
||||
title = "ray_{}:{}()".format(actor.__class__.__name__,
|
||||
function_name)
|
||||
next_title = "ray_{}".format(actor.__class__.__name__)
|
||||
function_name = execution_info.function_name
|
||||
extra_data = {"name": function_name, "task_id": task.task_id().hex()}
|
||||
if task.actor_id().is_nil():
|
||||
if task.actor_creation_id().is_nil():
|
||||
title = "ray_worker:{}()".format(function_name)
|
||||
next_title = "ray_worker"
|
||||
else:
|
||||
actor = self.actors[task.actor_id()]
|
||||
actor = self.actors[task.actor_creation_id()]
|
||||
title = "ray_{}:{}()".format(actor.__class__.__name__,
|
||||
function_name)
|
||||
next_title = "ray_{}".format(actor.__class__.__name__)
|
||||
with profiling.profile("task", extra_data=extra_data):
|
||||
with _changeproctitle(title, next_title):
|
||||
self._process_task(task, execution_info)
|
||||
# Reset the state fields so the next task can run.
|
||||
self.task_context.current_task_id = TaskID.nil()
|
||||
self.task_context.task_index = 0
|
||||
self.task_context.put_index = 1
|
||||
if self.actor_id.is_nil():
|
||||
# Don't need to reset task_driver_id if the worker is an
|
||||
# actor. Because the following tasks should all have the
|
||||
# same driver id.
|
||||
self.task_driver_id = DriverID.nil()
|
||||
# Reset signal counters so that the next task can get
|
||||
# all past signals.
|
||||
ray_signal.reset()
|
||||
else:
|
||||
actor = self.actors[task.actor_id()]
|
||||
title = "ray_{}:{}()".format(actor.__class__.__name__,
|
||||
function_name)
|
||||
next_title = "ray_{}".format(actor.__class__.__name__)
|
||||
with profiling.profile("task", extra_data=extra_data):
|
||||
with _changeproctitle(title, next_title):
|
||||
self._process_task(task, execution_info)
|
||||
# Reset the state fields so the next task can run.
|
||||
self.task_context.current_task_id = TaskID.nil()
|
||||
self.task_context.task_index = 0
|
||||
self.task_context.put_index = 1
|
||||
if self.actor_id.is_nil():
|
||||
# Don't need to reset task_driver_id if the worker is an
|
||||
# actor. Because the following tasks should all have the
|
||||
# same driver id.
|
||||
self.task_driver_id = DriverID.nil()
|
||||
# Reset signal counters so that the next task can get
|
||||
# all past signals.
|
||||
ray_signal.reset()
|
||||
|
||||
# Increase the task execution counter.
|
||||
self.function_actor_manager.increase_task_counter(
|
||||
@@ -1645,10 +1638,9 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
|
||||
|
||||
try:
|
||||
# Get the exports that occurred before the call to subscribe.
|
||||
with worker.lock:
|
||||
error_messages = global_state.error_messages(worker.task_driver_id)
|
||||
for error_message in error_messages:
|
||||
logger.error(error_message)
|
||||
error_messages = global_state.error_messages(worker.task_driver_id)
|
||||
for error_message in error_messages:
|
||||
logger.error(error_message)
|
||||
|
||||
while True:
|
||||
# Exit if we received a signal that we should stop.
|
||||
@@ -1774,7 +1766,7 @@ def connect(node,
|
||||
traceback_str,
|
||||
driver_id=None)
|
||||
|
||||
worker.lock = threading.Lock()
|
||||
worker.lock = threading.RLock()
|
||||
|
||||
# Create an object for interfacing with the global state.
|
||||
global_state._initialize_global_state(
|
||||
|
||||
Reference in New Issue
Block a user