Remove type information from remote decorator.

This commit is contained in:
Robert Nishihara
2016-08-29 22:05:59 -07:00
parent 93e6c9947b
commit d7f313a026
23 changed files with 147 additions and 305 deletions
+14 -14
View File
@@ -66,12 +66,12 @@ class DistArray(object):
a = self.assemble()
return a[sliced]
@ray.remote([DistArray], [np.ndarray])
@ray.remote()
def assemble(a):
return a.assemble()
# TODO(rkn): what should we call this method
@ray.remote([np.ndarray], [DistArray])
@ray.remote()
def numpy_to_dist(a):
result = DistArray(a.shape)
for index in np.ndindex(*result.num_blocks):
@@ -80,28 +80,28 @@ def numpy_to_dist(a):
result.objectids[index] = ray.put(a[[slice(l, u) for (l, u) in zip(lower, upper)]])
return result
@ray.remote([List, str], [DistArray])
@ray.remote()
def zeros(shape, dtype_name="float"):
result = DistArray(shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = ra.zeros.remote(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
return result
@ray.remote([List, str], [DistArray])
@ray.remote()
def ones(shape, dtype_name="float"):
result = DistArray(shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = ra.ones.remote(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
return result
@ray.remote([DistArray], [DistArray])
@ray.remote()
def copy(a):
result = DistArray(a.shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = a.objectids[index] # We don't need to actually copy the objects because cluster-level objects are assumed to be immutable.
return result
@ray.remote([int, int, str], [DistArray])
@ray.remote()
def eye(dim1, dim2=-1, dtype_name="float"):
dim2 = dim1 if dim2 == -1 else dim2
shape = [dim1, dim2]
@@ -114,7 +114,7 @@ def eye(dim1, dim2=-1, dtype_name="float"):
result.objectids[i, j] = ra.zeros.remote(block_shape, dtype_name=dtype_name)
return result
@ray.remote([DistArray], [DistArray])
@ray.remote()
def triu(a):
if a.ndim != 2:
raise Exception("Input must have 2 dimensions, but a.ndim is " + str(a.ndim))
@@ -128,7 +128,7 @@ def triu(a):
result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j])
return result
@ray.remote([DistArray], [DistArray])
@ray.remote()
def tril(a):
if a.ndim != 2:
raise Exception("Input must have 2 dimensions, but a.ndim is " + str(a.ndim))
@@ -142,7 +142,7 @@ def tril(a):
result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j])
return result
@ray.remote([np.ndarray], [np.ndarray])
@ray.remote()
def blockwise_dot(*matrices):
n = len(matrices)
if n % 2 != 0:
@@ -153,7 +153,7 @@ def blockwise_dot(*matrices):
result += np.dot(matrices[i], matrices[n / 2 + i])
return result
@ray.remote([DistArray, DistArray], [DistArray])
@ray.remote()
def dot(a, b):
if a.ndim != 2:
raise Exception("dot expects its arguments to be 2-dimensional, but a.ndim = {}.".format(a.ndim))
@@ -168,7 +168,7 @@ def dot(a, b):
result.objectids[i, j] = blockwise_dot.remote(*args)
return result
@ray.remote([DistArray, List], [DistArray])
@ray.remote()
def subblocks(a, *ranges):
"""
This function produces a distributed array from a subset of the blocks in the `a`. The result and `a` will have the same number of dimensions.For example,
@@ -198,7 +198,7 @@ def subblocks(a, *ranges):
result.objectids[index] = a.objectids[tuple([ranges[i][index[i]] for i in range(a.ndim)])]
return result
@ray.remote([DistArray], [DistArray])
@ray.remote()
def transpose(a):
if a.ndim != 2:
raise Exception("transpose expects its argument to be 2-dimensional, but a.ndim = {}, a.shape = {}.".format(a.ndim, a.shape))
@@ -209,7 +209,7 @@ def transpose(a):
return result
# TODO(rkn): support broadcasting?
@ray.remote([DistArray, DistArray], [DistArray])
@ray.remote()
def add(x1, x2):
if x1.shape != x2.shape:
raise Exception("add expects arguments `x1` and `x2` to have the same shape, but x1.shape = {}, and x2.shape = {}.".format(x1.shape, x2.shape))
@@ -219,7 +219,7 @@ def add(x1, x2):
return result
# TODO(rkn): support broadcasting?
@ray.remote([DistArray, DistArray], [DistArray])
@ray.remote()
def subtract(x1, x2):
if x1.shape != x2.shape:
raise Exception("subtract expects arguments `x1` and `x2` to have the same shape, but x1.shape = {}, and x2.shape = {}.".format(x1.shape, x2.shape))
+8 -8
View File
@@ -6,7 +6,7 @@ from core import *
__all__ = ["tsqr", "modified_lu", "tsqr_hr", "qr"]
@ray.remote([DistArray], [DistArray, np.ndarray])
@ray.remote(num_return_vals=2)
def tsqr(a):
"""
arguments:
@@ -75,7 +75,7 @@ def tsqr(a):
return q_result, r
# TODO(rkn): This is unoptimized, we really want a block version of this.
@ray.remote([DistArray], [DistArray, np.ndarray, np.ndarray])
@ray.remote(num_return_vals=3)
def modified_lu(q):
"""
Algorithm 5 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf
@@ -105,19 +105,19 @@ def modified_lu(q):
U = np.triu(q_work)[:b, :]
return numpy_to_dist.remote(ray.put(L)), U, S # TODO(rkn): get rid of put
@ray.remote([np.ndarray, np.ndarray, np.ndarray, int], [np.ndarray, np.ndarray])
@ray.remote(num_return_vals=2)
def tsqr_hr_helper1(u, s, y_top_block, b):
y_top = y_top_block[:b, :b]
s_full = np.diag(s)
t = -1 * np.dot(u, np.dot(s_full, np.linalg.inv(y_top).T))
return t, y_top
@ray.remote([np.ndarray, np.ndarray], [np.ndarray])
@ray.remote()
def tsqr_hr_helper2(s, r_temp):
s_full = np.diag(s)
return np.dot(s_full, r_temp)
@ray.remote([DistArray], [DistArray, np.ndarray, np.ndarray, np.ndarray])
@ray.remote(num_return_vals=4)
def tsqr_hr(a):
"""Algorithm 6 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf"""
q, r_temp = tsqr.remote(a)
@@ -127,15 +127,15 @@ def tsqr_hr(a):
r = tsqr_hr_helper2.remote(s, r_temp)
return y, t, y_top, r
@ray.remote([np.ndarray, np.ndarray, np.ndarray, np.ndarray], [np.ndarray])
@ray.remote()
def qr_helper1(a_rc, y_ri, t, W_c):
return a_rc - np.dot(y_ri, np.dot(t.T, W_c))
@ray.remote([np.ndarray, np.ndarray], [np.ndarray])
@ray.remote()
def qr_helper2(y_ri, a_rc):
return np.dot(y_ri.T, a_rc)
@ray.remote([DistArray], [DistArray, DistArray])
@ray.remote(num_return_vals=2)
def qr(a):
"""Algorithm 7 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf"""
m, n = a.shape[0], a.shape[1]
+1 -1
View File
@@ -6,7 +6,7 @@ import ray
from core import *
@ray.remote([List], [DistArray])
@ray.remote()
def normal(shape):
num_blocks = DistArray.compute_num_blocks(shape)
objectids = np.empty(num_blocks, dtype=object)
+18 -18
View File
@@ -4,80 +4,80 @@ import ray
__all__ = ["zeros", "zeros_like", "ones", "eye", "dot", "vstack", "hstack", "subarray", "copy", "tril", "triu", "diag", "transpose", "add", "subtract", "sum", "shape", "sum_list"]
@ray.remote([List, str, str], [np.ndarray])
@ray.remote()
def zeros(shape, dtype_name="float", order="C"):
return np.zeros(shape, dtype=np.dtype(dtype_name), order=order)
@ray.remote([np.ndarray, str, str, bool], [np.ndarray])
@ray.remote()
def zeros_like(a, dtype_name="None", order="K", subok=True):
dtype_val = None if dtype_name == "None" else np.dtype(dtype_name)
return np.zeros_like(a, dtype=dtype_val, order=order, subok=subok)
@ray.remote([List, str, str], [np.ndarray])
@ray.remote()
def ones(shape, dtype_name="float", order="C"):
return np.ones(shape, dtype=np.dtype(dtype_name), order=order)
@ray.remote([int, int, int, str], [np.ndarray])
@ray.remote()
def eye(N, M=-1, k=0, dtype_name="float"):
M = N if M == -1 else M
return np.eye(N, M=M, k=k, dtype=np.dtype(dtype_name))
@ray.remote([np.ndarray, np.ndarray], [np.ndarray])
@ray.remote()
def dot(a, b):
return np.dot(a, b)
@ray.remote([np.ndarray], [np.ndarray])
@ray.remote()
def vstack(*xs):
return np.vstack(xs)
@ray.remote([np.ndarray], [np.ndarray])
@ray.remote()
def hstack(*xs):
return np.hstack(xs)
# TODO(rkn): instead of this, consider implementing slicing
@ray.remote([np.ndarray, List, List], [np.ndarray])
@ray.remote()
def subarray(a, lower_indices, upper_indices): # TODO(rkn): be consistent about using "index" versus "indices"
return a[[slice(l, u) for (l, u) in zip(lower_indices, upper_indices)]]
@ray.remote([np.ndarray, str], [np.ndarray])
@ray.remote()
def copy(a, order="K"):
return np.copy(a, order=order)
@ray.remote([np.ndarray, int], [np.ndarray])
@ray.remote()
def tril(m, k=0):
return np.tril(m, k=k)
@ray.remote([np.ndarray, int], [np.ndarray])
@ray.remote()
def triu(m, k=0):
return np.triu(m, k=k)
@ray.remote([np.ndarray, int], [np.ndarray])
@ray.remote()
def diag(v, k=0):
return np.diag(v, k=k)
@ray.remote([np.ndarray, List], [np.ndarray])
@ray.remote()
def transpose(a, axes=[]):
axes = None if axes == [] else axes
return np.transpose(a, axes=axes)
@ray.remote([np.ndarray, np.ndarray], [np.ndarray])
@ray.remote()
def add(x1, x2):
return np.add(x1, x2)
@ray.remote([np.ndarray, np.ndarray], [np.ndarray])
@ray.remote()
def subtract(x1, x2):
return np.subtract(x1, x2)
@ray.remote([np.ndarray, int], [np.ndarray])
@ray.remote()
def sum(x, axis=-1):
return np.sum(x, axis=axis if axis != -1 else None)
@ray.remote([np.ndarray], [tuple])
@ray.remote()
def shape(a):
return np.shape(a)
# We use Any to allow different numerical types as well as numpy arrays.
# TODO(rkn):this isn't in the numpy API, so be careful about exposing this.
@ray.remote([Any], [Any])
@ray.remote()
def sum_list(*xs):
return np.sum(xs, axis=0)
+20 -20
View File
@@ -7,82 +7,82 @@ __all__ = ["matrix_power", "solve", "tensorsolve", "tensorinv", "inv",
"svd", "eig", "eigh", "lstsq", "norm", "qr", "cond", "matrix_rank",
"LinAlgError", "multi_dot"]
@ray.remote([np.ndarray, int], [np.ndarray])
@ray.remote()
def matrix_power(M, n):
return np.linalg.matrix_power(M, n)
@ray.remote([np.ndarray, np.ndarray], [np.ndarray])
@ray.remote()
def solve(a, b):
return np.linalg.solve(a, b)
@ray.remote([np.ndarray], [np.ndarray, np.ndarray])
@ray.remote(num_return_vals=2)
def tensorsolve(a):
raise NotImplementedError
@ray.remote([np.ndarray], [np.ndarray, np.ndarray])
@ray.remote(num_return_vals=2)
def tensorinv(a):
raise NotImplementedError
@ray.remote([np.ndarray], [np.ndarray])
@ray.remote()
def inv(a):
return np.linalg.inv(a)
@ray.remote([np.ndarray], [np.ndarray])
@ray.remote()
def cholesky(a):
return np.linalg.cholesky(a)
@ray.remote([np.ndarray], [np.ndarray])
@ray.remote()
def eigvals(a):
return np.linalg.eigvals(a)
@ray.remote([np.ndarray], [np.ndarray])
@ray.remote()
def eigvalsh(a):
raise NotImplementedError
@ray.remote([np.ndarray], [np.ndarray])
@ray.remote()
def pinv(a):
return np.linalg.pinv(a)
@ray.remote([np.ndarray], [int])
@ray.remote()
def slogdet(a):
raise NotImplementedError
@ray.remote([np.ndarray], [float])
@ray.remote()
def det(a):
return np.linalg.det(a)
@ray.remote([np.ndarray], [np.ndarray, np.ndarray, np.ndarray])
@ray.remote(num_return_vals=3)
def svd(a):
return np.linalg.svd(a)
@ray.remote([np.ndarray], [np.ndarray, np.ndarray])
@ray.remote(num_return_vals=2)
def eig(a):
return np.linalg.eig(a)
@ray.remote([np.ndarray], [np.ndarray, np.ndarray])
@ray.remote(num_return_vals=2)
def eigh(a):
return np.linalg.eigh(a)
@ray.remote([np.ndarray], [np.ndarray, np.ndarray, int, np.ndarray])
@ray.remote(num_return_vals=4)
def lstsq(a, b):
return np.linalg.lstsq(a)
@ray.remote([np.ndarray], [float])
@ray.remote()
def norm(x):
return np.linalg.norm(x)
@ray.remote([np.ndarray], [np.ndarray, np.ndarray])
@ray.remote(num_return_vals=2)
def qr(a):
return np.linalg.qr(a)
@ray.remote([np.ndarray], [float])
@ray.remote()
def cond(x):
return np.linalg.cond(x)
@ray.remote([np.ndarray], [int])
@ray.remote()
def matrix_rank(M):
return np.linalg.matrix_rank(M)
@ray.remote([np.ndarray], [np.ndarray])
@ray.remote()
def multi_dot(*a):
raise NotImplementedError
+1 -1
View File
@@ -2,6 +2,6 @@ from typing import List
import numpy as np
import ray
@ray.remote([List], [np.ndarray])
@ray.remote()
def normal(shape):
return np.random.normal(size=shape)
+9 -167
View File
@@ -4,7 +4,6 @@ import time
import traceback
import copy
import logging
from types import ModuleType
import typing
import funcsigs
import numpy as np
@@ -45,7 +44,7 @@ class RayTaskError(Exception):
def __init__(self, function_name, exception, traceback_str):
"""Initialize a RayTaskError."""
self.function_name = function_name
if isinstance(exception, RayGetError) or isinstance(exception, RayGetArgumentError) or isinstance(exception, RayGetArgumentTypeError):
if isinstance(exception, RayGetError) or isinstance(exception, RayGetArgumentError):
self.exception = exception
else:
self.exception = None
@@ -59,8 +58,6 @@ class RayTaskError(Exception):
exception = RayGetError.deserialize(exception[1])
elif exception[0] == "RayGetArgumentError":
exception = RayGetArgumentError.deserialize(exception[1])
elif exception[0] == "RayGetArgumentTypeError":
exception = RayGetArgumentTypeError.deserialize(exception[1])
elif exception[0] == "None":
exception = None
else:
@@ -73,8 +70,6 @@ class RayTaskError(Exception):
serialized_exception = ("RayGetError", self.exception.serialize())
elif isinstance(self.exception, RayGetArgumentError):
serialized_exception = ("RayGetArgumentError", self.exception.serialize())
elif isinstance(self.exception, RayGetArgumentTypeError):
serialized_exception = ("RayGetArgumentTypeError", self.exception.serialize())
elif self.exception is None:
serialized_exception = ("None",)
else:
@@ -151,41 +146,6 @@ class RayGetArgumentError(Exception):
"""Format a RayGetArgumentError as a string."""
return "Failed to get objectid {} as argument {} for remote function {}{}{}. It was created by remote function {}{}{} which failed with:\n{}".format(self.objectid, self.argument_index, colorama.Fore.RED, self.function_name, colorama.Fore.RESET, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error)
class RayGetArgumentTypeError(Exception):
"""An exception used when a task's argument doesn't type check.
Attributes:
function_name (str): The name of the function for the current task.
argument_index (int): The index (zero indexed) of the argument in the
present task's remote function call.
received_type: The type of the argument that was passed in.
expected_type: The type that was expected. This is determined by the remote
decorator.
"""
def __init__(self, function_name, argument_index, received_type, expected_type):
"""Initialize a RayGetArgumentTypeError object."""
self.function_name = function_name
self.argument_index = argument_index
# TODO(rkn): when we support the serialization of types, then we should
# remove the string conversions below.
self.received_type = str(received_type)
self.expected_type = str(expected_type)
@staticmethod
def deserialize(primitives):
"""Create a RayGetArgumentTypeError from a primitive object."""
function_name, argument_index, received_type, expected_type = primitives
return RayGetArgumentTypeError(function_name, argument_index, received_type, expected_type)
def serialize(self):
"""Turn a RayGetArgumentTypeError into a primitive object."""
return (self.function_name, self.argument_index, self.received_type, self.expected_type)
def __str__(self):
"""Format a RayGetArgumentTypeError as a string."""
return "Argument {} for remote function {}{}{} has type {} but an argument of type {} was expected.".format(self.argument_index, colorama.Fore.RED, self.function_name, colorama.Fore.RESET, self.received_type, self.expected_type)
class RayDealloc(object):
"""An object used internally to properly implement reference counting.
@@ -993,7 +953,7 @@ def main_loop(worker=global_worker):
def process_remote_function(function_name, serialized_function):
"""Import a remote function."""
try:
(function, arg_types, return_types, module) = pickling.loads(serialized_function)
function, num_return_vals, module = pickling.loads(serialized_function)
except:
# If an exception was thrown when the remote function was imported, we
# record the traceback and notify the scheduler of the failure.
@@ -1005,11 +965,11 @@ def main_loop(worker=global_worker):
# TODO(rkn): Why is the below line necessary?
function.__module__ = module
assert function_name == "{}.{}".format(function.__module__, function.__name__), "The remote function name does not match the name that was passed in."
worker.functions[function_name] = remote(arg_types, return_types, worker)(function)
worker.functions[function_name] = remote(num_return_vals, worker)(function)
_logger().info("Successfully imported remote function {}.".format(function_name))
# Noify the scheduler that the remote function imported successfully.
# We pass an empty error message string because the import succeeded.
raylib.register_remote_function(worker.handle, function_name, len(return_types))
raylib.register_remote_function(worker.handle, function_name, num_return_vals)
def process_reusable_variable(reusable_variable_name, initializer_str, reinitializer_str):
"""Import a reusable variable."""
@@ -1110,12 +1070,12 @@ def _export_reusable_variable(name, reusable, worker=global_worker):
raise Exception("_export_reusable_variable can only be called on a driver.")
raylib.export_reusable_variable(worker.handle, name, pickling.dumps(reusable.initializer), pickling.dumps(reusable.reinitializer))
def remote(arg_types, return_types, worker=global_worker):
def remote(num_return_vals=1, worker=global_worker):
"""This decorator is used to create remote functions.
Args:
arg_types (List[type]): List of Python types of the function arguments.
return_types (List[type]): List of Python types of the return values.
num_return_vals (int): The number of object IDs that a call to this function
should return.
"""
def remote_decorator(func):
def func_call(*args, **kwargs):
@@ -1128,7 +1088,6 @@ def remote(arg_types, return_types, worker=global_worker):
# arguments to prevent the function call from mutating them and to match
# the usual behavior of immutable remote objects.
return func(*copy.deepcopy(args))
check_arguments(arg_types, has_vararg_param, func_name, args) # throws an exception if args are invalid
objectids = _submit_task(func_name, args)
if len(objectids) == 1:
return objectids[0]
@@ -1140,7 +1099,6 @@ def remote(arg_types, return_types, worker=global_worker):
start_time = time.time()
result = func(*arguments)
end_time = time.time()
check_return_values(func_invoker, result) # throws an exception if result is invalid
_logger().info("Finished executing function {}, it took {} seconds".format(func.__name__, end_time - start_time))
return result
def func_invoker(*args, **kwargs):
@@ -1148,8 +1106,6 @@ def remote(arg_types, return_types, worker=global_worker):
raise Exception("Remote functions cannot be called directly. Instead of running '{}()', try '{}.remote()'.".format(func_name, func_name))
func_invoker.remote = func_call
func_invoker.executor = func_executor
func_invoker.arg_types = arg_types
func_invoker.return_types = return_types
func_invoker.is_remote = True
func_name = "{}.{}".format(func.__module__, func.__name__)
func_invoker.func_name = func_name
@@ -1168,7 +1124,7 @@ def remote(arg_types, return_types, worker=global_worker):
# Set the function globally to make it refer to itself
func.__globals__[func.__name__] = func_invoker # Allow the function to reference itself as a global variable
try:
to_export = pickling.dumps((func, arg_types, return_types, func.__module__))
to_export = pickling.dumps((func, num_return_vals, func.__module__))
finally:
# Undo our changes
if func_name_global_valid: func.__globals__[func.__name__] = func_name_global_value
@@ -1205,109 +1161,12 @@ def check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaul
if has_vararg_param and any([d != funcsigs._empty for _, d in keyword_defaults]):
raise "Function {} has a *args argument as well as a keyword argument, which is currently not supported.".format(name)
def check_return_values(function, result):
"""Check the types and number of return values.
Args:
function (Callable): The remote function whose outputs are being checked.
result: The value returned by an invocation of the remote function. The
expected types and number are defined in the remote decorator.
Raises:
Exception: An exception is raised if the return values have incorrect types
or the function returned the wrong number of return values.
"""
# If the @remote decorator declares that the function has no return values,
# then all we do is check that there were in fact no return values.
if len(function.return_types) == 0:
if result is not None:
raise Exception("The @remote decorator for function {} has 0 return values, but {} returned more than 0 values.".format(function.__name__, function.__name__))
return
# If a function has multiple return values, Python returns a tuple of the
# values. If there is a single return value, then Python does not return a
# tuple, it simply returns the value. That is why we place result with
# (result,) when there is only one return value, so we can treat these two
# cases similarly.
if len(function.return_types) == 1:
result = (result,)
# Below we check that the number of values returned by the function match the
# number of return values declared in the @remote decorator.
if len(result) != len(function.return_types):
raise Exception("The @remote decorator for function {} has {} return values with types {}, but {} returned {} values.".format(function.__name__, len(function.return_types), function.return_types, function.__name__, len(result)))
# Here we do some limited type checking to make sure the return values have
# the right types.
for i in range(len(result)):
if (not issubclass(type(result[i]), function.return_types[i])) and (not isinstance(result[i], raylib.ObjectID)):
raise Exception("The {}th return value for function {} has type {}, but the @remote decorator expected a return value of type {} or an ObjectID.".format(i, function.__name__, type(result[i]), function.return_types[i]))
def typecheck_arg(arg, expected_type, i, name):
"""Check that an argument has the expected type.
Args:
arg: An argument to function.
expected_type (type): The expected type of arg.
i (int): The position of the argument to the function.
name (str): The name of the function.
Raises:
RayGetArgumentTypeError: An exception is raised if arg does not have the
expected type.
"""
if issubclass(type(arg), expected_type):
# Passed the type-checck
# TODO(rkn): This check doesn't really work, e.g., issubclass(type([1, 2, 3]), typing.List[str]) == True
pass
elif isinstance(arg, long) and issubclass(int, expected_type):
# TODO(mehrdadn): Should long really be convertible to int?
pass
else:
raise RayGetArgumentTypeError(name, i, type(arg), expected_type)
def check_arguments(arg_types, has_vararg_param, name, args):
"""Check that the arguments to the remote function have the right types.
This is called by the worker that calls the remote function (not the worker
that executes the remote function).
Args:
arg_types (List[type]): A list of the types of the arguments to the function
being checked.
has_vararg_param (bool): True if the function being checked has a *args
argument.
name (str): The name of the function.
args (List): The arguments to the function.
Raises:
Exception: An exception is raised the args do not all have the right types.
"""
# check the number of args
if len(args) != len(arg_types) and not has_vararg_param:
raise Exception("Function {} expects {} arguments, but received {}.".format(name, len(arg_types), len(args)))
elif len(args) < len(arg_types) - 1 and has_vararg_param:
raise Exception("Function {} expects at least {} arguments, but received {}.".format(name, len(arg_types) - 1, len(args)))
for (i, arg) in enumerate(args):
if i <= len(arg_types) - 1:
expected_type = arg_types[i]
elif has_vararg_param:
expected_type = arg_types[-1]
else:
assert False, "This code should be unreachable."
if isinstance(arg, raylib.ObjectID):
# TODO(rkn): When we have type information in the ObjectID, do type checking here.
pass
else:
typecheck_arg(arg, expected_type, i, name)
def get_arguments_for_execution(function, args, worker=global_worker):
"""Retrieve the arguments for the remote function.
This retrieves the values for the arguments to the remote function that were
passed in as object IDs. Argumens that were passed by value are not changed.
This also does some type checking. This is called by the worker that is
executing the remote function.
This is called by the worker that is executing the remote function.
Args:
function (Callable): The remote function whose arguments are being
@@ -1321,25 +1180,9 @@ def get_arguments_for_execution(function, args, worker=global_worker):
Raises:
RayGetArgumentError: This exception is raised if a task that created one of
the arguments failed.
RayGetArgumentTypeError: This exception is raised (via typecheck_arg) if one
of the arguments does not have the expected type.
"""
# TODO(rkn): Eventually, all of the type checking can be put in `check_arguments` above so that the error will happen immediately when calling a remote function.
arguments = []
# # check the number of args
# if len(args) != len(function.arg_types) and function.arg_types[-1] is not None:
# raise Exception("Function {} expects {} arguments, but received {}.".format(function.__name__, len(function.arg_types), len(args)))
# elif len(args) < len(function.arg_types) - 1 and function.arg_types[-1] is None:
# raise Exception("Function {} expects at least {} arguments, but received {}.".format(function.__name__, len(function.arg_types) - 1, len(args)))
for (i, arg) in enumerate(args):
if i <= len(function.arg_types) - 1:
expected_type = function.arg_types[i]
elif function.has_vararg_param and len(function.arg_types) >= 1:
expected_type = function.arg_types[-1]
else:
assert False, "This code should be unreachable."
if isinstance(arg, raylib.ObjectID):
# get the object from the local object store
_logger().info("Getting argument {} for function {}.".format(i, function.__name__))
@@ -1353,7 +1196,6 @@ def get_arguments_for_execution(function, args, worker=global_worker):
# pass the argument by value
argument = arg
typecheck_arg(argument, expected_type, i, function.__name__)
arguments.append(argument)
return arguments