mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 08:18:49 +08:00
Create RemoteFunction class, remove FunctionProperties, simplify worker Python code. (#2052)
* Cleaning up worker and actor code. Create remote function class. Remove FunctionProperties object. * Remove register_actor_signatures function. * Small cleanups. * Fix linting. * Support @ray.method syntax for actor methods. * Fix pickling bug. * Fix linting. * Shorten testBlockingTasks. * Small fixes. * Call get_global_worker().
This commit is contained in:
committed by
Philipp Moritz
parent
ad48e47120
commit
8fbb88485b
+230
-411
@@ -5,7 +5,6 @@ from __future__ import print_function
|
||||
import atexit
|
||||
import collections
|
||||
import colorama
|
||||
import copy
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
@@ -23,13 +22,13 @@ import pyarrow
|
||||
import pyarrow.plasma as plasma
|
||||
import ray.cloudpickle as pickle
|
||||
import ray.experimental.state as state
|
||||
import ray.remote_function
|
||||
import ray.serialization as serialization
|
||||
import ray.services as services
|
||||
import ray.signature as signature
|
||||
import ray.signature
|
||||
import ray.local_scheduler
|
||||
import ray.plasma
|
||||
from ray.utils import (FunctionProperties, random_string, binary_to_hex,
|
||||
is_cython)
|
||||
from ray.utils import random_string, binary_to_hex, is_cython
|
||||
|
||||
# Import flatbuffer bindings.
|
||||
from ray.core.generated.ClientTableData import ClientTableData
|
||||
@@ -63,9 +62,6 @@ PUT_RECONSTRUCTION_ERROR_TYPE = b"put_reconstruction"
|
||||
# This must be kept in sync with the `scheduling_state` enum in common/task.h.
|
||||
TASK_STATUS_RUNNING = 8
|
||||
|
||||
# Default resource requirements for remote functions.
|
||||
DEFAULT_REMOTE_FUNCTION_CPUS = 1
|
||||
DEFAULT_REMOTE_FUNCTION_GPUS = 0
|
||||
# Default resource requirements for actors when no resource requirements are
|
||||
# specified.
|
||||
DEFAULT_ACTOR_METHOD_CPUS_SIMPLE_CASE = 1
|
||||
@@ -74,15 +70,6 @@ DEFAULT_ACTOR_CREATION_CPUS_SIMPLE_CASE = 0
|
||||
# specified.
|
||||
DEFAULT_ACTOR_METHOD_CPUS_SPECIFIED_CASE = 0
|
||||
DEFAULT_ACTOR_CREATION_CPUS_SPECIFIED_CASE = 1
|
||||
DEFAULT_ACTOR_CREATION_GPUS_SPECIFIED_CASE = 0
|
||||
|
||||
|
||||
class FunctionID(object):
|
||||
def __init__(self, function_id):
|
||||
self.function_id = function_id
|
||||
|
||||
def id(self):
|
||||
return self.function_id
|
||||
|
||||
|
||||
class RayTaskError(Exception):
|
||||
@@ -182,6 +169,11 @@ class RayGetArgumentError(Exception):
|
||||
self.task_error))
|
||||
|
||||
|
||||
FunctionExecutionInfo = collections.namedtuple(
|
||||
"FunctionExecutionInfo", ["function", "function_name", "max_calls"])
|
||||
"""FunctionExecutionInfo: A named tuple storing remote function information."""
|
||||
|
||||
|
||||
class Worker(object):
|
||||
"""A class used to define the control flow of a worker process.
|
||||
|
||||
@@ -190,9 +182,10 @@ class Worker(object):
|
||||
functions outside of this class are considered exposed.
|
||||
|
||||
Attributes:
|
||||
functions (Dict[str, Callable]): A dictionary mapping the name of a
|
||||
remote function to the remote function itself. This is the set of
|
||||
remote functions that can be executed by this worker.
|
||||
function_execution_info (Dict[str, FunctionExecutionInfo]): A
|
||||
dictionary mapping the name of a remote function to the remote
|
||||
function itself. This is the set of remote functions that can be
|
||||
executed by this worker.
|
||||
connected (bool): True if Ray has been started and False otherwise.
|
||||
mode: The mode of the worker. One of SCRIPT_MODE, PYTHON_MODE,
|
||||
SILENT_MODE, and WORKER_MODE.
|
||||
@@ -208,20 +201,12 @@ class Worker(object):
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize a Worker object."""
|
||||
# The functions field is a dictionary that maps a driver ID to a
|
||||
# dictionary of functions that have been registered for that driver
|
||||
# (this inner dictionary maps function IDs to a tuple of the function
|
||||
# name and the function itself). This should only be used on workers
|
||||
# that execute remote functions.
|
||||
self.functions = collections.defaultdict(lambda: {})
|
||||
# The function_properties field is a dictionary that maps a driver ID
|
||||
# to a dictionary of functions that have been registered for that
|
||||
# driver (this inner dictionary maps function IDs to a tuple of the
|
||||
# number of values returned by that function, the number of CPUs
|
||||
# required by that function, and the number of GPUs required by that
|
||||
# function). This is used when submitting a function (which can be done
|
||||
# both on workers and on drivers).
|
||||
self.function_properties = collections.defaultdict(lambda: {})
|
||||
# This field is a dictionary that maps a driver ID to a dictionary of
|
||||
# functions (and information about those functions) that have been
|
||||
# registered for that driver (this inner dictionary maps function IDs
|
||||
# to a FunctionExecutionInfo object. This should only be used on
|
||||
# workers that execute remote functions.
|
||||
self.function_execution_info = collections.defaultdict(lambda: {})
|
||||
# This is a dictionary mapping driver ID to a dictionary that maps
|
||||
# remote function IDs for that driver to a counter of the number of
|
||||
# times that remote function has been executed on this worker. The
|
||||
@@ -248,6 +233,16 @@ class Worker(object):
|
||||
# CUDA_VISIBLE_DEVICES environment variable.
|
||||
self.original_gpu_ids = ray.utils.get_cuda_visible_devices()
|
||||
|
||||
def check_connected(self):
|
||||
"""Check if the worker is connected.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if the worker is not connected.
|
||||
"""
|
||||
if not self.connected:
|
||||
raise RayConnectionError("Ray has not been started yet. You can "
|
||||
"start Ray with 'ray.init()'.")
|
||||
|
||||
def set_mode(self, mode):
|
||||
"""Set the mode of the worker.
|
||||
|
||||
@@ -356,7 +351,7 @@ class Worker(object):
|
||||
full.
|
||||
"""
|
||||
# Make sure that the value is not an object ID.
|
||||
if isinstance(value, ray.local_scheduler.ObjectID):
|
||||
if isinstance(value, ray.ObjectID):
|
||||
raise Exception("Calling 'put' on an ObjectID is not allowed "
|
||||
"(similarly, returning an ObjectID from a remote "
|
||||
"function is not allowed). If you really want to "
|
||||
@@ -438,7 +433,7 @@ class Worker(object):
|
||||
"""
|
||||
# Make sure that the values are object IDs.
|
||||
for object_id in object_ids:
|
||||
if not isinstance(object_id, ray.local_scheduler.ObjectID):
|
||||
if not isinstance(object_id, ray.ObjectID):
|
||||
raise Exception("Attempting to call `get` on the value {}, "
|
||||
"which is not an ObjectID.".format(object_id))
|
||||
# Do an initial fetch for remote objects. We divide the fetch into
|
||||
@@ -518,8 +513,6 @@ class Worker(object):
|
||||
actor_creation_dummy_object_id=None,
|
||||
execution_dependencies=None,
|
||||
num_return_vals=None,
|
||||
num_cpus=None,
|
||||
num_gpus=None,
|
||||
resources=None,
|
||||
driver_id=None):
|
||||
"""Submit a remote task to the scheduler.
|
||||
@@ -545,8 +538,6 @@ class Worker(object):
|
||||
execution_dependencies: The execution dependencies for this task.
|
||||
num_return_vals: The number of return values this function should
|
||||
have.
|
||||
num_cpus: The number of CPUs required by this task.
|
||||
num_gpus: The number of GPUs required by this task.
|
||||
resources: The resource requirements for this task.
|
||||
driver_id: The ID of the relevant driver. This is almost always the
|
||||
driver ID of the driver that is currently running. However, in
|
||||
@@ -561,24 +552,22 @@ class Worker(object):
|
||||
check_main_thread()
|
||||
if actor_id is None:
|
||||
assert actor_handle_id is None
|
||||
actor_id = ray.local_scheduler.ObjectID(NIL_ACTOR_ID)
|
||||
actor_handle_id = ray.local_scheduler.ObjectID(
|
||||
NIL_ACTOR_HANDLE_ID)
|
||||
actor_id = ray.ObjectID(NIL_ACTOR_ID)
|
||||
actor_handle_id = ray.ObjectID(NIL_ACTOR_HANDLE_ID)
|
||||
else:
|
||||
assert actor_handle_id is not None
|
||||
|
||||
if actor_creation_id is None:
|
||||
actor_creation_id = ray.local_scheduler.ObjectID(NIL_ACTOR_ID)
|
||||
actor_creation_id = ray.ObjectID(NIL_ACTOR_ID)
|
||||
|
||||
if actor_creation_dummy_object_id is None:
|
||||
actor_creation_dummy_object_id = (
|
||||
ray.local_scheduler.ObjectID(NIL_ID))
|
||||
actor_creation_dummy_object_id = (ray.ObjectID(NIL_ID))
|
||||
|
||||
# Put large or complex arguments that are passed by value in the
|
||||
# object store first.
|
||||
args_for_local_scheduler = []
|
||||
for arg in args:
|
||||
if isinstance(arg, ray.local_scheduler.ObjectID):
|
||||
if isinstance(arg, ray.ObjectID):
|
||||
args_for_local_scheduler.append(arg)
|
||||
elif ray.local_scheduler.check_simple_value(arg):
|
||||
args_for_local_scheduler.append(arg)
|
||||
@@ -592,26 +581,12 @@ class Worker(object):
|
||||
if driver_id is None:
|
||||
driver_id = self.task_driver_id
|
||||
|
||||
# Look up the various function properties.
|
||||
function_properties = self.function_properties[driver_id.id()][
|
||||
function_id.id()]
|
||||
|
||||
if num_return_vals is None:
|
||||
num_return_vals = function_properties.num_return_vals
|
||||
|
||||
if resources is None and num_cpus is None and num_gpus is None:
|
||||
resources = function_properties.resources
|
||||
else:
|
||||
resources = {} if resources is None else resources
|
||||
if "CPU" in resources or "GPU" in resources:
|
||||
raise ValueError("The resources dictionary must not "
|
||||
"contain the keys 'CPU' or 'GPU'")
|
||||
resources["CPU"] = num_cpus
|
||||
resources["GPU"] = num_gpus
|
||||
if resources is None:
|
||||
raise ValueError("The resources dictionary is required.")
|
||||
|
||||
# Submit the task to local scheduler.
|
||||
task = ray.local_scheduler.Task(
|
||||
driver_id, ray.local_scheduler.ObjectID(
|
||||
driver_id, ray.ObjectID(
|
||||
function_id.id()), args_for_local_scheduler,
|
||||
num_return_vals, self.current_task_id, self.task_index,
|
||||
actor_creation_id, actor_creation_dummy_object_id, actor_id,
|
||||
@@ -624,6 +599,55 @@ class Worker(object):
|
||||
|
||||
return task.returns()
|
||||
|
||||
def export_remote_function(self, function_id, function_name, function,
|
||||
max_calls, decorated_function):
|
||||
"""Export a remote function.
|
||||
|
||||
Args:
|
||||
function_id: The ID of the function.
|
||||
function_name: The name of the function.
|
||||
function: The raw undecorated function to export.
|
||||
max_calls: The maximum number of times a given worker can execute
|
||||
this function before exiting.
|
||||
decorated_function: The decorated function (this is used to enable
|
||||
the remote function to recursively call itself).
|
||||
"""
|
||||
check_main_thread()
|
||||
if self.mode not in [SCRIPT_MODE, SILENT_MODE]:
|
||||
raise Exception("export_remote_function can only be called on a "
|
||||
"driver.")
|
||||
|
||||
key = (b"RemoteFunction:" + self.task_driver_id.id() + b":" +
|
||||
function_id.id())
|
||||
|
||||
# Work around limitations of Python pickling.
|
||||
function_name_global_valid = function.__name__ in function.__globals__
|
||||
function_name_global_value = function.__globals__.get(
|
||||
function.__name__)
|
||||
# Allow the function to reference itself as a global variable
|
||||
if not is_cython(function):
|
||||
function.__globals__[function.__name__] = decorated_function
|
||||
try:
|
||||
pickled_function = pickle.dumps(function)
|
||||
finally:
|
||||
# Undo our changes
|
||||
if function_name_global_valid:
|
||||
function.__globals__[function.__name__] = (
|
||||
function_name_global_value)
|
||||
else:
|
||||
del function.__globals__[function.__name__]
|
||||
|
||||
self.redis_client.hmset(
|
||||
key, {
|
||||
"driver_id": self.task_driver_id.id(),
|
||||
"function_id": function_id.id(),
|
||||
"name": function_name,
|
||||
"module": function.__module__,
|
||||
"function": pickled_function,
|
||||
"max_calls": max_calls
|
||||
})
|
||||
self.redis_client.rpush("Exports", key)
|
||||
|
||||
def run_function_on_all_workers(self, function):
|
||||
"""Run arbitrary code on all of the workers.
|
||||
|
||||
@@ -697,7 +721,8 @@ class Worker(object):
|
||||
while True:
|
||||
with self.lock:
|
||||
if (self.actor_id == NIL_ACTOR_ID
|
||||
and (function_id.id() in self.functions[driver_id])):
|
||||
and (function_id.id() in
|
||||
self.function_execution_info[driver_id])):
|
||||
break
|
||||
elif self.actor_id != NIL_ACTOR_ID and (
|
||||
self.actor_id in self.actors):
|
||||
@@ -741,7 +766,7 @@ class Worker(object):
|
||||
"""
|
||||
arguments = []
|
||||
for (i, arg) in enumerate(serialized_args):
|
||||
if isinstance(arg, ray.local_scheduler.ObjectID):
|
||||
if isinstance(arg, ray.ObjectID):
|
||||
# get the object from the local object store
|
||||
argument = self.get_object([arg])[0]
|
||||
if isinstance(argument, RayTaskError):
|
||||
@@ -798,7 +823,6 @@ class Worker(object):
|
||||
# message to the correct driver.
|
||||
self.task_driver_id = task.driver_id()
|
||||
self.current_task_id = task.task_id()
|
||||
self.current_function_id = task.function_id().id()
|
||||
self.task_index = 0
|
||||
self.put_index = 1
|
||||
function_id = task.function_id()
|
||||
@@ -806,8 +830,10 @@ class Worker(object):
|
||||
return_object_ids = task.returns()
|
||||
if task.actor_id().id() != NIL_ACTOR_ID:
|
||||
dummy_return_id = return_object_ids.pop()
|
||||
function_name, function_executor = (
|
||||
self.functions[self.task_driver_id.id()][function_id.id()])
|
||||
function_executor = self.function_execution_info[
|
||||
self.task_driver_id.id()][function_id.id()].function
|
||||
function_name = self.function_execution_info[self.task_driver_id.id()][
|
||||
function_id.id()].function_name
|
||||
|
||||
# Get task arguments from the object store.
|
||||
try:
|
||||
@@ -829,7 +855,7 @@ class Worker(object):
|
||||
try:
|
||||
with log_span("ray:task:execute", worker=self):
|
||||
if task.actor_id().id() == NIL_ACTOR_ID:
|
||||
outputs = function_executor.executor(arguments)
|
||||
outputs = function_executor(*arguments)
|
||||
else:
|
||||
outputs = function_executor(
|
||||
dummy_return_id, self.actors[task.actor_id().id()],
|
||||
@@ -862,8 +888,8 @@ class Worker(object):
|
||||
|
||||
def _handle_process_task_failure(self, function_id, return_object_ids,
|
||||
error, backtrace):
|
||||
function_name, _ = self.functions[self.task_driver_id.id()][
|
||||
function_id.id()]
|
||||
function_name = self.function_execution_info[self.task_driver_id.id()][
|
||||
function_id.id()].function_name
|
||||
failure_object = RayTaskError(function_name, error, backtrace)
|
||||
failure_objects = [
|
||||
failure_object for _ in range(len(return_object_ids))
|
||||
@@ -902,7 +928,7 @@ class Worker(object):
|
||||
time.sleep(0.001)
|
||||
|
||||
with self.lock:
|
||||
self.fetch_and_register_actor(key, task.required_resources(), self)
|
||||
self.fetch_and_register_actor(key, self)
|
||||
|
||||
def _wait_for_and_process_task(self, task):
|
||||
"""Wait for a task to be ready and process the task.
|
||||
@@ -911,11 +937,11 @@ class Worker(object):
|
||||
task: The task to execute.
|
||||
"""
|
||||
function_id = task.function_id()
|
||||
driver_id = task.driver_id().id()
|
||||
|
||||
# TODO(rkn): It would be preferable for actor creation tasks to share
|
||||
# more of the code path with regular task execution.
|
||||
if (task.actor_creation_id() !=
|
||||
ray.local_scheduler.ObjectID(NIL_ACTOR_ID)):
|
||||
if (task.actor_creation_id() != ray.ObjectID(NIL_ACTOR_ID)):
|
||||
self._become_actor(task)
|
||||
return
|
||||
|
||||
@@ -923,7 +949,7 @@ class Worker(object):
|
||||
# on this worker. We will push warnings to the user if we spend too
|
||||
# long in this loop.
|
||||
with log_span("ray:wait_for_function", worker=self):
|
||||
self._wait_for_function(function_id, task.driver_id().id())
|
||||
self._wait_for_function(function_id, driver_id)
|
||||
|
||||
# Execute the task.
|
||||
# TODO(rkn): Consider acquiring this lock with a timeout and pushing a
|
||||
@@ -934,8 +960,8 @@ class Worker(object):
|
||||
with self.lock:
|
||||
log(event_type="ray:acquire_lock", kind=LOG_SPAN_END, worker=self)
|
||||
|
||||
function_name, _ = (
|
||||
self.functions[task.driver_id().id()][function_id.id()])
|
||||
function_name = (self.function_execution_info[driver_id][
|
||||
function_id.id()]).function_name
|
||||
contents = {
|
||||
"function_name": function_name,
|
||||
"task_id": task.task_id().hex(),
|
||||
@@ -948,14 +974,13 @@ class Worker(object):
|
||||
flush_log()
|
||||
|
||||
# Increase the task execution counter.
|
||||
(self.num_task_executions[task.driver_id().id()][function_id.id()]
|
||||
) += 1
|
||||
self.num_task_executions[driver_id][function_id.id()] += 1
|
||||
|
||||
reached_max_executions = (self.num_task_executions[task.driver_id().id(
|
||||
)][function_id.id()] == self.function_properties[task.driver_id().id()]
|
||||
[function_id.id()].max_calls)
|
||||
reached_max_executions = (
|
||||
self.num_task_executions[driver_id][function_id.id()] == self.
|
||||
function_execution_info[driver_id][function_id.id()].max_calls)
|
||||
if reached_max_executions:
|
||||
ray.worker.global_worker.local_scheduler_client.disconnect()
|
||||
self.local_scheduler_client.disconnect()
|
||||
os._exit(0)
|
||||
|
||||
def _get_next_task_from_local_scheduler(self):
|
||||
@@ -1069,18 +1094,6 @@ def check_main_thread():
|
||||
.format(threading.current_thread().getName()))
|
||||
|
||||
|
||||
def check_connected(worker=global_worker):
|
||||
"""Check if the worker is connected.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if the worker is not connected.
|
||||
"""
|
||||
if not worker.connected:
|
||||
raise RayConnectionError("This command cannot be called before Ray "
|
||||
"has been started. You can start Ray with "
|
||||
"'ray.init()'.")
|
||||
|
||||
|
||||
def print_failed_task(task_status):
|
||||
"""Print information about failed tasks.
|
||||
|
||||
@@ -1114,7 +1127,7 @@ def error_applies_to_driver(error_key, worker=global_worker):
|
||||
|
||||
def error_info(worker=global_worker):
|
||||
"""Return information about failed tasks."""
|
||||
check_connected(worker)
|
||||
worker.check_connected()
|
||||
check_main_thread()
|
||||
error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1)
|
||||
errors = []
|
||||
@@ -1143,13 +1156,13 @@ def _initialize_serialization(worker=global_worker):
|
||||
return obj.id()
|
||||
|
||||
def object_id_custom_deserializer(serialized_obj):
|
||||
return ray.local_scheduler.ObjectID(serialized_obj)
|
||||
return ray.ObjectID(serialized_obj)
|
||||
|
||||
# We register this serializer on each worker instead of calling
|
||||
# register_custom_serializer from the driver so that isinstance still
|
||||
# works.
|
||||
worker.serialization_context.register_type(
|
||||
ray.local_scheduler.ObjectID,
|
||||
ray.ObjectID,
|
||||
"ray.ObjectID",
|
||||
pickle=False,
|
||||
custom_serializer=object_id_custom_serializer,
|
||||
@@ -1786,12 +1799,9 @@ def fetch_and_register_remote_function(key, worker=global_worker):
|
||||
"driver_id", "function_id", "name", "function", "num_return_vals",
|
||||
"module", "resources", "max_calls"
|
||||
])
|
||||
function_id = ray.local_scheduler.ObjectID(function_id_str)
|
||||
function_id = ray.ObjectID(function_id_str)
|
||||
function_name = function_name.decode("ascii")
|
||||
function_properties = FunctionProperties(
|
||||
num_return_vals=int(num_return_vals),
|
||||
resources=json.loads(resources.decode("ascii")),
|
||||
max_calls=int(max_calls))
|
||||
max_calls = int(max_calls)
|
||||
module = module.decode("ascii")
|
||||
|
||||
# This is a placeholder in case the function can't be unpickled. This will
|
||||
@@ -1799,11 +1809,9 @@ def fetch_and_register_remote_function(key, worker=global_worker):
|
||||
def f():
|
||||
raise Exception("This function was not imported properly.")
|
||||
|
||||
remote_f_placeholder = remote(function_id=function_id)(lambda *xs: f())
|
||||
worker.functions[driver_id][function_id.id()] = (function_name,
|
||||
remote_f_placeholder)
|
||||
worker.function_properties[driver_id][function_id.id()] = (
|
||||
function_properties)
|
||||
worker.function_execution_info[driver_id][function_id.id()] = (
|
||||
FunctionExecutionInfo(
|
||||
function=f, function_name=function_name, max_calls=max_calls))
|
||||
worker.num_task_executions[driver_id][function_id.id()] = 0
|
||||
|
||||
try:
|
||||
@@ -1825,8 +1833,11 @@ def fetch_and_register_remote_function(key, worker=global_worker):
|
||||
else:
|
||||
# TODO(rkn): Why is the below line necessary?
|
||||
function.__module__ = module
|
||||
worker.functions[driver_id][function_id.id()] = (
|
||||
function_name, remote(function_id=function_id)(function))
|
||||
worker.function_execution_info[driver_id][function_id.id()] = (
|
||||
FunctionExecutionInfo(
|
||||
function=function,
|
||||
function_name=function_name,
|
||||
max_calls=max_calls))
|
||||
# Add the function to the function table.
|
||||
worker.redis_client.rpush(b"FunctionTable:" + function_id.id(),
|
||||
worker.worker_id)
|
||||
@@ -1973,6 +1984,14 @@ def connect(info,
|
||||
assert worker.cached_remote_functions_and_actors is not None, error_message
|
||||
# Initialize some fields.
|
||||
worker.worker_id = random_string()
|
||||
|
||||
# When tasks are executed on remote workers in the context of multiple
|
||||
# drivers, the task driver ID is used to keep track of which driver is
|
||||
# responsible for the task so that error messages will be propagated to
|
||||
# the correct driver.
|
||||
if mode != WORKER_MODE:
|
||||
worker.task_driver_id = ray.ObjectID(worker.worker_id)
|
||||
|
||||
# All workers start out as non-actors. A worker can be turned into an actor
|
||||
# after it is created.
|
||||
worker.actor_id = NIL_ACTOR_ID
|
||||
@@ -2102,13 +2121,7 @@ def connect(info,
|
||||
else:
|
||||
# Try to use true randomness.
|
||||
np.random.seed(None)
|
||||
worker.current_task_id = ray.local_scheduler.ObjectID(
|
||||
np.random.bytes(20))
|
||||
# When tasks are executed on remote workers in the context of multiple
|
||||
# drivers, the task driver ID is used to keep track of which driver is
|
||||
# responsible for the task so that error messages will be propagated to
|
||||
# the correct driver.
|
||||
worker.task_driver_id = ray.local_scheduler.ObjectID(worker.worker_id)
|
||||
worker.current_task_id = ray.ObjectID(np.random.bytes(20))
|
||||
# Reset the state of the numpy random number generator.
|
||||
np.random.set_state(numpy_state)
|
||||
# Set other fields needed for computing task IDs.
|
||||
@@ -2124,14 +2137,11 @@ def connect(info,
|
||||
nil_actor_counter = 0
|
||||
|
||||
driver_task = ray.local_scheduler.Task(
|
||||
worker.task_driver_id,
|
||||
ray.local_scheduler.ObjectID(NIL_FUNCTION_ID), [], 0,
|
||||
worker.task_driver_id, ray.ObjectID(NIL_FUNCTION_ID), [], 0,
|
||||
worker.current_task_id, worker.task_index,
|
||||
ray.local_scheduler.ObjectID(NIL_ACTOR_ID),
|
||||
ray.local_scheduler.ObjectID(NIL_ACTOR_ID),
|
||||
ray.local_scheduler.ObjectID(NIL_ACTOR_ID),
|
||||
ray.local_scheduler.ObjectID(NIL_ACTOR_ID), nil_actor_counter,
|
||||
False, [], {"CPU": 0}, worker.use_raylet)
|
||||
ray.ObjectID(NIL_ACTOR_ID), ray.ObjectID(NIL_ACTOR_ID),
|
||||
ray.ObjectID(NIL_ACTOR_ID), ray.ObjectID(NIL_ACTOR_ID),
|
||||
nil_actor_counter, False, [], {"CPU": 0}, worker.use_raylet)
|
||||
global_state._execute_command(
|
||||
driver_task.task_id(), "RAY.TASK_TABLE_ADD",
|
||||
driver_task.task_id().id(),
|
||||
@@ -2194,11 +2204,7 @@ def connect(info,
|
||||
# Export cached remote functions to the workers.
|
||||
for cached_type, info in worker.cached_remote_functions_and_actors:
|
||||
if cached_type == "remote_function":
|
||||
(function_id, func_name, func, func_invoker,
|
||||
function_properties) = info
|
||||
export_remote_function(function_id, func_name, func,
|
||||
func_invoker, function_properties,
|
||||
worker)
|
||||
info._export()
|
||||
elif cached_type == "actor":
|
||||
(key, actor_class_info) = info
|
||||
ray.actor.publish_actor_class_to_key(key, actor_class_info,
|
||||
@@ -2450,7 +2456,7 @@ def get(object_ids, worker=global_worker):
|
||||
Returns:
|
||||
A Python object or a list of Python objects.
|
||||
"""
|
||||
check_connected(worker)
|
||||
worker.check_connected()
|
||||
with log_span("ray:get", worker=worker):
|
||||
check_main_thread()
|
||||
|
||||
@@ -2483,7 +2489,7 @@ def put(value, worker=global_worker):
|
||||
Returns:
|
||||
The object ID assigned to this value.
|
||||
"""
|
||||
check_connected(worker)
|
||||
worker.check_connected()
|
||||
with log_span("ray:put", worker=worker):
|
||||
check_main_thread()
|
||||
|
||||
@@ -2524,7 +2530,7 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
|
||||
print("plasma_client.wait has not been implemented yet")
|
||||
return
|
||||
|
||||
if isinstance(object_ids, ray.local_scheduler.ObjectID):
|
||||
if isinstance(object_ids, ray.ObjectID):
|
||||
raise TypeError(
|
||||
"wait() expected a list of ObjectID, got a single ObjectID")
|
||||
|
||||
@@ -2534,12 +2540,12 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
|
||||
|
||||
if worker.mode != PYTHON_MODE:
|
||||
for object_id in object_ids:
|
||||
if not isinstance(object_id, ray.local_scheduler.ObjectID):
|
||||
if not isinstance(object_id, ray.ObjectID):
|
||||
raise TypeError("wait() expected a list of ObjectID, "
|
||||
"got list containing {}".format(
|
||||
type(object_id)))
|
||||
|
||||
check_connected(worker)
|
||||
worker.check_connected()
|
||||
with log_span("ray:wait", worker=worker):
|
||||
check_main_thread()
|
||||
|
||||
@@ -2561,27 +2567,14 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
|
||||
ready_ids, remaining_ids = worker.plasma_client.wait(
|
||||
object_id_strs, timeout, num_returns)
|
||||
ready_ids = [
|
||||
ray.local_scheduler.ObjectID(object_id.binary())
|
||||
for object_id in ready_ids
|
||||
ray.ObjectID(object_id.binary()) for object_id in ready_ids
|
||||
]
|
||||
remaining_ids = [
|
||||
ray.local_scheduler.ObjectID(object_id.binary())
|
||||
for object_id in remaining_ids
|
||||
ray.ObjectID(object_id.binary()) for object_id in remaining_ids
|
||||
]
|
||||
return ready_ids, remaining_ids
|
||||
|
||||
|
||||
def _submit_task(function_id, *args, **kwargs):
|
||||
"""This is a wrapper around worker.submit_task.
|
||||
|
||||
We use this wrapper so that in the remote decorator, we can call
|
||||
_submit_task instead of worker.submit_task. The difference is that when we
|
||||
attempt to serialize remote functions, we don't attempt to serialize the
|
||||
worker object, which cannot be serialized.
|
||||
"""
|
||||
return global_worker.submit_task(function_id, *args, **kwargs)
|
||||
|
||||
|
||||
def _mode(worker=global_worker):
|
||||
"""This is a wrapper around worker.mode.
|
||||
|
||||
@@ -2593,278 +2586,104 @@ def _mode(worker=global_worker):
|
||||
return worker.mode
|
||||
|
||||
|
||||
def export_remote_function(function_id,
|
||||
func_name,
|
||||
func,
|
||||
func_invoker,
|
||||
function_properties,
|
||||
worker=global_worker):
|
||||
check_main_thread()
|
||||
if _mode(worker) not in [SCRIPT_MODE, SILENT_MODE]:
|
||||
raise Exception("export_remote_function can only be called on a "
|
||||
"driver.")
|
||||
|
||||
worker.function_properties[worker.task_driver_id.id()][
|
||||
function_id.id()] = function_properties
|
||||
task_driver_id = worker.task_driver_id
|
||||
key = b"RemoteFunction:" + task_driver_id.id() + b":" + function_id.id()
|
||||
|
||||
# Work around limitations of Python pickling.
|
||||
func_name_global_valid = func.__name__ in func.__globals__
|
||||
func_name_global_value = func.__globals__.get(func.__name__)
|
||||
# Allow the function to reference itself as a global variable
|
||||
if not is_cython(func):
|
||||
func.__globals__[func.__name__] = func_invoker
|
||||
try:
|
||||
pickled_func = pickle.dumps(func)
|
||||
finally:
|
||||
# Undo our changes
|
||||
if func_name_global_valid:
|
||||
func.__globals__[func.__name__] = func_name_global_value
|
||||
else:
|
||||
del func.__globals__[func.__name__]
|
||||
|
||||
worker.redis_client.hmset(
|
||||
key, {
|
||||
"driver_id": worker.task_driver_id.id(),
|
||||
"function_id": function_id.id(),
|
||||
"name": func_name,
|
||||
"module": func.__module__,
|
||||
"function": pickled_func,
|
||||
"num_return_vals": function_properties.num_return_vals,
|
||||
"resources": json.dumps(function_properties.resources),
|
||||
"max_calls": function_properties.max_calls
|
||||
})
|
||||
worker.redis_client.rpush("Exports", key)
|
||||
def get_global_worker():
|
||||
return global_worker
|
||||
|
||||
|
||||
def in_ipython():
|
||||
"""Return true if we are in an IPython interpreter and false otherwise."""
|
||||
try:
|
||||
__IPYTHON__
|
||||
return True
|
||||
except NameError:
|
||||
return False
|
||||
def make_decorator(num_return_vals=None,
|
||||
num_cpus=None,
|
||||
num_gpus=None,
|
||||
resources=None,
|
||||
max_calls=None,
|
||||
checkpoint_interval=None,
|
||||
worker=None):
|
||||
def decorator(function_or_class):
|
||||
if (inspect.isfunction(function_or_class)
|
||||
or is_cython(function_or_class)):
|
||||
# Set the remote function default resources.
|
||||
if checkpoint_interval is not None:
|
||||
raise Exception("The keyword 'checkpoint_interval' is not "
|
||||
"allowed for remote functions.")
|
||||
|
||||
return ray.remote_function.RemoteFunction(
|
||||
function_or_class, num_cpus, num_gpus, resources,
|
||||
num_return_vals, max_calls)
|
||||
|
||||
def compute_function_id(func_name, func):
|
||||
"""Compute an function ID for a function.
|
||||
if inspect.isclass(function_or_class):
|
||||
if num_return_vals is not None:
|
||||
raise Exception("The keyword 'num_return_vals' is not allowed "
|
||||
"for actors.")
|
||||
if max_calls is not None:
|
||||
raise Exception("The keyword 'max_calls' is not allowed for "
|
||||
"actors.")
|
||||
|
||||
Args:
|
||||
func_name: The name of the function (this includes the module name plus
|
||||
the function name).
|
||||
func: The actual function.
|
||||
# Set the actor default resources.
|
||||
if num_cpus is None and num_gpus is None and resources is None:
|
||||
# In the default case, actors acquire no resources for
|
||||
# their lifetime, and actor methods will require 1 CPU.
|
||||
cpus_to_use = DEFAULT_ACTOR_CREATION_CPUS_SIMPLE_CASE
|
||||
actor_method_cpus = DEFAULT_ACTOR_METHOD_CPUS_SIMPLE_CASE
|
||||
else:
|
||||
# If any resources are specified, then all resources are
|
||||
# acquired for the actor's lifetime and no resources are
|
||||
# associated with methods.
|
||||
cpus_to_use = (DEFAULT_ACTOR_CREATION_CPUS_SPECIFIED_CASE
|
||||
if num_cpus is None else num_cpus)
|
||||
actor_method_cpus = DEFAULT_ACTOR_METHOD_CPUS_SPECIFIED_CASE
|
||||
|
||||
Returns:
|
||||
This returns the function ID.
|
||||
"""
|
||||
function_id_hash = hashlib.sha1()
|
||||
# Include the function name in the hash.
|
||||
function_id_hash.update(func_name.encode("ascii"))
|
||||
# If we are running a script or are in IPython, include the source code in
|
||||
# the hash. If we are in a regular Python interpreter we skip this part
|
||||
# because the source code is not accessible. If the function is a built-in
|
||||
# (e.g., Cython), the source code is not accessible.
|
||||
import __main__ as main
|
||||
if (hasattr(main, "__file__") or in_ipython()) \
|
||||
and inspect.isfunction(func):
|
||||
function_id_hash.update(inspect.getsource(func).encode("ascii"))
|
||||
# Compute the function ID.
|
||||
function_id = function_id_hash.digest()
|
||||
assert len(function_id) == 20
|
||||
function_id = FunctionID(function_id)
|
||||
return worker.make_actor(function_or_class, cpus_to_use, num_gpus,
|
||||
resources, actor_method_cpus,
|
||||
checkpoint_interval)
|
||||
|
||||
return function_id
|
||||
raise Exception("The @ray.remote decorator must be applied to "
|
||||
"either a function or to a class.")
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def remote(*args, **kwargs):
|
||||
"""This decorator is used to define remote functions and to define actors.
|
||||
|
||||
Args:
|
||||
num_return_vals (int): The number of object IDs that a call to this
|
||||
function should return.
|
||||
num_cpus (int): The number of CPUs needed to execute this function.
|
||||
num_gpus (int): The number of GPUs needed to execute this function.
|
||||
resources: A dictionary mapping resource name to the required quantity
|
||||
of that resource.
|
||||
max_calls (int): The maximum number of tasks of this kind that can be
|
||||
run on a worker before the worker needs to be restarted.
|
||||
checkpoint_interval (int): The number of tasks to run between
|
||||
checkpoints of the actor state.
|
||||
"""
|
||||
worker = global_worker
|
||||
|
||||
def make_remote_decorator(num_return_vals,
|
||||
num_cpus,
|
||||
num_gpus,
|
||||
resources,
|
||||
max_calls,
|
||||
checkpoint_interval,
|
||||
func_id=None):
|
||||
def remote_decorator(func_or_class):
|
||||
if inspect.isfunction(func_or_class) or is_cython(func_or_class):
|
||||
# Set the remote function default resources.
|
||||
resources["CPU"] = (DEFAULT_REMOTE_FUNCTION_CPUS
|
||||
if num_cpus is None else num_cpus)
|
||||
resources["GPU"] = (DEFAULT_REMOTE_FUNCTION_GPUS
|
||||
if num_gpus is None else num_gpus)
|
||||
|
||||
function_properties = FunctionProperties(
|
||||
num_return_vals=num_return_vals,
|
||||
resources=resources,
|
||||
max_calls=max_calls)
|
||||
return remote_function_decorator(func_or_class,
|
||||
function_properties)
|
||||
if inspect.isclass(func_or_class):
|
||||
# Set the actor default resources.
|
||||
if num_cpus is None and num_gpus is None and resources == {}:
|
||||
# In the default case, actors acquire no resources for
|
||||
# their lifetime, and actor methods will require 1 CPU.
|
||||
resources["CPU"] = DEFAULT_ACTOR_CREATION_CPUS_SIMPLE_CASE
|
||||
actor_method_cpus = DEFAULT_ACTOR_METHOD_CPUS_SIMPLE_CASE
|
||||
else:
|
||||
# If any resources are specified, then all resources are
|
||||
# acquired for the actor's lifetime and no resources are
|
||||
# associated with methods.
|
||||
resources["CPU"] = (
|
||||
DEFAULT_ACTOR_CREATION_CPUS_SPECIFIED_CASE
|
||||
if num_cpus is None else num_cpus)
|
||||
resources["GPU"] = (
|
||||
DEFAULT_ACTOR_CREATION_GPUS_SPECIFIED_CASE
|
||||
if num_gpus is None else num_gpus)
|
||||
actor_method_cpus = (
|
||||
DEFAULT_ACTOR_METHOD_CPUS_SPECIFIED_CASE)
|
||||
|
||||
return worker.make_actor(func_or_class, resources,
|
||||
checkpoint_interval,
|
||||
actor_method_cpus)
|
||||
raise Exception("The @ray.remote decorator must be applied to "
|
||||
"either a function or to a class.")
|
||||
|
||||
def remote_function_decorator(func, function_properties):
|
||||
func_name = "{}.{}".format(func.__module__, func.__name__)
|
||||
if func_id is None:
|
||||
function_id = compute_function_id(func_name, func)
|
||||
else:
|
||||
function_id = func_id
|
||||
|
||||
def func_call(*args, **kwargs):
|
||||
"""This runs immediately when a remote function is called."""
|
||||
return _submit(args=args, kwargs=kwargs)
|
||||
|
||||
def _submit(args=None,
|
||||
kwargs=None,
|
||||
num_return_vals=None,
|
||||
num_cpus=None,
|
||||
num_gpus=None,
|
||||
resources=None):
|
||||
"""An experimental alternate way to submit remote functions."""
|
||||
check_connected()
|
||||
check_main_thread()
|
||||
kwargs = {} if kwargs is None else kwargs
|
||||
args = signature.extend_args(function_signature, args, kwargs)
|
||||
|
||||
if _mode() == PYTHON_MODE:
|
||||
# In PYTHON_MODE, remote calls simply execute the function.
|
||||
# We copy the arguments to prevent the function call from
|
||||
# mutating them and to match the usual behavior of
|
||||
# immutable remote objects.
|
||||
result = func(*copy.deepcopy(args))
|
||||
return result
|
||||
object_ids = _submit_task(
|
||||
function_id,
|
||||
args,
|
||||
num_return_vals=num_return_vals,
|
||||
num_cpus=num_cpus,
|
||||
num_gpus=num_gpus,
|
||||
resources=resources)
|
||||
if len(object_ids) == 1:
|
||||
return object_ids[0]
|
||||
elif len(object_ids) > 1:
|
||||
return object_ids
|
||||
|
||||
def func_executor(arguments):
|
||||
"""This gets run when the remote function is executed."""
|
||||
result = func(*arguments)
|
||||
return result
|
||||
|
||||
def func_invoker(*args, **kwargs):
|
||||
"""This is used to invoke the function."""
|
||||
raise Exception("Remote functions cannot be called directly. "
|
||||
"Instead of running '{}()', try '{}.remote()'."
|
||||
.format(func_name, func_name))
|
||||
|
||||
func_invoker.remote = func_call
|
||||
func_invoker._submit = _submit
|
||||
func_invoker.executor = func_executor
|
||||
func_invoker.is_remote = True
|
||||
func_name = "{}.{}".format(func.__module__, func.__name__)
|
||||
func_invoker.func_name = func_name
|
||||
if sys.version_info >= (3, 0) or is_cython(func):
|
||||
func_invoker.__doc__ = func.__doc__
|
||||
else:
|
||||
func_invoker.func_doc = func.func_doc
|
||||
|
||||
signature.check_signature_supported(func)
|
||||
function_signature = signature.extract_signature(func)
|
||||
|
||||
# Everything ready - export the function
|
||||
if worker.mode in [SCRIPT_MODE, SILENT_MODE]:
|
||||
export_remote_function(function_id, func_name, func,
|
||||
func_invoker, function_properties)
|
||||
elif worker.mode is None:
|
||||
worker.cached_remote_functions_and_actors.append(
|
||||
("remote_function", (function_id, func_name, func,
|
||||
func_invoker, function_properties)))
|
||||
return func_invoker
|
||||
|
||||
return remote_decorator
|
||||
|
||||
# Handle resource arguments
|
||||
num_cpus = kwargs["num_cpus"] if "num_cpus" in kwargs else None
|
||||
num_gpus = kwargs["num_gpus"] if "num_gpus" in kwargs else None
|
||||
resources = kwargs.get("resources", {})
|
||||
if not isinstance(resources, dict):
|
||||
raise Exception("The 'resources' keyword argument must be a "
|
||||
"dictionary, but received type {}.".format(
|
||||
type(resources)))
|
||||
assert "CPU" not in resources, "Use the 'num_cpus' argument."
|
||||
assert "GPU" not in resources, "Use the 'num_gpus' argument."
|
||||
# Handle other arguments.
|
||||
num_return_vals = (kwargs["num_return_vals"]
|
||||
if "num_return_vals" in kwargs else 1)
|
||||
max_calls = kwargs["max_calls"] if "max_calls" in kwargs else 0
|
||||
checkpoint_interval = (kwargs["checkpoint_interval"]
|
||||
if "checkpoint_interval" in kwargs else -1)
|
||||
|
||||
if _mode() == WORKER_MODE:
|
||||
if "function_id" in kwargs:
|
||||
function_id = kwargs["function_id"]
|
||||
return make_remote_decorator(num_return_vals, num_cpus, num_gpus,
|
||||
resources, max_calls,
|
||||
checkpoint_interval, function_id)
|
||||
worker = get_global_worker()
|
||||
|
||||
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
||||
# This is the case where the decorator is just @ray.remote.
|
||||
return make_remote_decorator(num_return_vals, num_cpus, num_gpus,
|
||||
resources, max_calls,
|
||||
checkpoint_interval)(args[0])
|
||||
else:
|
||||
# This is the case where the decorator is something like
|
||||
# @ray.remote(num_return_vals=2).
|
||||
error_string = ("The @ray.remote decorator must be applied either "
|
||||
"with no arguments and no parentheses, for example "
|
||||
"'@ray.remote', or it must be applied using some of "
|
||||
"the arguments 'num_return_vals', 'resources', "
|
||||
"or 'max_calls', like "
|
||||
"'@ray.remote(num_return_vals=2, "
|
||||
"resources={\"GPU\": 1})'.")
|
||||
assert len(args) == 0 and len(kwargs) > 0, error_string
|
||||
for key in kwargs:
|
||||
assert key in [
|
||||
"num_return_vals", "num_cpus", "num_gpus", "resources",
|
||||
"max_calls", "checkpoint_interval"
|
||||
], error_string
|
||||
assert "function_id" not in kwargs
|
||||
return make_remote_decorator(num_return_vals, num_cpus, num_gpus,
|
||||
resources, max_calls, checkpoint_interval)
|
||||
return make_decorator(worker=worker)(args[0])
|
||||
|
||||
# Parse the keyword arguments from the decorator.
|
||||
error_string = ("The @ray.remote decorator must be applied either "
|
||||
"with no arguments and no parentheses, for example "
|
||||
"'@ray.remote', or it must be applied using some of "
|
||||
"the arguments 'num_return_vals', 'num_cpus', 'num_gpus', "
|
||||
"'resources', 'max_calls', or 'checkpoint_interval', like "
|
||||
"'@ray.remote(num_return_vals=2, "
|
||||
"resources={\"CustomResource\": 1})'.")
|
||||
assert len(args) == 0 and len(kwargs) > 0, error_string
|
||||
for key in kwargs:
|
||||
assert key in [
|
||||
"num_return_vals", "num_cpus", "num_gpus", "resources",
|
||||
"max_calls", "checkpoint_interval"
|
||||
], error_string
|
||||
|
||||
num_cpus = kwargs["num_cpus"] if "num_cpus" in kwargs else None
|
||||
num_gpus = kwargs["num_gpus"] if "num_gpus" in kwargs else None
|
||||
resources = kwargs.get("resources")
|
||||
if not isinstance(resources, dict) and resources is not None:
|
||||
raise Exception("The 'resources' keyword argument must be a "
|
||||
"dictionary, but received type {}.".format(
|
||||
type(resources)))
|
||||
if resources is not None:
|
||||
assert "CPU" not in resources, "Use the 'num_cpus' argument."
|
||||
assert "GPU" not in resources, "Use the 'num_gpus' argument."
|
||||
|
||||
# Handle other arguments.
|
||||
num_return_vals = kwargs.get("num_return_vals")
|
||||
max_calls = kwargs.get("max_calls")
|
||||
checkpoint_interval = kwargs.get("checkpoint_interval")
|
||||
|
||||
return make_decorator(
|
||||
num_return_vals=num_return_vals,
|
||||
num_cpus=num_cpus,
|
||||
num_gpus=num_gpus,
|
||||
resources=resources,
|
||||
max_calls=max_calls,
|
||||
checkpoint_interval=checkpoint_interval,
|
||||
worker=worker)
|
||||
|
||||
Reference in New Issue
Block a user