From a21523c709fd50538430070fdbe05992f575d5e6 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Wed, 2 Dec 2020 00:52:17 -0800 Subject: [PATCH] [tune/core] serialization debugging utility (#12142) Co-authored-by: SangBin Cho Co-authored-by: Kai Fricke --- doc/source/package-ref.rst | 6 +- doc/source/serialization.rst | 42 +++++ python/ray/cloudpickle/__init__.py | 38 ++++- python/ray/setup-dev.py | 1 + python/ray/tests/conftest.py | 9 +- python/ray/tests/test_serialization.py | 30 ++++ python/ray/tune/experiment.py | 21 +-- python/ray/tune/registry.py | 2 +- python/ray/tune/tests/test_api.py | 46 +++++- python/ray/tune/utils/serialization.py | 4 +- python/ray/util/__init__.py | 3 +- python/ray/util/check_serialize.py | 216 +++++++++++++++++++++++++ 12 files changed, 394 insertions(+), 24 deletions(-) create mode 100644 python/ray/util/check_serialize.py diff --git a/doc/source/package-ref.rst b/doc/source/package-ref.rst index dd6defcb2..db3cbd560 100644 --- a/doc/source/package-ref.rst +++ b/doc/source/package-ref.rst @@ -198,11 +198,13 @@ Histogram .. _package-ref-debugging-apis: -Debugger APIs -------------- +Debugging APIs +-------------- .. autofunction:: ray.util.pdb.set_trace +.. autofunction:: ray.util.inspect_serializability + Experimental APIs ----------------- diff --git a/doc/source/serialization.rst b/doc/source/serialization.rst index e1cadf7cb..6bdada7d0 100644 --- a/doc/source/serialization.rst +++ b/doc/source/serialization.rst @@ -64,6 +64,48 @@ Serialization notes - Lock objects are mostly unserializable, because copying a lock is meaningless and could cause serious concurrency problems. You may have to come up with a workaround if your object contains a lock. +Troubleshooting +--------------- + +Use ``ray.util.inspect_serializability`` to identify tricky pickling issues. This function can be used to trace a potential non-serializable object within any Python object -- whether it be a function, class, or object instance. + +Below, we demonstrate this behavior on a function with a non-serializable object (threading lock): + +.. code-block:: python + + from ray.util import inspect_serializability + import threading + + lock = threading.Lock() + + def test(): + print(lock) + + inspect_serializability(test, name="test") + +The resulting output is: + + +.. code-block:: bash + + ============================================================= + Checking Serializability of + ============================================================= + !!! FAIL serialization: can't pickle _thread.lock objects + Detected 1 global variables. Checking serializability... + Serializing 'lock' ... + !!! FAIL serialization: can't pickle _thread.lock objects + WARNING: Did not find non-serializable object in . This may be an oversight. + ============================================================= + Variable: + + lock [obj=, parent=] + + was found to be non-serializable. There may be multiple other undetected variables that were non-serializable. + Consider either removing the instantiation/imports of these variables or moving the instantiation into the scope of the function/class. + If you have any suggestions on how to improve this error message, please reach out to the Ray developers on github.com/ray-project/ray/issues/ + ============================================================= + Known Issues ------------ diff --git a/python/ray/cloudpickle/__init__.py b/python/ray/cloudpickle/__init__.py index b28e91ee8..fc90f8b3c 100644 --- a/python/ray/cloudpickle/__init__.py +++ b/python/ray/cloudpickle/__init__.py @@ -1,11 +1,47 @@ from __future__ import absolute_import +import os +from pickle import PicklingError from ray.cloudpickle.cloudpickle import * # noqa from ray.cloudpickle.cloudpickle_fast import CloudPickler, dumps, dump # noqa + # Conform to the convention used by python serialization libraries, which # expose their Pickler subclass at top-level under the "Pickler" name. Pickler = CloudPickler -__version__ = '1.6.0' +__version__ = "1.6.0" + + +def _warn_msg(obj, method, exc): + return ( + f"{method}({str(obj)}) failed." + "\nTo check which non-serializable variables are captured " + "in scope, re-run the ray script with 'RAY_PICKLE_VERBOSE_DEBUG=1'.") + + +def dump_debug(obj, *args, **kwargs): + try: + return dump(obj, *args, **kwargs) + except (TypeError, PicklingError) as exc: + if os.environ.get("RAY_PICKLE_VERBOSE_DEBUG"): + from ray.util.check_serialize import inspect_serializability + inspect_serializability(obj) + raise + else: + msg = _warn_msg(obj, "ray.cloudpickle.dump", exc) + raise type(exc)(msg) + + +def dumps_debug(obj, *args, **kwargs): + try: + return dumps(obj, *args, **kwargs) + except (TypeError, PicklingError) as exc: + if os.environ.get("RAY_PICKLE_VERBOSE_DEBUG"): + from ray.util.check_serialize import inspect_serializability + inspect_serializability(obj) + raise + else: + msg = _warn_msg(obj, "ray.cloudpickle.dumps", exc) + raise type(exc)(msg) diff --git a/python/ray/setup-dev.py b/python/ray/setup-dev.py index 88b526010..963b2b977 100755 --- a/python/ray/setup-dev.py +++ b/python/ray/setup-dev.py @@ -61,6 +61,7 @@ if __name__ == "__main__": do_link("rllib", force=args.yes, local_path="../../../rllib") do_link("tune", force=args.yes) do_link("autoscaler", force=args.yes) + do_link("cloudpickle", force=args.yes) do_link("scripts", force=args.yes) do_link("internal", force=args.yes) do_link("tests", force=args.yes) diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index 017f89afe..1c48a28d3 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -1,7 +1,7 @@ """ This file defines the common pytest fixtures used in current directory. """ - +import os from contextlib import contextmanager import pytest import subprocess @@ -204,6 +204,13 @@ def call_ray_stop_only(): subprocess.check_call(["ray", "stop"]) +@pytest.fixture +def enable_pickle_debug(): + os.environ["RAY_PICKLE_VERBOSE_DEBUG"] = "1" + yield + del os.environ["RAY_PICKLE_VERBOSE_DEBUG"] + + @pytest.fixture() def two_node_cluster(): system_config = { diff --git a/python/ray/tests/test_serialization.py b/python/ray/tests/test_serialization.py index 240fe4bc6..0c88ebd22 100644 --- a/python/ray/tests/test_serialization.py +++ b/python/ray/tests/test_serialization.py @@ -331,6 +331,36 @@ def test_numpy_subclass_serialization_pickle(ray_start_regular): assert repr_orig == repr_ser +def test_inspect_serialization(enable_pickle_debug): + import threading + from ray.cloudpickle import dumps_debug + + lock = threading.Lock() + + with pytest.raises(TypeError): + dumps_debug(lock) + + def test_func(): + print(lock) + + with pytest.raises(TypeError): + dumps_debug(test_func) + + class test_class: + def test(self): + self.lock = lock + + from ray.util.check_serialize import inspect_serializability + results = inspect_serializability(lock) + assert list(results[1])[0].obj == lock, results + + results = inspect_serializability(test_func) + assert list(results[1])[0].obj == lock, results + + results = inspect_serializability(test_class) + assert list(results[1])[0].obj == lock, results + + @pytest.mark.parametrize( "ray_start_regular", [{ "local_mode": True diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index 98ae6e640..bf50951a3 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -286,21 +286,12 @@ class Experiment: try: register_trainable(name, run_object) except (TypeError, PicklingError) as e: - msg = ( - f"{str(e)}. The trainable ({str(run_object)}) could not " - "be serialized, which is needed for parallel execution. " - "To diagnose the issue, try the following:\n\n" - "\t- Run `tune.utils.diagnose_serialization(trainable)` " - "to check if non-serializable variables are captured " - "in scope.\n" - "\t- Try reproducing the issue by calling " - "`pickle.dumps(trainable)`.\n" - "\t- If the error is typing-related, try removing " - "the type annotations and try again.\n\n" - "If you have any suggestions on how to improve " - "this error message, please reach out to the " - "Ray developers on github.com/ray-project/ray/issues/") - raise type(e)(msg) from None + extra_msg = (f"Other options: " + "\n-Try reproducing the issue by calling " + "`pickle.dumps(trainable)`. " + "\n-If the error is typing-related, try removing " + "the type annotations and try again.") + raise type(e)(str(e) + " " + extra_msg) from None return name else: raise TuneError("Improper 'run' - not string nor trainable.") diff --git a/python/ray/tune/registry.py b/python/ray/tune/registry.py index 7409aaa0b..5ffa4e416 100644 --- a/python/ray/tune/registry.py +++ b/python/ray/tune/registry.py @@ -119,7 +119,7 @@ class _Registry: from ray.tune import TuneError raise TuneError("Unknown category {} not among {}".format( category, KNOWN_CATEGORIES)) - self._to_flush[(category, key)] = pickle.dumps(value) + self._to_flush[(category, key)] = pickle.dumps_debug(value) if _internal_kv_initialized(): self.flush_values() diff --git a/python/ray/tune/tests/test_api.py b/python/ray/tune/tests/test_api.py index d16a512f7..ea325fa45 100644 --- a/python/ray/tune/tests/test_api.py +++ b/python/ray/tune/tests/test_api.py @@ -1038,8 +1038,6 @@ class TrainableFunctionApiTest(unittest.TestCase): self.assertLessEqual(status.get("PENDING", 0), 1) def testMetricCheckingEndToEnd(self): - from ray import tune - def train(config): tune.report(val=4, second=8) @@ -1130,6 +1128,50 @@ class TrainableFunctionApiTest(unittest.TestCase): self.assertFalse(found) +class SerializabilityTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + ray.init(local_mode=True) + + @classmethod + def tearDownClass(cls): + ray.shutdown() + + def tearDown(self): + if "RAY_PICKLE_VERBOSE_DEBUG" in os.environ: + del os.environ["RAY_PICKLE_VERBOSE_DEBUG"] + + def testNotRaisesNonserializable(self): + import threading + lock = threading.Lock() + + def train(config): + print(lock) + tune.report(val=4, second=8) + + with self.assertRaisesRegex(TypeError, "RAY_PICKLE_VERBOSE_DEBUG"): + # The trial runner raises a ValueError, but the experiment fails + # with a TuneError + tune.run(train, metric="acc") + + def testRaisesNonserializable(self): + os.environ["RAY_PICKLE_VERBOSE_DEBUG"] = "1" + import threading + lock = threading.Lock() + + def train(config): + print(lock) + tune.report(val=4, second=8) + + with self.assertRaises(TypeError) as cm: + # The trial runner raises a ValueError, but the experiment fails + # with a TuneError + tune.run(train, metric="acc") + msg = cm.exception.args[0] + assert "RAY_PICKLE_VERBOSE_DEBUG" not in msg + assert "thread.lock" in msg + + class ShimCreationTest(unittest.TestCase): def testCreateScheduler(self): kwargs = {"metric": "metric_foo", "mode": "min"} diff --git a/python/ray/tune/utils/serialization.py b/python/ray/tune/utils/serialization.py index de6061647..5efc3a5ae 100644 --- a/python/ray/tune/utils/serialization.py +++ b/python/ray/tune/utils/serialization.py @@ -4,6 +4,7 @@ import types from ray import cloudpickle as cloudpickle from ray.utils import binary_to_hex, hex_to_binary +from ray.util.debug import log_once logger = logging.getLogger(__name__) @@ -15,7 +16,8 @@ class TuneFunctionEncoder(json.JSONEncoder): try: return super(TuneFunctionEncoder, self).default(obj) except Exception: - logger.debug("Unable to encode. Falling back to cloudpickle.") + if log_once(f"tune_func_encode:{str(obj)}"): + logger.debug("Unable to encode. Falling back to cloudpickle.") return self._to_cloudpickle(obj) def _to_cloudpickle(self, obj): diff --git a/python/ray/util/__init__.py b/python/ray/util/__init__.py index 2a6d0a029..be876c649 100644 --- a/python/ray/util/__init__.py +++ b/python/ray/util/__init__.py @@ -1,5 +1,6 @@ from ray.util import iter from ray.util.actor_pool import ActorPool +from ray.util.check_serialize import inspect_serializability from ray.util.debug import log_once, disable_log_once_globally, \ enable_periodic_logging from ray.util.placement_group import (placement_group, placement_group_table, @@ -9,5 +10,5 @@ from ray.util import rpdb as pdb __all__ = [ "ActorPool", "disable_log_once_globally", "enable_periodic_logging", "iter", "log_once", "pdb", "placement_group", "placement_group_table", - "remove_placement_group" + "remove_placement_group", "inspect_serializability" ] diff --git a/python/ray/util/check_serialize.py b/python/ray/util/check_serialize.py new file mode 100644 index 000000000..4b8cb2c19 --- /dev/null +++ b/python/ray/util/check_serialize.py @@ -0,0 +1,216 @@ +"""A utility for debugging serialization issues.""" +from typing import Any, Tuple, Set, Optional +import inspect +import ray.cloudpickle as cp +import colorama +from contextlib import contextmanager + + +@contextmanager +def _indent(printer): + printer.level += 1 + yield + printer.level -= 1 + + +class _Printer: + def __init__(self): + self.level = 0 + + def indent(self): + return _indent(self) + + def print(self, msg): + indent = " " * self.level + print(indent + msg) + + +_printer = _Printer() + + +class FailureTuple: + """Represents the serialization 'frame'. + + Attributes: + obj: The object that fails serialization. + name: The variable name of the object. + parent: The object that references the `obj`. + """ + + def __init__(self, obj: Any, name: str, parent: Any): + self.obj = obj + self.name = name + self.parent = parent + + def __repr__(self): + return f"FailTuple({self.name} [obj={self.obj}, parent={self.parent}])" + + +def _inspect_func_serialization(base_obj, depth, parent, failure_set): + """Adds the first-found non-serializable element to the failure_set.""" + assert inspect.isfunction(base_obj) + closure = inspect.getclosurevars(base_obj) + found = False + if closure.globals: + _printer.print(f"Detected {len(closure.globals)} global variables. " + "Checking serializability...") + + with _printer.indent(): + for name, obj in closure.globals.items(): + serializable, _ = inspect_serializability( + obj, + name=name, + depth=depth - 1, + _parent=parent, + _failure_set=failure_set) + found = found or not serializable + if found: + break + + if closure.nonlocals: + _printer.print( + f"Detected {len(closure.nonlocals)} nonlocal variables. " + "Checking serializability...") + with _printer.indent(): + for name, obj in closure.nonlocals.items(): + serializable, _ = inspect_serializability( + obj, + name=name, + depth=depth - 1, + _parent=parent, + _failure_set=failure_set) + found = found or not serializable + if found: + break + if not found: + _printer.print( + f"WARNING: Did not find non-serializable object in {base_obj}. " + "This may be an oversight.") + return found + + +def _inspect_generic_serialization(base_obj, depth, parent, failure_set): + """Adds the first-found non-serializable element to the failure_set.""" + assert not inspect.isfunction(base_obj) + functions = inspect.getmembers(base_obj, predicate=inspect.isfunction) + found = False + with _printer.indent(): + for name, obj in functions: + serializable, _ = inspect_serializability( + obj, + name=name, + depth=depth - 1, + _parent=parent, + _failure_set=failure_set) + found = found or not serializable + if found: + break + + with _printer.indent(): + members = inspect.getmembers(base_obj) + for name, obj in members: + if name.startswith("__") and name.endswith( + "__") or inspect.isbuiltin(obj): + continue + serializable, _ = inspect_serializability( + obj, + name=name, + depth=depth - 1, + _parent=parent, + _failure_set=failure_set) + found = found or not serializable + if found: + break + if not found: + _printer.print( + f"WARNING: Did not find non-serializable object in {base_obj}. " + "This may be an oversight.") + return found + + +def inspect_serializability( + base_obj: Any, + name: Optional[str] = None, + depth: int = 3, + _parent: Optional[Any] = None, + _failure_set: Optional[set] = None) -> Tuple[bool, Set[FailureTuple]]: + """Identifies what objects are preventing serialization. + + Args: + base_obj: Object to be serialized. + name: Optional name of string. + depth: Depth of the scope stack to walk through. Defaults to 3. + + Returns: + bool: True if serializable. + set[FailureTuple]: Set of unserializable objects. + + .. versionadded:: 1.1.0 + + """ + colorama.init() + top_level = False + declaration = "" + found = False + if _failure_set is None: + top_level = True + _failure_set = set() + declaration = f"Checking Serializability of {base_obj}" + print("=" * min(len(declaration), 80)) + print(declaration) + print("=" * min(len(declaration), 80)) + + if name is None: + name = str(base_obj) + else: + _printer.print(f"Serializing '{name}' {base_obj}...") + try: + cp.dumps(base_obj) + return True, _failure_set + except Exception as e: + _printer.print(f"{colorama.Fore.RED}!!! FAIL{colorama.Fore.RESET} " + f"serialization: {e}") + found = True + try: + if depth == 0: + _failure_set.add(FailureTuple(base_obj, name, _parent)) + # Some objects may not be hashable, so we skip adding this to the set. + except Exception: + pass + + if depth <= 0: + return False, _failure_set + + # TODO: we only differentiate between 'function' and 'object' + # but we should do a better job of diving into something + # more specific like a Type, Object, etc. + if inspect.isfunction(base_obj): + _inspect_func_serialization( + base_obj, depth=depth, parent=base_obj, failure_set=_failure_set) + else: + _inspect_generic_serialization( + base_obj, depth=depth, parent=base_obj, failure_set=_failure_set) + + if not _failure_set: + _failure_set.add(FailureTuple(base_obj, name, _parent)) + + if top_level: + print("=" * min(len(declaration), 80)) + if not _failure_set: + print("Nothing failed the inspect_serialization test, though " + "serialization did not succeed.") + else: + fail_vars = f"\n\n\t{colorama.Style.BRIGHT}" + "\n".join( + str(k) + for k in _failure_set) + f"{colorama.Style.RESET_ALL}\n\n" + print(f"Variable: {fail_vars}was found to be non-serializable. " + "There may be multiple other undetected variables that were " + "non-serializable. ") + print("Consider either removing the " + "instantiation/imports of these variables or moving the " + "instantiation into the scope of the function/class. ") + print("If you have any suggestions on how to improve " + "this error message, please reach out to the " + "Ray developers on github.com/ray-project/ray/issues/") + print("=" * min(len(declaration), 80)) + return not found, _failure_set