Separate thread locks for worker and function manager. (#4499)

* Separate lock for function manager and worker

* Lint

* Add test case

* Remove print in remote function.

* Remove test and add ray.exit_actor

* Update python/ray/worker.py

Co-Authored-By: guoyuhong <guoyuhong1985@outlook.com>

* Move exit_actor from worker.py to actor.py

* Update actor.py

* Update actor.py
This commit is contained in:
Yuhong Guo
2019-04-29 14:55:37 +08:00
committed by GitHub
parent c578be23a5
commit 4eade036a0
4 changed files with 131 additions and 113 deletions
+25 -6
View File
@@ -708,12 +708,7 @@ def make_actor(cls, num_cpus, num_gpus, resources, max_reconstructions):
def __ray_terminate__(self):
worker = ray.worker.get_global_worker()
if worker.mode != ray.LOCAL_MODE:
# Disconnect the worker from the raylet. The point of
# this is so that when the worker kills itself below, the
# raylet won't push an error message to the driver.
worker.raylet_client.disconnect()
sys.exit(0)
assert False, "This process should have terminated."
ray.actor.exit_actor()
def __ray_checkpoint__(self):
"""Save a checkpoint.
@@ -738,6 +733,30 @@ def make_actor(cls, num_cpus, num_gpus, resources, max_reconstructions):
resources)
def exit_actor():
"""Intentionally exit the current actor.
This function is used to disconnect an actor and exit the worker.
Raises:
Exception: An exception is raised if this is a driver or this
worker is not an actor.
"""
worker = ray.worker.global_worker
if worker.mode == ray.WORKER_MODE and not worker.actor_id.is_nil():
# Disconnect the worker from the raylet. The point of
# this is so that when the worker kills itself below, the
# raylet won't push an error message to the driver.
worker.raylet_client.disconnect()
ray.disconnect()
# Disconnect global state from GCS.
ray.global_state.disconnect()
sys.exit(0)
assert False, "This process should have terminated."
else:
raise Exception("exit_actor called on a non-actor worker.")
ray.worker.global_worker.make_actor = make_actor
CheckpointContext = namedtuple(
+54 -45
View File
@@ -9,6 +9,7 @@ import json
import logging
import sys
import time
import threading
import traceback
from collections import (
namedtuple,
@@ -300,6 +301,7 @@ class FunctionActorManager(object):
# these types.
self.imported_actor_classes = set()
self._loaded_actor_classes = {}
self.lock = threading.Lock()
def increase_task_counter(self, driver_id, function_descriptor):
function_id = function_descriptor.function_id
@@ -407,41 +409,48 @@ class FunctionActorManager(object):
def f():
raise Exception("This function was not imported properly.")
self._function_execution_info[driver_id][function_id] = (
FunctionExecutionInfo(
function=f, function_name=function_name, max_calls=max_calls))
self._num_task_executions[driver_id][function_id] = 0
try:
function = pickle.loads(serialized_function)
except Exception:
# If an exception was thrown when the remote function was imported,
# we record the traceback and notify the scheduler of the failure.
traceback_str = format_error_message(traceback.format_exc())
# Log the error message.
push_error_to_driver(
self._worker,
ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR,
"Failed to unpickle the remote function '{}' with function ID "
"{}. Traceback:\n{}".format(function_name, function_id.hex(),
traceback_str),
driver_id=driver_id)
else:
# The below line is necessary. Because in the driver process,
# if the function is defined in the file where the python script
# was started from, its module is `__main__`.
# However in the worker process, the `__main__` module is a
# different module, which is `default_worker.py`
function.__module__ = module
# This function is called by ImportThread. This operation needs to be
# atomic. Otherwise, there is race condition. Another thread may use
# the temporary function above before the real function is ready.
with self.lock:
self._function_execution_info[driver_id][function_id] = (
FunctionExecutionInfo(
function=function,
function=f,
function_name=function_name,
max_calls=max_calls))
# Add the function to the function table.
self._worker.redis_client.rpush(
b"FunctionTable:" + function_id.binary(),
self._worker.worker_id)
self._num_task_executions[driver_id][function_id] = 0
try:
function = pickle.loads(serialized_function)
except Exception:
# If an exception was thrown when the remote function was
# imported, we record the traceback and notify the scheduler
# of the failure.
traceback_str = format_error_message(traceback.format_exc())
# Log the error message.
push_error_to_driver(
self._worker,
ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR,
"Failed to unpickle the remote function '{}' with "
"function ID {}. Traceback:\n{}".format(
function_name, function_id.hex(), traceback_str),
driver_id=driver_id)
else:
# The below line is necessary. Because in the driver process,
# if the function is defined in the file where the python
# script was started from, its module is `__main__`.
# However in the worker process, the `__main__` module is a
# different module, which is `default_worker.py`
function.__module__ = module
self._function_execution_info[driver_id][function_id] = (
FunctionExecutionInfo(
function=function,
function_name=function_name,
max_calls=max_calls))
# Add the function to the function table.
self._worker.redis_client.rpush(
b"FunctionTable:" + function_id.binary(),
self._worker.worker_id)
def get_execution_info(self, driver_id, function_descriptor):
"""Get the FunctionExecutionInfo of a remote function.
@@ -526,7 +535,7 @@ class FunctionActorManager(object):
# Only send the warning once.
warning_sent = False
while True:
with self._worker.lock:
with self.lock:
if (self._worker.actor_id.is_nil()
and (function_descriptor.function_id in
self._function_execution_info[driver_id])):
@@ -534,18 +543,18 @@ class FunctionActorManager(object):
elif not self._worker.actor_id.is_nil() and (
self._worker.actor_id in self._worker.actors):
break
if time.time() - start_time > timeout:
warning_message = ("This worker was asked to execute a "
"function that it does not have "
"registered. You may have to restart "
"Ray.")
if not warning_sent:
ray.utils.push_error_to_driver(
self._worker,
ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR,
warning_message,
driver_id=driver_id)
warning_sent = True
if time.time() - start_time > timeout:
warning_message = ("This worker was asked to execute a "
"function that it does not have "
"registered. You may have to restart "
"Ray.")
if not warning_sent:
ray.utils.push_error_to_driver(
self._worker,
ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR,
warning_message,
driver_id=driver_id)
warning_sent = True
time.sleep(0.001)
def _publish_actor_class_to_key(self, key, actor_class_info):
@@ -716,7 +725,7 @@ class FunctionActorManager(object):
actor_class = None
try:
with self._worker.lock:
with self.lock:
actor_class = pickle.loads(pickled_class)
except Exception:
logger.exception(
+13 -15
View File
@@ -56,11 +56,10 @@ class ImportThread(object):
try:
# Get the exports that occurred before the call to subscribe.
with self.worker.lock:
export_keys = self.redis_client.lrange("Exports", 0, -1)
for key in export_keys:
num_imported += 1
self._process_key(key)
export_keys = self.redis_client.lrange("Exports", 0, -1)
for key in export_keys:
num_imported += 1
self._process_key(key)
while True:
# Exit if we received a signal that we should stop.
@@ -72,16 +71,15 @@ class ImportThread(object):
self.threads_stopped.wait(timeout=0.01)
continue
with self.worker.lock:
if msg["type"] == "subscribe":
continue
assert msg["data"] == b"rpush"
num_imports = self.redis_client.llen("Exports")
assert num_imports >= num_imported
for i in range(num_imported, num_imports):
num_imported += 1
key = self.redis_client.lindex("Exports", i)
self._process_key(key)
if msg["type"] == "subscribe":
continue
assert msg["data"] == b"rpush"
num_imports = self.redis_client.llen("Exports")
assert num_imports >= num_imported
for i in range(num_imported, num_imports):
num_imported += 1
key = self.redis_client.lindex("Exports", i)
self._process_key(key)
finally:
# Close the pubsub client to avoid leaking file descriptors.
import_pubsub_client.close()
+39 -47
View File
@@ -234,9 +234,14 @@ class Worker(object):
Returns:
The serialization context of the given driver.
"""
if driver_id not in self.serialization_context_map:
_initialize_serialization(driver_id)
return self.serialization_context_map[driver_id]
# This function needs to be proctected by a lock, because it will be
# called by`register_class_for_serialization`, as well as the import
# thread, from different threads. Also, this function will recursively
# call itself, so we use RLock here.
with self.lock:
if driver_id not in self.serialization_context_map:
_initialize_serialization(driver_id)
return self.serialization_context_map[driver_id]
def check_connected(self):
"""Check if the worker is connected.
@@ -428,11 +433,7 @@ class Worker(object):
# Wait a little bit for the import thread to import the class.
# If we currently have the worker lock, we need to release it
# so that the import thread can acquire it.
if self.mode == WORKER_MODE:
self.lock.release()
time.sleep(0.01)
if self.mode == WORKER_MODE:
self.lock.acquire()
if time.time() - start_time > error_timeout:
warning_message = ("This worker or driver is waiting to "
@@ -968,45 +969,37 @@ class Worker(object):
driver_id, function_descriptor)
# Execute the task.
# TODO(rkn): Consider acquiring this lock with a timeout and pushing a
# warning to the user if we are waiting too long to acquire the lock
# because that may indicate that the system is hanging, and it'd be
# good to know where the system is hanging.
with self.lock:
function_name = execution_info.function_name
extra_data = {
"name": function_name,
"task_id": task.task_id().hex()
}
if task.actor_id().is_nil():
if task.actor_creation_id().is_nil():
title = "ray_worker:{}()".format(function_name)
next_title = "ray_worker"
else:
actor = self.actors[task.actor_creation_id()]
title = "ray_{}:{}()".format(actor.__class__.__name__,
function_name)
next_title = "ray_{}".format(actor.__class__.__name__)
function_name = execution_info.function_name
extra_data = {"name": function_name, "task_id": task.task_id().hex()}
if task.actor_id().is_nil():
if task.actor_creation_id().is_nil():
title = "ray_worker:{}()".format(function_name)
next_title = "ray_worker"
else:
actor = self.actors[task.actor_id()]
actor = self.actors[task.actor_creation_id()]
title = "ray_{}:{}()".format(actor.__class__.__name__,
function_name)
next_title = "ray_{}".format(actor.__class__.__name__)
with profiling.profile("task", extra_data=extra_data):
with _changeproctitle(title, next_title):
self._process_task(task, execution_info)
# Reset the state fields so the next task can run.
self.task_context.current_task_id = TaskID.nil()
self.task_context.task_index = 0
self.task_context.put_index = 1
if self.actor_id.is_nil():
# Don't need to reset task_driver_id if the worker is an
# actor. Because the following tasks should all have the
# same driver id.
self.task_driver_id = DriverID.nil()
# Reset signal counters so that the next task can get
# all past signals.
ray_signal.reset()
else:
actor = self.actors[task.actor_id()]
title = "ray_{}:{}()".format(actor.__class__.__name__,
function_name)
next_title = "ray_{}".format(actor.__class__.__name__)
with profiling.profile("task", extra_data=extra_data):
with _changeproctitle(title, next_title):
self._process_task(task, execution_info)
# Reset the state fields so the next task can run.
self.task_context.current_task_id = TaskID.nil()
self.task_context.task_index = 0
self.task_context.put_index = 1
if self.actor_id.is_nil():
# Don't need to reset task_driver_id if the worker is an
# actor. Because the following tasks should all have the
# same driver id.
self.task_driver_id = DriverID.nil()
# Reset signal counters so that the next task can get
# all past signals.
ray_signal.reset()
# Increase the task execution counter.
self.function_actor_manager.increase_task_counter(
@@ -1645,10 +1638,9 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
try:
# Get the exports that occurred before the call to subscribe.
with worker.lock:
error_messages = global_state.error_messages(worker.task_driver_id)
for error_message in error_messages:
logger.error(error_message)
error_messages = global_state.error_messages(worker.task_driver_id)
for error_message in error_messages:
logger.error(error_message)
while True:
# Exit if we received a signal that we should stop.
@@ -1774,7 +1766,7 @@ def connect(node,
traceback_str,
driver_id=None)
worker.lock = threading.Lock()
worker.lock = threading.RLock()
# Create an object for interfacing with the global state.
global_state._initialize_global_state(