mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 00:29:38 +08:00
Re-enable recursive remote functions in a limited form. (#453)
* Re-enable recursive remote functions in a limited form. * Fix linting.
This commit is contained in:
committed by
Philipp Moritz
parent
dad57e3b62
commit
c802e51d36
+26
-10
@@ -1508,9 +1508,10 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker,
|
||||
env.__setattr__(name, environment_variable)
|
||||
# Export cached remote functions to the workers.
|
||||
for info in worker.cached_remote_functions:
|
||||
function_id, func_name, func, num_return_vals, num_cpus, num_gpus = info
|
||||
export_remote_function(function_id, func_name, func, num_return_vals,
|
||||
num_cpus, num_gpus, worker)
|
||||
(function_id, func_name, func,
|
||||
func_invoker, num_return_vals, num_cpus, num_gpus) = info
|
||||
export_remote_function(function_id, func_name, func, func_invoker,
|
||||
num_return_vals, num_cpus, num_gpus, worker)
|
||||
worker.cached_functions_to_run = None
|
||||
worker.cached_remote_functions = None
|
||||
env._cached_environment_variables = None
|
||||
@@ -1984,8 +1985,9 @@ def _export_environment_variable(name, environment_variable,
|
||||
worker.redis_client.rpush("Exports", key)
|
||||
|
||||
|
||||
def export_remote_function(function_id, func_name, func, num_return_vals,
|
||||
num_cpus, num_gpus, worker=global_worker):
|
||||
def export_remote_function(function_id, func_name, func, func_invoker,
|
||||
num_return_vals, num_cpus, num_gpus,
|
||||
worker=global_worker):
|
||||
check_main_thread()
|
||||
if _mode(worker) not in [SCRIPT_MODE, SILENT_MODE]:
|
||||
raise Exception("export_remote_function can only be called on a driver.")
|
||||
@@ -1993,7 +1995,21 @@ def export_remote_function(function_id, func_name, func, num_return_vals,
|
||||
worker.function_properties[worker.task_driver_id.id()][function_id.id()] = (
|
||||
num_return_vals, num_cpus, num_gpus)
|
||||
key = "RemoteFunction:{}".format(function_id.id())
|
||||
pickled_func = pickling.dumps(func)
|
||||
|
||||
# Work around limitations of Python pickling.
|
||||
func_name_global_valid = func.__name__ in func.__globals__
|
||||
func_name_global_value = func.__globals__.get(func.__name__)
|
||||
# Allow the function to reference itself as a global variable
|
||||
func.__globals__[func.__name__] = func_invoker
|
||||
try:
|
||||
pickled_func = pickling.dumps(func)
|
||||
finally:
|
||||
# Undo our changes
|
||||
if func_name_global_valid:
|
||||
func.__globals__[func.__name__] = func_name_global_value
|
||||
else:
|
||||
del func.__globals__[func.__name__]
|
||||
|
||||
worker.redis_client.hmset(key, {"driver_id": worker.task_driver_id.id(),
|
||||
"function_id": function_id.id(),
|
||||
"name": func_name,
|
||||
@@ -2111,12 +2127,12 @@ def remote(*args, **kwargs):
|
||||
|
||||
# Everything ready - export the function
|
||||
if worker.mode in [SCRIPT_MODE, SILENT_MODE]:
|
||||
export_remote_function(function_id, func_name, func, num_return_vals,
|
||||
num_cpus, num_gpus)
|
||||
export_remote_function(function_id, func_name, func, func_invoker,
|
||||
num_return_vals, num_cpus, num_gpus)
|
||||
elif worker.mode is None:
|
||||
worker.cached_remote_functions.append((function_id, func_name, func,
|
||||
num_return_vals, num_cpus,
|
||||
num_gpus))
|
||||
func_invoker, num_return_vals,
|
||||
num_cpus, num_gpus))
|
||||
return func_invoker
|
||||
|
||||
return remote_decorator
|
||||
|
||||
Reference in New Issue
Block a user