mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 04:55:04 +08:00
a1e4268d37
* catch errors in importing reusable variables and remote functions * updates
1261 lines
58 KiB
Python
1261 lines
58 KiB
Python
import os
|
|
import time
|
|
import traceback
|
|
import copy
|
|
import logging
|
|
from types import ModuleType
|
|
import typing
|
|
import funcsigs
|
|
import numpy as np
|
|
import colorama
|
|
import atexit
|
|
|
|
# Ray modules
|
|
import config
|
|
import pickling
|
|
import serialization
|
|
import internal.graph_pb2
|
|
import graph
|
|
import services
|
|
import libnumbuf
|
|
import libraylib as raylib
|
|
|
|
class RayTaskError(Exception):
|
|
"""An object used internally to represent a task that threw an exception.
|
|
|
|
If a task throws an exception during execution, a RayTaskError is stored in
|
|
the object store for each of the task's outputs. When an object is retrieved
|
|
from the object store, the Python method that retrieved it checks to see if
|
|
the object is a RayTaskError and if it is then an exceptionis thrown
|
|
propagating the error message.
|
|
|
|
Currently, we either use the exception attribute or the traceback attribute
|
|
but not both.
|
|
|
|
Attributes:
|
|
function_name (str): The name of the function that failed and produced the
|
|
RayTaskError.
|
|
exception (Exception): The exception object thrown by the failed task.
|
|
traceback_str (str): The traceback from the exception.
|
|
"""
|
|
|
|
def __init__(self, function_name, exception, traceback_str):
|
|
"""Initialize a RayTaskError."""
|
|
self.function_name = function_name
|
|
if isinstance(exception, RayGetError) or isinstance(exception, RayGetArgumentError) or isinstance(exception, RayGetArgumentTypeError):
|
|
self.exception = exception
|
|
else:
|
|
self.exception = None
|
|
self.traceback_str = traceback_str
|
|
|
|
@staticmethod
|
|
def deserialize(primitives):
|
|
"""Create a RayTaskError from a primitive object."""
|
|
function_name, exception, traceback_str = primitives
|
|
if exception[0] == "RayGetError":
|
|
exception = RayGetError.deserialize(exception[1])
|
|
elif exception[0] == "RayGetArgumentError":
|
|
exception = RayGetArgumentError.deserialize(exception[1])
|
|
elif exception[0] == "RayGetArgumentTypeError":
|
|
exception = RayGetArgumentTypeError.deserialize(exception[1])
|
|
elif exception[0] == "None":
|
|
exception = None
|
|
else:
|
|
assert False, "This code should be unreachable."
|
|
return RayTaskError(function_name, exception, traceback_str)
|
|
|
|
def serialize(self):
|
|
"""Turn a RayTaskError into a primitive object."""
|
|
if isinstance(self.exception, RayGetError):
|
|
serialized_exception = ("RayGetError", self.exception.serialize())
|
|
elif isinstance(self.exception, RayGetArgumentError):
|
|
serialized_exception = ("RayGetArgumentError", self.exception.serialize())
|
|
elif isinstance(self.exception, RayGetArgumentTypeError):
|
|
serialized_exception = ("RayGetArgumentTypeError", self.exception.serialize())
|
|
elif self.exception is None:
|
|
serialized_exception = ("None",)
|
|
else:
|
|
assert False, "This code should be unreachable."
|
|
return (self.function_name, serialized_exception, self.traceback_str)
|
|
|
|
def __str__(self):
|
|
"""Format a RayTaskError as a string."""
|
|
if self.traceback_str is None:
|
|
# This path is taken if getting the task arguments failed.
|
|
return "Remote function {}{}{} failed with:\n\n{}".format(colorama.Fore.RED, self.function_name, colorama.Fore.RESET, self.exception)
|
|
else:
|
|
# This path is taken if the task execution failed.
|
|
return "Remote function {}{}{} failed with:\n\n{}".format(colorama.Fore.RED, self.function_name, colorama.Fore.RESET, self.traceback_str)
|
|
|
|
class RayGetError(Exception):
|
|
"""An exception used when get is called on an output of a failed task.
|
|
|
|
Attributes:
|
|
objectid (lib.ObjectID): The ObjectID that get was called on.
|
|
task_error (RayTaskError): The RayTaskError object created by the failed
|
|
task.
|
|
"""
|
|
|
|
def __init__(self, objectid, task_error):
|
|
"""Initialize a RayGetError object."""
|
|
self.objectid = objectid
|
|
self.task_error = task_error
|
|
|
|
@staticmethod
|
|
def deserialize(primitives):
|
|
"""Create a RayGetError from a primitive object."""
|
|
objectid, task_error = primitives
|
|
return RayGetError(objectid, RayTaskError.deserialize(task_error))
|
|
|
|
def serialize(self):
|
|
"""Turn a RayGetError into a primitive object."""
|
|
return (self.objectid, self.task_error.serialize())
|
|
|
|
def __str__(self):
|
|
"""Format a RayGetError as a string."""
|
|
return "Could not get objectid {}. It was created by remote function {}{}{} which failed with:\n\n{}".format(self.objectid, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error)
|
|
|
|
class RayGetArgumentError(Exception):
|
|
"""An exception used when a task's argument was produced by a failed task.
|
|
|
|
Attributes:
|
|
argument_index (int): The index (zero indexed) of the failed argument in
|
|
present task's remote function call.
|
|
function_name (str): The name of the function for the current task.
|
|
objectid (lib.ObjectID): The ObjectID that was passed in as the argument.
|
|
task_error (RayTaskError): The RayTaskError object created by the failed
|
|
task.
|
|
"""
|
|
|
|
def __init__(self, function_name, argument_index, objectid, task_error):
|
|
"""Initialize a RayGetArgumentError object."""
|
|
self.argument_index = argument_index
|
|
self.function_name = function_name
|
|
self.objectid = objectid
|
|
self.task_error = task_error
|
|
|
|
@staticmethod
|
|
def deserialize(primitives):
|
|
"""Create a RayGetArgumentError from a primitive object."""
|
|
function_name, argument_index, objectid, task_error = primitives
|
|
return RayGetArgumentError(function_name, argument_index, objectid, RayTaskError.deserialize(task_error))
|
|
|
|
def serialize(self):
|
|
"""Turn a RayGetArgumentError into a primitive object."""
|
|
return (self.function_name, self.argument_index, self.objectid, self.task_error.serialize())
|
|
|
|
def __str__(self):
|
|
"""Format a RayGetArgumentError as a string."""
|
|
return "Failed to get objectid {} as argument {} for remote function {}{}{}. It was created by remote function {}{}{} which failed with:\n{}".format(self.objectid, self.argument_index, colorama.Fore.RED, self.function_name, colorama.Fore.RESET, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error)
|
|
|
|
class RayGetArgumentTypeError(Exception):
|
|
"""An exception used when a task's argument doesn't type check.
|
|
|
|
Attributes:
|
|
function_name (str): The name of the function for the current task.
|
|
argument_index (int): The index (zero indexed) of the argument in the
|
|
present task's remote function call.
|
|
received_type: The type of the argument that was passed in.
|
|
expected_type: The type that was expected. This is determined by the remote
|
|
decorator.
|
|
"""
|
|
|
|
def __init__(self, function_name, argument_index, received_type, expected_type):
|
|
"""Initialize a RayGetArgumentTypeError object."""
|
|
self.function_name = function_name
|
|
self.argument_index = argument_index
|
|
# TODO(rkn): when we support the serialization of types, then we should
|
|
# remove the string conversions below.
|
|
self.received_type = str(received_type)
|
|
self.expected_type = str(expected_type)
|
|
|
|
@staticmethod
|
|
def deserialize(primitives):
|
|
"""Create a RayGetArgumentTypeError from a primitive object."""
|
|
function_name, argument_index, received_type, expected_type = primitives
|
|
return RayGetArgumentTypeError(function_name, argument_index, received_type, expected_type)
|
|
|
|
def serialize(self):
|
|
"""Turn a RayGetArgumentTypeError into a primitive object."""
|
|
return (self.function_name, self.argument_index, self.received_type, self.expected_type)
|
|
|
|
def __str__(self):
|
|
"""Format a RayGetArgumentTypeError as a string."""
|
|
return "Argument {} for remote function {}{}{} has type {} but an argument of type {} was expected.".format(self.argument_index, colorama.Fore.RED, self.function_name, colorama.Fore.RESET, self.received_type, self.expected_type)
|
|
|
|
class RayDealloc(object):
|
|
"""An object used internally to properly implement reference counting.
|
|
|
|
When we call get_object with a particular object ID, we create a RayDealloc
|
|
object with the information necessary to properly handle closing the relevant
|
|
memory segment when the object is no longer needed by the worker. The
|
|
RayDealloc object is stored as a field in the object returned by get_object so
|
|
that its destructor is only called when the worker no longer has any
|
|
references to the object.
|
|
|
|
Attributes
|
|
handle (worker capsule): A Python object wrapping a C++ Worker object.
|
|
segmentid (int): The id of the segment that contains the object that holds
|
|
this RayDealloc object.
|
|
"""
|
|
|
|
def __init__(self, handle, segmentid):
|
|
"""Initialize a RayDealloc object.
|
|
|
|
Args:
|
|
handle (worker capsule): A Python object wrapping a C++ Worker object.
|
|
segmentid (int): The id of the segment that contains the object that holds
|
|
this RayDealloc object.
|
|
"""
|
|
self.handle = handle
|
|
self.segmentid = segmentid
|
|
|
|
def __del__(self):
|
|
"""Deallocate the relevant segment to avoid a memory leak."""
|
|
raylib.unmap_object(self.handle, self.segmentid)
|
|
|
|
class Reusable(object):
|
|
"""An Python object that can be shared between tasks.
|
|
|
|
Attributes:
|
|
initializer (Callable[[], object]): A function used to create and initialize
|
|
the reusable variable.
|
|
reinitializer (Optional[Callable[[object], object]]): An optional function
|
|
used to reinitialize the reusable variable after it has been used. This
|
|
argument can be used as an optimization if there is a fast way to
|
|
reinitialize the state of the variable other than rerunning the
|
|
initializer.
|
|
"""
|
|
|
|
def __init__(self, initializer, reinitializer=None):
|
|
"""Initialize a Reusable object."""
|
|
if not isinstance(initializer, typing.Callable):
|
|
raise Exception("When creating a RayReusable, initializer must be a function.")
|
|
self.initializer = initializer
|
|
if reinitializer is None:
|
|
# If no reinitializer is passed in, use a wrapped version of the initializer.
|
|
reinitializer = lambda value: initializer()
|
|
if not isinstance(reinitializer, typing.Callable):
|
|
raise Exception("When creating a RayReusable, reinitializer must be a function.")
|
|
self.reinitializer = reinitializer
|
|
|
|
class RayReusables(object):
|
|
"""An object used to store Python variables that are shared between tasks.
|
|
|
|
Each worker process will have a single RayReusables object. This class serves
|
|
two purposes. First, some objects are not serializable, and so the code that
|
|
creates those objects must be run on the worker that uses them. This class is
|
|
responsible for running the code that creates those objects. Second, some of
|
|
these objects are expensive to create, and so they should be shared between
|
|
tasks. However, if a task mutates a variable that is shared between tasks,
|
|
then the behavior of the overall program may be nondeterministic (it could
|
|
depend on scheduling decisions). To fix this, if a task uses a one of these
|
|
shared objects, then that shared object will be reinitialized after the task
|
|
finishes. Since the initialization may be expensive, the user can pass in
|
|
custom reinitialization code that resets the state of the shared variable to
|
|
the way it was after initialization. If the reinitialization code does not do
|
|
this, then the behavior of the overall program is undefined.
|
|
|
|
Attributes:
|
|
_names (List[str]): A list of the names of all the reusable variables.
|
|
_reusables (Dict[str, Reusable]): A dictionary mapping the name of the
|
|
reusable variables to the corresponding Reusable object.
|
|
_cached_reusables (List[Tuple[str, Reusable]]): A list of pairs. The first
|
|
element of each pair is the name of a reusable variable, and the second
|
|
element is the Reusable object. This list is used to store reusable
|
|
variables that are defined before the driver is connected. Once the driver
|
|
is connected, these variables will be exported.
|
|
_used (List[str]): A list of the names of all the reusable variables that
|
|
have been accessed within the scope of the current task. This is reset to
|
|
the empty list after each task.
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize a RayReusables object."""
|
|
self._names = set()
|
|
self._reusables = {}
|
|
self._cached_reusables = []
|
|
self._used = set()
|
|
self._slots = ("_names", "_reusables", "_cached_reusables", "_used", "_slots", "_reinitialize", "__getattribute__", "__setattr__", "__delattr__")
|
|
# CHECKPOINT: Attributes must not be added after _slots. The above attributes are protected from deletion.
|
|
|
|
def _reinitialize(self):
|
|
"""Reinitialize the reusable variables that the current task used."""
|
|
for name in self._used:
|
|
current_value = getattr(self, name)
|
|
new_value = self._reusables[name].reinitializer(current_value)
|
|
object.__setattr__(self, name, new_value)
|
|
self._used.clear() # Reset the _used list.
|
|
|
|
def __getattribute__(self, name):
|
|
"""Get an attribute. This handles reusable variables as a special case.
|
|
|
|
When __getattribute__ is called with the name of a reusable variable, that
|
|
name is added to the list of variables that were used in the current task.
|
|
|
|
Args:
|
|
name (str): The name of the attribute to get.
|
|
"""
|
|
if name == "_slots":
|
|
return object.__getattribute__(self, name)
|
|
if name in self._slots:
|
|
return object.__getattribute__(self, name)
|
|
if name in self._names and name not in self._used:
|
|
self._used.add(name)
|
|
return object.__getattribute__(self, name)
|
|
|
|
def __setattr__(self, name, value):
|
|
"""Set an attribute. This handles reusable variables as a special case.
|
|
|
|
This is used to create reusable variables. When it is called, it runs the
|
|
function for initializing the variable to create the variable. If this is
|
|
called on the driver, then the functions for initializing and reinitializing
|
|
the variable are shipped to the workers.
|
|
|
|
Args:
|
|
name (str): The name of the attribute to set. This is either a whitelisted
|
|
name or it is treated as the name of a reusable variable.
|
|
value: If name is a whitelisted name, then value can be any value. If name
|
|
is the name of a reusable variable, then this is either the serialized
|
|
initializer code or it is a tuple of the serialized initializer and
|
|
reinitializer code.
|
|
"""
|
|
try:
|
|
slots = self._slots
|
|
except AttributeError:
|
|
slots = ()
|
|
if slots == ():
|
|
return object.__setattr__(self, name, value)
|
|
if name in slots:
|
|
return object.__setattr__(self, name, value)
|
|
reusable = value
|
|
if not issubclass(type(reusable), Reusable):
|
|
raise Exception("To set a reusable variable, you must pass in a Reusable object")
|
|
self._names.add(name)
|
|
self._reusables[name] = reusable
|
|
if _mode() in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
|
|
_export_reusable_variable(name, reusable)
|
|
elif _mode() is None:
|
|
self._cached_reusables.append((name, reusable))
|
|
object.__setattr__(self, name, reusable.initializer())
|
|
|
|
def __delattr__(self, name):
|
|
"""We do not allow attributes of RayReusables to be deleted.
|
|
|
|
Args:
|
|
name (str): The name of the attribute to delete.
|
|
"""
|
|
raise Exception("Attempted deletion of attribute {}. Attributes of a RayReusable object may not be deleted.".format(name))
|
|
|
|
class Worker(object):
|
|
"""A class used to define the control flow of a worker process.
|
|
|
|
Note:
|
|
The methods in this class are considered unexposed to the user. The
|
|
functions outside of this class are considered exposed.
|
|
|
|
Attributes:
|
|
functions (Dict[str, Callable]): A dictionary mapping the name of a remote
|
|
function to the remote function itself. This is the set of remote
|
|
functions that can be executed by this worker.
|
|
handle (worker capsule): A Python object wrapping a C++ Worker object.
|
|
mode: The mode of the worker. One of SCRIPT_MODE, PYTHON_MODE, SILENT_MODE,
|
|
and WORKER_MODE.
|
|
cached_remote_functions (List[Tuple[str, str]]): A list of pairs
|
|
representing the remote functions that were defined before he worker
|
|
called connect. The first element is the name of the remote function, and
|
|
the second element is the serialized remote function. When the worker
|
|
eventually does call connect, if it is a driver, it will export these
|
|
functions to the scheduler. If cached_remote_functions is None, that means
|
|
that connect has been called already.
|
|
num_failed_tasks (int): The number of tasks that have failed and whose error
|
|
messages have been displayed to the user. We use this value to know when
|
|
a failed task hasn't been seen by the user and should be displayed.
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize a Worker object."""
|
|
self.functions = {}
|
|
self.handle = None
|
|
self.mode = None
|
|
self.cached_remote_functions = []
|
|
self.num_failed_tasks = 0
|
|
|
|
def set_mode(self, mode):
|
|
"""Set the mode of the worker.
|
|
|
|
The mode SCRIPT_MODE should be used if this Worker is a driver that is being
|
|
run as a Python script or interactively in a shell. It will print
|
|
information about task failures.
|
|
|
|
The mode WORKER_MODE should be used if this Worker is not a driver. It will
|
|
not print information about tasks.
|
|
|
|
The mode PYTHON_MODE should be used if this Worker is a driver and if you
|
|
want to run the driver in a manner equivalent to serial Python for debugging
|
|
purposes. It will not send remote function calls to the scheduler and will
|
|
insead execute them in a blocking fashion.
|
|
|
|
The mode SILENT_MODE should be used only during testing. It does not print
|
|
any information about errors because some of the tests intentionally fail.
|
|
|
|
args:
|
|
mode: One of SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, and SILENT_MODE.
|
|
"""
|
|
self.mode = mode
|
|
colorama.init()
|
|
|
|
def put_object(self, objectid, value):
|
|
"""Put value in the local object store with object id objectid.
|
|
|
|
This assumes that the value for objectid has not yet been placed in the
|
|
local object store.
|
|
|
|
Args:
|
|
objectid (raylib.ObjectID): The object ID of the value to be put.
|
|
value (serializable object): The value to put in the object store.
|
|
"""
|
|
try:
|
|
# We put the value into a list here because in arrow the concept of
|
|
# "serializing a single object" does not exits.
|
|
schema, size, serialized = libnumbuf.serialize_list([value])
|
|
# TODO(pcm): Right now, metadata is serialized twice, change that in the future
|
|
# in the following line, the "8" is for storing the metadata size,
|
|
# the len(schema) is for storing the metadata and the 4096 is for storing
|
|
# the metadata in the batch (see INITIAL_METADATA_SIZE in arrow)
|
|
size = size + 8 + len(schema) + 4096
|
|
buff, segmentid = raylib.allocate_buffer(self.handle, objectid, size)
|
|
# write the metadata length
|
|
np.frombuffer(buff, dtype="int64", count=1)[0] = len(schema)
|
|
# metadata buffer
|
|
metadata = np.frombuffer(buff, dtype="byte", offset=8, count=len(schema))
|
|
# write the metadata
|
|
metadata[:] = schema
|
|
data = np.frombuffer(buff, dtype="byte")[8 + len(schema):]
|
|
metadata_offset = libnumbuf.write_to_buffer(serialized, memoryview(data))
|
|
raylib.finish_buffer(self.handle, objectid, segmentid, metadata_offset)
|
|
except:
|
|
# At the moment, custom object and objects that contain object IDs take this path
|
|
# TODO(pcm): Make sure that these are the only objects getting serialized to protobuf
|
|
object_capsule, contained_objectids = serialization.serialize(self.handle, value) # contained_objectids is a list of the objectids contained in object_capsule
|
|
raylib.put_object(self.handle, objectid, object_capsule, contained_objectids)
|
|
|
|
def get_object(self, objectid):
|
|
"""Get the value in the local object store associated with objectid.
|
|
|
|
Return the value from the local object store for objectid. This will block
|
|
until the value for objectid has been written to the local object store.
|
|
|
|
Args:
|
|
objectid (raylib.ObjectID): The object ID of the value to retrieve.
|
|
"""
|
|
if raylib.is_arrow(self.handle, objectid):
|
|
## this is the new codepath
|
|
buff, segmentid, metadata_offset = raylib.get_buffer(self.handle, objectid)
|
|
metadata_size = np.frombuffer(buff, dtype="int64", count=1)[0]
|
|
metadata = np.frombuffer(buff, dtype="byte", offset=8, count=metadata_size)
|
|
data = np.frombuffer(buff, dtype="byte")[8 + metadata_size:]
|
|
serialized = libnumbuf.read_from_buffer(memoryview(data), bytearray(metadata), metadata_offset)
|
|
deserialized = libnumbuf.deserialize_list(serialized)
|
|
# Unwrap the object from the list (it was wrapped put_object)
|
|
assert len(deserialized) == 1
|
|
result = deserialized[0]
|
|
## this is the old codepath
|
|
# result, segmentid = raylib.get_arrow(self.handle, objectid)
|
|
else:
|
|
object_capsule, segmentid = raylib.get_object(self.handle, objectid)
|
|
result = serialization.deserialize(self.handle, object_capsule)
|
|
|
|
if isinstance(result, int):
|
|
result = serialization.Int(result)
|
|
elif isinstance(result, long):
|
|
result = serialization.Long(result)
|
|
elif isinstance(result, float):
|
|
result = serialization.Float(result)
|
|
elif isinstance(result, bool):
|
|
raylib.unmap_object(self.handle, segmentid) # need to unmap here because result is passed back "by value" and we have no reference to unmap later
|
|
return result # can't subclass bool, and don't need to because there is a global True/False
|
|
elif isinstance(result, list):
|
|
result = serialization.List(result)
|
|
elif isinstance(result, dict):
|
|
result = serialization.Dict(result)
|
|
elif isinstance(result, tuple):
|
|
result = serialization.Tuple(result)
|
|
elif isinstance(result, str):
|
|
result = serialization.Str(result)
|
|
elif isinstance(result, unicode):
|
|
result = serialization.Unicode(result)
|
|
elif isinstance(result, np.ndarray):
|
|
result = result.view(serialization.NDArray)
|
|
elif isinstance(result, np.generic):
|
|
return result
|
|
# TODO(pcm): close the associated memory segment; if we don't, this leaks memory (but very little, so it is ok for now)
|
|
elif result is None:
|
|
raylib.unmap_object(self.handle, segmentid) # need to unmap here because result is passed back "by value" and we have no reference to unmap later
|
|
return None # can't subclass None and don't need to because there is a global None
|
|
result.ray_objectid = objectid # TODO(pcm): This could be done only for the "get" case in the future if we want to increase performance
|
|
result.ray_deallocator = RayDealloc(self.handle, segmentid)
|
|
return result
|
|
|
|
def alias_objectids(self, alias_objectid, target_objectid):
|
|
"""Make two object IDs refer to the same object."""
|
|
raylib.alias_objectids(self.handle, alias_objectid, target_objectid)
|
|
|
|
def submit_task(self, func_name, args):
|
|
"""Submit a remote task to the scheduler.
|
|
|
|
Tell the scheduler to schedule the execution of the function with name
|
|
func_name with arguments args. Retrieve object IDs for the outputs of
|
|
the function from the scheduler and immediately return them.
|
|
|
|
Args:
|
|
func_name (str): The name of the function to be executed.
|
|
args (List[Any]): The arguments to pass into the function. Arguments can
|
|
be object IDs or they can be values. If they are values, they
|
|
must be serializable objecs.
|
|
"""
|
|
task_capsule = serialization.serialize_task(self.handle, func_name, args)
|
|
objectids = raylib.submit_task(self.handle, task_capsule)
|
|
return objectids
|
|
|
|
global_worker = Worker()
|
|
"""Worker: The global Worker object for this worker process.
|
|
|
|
We use a global Worker object to ensure that there is a single worker object
|
|
per worker process.
|
|
"""
|
|
|
|
reusables = RayReusables()
|
|
"""RayReusables: The reusable variables that are shared between tasks.
|
|
|
|
Each worker process has its own RayReusables object, and these objects should be
|
|
the same in all workers. This is used for storing variables that are not
|
|
serializable but must be used by remote tasks. In addition, it is used to
|
|
reinitialize these variables after they are used so that changes to their state
|
|
made by one task do not affect other tasks.
|
|
"""
|
|
|
|
logger = logging.getLogger("ray")
|
|
"""Logger: The logging object for the Python worker code."""
|
|
|
|
def check_connected(worker=global_worker):
|
|
"""Check if the worker is connected.
|
|
|
|
Raises:
|
|
Exception: An exception is raised if the worker is not connected.
|
|
"""
|
|
if worker.handle is None:
|
|
raise Exception("This command cannot be called before a Ray cluster has been started. You can start one with 'ray.init(start_ray_local=True, num_workers=1)'.")
|
|
|
|
def print_failed_task(task_status):
|
|
"""Print information about failed tasks.
|
|
|
|
Args:
|
|
task_status (Dict): A dictionary containing the name, operationid, and
|
|
error message for a failed task.
|
|
"""
|
|
print """
|
|
Error: Task failed
|
|
Function Name: {}
|
|
Task ID: {}
|
|
Error Message: \n{}
|
|
""".format(task_status["function_name"], task_status["operationid"], task_status["error_message"])
|
|
|
|
def scheduler_info(worker=global_worker):
|
|
"""Return information about the state of the scheduler."""
|
|
check_connected(worker)
|
|
return raylib.scheduler_info(worker.handle)
|
|
|
|
def visualize_computation_graph(file_path=None, view=False, worker=global_worker):
|
|
"""Write the computation graph to a pdf file.
|
|
|
|
Args:
|
|
file_path (str): The name of a pdf file that the rendered computation graph
|
|
will be written to. If this argument is None, a temporary path will be
|
|
used.
|
|
view (bool): If true, the result the python graphviz package will try to
|
|
open the result in a viewer.
|
|
|
|
Examples:
|
|
Try the following code.
|
|
|
|
>>> import ray.array.distributed as da
|
|
>>> x = da.zeros([20, 20])
|
|
>>> y = da.zeros([20, 20])
|
|
>>> z = da.dot(x, y)
|
|
>>> ray.visualize_computation_graph(view=True)
|
|
"""
|
|
check_connected(worker)
|
|
if file_path is None:
|
|
file_path = config.get_log_file_path("computation-graph.pdf")
|
|
|
|
base_path, extension = os.path.splitext(file_path)
|
|
if extension != ".pdf":
|
|
raise Exception("File path must be a .pdf file")
|
|
proto_path = base_path + ".binaryproto"
|
|
|
|
raylib.dump_computation_graph(worker.handle, proto_path)
|
|
g = internal.graph_pb2.CompGraph()
|
|
g.ParseFromString(open(proto_path).read())
|
|
graph.graph_to_graphviz(g).render(base_path, view=view)
|
|
|
|
print "Wrote graph dot description to file {}".format(base_path)
|
|
print "Wrote graph protocol buffer description to file {}".format(proto_path)
|
|
print "Wrote computation graph to file {}.pdf".format(base_path)
|
|
|
|
def task_info(worker=global_worker):
|
|
"""Return information about failed tasks."""
|
|
check_connected(worker)
|
|
return raylib.task_info(worker.handle)
|
|
|
|
def init(start_ray_local=False, num_workers=None, num_objstores=None, scheduler_address=None, node_ip_address=None, driver_mode=raylib.SCRIPT_MODE):
|
|
"""Either connect to an existing Ray cluster or start one and connect to it.
|
|
|
|
This method handles two cases. Either a Ray cluster already exists and we
|
|
just attach this driver to it, or we start all of the processes associated
|
|
with a Ray cluster and attach to the newly started cluster.
|
|
|
|
Args:
|
|
start_ray_local (Optional[bool]): If True then this will start a scheduler
|
|
an object store, and some workers. If False, this will attach to an
|
|
existing Ray cluster.
|
|
num_workers (Optional[int]): The number of workers to start if
|
|
start_ray_local is True.
|
|
num_objstores (Optional[int]): The number of object stores to start if
|
|
start_ray_local is True.
|
|
scheduler_address (Optional[str]): The address of the scheduler to connect
|
|
to if start_ray_local is False.
|
|
node_ip_address (Optional[str]): The address of the node the worker is
|
|
running on. It is required if start_ray_local is False and it cannot be
|
|
provided otherwise.
|
|
driver_mode (Optional[bool]): The mode in which to start the driver. This
|
|
should be one of SCRIPT_MODE, PYTHON_MODE, and SILENT_MODE.
|
|
|
|
raises:
|
|
Exception: An exception is raised if an inappropriate combination of
|
|
arguments is passed in.
|
|
"""
|
|
if start_ray_local:
|
|
# In this case, we launch a scheduler, a new object store, and some workers,
|
|
# and we connect to them.
|
|
if (scheduler_address is not None) or (node_ip_address is not None):
|
|
raise Exception("If start_ray_local=True, then you cannot pass in a scheduler_address or a node_ip_address.")
|
|
if driver_mode not in [raylib.SCRIPT_MODE, raylib.PYTHON_MODE, raylib.SILENT_MODE]:
|
|
raise Exception("If start_ray_local=True, then driver_mode must be in [ray.SCRIPT_MODE, ray.PYTHON_MODE, ray.SILENT_MODE].")
|
|
# Use the address 127.0.0.1 in local mode.
|
|
node_ip_address = "127.0.0.1"
|
|
num_workers = 1 if num_workers is None else num_workers
|
|
num_objstores = 1 if num_objstores is None else num_objstores
|
|
# Start the scheduler, object store, and some workers. These will be killed
|
|
# by the call to cleanup(), which happens when the Python script exits.
|
|
scheduler_address, _ = services.start_ray_local(num_objstores=num_objstores, num_workers=num_workers, worker_path=None)
|
|
else:
|
|
# In this case, there is an existing scheduler and object store, and we do
|
|
# not need to start any processes.
|
|
if (num_workers is not None) or (num_objstores is not None):
|
|
raise Exception("The arguments num_workers and num_objstores must not be provided unless start_ray_local=True.")
|
|
if node_ip_address is None:
|
|
raise Exception("When start_ray_local=False, the node_ip_address of the current node must be provided.")
|
|
# Connect this driver to the scheduler and object store. The corresponing call
|
|
# to disconnect will happen in the call to cleanup() when the Python script
|
|
# exits.
|
|
connect(node_ip_address, scheduler_address, is_driver=True, worker=global_worker, mode=driver_mode)
|
|
|
|
def cleanup(worker=global_worker):
|
|
"""Disconnect the driver, and terminate any processes started in init.
|
|
|
|
This will automatically run at the end when a Python process that uses Ray
|
|
exits. It is ok to run this twice in a row. Note that we manually call
|
|
services.cleanup() in the tests because we need to start and stop many
|
|
clusters in the tests, but the import and exit only happen once.
|
|
"""
|
|
disconnect()
|
|
worker.set_mode(None)
|
|
services.cleanup()
|
|
|
|
atexit.register(cleanup)
|
|
|
|
def connect(node_ip_address, scheduler_address, objstore_address=None, is_driver=False, worker=global_worker, mode=raylib.WORKER_MODE):
|
|
"""Connect this worker to the scheduler and an object store.
|
|
|
|
Args:
|
|
node_ip_address (str): The ip address of the node the worker runs on.
|
|
scheduler_address (str): The ip address and port of the scheduler.
|
|
objstore_address (Optional[str]): The ip address and port of the local
|
|
object store. Normally, this argument should be omitted and the scheduler
|
|
will tell the worker what object store to connect to.
|
|
is_driver (bool): True if this worker is a driver and false otherwise.
|
|
mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, PYTHON_MODE,
|
|
and SILENT_MODE.
|
|
"""
|
|
if hasattr(worker, "handle"):
|
|
del worker.handle
|
|
worker.scheduler_address = scheduler_address
|
|
worker.handle, worker.worker_address = raylib.create_worker(node_ip_address, scheduler_address, objstore_address if objstore_address is not None else "", is_driver)
|
|
worker.set_mode(mode)
|
|
FORMAT = "%(asctime)-15s %(message)s"
|
|
# Configure the Python logging module. Note that if we do not provide our own
|
|
# logger, then our logging will interfere with other Python modules that also
|
|
# use the logging module.
|
|
log_handler = logging.FileHandler(config.get_log_file_path("-".join(["worker", worker.worker_address]) + ".log"))
|
|
log_handler.setLevel(logging.DEBUG)
|
|
log_handler.setFormatter(logging.Formatter(FORMAT))
|
|
_logger().addHandler(log_handler)
|
|
_logger().setLevel(logging.DEBUG)
|
|
_logger().propagate = False
|
|
# Configure the logging from the worker C++ code.
|
|
raylib.set_log_config(config.get_log_file_path("-".join(["worker", worker.worker_address, "c++"]) + ".log"))
|
|
if mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
|
|
for function_name, function_to_export in worker.cached_remote_functions:
|
|
raylib.export_remote_function(worker.handle, function_name, function_to_export)
|
|
for name, reusable_variable in reusables._cached_reusables:
|
|
_export_reusable_variable(name, reusable_variable)
|
|
worker.cached_remote_functions = None
|
|
reusables._cached_reusables = None
|
|
# Start the driver's WorkerService (if this is a driver). This will receive
|
|
# GRPC commands from the scheduler to print error messages. We pass in the
|
|
# mode below. This tells the WorkerService whether it is operating for a
|
|
# driver or a worker and whether it should surpress errors or not.
|
|
if is_driver:
|
|
raylib.start_worker_service(worker.handle, mode)
|
|
|
|
def disconnect(worker=global_worker):
|
|
"""Disconnect this worker from the scheduler and object store."""
|
|
if worker.handle is not None:
|
|
raylib.disconnect(worker.handle)
|
|
# Reset the list of cached remote functions so that if more remote functions
|
|
# are defined and then connect is called again, the remote functions will be
|
|
# exported. This is mostly relevant for the tests.
|
|
worker.handle = None
|
|
worker.cached_remote_functions = []
|
|
reusables._cached_reusables = []
|
|
|
|
def get(objectid, worker=global_worker):
|
|
"""Get a remote object from an object store.
|
|
|
|
This method blocks until the object corresponding to objectid is available in
|
|
the local object store. If this object is not in the local object store, it
|
|
will be shipped from an object store that has it (once the object has been
|
|
created).
|
|
|
|
Args:
|
|
objectid (raylib.ObjectID): Object ID to the object to get.
|
|
|
|
Returns:
|
|
A Python object
|
|
"""
|
|
check_connected(worker)
|
|
if worker.mode == raylib.PYTHON_MODE:
|
|
return objectid # In raylib.PYTHON_MODE, ray.get is the identity operation (the input will actually be a value not an objectid)
|
|
raylib.request_object(worker.handle, objectid)
|
|
value = worker.get_object(objectid)
|
|
if isinstance(value, RayTaskError):
|
|
# If the result is a RayTaskError, then the task that created this object
|
|
# failed, and we should propagate the error message here.
|
|
raise RayGetError(objectid, value)
|
|
return value
|
|
|
|
def put(value, worker=global_worker):
|
|
"""Store an object in the object store.
|
|
|
|
Args:
|
|
value (serializable object): The Python object to be stored.
|
|
|
|
Returns:
|
|
The object ID assigned to this value.
|
|
"""
|
|
check_connected(worker)
|
|
if worker.mode == raylib.PYTHON_MODE:
|
|
return value # In raylib.PYTHON_MODE, ray.put is the identity operation
|
|
objectid = raylib.get_objectid(worker.handle)
|
|
worker.put_object(objectid, value)
|
|
return objectid
|
|
|
|
def kill_workers(worker=global_worker):
|
|
"""Kill all of the workers in the cluster. This does not kill drivers.
|
|
|
|
Note:
|
|
Currently, we only support killing workers if all submitted tasks have been
|
|
run. If some workers are still running tasks or if the scheduler still has
|
|
tasks in its queue, then this method will not do anything.
|
|
|
|
Returns:
|
|
True if workers were successfully killed. False otherwise.
|
|
"""
|
|
success = raylib.kill_workers(worker.handle)
|
|
if not success:
|
|
print "Could not kill all workers. We currently do not support killing workers when tasks are running."
|
|
return success
|
|
|
|
def restart_workers_local(num_workers, worker_path, worker=global_worker):
|
|
"""Restart workers locally.
|
|
|
|
This method kills all of the workers and starts new workers locally on the
|
|
same node as the driver. This is intended for use in the case where Ray is
|
|
being used on a single node.
|
|
|
|
Args:
|
|
num_workers (int): The number of workers to be started.
|
|
worker_path (str): The path of the source code that workers will run.
|
|
|
|
Returns:
|
|
True if workers were successfully restarted. False otherwise.
|
|
"""
|
|
if not kill_workers(worker):
|
|
return False
|
|
services.start_workers(worker.scheduler_address, worker.objstore_address, num_workers, worker_path)
|
|
return True
|
|
|
|
def format_error_message(exception_message):
|
|
"""Improve the formatting of an exception thrown by a remote function.
|
|
|
|
This method takes a traceback from an exception and makes it nicer by
|
|
removing a few uninformative lines and adding some space to indent the
|
|
remaining lines nicely.
|
|
|
|
Args:
|
|
exception_message (str): A message generated by traceback.format_exc().
|
|
|
|
Returns:
|
|
A string of the formatted exception message.
|
|
"""
|
|
lines = exception_message.split("\n")
|
|
# Remove lines 1, 2, 3, and 4, which are always the same, they just contain
|
|
# information about the main loop.
|
|
lines = lines[0:1] + lines[5:]
|
|
return "\n".join(lines)
|
|
|
|
def main_loop(worker=global_worker):
|
|
"""The main loop a worker runs to receive and execute tasks.
|
|
|
|
This method is an infinite loop. It waits to receive commands from the
|
|
scheduler. A command may consist of a task to execute, a remote function to
|
|
import, a reusable variable to import, or an order to terminate the worker
|
|
process. The worker executes the command, notifies the scheduler of any errors
|
|
that occurred while executing the command, and waits for the next command.
|
|
"""
|
|
if not raylib.connected(worker.handle):
|
|
raise Exception("Worker is attempting to enter main_loop but has not been connected yet.")
|
|
# We pass in raylib.WORKER_MODE below to indicate that the WorkerService is
|
|
# operating for a worker and not a driver.
|
|
raylib.start_worker_service(worker.handle, raylib.WORKER_MODE)
|
|
|
|
def process_task(task): # wrapping these lines in a function should cause the local variables to go out of scope more quickly, which is useful for inspecting reference counts
|
|
"""Execute a task assigned to this worker.
|
|
|
|
This method deserializes a task from the scheduler, and attempts to execute
|
|
the task. If the task succeeds, the outputs are stored in the local object
|
|
store. If the task throws an exception, RayTaskError objects are stored in
|
|
the object store to represent the failed task (these will be retrieved by
|
|
calls to get or by subsequent tasks that use the outputs of this task).
|
|
After the task executes, the worker resets any reusable variables that were
|
|
accessed by the task.
|
|
"""
|
|
function_name, args, return_objectids = serialization.deserialize_task(worker.handle, task)
|
|
try:
|
|
arguments = get_arguments_for_execution(worker.functions[function_name], args, worker) # get args from objstore
|
|
outputs = worker.functions[function_name].executor(arguments) # execute the function
|
|
if len(return_objectids) == 1:
|
|
outputs = (outputs,)
|
|
except Exception as e:
|
|
# If the task threw an exception, then record the traceback. We determine
|
|
# whether the exception was thrown in the task execution by whether the
|
|
# variable "arguments" is defined.
|
|
traceback_str = format_error_message(traceback.format_exc()) if "arguments" in locals() else None
|
|
failure_object = RayTaskError(function_name, e, traceback_str)
|
|
failure_objects = [failure_object for _ in range(len(return_objectids))]
|
|
store_outputs_in_objstore(return_objectids, failure_objects, worker)
|
|
# Notify the scheduler that the task failed.
|
|
raylib.notify_failure(worker.handle, function_name, str(failure_object), raylib.FailedTask)
|
|
_logger().info("While running function {}, worker threw exception with message: \n\n{}\n".format(function_name, str(failure_object)))
|
|
else:
|
|
store_outputs_in_objstore(return_objectids, outputs, worker) # store output in local object store
|
|
# Notify the scheduler that the task is done. This happens regardless of
|
|
# whether the task succeeded or failed.
|
|
raylib.notify_task_completed(worker.handle)
|
|
try:
|
|
# Reinitialize the values of reusable variables that were used in the task
|
|
# above so that changes made to their state do not affect other tasks.
|
|
reusables._reinitialize()
|
|
except Exception as e:
|
|
# The attempt to reinitialize the reusable variables threw an exception.
|
|
# We record the traceback and notify the scheduler.
|
|
traceback_str = format_error_message(traceback.format_exc())
|
|
raylib.notify_failure(worker.handle, function_name, traceback_str, raylib.FailedReinitializeReusableVariable)
|
|
_logger().info("While attempting to reinitialize the reusable variables after running function {}, the worker threw exception with message: \n\n{}\n".format(function_name, traceback_str))
|
|
|
|
def process_remote_function(function_name, serialized_function):
|
|
"""Import a remote function."""
|
|
try:
|
|
(function, arg_types, return_types, module) = pickling.loads(serialized_function)
|
|
except:
|
|
# If an exception was thrown when the remote function was imported, we
|
|
# record the traceback and notify the scheduler of the failure.
|
|
traceback_str = format_error_message(traceback.format_exc())
|
|
_logger().info("Failed to import remote function {}. Failed with message: \n\n{}\n".format(function_name, traceback_str))
|
|
# Notify the scheduler that the remote function failed to import.
|
|
raylib.notify_failure(worker.handle, function_name, traceback_str, raylib.FailedRemoteFunctionImport)
|
|
else:
|
|
# TODO(rkn): Why is the below line necessary?
|
|
function.__module__ = module
|
|
assert function_name == "{}.{}".format(function.__module__, function.__name__), "The remote function name does not match the name that was passed in."
|
|
worker.functions[function_name] = remote(arg_types, return_types, worker)(function)
|
|
_logger().info("Successfully imported remote function {}.".format(function_name))
|
|
# Noify the scheduler that the remote function imported successfully.
|
|
# We pass an empty error message string because the import succeeded.
|
|
raylib.register_remote_function(worker.handle, function_name, len(return_types))
|
|
|
|
def process_reusable_variable(reusable_variable_name, initializer_str, reinitializer_str):
|
|
"""Import a reusable variable."""
|
|
try:
|
|
initializer = pickling.loads(initializer_str)
|
|
reinitializer = pickling.loads(reinitializer_str)
|
|
reusables.__setattr__(reusable_variable_name, Reusable(initializer, reinitializer))
|
|
except:
|
|
# If an exception was thrown when the reusable variable was imported, we
|
|
# record the traceback and notify the scheduler of the failure.
|
|
traceback_str = format_error_message(traceback.format_exc())
|
|
_logger().info("Failed to import reusable variable {}. Failed with message: \n\n{}\n".format(reusable_variable_name, traceback_str))
|
|
# Notify the scheduler that the reusable variable failed to import.
|
|
raylib.notify_failure(worker.handle, reusable_variable_name, traceback_str, raylib.FailedReusableVariableImport)
|
|
else:
|
|
_logger().info("Successfully imported reusable variable {}.".format(reusable_variable_name))
|
|
|
|
while True:
|
|
command, command_args = raylib.wait_for_next_message(worker.handle)
|
|
try:
|
|
if command == "die":
|
|
# We use this as a mechanism to allow the scheduler to kill workers.
|
|
_logger().info("Received a 'die' command, and will exit now.")
|
|
break
|
|
elif command == "task":
|
|
process_task(command_args)
|
|
elif command == "function":
|
|
function_name, serialized_function = command_args
|
|
process_remote_function(function_name, serialized_function)
|
|
elif command == "reusable_variable":
|
|
name, initializer_str, reinitializer_str = command_args
|
|
process_reusable_variable(name, initializer_str, reinitializer_str)
|
|
else:
|
|
_logger().info("Reached the end of the if-else loop in the main loop. This should be unreachable.")
|
|
assert False, "This code should be unreachable."
|
|
finally:
|
|
# Allow releasing the variables BEFORE we wait for the next message or exit the block
|
|
del command_args
|
|
|
|
def _submit_task(func_name, args, worker=global_worker):
|
|
"""This is a wrapper around worker.submit_task.
|
|
|
|
We use this wrapper so that in the remote decorator, we can call _submit_task
|
|
instead of worker.submit_task. The difference is that when we attempt to
|
|
serialize remote functions, we don't attempt to serialize the worker object,
|
|
which cannot be serialized.
|
|
"""
|
|
return worker.submit_task(func_name, args)
|
|
|
|
def _mode(worker=global_worker):
|
|
"""This is a wrapper around worker.mode.
|
|
|
|
We use this wrapper so that in the remote decorator, we can call _mode()
|
|
instead of worker.mode. The difference is that when we attempt to serialize
|
|
remote functions, we don't attempt to serialize the worker object, which
|
|
cannot be serialized.
|
|
"""
|
|
return worker.mode
|
|
|
|
def _logger():
|
|
"""Return the logger object.
|
|
|
|
We use this wrapper because so that functions which do logging can be pickled.
|
|
Normally a logger object is specific to a machine (it opens a local file), and
|
|
so cannot be pickled.
|
|
"""
|
|
return logger
|
|
|
|
def _export_reusable_variable(name, reusable, worker=global_worker):
|
|
"""Export a reusable variable to the workers. This is only called by a driver.
|
|
|
|
Args:
|
|
name (str): The name of the variable to export.
|
|
reusable (Reusable): The reusable object containing code for initializing
|
|
and reinitializing the variable.
|
|
"""
|
|
if _mode(worker) not in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
|
|
raise Exception("_export_reusable_variable can only be called on a driver.")
|
|
raylib.export_reusable_variable(worker.handle, name, pickling.dumps(reusable.initializer), pickling.dumps(reusable.reinitializer))
|
|
|
|
def remote(arg_types, return_types, worker=global_worker):
|
|
"""This decorator is used to create remote functions.
|
|
|
|
Args:
|
|
arg_types (List[type]): List of Python types of the function arguments.
|
|
return_types (List[type]): List of Python types of the return values.
|
|
"""
|
|
def remote_decorator(func):
|
|
def func_call(*args, **kwargs):
|
|
"""This gets run immediately when a worker calls a remote function."""
|
|
check_connected()
|
|
args = list(args)
|
|
args.extend([kwargs[keyword] if kwargs.has_key(keyword) else default for keyword, default in keyword_defaults[len(args):]]) # fill in the remaining arguments
|
|
if _mode() == raylib.PYTHON_MODE:
|
|
# In raylib.PYTHON_MODE, remote calls simply execute the function. We copy the
|
|
# arguments to prevent the function call from mutating them and to match
|
|
# the usual behavior of immutable remote objects.
|
|
return func(*copy.deepcopy(args))
|
|
check_arguments(arg_types, has_vararg_param, func_name, args) # throws an exception if args are invalid
|
|
objectids = _submit_task(func_name, args)
|
|
if len(objectids) == 1:
|
|
return objectids[0]
|
|
elif len(objectids) > 1:
|
|
return objectids
|
|
def func_executor(arguments):
|
|
"""This gets run when the remote function is executed."""
|
|
_logger().info("Calling function {}".format(func.__name__))
|
|
start_time = time.time()
|
|
result = func(*arguments)
|
|
end_time = time.time()
|
|
check_return_values(func_invoker, result) # throws an exception if result is invalid
|
|
_logger().info("Finished executing function {}, it took {} seconds".format(func.__name__, end_time - start_time))
|
|
return result
|
|
def func_invoker(*args, **kwargs):
|
|
"""This is returned by the decorator and used to invoke the function."""
|
|
raise Exception("Remote functions cannot be called directly. Instead of running '{}()', try '{}.remote()'.".format(func_name, func_name))
|
|
func_invoker.remote = func_call
|
|
func_invoker.executor = func_executor
|
|
func_invoker.arg_types = arg_types
|
|
func_invoker.return_types = return_types
|
|
func_invoker.is_remote = True
|
|
func_name = "{}.{}".format(func.__module__, func.__name__)
|
|
func_invoker.func_name = func_name
|
|
func_invoker.func_doc = func.func_doc
|
|
sig_params = [(k, v) for k, v in funcsigs.signature(func).parameters.iteritems()]
|
|
keyword_defaults = [(k, v.default) for k, v in sig_params]
|
|
has_vararg_param = any([v.kind == v.VAR_POSITIONAL for k, v in sig_params])
|
|
func_invoker.has_vararg_param = has_vararg_param
|
|
has_kwargs_param = any([v.kind == v.VAR_KEYWORD for k, v in sig_params])
|
|
check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, func_name)
|
|
|
|
# Everything ready - export the function
|
|
if worker.mode in [None, raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
|
|
func_name_global_valid = func.__name__ in func.__globals__
|
|
func_name_global_value = func.__globals__.get(func.__name__)
|
|
# Set the function globally to make it refer to itself
|
|
func.__globals__[func.__name__] = func_invoker # Allow the function to reference itself as a global variable
|
|
try:
|
|
to_export = pickling.dumps((func, arg_types, return_types, func.__module__))
|
|
finally:
|
|
# Undo our changes
|
|
if func_name_global_valid: func.__globals__[func.__name__] = func_name_global_value
|
|
else: del func.__globals__[func.__name__]
|
|
if worker.mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
|
|
raylib.export_remote_function(worker.handle, func_name, to_export)
|
|
elif worker.mode is None:
|
|
worker.cached_remote_functions.append((func_name, to_export))
|
|
return func_invoker
|
|
return remote_decorator
|
|
|
|
def check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, name):
|
|
"""Check if we support the signature of this function.
|
|
|
|
We currently do not allow remote functions to have **kwargs. We also do not
|
|
support keyword argumens in conjunction with a *args argument.
|
|
|
|
Args:
|
|
has_kwards_param (bool): True if the function being checked has a **kwargs
|
|
argument.
|
|
has_vararg_param (bool): True if the function being checked has a *args
|
|
argument.
|
|
keyword_defaults (List): A list of the default values for the arguments to
|
|
the function being checked.
|
|
name (str): The name of the function to check.
|
|
|
|
Raises:
|
|
Exception: An exception is raised if the signature is not supported.
|
|
"""
|
|
# check if the user specified kwargs
|
|
if has_kwargs_param:
|
|
raise "Function {} has a **kwargs argument, which is currently not supported.".format(name)
|
|
# check if the user specified a variable number of arguments and any keyword arguments
|
|
if has_vararg_param and any([d != funcsigs._empty for _, d in keyword_defaults]):
|
|
raise "Function {} has a *args argument as well as a keyword argument, which is currently not supported.".format(name)
|
|
|
|
|
|
def check_return_values(function, result):
|
|
"""Check the types and number of return values.
|
|
|
|
Args:
|
|
function (Callable): The remote function whose outputs are being checked.
|
|
result: The value returned by an invocation of the remote function. The
|
|
expected types and number are defined in the remote decorator.
|
|
|
|
Raises:
|
|
Exception: An exception is raised if the return values have incorrect types
|
|
or the function returned the wrong number of return values.
|
|
"""
|
|
# If the @remote decorator declares that the function has no return values,
|
|
# then all we do is check that there were in fact no return values.
|
|
if len(function.return_types) == 0:
|
|
if result is not None:
|
|
raise Exception("The @remote decorator for function {} has 0 return values, but {} returned more than 0 values.".format(function.__name__, function.__name__))
|
|
return
|
|
# If a function has multiple return values, Python returns a tuple of the
|
|
# values. If there is a single return value, then Python does not return a
|
|
# tuple, it simply returns the value. That is why we place result with
|
|
# (result,) when there is only one return value, so we can treat these two
|
|
# cases similarly.
|
|
if len(function.return_types) == 1:
|
|
result = (result,)
|
|
# Below we check that the number of values returned by the function match the
|
|
# number of return values declared in the @remote decorator.
|
|
if len(result) != len(function.return_types):
|
|
raise Exception("The @remote decorator for function {} has {} return values with types {}, but {} returned {} values.".format(function.__name__, len(function.return_types), function.return_types, function.__name__, len(result)))
|
|
# Here we do some limited type checking to make sure the return values have
|
|
# the right types.
|
|
for i in range(len(result)):
|
|
if (not issubclass(type(result[i]), function.return_types[i])) and (not isinstance(result[i], raylib.ObjectID)):
|
|
raise Exception("The {}th return value for function {} has type {}, but the @remote decorator expected a return value of type {} or an ObjectID.".format(i, function.__name__, type(result[i]), function.return_types[i]))
|
|
|
|
def typecheck_arg(arg, expected_type, i, name):
|
|
"""Check that an argument has the expected type.
|
|
|
|
Args:
|
|
arg: An argument to function.
|
|
expected_type (type): The expected type of arg.
|
|
i (int): The position of the argument to the function.
|
|
name (str): The name of the function.
|
|
|
|
Raises:
|
|
RayGetArgumentTypeError: An exception is raised if arg does not have the
|
|
expected type.
|
|
"""
|
|
if issubclass(type(arg), expected_type):
|
|
# Passed the type-checck
|
|
# TODO(rkn): This check doesn't really work, e.g., issubclass(type([1, 2, 3]), typing.List[str]) == True
|
|
pass
|
|
elif isinstance(arg, long) and issubclass(int, expected_type):
|
|
# TODO(mehrdadn): Should long really be convertible to int?
|
|
pass
|
|
else:
|
|
raise RayGetArgumentTypeError(name, i, type(arg), expected_type)
|
|
|
|
def check_arguments(arg_types, has_vararg_param, name, args):
|
|
"""Check that the arguments to the remote function have the right types.
|
|
|
|
This is called by the worker that calls the remote function (not the worker
|
|
that executes the remote function).
|
|
|
|
Args:
|
|
arg_types (List[type]): A list of the types of the arguments to the function
|
|
being checked.
|
|
has_vararg_param (bool): True if the function being checked has a *args
|
|
argument.
|
|
name (str): The name of the function.
|
|
args (List): The arguments to the function.
|
|
|
|
Raises:
|
|
Exception: An exception is raised the args do not all have the right types.
|
|
"""
|
|
# check the number of args
|
|
if len(args) != len(arg_types) and not has_vararg_param:
|
|
raise Exception("Function {} expects {} arguments, but received {}.".format(name, len(arg_types), len(args)))
|
|
elif len(args) < len(arg_types) - 1 and has_vararg_param:
|
|
raise Exception("Function {} expects at least {} arguments, but received {}.".format(name, len(arg_types) - 1, len(args)))
|
|
|
|
for (i, arg) in enumerate(args):
|
|
if i <= len(arg_types) - 1:
|
|
expected_type = arg_types[i]
|
|
elif has_vararg_param:
|
|
expected_type = arg_types[-1]
|
|
else:
|
|
assert False, "This code should be unreachable."
|
|
|
|
if isinstance(arg, raylib.ObjectID):
|
|
# TODO(rkn): When we have type information in the ObjectID, do type checking here.
|
|
pass
|
|
else:
|
|
typecheck_arg(arg, expected_type, i, name)
|
|
|
|
def get_arguments_for_execution(function, args, worker=global_worker):
|
|
"""Retrieve the arguments for the remote function.
|
|
|
|
This retrieves the values for the arguments to the remote function that were
|
|
passed in as object IDs. Argumens that were passed by value are not changed.
|
|
This also does some type checking. This is called by the worker that is
|
|
executing the remote function.
|
|
|
|
Args:
|
|
function (Callable): The remote function whose arguments are being
|
|
retrieved.
|
|
args (List): The arguments to the function.
|
|
|
|
Returns:
|
|
The retrieved arguments in addition to the arguments that were passed by
|
|
value.
|
|
|
|
Raises:
|
|
RayGetArgumentError: This exception is raised if a task that created one of
|
|
the arguments failed.
|
|
RayGetArgumentTypeError: This exception is raised (via typecheck_arg) if one
|
|
of the arguments does not have the expected type.
|
|
"""
|
|
# TODO(rkn): Eventually, all of the type checking can be put in `check_arguments` above so that the error will happen immediately when calling a remote function.
|
|
arguments = []
|
|
# # check the number of args
|
|
# if len(args) != len(function.arg_types) and function.arg_types[-1] is not None:
|
|
# raise Exception("Function {} expects {} arguments, but received {}.".format(function.__name__, len(function.arg_types), len(args)))
|
|
# elif len(args) < len(function.arg_types) - 1 and function.arg_types[-1] is None:
|
|
# raise Exception("Function {} expects at least {} arguments, but received {}.".format(function.__name__, len(function.arg_types) - 1, len(args)))
|
|
|
|
for (i, arg) in enumerate(args):
|
|
if i <= len(function.arg_types) - 1:
|
|
expected_type = function.arg_types[i]
|
|
elif function.has_vararg_param and len(function.arg_types) >= 1:
|
|
expected_type = function.arg_types[-1]
|
|
else:
|
|
assert False, "This code should be unreachable."
|
|
|
|
if isinstance(arg, raylib.ObjectID):
|
|
# get the object from the local object store
|
|
_logger().info("Getting argument {} for function {}.".format(i, function.__name__))
|
|
argument = worker.get_object(arg)
|
|
if isinstance(argument, RayTaskError):
|
|
# If the result is a RayTaskError, then the task that created this
|
|
# object failed, and we should propagate the error message here.
|
|
raise RayGetArgumentError(function.__name__, i, arg, argument)
|
|
_logger().info("Successfully retrieved argument {} for function {}.".format(i, function.__name__))
|
|
else:
|
|
# pass the argument by value
|
|
argument = arg
|
|
|
|
typecheck_arg(argument, expected_type, i, function.__name__)
|
|
arguments.append(argument)
|
|
return arguments
|
|
|
|
def store_outputs_in_objstore(objectids, outputs, worker=global_worker):
|
|
"""Store the outputs of a remote function in the local object store.
|
|
|
|
This stores the values that were returned by a remote function in the local
|
|
object store. If any of the return values are object IDs, then these object
|
|
IDs are aliased with the object IDs that the scheduler assigned for the return
|
|
values. This is called by the worker that executes the remote function.
|
|
|
|
Note:
|
|
The arguments objectids and outputs should have the same length.
|
|
|
|
Args:
|
|
objectids (List[raylib.ObjectID]): The object IDs that were assigned to the
|
|
outputs of the remote function call.
|
|
outputs (Tuple): The value returned by the remote function. If the remote
|
|
function was supposed to only return one value, then its output was
|
|
wrapped in a tuple with one element prior to being passed into this
|
|
function.
|
|
"""
|
|
for i in range(len(objectids)):
|
|
if isinstance(outputs[i], raylib.ObjectID):
|
|
# An ObjectID is being returned, so we must alias objectids[i] so that it refers to the same object that outputs[i] refers to
|
|
_logger().info("Aliasing objectids {} and {}".format(objectids[i].id, outputs[i].id))
|
|
worker.alias_objectids(objectids[i], outputs[i])
|
|
pass
|
|
else:
|
|
worker.put_object(objectids[i], outputs[i])
|