[core worker] Python core worker task execution (#5783)

Executes tasks via the event loop in the C++ core worker. Also properly handles signals (including KeyboardInterrupt), so ctrl-C in a python interactive shell works now (if connecting to an existing cluster).
This commit is contained in:
Edward Oakes
2019-10-22 20:15:59 -07:00
committed by GitHub
parent 95241f6686
commit 02931e08f3
38 changed files with 830 additions and 678 deletions
+30 -287
View File
@@ -26,7 +26,6 @@ import random
import pyarrow
import pyarrow.plasma as plasma
import ray.cloudpickle as pickle
import ray.experimental.signal as ray_signal
import ray.experimental.no_return
import ray.gcs_utils
import ray.memory_monitor as memory_monitor
@@ -41,7 +40,6 @@ import ray.state
from ray import (
ActorID,
WorkerID,
JobID,
ObjectID,
TaskID,
@@ -60,10 +58,7 @@ from ray.exceptions import (
UnreconstructableError,
RAY_EXCEPTION_TYPES,
)
from ray.function_manager import (
FunctionActorManager,
FunctionDescriptor,
)
from ray.function_manager import FunctionActorManager
from ray.utils import (
_random_string,
check_oversized_pickle,
@@ -156,7 +151,6 @@ class Worker(object):
# Index of the current session. This number will
# increment every time when `ray.shutdown` is called.
self._session_index = 0
self._current_task = None
# Functions to run to process the values returned by ray.get. Each
# postprocessor must take two arguments ("object_ids", and "values").
self._post_get_hooks = []
@@ -473,9 +467,10 @@ class Worker(object):
logger.warning(warning_message)
self.store_and_register(object_id, value)
def retrieve_and_deserialize(self, object_ids, error_timeout=10):
data_metadata_pairs = self.core_worker.get_objects(
object_ids, self.current_task_id)
def deserialize_objects(self,
data_metadata_pairs,
object_ids,
error_timeout=10):
assert len(data_metadata_pairs) == len(object_ids)
start_time = time.time()
@@ -571,9 +566,9 @@ class Worker(object):
if self.mode == LOCAL_MODE:
return self.local_mode_manager.get_objects(object_ids)
results = self.retrieve_and_deserialize(object_ids)
assert len(results) == len(object_ids)
return results
data_metadata_pairs = self.core_worker.get_objects(
object_ids, self.current_task_id)
return self.deserialize_objects(data_metadata_pairs, object_ids)
def run_function_on_all_workers(self, function,
run_on_other_drivers=False):
@@ -679,149 +674,6 @@ class Worker(object):
return ray.signature.recover_args(arguments)
def _store_outputs_in_object_store(self, object_ids, outputs):
"""Store the outputs of a remote function in the local object store.
This stores the values that were returned by a remote function in the
local object store. If any of the return values are object IDs, then
these object IDs are aliased with the object IDs that the scheduler
assigned for the return values. This is called by the worker that
executes the remote function.
Note:
The arguments object_ids and outputs should have the same length.
Args:
object_ids (List[ObjectID]): The object IDs that were assigned to
the outputs of the remote function call.
outputs (Tuple): The value returned by the remote function. If the
remote function was supposed to only return one value, then its
output was wrapped in a tuple with one element prior to being
passed into this function.
"""
for i in range(len(object_ids)):
if isinstance(outputs[i], ray.actor.ActorHandle):
raise Exception("Returning an actor handle from a remote "
"function is not allowed).")
if outputs[i] is ray.experimental.no_return.NoReturn:
if not self.core_worker.object_exists(object_ids[i]):
raise RuntimeError(
"Attempting to return 'ray.experimental.NoReturn' "
"from a remote function, but the corresponding "
"ObjectID does not exist in the local object store.")
else:
self.put_object(object_ids[i], outputs[i])
def _process_task(self, task, function_execution_info):
"""Execute a task assigned to this worker.
This method deserializes a task from the scheduler, and attempts to
execute the task. If the task succeeds, the outputs are stored in the
local object store. If the task throws an exception, RayTaskError
objects are stored in the object store to represent the failed task
(these will be retrieved by calls to get or by subsequent tasks that
use the outputs of this task).
"""
assert self.current_task_id.is_nil()
assert self.task_context.task_index == 0
assert self.task_context.put_index == 1
if not task.is_actor_task():
# If this worker is not an actor, check that `current_job_id`
# was reset when the worker finished the previous task.
assert self.current_job_id.is_nil()
# Set the driver ID of the current running task. This is
# needed so that if the task throws an exception, we propagate
# the error message to the correct driver.
self.current_job_id = task.job_id()
self.core_worker.set_current_job_id(task.job_id())
else:
# If this worker is an actor, current_job_id wasn't reset.
# Check that current task's driver ID equals the previous one.
assert self.current_job_id == task.job_id()
self.task_context.current_task_id = task.task_id()
self.core_worker.set_current_task_id(task.task_id())
function_descriptor = FunctionDescriptor.from_bytes_list(
task.function_descriptor_list())
serialized_args = task.arguments()
return_object_ids = task.returns()
if task.is_actor_task() or task.is_actor_creation_task():
dummy_return_id = return_object_ids.pop()
function_executor = function_execution_info.function
function_name = function_execution_info.function_name
# Get task arguments from the object store.
try:
if function_name != "__ray_terminate__":
self.reraise_actor_init_error()
self.memory_monitor.raise_if_low_memory()
with profiling.profile("task:deserialize_arguments"):
function_args, function_kwargs = (
self._get_arguments_for_execution(function_name,
serialized_args))
except Exception as e:
self._handle_process_task_failure(
function_descriptor, return_object_ids, e,
ray.utils.format_error_message(traceback.format_exc()))
return
# Execute the task.
try:
self._current_task = task
with profiling.profile("task:execute"):
if task.is_normal_task():
outputs = function_executor(*function_args,
**function_kwargs)
else:
if task.is_actor_task():
key = task.actor_id()
else:
key = task.actor_creation_id()
worker_name = "ray_{}_{}".format(
self.actors[key].__class__.__name__, os.getpid())
if "memory" in task.required_resources():
self.memory_monitor.set_heap_limit(
worker_name,
ray_constants.from_memory_units(
task.required_resources()["memory"]))
if "object_store_memory" in task.required_resources():
self._set_object_store_client_options(
worker_name,
int(
ray_constants.from_memory_units(
task.required_resources()[
"object_store_memory"])))
outputs = function_executor(
dummy_return_id, self.actors[key], *function_args,
**function_kwargs)
except Exception as e:
# Determine whether the exception occured during a task, not an
# actor method.
task_exception = not task.is_actor_task()
traceback_str = ray.utils.format_error_message(
traceback.format_exc(), task_exception=task_exception)
self._handle_process_task_failure(
function_descriptor, return_object_ids, e, traceback_str)
return
finally:
self._current_task = None
# Store the outputs in the local object store.
try:
with profiling.profile("task:store_outputs"):
# If this is an actor task, then the last object ID returned by
# the task is a dummy output, not returned by the function
# itself. Decrement to get the correct number of return values.
num_returns = len(return_object_ids)
if num_returns == 1:
outputs = (outputs, )
self._store_outputs_in_object_store(return_object_ids, outputs)
except Exception as e:
self._handle_process_task_failure(
function_descriptor, return_object_ids, e,
ray.utils.format_error_message(traceback.format_exc()))
def _set_object_store_client_options(self, name, object_store_memory):
try:
logger.debug("Setting plasma memory limit to {} for {}".format(
@@ -838,133 +690,15 @@ class Worker(object):
"object store memory status is:\n\n{}".format(
object_store_memory, name, e))
def _handle_process_task_failure(self, function_descriptor,
return_object_ids, error, backtrace):
function_name = function_descriptor.function_name
if isinstance(error, RayTaskError):
# avoid recursively nesting of RayTaskError
failure_object = RayTaskError(function_name, backtrace,
error.cause_cls)
else:
failure_object = RayTaskError(function_name, backtrace,
error.__class__)
failure_objects = [
failure_object for _ in range(len(return_object_ids))
]
self._store_outputs_in_object_store(return_object_ids, failure_objects)
# Log the error message.
ray.utils.push_error_to_driver(
self,
ray_constants.TASK_PUSH_ERROR,
str(failure_object),
job_id=self.current_job_id)
# Mark the actor init as failed
if not self.actor_id.is_nil() and function_name == "__init__":
self.mark_actor_init_failed(error)
# Send signal with the error.
ray_signal.send(ray_signal.ErrorSignal(str(failure_object)))
def _wait_for_and_process_task(self, task):
"""Wait for a task to be ready and process the task.
Args:
task: The task to execute.
"""
function_descriptor = FunctionDescriptor.from_bytes_list(
task.function_descriptor_list())
job_id = task.job_id()
# TODO(rkn): It would be preferable for actor creation tasks to share
# more of the code path with regular task execution.
if task.is_actor_creation_task():
# TODO: Remove Worker.actor_id and just use CoreWorker.GetActorId.
self.actor_id = task.actor_creation_id()
self.core_worker.set_actor_id(task.actor_creation_id())
self.actor_creation_task_id = task.task_id()
actor_class = self.function_actor_manager.load_actor_class(
job_id, function_descriptor)
self.actors[self.actor_id] = actor_class.__new__(actor_class)
self.actor_checkpoint_info[self.actor_id] = ActorCheckpointInfo(
num_tasks_since_last_checkpoint=0,
last_checkpoint_timestamp=int(1000 * time.time()),
checkpoint_ids=[],
)
execution_info = self.function_actor_manager.get_execution_info(
job_id, function_descriptor)
# Execute the task.
function_name = execution_info.function_name
extra_data = {"name": function_name, "task_id": task.task_id().hex()}
if not task.is_actor_task():
if not task.is_actor_creation_task():
title = "ray_worker:{}()".format(function_name)
next_title = "ray_worker"
else:
actor = self.actors[task.actor_creation_id()]
title = "ray_{}:{}()".format(actor.__class__.__name__,
function_name)
next_title = "ray_{}".format(actor.__class__.__name__)
else:
actor = self.actors[task.actor_id()]
title = "ray_{}:{}()".format(actor.__class__.__name__,
function_name)
next_title = "ray_{}".format(actor.__class__.__name__)
with profiling.profile("task", extra_data=extra_data):
with _changeproctitle(title, next_title):
self._process_task(task, execution_info)
# Reset the state fields so the next task can run.
self.task_context.current_task_id = TaskID.nil()
self.core_worker.set_current_task_id(TaskID.nil())
self.task_context.task_index = 0
self.task_context.put_index = 1
if self.actor_id.is_nil():
# Don't need to reset `current_job_id` if the worker is an
# actor. Because the following tasks should all have the
# same driver id.
self.current_job_id = WorkerID.nil()
self.core_worker.set_current_job_id(JobID.nil())
# Reset signal counters so that the next task can get
# all past signals.
ray_signal.reset()
# Increase the task execution counter.
self.function_actor_manager.increase_task_counter(
job_id, function_descriptor)
reached_max_executions = (self.function_actor_manager.get_task_counter(
job_id, function_descriptor) == execution_info.max_calls)
if reached_max_executions:
self.core_worker.disconnect()
sys.exit(0)
def _get_next_task_from_raylet(self):
"""Get the next task from the raylet.
Returns:
A task from the raylet.
"""
with profiling.profile("worker_idle"):
task = self.raylet_client.get_task()
# Automatically restrict the GPUs available to this task.
ray.utils.set_cuda_visible_devices(ray.get_gpu_ids())
return task
def main_loop(self):
"""The main loop a worker runs to receive and execute tasks."""
def exit(signum, frame):
shutdown()
sys.exit(0)
def sigterm_handler(signum, frame):
shutdown(True)
sys.exit(1)
signal.signal(signal.SIGTERM, exit)
while True:
task = self._get_next_task_from_raylet()
self._wait_for_and_process_task(task)
signal.signal(signal.SIGTERM, sigterm_handler)
self.core_worker.run_task_loop()
def get_gpu_ids():
@@ -982,7 +716,7 @@ def get_gpu_ids():
raise Exception("ray.get_gpu_ids() currently does not work in LOCAL "
"MODE.")
all_resource_ids = global_worker.raylet_client.resource_ids()
all_resource_ids = global_worker.core_worker.resource_ids()
assigned_ids = [
resource_id for resource_id, _ in all_resource_ids.get("GPU", [])
]
@@ -1010,7 +744,7 @@ def get_resource_ids():
"ray.get_resource_ids() currently does not work in LOCAL "
"MODE.")
return global_worker.raylet_client.resource_ids()
return global_worker.core_worker.resource_ids()
def get_webui_url():
@@ -1437,7 +1171,7 @@ def shutdown(exiting_interpreter=False):
# to make sure that log messages finish printing.
time.sleep(0.5)
disconnect()
disconnect(exiting_interpreter)
# Disconnect global state from GCS.
ray.state.state.disconnect()
@@ -1456,6 +1190,13 @@ def shutdown(exiting_interpreter=False):
atexit.register(shutdown, True)
def sigterm_handler(signum, frame):
sys.exit(signal.SIGTERM)
signal.signal(signal.SIGTERM, sigterm_handler)
# Define a custom excepthook so that if the driver exits with an exception, we
# can push that exception to Redis.
normal_excepthook = sys.excepthook
@@ -1900,7 +1641,7 @@ def connect(node,
worker.cached_functions_to_run = None
def disconnect():
def disconnect(exiting_interpreter=False):
"""Disconnect this worker from the raylet and object store."""
# Reset the list of cached remote functions and actors so that if more
# remote functions or actors are defined and then connect is called again,
@@ -1928,10 +1669,12 @@ def disconnect():
worker.function_actor_manager.reset_cache()
worker.serialization_context_map.clear()
if hasattr(worker, "raylet_client"):
del worker.raylet_client
if hasattr(worker, "core_worker"):
del worker.core_worker
if not exiting_interpreter:
if hasattr(worker, "raylet_client"):
del worker.raylet_client
if hasattr(worker, "core_worker"):
del worker.core_worker
@contextmanager