mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 11:01:06 +08:00
Fix pickling (#289)
This commit is contained in:
@@ -1,13 +1,72 @@
|
||||
import cloudpickle
|
||||
# Note that a little bit of code here is taken and slightly modified from the pickler because it was not possible to change its behavior otherwise.
|
||||
|
||||
def serialize(function):
|
||||
return cloudpickle.dumps(function)
|
||||
import sys
|
||||
import typing
|
||||
from ctypes import c_void_p
|
||||
from cloudpickle import pickle, cloudpickle, CloudPickler, load, loads
|
||||
|
||||
def deserialize(serialized_function):
|
||||
return cloudpickle.loads(serialized_function)
|
||||
try:
|
||||
from ctypes import pythonapi
|
||||
pythonapi.PyCell_Set # Make sure this exists
|
||||
except:
|
||||
pythonapi = None
|
||||
|
||||
def dumps(func, arg_types, return_types):
|
||||
return cloudpickle.dumps((func, arg_types, return_types))
|
||||
def dump(obj, file, protocol=2):
|
||||
return BetterPickler(file, protocol).dump(obj)
|
||||
|
||||
def loads(function):
|
||||
return cloudpickle.loads(function)
|
||||
def dumps(obj):
|
||||
stringio = cloudpickle.StringIO()
|
||||
dump(obj, stringio)
|
||||
return stringio.getvalue()
|
||||
|
||||
def _make_skel_func(code, closure, base_globals = None):
|
||||
""" Creates a skeleton function object that contains just the provided
|
||||
code and the correct number of cells in func_closure. All other
|
||||
func attributes (e.g. func_globals) are empty.
|
||||
"""
|
||||
if base_globals is None: base_globals = {}
|
||||
base_globals['__builtins__'] = __builtins__
|
||||
return _make_skel_func.__class__(code, base_globals, None, None, tuple(closure))
|
||||
|
||||
def _fill_function(func, globals, defaults, closure, dict):
|
||||
""" Fills in the rest of function data into the skeleton function object
|
||||
that were created via _make_skel_func(), including closures.
|
||||
"""
|
||||
result = cloudpickle._fill_function(func, globals, defaults, dict)
|
||||
if pythonapi is not None:
|
||||
for i, v in enumerate(closure):
|
||||
pythonapi.PyCell_Set(c_void_p(id(result.__closure__[i])), c_void_p(id(v)))
|
||||
return result
|
||||
|
||||
def _create_type(type_repr):
|
||||
return eval(type_repr.replace("~", ""), None, (lambda d: d.setdefault("typing", typing) and None or d)(dict(typing.__dict__)))
|
||||
|
||||
class BetterPickler(CloudPickler):
|
||||
def save_function_tuple(self, func):
|
||||
code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func)
|
||||
|
||||
self.save(_fill_function)
|
||||
self.write(pickle.MARK)
|
||||
|
||||
self.save(_make_skel_func if pythonapi else cloudpickle._make_skel_func)
|
||||
self.save((code, map(lambda _: cloudpickle._make_cell(None), closure) if closure and pythonapi is not None else closure, base_globals))
|
||||
self.write(pickle.REDUCE)
|
||||
self.memoize(func)
|
||||
|
||||
self.save(f_globals)
|
||||
self.save(defaults)
|
||||
self.save(closure)
|
||||
self.save(dct)
|
||||
self.write(pickle.TUPLE)
|
||||
self.write(pickle.REDUCE)
|
||||
def save_cell(self, obj):
|
||||
self.save(cloudpickle._make_cell)
|
||||
self.save((obj.cell_contents,))
|
||||
self.write(pickle.REDUCE)
|
||||
def save_type(self, obj):
|
||||
self.save(_create_type)
|
||||
self.save((repr(obj),))
|
||||
self.write(pickle.REDUCE)
|
||||
dispatch = CloudPickler.dispatch.copy()
|
||||
dispatch[(lambda _: lambda: _)(0).__closure__[0].__class__] = save_cell
|
||||
dispatch[typing.GenericMeta] = save_type
|
||||
|
||||
+27
-14
@@ -669,8 +669,8 @@ def main_loop(worker=global_worker):
|
||||
worker.register_function(remote(arg_types, return_types, worker)(function))
|
||||
if reusable_variable is not None:
|
||||
name, initializer_str, reinitializer_str = reusable_variable
|
||||
initializer = pickling.deserialize(initializer_str)
|
||||
reinitializer = pickling.deserialize(reinitializer_str)
|
||||
initializer = pickling.loads(initializer_str)
|
||||
reinitializer = pickling.loads(reinitializer_str)
|
||||
reusables.__setattr__(name, Reusable(initializer, reinitializer))
|
||||
if task is not None:
|
||||
process_task(task)
|
||||
@@ -710,7 +710,7 @@ def _export_reusable_variable(name, reusable, worker=global_worker):
|
||||
"""
|
||||
if _mode(worker) not in [ray.SHELL_MODE, ray.SCRIPT_MODE]:
|
||||
raise Exception("_export_reusable_variable can only be called on a driver.")
|
||||
ray.lib.export_reusable_variable(worker.handle, name, pickling.serialize(reusable.initializer), pickling.serialize(reusable.reinitializer))
|
||||
ray.lib.export_reusable_variable(worker.handle, name, pickling.dumps(reusable.initializer), pickling.dumps(reusable.reinitializer))
|
||||
|
||||
def remote(arg_types, return_types, worker=global_worker):
|
||||
"""This decorator is used to create remote functions.
|
||||
@@ -720,16 +720,6 @@ def remote(arg_types, return_types, worker=global_worker):
|
||||
return_types (List[type]): List of Python types of the return values.
|
||||
"""
|
||||
def remote_decorator(func):
|
||||
to_export = pickling.dumps(func, arg_types, return_types) if worker.mode in [ray.SHELL_MODE, ray.SCRIPT_MODE] else None
|
||||
def func_executor(arguments):
|
||||
"""This gets run when the remote function is executed."""
|
||||
logging.info("Calling function {}".format(func.__name__))
|
||||
start_time = time.time()
|
||||
result = func(*arguments)
|
||||
end_time = time.time()
|
||||
check_return_values(func_call, result) # throws an exception if result is invalid
|
||||
logging.info("Finished executing function {}, it took {} seconds".format(func.__name__, end_time - start_time))
|
||||
return result
|
||||
def func_call(*args, **kwargs):
|
||||
"""This gets run immediately when a worker calls a remote function."""
|
||||
args = list(args)
|
||||
@@ -745,6 +735,15 @@ def remote(arg_types, return_types, worker=global_worker):
|
||||
return objrefs[0]
|
||||
elif len(objrefs) > 1:
|
||||
return objrefs
|
||||
def func_executor(arguments):
|
||||
"""This gets run when the remote function is executed."""
|
||||
logging.info("Calling function {}".format(func.__name__))
|
||||
start_time = time.time()
|
||||
result = func(*arguments)
|
||||
end_time = time.time()
|
||||
check_return_values(func_call, result) # throws an exception if result is invalid
|
||||
logging.info("Finished executing function {}, it took {} seconds".format(func.__name__, end_time - start_time))
|
||||
return result
|
||||
func_call.executor = func_executor
|
||||
func_call.arg_types = arg_types
|
||||
func_call.return_types = return_types
|
||||
@@ -758,7 +757,21 @@ def remote(arg_types, return_types, worker=global_worker):
|
||||
func_call.has_vararg_param = has_vararg_param
|
||||
has_kwargs_param = any([v.kind == v.VAR_KEYWORD for k, v in sig_params])
|
||||
check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, func_name)
|
||||
if to_export is not None:
|
||||
|
||||
# Everything ready - export the function
|
||||
to_export = None
|
||||
if worker.mode in [ray.SHELL_MODE, ray.SCRIPT_MODE]:
|
||||
func_name_global_valid = func.__name__ in func.__globals__
|
||||
func_name_global_value = func.__globals__.get(func.__name__)
|
||||
# Set the function globally to make it refer to itself
|
||||
func.__globals__[func.__name__] = func_call # Allow the function to reference itself as a global variable
|
||||
try:
|
||||
to_export = pickling.dumps((func, arg_types, return_types))
|
||||
finally:
|
||||
# Undo our changes
|
||||
if func_name_global_valid: func.__globals__[func.__name__] = func_name_global_value
|
||||
else: del func.__globals__[func.__name__]
|
||||
if to_export:
|
||||
ray.lib.export_function(worker.handle, to_export)
|
||||
return func_call
|
||||
return remote_decorator
|
||||
|
||||
Reference in New Issue
Block a user