From 5ff00e0e81f563b9b45fccd6b50902b062107ba3 Mon Sep 17 00:00:00 2001 From: mehrdadn Date: Sun, 24 Jul 2016 12:53:55 -0700 Subject: [PATCH] Fix pickling (#289) --- lib/python/ray/pickling.py | 77 +++++++++++++++++++++++++++++++++----- lib/python/ray/worker.py | 41 +++++++++++++------- 2 files changed, 95 insertions(+), 23 deletions(-) diff --git a/lib/python/ray/pickling.py b/lib/python/ray/pickling.py index 1d3c27afd..4fdf05a33 100644 --- a/lib/python/ray/pickling.py +++ b/lib/python/ray/pickling.py @@ -1,13 +1,72 @@ -import cloudpickle +# Note that a little bit of code here is taken and slightly modified from the pickler because it was not possible to change its behavior otherwise. -def serialize(function): - return cloudpickle.dumps(function) +import sys +import typing +from ctypes import c_void_p +from cloudpickle import pickle, cloudpickle, CloudPickler, load, loads -def deserialize(serialized_function): - return cloudpickle.loads(serialized_function) +try: + from ctypes import pythonapi + pythonapi.PyCell_Set # Make sure this exists +except: + pythonapi = None -def dumps(func, arg_types, return_types): - return cloudpickle.dumps((func, arg_types, return_types)) +def dump(obj, file, protocol=2): + return BetterPickler(file, protocol).dump(obj) -def loads(function): - return cloudpickle.loads(function) +def dumps(obj): + stringio = cloudpickle.StringIO() + dump(obj, stringio) + return stringio.getvalue() + +def _make_skel_func(code, closure, base_globals = None): + """ Creates a skeleton function object that contains just the provided + code and the correct number of cells in func_closure. All other + func attributes (e.g. func_globals) are empty. + """ + if base_globals is None: base_globals = {} + base_globals['__builtins__'] = __builtins__ + return _make_skel_func.__class__(code, base_globals, None, None, tuple(closure)) + +def _fill_function(func, globals, defaults, closure, dict): + """ Fills in the rest of function data into the skeleton function object + that were created via _make_skel_func(), including closures. + """ + result = cloudpickle._fill_function(func, globals, defaults, dict) + if pythonapi is not None: + for i, v in enumerate(closure): + pythonapi.PyCell_Set(c_void_p(id(result.__closure__[i])), c_void_p(id(v))) + return result + +def _create_type(type_repr): + return eval(type_repr.replace("~", ""), None, (lambda d: d.setdefault("typing", typing) and None or d)(dict(typing.__dict__))) + +class BetterPickler(CloudPickler): + def save_function_tuple(self, func): + code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func) + + self.save(_fill_function) + self.write(pickle.MARK) + + self.save(_make_skel_func if pythonapi else cloudpickle._make_skel_func) + self.save((code, map(lambda _: cloudpickle._make_cell(None), closure) if closure and pythonapi is not None else closure, base_globals)) + self.write(pickle.REDUCE) + self.memoize(func) + + self.save(f_globals) + self.save(defaults) + self.save(closure) + self.save(dct) + self.write(pickle.TUPLE) + self.write(pickle.REDUCE) + def save_cell(self, obj): + self.save(cloudpickle._make_cell) + self.save((obj.cell_contents,)) + self.write(pickle.REDUCE) + def save_type(self, obj): + self.save(_create_type) + self.save((repr(obj),)) + self.write(pickle.REDUCE) + dispatch = CloudPickler.dispatch.copy() + dispatch[(lambda _: lambda: _)(0).__closure__[0].__class__] = save_cell + dispatch[typing.GenericMeta] = save_type diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index d5e676e0f..19fd68338 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -669,8 +669,8 @@ def main_loop(worker=global_worker): worker.register_function(remote(arg_types, return_types, worker)(function)) if reusable_variable is not None: name, initializer_str, reinitializer_str = reusable_variable - initializer = pickling.deserialize(initializer_str) - reinitializer = pickling.deserialize(reinitializer_str) + initializer = pickling.loads(initializer_str) + reinitializer = pickling.loads(reinitializer_str) reusables.__setattr__(name, Reusable(initializer, reinitializer)) if task is not None: process_task(task) @@ -710,7 +710,7 @@ def _export_reusable_variable(name, reusable, worker=global_worker): """ if _mode(worker) not in [ray.SHELL_MODE, ray.SCRIPT_MODE]: raise Exception("_export_reusable_variable can only be called on a driver.") - ray.lib.export_reusable_variable(worker.handle, name, pickling.serialize(reusable.initializer), pickling.serialize(reusable.reinitializer)) + ray.lib.export_reusable_variable(worker.handle, name, pickling.dumps(reusable.initializer), pickling.dumps(reusable.reinitializer)) def remote(arg_types, return_types, worker=global_worker): """This decorator is used to create remote functions. @@ -720,16 +720,6 @@ def remote(arg_types, return_types, worker=global_worker): return_types (List[type]): List of Python types of the return values. """ def remote_decorator(func): - to_export = pickling.dumps(func, arg_types, return_types) if worker.mode in [ray.SHELL_MODE, ray.SCRIPT_MODE] else None - def func_executor(arguments): - """This gets run when the remote function is executed.""" - logging.info("Calling function {}".format(func.__name__)) - start_time = time.time() - result = func(*arguments) - end_time = time.time() - check_return_values(func_call, result) # throws an exception if result is invalid - logging.info("Finished executing function {}, it took {} seconds".format(func.__name__, end_time - start_time)) - return result def func_call(*args, **kwargs): """This gets run immediately when a worker calls a remote function.""" args = list(args) @@ -745,6 +735,15 @@ def remote(arg_types, return_types, worker=global_worker): return objrefs[0] elif len(objrefs) > 1: return objrefs + def func_executor(arguments): + """This gets run when the remote function is executed.""" + logging.info("Calling function {}".format(func.__name__)) + start_time = time.time() + result = func(*arguments) + end_time = time.time() + check_return_values(func_call, result) # throws an exception if result is invalid + logging.info("Finished executing function {}, it took {} seconds".format(func.__name__, end_time - start_time)) + return result func_call.executor = func_executor func_call.arg_types = arg_types func_call.return_types = return_types @@ -758,7 +757,21 @@ def remote(arg_types, return_types, worker=global_worker): func_call.has_vararg_param = has_vararg_param has_kwargs_param = any([v.kind == v.VAR_KEYWORD for k, v in sig_params]) check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, func_name) - if to_export is not None: + + # Everything ready - export the function + to_export = None + if worker.mode in [ray.SHELL_MODE, ray.SCRIPT_MODE]: + func_name_global_valid = func.__name__ in func.__globals__ + func_name_global_value = func.__globals__.get(func.__name__) + # Set the function globally to make it refer to itself + func.__globals__[func.__name__] = func_call # Allow the function to reference itself as a global variable + try: + to_export = pickling.dumps((func, arg_types, return_types)) + finally: + # Undo our changes + if func_name_global_valid: func.__globals__[func.__name__] = func_name_global_value + else: del func.__globals__[func.__name__] + if to_export: ray.lib.export_function(worker.handle, to_export) return func_call return remote_decorator