mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 11:01:06 +08:00
Remove type information from remote decorator.
This commit is contained in:
@@ -66,12 +66,12 @@ class DistArray(object):
|
||||
a = self.assemble()
|
||||
return a[sliced]
|
||||
|
||||
@ray.remote([DistArray], [np.ndarray])
|
||||
@ray.remote()
|
||||
def assemble(a):
|
||||
return a.assemble()
|
||||
|
||||
# TODO(rkn): what should we call this method
|
||||
@ray.remote([np.ndarray], [DistArray])
|
||||
@ray.remote()
|
||||
def numpy_to_dist(a):
|
||||
result = DistArray(a.shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
@@ -80,28 +80,28 @@ def numpy_to_dist(a):
|
||||
result.objectids[index] = ray.put(a[[slice(l, u) for (l, u) in zip(lower, upper)]])
|
||||
return result
|
||||
|
||||
@ray.remote([List, str], [DistArray])
|
||||
@ray.remote()
|
||||
def zeros(shape, dtype_name="float"):
|
||||
result = DistArray(shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objectids[index] = ra.zeros.remote(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
|
||||
return result
|
||||
|
||||
@ray.remote([List, str], [DistArray])
|
||||
@ray.remote()
|
||||
def ones(shape, dtype_name="float"):
|
||||
result = DistArray(shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objectids[index] = ra.ones.remote(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
|
||||
return result
|
||||
|
||||
@ray.remote([DistArray], [DistArray])
|
||||
@ray.remote()
|
||||
def copy(a):
|
||||
result = DistArray(a.shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objectids[index] = a.objectids[index] # We don't need to actually copy the objects because cluster-level objects are assumed to be immutable.
|
||||
return result
|
||||
|
||||
@ray.remote([int, int, str], [DistArray])
|
||||
@ray.remote()
|
||||
def eye(dim1, dim2=-1, dtype_name="float"):
|
||||
dim2 = dim1 if dim2 == -1 else dim2
|
||||
shape = [dim1, dim2]
|
||||
@@ -114,7 +114,7 @@ def eye(dim1, dim2=-1, dtype_name="float"):
|
||||
result.objectids[i, j] = ra.zeros.remote(block_shape, dtype_name=dtype_name)
|
||||
return result
|
||||
|
||||
@ray.remote([DistArray], [DistArray])
|
||||
@ray.remote()
|
||||
def triu(a):
|
||||
if a.ndim != 2:
|
||||
raise Exception("Input must have 2 dimensions, but a.ndim is " + str(a.ndim))
|
||||
@@ -128,7 +128,7 @@ def triu(a):
|
||||
result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j])
|
||||
return result
|
||||
|
||||
@ray.remote([DistArray], [DistArray])
|
||||
@ray.remote()
|
||||
def tril(a):
|
||||
if a.ndim != 2:
|
||||
raise Exception("Input must have 2 dimensions, but a.ndim is " + str(a.ndim))
|
||||
@@ -142,7 +142,7 @@ def tril(a):
|
||||
result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j])
|
||||
return result
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray])
|
||||
@ray.remote()
|
||||
def blockwise_dot(*matrices):
|
||||
n = len(matrices)
|
||||
if n % 2 != 0:
|
||||
@@ -153,7 +153,7 @@ def blockwise_dot(*matrices):
|
||||
result += np.dot(matrices[i], matrices[n / 2 + i])
|
||||
return result
|
||||
|
||||
@ray.remote([DistArray, DistArray], [DistArray])
|
||||
@ray.remote()
|
||||
def dot(a, b):
|
||||
if a.ndim != 2:
|
||||
raise Exception("dot expects its arguments to be 2-dimensional, but a.ndim = {}.".format(a.ndim))
|
||||
@@ -168,7 +168,7 @@ def dot(a, b):
|
||||
result.objectids[i, j] = blockwise_dot.remote(*args)
|
||||
return result
|
||||
|
||||
@ray.remote([DistArray, List], [DistArray])
|
||||
@ray.remote()
|
||||
def subblocks(a, *ranges):
|
||||
"""
|
||||
This function produces a distributed array from a subset of the blocks in the `a`. The result and `a` will have the same number of dimensions.For example,
|
||||
@@ -198,7 +198,7 @@ def subblocks(a, *ranges):
|
||||
result.objectids[index] = a.objectids[tuple([ranges[i][index[i]] for i in range(a.ndim)])]
|
||||
return result
|
||||
|
||||
@ray.remote([DistArray], [DistArray])
|
||||
@ray.remote()
|
||||
def transpose(a):
|
||||
if a.ndim != 2:
|
||||
raise Exception("transpose expects its argument to be 2-dimensional, but a.ndim = {}, a.shape = {}.".format(a.ndim, a.shape))
|
||||
@@ -209,7 +209,7 @@ def transpose(a):
|
||||
return result
|
||||
|
||||
# TODO(rkn): support broadcasting?
|
||||
@ray.remote([DistArray, DistArray], [DistArray])
|
||||
@ray.remote()
|
||||
def add(x1, x2):
|
||||
if x1.shape != x2.shape:
|
||||
raise Exception("add expects arguments `x1` and `x2` to have the same shape, but x1.shape = {}, and x2.shape = {}.".format(x1.shape, x2.shape))
|
||||
@@ -219,7 +219,7 @@ def add(x1, x2):
|
||||
return result
|
||||
|
||||
# TODO(rkn): support broadcasting?
|
||||
@ray.remote([DistArray, DistArray], [DistArray])
|
||||
@ray.remote()
|
||||
def subtract(x1, x2):
|
||||
if x1.shape != x2.shape:
|
||||
raise Exception("subtract expects arguments `x1` and `x2` to have the same shape, but x1.shape = {}, and x2.shape = {}.".format(x1.shape, x2.shape))
|
||||
|
||||
@@ -6,7 +6,7 @@ from core import *
|
||||
|
||||
__all__ = ["tsqr", "modified_lu", "tsqr_hr", "qr"]
|
||||
|
||||
@ray.remote([DistArray], [DistArray, np.ndarray])
|
||||
@ray.remote(num_return_vals=2)
|
||||
def tsqr(a):
|
||||
"""
|
||||
arguments:
|
||||
@@ -75,7 +75,7 @@ def tsqr(a):
|
||||
return q_result, r
|
||||
|
||||
# TODO(rkn): This is unoptimized, we really want a block version of this.
|
||||
@ray.remote([DistArray], [DistArray, np.ndarray, np.ndarray])
|
||||
@ray.remote(num_return_vals=3)
|
||||
def modified_lu(q):
|
||||
"""
|
||||
Algorithm 5 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf
|
||||
@@ -105,19 +105,19 @@ def modified_lu(q):
|
||||
U = np.triu(q_work)[:b, :]
|
||||
return numpy_to_dist.remote(ray.put(L)), U, S # TODO(rkn): get rid of put
|
||||
|
||||
@ray.remote([np.ndarray, np.ndarray, np.ndarray, int], [np.ndarray, np.ndarray])
|
||||
@ray.remote(num_return_vals=2)
|
||||
def tsqr_hr_helper1(u, s, y_top_block, b):
|
||||
y_top = y_top_block[:b, :b]
|
||||
s_full = np.diag(s)
|
||||
t = -1 * np.dot(u, np.dot(s_full, np.linalg.inv(y_top).T))
|
||||
return t, y_top
|
||||
|
||||
@ray.remote([np.ndarray, np.ndarray], [np.ndarray])
|
||||
@ray.remote()
|
||||
def tsqr_hr_helper2(s, r_temp):
|
||||
s_full = np.diag(s)
|
||||
return np.dot(s_full, r_temp)
|
||||
|
||||
@ray.remote([DistArray], [DistArray, np.ndarray, np.ndarray, np.ndarray])
|
||||
@ray.remote(num_return_vals=4)
|
||||
def tsqr_hr(a):
|
||||
"""Algorithm 6 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf"""
|
||||
q, r_temp = tsqr.remote(a)
|
||||
@@ -127,15 +127,15 @@ def tsqr_hr(a):
|
||||
r = tsqr_hr_helper2.remote(s, r_temp)
|
||||
return y, t, y_top, r
|
||||
|
||||
@ray.remote([np.ndarray, np.ndarray, np.ndarray, np.ndarray], [np.ndarray])
|
||||
@ray.remote()
|
||||
def qr_helper1(a_rc, y_ri, t, W_c):
|
||||
return a_rc - np.dot(y_ri, np.dot(t.T, W_c))
|
||||
|
||||
@ray.remote([np.ndarray, np.ndarray], [np.ndarray])
|
||||
@ray.remote()
|
||||
def qr_helper2(y_ri, a_rc):
|
||||
return np.dot(y_ri.T, a_rc)
|
||||
|
||||
@ray.remote([DistArray], [DistArray, DistArray])
|
||||
@ray.remote(num_return_vals=2)
|
||||
def qr(a):
|
||||
"""Algorithm 7 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf"""
|
||||
m, n = a.shape[0], a.shape[1]
|
||||
|
||||
@@ -6,7 +6,7 @@ import ray
|
||||
|
||||
from core import *
|
||||
|
||||
@ray.remote([List], [DistArray])
|
||||
@ray.remote()
|
||||
def normal(shape):
|
||||
num_blocks = DistArray.compute_num_blocks(shape)
|
||||
objectids = np.empty(num_blocks, dtype=object)
|
||||
|
||||
@@ -4,80 +4,80 @@ import ray
|
||||
|
||||
__all__ = ["zeros", "zeros_like", "ones", "eye", "dot", "vstack", "hstack", "subarray", "copy", "tril", "triu", "diag", "transpose", "add", "subtract", "sum", "shape", "sum_list"]
|
||||
|
||||
@ray.remote([List, str, str], [np.ndarray])
|
||||
@ray.remote()
|
||||
def zeros(shape, dtype_name="float", order="C"):
|
||||
return np.zeros(shape, dtype=np.dtype(dtype_name), order=order)
|
||||
|
||||
@ray.remote([np.ndarray, str, str, bool], [np.ndarray])
|
||||
@ray.remote()
|
||||
def zeros_like(a, dtype_name="None", order="K", subok=True):
|
||||
dtype_val = None if dtype_name == "None" else np.dtype(dtype_name)
|
||||
return np.zeros_like(a, dtype=dtype_val, order=order, subok=subok)
|
||||
|
||||
@ray.remote([List, str, str], [np.ndarray])
|
||||
@ray.remote()
|
||||
def ones(shape, dtype_name="float", order="C"):
|
||||
return np.ones(shape, dtype=np.dtype(dtype_name), order=order)
|
||||
|
||||
@ray.remote([int, int, int, str], [np.ndarray])
|
||||
@ray.remote()
|
||||
def eye(N, M=-1, k=0, dtype_name="float"):
|
||||
M = N if M == -1 else M
|
||||
return np.eye(N, M=M, k=k, dtype=np.dtype(dtype_name))
|
||||
|
||||
@ray.remote([np.ndarray, np.ndarray], [np.ndarray])
|
||||
@ray.remote()
|
||||
def dot(a, b):
|
||||
return np.dot(a, b)
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray])
|
||||
@ray.remote()
|
||||
def vstack(*xs):
|
||||
return np.vstack(xs)
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray])
|
||||
@ray.remote()
|
||||
def hstack(*xs):
|
||||
return np.hstack(xs)
|
||||
|
||||
# TODO(rkn): instead of this, consider implementing slicing
|
||||
@ray.remote([np.ndarray, List, List], [np.ndarray])
|
||||
@ray.remote()
|
||||
def subarray(a, lower_indices, upper_indices): # TODO(rkn): be consistent about using "index" versus "indices"
|
||||
return a[[slice(l, u) for (l, u) in zip(lower_indices, upper_indices)]]
|
||||
|
||||
@ray.remote([np.ndarray, str], [np.ndarray])
|
||||
@ray.remote()
|
||||
def copy(a, order="K"):
|
||||
return np.copy(a, order=order)
|
||||
|
||||
@ray.remote([np.ndarray, int], [np.ndarray])
|
||||
@ray.remote()
|
||||
def tril(m, k=0):
|
||||
return np.tril(m, k=k)
|
||||
|
||||
@ray.remote([np.ndarray, int], [np.ndarray])
|
||||
@ray.remote()
|
||||
def triu(m, k=0):
|
||||
return np.triu(m, k=k)
|
||||
|
||||
@ray.remote([np.ndarray, int], [np.ndarray])
|
||||
@ray.remote()
|
||||
def diag(v, k=0):
|
||||
return np.diag(v, k=k)
|
||||
|
||||
@ray.remote([np.ndarray, List], [np.ndarray])
|
||||
@ray.remote()
|
||||
def transpose(a, axes=[]):
|
||||
axes = None if axes == [] else axes
|
||||
return np.transpose(a, axes=axes)
|
||||
|
||||
@ray.remote([np.ndarray, np.ndarray], [np.ndarray])
|
||||
@ray.remote()
|
||||
def add(x1, x2):
|
||||
return np.add(x1, x2)
|
||||
|
||||
@ray.remote([np.ndarray, np.ndarray], [np.ndarray])
|
||||
@ray.remote()
|
||||
def subtract(x1, x2):
|
||||
return np.subtract(x1, x2)
|
||||
|
||||
@ray.remote([np.ndarray, int], [np.ndarray])
|
||||
@ray.remote()
|
||||
def sum(x, axis=-1):
|
||||
return np.sum(x, axis=axis if axis != -1 else None)
|
||||
|
||||
@ray.remote([np.ndarray], [tuple])
|
||||
@ray.remote()
|
||||
def shape(a):
|
||||
return np.shape(a)
|
||||
|
||||
# We use Any to allow different numerical types as well as numpy arrays.
|
||||
# TODO(rkn):this isn't in the numpy API, so be careful about exposing this.
|
||||
@ray.remote([Any], [Any])
|
||||
@ray.remote()
|
||||
def sum_list(*xs):
|
||||
return np.sum(xs, axis=0)
|
||||
|
||||
@@ -7,82 +7,82 @@ __all__ = ["matrix_power", "solve", "tensorsolve", "tensorinv", "inv",
|
||||
"svd", "eig", "eigh", "lstsq", "norm", "qr", "cond", "matrix_rank",
|
||||
"LinAlgError", "multi_dot"]
|
||||
|
||||
@ray.remote([np.ndarray, int], [np.ndarray])
|
||||
@ray.remote()
|
||||
def matrix_power(M, n):
|
||||
return np.linalg.matrix_power(M, n)
|
||||
|
||||
@ray.remote([np.ndarray, np.ndarray], [np.ndarray])
|
||||
@ray.remote()
|
||||
def solve(a, b):
|
||||
return np.linalg.solve(a, b)
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray, np.ndarray])
|
||||
@ray.remote(num_return_vals=2)
|
||||
def tensorsolve(a):
|
||||
raise NotImplementedError
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray, np.ndarray])
|
||||
@ray.remote(num_return_vals=2)
|
||||
def tensorinv(a):
|
||||
raise NotImplementedError
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray])
|
||||
@ray.remote()
|
||||
def inv(a):
|
||||
return np.linalg.inv(a)
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray])
|
||||
@ray.remote()
|
||||
def cholesky(a):
|
||||
return np.linalg.cholesky(a)
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray])
|
||||
@ray.remote()
|
||||
def eigvals(a):
|
||||
return np.linalg.eigvals(a)
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray])
|
||||
@ray.remote()
|
||||
def eigvalsh(a):
|
||||
raise NotImplementedError
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray])
|
||||
@ray.remote()
|
||||
def pinv(a):
|
||||
return np.linalg.pinv(a)
|
||||
|
||||
@ray.remote([np.ndarray], [int])
|
||||
@ray.remote()
|
||||
def slogdet(a):
|
||||
raise NotImplementedError
|
||||
|
||||
@ray.remote([np.ndarray], [float])
|
||||
@ray.remote()
|
||||
def det(a):
|
||||
return np.linalg.det(a)
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray, np.ndarray, np.ndarray])
|
||||
@ray.remote(num_return_vals=3)
|
||||
def svd(a):
|
||||
return np.linalg.svd(a)
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray, np.ndarray])
|
||||
@ray.remote(num_return_vals=2)
|
||||
def eig(a):
|
||||
return np.linalg.eig(a)
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray, np.ndarray])
|
||||
@ray.remote(num_return_vals=2)
|
||||
def eigh(a):
|
||||
return np.linalg.eigh(a)
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray, np.ndarray, int, np.ndarray])
|
||||
@ray.remote(num_return_vals=4)
|
||||
def lstsq(a, b):
|
||||
return np.linalg.lstsq(a)
|
||||
|
||||
@ray.remote([np.ndarray], [float])
|
||||
@ray.remote()
|
||||
def norm(x):
|
||||
return np.linalg.norm(x)
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray, np.ndarray])
|
||||
@ray.remote(num_return_vals=2)
|
||||
def qr(a):
|
||||
return np.linalg.qr(a)
|
||||
|
||||
@ray.remote([np.ndarray], [float])
|
||||
@ray.remote()
|
||||
def cond(x):
|
||||
return np.linalg.cond(x)
|
||||
|
||||
@ray.remote([np.ndarray], [int])
|
||||
@ray.remote()
|
||||
def matrix_rank(M):
|
||||
return np.linalg.matrix_rank(M)
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray])
|
||||
@ray.remote()
|
||||
def multi_dot(*a):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -2,6 +2,6 @@ from typing import List
|
||||
import numpy as np
|
||||
import ray
|
||||
|
||||
@ray.remote([List], [np.ndarray])
|
||||
@ray.remote()
|
||||
def normal(shape):
|
||||
return np.random.normal(size=shape)
|
||||
|
||||
+9
-167
@@ -4,7 +4,6 @@ import time
|
||||
import traceback
|
||||
import copy
|
||||
import logging
|
||||
from types import ModuleType
|
||||
import typing
|
||||
import funcsigs
|
||||
import numpy as np
|
||||
@@ -45,7 +44,7 @@ class RayTaskError(Exception):
|
||||
def __init__(self, function_name, exception, traceback_str):
|
||||
"""Initialize a RayTaskError."""
|
||||
self.function_name = function_name
|
||||
if isinstance(exception, RayGetError) or isinstance(exception, RayGetArgumentError) or isinstance(exception, RayGetArgumentTypeError):
|
||||
if isinstance(exception, RayGetError) or isinstance(exception, RayGetArgumentError):
|
||||
self.exception = exception
|
||||
else:
|
||||
self.exception = None
|
||||
@@ -59,8 +58,6 @@ class RayTaskError(Exception):
|
||||
exception = RayGetError.deserialize(exception[1])
|
||||
elif exception[0] == "RayGetArgumentError":
|
||||
exception = RayGetArgumentError.deserialize(exception[1])
|
||||
elif exception[0] == "RayGetArgumentTypeError":
|
||||
exception = RayGetArgumentTypeError.deserialize(exception[1])
|
||||
elif exception[0] == "None":
|
||||
exception = None
|
||||
else:
|
||||
@@ -73,8 +70,6 @@ class RayTaskError(Exception):
|
||||
serialized_exception = ("RayGetError", self.exception.serialize())
|
||||
elif isinstance(self.exception, RayGetArgumentError):
|
||||
serialized_exception = ("RayGetArgumentError", self.exception.serialize())
|
||||
elif isinstance(self.exception, RayGetArgumentTypeError):
|
||||
serialized_exception = ("RayGetArgumentTypeError", self.exception.serialize())
|
||||
elif self.exception is None:
|
||||
serialized_exception = ("None",)
|
||||
else:
|
||||
@@ -151,41 +146,6 @@ class RayGetArgumentError(Exception):
|
||||
"""Format a RayGetArgumentError as a string."""
|
||||
return "Failed to get objectid {} as argument {} for remote function {}{}{}. It was created by remote function {}{}{} which failed with:\n{}".format(self.objectid, self.argument_index, colorama.Fore.RED, self.function_name, colorama.Fore.RESET, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error)
|
||||
|
||||
class RayGetArgumentTypeError(Exception):
|
||||
"""An exception used when a task's argument doesn't type check.
|
||||
|
||||
Attributes:
|
||||
function_name (str): The name of the function for the current task.
|
||||
argument_index (int): The index (zero indexed) of the argument in the
|
||||
present task's remote function call.
|
||||
received_type: The type of the argument that was passed in.
|
||||
expected_type: The type that was expected. This is determined by the remote
|
||||
decorator.
|
||||
"""
|
||||
|
||||
def __init__(self, function_name, argument_index, received_type, expected_type):
|
||||
"""Initialize a RayGetArgumentTypeError object."""
|
||||
self.function_name = function_name
|
||||
self.argument_index = argument_index
|
||||
# TODO(rkn): when we support the serialization of types, then we should
|
||||
# remove the string conversions below.
|
||||
self.received_type = str(received_type)
|
||||
self.expected_type = str(expected_type)
|
||||
|
||||
@staticmethod
|
||||
def deserialize(primitives):
|
||||
"""Create a RayGetArgumentTypeError from a primitive object."""
|
||||
function_name, argument_index, received_type, expected_type = primitives
|
||||
return RayGetArgumentTypeError(function_name, argument_index, received_type, expected_type)
|
||||
|
||||
def serialize(self):
|
||||
"""Turn a RayGetArgumentTypeError into a primitive object."""
|
||||
return (self.function_name, self.argument_index, self.received_type, self.expected_type)
|
||||
|
||||
def __str__(self):
|
||||
"""Format a RayGetArgumentTypeError as a string."""
|
||||
return "Argument {} for remote function {}{}{} has type {} but an argument of type {} was expected.".format(self.argument_index, colorama.Fore.RED, self.function_name, colorama.Fore.RESET, self.received_type, self.expected_type)
|
||||
|
||||
class RayDealloc(object):
|
||||
"""An object used internally to properly implement reference counting.
|
||||
|
||||
@@ -993,7 +953,7 @@ def main_loop(worker=global_worker):
|
||||
def process_remote_function(function_name, serialized_function):
|
||||
"""Import a remote function."""
|
||||
try:
|
||||
(function, arg_types, return_types, module) = pickling.loads(serialized_function)
|
||||
function, num_return_vals, module = pickling.loads(serialized_function)
|
||||
except:
|
||||
# If an exception was thrown when the remote function was imported, we
|
||||
# record the traceback and notify the scheduler of the failure.
|
||||
@@ -1005,11 +965,11 @@ def main_loop(worker=global_worker):
|
||||
# TODO(rkn): Why is the below line necessary?
|
||||
function.__module__ = module
|
||||
assert function_name == "{}.{}".format(function.__module__, function.__name__), "The remote function name does not match the name that was passed in."
|
||||
worker.functions[function_name] = remote(arg_types, return_types, worker)(function)
|
||||
worker.functions[function_name] = remote(num_return_vals, worker)(function)
|
||||
_logger().info("Successfully imported remote function {}.".format(function_name))
|
||||
# Noify the scheduler that the remote function imported successfully.
|
||||
# We pass an empty error message string because the import succeeded.
|
||||
raylib.register_remote_function(worker.handle, function_name, len(return_types))
|
||||
raylib.register_remote_function(worker.handle, function_name, num_return_vals)
|
||||
|
||||
def process_reusable_variable(reusable_variable_name, initializer_str, reinitializer_str):
|
||||
"""Import a reusable variable."""
|
||||
@@ -1110,12 +1070,12 @@ def _export_reusable_variable(name, reusable, worker=global_worker):
|
||||
raise Exception("_export_reusable_variable can only be called on a driver.")
|
||||
raylib.export_reusable_variable(worker.handle, name, pickling.dumps(reusable.initializer), pickling.dumps(reusable.reinitializer))
|
||||
|
||||
def remote(arg_types, return_types, worker=global_worker):
|
||||
def remote(num_return_vals=1, worker=global_worker):
|
||||
"""This decorator is used to create remote functions.
|
||||
|
||||
Args:
|
||||
arg_types (List[type]): List of Python types of the function arguments.
|
||||
return_types (List[type]): List of Python types of the return values.
|
||||
num_return_vals (int): The number of object IDs that a call to this function
|
||||
should return.
|
||||
"""
|
||||
def remote_decorator(func):
|
||||
def func_call(*args, **kwargs):
|
||||
@@ -1128,7 +1088,6 @@ def remote(arg_types, return_types, worker=global_worker):
|
||||
# arguments to prevent the function call from mutating them and to match
|
||||
# the usual behavior of immutable remote objects.
|
||||
return func(*copy.deepcopy(args))
|
||||
check_arguments(arg_types, has_vararg_param, func_name, args) # throws an exception if args are invalid
|
||||
objectids = _submit_task(func_name, args)
|
||||
if len(objectids) == 1:
|
||||
return objectids[0]
|
||||
@@ -1140,7 +1099,6 @@ def remote(arg_types, return_types, worker=global_worker):
|
||||
start_time = time.time()
|
||||
result = func(*arguments)
|
||||
end_time = time.time()
|
||||
check_return_values(func_invoker, result) # throws an exception if result is invalid
|
||||
_logger().info("Finished executing function {}, it took {} seconds".format(func.__name__, end_time - start_time))
|
||||
return result
|
||||
def func_invoker(*args, **kwargs):
|
||||
@@ -1148,8 +1106,6 @@ def remote(arg_types, return_types, worker=global_worker):
|
||||
raise Exception("Remote functions cannot be called directly. Instead of running '{}()', try '{}.remote()'.".format(func_name, func_name))
|
||||
func_invoker.remote = func_call
|
||||
func_invoker.executor = func_executor
|
||||
func_invoker.arg_types = arg_types
|
||||
func_invoker.return_types = return_types
|
||||
func_invoker.is_remote = True
|
||||
func_name = "{}.{}".format(func.__module__, func.__name__)
|
||||
func_invoker.func_name = func_name
|
||||
@@ -1168,7 +1124,7 @@ def remote(arg_types, return_types, worker=global_worker):
|
||||
# Set the function globally to make it refer to itself
|
||||
func.__globals__[func.__name__] = func_invoker # Allow the function to reference itself as a global variable
|
||||
try:
|
||||
to_export = pickling.dumps((func, arg_types, return_types, func.__module__))
|
||||
to_export = pickling.dumps((func, num_return_vals, func.__module__))
|
||||
finally:
|
||||
# Undo our changes
|
||||
if func_name_global_valid: func.__globals__[func.__name__] = func_name_global_value
|
||||
@@ -1205,109 +1161,12 @@ def check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaul
|
||||
if has_vararg_param and any([d != funcsigs._empty for _, d in keyword_defaults]):
|
||||
raise "Function {} has a *args argument as well as a keyword argument, which is currently not supported.".format(name)
|
||||
|
||||
|
||||
def check_return_values(function, result):
|
||||
"""Check the types and number of return values.
|
||||
|
||||
Args:
|
||||
function (Callable): The remote function whose outputs are being checked.
|
||||
result: The value returned by an invocation of the remote function. The
|
||||
expected types and number are defined in the remote decorator.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if the return values have incorrect types
|
||||
or the function returned the wrong number of return values.
|
||||
"""
|
||||
# If the @remote decorator declares that the function has no return values,
|
||||
# then all we do is check that there were in fact no return values.
|
||||
if len(function.return_types) == 0:
|
||||
if result is not None:
|
||||
raise Exception("The @remote decorator for function {} has 0 return values, but {} returned more than 0 values.".format(function.__name__, function.__name__))
|
||||
return
|
||||
# If a function has multiple return values, Python returns a tuple of the
|
||||
# values. If there is a single return value, then Python does not return a
|
||||
# tuple, it simply returns the value. That is why we place result with
|
||||
# (result,) when there is only one return value, so we can treat these two
|
||||
# cases similarly.
|
||||
if len(function.return_types) == 1:
|
||||
result = (result,)
|
||||
# Below we check that the number of values returned by the function match the
|
||||
# number of return values declared in the @remote decorator.
|
||||
if len(result) != len(function.return_types):
|
||||
raise Exception("The @remote decorator for function {} has {} return values with types {}, but {} returned {} values.".format(function.__name__, len(function.return_types), function.return_types, function.__name__, len(result)))
|
||||
# Here we do some limited type checking to make sure the return values have
|
||||
# the right types.
|
||||
for i in range(len(result)):
|
||||
if (not issubclass(type(result[i]), function.return_types[i])) and (not isinstance(result[i], raylib.ObjectID)):
|
||||
raise Exception("The {}th return value for function {} has type {}, but the @remote decorator expected a return value of type {} or an ObjectID.".format(i, function.__name__, type(result[i]), function.return_types[i]))
|
||||
|
||||
def typecheck_arg(arg, expected_type, i, name):
|
||||
"""Check that an argument has the expected type.
|
||||
|
||||
Args:
|
||||
arg: An argument to function.
|
||||
expected_type (type): The expected type of arg.
|
||||
i (int): The position of the argument to the function.
|
||||
name (str): The name of the function.
|
||||
|
||||
Raises:
|
||||
RayGetArgumentTypeError: An exception is raised if arg does not have the
|
||||
expected type.
|
||||
"""
|
||||
if issubclass(type(arg), expected_type):
|
||||
# Passed the type-checck
|
||||
# TODO(rkn): This check doesn't really work, e.g., issubclass(type([1, 2, 3]), typing.List[str]) == True
|
||||
pass
|
||||
elif isinstance(arg, long) and issubclass(int, expected_type):
|
||||
# TODO(mehrdadn): Should long really be convertible to int?
|
||||
pass
|
||||
else:
|
||||
raise RayGetArgumentTypeError(name, i, type(arg), expected_type)
|
||||
|
||||
def check_arguments(arg_types, has_vararg_param, name, args):
|
||||
"""Check that the arguments to the remote function have the right types.
|
||||
|
||||
This is called by the worker that calls the remote function (not the worker
|
||||
that executes the remote function).
|
||||
|
||||
Args:
|
||||
arg_types (List[type]): A list of the types of the arguments to the function
|
||||
being checked.
|
||||
has_vararg_param (bool): True if the function being checked has a *args
|
||||
argument.
|
||||
name (str): The name of the function.
|
||||
args (List): The arguments to the function.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised the args do not all have the right types.
|
||||
"""
|
||||
# check the number of args
|
||||
if len(args) != len(arg_types) and not has_vararg_param:
|
||||
raise Exception("Function {} expects {} arguments, but received {}.".format(name, len(arg_types), len(args)))
|
||||
elif len(args) < len(arg_types) - 1 and has_vararg_param:
|
||||
raise Exception("Function {} expects at least {} arguments, but received {}.".format(name, len(arg_types) - 1, len(args)))
|
||||
|
||||
for (i, arg) in enumerate(args):
|
||||
if i <= len(arg_types) - 1:
|
||||
expected_type = arg_types[i]
|
||||
elif has_vararg_param:
|
||||
expected_type = arg_types[-1]
|
||||
else:
|
||||
assert False, "This code should be unreachable."
|
||||
|
||||
if isinstance(arg, raylib.ObjectID):
|
||||
# TODO(rkn): When we have type information in the ObjectID, do type checking here.
|
||||
pass
|
||||
else:
|
||||
typecheck_arg(arg, expected_type, i, name)
|
||||
|
||||
def get_arguments_for_execution(function, args, worker=global_worker):
|
||||
"""Retrieve the arguments for the remote function.
|
||||
|
||||
This retrieves the values for the arguments to the remote function that were
|
||||
passed in as object IDs. Argumens that were passed by value are not changed.
|
||||
This also does some type checking. This is called by the worker that is
|
||||
executing the remote function.
|
||||
This is called by the worker that is executing the remote function.
|
||||
|
||||
Args:
|
||||
function (Callable): The remote function whose arguments are being
|
||||
@@ -1321,25 +1180,9 @@ def get_arguments_for_execution(function, args, worker=global_worker):
|
||||
Raises:
|
||||
RayGetArgumentError: This exception is raised if a task that created one of
|
||||
the arguments failed.
|
||||
RayGetArgumentTypeError: This exception is raised (via typecheck_arg) if one
|
||||
of the arguments does not have the expected type.
|
||||
"""
|
||||
# TODO(rkn): Eventually, all of the type checking can be put in `check_arguments` above so that the error will happen immediately when calling a remote function.
|
||||
arguments = []
|
||||
# # check the number of args
|
||||
# if len(args) != len(function.arg_types) and function.arg_types[-1] is not None:
|
||||
# raise Exception("Function {} expects {} arguments, but received {}.".format(function.__name__, len(function.arg_types), len(args)))
|
||||
# elif len(args) < len(function.arg_types) - 1 and function.arg_types[-1] is None:
|
||||
# raise Exception("Function {} expects at least {} arguments, but received {}.".format(function.__name__, len(function.arg_types) - 1, len(args)))
|
||||
|
||||
for (i, arg) in enumerate(args):
|
||||
if i <= len(function.arg_types) - 1:
|
||||
expected_type = function.arg_types[i]
|
||||
elif function.has_vararg_param and len(function.arg_types) >= 1:
|
||||
expected_type = function.arg_types[-1]
|
||||
else:
|
||||
assert False, "This code should be unreachable."
|
||||
|
||||
if isinstance(arg, raylib.ObjectID):
|
||||
# get the object from the local object store
|
||||
_logger().info("Getting argument {} for function {}.".format(i, function.__name__))
|
||||
@@ -1353,7 +1196,6 @@ def get_arguments_for_execution(function, args, worker=global_worker):
|
||||
# pass the argument by value
|
||||
argument = arg
|
||||
|
||||
typecheck_arg(argument, expected_type, i, function.__name__)
|
||||
arguments.append(argument)
|
||||
return arguments
|
||||
|
||||
|
||||
Reference in New Issue
Block a user