diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index e5f19537a..7a5865afb 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -355,16 +355,28 @@ def check_signature_supported(function): # helper method, this should not be called by the user def check_return_values(function, result): + # If the @remote decorator declares that the function has no return values, + # then all we do is check that there were in fact no return values. + if len(function.return_types) == 0: + if result is not None: + raise Exception("The @remote decorator for function {} has 0 return values, but {} returned more than 0 values.".format(function.__name__, function.__name__)) + return + # If a function has multiple return values, Python returns a tuple of the + # values. If there is a single return value, then Python does not return a + # tuple, it simply returns the value. That is why we place result with + # (result,) when there is only one return value, so we can treat these two + # cases similarly. if len(function.return_types) == 1: result = (result,) - # if not isinstance(result, function.return_types[0]): - # raise Exception("The @remote decorator for function {} expects one return value with type {}, but {} returned a {}.".format(function.__name__, function.return_types[0], function.__name__, type(result))) - else: - if len(result) != len(function.return_types): - raise Exception("The @remote decorator for function {} has {} return values with types {}, but {} returned {} values.".format(function.__name__, len(function.return_types), function.return_types, function.__name__, len(result))) - for i in range(len(result)): - if (not issubclass(type(result[i]), function.return_types[i])) and (not isinstance(result[i], ray.lib.ObjRef)): - raise Exception("The {}th return value for function {} has type {}, but the @remote decorator expected a return value of type {} or an ObjRef.".format(i, function.__name__, type(result[i]), function.return_types[i])) + # Below we check that the number of values returned by the function match the + # number of return values declared in the @remote decorator. + if len(result) != len(function.return_types): + raise Exception("The @remote decorator for function {} has {} return values with types {}, but {} returned {} values.".format(function.__name__, len(function.return_types), function.return_types, function.__name__, len(result))) + # Here we do some limited type checking to make sure the return values have + # the right types. + for i in range(len(result)): + if (not issubclass(type(result[i]), function.return_types[i])) and (not isinstance(result[i], ray.lib.ObjRef)): + raise Exception("The {}th return value for function {} has type {}, but the @remote decorator expected a return value of type {} or an ObjRef.".format(i, function.__name__, type(result[i]), function.return_types[i])) # helper method, this should not be called by the user def check_arguments(function, args): diff --git a/scripts/example_functions.py b/scripts/example_functions.py index e424e9b5e..e3e43f24f 100644 --- a/scripts/example_functions.py +++ b/scripts/example_functions.py @@ -23,3 +23,7 @@ def dot(a, b): @ray.remote([], []) def throw_exception(): raise Exception("This function intentionally failed.") + +@ray.remote([], []) +def no_op(): + pass diff --git a/test/runtest.py b/test/runtest.py index 1df64568e..414380ae6 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -238,6 +238,43 @@ class APITest(unittest.TestCase): services.cleanup() + def testNoArgs(self): + worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py") + services.start_ray_local(num_workers=1, worker_path=worker_path) + + test_functions.no_op() + time.sleep(0.2) + task_info = ray.task_info() + self.assertEqual(len(task_info["failed_tasks"]), 0) + self.assertEqual(len(task_info["running_tasks"]), 0) + self.assertEqual(task_info["num_succeeded"], 1) + + test_functions.no_op_fail() + time.sleep(0.2) + task_info = ray.task_info() + self.assertEqual(len(task_info["failed_tasks"]), 1) + self.assertEqual(len(task_info["running_tasks"]), 0) + self.assertEqual(task_info["num_succeeded"], 1) + self.assertEqual(task_info["failed_tasks"][0].get("error_message"), "The @remote decorator for function test_functions.no_op_fail has 0 return values, but test_functions.no_op_fail returned more than 0 values.") + + services.cleanup() + + def testTypeChecking(self): + worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py") + services.start_ray_local(num_workers=1, worker_path=worker_path, driver_mode=ray.WORKER_MODE) + + # Make sure that these functions throw exceptions because there return + # values do not type check. + test_functions.test_return1() + test_functions.test_return2() + time.sleep(0.2) + task_info = ray.task_info() + self.assertEqual(len(task_info["failed_tasks"]), 2) + self.assertEqual(len(task_info["running_tasks"]), 0) + self.assertEqual(task_info["num_succeeded"], 0) + + services.cleanup() + class TaskStatusTest(unittest.TestCase): def testFailedTask(self): worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py") diff --git a/test/test_functions.py b/test/test_functions.py index dfe8a6b2a..573bfb5bf 100644 --- a/test/test_functions.py +++ b/test/test_functions.py @@ -87,6 +87,7 @@ def throw_exception_fct3(x): raise Exception("Test function 3 intentionally failed.") # test Python mode + @ray.remote([], [np.ndarray]) def python_mode_f(): return np.array([0, 0]) @@ -95,3 +96,23 @@ def python_mode_f(): def python_mode_g(x): x[0] = 1 return x + +# test no return values + +@ray.remote([], []) +def no_op(): + pass + +@ray.remote([], []) +def no_op_fail(): + return 0 + +# test wrong return types + +@ray.remote([], [int]) +def test_return1(): + return 0.0 + +@ray.remote([], [int, float]) +def test_return2(): + return 2.0, 3.0