From c802e51d36454828d07e7bfa8a2aad788a005e07 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Thu, 13 Apr 2017 01:47:33 -0700 Subject: [PATCH] Re-enable recursive remote functions in a limited form. (#453) * Re-enable recursive remote functions in a limited form. * Fix linting. --- .travis.yml | 1 + python/ray/worker.py | 36 ++++++++++++++++++++++++++---------- test/recursion_test.py | 26 ++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 10 deletions(-) create mode 100644 test/recursion_test.py diff --git a/.travis.yml b/.travis.yml index de3ad1090..89e5ac493 100644 --- a/.travis.yml +++ b/.travis.yml @@ -92,3 +92,4 @@ script: - python test/stress_tests.py - python test/component_failures_test.py - python test/multi_node_test.py + - python test/recursion_test.py diff --git a/python/ray/worker.py b/python/ray/worker.py index 33f3dc9a0..8cb114b0a 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1508,9 +1508,10 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, env.__setattr__(name, environment_variable) # Export cached remote functions to the workers. for info in worker.cached_remote_functions: - function_id, func_name, func, num_return_vals, num_cpus, num_gpus = info - export_remote_function(function_id, func_name, func, num_return_vals, - num_cpus, num_gpus, worker) + (function_id, func_name, func, + func_invoker, num_return_vals, num_cpus, num_gpus) = info + export_remote_function(function_id, func_name, func, func_invoker, + num_return_vals, num_cpus, num_gpus, worker) worker.cached_functions_to_run = None worker.cached_remote_functions = None env._cached_environment_variables = None @@ -1984,8 +1985,9 @@ def _export_environment_variable(name, environment_variable, worker.redis_client.rpush("Exports", key) -def export_remote_function(function_id, func_name, func, num_return_vals, - num_cpus, num_gpus, worker=global_worker): +def export_remote_function(function_id, func_name, func, func_invoker, + num_return_vals, num_cpus, num_gpus, + worker=global_worker): check_main_thread() if _mode(worker) not in [SCRIPT_MODE, SILENT_MODE]: raise Exception("export_remote_function can only be called on a driver.") @@ -1993,7 +1995,21 @@ def export_remote_function(function_id, func_name, func, num_return_vals, worker.function_properties[worker.task_driver_id.id()][function_id.id()] = ( num_return_vals, num_cpus, num_gpus) key = "RemoteFunction:{}".format(function_id.id()) - pickled_func = pickling.dumps(func) + + # Work around limitations of Python pickling. + func_name_global_valid = func.__name__ in func.__globals__ + func_name_global_value = func.__globals__.get(func.__name__) + # Allow the function to reference itself as a global variable + func.__globals__[func.__name__] = func_invoker + try: + pickled_func = pickling.dumps(func) + finally: + # Undo our changes + if func_name_global_valid: + func.__globals__[func.__name__] = func_name_global_value + else: + del func.__globals__[func.__name__] + worker.redis_client.hmset(key, {"driver_id": worker.task_driver_id.id(), "function_id": function_id.id(), "name": func_name, @@ -2111,12 +2127,12 @@ def remote(*args, **kwargs): # Everything ready - export the function if worker.mode in [SCRIPT_MODE, SILENT_MODE]: - export_remote_function(function_id, func_name, func, num_return_vals, - num_cpus, num_gpus) + export_remote_function(function_id, func_name, func, func_invoker, + num_return_vals, num_cpus, num_gpus) elif worker.mode is None: worker.cached_remote_functions.append((function_id, func_name, func, - num_return_vals, num_cpus, - num_gpus)) + func_invoker, num_return_vals, + num_cpus, num_gpus)) return func_invoker return remote_decorator diff --git a/test/recursion_test.py b/test/recursion_test.py new file mode 100644 index 000000000..247781907 --- /dev/null +++ b/test/recursion_test.py @@ -0,0 +1,26 @@ +# This test is not inside of runtest.py because when a recursive remote +# function is defined inside of another function, we currently can't handle +# that. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ray + +ray.init() + + +@ray.remote +def factorial(n): + if n == 0: + return 1 + return n * ray.get(factorial.remote(n - 1)) + + +assert ray.get(factorial.remote(0)) == 1 +assert ray.get(factorial.remote(1)) == 1 +assert ray.get(factorial.remote(2)) == 2 +assert ray.get(factorial.remote(3)) == 6 +assert ray.get(factorial.remote(4)) == 24 +assert ray.get(factorial.remote(5)) == 120