diff --git a/python/ray/cloudpickle/__init__.py b/python/ray/cloudpickle/__init__.py index 57a8a0b87..536199f2d 100644 --- a/python/ray/cloudpickle/__init__.py +++ b/python/ray/cloudpickle/__init__.py @@ -2,4 +2,4 @@ from __future__ import absolute_import from ray.cloudpickle.cloudpickle import * -__version__ = '0.8.0.dev0' +__version__ = '1.2.2.dev0' diff --git a/python/ray/cloudpickle/cloudpickle.py b/python/ray/cloudpickle/cloudpickle.py index 54d745cbc..c92b2eac4 100644 --- a/python/ray/cloudpickle/cloudpickle.py +++ b/python/ray/cloudpickle/cloudpickle.py @@ -44,24 +44,45 @@ from __future__ import print_function import dis from functools import partial -import importlib import io import itertools import logging import opcode import operator import pickle +import platform import struct import sys import traceback import types import weakref +import uuid +import threading + + +try: + from enum import Enum +except ImportError: + Enum = None # cloudpickle is meant for inter process communication: we expect all # communicating processes to run the same Python version hence we favor # communication speed over compatibility: DEFAULT_PROTOCOL = pickle.HIGHEST_PROTOCOL +# Track the provenance of reconstructed dynamic classes to make it possible to +# recontruct instances from the matching singleton class definition when +# appropriate and preserve the usual "isinstance" semantics of Python objects. +_DYNAMIC_CLASS_TRACKER_BY_CLASS = weakref.WeakKeyDictionary() +_DYNAMIC_CLASS_TRACKER_BY_ID = weakref.WeakValueDictionary() +_DYNAMIC_CLASS_TRACKER_LOCK = threading.Lock() + +PYPY = platform.python_implementation() == "PyPy" + +builtin_code_type = None +if PYPY: + # builtin-code objects only exist in pypy + builtin_code_type = type(float.__new__.__code__) if sys.version_info[0] < 3: # pragma: no branch from pickle import Pickler @@ -71,98 +92,300 @@ if sys.version_info[0] < 3: # pragma: no branch from StringIO import StringIO string_types = (basestring,) # noqa PY3 = False + PY2 = True else: types.ClassType = type from pickle import _Pickler as Pickler from io import BytesIO as StringIO string_types = (str,) PY3 = True + PY2 = False + from importlib._bootstrap import _find_spec + +_extract_code_globals_cache = weakref.WeakKeyDictionary() -def _make_cell_set_template_code(): - """Get the Python compiler to emit LOAD_FAST(arg); STORE_DEREF +def _ensure_tracking(class_def): + with _DYNAMIC_CLASS_TRACKER_LOCK: + class_tracker_id = _DYNAMIC_CLASS_TRACKER_BY_CLASS.get(class_def) + if class_tracker_id is None: + class_tracker_id = uuid.uuid4().hex + _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id + _DYNAMIC_CLASS_TRACKER_BY_ID[class_tracker_id] = class_def + return class_tracker_id - Notes - ----- - In Python 3, we could use an easier function: - .. code-block:: python +def _lookup_class_or_track(class_tracker_id, class_def): + if class_tracker_id is not None: + with _DYNAMIC_CLASS_TRACKER_LOCK: + class_def = _DYNAMIC_CLASS_TRACKER_BY_ID.setdefault( + class_tracker_id, class_def) + _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id + return class_def - def f(): - cell = None +if sys.version_info[:2] >= (3, 5): + from pickle import _getattribute +elif sys.version_info[:2] >= (3, 4): + from pickle import _getattribute as _py34_getattribute + # pickle._getattribute does not return the parent under Python 3.4 + def _getattribute(obj, name): + return _py34_getattribute(obj, name), None +else: + # pickle._getattribute is a python3 addition and enchancement of getattr, + # that can handle dotted attribute names. In cloudpickle for python2, + # handling dotted names is not needed, so we simply define _getattribute as + # a wrapper around getattr. + def _getattribute(obj, name): + return getattr(obj, name, None), None - def _stub(value): - nonlocal cell - cell = value - return _stub +def _whichmodule(obj, name): + """Find the module an object belongs to. - _cell_set_template_code = f().__code__ - - This function is _only_ a LOAD_FAST(arg); STORE_DEREF, but that is - invalid syntax on Python 2. If we use this function we also don't need - to do the weird freevars/cellvars swap below + This function differs from ``pickle.whichmodule`` in two ways: + - it does not mangle the cases where obj's module is __main__ and obj was + not found in any module. + - Errors arising during module introspection are ignored, as those errors + are considered unwanted side effects. """ - def inner(value): - lambda: cell # make ``cell`` a closure so that we get a STORE_DEREF - cell = value - - co = inner.__code__ - - # NOTE: we are marking the cell variable as a free variable intentionally - # so that we simulate an inner function instead of the outer function. This - # is what gives us the ``nonlocal`` behavior in a Python 2 compatible way. - if not PY3: # pragma: no branch - return types.CodeType( - co.co_argcount, - co.co_nlocals, - co.co_stacksize, - co.co_flags, - co.co_code, - co.co_consts, - co.co_names, - co.co_varnames, - co.co_filename, - co.co_name, - co.co_firstlineno, - co.co_lnotab, - co.co_cellvars, # this is the trickery - (), - ) - else: - return types.CodeType( - co.co_argcount, - co.co_kwonlyargcount, - co.co_nlocals, - co.co_stacksize, - co.co_flags, - co.co_code, - co.co_consts, - co.co_names, - co.co_varnames, - co.co_filename, - co.co_name, - co.co_firstlineno, - co.co_lnotab, - co.co_cellvars, # this is the trickery - (), - ) + module_name = getattr(obj, '__module__', None) + if module_name is not None: + return module_name + # Protect the iteration by using a list copy of sys.modules against dynamic + # modules that trigger imports of other modules upon calls to getattr. + for module_name, module in list(sys.modules.items()): + if module_name == '__main__' or module is None: + continue + try: + if _getattribute(module, name)[0] is obj: + return module_name + except Exception: + pass + return None -_cell_set_template_code = _make_cell_set_template_code() +def _is_global(obj, name=None): + """Determine if obj can be pickled as attribute of a file-backed module""" + if name is None: + name = getattr(obj, '__qualname__', None) + if name is None: + name = getattr(obj, '__name__', None) + + module_name = _whichmodule(obj, name) + + if module_name is None: + # In this case, obj.__module__ is None AND obj was not found in any + # imported module. obj is thus treated as dynamic. + return False + + if module_name == "__main__": + return False + + module = sys.modules.get(module_name, None) + if module is None: + # The main reason why obj's module would not be imported is that this + # module has been dynamically created, using for example + # types.ModuleType. The other possibility is that module was removed + # from sys.modules after obj was created/imported. But this case is not + # supported, as the standard pickle does not support it either. + return False + + # module has been added to sys.modules, but it can still be dynamic. + if _is_dynamic(module): + return False + + try: + obj2, parent = _getattribute(module, name) + except AttributeError: + # obj was not found inside the module it points to + return False + return obj2 is obj + + +def _extract_code_globals(co): + """ + Find all globals names read or written to by codeblock co + """ + out_names = _extract_code_globals_cache.get(co) + if out_names is None: + names = co.co_names + out_names = {names[oparg] for _, oparg in _walk_global_ops(co)} + + # Declaring a function inside another one using the "def ..." + # syntax generates a constant code object corresonding to the one + # of the nested function's As the nested function may itself need + # global variables, we need to introspect its code, extract its + # globals, (look for code object in it's co_consts attribute..) and + # add the result to code_globals + if co.co_consts: + for const in co.co_consts: + if isinstance(const, types.CodeType): + out_names |= _extract_code_globals(const) + + _extract_code_globals_cache[co] = out_names + + return out_names + + +def _find_imported_submodules(code, top_level_dependencies): + """ + Find currently imported submodules used by a function. + + Submodules used by a function need to be detected and referenced for the + function to work correctly at depickling time. Because submodules can be + referenced as attribute of their parent package (``package.submodule``), we + need a special introspection technique that does not rely on GLOBAL-related + opcodes to find references of them in a code object. + + Example: + ``` + import concurrent.futures + import cloudpickle + def func(): + x = concurrent.futures.ThreadPoolExecutor + if __name__ == '__main__': + cloudpickle.dumps(func) + ``` + The globals extracted by cloudpickle in the function's state include the + concurrent package, but not its submodule (here, concurrent.futures), which + is the module used by func. Find_imported_submodules will detect the usage + of concurrent.futures. Saving this module alongside with func will ensure + that calling func once depickled does not fail due to concurrent.futures + not being imported + """ + + subimports = [] + # check if any known dependency is an imported package + for x in top_level_dependencies: + if (isinstance(x, types.ModuleType) and + hasattr(x, '__package__') and x.__package__): + # check if the package has any currently loaded sub-imports + prefix = x.__name__ + '.' + # A concurrent thread could mutate sys.modules, + # make sure we iterate over a copy to avoid exceptions + for name in list(sys.modules): + # Older versions of pytest will add a "None" module to + # sys.modules. + if name is not None and name.startswith(prefix): + # check whether the function can address the sub-module + tokens = set(name[len(prefix):].split('.')) + if not tokens - set(code.co_names): + subimports.append(sys.modules[name]) + return subimports def cell_set(cell, value): """Set the value of a closure cell. - """ - return types.FunctionType( - _cell_set_template_code, - {}, - '_cell_set_inner', - (), - (cell,), - )(value) + The point of this function is to set the cell_contents attribute of a cell + after its creation. This operation is necessary in case the cell contains a + reference to the function the cell belongs to, as when calling the + function's constructor + ``f = types.FunctionType(code, globals, name, argdefs, closure)``, + closure will not be able to contain the yet-to-be-created f. + + In Python3.7, cell_contents is writeable, so setting the contents of a cell + can be done simply using + >>> cell.cell_contents = value + + In earlier Python3 versions, the cell_contents attribute of a cell is read + only, but this limitation can be worked around by leveraging the Python 3 + ``nonlocal`` keyword. + + In Python2 however, this attribute is read only, and there is no + ``nonlocal`` keyword. For this reason, we need to come up with more + complicated hacks to set this attribute. + + The chosen approach is to create a function with a STORE_DEREF opcode, + which sets the content of a closure variable. Typically: + + >>> def inner(value): + ... lambda: cell # the lambda makes cell a closure + ... cell = value # cell is a closure, so this triggers a STORE_DEREF + + (Note that in Python2, A STORE_DEREF can never be triggered from an inner + function. The function g for example here + >>> def f(var): + ... def g(): + ... var += 1 + ... return g + + will not modify the closure variable ``var```inplace, but instead try to + load a local variable var and increment it. As g does not assign the local + variable ``var`` any initial value, calling f(1)() will fail at runtime.) + + Our objective is to set the value of a given cell ``cell``. So we need to + somewhat reference our ``cell`` object into the ``inner`` function so that + this object (and not the smoke cell of the lambda function) gets affected + by the STORE_DEREF operation. + + In inner, ``cell`` is referenced as a cell variable (an enclosing variable + that is referenced by the inner function). If we create a new function + cell_set with the exact same code as ``inner``, but with ``cell`` marked as + a free variable instead, the STORE_DEREF will be applied on its closure - + ``cell``, which we can specify explicitly during construction! The new + cell_set variable thus actually sets the contents of a specified cell! + + Note: we do not make use of the ``nonlocal`` keyword to set the contents of + a cell in early python3 versions to limit possible syntax errors in case + test and checker libraries decide to parse the whole file. + """ + + if sys.version_info[:2] >= (3, 7): # pragma: no branch + cell.cell_contents = value + else: + _cell_set = types.FunctionType( + _cell_set_template_code, {}, '_cell_set', (), (cell,),) + _cell_set(value) + + +def _make_cell_set_template_code(): + def _cell_set_factory(value): + lambda: cell + cell = value + + co = _cell_set_factory.__code__ + + if PY2: # pragma: no branch + _cell_set_template_code = types.CodeType( + co.co_argcount, + co.co_nlocals, + co.co_stacksize, + co.co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_firstlineno, + co.co_lnotab, + co.co_cellvars, # co_freevars is initialized with co_cellvars + (), # co_cellvars is made empty + ) + else: + _cell_set_template_code = types.CodeType( + co.co_argcount, + co.co_kwonlyargcount, # Python 3 only argument + co.co_nlocals, + co.co_stacksize, + co.co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_firstlineno, + co.co_lnotab, + co.co_cellvars, # co_freevars is initialized with co_cellvars + (), # co_cellvars is made empty + ) + return _cell_set_template_code + + +if sys.version_info[:2] < (3, 7): + _cell_set_template_code = _make_cell_set_template_code() # relevant opcodes STORE_GLOBAL = opcode.opmap['STORE_GLOBAL'] @@ -173,10 +396,6 @@ HAVE_ARGUMENT = dis.HAVE_ARGUMENT EXTENDED_ARG = dis.EXTENDED_ARG -def islambda(func): - return getattr(func, '__name__') == '' - - _BUILTIN_TYPE_NAMES = {} for k, v in types.__dict__.items(): if type(v) is type: @@ -187,32 +406,6 @@ def _builtin_type(name): return getattr(types, name) -def _make__new__factory(type_): - def _factory(): - return type_.__new__ - return _factory - - -# NOTE: These need to be module globals so that they're pickleable as globals. -_get_dict_new = _make__new__factory(dict) -_get_frozenset_new = _make__new__factory(frozenset) -_get_list_new = _make__new__factory(list) -_get_set_new = _make__new__factory(set) -_get_tuple_new = _make__new__factory(tuple) -_get_object_new = _make__new__factory(object) - -# Pre-defined set of builtin_function_or_method instances that can be -# serialized. -_BUILTIN_TYPE_CONSTRUCTORS = { - dict.__new__: _get_dict_new, - frozenset.__new__: _get_frozenset_new, - set.__new__: _get_set_new, - list.__new__: _get_list_new, - tuple.__new__: _get_tuple_new, - object.__new__: _get_object_new, -} - - if sys.version_info < (3, 4): # pragma: no branch def _walk_global_ops(code): """ @@ -220,7 +413,7 @@ if sys.version_info < (3, 4): # pragma: no branch global-referencing instructions in *code*. """ code = getattr(code, 'co_code', b'') - if not PY3: # pragma: no branch + if PY2: # pragma: no branch code = map(ord, code) n = len(code) @@ -250,6 +443,28 @@ else: yield op, instr.arg +def _extract_class_dict(cls): + """Retrieve a copy of the dict of a class without the inherited methods""" + clsdict = dict(cls.__dict__) # copy dict proxy to a dict + if len(cls.__bases__) == 1: + inherited_dict = cls.__bases__[0].__dict__ + else: + inherited_dict = {} + for base in reversed(cls.__bases__): + inherited_dict.update(base.__dict__) + to_remove = [] + for name, value in clsdict.items(): + try: + base_value = inherited_dict[name] + if value is base_value: + to_remove.append(name) + except KeyError: + pass + for name in to_remove: + clsdict.pop(name) + return clsdict + + class CloudPickler(Pickler): dispatch = Pickler.dispatch.copy() @@ -277,7 +492,7 @@ class CloudPickler(Pickler): dispatch[memoryview] = save_memoryview - if not PY3: # pragma: no branch + if PY2: # pragma: no branch def save_buffer(self, obj): self.save(str(obj)) @@ -300,12 +515,23 @@ class CloudPickler(Pickler): Save a code object """ if PY3: # pragma: no branch - args = ( - obj.co_argcount, obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize, - obj.co_flags, obj.co_code, obj.co_consts, obj.co_names, obj.co_varnames, - obj.co_filename, obj.co_name, obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, - obj.co_cellvars - ) + if hasattr(obj, "co_posonlyargcount"): # pragma: no branch + args = ( + obj.co_argcount, obj.co_posonlyargcount, + obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize, + obj.co_flags, obj.co_code, obj.co_consts, obj.co_names, + obj.co_varnames, obj.co_filename, obj.co_name, + obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, + obj.co_cellvars + ) + else: + args = ( + obj.co_argcount, obj.co_kwonlyargcount, obj.co_nlocals, + obj.co_stacksize, obj.co_flags, obj.co_code, obj.co_consts, + obj.co_names, obj.co_varnames, obj.co_filename, + obj.co_name, obj.co_firstlineno, obj.co_lnotab, + obj.co_freevars, obj.co_cellvars + ) else: args = ( obj.co_argcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, obj.co_code, @@ -322,153 +548,73 @@ class CloudPickler(Pickler): Determines what kind of function obj is (e.g. lambda, defined at interactive prompt, etc) and handles the pickling appropriately. """ - try: - should_special_case = obj in _BUILTIN_TYPE_CONSTRUCTORS - except TypeError: - # Methods of builtin types aren't hashable in python 2. - should_special_case = False - - if should_special_case: - # We keep a special-cased cache of built-in type constructors at - # global scope, because these functions are structured very - # differently in different python versions and implementations (for - # example, they're instances of types.BuiltinFunctionType in - # CPython, but they're ordinary types.FunctionType instances in - # PyPy). - # - # If the function we've received is in that cache, we just - # serialize it as a lookup into the cache. - return self.save_reduce(_BUILTIN_TYPE_CONSTRUCTORS[obj], (), obj=obj) - - write = self.write - - if name is None: - name = obj.__name__ - try: - # whichmodule() could fail, see - # https://bitbucket.org/gutworth/six/issues/63/importing-six-breaks-pickling - modname = pickle.whichmodule(obj, name) - except Exception: - modname = None - # print('which gives %s %s %s' % (modname, obj, name)) - try: - themodule = sys.modules[modname] - except KeyError: - # eval'd items such as namedtuple give invalid items for their function __module__ - modname = '__main__' - - if modname == '__main__': - themodule = None - - try: - lookedup_by_name = getattr(themodule, name, None) - except Exception: - lookedup_by_name = None - - if themodule: - if lookedup_by_name is obj: - return self.save_global(obj, name) - - # a builtin_function_or_method which comes in as an attribute of some - # object (e.g., itertools.chain.from_iterable) will end - # up with modname "__main__" and so end up here. But these functions - # have no __code__ attribute in CPython, so the handling for - # user-defined functions below will fail. - # So we pickle them here using save_reduce; have to do it differently - # for different python versions. - if not hasattr(obj, '__code__'): - if PY3: # pragma: no branch - rv = obj.__reduce_ex__(self.proto) - else: - if hasattr(obj, '__self__'): - rv = (getattr, (obj.__self__, name)) - else: - raise pickle.PicklingError("Can't pickle %r" % obj) - return self.save_reduce(obj=obj, *rv) - - # if func is lambda, def'ed at prompt, is in main, or is nested, then - # we'll pickle the actual function object rather than simply saving a - # reference (as is done in default pickler), via save_function_tuple. - if (islambda(obj) - or getattr(obj.__code__, 'co_filename', None) == '' - or themodule is None): - self.save_function_tuple(obj) - return + if _is_global(obj, name=name): + return Pickler.save_global(self, obj, name=name) + elif PYPY and isinstance(obj.__code__, builtin_code_type): + return self.save_pypy_builtin_func(obj) else: - # func is nested - if lookedup_by_name is None or lookedup_by_name is not obj: - self.save_function_tuple(obj) - return - - if obj.__dict__: - # essentially save_reduce, but workaround needed to avoid recursion - self.save(_restore_attr) - write(pickle.MARK + pickle.GLOBAL + modname + '\n' + name + '\n') - self.memoize(obj) - self.save(obj.__dict__) - write(pickle.TUPLE + pickle.REDUCE) - else: - write(pickle.GLOBAL + modname + '\n' + name + '\n') - self.memoize(obj) + return self.save_function_tuple(obj) dispatch[types.FunctionType] = save_function - def _save_subimports(self, code, top_level_dependencies): + def save_pypy_builtin_func(self, obj): + """Save pypy equivalent of builtin functions. + + PyPy does not have the concept of builtin-functions. Instead, + builtin-functions are simple function instances, but with a + builtin-code attribute. + Most of the time, builtin functions should be pickled by attribute. But + PyPy has flaky support for __qualname__, so some builtin functions such + as float.__new__ will be classified as dynamic. For this reason only, + we created this special routine. Because builtin-functions are not + expected to have closure or globals, there is no additional hack + (compared the one already implemented in pickle) to protect ourselves + from reference cycles. A simple (reconstructor, newargs, obj.__dict__) + tuple is save_reduced. + + Note also that PyPy improved their support for __qualname__ in v3.6, so + this routing should be removed when cloudpickle supports only PyPy 3.6 + and later. """ - Save submodules used by a function but not listed in its globals. + rv = (types.FunctionType, (obj.__code__, {}, obj.__name__, + obj.__defaults__, obj.__closure__), + obj.__dict__) + self.save_reduce(*rv, obj=obj) - In the example below: + def _save_dynamic_enum(self, obj, clsdict): + """Special handling for dynamic Enum subclasses - ``` - import concurrent.futures - import cloudpickle - - - def func(): - x = concurrent.futures.ThreadPoolExecutor - - - if __name__ == '__main__': - cloudpickle.dumps(func) - ``` - - the globals extracted by cloudpickle in the function's state include - the concurrent module, but not its submodule (here, - concurrent.futures), which is the module used by func. - - To ensure that calling the depickled function does not raise an - AttributeError, this function looks for any currently loaded submodule - that the function uses and whose parent is present in the function - globals, and saves it before saving the function. + Use a dedicated Enum constructor (inspired by EnumMeta.__call__) as the + EnumMeta metaclass has complex initialization that makes the Enum + subclasses hold references to their own instances. """ + members = dict((e.name, e.value) for e in obj) - # check if any known dependency is an imported package - for x in top_level_dependencies: - if isinstance(x, types.ModuleType) and hasattr(x, '__package__') and x.__package__: - # check if the package has any currently loaded sub-imports - prefix = x.__name__ + '.' - # A concurrent thread could mutate sys.modules, - # make sure we iterate over a copy to avoid exceptions - for name in list(sys.modules): - # Older versions of pytest will add a "None" module to sys.modules. - if name is not None and name.startswith(prefix): - # check whether the function can address the sub-module - tokens = set(name[len(prefix):].split('.')) - if not tokens - set(code.co_names): - # ensure unpickler executes this import - self.save(sys.modules[name]) - # then discards the reference to it - self.write(pickle.POP) + # Python 2.7 with enum34 can have no qualname: + qualname = getattr(obj, "__qualname__", None) + + self.save_reduce(_make_skeleton_enum, + (obj.__bases__, obj.__name__, qualname, members, + obj.__module__, _ensure_tracking(obj), None), + obj=obj) + + # Cleanup the clsdict that will be passed to _rehydrate_skeleton_class: + # Those attributes are already handled by the metaclass. + for attrname in ["_generate_next_value_", "_member_names_", + "_member_map_", "_member_type_", + "_value2member_map_"]: + clsdict.pop(attrname, None) + for member in members: + clsdict.pop(member) def save_dynamic_class(self, obj): - """ - Save a class that can't be stored as module global. + """Save a class that can't be stored as module global. This method is used to serialize classes that are defined inside functions, or that otherwise can't be serialized as attribute lookups from global modules. """ - clsdict = dict(obj.__dict__) # copy dict proxy to a dict + clsdict = _extract_class_dict(obj) clsdict.pop('__weakref__', None) # For ABCMeta in python3.7+, remove _abc_impl as it is not picklable. @@ -496,8 +642,8 @@ class CloudPickler(Pickler): for k in obj.__slots__: clsdict.pop(k, None) - # If type overrides __dict__ as a property, include it in the type kwargs. - # In Python 2, we can't set this attribute after construction. + # If type overrides __dict__ as a property, include it in the type + # kwargs. In Python 2, we can't set this attribute after construction. __dict__ = clsdict.pop('__dict__', None) if isinstance(__dict__, property): type_kwargs['__dict__'] = __dict__ @@ -524,8 +670,16 @@ class CloudPickler(Pickler): write(pickle.MARK) # Create and memoize an skeleton class with obj's name and bases. - tp = type(obj) - self.save_reduce(tp, (obj.__name__, obj.__bases__, type_kwargs), obj=obj) + if Enum is not None and issubclass(obj, Enum): + # Special handling of Enum subclasses + self._save_dynamic_enum(obj, clsdict) + else: + # "Regular" class definition: + tp = type(obj) + self.save_reduce(_make_skeleton_class, + (tp, obj.__name__, obj.__bases__, type_kwargs, + _ensure_tracking(obj), None), + obj=obj) # Now save the rest of obj's __dict__. Any references to obj # encountered while saving will point to the skeleton class. @@ -562,7 +716,12 @@ class CloudPickler(Pickler): save(_fill_function) # skeleton function updater write(pickle.MARK) # beginning of tuple that _fill_function expects - self._save_subimports( + # Extract currently-imported submodules used by func. Storing these + # modules in a smoke _cloudpickle_subimports attribute of the object's + # state will trigger the side effect of importing these modules at + # unpickling time (which is necessary for func to work correctly once + # depickled) + submodules = _find_imported_submodules( code, itertools.chain(f_globals.values(), closure_values or ()), ) @@ -586,45 +745,18 @@ class CloudPickler(Pickler): 'module': func.__module__, 'name': func.__name__, 'doc': func.__doc__, + '_cloudpickle_submodules': submodules } - if hasattr(func, '__annotations__') and sys.version_info >= (3, 7): + if hasattr(func, '__annotations__') and sys.version_info >= (3, 4): state['annotations'] = func.__annotations__ if hasattr(func, '__qualname__'): state['qualname'] = func.__qualname__ + if hasattr(func, '__kwdefaults__'): + state['kwdefaults'] = func.__kwdefaults__ save(state) write(pickle.TUPLE) write(pickle.REDUCE) # applies _fill_function on the tuple - _extract_code_globals_cache = ( - weakref.WeakKeyDictionary() - if not hasattr(sys, "pypy_version_info") - else {}) - - @classmethod - def extract_code_globals(cls, co): - """ - Find all globals names read or written to by codeblock co - """ - out_names = cls._extract_code_globals_cache.get(co) - if out_names is None: - try: - names = co.co_names - except AttributeError: - # PyPy "builtin-code" object - out_names = set() - else: - out_names = {names[oparg] for _, oparg in _walk_global_ops(co)} - - # see if nested function have any global refs - if co.co_consts: - for const in co.co_consts: - if type(const) is types.CodeType: - out_names |= cls.extract_code_globals(const) - - cls._extract_code_globals_cache[co] = out_names - - return out_names - def extract_func_data(self, func): """ Turn the function into a tuple of data necessary to recreate it: @@ -633,7 +765,7 @@ class CloudPickler(Pickler): code = func.__code__ # extract all global ref's - func_global_refs = self.extract_code_globals(code) + func_global_refs = _extract_code_globals(code) # process all variables referenced by global environment f_globals = {} @@ -666,14 +798,60 @@ class CloudPickler(Pickler): # multiple invokations are bound to the same Cloudpickler. base_globals = self.globals_ref.setdefault(id(func.__globals__), {}) + if base_globals == {}: + # Add module attributes used to resolve relative imports + # instructions inside func. + for k in ["__package__", "__name__", "__path__", "__file__"]: + # Some built-in functions/methods such as object.__new__ have + # their __globals__ set to None in PyPy + if func.__globals__ is not None and k in func.__globals__: + base_globals[k] = func.__globals__[k] + return (code, f_globals, defaults, closure, dct, base_globals) - def save_builtin_function(self, obj): - if obj.__module__ == "__builtin__": - return self.save_global(obj) - return self.save_function(obj) + if not PY3: # pragma: no branch + # Python3 comes with native reducers that allow builtin functions and + # methods pickling as module/class attributes. The following method + # extends this for python2. + # Please note that currently, neither pickle nor cloudpickle support + # dynamically created builtin functions/method pickling. + def save_builtin_function_or_method(self, obj): + is_bound = getattr(obj, '__self__', None) is not None + if is_bound: + # obj is a bound builtin method. + rv = (getattr, (obj.__self__, obj.__name__)) + return self.save_reduce(obj=obj, *rv) - dispatch[types.BuiltinFunctionType] = save_builtin_function + is_unbound = hasattr(obj, '__objclass__') + if is_unbound: + # obj is an unbound builtin method (accessed from its class) + rv = (getattr, (obj.__objclass__, obj.__name__)) + return self.save_reduce(obj=obj, *rv) + + # Otherwise, obj is not a method, but a function. Fallback to + # default pickling by attribute. + return Pickler.save_global(self, obj) + + dispatch[types.BuiltinFunctionType] = save_builtin_function_or_method + + # A comprehensive summary of the various kinds of builtin methods can + # be found in PEP 579: https://www.python.org/dev/peps/pep-0579/ + classmethod_descriptor_type = type(float.__dict__['fromhex']) + wrapper_descriptor_type = type(float.__repr__) + method_wrapper_type = type(1.5.__repr__) + + dispatch[classmethod_descriptor_type] = save_builtin_function_or_method + dispatch[wrapper_descriptor_type] = save_builtin_function_or_method + dispatch[method_wrapper_type] = save_builtin_function_or_method + + if sys.version_info[:2] < (3, 4): + method_descriptor = type(str.upper) + dispatch[method_descriptor] = save_builtin_function_or_method + + def save_getset_descriptor(self, obj): + return self.save_reduce(getattr, (obj.__objclass__, obj.__name__)) + + dispatch[types.GetSetDescriptorType] = save_getset_descriptor def save_global(self, obj, name=None, pack=struct.pack): """ @@ -688,23 +866,15 @@ class CloudPickler(Pickler): return self.save_reduce(type, (Ellipsis,), obj=obj) elif obj is type(NotImplemented): return self.save_reduce(type, (NotImplemented,), obj=obj) - - if obj.__module__ == "__main__": - return self.save_dynamic_class(obj) - - try: - return Pickler.save_global(self, obj, name=name) - except Exception: - if obj.__module__ == "__builtin__" or obj.__module__ == "builtins": - if obj in _BUILTIN_TYPE_NAMES: - return self.save_reduce( - _builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj) - - typ = type(obj) - if typ is not obj and isinstance(obj, (type, types.ClassType)): - return self.save_dynamic_class(obj) - - raise + elif obj in _BUILTIN_TYPE_NAMES: + return self.save_reduce( + _builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj) + elif name is not None: + Pickler.save_global(self, obj, name=name) + elif not _is_global(obj, name=name): + self.save_dynamic_class(obj) + else: + Pickler.save_global(self, obj, name=name) dispatch[type] = save_global dispatch[types.ClassType] = save_global @@ -717,8 +887,9 @@ class CloudPickler(Pickler): if PY3: # pragma: no branch self.save_reduce(types.MethodType, (obj.__func__, obj.__self__), obj=obj) else: - self.save_reduce(types.MethodType, (obj.__func__, obj.__self__, obj.__self__.__class__), - obj=obj) + self.save_reduce( + types.MethodType, + (obj.__func__, obj.__self__, type(obj.__self__)), obj=obj) dispatch[types.MethodType] = save_instancemethod @@ -767,7 +938,7 @@ class CloudPickler(Pickler): save(stuff) write(pickle.BUILD) - if not PY3: # pragma: no branch + if PY2: # pragma: no branch dispatch[types.InstanceType] = save_inst def save_property(self, obj): @@ -972,50 +1143,6 @@ def dynamic_subimport(name, vars): return mod -# restores function attributes -def _restore_attr(obj, attr): - for key, val in attr.items(): - setattr(obj, key, val) - return obj - - -def _get_module_builtins(): - return pickle.__builtins__ - - -def print_exec(stream): - ei = sys.exc_info() - traceback.print_exception(ei[0], ei[1], ei[2], None, stream) - - -def _modules_to_main(modList): - """Force every module in modList to be placed into main""" - if not modList: - return - - main = sys.modules['__main__'] - for modname in modList: - if type(modname) is str: - try: - mod = __import__(modname) - except Exception: - sys.stderr.write('warning: could not import %s\n. ' - 'Your function may unexpectedly error due to this import failing;' - 'A version mismatch is likely. Specific error was:\n' % modname) - print_exec(sys.stderr) - else: - setattr(main, mod.__name__, mod) - - -# object generators: -def _genpartial(func, args, kwds): - if not args: - args = () - if not kwds: - kwds = {} - return partial(func, *args, **kwds) - - def _gen_ellipsis(): return Ellipsis @@ -1103,6 +1230,15 @@ def _fill_function(*args): func.__module__ = state['module'] if 'qualname' in state: func.__qualname__ = state['qualname'] + if 'kwdefaults' in state: + func.__kwdefaults__ = state['kwdefaults'] + # _cloudpickle_subimports is a set of submodules that must be loaded for + # the pickled function to work correctly at unpickling time. Now that these + # submodules are depickled (hence imported), they can be removed from the + # object's state (the object state only served as a reference holder to + # these submodules) + if '_cloudpickle_submodules' in state: + state.pop('_cloudpickle_submodules') cells = func.__closure__ if cells is not None: @@ -1143,6 +1279,22 @@ def _make_skel_func(code, cell_count, base_globals=None): return types.FunctionType(code, base_globals, None, None, closure) +def _make_skeleton_class(type_constructor, name, bases, type_kwargs, + class_tracker_id, extra): + """Build dynamic class with an empty __dict__ to be filled once memoized + + If class_tracker_id is not None, try to lookup an existing class definition + matching that id. If none is found, track a newly reconstructed class + definition under that id so that other instances stemming from the same + class id will also reuse this class definition. + + The "extra" variable is meant to be a dict (or None) that can be used for + forward compatibility shall the need arise. + """ + skeleton_class = type_constructor(name, bases, type_kwargs) + return _lookup_class_or_track(class_tracker_id, skeleton_class) + + def _rehydrate_skeleton_class(skeleton_class, class_dict): """Put attributes from `class_dict` back on `skeleton_class`. @@ -1161,6 +1313,39 @@ def _rehydrate_skeleton_class(skeleton_class, class_dict): return skeleton_class +def _make_skeleton_enum(bases, name, qualname, members, module, + class_tracker_id, extra): + """Build dynamic enum with an empty __dict__ to be filled once memoized + + The creation of the enum class is inspired by the code of + EnumMeta._create_. + + If class_tracker_id is not None, try to lookup an existing enum definition + matching that id. If none is found, track a newly reconstructed enum + definition under that id so that other instances stemming from the same + class id will also reuse this enum definition. + + The "extra" variable is meant to be a dict (or None) that can be used for + forward compatibility shall the need arise. + """ + # enums always inherit from their base Enum class at the last position in + # the list of base classes: + enum_base = bases[-1] + metacls = enum_base.__class__ + classdict = metacls.__prepare__(name, bases) + + for member_name, member_value in members.items(): + classdict[member_name] = member_value + enum_class = metacls.__new__(metacls, name, bases, classdict) + enum_class.__module__ = module + + # Python 2.7 compat + if qualname is not None: + enum_class.__qualname__ = qualname + + return _lookup_class_or_track(class_tracker_id, enum_class) + + def _is_dynamic(module): """ Return True if the module is special module that cannot be imported by its @@ -1171,7 +1356,29 @@ def _is_dynamic(module): return False if hasattr(module, '__spec__'): - return module.__spec__ is None + if module.__spec__ is not None: + return False + + # In PyPy, Some built-in modules such as _codecs can have their + # __spec__ attribute set to None despite being imported. For such + # modules, the ``_find_spec`` utility of the standard library is used. + parent_name = module.__name__.rpartition('.')[0] + if parent_name: # pragma: no cover + # This code handles the case where an imported package (and not + # module) remains with __spec__ set to None. It is however untested + # as no package in the PyPy stdlib has __spec__ set to None after + # it is imported. + try: + parent = sys.modules[parent_name] + except KeyError: + msg = "parent {!r} not in sys.modules" + raise ImportError(msg.format(parent_name)) + else: + pkgpath = parent.__path__ + else: + pkgpath = None + return _find_spec(module.__name__, pkgpath, module) is None + else: # Backward compat for Python 2 import imp @@ -1186,27 +1393,3 @@ def _is_dynamic(module): except ImportError: return True return False - - -"""Constructors for 3rd party libraries -Note: These can never be renamed due to client compatibility issues""" - - -def _getobject(modname, attribute): - mod = __import__(modname, fromlist=[attribute]) - return mod.__dict__[attribute] - - -""" Use copy_reg to extend global pickle definitions """ - -if sys.version_info < (3, 4): # pragma: no branch - method_descriptor = type(str.upper) - - def _reduce_method_descriptor(obj): - return (getattr, (obj.__objclass__, obj.__name__)) - - try: - import copy_reg as copyreg - except ImportError: - import copyreg - copyreg.pickle(method_descriptor, _reduce_method_descriptor) diff --git a/rllib/tests/test_lstm.py b/rllib/tests/test_lstm.py index 3c3c16e78..dbd44175f 100644 --- a/rllib/tests/test_lstm.py +++ b/rllib/tests/test_lstm.py @@ -84,6 +84,13 @@ class RNNSpyModel(Model): capture_index = 0 def _build_layers_v2(self, input_dict, num_outputs, options): + # Previously, a new class object was created during + # deserialization and this `capture_index` + # variable would be refreshed between class instantiations. + # This behavior is no longer the case, so we manually refresh + # the variable. + RNNSpyModel.capture_index = 0 + def spy(sequences, state_in, state_out, seq_lens): if len(sequences) == 1: return 0 # don't capture inference inputs