Re-enable recursive remote functions in a limited form. (#453)

* Re-enable recursive remote functions in a limited form.

* Fix linting.
This commit is contained in:
Robert Nishihara
2017-04-13 01:47:33 -07:00
committed by Philipp Moritz
parent dad57e3b62
commit c802e51d36
3 changed files with 53 additions and 10 deletions
+26 -10
View File
@@ -1508,9 +1508,10 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker,
env.__setattr__(name, environment_variable)
# Export cached remote functions to the workers.
for info in worker.cached_remote_functions:
function_id, func_name, func, num_return_vals, num_cpus, num_gpus = info
export_remote_function(function_id, func_name, func, num_return_vals,
num_cpus, num_gpus, worker)
(function_id, func_name, func,
func_invoker, num_return_vals, num_cpus, num_gpus) = info
export_remote_function(function_id, func_name, func, func_invoker,
num_return_vals, num_cpus, num_gpus, worker)
worker.cached_functions_to_run = None
worker.cached_remote_functions = None
env._cached_environment_variables = None
@@ -1984,8 +1985,9 @@ def _export_environment_variable(name, environment_variable,
worker.redis_client.rpush("Exports", key)
def export_remote_function(function_id, func_name, func, num_return_vals,
num_cpus, num_gpus, worker=global_worker):
def export_remote_function(function_id, func_name, func, func_invoker,
num_return_vals, num_cpus, num_gpus,
worker=global_worker):
check_main_thread()
if _mode(worker) not in [SCRIPT_MODE, SILENT_MODE]:
raise Exception("export_remote_function can only be called on a driver.")
@@ -1993,7 +1995,21 @@ def export_remote_function(function_id, func_name, func, num_return_vals,
worker.function_properties[worker.task_driver_id.id()][function_id.id()] = (
num_return_vals, num_cpus, num_gpus)
key = "RemoteFunction:{}".format(function_id.id())
pickled_func = pickling.dumps(func)
# Work around limitations of Python pickling.
func_name_global_valid = func.__name__ in func.__globals__
func_name_global_value = func.__globals__.get(func.__name__)
# Allow the function to reference itself as a global variable
func.__globals__[func.__name__] = func_invoker
try:
pickled_func = pickling.dumps(func)
finally:
# Undo our changes
if func_name_global_valid:
func.__globals__[func.__name__] = func_name_global_value
else:
del func.__globals__[func.__name__]
worker.redis_client.hmset(key, {"driver_id": worker.task_driver_id.id(),
"function_id": function_id.id(),
"name": func_name,
@@ -2111,12 +2127,12 @@ def remote(*args, **kwargs):
# Everything ready - export the function
if worker.mode in [SCRIPT_MODE, SILENT_MODE]:
export_remote_function(function_id, func_name, func, num_return_vals,
num_cpus, num_gpus)
export_remote_function(function_id, func_name, func, func_invoker,
num_return_vals, num_cpus, num_gpus)
elif worker.mode is None:
worker.cached_remote_functions.append((function_id, func_name, func,
num_return_vals, num_cpus,
num_gpus))
func_invoker, num_return_vals,
num_cpus, num_gpus))
return func_invoker
return remote_decorator