mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 19:33:06 +08:00
Vendored
+2
-3
@@ -145,7 +145,7 @@ def tril(a):
|
||||
result.objrefs[i, j] = single.zeros_like(a.objrefs[i, j])
|
||||
return result
|
||||
|
||||
@halo.remote([np.ndarray, None], [np.ndarray])
|
||||
@halo.remote([np.ndarray], [np.ndarray])
|
||||
def blockwise_dot(*matrices):
|
||||
n = len(matrices)
|
||||
if n % 2 != 0:
|
||||
@@ -171,8 +171,7 @@ def dot(a, b):
|
||||
result.objrefs[i, j] = blockwise_dot(*args)
|
||||
return result
|
||||
|
||||
# This is not in numpy, should we expose this?
|
||||
@halo.remote([DistArray, List[int], None], [DistArray])
|
||||
@halo.remote([DistArray, List[int]], [DistArray])
|
||||
def subblocks(a, *ranges):
|
||||
"""
|
||||
This function produces a distributed array from a subset of the blocks in the `a`. The result and `a` will have the same number of dimensions.For example,
|
||||
|
||||
@@ -26,18 +26,15 @@ def eye(N, M=-1, k=0, dtype_name="float"):
|
||||
def dot(a, b):
|
||||
return np.dot(a, b)
|
||||
|
||||
# TODO(rkn): My preferred signature would have been
|
||||
# @halo.remote([List[np.ndarray]], [np.ndarray]) but that currently doesn't
|
||||
# work because that would expect a list of ndarrays not a list of ObjRefs
|
||||
@halo.remote([np.ndarray, None], [np.ndarray])
|
||||
@halo.remote([np.ndarray], [np.ndarray])
|
||||
def vstack(*xs):
|
||||
return np.vstack(xs)
|
||||
|
||||
@halo.remote([np.ndarray, None], [np.ndarray])
|
||||
@halo.remote([np.ndarray], [np.ndarray])
|
||||
def hstack(*xs):
|
||||
return np.hstack(xs)
|
||||
|
||||
# TODO(rkn): this doesn't parallel the numpy API, but we can't really slice an ObjRef, think about this
|
||||
# TODO(rkn): instead of this, consider implementing slicing
|
||||
@halo.remote([np.ndarray, List[int], List[int]], [np.ndarray])
|
||||
def subarray(a, lower_indices, upper_indices): # TODO(rkn): be consistent about using "index" versus "indices"
|
||||
return a[[slice(l, u) for (l, u) in zip(lower_indices, upper_indices)]]
|
||||
@@ -71,7 +68,7 @@ def add(x1, x2):
|
||||
def subtract(x1, x2):
|
||||
return np.subtract(x1, x2)
|
||||
|
||||
@halo.remote([int, np.ndarray, None], [np.ndarray])
|
||||
@halo.remote([int, np.ndarray], [np.ndarray])
|
||||
def sum(axis, *xs):
|
||||
return np.sum(xs, axis=axis)
|
||||
|
||||
|
||||
@@ -83,6 +83,6 @@ def cond(x):
|
||||
def matrix_rank(M):
|
||||
return np.linalg.matrix_rank(M)
|
||||
|
||||
@halo.remote([np.ndarray, None], [np.ndarray])
|
||||
def multi_dot(a):
|
||||
@halo.remote([np.ndarray], [np.ndarray])
|
||||
def multi_dot(*a):
|
||||
raise NotImplementedError
|
||||
|
||||
+23
-11
@@ -118,10 +118,26 @@ def remote(arg_types, return_types, worker=global_worker):
|
||||
func_call.arg_types = arg_types
|
||||
func_call.return_types = return_types
|
||||
func_call.is_remote = True
|
||||
func_call.keyword_defaults = [(k, v.default) for k, v in funcsigs.signature(func).parameters.iteritems()]
|
||||
func_call.sig_params = [(k, v) for k, v in funcsigs.signature(func).parameters.iteritems()]
|
||||
func_call.keyword_defaults = [(k, v.default) for k, v in func_call.sig_params]
|
||||
func_call.has_vararg_param = any([v.kind == v.VAR_POSITIONAL for k, v in func_call.sig_params])
|
||||
func_call.has_kwargs_param = any([v.kind == v.VAR_KEYWORD for k, v in func_call.sig_params])
|
||||
check_signature_supported(func_call)
|
||||
return func_call
|
||||
return remote_decorator
|
||||
|
||||
# helper method, this should not be called by the user
|
||||
# we currently do not support the functionality that we test for in this method,
|
||||
# but in the future we could
|
||||
def check_signature_supported(function):
|
||||
# check if the user specified kwargs
|
||||
if function.has_kwargs_param:
|
||||
raise "Function {} has a **kwargs argument, which is currently not supported.".format(function.__name__)
|
||||
# check if the user specified a variable number of arguments and any keyword arguments
|
||||
if function.has_vararg_param and any([d != funcsigs._empty for k, d in function.keyword_defaults]):
|
||||
raise "Function {} has a *args argument as well as a keyword argument, which is currently not supported.".format(function.__name__)
|
||||
|
||||
|
||||
# helper method, this should not be called by the user
|
||||
def check_return_values(function, result):
|
||||
if len(function.return_types) == 1:
|
||||
@@ -138,18 +154,16 @@ def check_return_values(function, result):
|
||||
# helper method, this should not be called by the user
|
||||
def check_arguments(function, args):
|
||||
# check the number of args
|
||||
if len(args) != len(function.arg_types) and function.arg_types[-1] is not None:
|
||||
if len(args) != len(function.arg_types) and not function.has_vararg_param:
|
||||
raise Exception("Function {} expects {} arguments, but received {}.".format(function.__name__, len(function.arg_types), len(args)))
|
||||
elif len(args) < len(function.arg_types) - 1 and function.arg_types[-1] is None:
|
||||
elif len(args) < len(function.arg_types) - 1 and function.has_vararg_param:
|
||||
raise Exception("Function {} expects at least {} arguments, but received {}.".format(function.__name__, len(function.arg_types) - 1, len(args)))
|
||||
|
||||
for (i, arg) in enumerate(args):
|
||||
if i < len(function.arg_types) - 1:
|
||||
if i <= len(function.arg_types) - 1:
|
||||
expected_type = function.arg_types[i]
|
||||
elif i == len(function.arg_types) - 1 and function.arg_types[-1] is not None:
|
||||
elif function.has_vararg_param:
|
||||
expected_type = function.arg_types[-1]
|
||||
elif function.arg_types[-1] is None and len(function.arg_types) > 1:
|
||||
expected_type = function.arg_types[-2]
|
||||
else:
|
||||
assert False, "This code should be unreachable."
|
||||
|
||||
@@ -173,12 +187,10 @@ def get_arguments_for_execution(function, args, worker=global_worker):
|
||||
"""
|
||||
|
||||
for (i, arg) in enumerate(args):
|
||||
if i < len(function.arg_types) - 1:
|
||||
if i <= len(function.arg_types) - 1:
|
||||
expected_type = function.arg_types[i]
|
||||
elif i == len(function.arg_types) - 1 and function.arg_types[-1] is not None:
|
||||
elif function.has_vararg_param and len(function.arg_types) >= 1:
|
||||
expected_type = function.arg_types[-1]
|
||||
elif function.arg_types[-1] is None and len(function.arg_types) > 1:
|
||||
expected_type = function.arg_types[-2]
|
||||
else:
|
||||
assert False, "This code should be unreachable."
|
||||
|
||||
|
||||
Reference in New Issue
Block a user