From f9aeb5d018b034e33ac4cbec1bcb88e03d3c1451 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sat, 4 Jun 2016 16:22:10 -0700 Subject: [PATCH] implement varargs (#83) * implement varargs * clean up varargs --- lib/python/arrays/dist/core.py | 5 ++--- lib/python/arrays/single/core.py | 11 ++++------ lib/python/arrays/single/linalg.py | 4 ++-- lib/python/halo/worker.py | 34 ++++++++++++++++++++---------- test/runtest.py | 18 +++++++++++++++- test/test_functions.py | 26 +++++++++++++++++++++++ 6 files changed, 74 insertions(+), 24 deletions(-) diff --git a/lib/python/arrays/dist/core.py b/lib/python/arrays/dist/core.py index f0265ccc9..7ecef0a44 100644 --- a/lib/python/arrays/dist/core.py +++ b/lib/python/arrays/dist/core.py @@ -145,7 +145,7 @@ def tril(a): result.objrefs[i, j] = single.zeros_like(a.objrefs[i, j]) return result -@halo.remote([np.ndarray, None], [np.ndarray]) +@halo.remote([np.ndarray], [np.ndarray]) def blockwise_dot(*matrices): n = len(matrices) if n % 2 != 0: @@ -171,8 +171,7 @@ def dot(a, b): result.objrefs[i, j] = blockwise_dot(*args) return result -# This is not in numpy, should we expose this? -@halo.remote([DistArray, List[int], None], [DistArray]) +@halo.remote([DistArray, List[int]], [DistArray]) def subblocks(a, *ranges): """ This function produces a distributed array from a subset of the blocks in the `a`. The result and `a` will have the same number of dimensions.For example, diff --git a/lib/python/arrays/single/core.py b/lib/python/arrays/single/core.py index a409e7091..4487f72c9 100644 --- a/lib/python/arrays/single/core.py +++ b/lib/python/arrays/single/core.py @@ -26,18 +26,15 @@ def eye(N, M=-1, k=0, dtype_name="float"): def dot(a, b): return np.dot(a, b) -# TODO(rkn): My preferred signature would have been -# @halo.remote([List[np.ndarray]], [np.ndarray]) but that currently doesn't -# work because that would expect a list of ndarrays not a list of ObjRefs -@halo.remote([np.ndarray, None], [np.ndarray]) +@halo.remote([np.ndarray], [np.ndarray]) def vstack(*xs): return np.vstack(xs) -@halo.remote([np.ndarray, None], [np.ndarray]) +@halo.remote([np.ndarray], [np.ndarray]) def hstack(*xs): return np.hstack(xs) -# TODO(rkn): this doesn't parallel the numpy API, but we can't really slice an ObjRef, think about this +# TODO(rkn): instead of this, consider implementing slicing @halo.remote([np.ndarray, List[int], List[int]], [np.ndarray]) def subarray(a, lower_indices, upper_indices): # TODO(rkn): be consistent about using "index" versus "indices" return a[[slice(l, u) for (l, u) in zip(lower_indices, upper_indices)]] @@ -71,7 +68,7 @@ def add(x1, x2): def subtract(x1, x2): return np.subtract(x1, x2) -@halo.remote([int, np.ndarray, None], [np.ndarray]) +@halo.remote([int, np.ndarray], [np.ndarray]) def sum(axis, *xs): return np.sum(xs, axis=axis) diff --git a/lib/python/arrays/single/linalg.py b/lib/python/arrays/single/linalg.py index 1721cc4a0..f7d7fa6b2 100644 --- a/lib/python/arrays/single/linalg.py +++ b/lib/python/arrays/single/linalg.py @@ -83,6 +83,6 @@ def cond(x): def matrix_rank(M): return np.linalg.matrix_rank(M) -@halo.remote([np.ndarray, None], [np.ndarray]) -def multi_dot(a): +@halo.remote([np.ndarray], [np.ndarray]) +def multi_dot(*a): raise NotImplementedError diff --git a/lib/python/halo/worker.py b/lib/python/halo/worker.py index c0a275ff1..334d99bbe 100644 --- a/lib/python/halo/worker.py +++ b/lib/python/halo/worker.py @@ -118,10 +118,26 @@ def remote(arg_types, return_types, worker=global_worker): func_call.arg_types = arg_types func_call.return_types = return_types func_call.is_remote = True - func_call.keyword_defaults = [(k, v.default) for k, v in funcsigs.signature(func).parameters.iteritems()] + func_call.sig_params = [(k, v) for k, v in funcsigs.signature(func).parameters.iteritems()] + func_call.keyword_defaults = [(k, v.default) for k, v in func_call.sig_params] + func_call.has_vararg_param = any([v.kind == v.VAR_POSITIONAL for k, v in func_call.sig_params]) + func_call.has_kwargs_param = any([v.kind == v.VAR_KEYWORD for k, v in func_call.sig_params]) + check_signature_supported(func_call) return func_call return remote_decorator +# helper method, this should not be called by the user +# we currently do not support the functionality that we test for in this method, +# but in the future we could +def check_signature_supported(function): + # check if the user specified kwargs + if function.has_kwargs_param: + raise "Function {} has a **kwargs argument, which is currently not supported.".format(function.__name__) + # check if the user specified a variable number of arguments and any keyword arguments + if function.has_vararg_param and any([d != funcsigs._empty for k, d in function.keyword_defaults]): + raise "Function {} has a *args argument as well as a keyword argument, which is currently not supported.".format(function.__name__) + + # helper method, this should not be called by the user def check_return_values(function, result): if len(function.return_types) == 1: @@ -138,18 +154,16 @@ def check_return_values(function, result): # helper method, this should not be called by the user def check_arguments(function, args): # check the number of args - if len(args) != len(function.arg_types) and function.arg_types[-1] is not None: + if len(args) != len(function.arg_types) and not function.has_vararg_param: raise Exception("Function {} expects {} arguments, but received {}.".format(function.__name__, len(function.arg_types), len(args))) - elif len(args) < len(function.arg_types) - 1 and function.arg_types[-1] is None: + elif len(args) < len(function.arg_types) - 1 and function.has_vararg_param: raise Exception("Function {} expects at least {} arguments, but received {}.".format(function.__name__, len(function.arg_types) - 1, len(args))) for (i, arg) in enumerate(args): - if i < len(function.arg_types) - 1: + if i <= len(function.arg_types) - 1: expected_type = function.arg_types[i] - elif i == len(function.arg_types) - 1 and function.arg_types[-1] is not None: + elif function.has_vararg_param: expected_type = function.arg_types[-1] - elif function.arg_types[-1] is None and len(function.arg_types) > 1: - expected_type = function.arg_types[-2] else: assert False, "This code should be unreachable." @@ -173,12 +187,10 @@ def get_arguments_for_execution(function, args, worker=global_worker): """ for (i, arg) in enumerate(args): - if i < len(function.arg_types) - 1: + if i <= len(function.arg_types) - 1: expected_type = function.arg_types[i] - elif i == len(function.arg_types) - 1 and function.arg_types[-1] is not None: + elif function.has_vararg_param and len(function.arg_types) >= 1: expected_type = function.arg_types[-1] - elif function.arg_types[-1] is None and len(function.arg_types) > 1: - expected_type = function.arg_types[-2] else: assert False, "This code should be unreachable." diff --git a/test/runtest.py b/test/runtest.py index c0f49bc4a..58a597e50 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -191,7 +191,8 @@ class APITest(unittest.TestCase): def testKeywordArgs(self): test_dir = os.path.dirname(os.path.abspath(__file__)) test_path = os.path.join(test_dir, "testrecv.py") - services.start_singlenode_cluster(return_drivers=False, num_workers_per_objstore=3, worker_path=test_path) + services.start_singlenode_cluster(return_drivers=False, num_workers_per_objstore=1, worker_path=test_path) + x = test_functions.keyword_fct1(1) self.assertEqual(halo.pull(x), "1 hello") x = test_functions.keyword_fct1(1, "hi") @@ -225,6 +226,21 @@ class APITest(unittest.TestCase): services.cleanup() + def testVariableNumberOfArgs(self): + test_dir = os.path.dirname(os.path.abspath(__file__)) + test_path = os.path.join(test_dir, "testrecv.py") + services.start_singlenode_cluster(return_drivers=False, num_workers_per_objstore=1, worker_path=test_path) + + x = test_functions.varargs_fct1(0, 1, 2) + self.assertEqual(halo.pull(x), "0 1 2") + x = test_functions.varargs_fct2(0, 1, 2) + self.assertEqual(halo.pull(x), "1 2") + + self.assertTrue(test_functions.kwargs_exception_thrown) + self.assertTrue(test_functions.varargs_and_kwargs_exception_thrown) + + services.cleanup() + class ReferenceCountingTest(unittest.TestCase): def testDeallocation(self): diff --git a/test/test_functions.py b/test/test_functions.py index 10dea8da8..47ae396d9 100644 --- a/test/test_functions.py +++ b/test/test_functions.py @@ -52,3 +52,29 @@ def keyword_fct2(a="hello", b="world"): @halo.remote([int, int, str, str], [str]) def keyword_fct3(a, b, c="hello", d="world"): return "{} {} {} {}".format(a, b, c, d) + +# Test variable numbers of arguments + +@halo.remote([int], [str]) +def varargs_fct1(*a): + return " ".join(map(str, a)) + +@halo.remote([int, int], [str]) +def varargs_fct2(a, *b): + return " ".join(map(str, b)) + +try: + @halo.remote([int], []) + def kwargs_throw_exception(**c): + return () + kwargs_exception_thrown = False +except: + kwargs_exception_thrown = True + +try: + @halo.remote([int, str, int], [str]) + def varargs_and_kwargs_throw_exception(a, b="hi", *c): + return "{} {} {}".format(a, b, c) + varargs_and_kwargs_exception_thrown = False +except: + varargs_and_kwargs_exception_thrown = True