implement varargs (#83)

* implement varargs

* clean up varargs
This commit is contained in:
Philipp Moritz
2016-06-04 16:22:10 -07:00
parent 2b52b91acb
commit f9aeb5d018
6 changed files with 74 additions and 24 deletions
+2 -3
View File
@@ -145,7 +145,7 @@ def tril(a):
result.objrefs[i, j] = single.zeros_like(a.objrefs[i, j])
return result
@halo.remote([np.ndarray, None], [np.ndarray])
@halo.remote([np.ndarray], [np.ndarray])
def blockwise_dot(*matrices):
n = len(matrices)
if n % 2 != 0:
@@ -171,8 +171,7 @@ def dot(a, b):
result.objrefs[i, j] = blockwise_dot(*args)
return result
# This is not in numpy, should we expose this?
@halo.remote([DistArray, List[int], None], [DistArray])
@halo.remote([DistArray, List[int]], [DistArray])
def subblocks(a, *ranges):
"""
This function produces a distributed array from a subset of the blocks in the `a`. The result and `a` will have the same number of dimensions.For example,
+4 -7
View File
@@ -26,18 +26,15 @@ def eye(N, M=-1, k=0, dtype_name="float"):
def dot(a, b):
return np.dot(a, b)
# TODO(rkn): My preferred signature would have been
# @halo.remote([List[np.ndarray]], [np.ndarray]) but that currently doesn't
# work because that would expect a list of ndarrays not a list of ObjRefs
@halo.remote([np.ndarray, None], [np.ndarray])
@halo.remote([np.ndarray], [np.ndarray])
def vstack(*xs):
return np.vstack(xs)
@halo.remote([np.ndarray, None], [np.ndarray])
@halo.remote([np.ndarray], [np.ndarray])
def hstack(*xs):
return np.hstack(xs)
# TODO(rkn): this doesn't parallel the numpy API, but we can't really slice an ObjRef, think about this
# TODO(rkn): instead of this, consider implementing slicing
@halo.remote([np.ndarray, List[int], List[int]], [np.ndarray])
def subarray(a, lower_indices, upper_indices): # TODO(rkn): be consistent about using "index" versus "indices"
return a[[slice(l, u) for (l, u) in zip(lower_indices, upper_indices)]]
@@ -71,7 +68,7 @@ def add(x1, x2):
def subtract(x1, x2):
return np.subtract(x1, x2)
@halo.remote([int, np.ndarray, None], [np.ndarray])
@halo.remote([int, np.ndarray], [np.ndarray])
def sum(axis, *xs):
return np.sum(xs, axis=axis)
+2 -2
View File
@@ -83,6 +83,6 @@ def cond(x):
def matrix_rank(M):
return np.linalg.matrix_rank(M)
@halo.remote([np.ndarray, None], [np.ndarray])
def multi_dot(a):
@halo.remote([np.ndarray], [np.ndarray])
def multi_dot(*a):
raise NotImplementedError
+23 -11
View File
@@ -118,10 +118,26 @@ def remote(arg_types, return_types, worker=global_worker):
func_call.arg_types = arg_types
func_call.return_types = return_types
func_call.is_remote = True
func_call.keyword_defaults = [(k, v.default) for k, v in funcsigs.signature(func).parameters.iteritems()]
func_call.sig_params = [(k, v) for k, v in funcsigs.signature(func).parameters.iteritems()]
func_call.keyword_defaults = [(k, v.default) for k, v in func_call.sig_params]
func_call.has_vararg_param = any([v.kind == v.VAR_POSITIONAL for k, v in func_call.sig_params])
func_call.has_kwargs_param = any([v.kind == v.VAR_KEYWORD for k, v in func_call.sig_params])
check_signature_supported(func_call)
return func_call
return remote_decorator
# helper method, this should not be called by the user
# we currently do not support the functionality that we test for in this method,
# but in the future we could
def check_signature_supported(function):
# check if the user specified kwargs
if function.has_kwargs_param:
raise "Function {} has a **kwargs argument, which is currently not supported.".format(function.__name__)
# check if the user specified a variable number of arguments and any keyword arguments
if function.has_vararg_param and any([d != funcsigs._empty for k, d in function.keyword_defaults]):
raise "Function {} has a *args argument as well as a keyword argument, which is currently not supported.".format(function.__name__)
# helper method, this should not be called by the user
def check_return_values(function, result):
if len(function.return_types) == 1:
@@ -138,18 +154,16 @@ def check_return_values(function, result):
# helper method, this should not be called by the user
def check_arguments(function, args):
# check the number of args
if len(args) != len(function.arg_types) and function.arg_types[-1] is not None:
if len(args) != len(function.arg_types) and not function.has_vararg_param:
raise Exception("Function {} expects {} arguments, but received {}.".format(function.__name__, len(function.arg_types), len(args)))
elif len(args) < len(function.arg_types) - 1 and function.arg_types[-1] is None:
elif len(args) < len(function.arg_types) - 1 and function.has_vararg_param:
raise Exception("Function {} expects at least {} arguments, but received {}.".format(function.__name__, len(function.arg_types) - 1, len(args)))
for (i, arg) in enumerate(args):
if i < len(function.arg_types) - 1:
if i <= len(function.arg_types) - 1:
expected_type = function.arg_types[i]
elif i == len(function.arg_types) - 1 and function.arg_types[-1] is not None:
elif function.has_vararg_param:
expected_type = function.arg_types[-1]
elif function.arg_types[-1] is None and len(function.arg_types) > 1:
expected_type = function.arg_types[-2]
else:
assert False, "This code should be unreachable."
@@ -173,12 +187,10 @@ def get_arguments_for_execution(function, args, worker=global_worker):
"""
for (i, arg) in enumerate(args):
if i < len(function.arg_types) - 1:
if i <= len(function.arg_types) - 1:
expected_type = function.arg_types[i]
elif i == len(function.arg_types) - 1 and function.arg_types[-1] is not None:
elif function.has_vararg_param and len(function.arg_types) >= 1:
expected_type = function.arg_types[-1]
elif function.arg_types[-1] is None and len(function.arg_types) > 1:
expected_type = function.arg_types[-2]
else:
assert False, "This code should be unreachable."