Create RemoteFunction class, remove FunctionProperties, simplify worker Python code. (#2052)

* Cleaning up worker and actor code. Create remote function class. Remove FunctionProperties object.

* Remove register_actor_signatures function.

* Small cleanups.

* Fix linting.

* Support @ray.method syntax for actor methods.

* Fix pickling bug.

* Fix linting.

* Shorten testBlockingTasks.

* Small fixes.

* Call get_global_worker().
This commit is contained in:
Robert Nishihara
2018-05-14 14:35:23 -07:00
committed by Philipp Moritz
parent ad48e47120
commit 8fbb88485b
9 changed files with 623 additions and 657 deletions
+230 -411
View File
@@ -5,7 +5,6 @@ from __future__ import print_function
import atexit
import collections
import colorama
import copy
import hashlib
import inspect
import json
@@ -23,13 +22,13 @@ import pyarrow
import pyarrow.plasma as plasma
import ray.cloudpickle as pickle
import ray.experimental.state as state
import ray.remote_function
import ray.serialization as serialization
import ray.services as services
import ray.signature as signature
import ray.signature
import ray.local_scheduler
import ray.plasma
from ray.utils import (FunctionProperties, random_string, binary_to_hex,
is_cython)
from ray.utils import random_string, binary_to_hex, is_cython
# Import flatbuffer bindings.
from ray.core.generated.ClientTableData import ClientTableData
@@ -63,9 +62,6 @@ PUT_RECONSTRUCTION_ERROR_TYPE = b"put_reconstruction"
# This must be kept in sync with the `scheduling_state` enum in common/task.h.
TASK_STATUS_RUNNING = 8
# Default resource requirements for remote functions.
DEFAULT_REMOTE_FUNCTION_CPUS = 1
DEFAULT_REMOTE_FUNCTION_GPUS = 0
# Default resource requirements for actors when no resource requirements are
# specified.
DEFAULT_ACTOR_METHOD_CPUS_SIMPLE_CASE = 1
@@ -74,15 +70,6 @@ DEFAULT_ACTOR_CREATION_CPUS_SIMPLE_CASE = 0
# specified.
DEFAULT_ACTOR_METHOD_CPUS_SPECIFIED_CASE = 0
DEFAULT_ACTOR_CREATION_CPUS_SPECIFIED_CASE = 1
DEFAULT_ACTOR_CREATION_GPUS_SPECIFIED_CASE = 0
class FunctionID(object):
def __init__(self, function_id):
self.function_id = function_id
def id(self):
return self.function_id
class RayTaskError(Exception):
@@ -182,6 +169,11 @@ class RayGetArgumentError(Exception):
self.task_error))
FunctionExecutionInfo = collections.namedtuple(
"FunctionExecutionInfo", ["function", "function_name", "max_calls"])
"""FunctionExecutionInfo: A named tuple storing remote function information."""
class Worker(object):
"""A class used to define the control flow of a worker process.
@@ -190,9 +182,10 @@ class Worker(object):
functions outside of this class are considered exposed.
Attributes:
functions (Dict[str, Callable]): A dictionary mapping the name of a
remote function to the remote function itself. This is the set of
remote functions that can be executed by this worker.
function_execution_info (Dict[str, FunctionExecutionInfo]): A
dictionary mapping the name of a remote function to the remote
function itself. This is the set of remote functions that can be
executed by this worker.
connected (bool): True if Ray has been started and False otherwise.
mode: The mode of the worker. One of SCRIPT_MODE, PYTHON_MODE,
SILENT_MODE, and WORKER_MODE.
@@ -208,20 +201,12 @@ class Worker(object):
def __init__(self):
"""Initialize a Worker object."""
# The functions field is a dictionary that maps a driver ID to a
# dictionary of functions that have been registered for that driver
# (this inner dictionary maps function IDs to a tuple of the function
# name and the function itself). This should only be used on workers
# that execute remote functions.
self.functions = collections.defaultdict(lambda: {})
# The function_properties field is a dictionary that maps a driver ID
# to a dictionary of functions that have been registered for that
# driver (this inner dictionary maps function IDs to a tuple of the
# number of values returned by that function, the number of CPUs
# required by that function, and the number of GPUs required by that
# function). This is used when submitting a function (which can be done
# both on workers and on drivers).
self.function_properties = collections.defaultdict(lambda: {})
# This field is a dictionary that maps a driver ID to a dictionary of
# functions (and information about those functions) that have been
# registered for that driver (this inner dictionary maps function IDs
# to a FunctionExecutionInfo object. This should only be used on
# workers that execute remote functions.
self.function_execution_info = collections.defaultdict(lambda: {})
# This is a dictionary mapping driver ID to a dictionary that maps
# remote function IDs for that driver to a counter of the number of
# times that remote function has been executed on this worker. The
@@ -248,6 +233,16 @@ class Worker(object):
# CUDA_VISIBLE_DEVICES environment variable.
self.original_gpu_ids = ray.utils.get_cuda_visible_devices()
def check_connected(self):
"""Check if the worker is connected.
Raises:
Exception: An exception is raised if the worker is not connected.
"""
if not self.connected:
raise RayConnectionError("Ray has not been started yet. You can "
"start Ray with 'ray.init()'.")
def set_mode(self, mode):
"""Set the mode of the worker.
@@ -356,7 +351,7 @@ class Worker(object):
full.
"""
# Make sure that the value is not an object ID.
if isinstance(value, ray.local_scheduler.ObjectID):
if isinstance(value, ray.ObjectID):
raise Exception("Calling 'put' on an ObjectID is not allowed "
"(similarly, returning an ObjectID from a remote "
"function is not allowed). If you really want to "
@@ -438,7 +433,7 @@ class Worker(object):
"""
# Make sure that the values are object IDs.
for object_id in object_ids:
if not isinstance(object_id, ray.local_scheduler.ObjectID):
if not isinstance(object_id, ray.ObjectID):
raise Exception("Attempting to call `get` on the value {}, "
"which is not an ObjectID.".format(object_id))
# Do an initial fetch for remote objects. We divide the fetch into
@@ -518,8 +513,6 @@ class Worker(object):
actor_creation_dummy_object_id=None,
execution_dependencies=None,
num_return_vals=None,
num_cpus=None,
num_gpus=None,
resources=None,
driver_id=None):
"""Submit a remote task to the scheduler.
@@ -545,8 +538,6 @@ class Worker(object):
execution_dependencies: The execution dependencies for this task.
num_return_vals: The number of return values this function should
have.
num_cpus: The number of CPUs required by this task.
num_gpus: The number of GPUs required by this task.
resources: The resource requirements for this task.
driver_id: The ID of the relevant driver. This is almost always the
driver ID of the driver that is currently running. However, in
@@ -561,24 +552,22 @@ class Worker(object):
check_main_thread()
if actor_id is None:
assert actor_handle_id is None
actor_id = ray.local_scheduler.ObjectID(NIL_ACTOR_ID)
actor_handle_id = ray.local_scheduler.ObjectID(
NIL_ACTOR_HANDLE_ID)
actor_id = ray.ObjectID(NIL_ACTOR_ID)
actor_handle_id = ray.ObjectID(NIL_ACTOR_HANDLE_ID)
else:
assert actor_handle_id is not None
if actor_creation_id is None:
actor_creation_id = ray.local_scheduler.ObjectID(NIL_ACTOR_ID)
actor_creation_id = ray.ObjectID(NIL_ACTOR_ID)
if actor_creation_dummy_object_id is None:
actor_creation_dummy_object_id = (
ray.local_scheduler.ObjectID(NIL_ID))
actor_creation_dummy_object_id = (ray.ObjectID(NIL_ID))
# Put large or complex arguments that are passed by value in the
# object store first.
args_for_local_scheduler = []
for arg in args:
if isinstance(arg, ray.local_scheduler.ObjectID):
if isinstance(arg, ray.ObjectID):
args_for_local_scheduler.append(arg)
elif ray.local_scheduler.check_simple_value(arg):
args_for_local_scheduler.append(arg)
@@ -592,26 +581,12 @@ class Worker(object):
if driver_id is None:
driver_id = self.task_driver_id
# Look up the various function properties.
function_properties = self.function_properties[driver_id.id()][
function_id.id()]
if num_return_vals is None:
num_return_vals = function_properties.num_return_vals
if resources is None and num_cpus is None and num_gpus is None:
resources = function_properties.resources
else:
resources = {} if resources is None else resources
if "CPU" in resources or "GPU" in resources:
raise ValueError("The resources dictionary must not "
"contain the keys 'CPU' or 'GPU'")
resources["CPU"] = num_cpus
resources["GPU"] = num_gpus
if resources is None:
raise ValueError("The resources dictionary is required.")
# Submit the task to local scheduler.
task = ray.local_scheduler.Task(
driver_id, ray.local_scheduler.ObjectID(
driver_id, ray.ObjectID(
function_id.id()), args_for_local_scheduler,
num_return_vals, self.current_task_id, self.task_index,
actor_creation_id, actor_creation_dummy_object_id, actor_id,
@@ -624,6 +599,55 @@ class Worker(object):
return task.returns()
def export_remote_function(self, function_id, function_name, function,
max_calls, decorated_function):
"""Export a remote function.
Args:
function_id: The ID of the function.
function_name: The name of the function.
function: The raw undecorated function to export.
max_calls: The maximum number of times a given worker can execute
this function before exiting.
decorated_function: The decorated function (this is used to enable
the remote function to recursively call itself).
"""
check_main_thread()
if self.mode not in [SCRIPT_MODE, SILENT_MODE]:
raise Exception("export_remote_function can only be called on a "
"driver.")
key = (b"RemoteFunction:" + self.task_driver_id.id() + b":" +
function_id.id())
# Work around limitations of Python pickling.
function_name_global_valid = function.__name__ in function.__globals__
function_name_global_value = function.__globals__.get(
function.__name__)
# Allow the function to reference itself as a global variable
if not is_cython(function):
function.__globals__[function.__name__] = decorated_function
try:
pickled_function = pickle.dumps(function)
finally:
# Undo our changes
if function_name_global_valid:
function.__globals__[function.__name__] = (
function_name_global_value)
else:
del function.__globals__[function.__name__]
self.redis_client.hmset(
key, {
"driver_id": self.task_driver_id.id(),
"function_id": function_id.id(),
"name": function_name,
"module": function.__module__,
"function": pickled_function,
"max_calls": max_calls
})
self.redis_client.rpush("Exports", key)
def run_function_on_all_workers(self, function):
"""Run arbitrary code on all of the workers.
@@ -697,7 +721,8 @@ class Worker(object):
while True:
with self.lock:
if (self.actor_id == NIL_ACTOR_ID
and (function_id.id() in self.functions[driver_id])):
and (function_id.id() in
self.function_execution_info[driver_id])):
break
elif self.actor_id != NIL_ACTOR_ID and (
self.actor_id in self.actors):
@@ -741,7 +766,7 @@ class Worker(object):
"""
arguments = []
for (i, arg) in enumerate(serialized_args):
if isinstance(arg, ray.local_scheduler.ObjectID):
if isinstance(arg, ray.ObjectID):
# get the object from the local object store
argument = self.get_object([arg])[0]
if isinstance(argument, RayTaskError):
@@ -798,7 +823,6 @@ class Worker(object):
# message to the correct driver.
self.task_driver_id = task.driver_id()
self.current_task_id = task.task_id()
self.current_function_id = task.function_id().id()
self.task_index = 0
self.put_index = 1
function_id = task.function_id()
@@ -806,8 +830,10 @@ class Worker(object):
return_object_ids = task.returns()
if task.actor_id().id() != NIL_ACTOR_ID:
dummy_return_id = return_object_ids.pop()
function_name, function_executor = (
self.functions[self.task_driver_id.id()][function_id.id()])
function_executor = self.function_execution_info[
self.task_driver_id.id()][function_id.id()].function
function_name = self.function_execution_info[self.task_driver_id.id()][
function_id.id()].function_name
# Get task arguments from the object store.
try:
@@ -829,7 +855,7 @@ class Worker(object):
try:
with log_span("ray:task:execute", worker=self):
if task.actor_id().id() == NIL_ACTOR_ID:
outputs = function_executor.executor(arguments)
outputs = function_executor(*arguments)
else:
outputs = function_executor(
dummy_return_id, self.actors[task.actor_id().id()],
@@ -862,8 +888,8 @@ class Worker(object):
def _handle_process_task_failure(self, function_id, return_object_ids,
error, backtrace):
function_name, _ = self.functions[self.task_driver_id.id()][
function_id.id()]
function_name = self.function_execution_info[self.task_driver_id.id()][
function_id.id()].function_name
failure_object = RayTaskError(function_name, error, backtrace)
failure_objects = [
failure_object for _ in range(len(return_object_ids))
@@ -902,7 +928,7 @@ class Worker(object):
time.sleep(0.001)
with self.lock:
self.fetch_and_register_actor(key, task.required_resources(), self)
self.fetch_and_register_actor(key, self)
def _wait_for_and_process_task(self, task):
"""Wait for a task to be ready and process the task.
@@ -911,11 +937,11 @@ class Worker(object):
task: The task to execute.
"""
function_id = task.function_id()
driver_id = task.driver_id().id()
# TODO(rkn): It would be preferable for actor creation tasks to share
# more of the code path with regular task execution.
if (task.actor_creation_id() !=
ray.local_scheduler.ObjectID(NIL_ACTOR_ID)):
if (task.actor_creation_id() != ray.ObjectID(NIL_ACTOR_ID)):
self._become_actor(task)
return
@@ -923,7 +949,7 @@ class Worker(object):
# on this worker. We will push warnings to the user if we spend too
# long in this loop.
with log_span("ray:wait_for_function", worker=self):
self._wait_for_function(function_id, task.driver_id().id())
self._wait_for_function(function_id, driver_id)
# Execute the task.
# TODO(rkn): Consider acquiring this lock with a timeout and pushing a
@@ -934,8 +960,8 @@ class Worker(object):
with self.lock:
log(event_type="ray:acquire_lock", kind=LOG_SPAN_END, worker=self)
function_name, _ = (
self.functions[task.driver_id().id()][function_id.id()])
function_name = (self.function_execution_info[driver_id][
function_id.id()]).function_name
contents = {
"function_name": function_name,
"task_id": task.task_id().hex(),
@@ -948,14 +974,13 @@ class Worker(object):
flush_log()
# Increase the task execution counter.
(self.num_task_executions[task.driver_id().id()][function_id.id()]
) += 1
self.num_task_executions[driver_id][function_id.id()] += 1
reached_max_executions = (self.num_task_executions[task.driver_id().id(
)][function_id.id()] == self.function_properties[task.driver_id().id()]
[function_id.id()].max_calls)
reached_max_executions = (
self.num_task_executions[driver_id][function_id.id()] == self.
function_execution_info[driver_id][function_id.id()].max_calls)
if reached_max_executions:
ray.worker.global_worker.local_scheduler_client.disconnect()
self.local_scheduler_client.disconnect()
os._exit(0)
def _get_next_task_from_local_scheduler(self):
@@ -1069,18 +1094,6 @@ def check_main_thread():
.format(threading.current_thread().getName()))
def check_connected(worker=global_worker):
"""Check if the worker is connected.
Raises:
Exception: An exception is raised if the worker is not connected.
"""
if not worker.connected:
raise RayConnectionError("This command cannot be called before Ray "
"has been started. You can start Ray with "
"'ray.init()'.")
def print_failed_task(task_status):
"""Print information about failed tasks.
@@ -1114,7 +1127,7 @@ def error_applies_to_driver(error_key, worker=global_worker):
def error_info(worker=global_worker):
"""Return information about failed tasks."""
check_connected(worker)
worker.check_connected()
check_main_thread()
error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1)
errors = []
@@ -1143,13 +1156,13 @@ def _initialize_serialization(worker=global_worker):
return obj.id()
def object_id_custom_deserializer(serialized_obj):
return ray.local_scheduler.ObjectID(serialized_obj)
return ray.ObjectID(serialized_obj)
# We register this serializer on each worker instead of calling
# register_custom_serializer from the driver so that isinstance still
# works.
worker.serialization_context.register_type(
ray.local_scheduler.ObjectID,
ray.ObjectID,
"ray.ObjectID",
pickle=False,
custom_serializer=object_id_custom_serializer,
@@ -1786,12 +1799,9 @@ def fetch_and_register_remote_function(key, worker=global_worker):
"driver_id", "function_id", "name", "function", "num_return_vals",
"module", "resources", "max_calls"
])
function_id = ray.local_scheduler.ObjectID(function_id_str)
function_id = ray.ObjectID(function_id_str)
function_name = function_name.decode("ascii")
function_properties = FunctionProperties(
num_return_vals=int(num_return_vals),
resources=json.loads(resources.decode("ascii")),
max_calls=int(max_calls))
max_calls = int(max_calls)
module = module.decode("ascii")
# This is a placeholder in case the function can't be unpickled. This will
@@ -1799,11 +1809,9 @@ def fetch_and_register_remote_function(key, worker=global_worker):
def f():
raise Exception("This function was not imported properly.")
remote_f_placeholder = remote(function_id=function_id)(lambda *xs: f())
worker.functions[driver_id][function_id.id()] = (function_name,
remote_f_placeholder)
worker.function_properties[driver_id][function_id.id()] = (
function_properties)
worker.function_execution_info[driver_id][function_id.id()] = (
FunctionExecutionInfo(
function=f, function_name=function_name, max_calls=max_calls))
worker.num_task_executions[driver_id][function_id.id()] = 0
try:
@@ -1825,8 +1833,11 @@ def fetch_and_register_remote_function(key, worker=global_worker):
else:
# TODO(rkn): Why is the below line necessary?
function.__module__ = module
worker.functions[driver_id][function_id.id()] = (
function_name, remote(function_id=function_id)(function))
worker.function_execution_info[driver_id][function_id.id()] = (
FunctionExecutionInfo(
function=function,
function_name=function_name,
max_calls=max_calls))
# Add the function to the function table.
worker.redis_client.rpush(b"FunctionTable:" + function_id.id(),
worker.worker_id)
@@ -1973,6 +1984,14 @@ def connect(info,
assert worker.cached_remote_functions_and_actors is not None, error_message
# Initialize some fields.
worker.worker_id = random_string()
# When tasks are executed on remote workers in the context of multiple
# drivers, the task driver ID is used to keep track of which driver is
# responsible for the task so that error messages will be propagated to
# the correct driver.
if mode != WORKER_MODE:
worker.task_driver_id = ray.ObjectID(worker.worker_id)
# All workers start out as non-actors. A worker can be turned into an actor
# after it is created.
worker.actor_id = NIL_ACTOR_ID
@@ -2102,13 +2121,7 @@ def connect(info,
else:
# Try to use true randomness.
np.random.seed(None)
worker.current_task_id = ray.local_scheduler.ObjectID(
np.random.bytes(20))
# When tasks are executed on remote workers in the context of multiple
# drivers, the task driver ID is used to keep track of which driver is
# responsible for the task so that error messages will be propagated to
# the correct driver.
worker.task_driver_id = ray.local_scheduler.ObjectID(worker.worker_id)
worker.current_task_id = ray.ObjectID(np.random.bytes(20))
# Reset the state of the numpy random number generator.
np.random.set_state(numpy_state)
# Set other fields needed for computing task IDs.
@@ -2124,14 +2137,11 @@ def connect(info,
nil_actor_counter = 0
driver_task = ray.local_scheduler.Task(
worker.task_driver_id,
ray.local_scheduler.ObjectID(NIL_FUNCTION_ID), [], 0,
worker.task_driver_id, ray.ObjectID(NIL_FUNCTION_ID), [], 0,
worker.current_task_id, worker.task_index,
ray.local_scheduler.ObjectID(NIL_ACTOR_ID),
ray.local_scheduler.ObjectID(NIL_ACTOR_ID),
ray.local_scheduler.ObjectID(NIL_ACTOR_ID),
ray.local_scheduler.ObjectID(NIL_ACTOR_ID), nil_actor_counter,
False, [], {"CPU": 0}, worker.use_raylet)
ray.ObjectID(NIL_ACTOR_ID), ray.ObjectID(NIL_ACTOR_ID),
ray.ObjectID(NIL_ACTOR_ID), ray.ObjectID(NIL_ACTOR_ID),
nil_actor_counter, False, [], {"CPU": 0}, worker.use_raylet)
global_state._execute_command(
driver_task.task_id(), "RAY.TASK_TABLE_ADD",
driver_task.task_id().id(),
@@ -2194,11 +2204,7 @@ def connect(info,
# Export cached remote functions to the workers.
for cached_type, info in worker.cached_remote_functions_and_actors:
if cached_type == "remote_function":
(function_id, func_name, func, func_invoker,
function_properties) = info
export_remote_function(function_id, func_name, func,
func_invoker, function_properties,
worker)
info._export()
elif cached_type == "actor":
(key, actor_class_info) = info
ray.actor.publish_actor_class_to_key(key, actor_class_info,
@@ -2450,7 +2456,7 @@ def get(object_ids, worker=global_worker):
Returns:
A Python object or a list of Python objects.
"""
check_connected(worker)
worker.check_connected()
with log_span("ray:get", worker=worker):
check_main_thread()
@@ -2483,7 +2489,7 @@ def put(value, worker=global_worker):
Returns:
The object ID assigned to this value.
"""
check_connected(worker)
worker.check_connected()
with log_span("ray:put", worker=worker):
check_main_thread()
@@ -2524,7 +2530,7 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
print("plasma_client.wait has not been implemented yet")
return
if isinstance(object_ids, ray.local_scheduler.ObjectID):
if isinstance(object_ids, ray.ObjectID):
raise TypeError(
"wait() expected a list of ObjectID, got a single ObjectID")
@@ -2534,12 +2540,12 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
if worker.mode != PYTHON_MODE:
for object_id in object_ids:
if not isinstance(object_id, ray.local_scheduler.ObjectID):
if not isinstance(object_id, ray.ObjectID):
raise TypeError("wait() expected a list of ObjectID, "
"got list containing {}".format(
type(object_id)))
check_connected(worker)
worker.check_connected()
with log_span("ray:wait", worker=worker):
check_main_thread()
@@ -2561,27 +2567,14 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
ready_ids, remaining_ids = worker.plasma_client.wait(
object_id_strs, timeout, num_returns)
ready_ids = [
ray.local_scheduler.ObjectID(object_id.binary())
for object_id in ready_ids
ray.ObjectID(object_id.binary()) for object_id in ready_ids
]
remaining_ids = [
ray.local_scheduler.ObjectID(object_id.binary())
for object_id in remaining_ids
ray.ObjectID(object_id.binary()) for object_id in remaining_ids
]
return ready_ids, remaining_ids
def _submit_task(function_id, *args, **kwargs):
"""This is a wrapper around worker.submit_task.
We use this wrapper so that in the remote decorator, we can call
_submit_task instead of worker.submit_task. The difference is that when we
attempt to serialize remote functions, we don't attempt to serialize the
worker object, which cannot be serialized.
"""
return global_worker.submit_task(function_id, *args, **kwargs)
def _mode(worker=global_worker):
"""This is a wrapper around worker.mode.
@@ -2593,278 +2586,104 @@ def _mode(worker=global_worker):
return worker.mode
def export_remote_function(function_id,
func_name,
func,
func_invoker,
function_properties,
worker=global_worker):
check_main_thread()
if _mode(worker) not in [SCRIPT_MODE, SILENT_MODE]:
raise Exception("export_remote_function can only be called on a "
"driver.")
worker.function_properties[worker.task_driver_id.id()][
function_id.id()] = function_properties
task_driver_id = worker.task_driver_id
key = b"RemoteFunction:" + task_driver_id.id() + b":" + function_id.id()
# Work around limitations of Python pickling.
func_name_global_valid = func.__name__ in func.__globals__
func_name_global_value = func.__globals__.get(func.__name__)
# Allow the function to reference itself as a global variable
if not is_cython(func):
func.__globals__[func.__name__] = func_invoker
try:
pickled_func = pickle.dumps(func)
finally:
# Undo our changes
if func_name_global_valid:
func.__globals__[func.__name__] = func_name_global_value
else:
del func.__globals__[func.__name__]
worker.redis_client.hmset(
key, {
"driver_id": worker.task_driver_id.id(),
"function_id": function_id.id(),
"name": func_name,
"module": func.__module__,
"function": pickled_func,
"num_return_vals": function_properties.num_return_vals,
"resources": json.dumps(function_properties.resources),
"max_calls": function_properties.max_calls
})
worker.redis_client.rpush("Exports", key)
def get_global_worker():
return global_worker
def in_ipython():
"""Return true if we are in an IPython interpreter and false otherwise."""
try:
__IPYTHON__
return True
except NameError:
return False
def make_decorator(num_return_vals=None,
num_cpus=None,
num_gpus=None,
resources=None,
max_calls=None,
checkpoint_interval=None,
worker=None):
def decorator(function_or_class):
if (inspect.isfunction(function_or_class)
or is_cython(function_or_class)):
# Set the remote function default resources.
if checkpoint_interval is not None:
raise Exception("The keyword 'checkpoint_interval' is not "
"allowed for remote functions.")
return ray.remote_function.RemoteFunction(
function_or_class, num_cpus, num_gpus, resources,
num_return_vals, max_calls)
def compute_function_id(func_name, func):
"""Compute an function ID for a function.
if inspect.isclass(function_or_class):
if num_return_vals is not None:
raise Exception("The keyword 'num_return_vals' is not allowed "
"for actors.")
if max_calls is not None:
raise Exception("The keyword 'max_calls' is not allowed for "
"actors.")
Args:
func_name: The name of the function (this includes the module name plus
the function name).
func: The actual function.
# Set the actor default resources.
if num_cpus is None and num_gpus is None and resources is None:
# In the default case, actors acquire no resources for
# their lifetime, and actor methods will require 1 CPU.
cpus_to_use = DEFAULT_ACTOR_CREATION_CPUS_SIMPLE_CASE
actor_method_cpus = DEFAULT_ACTOR_METHOD_CPUS_SIMPLE_CASE
else:
# If any resources are specified, then all resources are
# acquired for the actor's lifetime and no resources are
# associated with methods.
cpus_to_use = (DEFAULT_ACTOR_CREATION_CPUS_SPECIFIED_CASE
if num_cpus is None else num_cpus)
actor_method_cpus = DEFAULT_ACTOR_METHOD_CPUS_SPECIFIED_CASE
Returns:
This returns the function ID.
"""
function_id_hash = hashlib.sha1()
# Include the function name in the hash.
function_id_hash.update(func_name.encode("ascii"))
# If we are running a script or are in IPython, include the source code in
# the hash. If we are in a regular Python interpreter we skip this part
# because the source code is not accessible. If the function is a built-in
# (e.g., Cython), the source code is not accessible.
import __main__ as main
if (hasattr(main, "__file__") or in_ipython()) \
and inspect.isfunction(func):
function_id_hash.update(inspect.getsource(func).encode("ascii"))
# Compute the function ID.
function_id = function_id_hash.digest()
assert len(function_id) == 20
function_id = FunctionID(function_id)
return worker.make_actor(function_or_class, cpus_to_use, num_gpus,
resources, actor_method_cpus,
checkpoint_interval)
return function_id
raise Exception("The @ray.remote decorator must be applied to "
"either a function or to a class.")
return decorator
def remote(*args, **kwargs):
"""This decorator is used to define remote functions and to define actors.
Args:
num_return_vals (int): The number of object IDs that a call to this
function should return.
num_cpus (int): The number of CPUs needed to execute this function.
num_gpus (int): The number of GPUs needed to execute this function.
resources: A dictionary mapping resource name to the required quantity
of that resource.
max_calls (int): The maximum number of tasks of this kind that can be
run on a worker before the worker needs to be restarted.
checkpoint_interval (int): The number of tasks to run between
checkpoints of the actor state.
"""
worker = global_worker
def make_remote_decorator(num_return_vals,
num_cpus,
num_gpus,
resources,
max_calls,
checkpoint_interval,
func_id=None):
def remote_decorator(func_or_class):
if inspect.isfunction(func_or_class) or is_cython(func_or_class):
# Set the remote function default resources.
resources["CPU"] = (DEFAULT_REMOTE_FUNCTION_CPUS
if num_cpus is None else num_cpus)
resources["GPU"] = (DEFAULT_REMOTE_FUNCTION_GPUS
if num_gpus is None else num_gpus)
function_properties = FunctionProperties(
num_return_vals=num_return_vals,
resources=resources,
max_calls=max_calls)
return remote_function_decorator(func_or_class,
function_properties)
if inspect.isclass(func_or_class):
# Set the actor default resources.
if num_cpus is None and num_gpus is None and resources == {}:
# In the default case, actors acquire no resources for
# their lifetime, and actor methods will require 1 CPU.
resources["CPU"] = DEFAULT_ACTOR_CREATION_CPUS_SIMPLE_CASE
actor_method_cpus = DEFAULT_ACTOR_METHOD_CPUS_SIMPLE_CASE
else:
# If any resources are specified, then all resources are
# acquired for the actor's lifetime and no resources are
# associated with methods.
resources["CPU"] = (
DEFAULT_ACTOR_CREATION_CPUS_SPECIFIED_CASE
if num_cpus is None else num_cpus)
resources["GPU"] = (
DEFAULT_ACTOR_CREATION_GPUS_SPECIFIED_CASE
if num_gpus is None else num_gpus)
actor_method_cpus = (
DEFAULT_ACTOR_METHOD_CPUS_SPECIFIED_CASE)
return worker.make_actor(func_or_class, resources,
checkpoint_interval,
actor_method_cpus)
raise Exception("The @ray.remote decorator must be applied to "
"either a function or to a class.")
def remote_function_decorator(func, function_properties):
func_name = "{}.{}".format(func.__module__, func.__name__)
if func_id is None:
function_id = compute_function_id(func_name, func)
else:
function_id = func_id
def func_call(*args, **kwargs):
"""This runs immediately when a remote function is called."""
return _submit(args=args, kwargs=kwargs)
def _submit(args=None,
kwargs=None,
num_return_vals=None,
num_cpus=None,
num_gpus=None,
resources=None):
"""An experimental alternate way to submit remote functions."""
check_connected()
check_main_thread()
kwargs = {} if kwargs is None else kwargs
args = signature.extend_args(function_signature, args, kwargs)
if _mode() == PYTHON_MODE:
# In PYTHON_MODE, remote calls simply execute the function.
# We copy the arguments to prevent the function call from
# mutating them and to match the usual behavior of
# immutable remote objects.
result = func(*copy.deepcopy(args))
return result
object_ids = _submit_task(
function_id,
args,
num_return_vals=num_return_vals,
num_cpus=num_cpus,
num_gpus=num_gpus,
resources=resources)
if len(object_ids) == 1:
return object_ids[0]
elif len(object_ids) > 1:
return object_ids
def func_executor(arguments):
"""This gets run when the remote function is executed."""
result = func(*arguments)
return result
def func_invoker(*args, **kwargs):
"""This is used to invoke the function."""
raise Exception("Remote functions cannot be called directly. "
"Instead of running '{}()', try '{}.remote()'."
.format(func_name, func_name))
func_invoker.remote = func_call
func_invoker._submit = _submit
func_invoker.executor = func_executor
func_invoker.is_remote = True
func_name = "{}.{}".format(func.__module__, func.__name__)
func_invoker.func_name = func_name
if sys.version_info >= (3, 0) or is_cython(func):
func_invoker.__doc__ = func.__doc__
else:
func_invoker.func_doc = func.func_doc
signature.check_signature_supported(func)
function_signature = signature.extract_signature(func)
# Everything ready - export the function
if worker.mode in [SCRIPT_MODE, SILENT_MODE]:
export_remote_function(function_id, func_name, func,
func_invoker, function_properties)
elif worker.mode is None:
worker.cached_remote_functions_and_actors.append(
("remote_function", (function_id, func_name, func,
func_invoker, function_properties)))
return func_invoker
return remote_decorator
# Handle resource arguments
num_cpus = kwargs["num_cpus"] if "num_cpus" in kwargs else None
num_gpus = kwargs["num_gpus"] if "num_gpus" in kwargs else None
resources = kwargs.get("resources", {})
if not isinstance(resources, dict):
raise Exception("The 'resources' keyword argument must be a "
"dictionary, but received type {}.".format(
type(resources)))
assert "CPU" not in resources, "Use the 'num_cpus' argument."
assert "GPU" not in resources, "Use the 'num_gpus' argument."
# Handle other arguments.
num_return_vals = (kwargs["num_return_vals"]
if "num_return_vals" in kwargs else 1)
max_calls = kwargs["max_calls"] if "max_calls" in kwargs else 0
checkpoint_interval = (kwargs["checkpoint_interval"]
if "checkpoint_interval" in kwargs else -1)
if _mode() == WORKER_MODE:
if "function_id" in kwargs:
function_id = kwargs["function_id"]
return make_remote_decorator(num_return_vals, num_cpus, num_gpus,
resources, max_calls,
checkpoint_interval, function_id)
worker = get_global_worker()
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
# This is the case where the decorator is just @ray.remote.
return make_remote_decorator(num_return_vals, num_cpus, num_gpus,
resources, max_calls,
checkpoint_interval)(args[0])
else:
# This is the case where the decorator is something like
# @ray.remote(num_return_vals=2).
error_string = ("The @ray.remote decorator must be applied either "
"with no arguments and no parentheses, for example "
"'@ray.remote', or it must be applied using some of "
"the arguments 'num_return_vals', 'resources', "
"or 'max_calls', like "
"'@ray.remote(num_return_vals=2, "
"resources={\"GPU\": 1})'.")
assert len(args) == 0 and len(kwargs) > 0, error_string
for key in kwargs:
assert key in [
"num_return_vals", "num_cpus", "num_gpus", "resources",
"max_calls", "checkpoint_interval"
], error_string
assert "function_id" not in kwargs
return make_remote_decorator(num_return_vals, num_cpus, num_gpus,
resources, max_calls, checkpoint_interval)
return make_decorator(worker=worker)(args[0])
# Parse the keyword arguments from the decorator.
error_string = ("The @ray.remote decorator must be applied either "
"with no arguments and no parentheses, for example "
"'@ray.remote', or it must be applied using some of "
"the arguments 'num_return_vals', 'num_cpus', 'num_gpus', "
"'resources', 'max_calls', or 'checkpoint_interval', like "
"'@ray.remote(num_return_vals=2, "
"resources={\"CustomResource\": 1})'.")
assert len(args) == 0 and len(kwargs) > 0, error_string
for key in kwargs:
assert key in [
"num_return_vals", "num_cpus", "num_gpus", "resources",
"max_calls", "checkpoint_interval"
], error_string
num_cpus = kwargs["num_cpus"] if "num_cpus" in kwargs else None
num_gpus = kwargs["num_gpus"] if "num_gpus" in kwargs else None
resources = kwargs.get("resources")
if not isinstance(resources, dict) and resources is not None:
raise Exception("The 'resources' keyword argument must be a "
"dictionary, but received type {}.".format(
type(resources)))
if resources is not None:
assert "CPU" not in resources, "Use the 'num_cpus' argument."
assert "GPU" not in resources, "Use the 'num_gpus' argument."
# Handle other arguments.
num_return_vals = kwargs.get("num_return_vals")
max_calls = kwargs.get("max_calls")
checkpoint_interval = kwargs.get("checkpoint_interval")
return make_decorator(
num_return_vals=num_return_vals,
num_cpus=num_cpus,
num_gpus=num_gpus,
resources=resources,
max_calls=max_calls,
checkpoint_interval=checkpoint_interval,
worker=worker)