mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 09:05:47 +08:00
Re-enable recursive remote functions in a limited form. (#453)
* Re-enable recursive remote functions in a limited form. * Fix linting.
This commit is contained in:
committed by
Philipp Moritz
parent
dad57e3b62
commit
c802e51d36
@@ -92,3 +92,4 @@ script:
|
||||
- python test/stress_tests.py
|
||||
- python test/component_failures_test.py
|
||||
- python test/multi_node_test.py
|
||||
- python test/recursion_test.py
|
||||
|
||||
+26
-10
@@ -1508,9 +1508,10 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker,
|
||||
env.__setattr__(name, environment_variable)
|
||||
# Export cached remote functions to the workers.
|
||||
for info in worker.cached_remote_functions:
|
||||
function_id, func_name, func, num_return_vals, num_cpus, num_gpus = info
|
||||
export_remote_function(function_id, func_name, func, num_return_vals,
|
||||
num_cpus, num_gpus, worker)
|
||||
(function_id, func_name, func,
|
||||
func_invoker, num_return_vals, num_cpus, num_gpus) = info
|
||||
export_remote_function(function_id, func_name, func, func_invoker,
|
||||
num_return_vals, num_cpus, num_gpus, worker)
|
||||
worker.cached_functions_to_run = None
|
||||
worker.cached_remote_functions = None
|
||||
env._cached_environment_variables = None
|
||||
@@ -1984,8 +1985,9 @@ def _export_environment_variable(name, environment_variable,
|
||||
worker.redis_client.rpush("Exports", key)
|
||||
|
||||
|
||||
def export_remote_function(function_id, func_name, func, num_return_vals,
|
||||
num_cpus, num_gpus, worker=global_worker):
|
||||
def export_remote_function(function_id, func_name, func, func_invoker,
|
||||
num_return_vals, num_cpus, num_gpus,
|
||||
worker=global_worker):
|
||||
check_main_thread()
|
||||
if _mode(worker) not in [SCRIPT_MODE, SILENT_MODE]:
|
||||
raise Exception("export_remote_function can only be called on a driver.")
|
||||
@@ -1993,7 +1995,21 @@ def export_remote_function(function_id, func_name, func, num_return_vals,
|
||||
worker.function_properties[worker.task_driver_id.id()][function_id.id()] = (
|
||||
num_return_vals, num_cpus, num_gpus)
|
||||
key = "RemoteFunction:{}".format(function_id.id())
|
||||
pickled_func = pickling.dumps(func)
|
||||
|
||||
# Work around limitations of Python pickling.
|
||||
func_name_global_valid = func.__name__ in func.__globals__
|
||||
func_name_global_value = func.__globals__.get(func.__name__)
|
||||
# Allow the function to reference itself as a global variable
|
||||
func.__globals__[func.__name__] = func_invoker
|
||||
try:
|
||||
pickled_func = pickling.dumps(func)
|
||||
finally:
|
||||
# Undo our changes
|
||||
if func_name_global_valid:
|
||||
func.__globals__[func.__name__] = func_name_global_value
|
||||
else:
|
||||
del func.__globals__[func.__name__]
|
||||
|
||||
worker.redis_client.hmset(key, {"driver_id": worker.task_driver_id.id(),
|
||||
"function_id": function_id.id(),
|
||||
"name": func_name,
|
||||
@@ -2111,12 +2127,12 @@ def remote(*args, **kwargs):
|
||||
|
||||
# Everything ready - export the function
|
||||
if worker.mode in [SCRIPT_MODE, SILENT_MODE]:
|
||||
export_remote_function(function_id, func_name, func, num_return_vals,
|
||||
num_cpus, num_gpus)
|
||||
export_remote_function(function_id, func_name, func, func_invoker,
|
||||
num_return_vals, num_cpus, num_gpus)
|
||||
elif worker.mode is None:
|
||||
worker.cached_remote_functions.append((function_id, func_name, func,
|
||||
num_return_vals, num_cpus,
|
||||
num_gpus))
|
||||
func_invoker, num_return_vals,
|
||||
num_cpus, num_gpus))
|
||||
return func_invoker
|
||||
|
||||
return remote_decorator
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
# This test is not inside of runtest.py because when a recursive remote
|
||||
# function is defined inside of another function, we currently can't handle
|
||||
# that.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
|
||||
ray.init()
|
||||
|
||||
|
||||
@ray.remote
|
||||
def factorial(n):
|
||||
if n == 0:
|
||||
return 1
|
||||
return n * ray.get(factorial.remote(n - 1))
|
||||
|
||||
|
||||
assert ray.get(factorial.remote(0)) == 1
|
||||
assert ray.get(factorial.remote(1)) == 1
|
||||
assert ray.get(factorial.remote(2)) == 2
|
||||
assert ray.get(factorial.remote(3)) == 6
|
||||
assert ray.get(factorial.remote(4)) == 24
|
||||
assert ray.get(factorial.remote(5)) == 120
|
||||
Reference in New Issue
Block a user