diff --git a/python/ray/function_manager.py b/python/ray/function_manager.py index 4220feb82..ad20f0737 100644 --- a/python/ray/function_manager.py +++ b/python/ray/function_manager.py @@ -22,7 +22,6 @@ from ray import ray_constants from ray import cloudpickle as pickle from ray.utils import ( binary_to_hex, - is_cython, is_function_or_method, is_class_method, check_oversized_pickle, @@ -355,23 +354,8 @@ class FunctionActorManager(object): """ if self._worker.load_code_from_local: return - # Work around limitations of Python pickling. function = remote_function._function - function_name_global_valid = function.__name__ in function.__globals__ - function_name_global_value = function.__globals__.get( - function.__name__) - # Allow the function to reference itself as a global variable - if not is_cython(function): - function.__globals__[function.__name__] = remote_function - try: - pickled_function = pickle.dumps(function) - finally: - # Undo our changes - if function_name_global_valid: - function.__globals__[function.__name__] = ( - function_name_global_value) - else: - del function.__globals__[function.__name__] + pickled_function = pickle.dumps(function) check_oversized_pickle(pickled_function, remote_function._function_name, diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 7c1488120..92eb2d640 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -319,6 +319,38 @@ def test_nested_functions(ray_start_regular): assert ray.get(f.remote()) == (1, 2) + # Test a remote function that recursively calls itself. + + @ray.remote + def factorial(n): + if n == 0: + return 1 + return n * ray.get(factorial.remote(n - 1)) + + assert ray.get(factorial.remote(0)) == 1 + assert ray.get(factorial.remote(1)) == 1 + assert ray.get(factorial.remote(2)) == 2 + assert ray.get(factorial.remote(3)) == 6 + assert ray.get(factorial.remote(4)) == 24 + assert ray.get(factorial.remote(5)) == 120 + + # Test remote functions that recursively call each other. + + @ray.remote + def factorial_even(n): + assert n % 2 == 0 + if n == 0: + return 1 + return n * ray.get(factorial_odd.remote(n - 1)) + + @ray.remote + def factorial_odd(n): + assert n % 2 == 1 + return n * ray.get(factorial_even.remote(n - 1)) + + assert ray.get(factorial_even.remote(4)) == 24 + assert ray.get(factorial_odd.remote(5)) == 120 + def test_ray_recursive_objects(ray_start_regular): class ClassA(object): diff --git a/python/ray/tests/test_recursion.py b/python/ray/tests/test_recursion.py deleted file mode 100644 index 56b28f5e6..000000000 --- a/python/ray/tests/test_recursion.py +++ /dev/null @@ -1,25 +0,0 @@ -# This test is not inside of test_basic.py because when a recursive remote -# function is defined inside of another function, we currently can't handle -# that. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import ray - - -@ray.remote -def factorial(n): - if n == 0: - return 1 - return n * ray.get(factorial.remote(n - 1)) - - -def test_recursion(ray_start_regular): - assert ray.get(factorial.remote(0)) == 1 - assert ray.get(factorial.remote(1)) == 1 - assert ray.get(factorial.remote(2)) == 2 - assert ray.get(factorial.remote(3)) == 6 - assert ray.get(factorial.remote(4)) == 24 - assert ray.get(factorial.remote(5)) == 120