mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 08:19:30 +08:00
[core worker] Python core worker task execution (#5783)
Executes tasks via the event loop in the C++ core worker. Also properly handles signals (including KeyboardInterrupt), so ctrl-C in a python interactive shell works now (if connecting to an existing cluster).
This commit is contained in:
+30
-287
@@ -26,7 +26,6 @@ import random
|
||||
import pyarrow
|
||||
import pyarrow.plasma as plasma
|
||||
import ray.cloudpickle as pickle
|
||||
import ray.experimental.signal as ray_signal
|
||||
import ray.experimental.no_return
|
||||
import ray.gcs_utils
|
||||
import ray.memory_monitor as memory_monitor
|
||||
@@ -41,7 +40,6 @@ import ray.state
|
||||
|
||||
from ray import (
|
||||
ActorID,
|
||||
WorkerID,
|
||||
JobID,
|
||||
ObjectID,
|
||||
TaskID,
|
||||
@@ -60,10 +58,7 @@ from ray.exceptions import (
|
||||
UnreconstructableError,
|
||||
RAY_EXCEPTION_TYPES,
|
||||
)
|
||||
from ray.function_manager import (
|
||||
FunctionActorManager,
|
||||
FunctionDescriptor,
|
||||
)
|
||||
from ray.function_manager import FunctionActorManager
|
||||
from ray.utils import (
|
||||
_random_string,
|
||||
check_oversized_pickle,
|
||||
@@ -156,7 +151,6 @@ class Worker(object):
|
||||
# Index of the current session. This number will
|
||||
# increment every time when `ray.shutdown` is called.
|
||||
self._session_index = 0
|
||||
self._current_task = None
|
||||
# Functions to run to process the values returned by ray.get. Each
|
||||
# postprocessor must take two arguments ("object_ids", and "values").
|
||||
self._post_get_hooks = []
|
||||
@@ -473,9 +467,10 @@ class Worker(object):
|
||||
logger.warning(warning_message)
|
||||
self.store_and_register(object_id, value)
|
||||
|
||||
def retrieve_and_deserialize(self, object_ids, error_timeout=10):
|
||||
data_metadata_pairs = self.core_worker.get_objects(
|
||||
object_ids, self.current_task_id)
|
||||
def deserialize_objects(self,
|
||||
data_metadata_pairs,
|
||||
object_ids,
|
||||
error_timeout=10):
|
||||
assert len(data_metadata_pairs) == len(object_ids)
|
||||
|
||||
start_time = time.time()
|
||||
@@ -571,9 +566,9 @@ class Worker(object):
|
||||
if self.mode == LOCAL_MODE:
|
||||
return self.local_mode_manager.get_objects(object_ids)
|
||||
|
||||
results = self.retrieve_and_deserialize(object_ids)
|
||||
assert len(results) == len(object_ids)
|
||||
return results
|
||||
data_metadata_pairs = self.core_worker.get_objects(
|
||||
object_ids, self.current_task_id)
|
||||
return self.deserialize_objects(data_metadata_pairs, object_ids)
|
||||
|
||||
def run_function_on_all_workers(self, function,
|
||||
run_on_other_drivers=False):
|
||||
@@ -679,149 +674,6 @@ class Worker(object):
|
||||
|
||||
return ray.signature.recover_args(arguments)
|
||||
|
||||
def _store_outputs_in_object_store(self, object_ids, outputs):
|
||||
"""Store the outputs of a remote function in the local object store.
|
||||
|
||||
This stores the values that were returned by a remote function in the
|
||||
local object store. If any of the return values are object IDs, then
|
||||
these object IDs are aliased with the object IDs that the scheduler
|
||||
assigned for the return values. This is called by the worker that
|
||||
executes the remote function.
|
||||
|
||||
Note:
|
||||
The arguments object_ids and outputs should have the same length.
|
||||
|
||||
Args:
|
||||
object_ids (List[ObjectID]): The object IDs that were assigned to
|
||||
the outputs of the remote function call.
|
||||
outputs (Tuple): The value returned by the remote function. If the
|
||||
remote function was supposed to only return one value, then its
|
||||
output was wrapped in a tuple with one element prior to being
|
||||
passed into this function.
|
||||
"""
|
||||
for i in range(len(object_ids)):
|
||||
if isinstance(outputs[i], ray.actor.ActorHandle):
|
||||
raise Exception("Returning an actor handle from a remote "
|
||||
"function is not allowed).")
|
||||
if outputs[i] is ray.experimental.no_return.NoReturn:
|
||||
if not self.core_worker.object_exists(object_ids[i]):
|
||||
raise RuntimeError(
|
||||
"Attempting to return 'ray.experimental.NoReturn' "
|
||||
"from a remote function, but the corresponding "
|
||||
"ObjectID does not exist in the local object store.")
|
||||
else:
|
||||
self.put_object(object_ids[i], outputs[i])
|
||||
|
||||
def _process_task(self, task, function_execution_info):
|
||||
"""Execute a task assigned to this worker.
|
||||
|
||||
This method deserializes a task from the scheduler, and attempts to
|
||||
execute the task. If the task succeeds, the outputs are stored in the
|
||||
local object store. If the task throws an exception, RayTaskError
|
||||
objects are stored in the object store to represent the failed task
|
||||
(these will be retrieved by calls to get or by subsequent tasks that
|
||||
use the outputs of this task).
|
||||
"""
|
||||
assert self.current_task_id.is_nil()
|
||||
assert self.task_context.task_index == 0
|
||||
assert self.task_context.put_index == 1
|
||||
if not task.is_actor_task():
|
||||
# If this worker is not an actor, check that `current_job_id`
|
||||
# was reset when the worker finished the previous task.
|
||||
assert self.current_job_id.is_nil()
|
||||
# Set the driver ID of the current running task. This is
|
||||
# needed so that if the task throws an exception, we propagate
|
||||
# the error message to the correct driver.
|
||||
self.current_job_id = task.job_id()
|
||||
self.core_worker.set_current_job_id(task.job_id())
|
||||
else:
|
||||
# If this worker is an actor, current_job_id wasn't reset.
|
||||
# Check that current task's driver ID equals the previous one.
|
||||
assert self.current_job_id == task.job_id()
|
||||
|
||||
self.task_context.current_task_id = task.task_id()
|
||||
self.core_worker.set_current_task_id(task.task_id())
|
||||
|
||||
function_descriptor = FunctionDescriptor.from_bytes_list(
|
||||
task.function_descriptor_list())
|
||||
serialized_args = task.arguments()
|
||||
return_object_ids = task.returns()
|
||||
if task.is_actor_task() or task.is_actor_creation_task():
|
||||
dummy_return_id = return_object_ids.pop()
|
||||
function_executor = function_execution_info.function
|
||||
function_name = function_execution_info.function_name
|
||||
|
||||
# Get task arguments from the object store.
|
||||
try:
|
||||
if function_name != "__ray_terminate__":
|
||||
self.reraise_actor_init_error()
|
||||
self.memory_monitor.raise_if_low_memory()
|
||||
with profiling.profile("task:deserialize_arguments"):
|
||||
function_args, function_kwargs = (
|
||||
self._get_arguments_for_execution(function_name,
|
||||
serialized_args))
|
||||
except Exception as e:
|
||||
self._handle_process_task_failure(
|
||||
function_descriptor, return_object_ids, e,
|
||||
ray.utils.format_error_message(traceback.format_exc()))
|
||||
return
|
||||
|
||||
# Execute the task.
|
||||
try:
|
||||
self._current_task = task
|
||||
with profiling.profile("task:execute"):
|
||||
if task.is_normal_task():
|
||||
outputs = function_executor(*function_args,
|
||||
**function_kwargs)
|
||||
else:
|
||||
if task.is_actor_task():
|
||||
key = task.actor_id()
|
||||
else:
|
||||
key = task.actor_creation_id()
|
||||
worker_name = "ray_{}_{}".format(
|
||||
self.actors[key].__class__.__name__, os.getpid())
|
||||
if "memory" in task.required_resources():
|
||||
self.memory_monitor.set_heap_limit(
|
||||
worker_name,
|
||||
ray_constants.from_memory_units(
|
||||
task.required_resources()["memory"]))
|
||||
if "object_store_memory" in task.required_resources():
|
||||
self._set_object_store_client_options(
|
||||
worker_name,
|
||||
int(
|
||||
ray_constants.from_memory_units(
|
||||
task.required_resources()[
|
||||
"object_store_memory"])))
|
||||
outputs = function_executor(
|
||||
dummy_return_id, self.actors[key], *function_args,
|
||||
**function_kwargs)
|
||||
except Exception as e:
|
||||
# Determine whether the exception occured during a task, not an
|
||||
# actor method.
|
||||
task_exception = not task.is_actor_task()
|
||||
traceback_str = ray.utils.format_error_message(
|
||||
traceback.format_exc(), task_exception=task_exception)
|
||||
self._handle_process_task_failure(
|
||||
function_descriptor, return_object_ids, e, traceback_str)
|
||||
return
|
||||
finally:
|
||||
self._current_task = None
|
||||
|
||||
# Store the outputs in the local object store.
|
||||
try:
|
||||
with profiling.profile("task:store_outputs"):
|
||||
# If this is an actor task, then the last object ID returned by
|
||||
# the task is a dummy output, not returned by the function
|
||||
# itself. Decrement to get the correct number of return values.
|
||||
num_returns = len(return_object_ids)
|
||||
if num_returns == 1:
|
||||
outputs = (outputs, )
|
||||
self._store_outputs_in_object_store(return_object_ids, outputs)
|
||||
except Exception as e:
|
||||
self._handle_process_task_failure(
|
||||
function_descriptor, return_object_ids, e,
|
||||
ray.utils.format_error_message(traceback.format_exc()))
|
||||
|
||||
def _set_object_store_client_options(self, name, object_store_memory):
|
||||
try:
|
||||
logger.debug("Setting plasma memory limit to {} for {}".format(
|
||||
@@ -838,133 +690,15 @@ class Worker(object):
|
||||
"object store memory status is:\n\n{}".format(
|
||||
object_store_memory, name, e))
|
||||
|
||||
def _handle_process_task_failure(self, function_descriptor,
|
||||
return_object_ids, error, backtrace):
|
||||
function_name = function_descriptor.function_name
|
||||
if isinstance(error, RayTaskError):
|
||||
# avoid recursively nesting of RayTaskError
|
||||
failure_object = RayTaskError(function_name, backtrace,
|
||||
error.cause_cls)
|
||||
else:
|
||||
failure_object = RayTaskError(function_name, backtrace,
|
||||
error.__class__)
|
||||
failure_objects = [
|
||||
failure_object for _ in range(len(return_object_ids))
|
||||
]
|
||||
self._store_outputs_in_object_store(return_object_ids, failure_objects)
|
||||
# Log the error message.
|
||||
ray.utils.push_error_to_driver(
|
||||
self,
|
||||
ray_constants.TASK_PUSH_ERROR,
|
||||
str(failure_object),
|
||||
job_id=self.current_job_id)
|
||||
# Mark the actor init as failed
|
||||
if not self.actor_id.is_nil() and function_name == "__init__":
|
||||
self.mark_actor_init_failed(error)
|
||||
# Send signal with the error.
|
||||
ray_signal.send(ray_signal.ErrorSignal(str(failure_object)))
|
||||
|
||||
def _wait_for_and_process_task(self, task):
|
||||
"""Wait for a task to be ready and process the task.
|
||||
|
||||
Args:
|
||||
task: The task to execute.
|
||||
"""
|
||||
function_descriptor = FunctionDescriptor.from_bytes_list(
|
||||
task.function_descriptor_list())
|
||||
job_id = task.job_id()
|
||||
|
||||
# TODO(rkn): It would be preferable for actor creation tasks to share
|
||||
# more of the code path with regular task execution.
|
||||
if task.is_actor_creation_task():
|
||||
# TODO: Remove Worker.actor_id and just use CoreWorker.GetActorId.
|
||||
self.actor_id = task.actor_creation_id()
|
||||
self.core_worker.set_actor_id(task.actor_creation_id())
|
||||
self.actor_creation_task_id = task.task_id()
|
||||
actor_class = self.function_actor_manager.load_actor_class(
|
||||
job_id, function_descriptor)
|
||||
self.actors[self.actor_id] = actor_class.__new__(actor_class)
|
||||
self.actor_checkpoint_info[self.actor_id] = ActorCheckpointInfo(
|
||||
num_tasks_since_last_checkpoint=0,
|
||||
last_checkpoint_timestamp=int(1000 * time.time()),
|
||||
checkpoint_ids=[],
|
||||
)
|
||||
|
||||
execution_info = self.function_actor_manager.get_execution_info(
|
||||
job_id, function_descriptor)
|
||||
|
||||
# Execute the task.
|
||||
function_name = execution_info.function_name
|
||||
extra_data = {"name": function_name, "task_id": task.task_id().hex()}
|
||||
if not task.is_actor_task():
|
||||
if not task.is_actor_creation_task():
|
||||
title = "ray_worker:{}()".format(function_name)
|
||||
next_title = "ray_worker"
|
||||
else:
|
||||
actor = self.actors[task.actor_creation_id()]
|
||||
title = "ray_{}:{}()".format(actor.__class__.__name__,
|
||||
function_name)
|
||||
next_title = "ray_{}".format(actor.__class__.__name__)
|
||||
else:
|
||||
actor = self.actors[task.actor_id()]
|
||||
title = "ray_{}:{}()".format(actor.__class__.__name__,
|
||||
function_name)
|
||||
next_title = "ray_{}".format(actor.__class__.__name__)
|
||||
|
||||
with profiling.profile("task", extra_data=extra_data):
|
||||
with _changeproctitle(title, next_title):
|
||||
self._process_task(task, execution_info)
|
||||
# Reset the state fields so the next task can run.
|
||||
self.task_context.current_task_id = TaskID.nil()
|
||||
self.core_worker.set_current_task_id(TaskID.nil())
|
||||
self.task_context.task_index = 0
|
||||
self.task_context.put_index = 1
|
||||
if self.actor_id.is_nil():
|
||||
# Don't need to reset `current_job_id` if the worker is an
|
||||
# actor. Because the following tasks should all have the
|
||||
# same driver id.
|
||||
self.current_job_id = WorkerID.nil()
|
||||
self.core_worker.set_current_job_id(JobID.nil())
|
||||
# Reset signal counters so that the next task can get
|
||||
# all past signals.
|
||||
ray_signal.reset()
|
||||
|
||||
# Increase the task execution counter.
|
||||
self.function_actor_manager.increase_task_counter(
|
||||
job_id, function_descriptor)
|
||||
|
||||
reached_max_executions = (self.function_actor_manager.get_task_counter(
|
||||
job_id, function_descriptor) == execution_info.max_calls)
|
||||
if reached_max_executions:
|
||||
self.core_worker.disconnect()
|
||||
sys.exit(0)
|
||||
|
||||
def _get_next_task_from_raylet(self):
|
||||
"""Get the next task from the raylet.
|
||||
|
||||
Returns:
|
||||
A task from the raylet.
|
||||
"""
|
||||
with profiling.profile("worker_idle"):
|
||||
task = self.raylet_client.get_task()
|
||||
|
||||
# Automatically restrict the GPUs available to this task.
|
||||
ray.utils.set_cuda_visible_devices(ray.get_gpu_ids())
|
||||
|
||||
return task
|
||||
|
||||
def main_loop(self):
|
||||
"""The main loop a worker runs to receive and execute tasks."""
|
||||
|
||||
def exit(signum, frame):
|
||||
shutdown()
|
||||
sys.exit(0)
|
||||
def sigterm_handler(signum, frame):
|
||||
shutdown(True)
|
||||
sys.exit(1)
|
||||
|
||||
signal.signal(signal.SIGTERM, exit)
|
||||
|
||||
while True:
|
||||
task = self._get_next_task_from_raylet()
|
||||
self._wait_for_and_process_task(task)
|
||||
signal.signal(signal.SIGTERM, sigterm_handler)
|
||||
self.core_worker.run_task_loop()
|
||||
|
||||
|
||||
def get_gpu_ids():
|
||||
@@ -982,7 +716,7 @@ def get_gpu_ids():
|
||||
raise Exception("ray.get_gpu_ids() currently does not work in LOCAL "
|
||||
"MODE.")
|
||||
|
||||
all_resource_ids = global_worker.raylet_client.resource_ids()
|
||||
all_resource_ids = global_worker.core_worker.resource_ids()
|
||||
assigned_ids = [
|
||||
resource_id for resource_id, _ in all_resource_ids.get("GPU", [])
|
||||
]
|
||||
@@ -1010,7 +744,7 @@ def get_resource_ids():
|
||||
"ray.get_resource_ids() currently does not work in LOCAL "
|
||||
"MODE.")
|
||||
|
||||
return global_worker.raylet_client.resource_ids()
|
||||
return global_worker.core_worker.resource_ids()
|
||||
|
||||
|
||||
def get_webui_url():
|
||||
@@ -1437,7 +1171,7 @@ def shutdown(exiting_interpreter=False):
|
||||
# to make sure that log messages finish printing.
|
||||
time.sleep(0.5)
|
||||
|
||||
disconnect()
|
||||
disconnect(exiting_interpreter)
|
||||
|
||||
# Disconnect global state from GCS.
|
||||
ray.state.state.disconnect()
|
||||
@@ -1456,6 +1190,13 @@ def shutdown(exiting_interpreter=False):
|
||||
|
||||
atexit.register(shutdown, True)
|
||||
|
||||
|
||||
def sigterm_handler(signum, frame):
|
||||
sys.exit(signal.SIGTERM)
|
||||
|
||||
|
||||
signal.signal(signal.SIGTERM, sigterm_handler)
|
||||
|
||||
# Define a custom excepthook so that if the driver exits with an exception, we
|
||||
# can push that exception to Redis.
|
||||
normal_excepthook = sys.excepthook
|
||||
@@ -1900,7 +1641,7 @@ def connect(node,
|
||||
worker.cached_functions_to_run = None
|
||||
|
||||
|
||||
def disconnect():
|
||||
def disconnect(exiting_interpreter=False):
|
||||
"""Disconnect this worker from the raylet and object store."""
|
||||
# Reset the list of cached remote functions and actors so that if more
|
||||
# remote functions or actors are defined and then connect is called again,
|
||||
@@ -1928,10 +1669,12 @@ def disconnect():
|
||||
worker.function_actor_manager.reset_cache()
|
||||
worker.serialization_context_map.clear()
|
||||
|
||||
if hasattr(worker, "raylet_client"):
|
||||
del worker.raylet_client
|
||||
if hasattr(worker, "core_worker"):
|
||||
del worker.core_worker
|
||||
if not exiting_interpreter:
|
||||
if hasattr(worker, "raylet_client"):
|
||||
del worker.raylet_client
|
||||
|
||||
if hasattr(worker, "core_worker"):
|
||||
del worker.core_worker
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
||||
Reference in New Issue
Block a user