diff --git a/lib/orchpy/orchpy/worker.py b/lib/orchpy/orchpy/worker.py index 0bcf85abc..37e24effd 100644 --- a/lib/orchpy/orchpy/worker.py +++ b/lib/orchpy/orchpy/worker.py @@ -1,5 +1,6 @@ from types import ModuleType import typing +import funcsigs import numpy as np import pynumbuf @@ -105,16 +106,19 @@ def distributed(arg_types, return_types, worker=global_worker): check_return_values(func_call, result) # throws an exception if result is invalid print "Finished executing function {}".format(func.__name__) return result - def func_call(*args): + def func_call(*args, **kwargs): """This is what gets run immediately when a worker calls a distributed function.""" - check_arguments(func_call, list(args)) # throws an exception if args are invalid - objrefs = worker.submit_task(func_call.func_name, list(args)) + args = list(args) + args.extend([kwargs[keyword] if kwargs.has_key(keyword) else default for keyword, default in func_call.keyword_defaults[len(args):]]) # fill in the remaining arguments + check_arguments(func_call, args) # throws an exception if args are invalid + objrefs = worker.submit_task(func_call.func_name, args) return objrefs[0] if len(objrefs) == 1 else objrefs func_call.func_name = "{}.{}".format(func.__module__, func.__name__) func_call.executor = func_executor func_call.arg_types = arg_types func_call.return_types = return_types func_call.is_distributed = True + func_call.keyword_defaults = [(k, v.default) for k, v in funcsigs.signature(func).parameters.iteritems()] return func_call return distributed_decorator diff --git a/requirements.txt b/requirements.txt index e3736ee5c..85f1652e5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ six >= 1.10 typing +funcsigs subprocess32 grpcio diff --git a/test/runtest.py b/test/runtest.py index 86908e59a..f6440327f 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -181,6 +181,43 @@ class APITest(unittest.TestCase): services.cleanup() + def testKeywordArgs(self): + test_dir = os.path.dirname(os.path.abspath(__file__)) + test_path = os.path.join(test_dir, "testrecv.py") + services.start_singlenode_cluster(return_drivers=False, num_workers_per_objstore=3, worker_path=test_path) + x = test_functions.keyword_fct1(1) + self.assertEqual(orchpy.pull(x), "1 hello") + x = test_functions.keyword_fct1(1, "hi") + self.assertEqual(orchpy.pull(x), "1 hi") + x = test_functions.keyword_fct1(1, b="world") + self.assertEqual(orchpy.pull(x), "1 world") + + x = test_functions.keyword_fct2(a="w", b="hi") + self.assertEqual(orchpy.pull(x), "w hi") + x = test_functions.keyword_fct2(b="hi", a="w") + self.assertEqual(orchpy.pull(x), "w hi") + x = test_functions.keyword_fct2(a="w") + self.assertEqual(orchpy.pull(x), "w world") + x = test_functions.keyword_fct2(b="hi") + self.assertEqual(orchpy.pull(x), "hello hi") + x = test_functions.keyword_fct2("w") + self.assertEqual(orchpy.pull(x), "w world") + x = test_functions.keyword_fct2("w", "hi") + self.assertEqual(orchpy.pull(x), "w hi") + + x = test_functions.keyword_fct3(0, 1, c="w", d="hi") + self.assertEqual(orchpy.pull(x), "0 1 w hi") + x = test_functions.keyword_fct3(0, 1, d="hi", c="w") + self.assertEqual(orchpy.pull(x), "0 1 w hi") + x = test_functions.keyword_fct3(0, 1, c="w") + self.assertEqual(orchpy.pull(x), "0 1 w world") + x = test_functions.keyword_fct3(0, 1, d="hi") + self.assertEqual(orchpy.pull(x), "0 1 hello hi") + x = test_functions.keyword_fct3(0, 1) + self.assertEqual(orchpy.pull(x), "0 1 hello world") + + services.cleanup() + class ReferenceCountingTest(unittest.TestCase): def testDeallocation(self): diff --git a/test/test_functions.py b/test/test_functions.py index 935fef283..6af075475 100644 --- a/test/test_functions.py +++ b/test/test_functions.py @@ -38,3 +38,17 @@ def empty_function(): @orchpy.distributed([], [int]) def trivial_function(): return 1 + +# Test keyword arguments + +@orchpy.distributed([int, str], [str]) +def keyword_fct1(a, b="hello"): + return "{} {}".format(a, b) + +@orchpy.distributed([str, str], [str]) +def keyword_fct2(a="hello", b="world"): + return "{} {}".format(a, b) + +@orchpy.distributed([int, int, str, str], [str]) +def keyword_fct3(a, b, c="hello", d="world"): + return "{} {} {} {}".format(a, b, c, d)