mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 16:30:41 +08:00
Allow remote decorator to be used with no parentheses.
This commit is contained in:
@@ -65,12 +65,12 @@ class DistArray(object):
|
||||
a = self.assemble()
|
||||
return a[sliced]
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def assemble(a):
|
||||
return a.assemble()
|
||||
|
||||
# TODO(rkn): what should we call this method
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def numpy_to_dist(a):
|
||||
result = DistArray(a.shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
@@ -79,28 +79,28 @@ def numpy_to_dist(a):
|
||||
result.objectids[index] = ray.put(a[[slice(l, u) for (l, u) in zip(lower, upper)]])
|
||||
return result
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def zeros(shape, dtype_name="float"):
|
||||
result = DistArray(shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objectids[index] = ra.zeros.remote(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
|
||||
return result
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def ones(shape, dtype_name="float"):
|
||||
result = DistArray(shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objectids[index] = ra.ones.remote(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
|
||||
return result
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def copy(a):
|
||||
result = DistArray(a.shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objectids[index] = a.objectids[index] # We don't need to actually copy the objects because cluster-level objects are assumed to be immutable.
|
||||
return result
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def eye(dim1, dim2=-1, dtype_name="float"):
|
||||
dim2 = dim1 if dim2 == -1 else dim2
|
||||
shape = [dim1, dim2]
|
||||
@@ -113,7 +113,7 @@ def eye(dim1, dim2=-1, dtype_name="float"):
|
||||
result.objectids[i, j] = ra.zeros.remote(block_shape, dtype_name=dtype_name)
|
||||
return result
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def triu(a):
|
||||
if a.ndim != 2:
|
||||
raise Exception("Input must have 2 dimensions, but a.ndim is " + str(a.ndim))
|
||||
@@ -127,7 +127,7 @@ def triu(a):
|
||||
result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j])
|
||||
return result
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def tril(a):
|
||||
if a.ndim != 2:
|
||||
raise Exception("Input must have 2 dimensions, but a.ndim is " + str(a.ndim))
|
||||
@@ -141,7 +141,7 @@ def tril(a):
|
||||
result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j])
|
||||
return result
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def blockwise_dot(*matrices):
|
||||
n = len(matrices)
|
||||
if n % 2 != 0:
|
||||
@@ -152,7 +152,7 @@ def blockwise_dot(*matrices):
|
||||
result += np.dot(matrices[i], matrices[n / 2 + i])
|
||||
return result
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def dot(a, b):
|
||||
if a.ndim != 2:
|
||||
raise Exception("dot expects its arguments to be 2-dimensional, but a.ndim = {}.".format(a.ndim))
|
||||
@@ -167,7 +167,7 @@ def dot(a, b):
|
||||
result.objectids[i, j] = blockwise_dot.remote(*args)
|
||||
return result
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def subblocks(a, *ranges):
|
||||
"""
|
||||
This function produces a distributed array from a subset of the blocks in the `a`. The result and `a` will have the same number of dimensions.For example,
|
||||
@@ -197,7 +197,7 @@ def subblocks(a, *ranges):
|
||||
result.objectids[index] = a.objectids[tuple([ranges[i][index[i]] for i in range(a.ndim)])]
|
||||
return result
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def transpose(a):
|
||||
if a.ndim != 2:
|
||||
raise Exception("transpose expects its argument to be 2-dimensional, but a.ndim = {}, a.shape = {}.".format(a.ndim, a.shape))
|
||||
@@ -208,7 +208,7 @@ def transpose(a):
|
||||
return result
|
||||
|
||||
# TODO(rkn): support broadcasting?
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def add(x1, x2):
|
||||
if x1.shape != x2.shape:
|
||||
raise Exception("add expects arguments `x1` and `x2` to have the same shape, but x1.shape = {}, and x2.shape = {}.".format(x1.shape, x2.shape))
|
||||
@@ -218,7 +218,7 @@ def add(x1, x2):
|
||||
return result
|
||||
|
||||
# TODO(rkn): support broadcasting?
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def subtract(x1, x2):
|
||||
if x1.shape != x2.shape:
|
||||
raise Exception("subtract expects arguments `x1` and `x2` to have the same shape, but x1.shape = {}, and x2.shape = {}.".format(x1.shape, x2.shape))
|
||||
|
||||
@@ -112,7 +112,7 @@ def tsqr_hr_helper1(u, s, y_top_block, b):
|
||||
t = -1 * np.dot(u, np.dot(s_full, np.linalg.inv(y_top).T))
|
||||
return t, y_top
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def tsqr_hr_helper2(s, r_temp):
|
||||
s_full = np.diag(s)
|
||||
return np.dot(s_full, r_temp)
|
||||
@@ -127,11 +127,11 @@ def tsqr_hr(a):
|
||||
r = tsqr_hr_helper2.remote(s, r_temp)
|
||||
return y, t, y_top, r
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def qr_helper1(a_rc, y_ri, t, W_c):
|
||||
return a_rc - np.dot(y_ri, np.dot(t.T, W_c))
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def qr_helper2(y_ri, a_rc):
|
||||
return np.dot(y_ri.T, a_rc)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import ray
|
||||
|
||||
from core import *
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def normal(shape):
|
||||
num_blocks = DistArray.compute_num_blocks(shape)
|
||||
objectids = np.empty(num_blocks, dtype=object)
|
||||
|
||||
@@ -3,80 +3,80 @@ import ray
|
||||
|
||||
__all__ = ["zeros", "zeros_like", "ones", "eye", "dot", "vstack", "hstack", "subarray", "copy", "tril", "triu", "diag", "transpose", "add", "subtract", "sum", "shape", "sum_list"]
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def zeros(shape, dtype_name="float", order="C"):
|
||||
return np.zeros(shape, dtype=np.dtype(dtype_name), order=order)
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def zeros_like(a, dtype_name="None", order="K", subok=True):
|
||||
dtype_val = None if dtype_name == "None" else np.dtype(dtype_name)
|
||||
return np.zeros_like(a, dtype=dtype_val, order=order, subok=subok)
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def ones(shape, dtype_name="float", order="C"):
|
||||
return np.ones(shape, dtype=np.dtype(dtype_name), order=order)
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def eye(N, M=-1, k=0, dtype_name="float"):
|
||||
M = N if M == -1 else M
|
||||
return np.eye(N, M=M, k=k, dtype=np.dtype(dtype_name))
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def dot(a, b):
|
||||
return np.dot(a, b)
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def vstack(*xs):
|
||||
return np.vstack(xs)
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def hstack(*xs):
|
||||
return np.hstack(xs)
|
||||
|
||||
# TODO(rkn): instead of this, consider implementing slicing
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def subarray(a, lower_indices, upper_indices): # TODO(rkn): be consistent about using "index" versus "indices"
|
||||
return a[[slice(l, u) for (l, u) in zip(lower_indices, upper_indices)]]
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def copy(a, order="K"):
|
||||
return np.copy(a, order=order)
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def tril(m, k=0):
|
||||
return np.tril(m, k=k)
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def triu(m, k=0):
|
||||
return np.triu(m, k=k)
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def diag(v, k=0):
|
||||
return np.diag(v, k=k)
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def transpose(a, axes=[]):
|
||||
axes = None if axes == [] else axes
|
||||
return np.transpose(a, axes=axes)
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def add(x1, x2):
|
||||
return np.add(x1, x2)
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def subtract(x1, x2):
|
||||
return np.subtract(x1, x2)
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def sum(x, axis=-1):
|
||||
return np.sum(x, axis=axis if axis != -1 else None)
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def shape(a):
|
||||
return np.shape(a)
|
||||
|
||||
# We use Any to allow different numerical types as well as numpy arrays.
|
||||
# TODO(rkn):this isn't in the numpy API, so be careful about exposing this.
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def sum_list(*xs):
|
||||
return np.sum(xs, axis=0)
|
||||
|
||||
@@ -6,11 +6,11 @@ __all__ = ["matrix_power", "solve", "tensorsolve", "tensorinv", "inv",
|
||||
"svd", "eig", "eigh", "lstsq", "norm", "qr", "cond", "matrix_rank",
|
||||
"LinAlgError", "multi_dot"]
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def matrix_power(M, n):
|
||||
return np.linalg.matrix_power(M, n)
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def solve(a, b):
|
||||
return np.linalg.solve(a, b)
|
||||
|
||||
@@ -22,31 +22,31 @@ def tensorsolve(a):
|
||||
def tensorinv(a):
|
||||
raise NotImplementedError
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def inv(a):
|
||||
return np.linalg.inv(a)
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def cholesky(a):
|
||||
return np.linalg.cholesky(a)
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def eigvals(a):
|
||||
return np.linalg.eigvals(a)
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def eigvalsh(a):
|
||||
raise NotImplementedError
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def pinv(a):
|
||||
return np.linalg.pinv(a)
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def slogdet(a):
|
||||
raise NotImplementedError
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def det(a):
|
||||
return np.linalg.det(a)
|
||||
|
||||
@@ -66,7 +66,7 @@ def eigh(a):
|
||||
def lstsq(a, b):
|
||||
return np.linalg.lstsq(a)
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def norm(x):
|
||||
return np.linalg.norm(x)
|
||||
|
||||
@@ -74,14 +74,14 @@ def norm(x):
|
||||
def qr(a):
|
||||
return np.linalg.qr(a)
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def cond(x):
|
||||
return np.linalg.cond(x)
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def matrix_rank(M):
|
||||
return np.linalg.matrix_rank(M)
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def multi_dot(*a):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import numpy as np
|
||||
import ray
|
||||
|
||||
@ray.remote()
|
||||
@ray.remote
|
||||
def normal(shape):
|
||||
return np.random.normal(size=shape)
|
||||
|
||||
@@ -28,7 +28,7 @@ class Tuple(tuple):
|
||||
|
||||
class Str(str):
|
||||
pass
|
||||
|
||||
|
||||
class Unicode(unicode):
|
||||
pass
|
||||
|
||||
|
||||
+77
-59
@@ -964,7 +964,7 @@ def main_loop(worker=global_worker):
|
||||
# TODO(rkn): Why is the below line necessary?
|
||||
function.__module__ = module
|
||||
assert function_name == "{}.{}".format(function.__module__, function.__name__), "The remote function name does not match the name that was passed in."
|
||||
worker.functions[function_name] = remote(num_return_vals, worker)(function)
|
||||
worker.functions[function_name] = remote(num_return_vals=num_return_vals)(function)
|
||||
_logger().info("Successfully imported remote function {}.".format(function_name))
|
||||
# Noify the scheduler that the remote function imported successfully.
|
||||
# We pass an empty error message string because the import succeeded.
|
||||
@@ -1069,71 +1069,89 @@ def _export_reusable_variable(name, reusable, worker=global_worker):
|
||||
raise Exception("_export_reusable_variable can only be called on a driver.")
|
||||
raylib.export_reusable_variable(worker.handle, name, pickling.dumps(reusable.initializer), pickling.dumps(reusable.reinitializer))
|
||||
|
||||
def remote(num_return_vals=1, worker=global_worker):
|
||||
def remote(*args, **kwargs):
|
||||
"""This decorator is used to create remote functions.
|
||||
|
||||
Args:
|
||||
num_return_vals (int): The number of object IDs that a call to this function
|
||||
should return.
|
||||
"""
|
||||
def remote_decorator(func):
|
||||
def func_call(*args, **kwargs):
|
||||
"""This gets run immediately when a worker calls a remote function."""
|
||||
check_connected()
|
||||
args = list(args)
|
||||
args.extend([kwargs[keyword] if kwargs.has_key(keyword) else default for keyword, default in keyword_defaults[len(args):]]) # fill in the remaining arguments
|
||||
if _mode() == raylib.PYTHON_MODE:
|
||||
# In raylib.PYTHON_MODE, remote calls simply execute the function. We copy the
|
||||
# arguments to prevent the function call from mutating them and to match
|
||||
# the usual behavior of immutable remote objects.
|
||||
return func(*copy.deepcopy(args))
|
||||
objectids = _submit_task(func_name, args)
|
||||
if len(objectids) == 1:
|
||||
return objectids[0]
|
||||
elif len(objectids) > 1:
|
||||
return objectids
|
||||
def func_executor(arguments):
|
||||
"""This gets run when the remote function is executed."""
|
||||
_logger().info("Calling function {}".format(func.__name__))
|
||||
start_time = time.time()
|
||||
result = func(*arguments)
|
||||
end_time = time.time()
|
||||
_logger().info("Finished executing function {}, it took {} seconds".format(func.__name__, end_time - start_time))
|
||||
return result
|
||||
def func_invoker(*args, **kwargs):
|
||||
"""This is returned by the decorator and used to invoke the function."""
|
||||
raise Exception("Remote functions cannot be called directly. Instead of running '{}()', try '{}.remote()'.".format(func_name, func_name))
|
||||
func_invoker.remote = func_call
|
||||
func_invoker.executor = func_executor
|
||||
func_invoker.is_remote = True
|
||||
func_name = "{}.{}".format(func.__module__, func.__name__)
|
||||
func_invoker.func_name = func_name
|
||||
func_invoker.func_doc = func.func_doc
|
||||
sig_params = [(k, v) for k, v in funcsigs.signature(func).parameters.iteritems()]
|
||||
keyword_defaults = [(k, v.default) for k, v in sig_params]
|
||||
has_vararg_param = any([v.kind == v.VAR_POSITIONAL for k, v in sig_params])
|
||||
func_invoker.has_vararg_param = has_vararg_param
|
||||
has_kwargs_param = any([v.kind == v.VAR_KEYWORD for k, v in sig_params])
|
||||
check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, func_name)
|
||||
worker = global_worker
|
||||
def make_remote_decorator(num_return_vals):
|
||||
def remote_decorator(func):
|
||||
def func_call(*args, **kwargs):
|
||||
"""This gets run immediately when a worker calls a remote function."""
|
||||
check_connected()
|
||||
args = list(args)
|
||||
args.extend([kwargs[keyword] if kwargs.has_key(keyword) else default for keyword, default in keyword_defaults[len(args):]]) # fill in the remaining arguments
|
||||
if any([arg is funcsigs._empty for arg in args]):
|
||||
raise Exception("Not enough arguments were provided to {}.".format(func_name))
|
||||
if _mode() == raylib.PYTHON_MODE:
|
||||
# In raylib.PYTHON_MODE, remote calls simply execute the function. We copy the
|
||||
# arguments to prevent the function call from mutating them and to match
|
||||
# the usual behavior of immutable remote objects.
|
||||
return func(*copy.deepcopy(args))
|
||||
objectids = _submit_task(func_name, args)
|
||||
if len(objectids) == 1:
|
||||
return objectids[0]
|
||||
elif len(objectids) > 1:
|
||||
return objectids
|
||||
def func_executor(arguments):
|
||||
"""This gets run when the remote function is executed."""
|
||||
_logger().info("Calling function {}".format(func.__name__))
|
||||
start_time = time.time()
|
||||
result = func(*arguments)
|
||||
end_time = time.time()
|
||||
_logger().info("Finished executing function {}, it took {} seconds".format(func.__name__, end_time - start_time))
|
||||
return result
|
||||
def func_invoker(*args, **kwargs):
|
||||
"""This is returned by the decorator and used to invoke the function."""
|
||||
raise Exception("Remote functions cannot be called directly. Instead of running '{}()', try '{}.remote()'.".format(func_name, func_name))
|
||||
func_invoker.remote = func_call
|
||||
func_invoker.executor = func_executor
|
||||
func_invoker.is_remote = True
|
||||
func_name = "{}.{}".format(func.__module__, func.__name__)
|
||||
func_invoker.func_name = func_name
|
||||
func_invoker.func_doc = func.func_doc
|
||||
|
||||
# Everything ready - export the function
|
||||
if worker.mode in [None, raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
|
||||
func_name_global_valid = func.__name__ in func.__globals__
|
||||
func_name_global_value = func.__globals__.get(func.__name__)
|
||||
# Set the function globally to make it refer to itself
|
||||
func.__globals__[func.__name__] = func_invoker # Allow the function to reference itself as a global variable
|
||||
try:
|
||||
to_export = pickling.dumps((func, num_return_vals, func.__module__))
|
||||
finally:
|
||||
# Undo our changes
|
||||
if func_name_global_valid: func.__globals__[func.__name__] = func_name_global_value
|
||||
else: del func.__globals__[func.__name__]
|
||||
if worker.mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
|
||||
raylib.export_remote_function(worker.handle, func_name, to_export)
|
||||
elif worker.mode is None:
|
||||
worker.cached_remote_functions.append((func_name, to_export))
|
||||
return func_invoker
|
||||
return remote_decorator
|
||||
sig_params = [(k, v) for k, v in funcsigs.signature(func).parameters.iteritems()]
|
||||
keyword_defaults = [(k, v.default) for k, v in sig_params]
|
||||
has_vararg_param = any([v.kind == v.VAR_POSITIONAL for k, v in sig_params])
|
||||
func_invoker.has_vararg_param = has_vararg_param
|
||||
has_kwargs_param = any([v.kind == v.VAR_KEYWORD for k, v in sig_params])
|
||||
check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, func_name)
|
||||
|
||||
# Everything ready - export the function
|
||||
if worker.mode in [None, raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
|
||||
func_name_global_valid = func.__name__ in func.__globals__
|
||||
func_name_global_value = func.__globals__.get(func.__name__)
|
||||
# Set the function globally to make it refer to itself
|
||||
func.__globals__[func.__name__] = func_invoker # Allow the function to reference itself as a global variable
|
||||
try:
|
||||
to_export = pickling.dumps((func, num_return_vals, func.__module__))
|
||||
finally:
|
||||
# Undo our changes
|
||||
if func_name_global_valid: func.__globals__[func.__name__] = func_name_global_value
|
||||
else: del func.__globals__[func.__name__]
|
||||
if worker.mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
|
||||
raylib.export_remote_function(worker.handle, func_name, to_export)
|
||||
elif worker.mode is None:
|
||||
worker.cached_remote_functions.append((func_name, to_export))
|
||||
return func_invoker
|
||||
|
||||
return remote_decorator
|
||||
|
||||
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
||||
# This is the case where the decorator is just @ray.remote.
|
||||
num_return_vals = 1
|
||||
func = args[0]
|
||||
return make_remote_decorator(num_return_vals)(func)
|
||||
else:
|
||||
# This is the case where the decorator is something like
|
||||
# @ray.remote(num_return_vals=2).
|
||||
assert len(args) == 0 and "num_return_vals" in kwargs.keys(), "The @ray.remote decorator must be applied either with no arguments and no parentheses, for example '@ray.remote', or it must be applied with only the argument num_return_vals, like '@ray.remote(num_return_vals=2)'."
|
||||
num_return_vals = kwargs["num_return_vals"]
|
||||
return make_remote_decorator(num_return_vals)
|
||||
|
||||
def check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, name):
|
||||
"""Check if we support the signature of this function.
|
||||
|
||||
Reference in New Issue
Block a user