Fix pickling (#289)

This commit is contained in:
mehrdadn
2016-07-24 12:53:55 -07:00
committed by Philipp Moritz
parent a22b35a881
commit 5ff00e0e81
2 changed files with 95 additions and 23 deletions
+68 -9
View File
@@ -1,13 +1,72 @@
import cloudpickle
# Note that a little bit of code here is taken and slightly modified from the pickler because it was not possible to change its behavior otherwise.
def serialize(function):
return cloudpickle.dumps(function)
import sys
import typing
from ctypes import c_void_p
from cloudpickle import pickle, cloudpickle, CloudPickler, load, loads
def deserialize(serialized_function):
return cloudpickle.loads(serialized_function)
try:
from ctypes import pythonapi
pythonapi.PyCell_Set # Make sure this exists
except:
pythonapi = None
def dumps(func, arg_types, return_types):
return cloudpickle.dumps((func, arg_types, return_types))
def dump(obj, file, protocol=2):
return BetterPickler(file, protocol).dump(obj)
def loads(function):
return cloudpickle.loads(function)
def dumps(obj):
stringio = cloudpickle.StringIO()
dump(obj, stringio)
return stringio.getvalue()
def _make_skel_func(code, closure, base_globals = None):
""" Creates a skeleton function object that contains just the provided
code and the correct number of cells in func_closure. All other
func attributes (e.g. func_globals) are empty.
"""
if base_globals is None: base_globals = {}
base_globals['__builtins__'] = __builtins__
return _make_skel_func.__class__(code, base_globals, None, None, tuple(closure))
def _fill_function(func, globals, defaults, closure, dict):
""" Fills in the rest of function data into the skeleton function object
that were created via _make_skel_func(), including closures.
"""
result = cloudpickle._fill_function(func, globals, defaults, dict)
if pythonapi is not None:
for i, v in enumerate(closure):
pythonapi.PyCell_Set(c_void_p(id(result.__closure__[i])), c_void_p(id(v)))
return result
def _create_type(type_repr):
return eval(type_repr.replace("~", ""), None, (lambda d: d.setdefault("typing", typing) and None or d)(dict(typing.__dict__)))
class BetterPickler(CloudPickler):
def save_function_tuple(self, func):
code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func)
self.save(_fill_function)
self.write(pickle.MARK)
self.save(_make_skel_func if pythonapi else cloudpickle._make_skel_func)
self.save((code, map(lambda _: cloudpickle._make_cell(None), closure) if closure and pythonapi is not None else closure, base_globals))
self.write(pickle.REDUCE)
self.memoize(func)
self.save(f_globals)
self.save(defaults)
self.save(closure)
self.save(dct)
self.write(pickle.TUPLE)
self.write(pickle.REDUCE)
def save_cell(self, obj):
self.save(cloudpickle._make_cell)
self.save((obj.cell_contents,))
self.write(pickle.REDUCE)
def save_type(self, obj):
self.save(_create_type)
self.save((repr(obj),))
self.write(pickle.REDUCE)
dispatch = CloudPickler.dispatch.copy()
dispatch[(lambda _: lambda: _)(0).__closure__[0].__class__] = save_cell
dispatch[typing.GenericMeta] = save_type
+27 -14
View File
@@ -669,8 +669,8 @@ def main_loop(worker=global_worker):
worker.register_function(remote(arg_types, return_types, worker)(function))
if reusable_variable is not None:
name, initializer_str, reinitializer_str = reusable_variable
initializer = pickling.deserialize(initializer_str)
reinitializer = pickling.deserialize(reinitializer_str)
initializer = pickling.loads(initializer_str)
reinitializer = pickling.loads(reinitializer_str)
reusables.__setattr__(name, Reusable(initializer, reinitializer))
if task is not None:
process_task(task)
@@ -710,7 +710,7 @@ def _export_reusable_variable(name, reusable, worker=global_worker):
"""
if _mode(worker) not in [ray.SHELL_MODE, ray.SCRIPT_MODE]:
raise Exception("_export_reusable_variable can only be called on a driver.")
ray.lib.export_reusable_variable(worker.handle, name, pickling.serialize(reusable.initializer), pickling.serialize(reusable.reinitializer))
ray.lib.export_reusable_variable(worker.handle, name, pickling.dumps(reusable.initializer), pickling.dumps(reusable.reinitializer))
def remote(arg_types, return_types, worker=global_worker):
"""This decorator is used to create remote functions.
@@ -720,16 +720,6 @@ def remote(arg_types, return_types, worker=global_worker):
return_types (List[type]): List of Python types of the return values.
"""
def remote_decorator(func):
to_export = pickling.dumps(func, arg_types, return_types) if worker.mode in [ray.SHELL_MODE, ray.SCRIPT_MODE] else None
def func_executor(arguments):
"""This gets run when the remote function is executed."""
logging.info("Calling function {}".format(func.__name__))
start_time = time.time()
result = func(*arguments)
end_time = time.time()
check_return_values(func_call, result) # throws an exception if result is invalid
logging.info("Finished executing function {}, it took {} seconds".format(func.__name__, end_time - start_time))
return result
def func_call(*args, **kwargs):
"""This gets run immediately when a worker calls a remote function."""
args = list(args)
@@ -745,6 +735,15 @@ def remote(arg_types, return_types, worker=global_worker):
return objrefs[0]
elif len(objrefs) > 1:
return objrefs
def func_executor(arguments):
"""This gets run when the remote function is executed."""
logging.info("Calling function {}".format(func.__name__))
start_time = time.time()
result = func(*arguments)
end_time = time.time()
check_return_values(func_call, result) # throws an exception if result is invalid
logging.info("Finished executing function {}, it took {} seconds".format(func.__name__, end_time - start_time))
return result
func_call.executor = func_executor
func_call.arg_types = arg_types
func_call.return_types = return_types
@@ -758,7 +757,21 @@ def remote(arg_types, return_types, worker=global_worker):
func_call.has_vararg_param = has_vararg_param
has_kwargs_param = any([v.kind == v.VAR_KEYWORD for k, v in sig_params])
check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, func_name)
if to_export is not None:
# Everything ready - export the function
to_export = None
if worker.mode in [ray.SHELL_MODE, ray.SCRIPT_MODE]:
func_name_global_valid = func.__name__ in func.__globals__
func_name_global_value = func.__globals__.get(func.__name__)
# Set the function globally to make it refer to itself
func.__globals__[func.__name__] = func_call # Allow the function to reference itself as a global variable
try:
to_export = pickling.dumps((func, arg_types, return_types))
finally:
# Undo our changes
if func_name_global_valid: func.__globals__[func.__name__] = func_name_global_value
else: del func.__globals__[func.__name__]
if to_export:
ray.lib.export_function(worker.handle, to_export)
return func_call
return remote_decorator