[tune/core] serialization debugging utility (#12142)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
Co-authored-by: Kai Fricke <kai@anyscale.com>
This commit is contained in:
Richard Liaw
2020-12-02 00:52:17 -08:00
committed by GitHub
parent 63b85df828
commit a21523c709
12 changed files with 394 additions and 24 deletions
+37 -1
View File
@@ -1,11 +1,47 @@
from __future__ import absolute_import
import os
from pickle import PicklingError
from ray.cloudpickle.cloudpickle import * # noqa
from ray.cloudpickle.cloudpickle_fast import CloudPickler, dumps, dump # noqa
# Conform to the convention used by python serialization libraries, which
# expose their Pickler subclass at top-level under the "Pickler" name.
Pickler = CloudPickler
__version__ = '1.6.0'
__version__ = "1.6.0"
def _warn_msg(obj, method, exc):
return (
f"{method}({str(obj)}) failed."
"\nTo check which non-serializable variables are captured "
"in scope, re-run the ray script with 'RAY_PICKLE_VERBOSE_DEBUG=1'.")
def dump_debug(obj, *args, **kwargs):
try:
return dump(obj, *args, **kwargs)
except (TypeError, PicklingError) as exc:
if os.environ.get("RAY_PICKLE_VERBOSE_DEBUG"):
from ray.util.check_serialize import inspect_serializability
inspect_serializability(obj)
raise
else:
msg = _warn_msg(obj, "ray.cloudpickle.dump", exc)
raise type(exc)(msg)
def dumps_debug(obj, *args, **kwargs):
try:
return dumps(obj, *args, **kwargs)
except (TypeError, PicklingError) as exc:
if os.environ.get("RAY_PICKLE_VERBOSE_DEBUG"):
from ray.util.check_serialize import inspect_serializability
inspect_serializability(obj)
raise
else:
msg = _warn_msg(obj, "ray.cloudpickle.dumps", exc)
raise type(exc)(msg)
+1
View File
@@ -61,6 +61,7 @@ if __name__ == "__main__":
do_link("rllib", force=args.yes, local_path="../../../rllib")
do_link("tune", force=args.yes)
do_link("autoscaler", force=args.yes)
do_link("cloudpickle", force=args.yes)
do_link("scripts", force=args.yes)
do_link("internal", force=args.yes)
do_link("tests", force=args.yes)
+8 -1
View File
@@ -1,7 +1,7 @@
"""
This file defines the common pytest fixtures used in current directory.
"""
import os
from contextlib import contextmanager
import pytest
import subprocess
@@ -204,6 +204,13 @@ def call_ray_stop_only():
subprocess.check_call(["ray", "stop"])
@pytest.fixture
def enable_pickle_debug():
os.environ["RAY_PICKLE_VERBOSE_DEBUG"] = "1"
yield
del os.environ["RAY_PICKLE_VERBOSE_DEBUG"]
@pytest.fixture()
def two_node_cluster():
system_config = {
+30
View File
@@ -331,6 +331,36 @@ def test_numpy_subclass_serialization_pickle(ray_start_regular):
assert repr_orig == repr_ser
def test_inspect_serialization(enable_pickle_debug):
import threading
from ray.cloudpickle import dumps_debug
lock = threading.Lock()
with pytest.raises(TypeError):
dumps_debug(lock)
def test_func():
print(lock)
with pytest.raises(TypeError):
dumps_debug(test_func)
class test_class:
def test(self):
self.lock = lock
from ray.util.check_serialize import inspect_serializability
results = inspect_serializability(lock)
assert list(results[1])[0].obj == lock, results
results = inspect_serializability(test_func)
assert list(results[1])[0].obj == lock, results
results = inspect_serializability(test_class)
assert list(results[1])[0].obj == lock, results
@pytest.mark.parametrize(
"ray_start_regular", [{
"local_mode": True
+6 -15
View File
@@ -286,21 +286,12 @@ class Experiment:
try:
register_trainable(name, run_object)
except (TypeError, PicklingError) as e:
msg = (
f"{str(e)}. The trainable ({str(run_object)}) could not "
"be serialized, which is needed for parallel execution. "
"To diagnose the issue, try the following:\n\n"
"\t- Run `tune.utils.diagnose_serialization(trainable)` "
"to check if non-serializable variables are captured "
"in scope.\n"
"\t- Try reproducing the issue by calling "
"`pickle.dumps(trainable)`.\n"
"\t- If the error is typing-related, try removing "
"the type annotations and try again.\n\n"
"If you have any suggestions on how to improve "
"this error message, please reach out to the "
"Ray developers on github.com/ray-project/ray/issues/")
raise type(e)(msg) from None
extra_msg = (f"Other options: "
"\n-Try reproducing the issue by calling "
"`pickle.dumps(trainable)`. "
"\n-If the error is typing-related, try removing "
"the type annotations and try again.")
raise type(e)(str(e) + " " + extra_msg) from None
return name
else:
raise TuneError("Improper 'run' - not string nor trainable.")
+1 -1
View File
@@ -119,7 +119,7 @@ class _Registry:
from ray.tune import TuneError
raise TuneError("Unknown category {} not among {}".format(
category, KNOWN_CATEGORIES))
self._to_flush[(category, key)] = pickle.dumps(value)
self._to_flush[(category, key)] = pickle.dumps_debug(value)
if _internal_kv_initialized():
self.flush_values()
+44 -2
View File
@@ -1038,8 +1038,6 @@ class TrainableFunctionApiTest(unittest.TestCase):
self.assertLessEqual(status.get("PENDING", 0), 1)
def testMetricCheckingEndToEnd(self):
from ray import tune
def train(config):
tune.report(val=4, second=8)
@@ -1130,6 +1128,50 @@ class TrainableFunctionApiTest(unittest.TestCase):
self.assertFalse(found)
class SerializabilityTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init(local_mode=True)
@classmethod
def tearDownClass(cls):
ray.shutdown()
def tearDown(self):
if "RAY_PICKLE_VERBOSE_DEBUG" in os.environ:
del os.environ["RAY_PICKLE_VERBOSE_DEBUG"]
def testNotRaisesNonserializable(self):
import threading
lock = threading.Lock()
def train(config):
print(lock)
tune.report(val=4, second=8)
with self.assertRaisesRegex(TypeError, "RAY_PICKLE_VERBOSE_DEBUG"):
# The trial runner raises a ValueError, but the experiment fails
# with a TuneError
tune.run(train, metric="acc")
def testRaisesNonserializable(self):
os.environ["RAY_PICKLE_VERBOSE_DEBUG"] = "1"
import threading
lock = threading.Lock()
def train(config):
print(lock)
tune.report(val=4, second=8)
with self.assertRaises(TypeError) as cm:
# The trial runner raises a ValueError, but the experiment fails
# with a TuneError
tune.run(train, metric="acc")
msg = cm.exception.args[0]
assert "RAY_PICKLE_VERBOSE_DEBUG" not in msg
assert "thread.lock" in msg
class ShimCreationTest(unittest.TestCase):
def testCreateScheduler(self):
kwargs = {"metric": "metric_foo", "mode": "min"}
+3 -1
View File
@@ -4,6 +4,7 @@ import types
from ray import cloudpickle as cloudpickle
from ray.utils import binary_to_hex, hex_to_binary
from ray.util.debug import log_once
logger = logging.getLogger(__name__)
@@ -15,7 +16,8 @@ class TuneFunctionEncoder(json.JSONEncoder):
try:
return super(TuneFunctionEncoder, self).default(obj)
except Exception:
logger.debug("Unable to encode. Falling back to cloudpickle.")
if log_once(f"tune_func_encode:{str(obj)}"):
logger.debug("Unable to encode. Falling back to cloudpickle.")
return self._to_cloudpickle(obj)
def _to_cloudpickle(self, obj):
+2 -1
View File
@@ -1,5 +1,6 @@
from ray.util import iter
from ray.util.actor_pool import ActorPool
from ray.util.check_serialize import inspect_serializability
from ray.util.debug import log_once, disable_log_once_globally, \
enable_periodic_logging
from ray.util.placement_group import (placement_group, placement_group_table,
@@ -9,5 +10,5 @@ from ray.util import rpdb as pdb
__all__ = [
"ActorPool", "disable_log_once_globally", "enable_periodic_logging",
"iter", "log_once", "pdb", "placement_group", "placement_group_table",
"remove_placement_group"
"remove_placement_group", "inspect_serializability"
]
+216
View File
@@ -0,0 +1,216 @@
"""A utility for debugging serialization issues."""
from typing import Any, Tuple, Set, Optional
import inspect
import ray.cloudpickle as cp
import colorama
from contextlib import contextmanager
@contextmanager
def _indent(printer):
printer.level += 1
yield
printer.level -= 1
class _Printer:
def __init__(self):
self.level = 0
def indent(self):
return _indent(self)
def print(self, msg):
indent = " " * self.level
print(indent + msg)
_printer = _Printer()
class FailureTuple:
"""Represents the serialization 'frame'.
Attributes:
obj: The object that fails serialization.
name: The variable name of the object.
parent: The object that references the `obj`.
"""
def __init__(self, obj: Any, name: str, parent: Any):
self.obj = obj
self.name = name
self.parent = parent
def __repr__(self):
return f"FailTuple({self.name} [obj={self.obj}, parent={self.parent}])"
def _inspect_func_serialization(base_obj, depth, parent, failure_set):
"""Adds the first-found non-serializable element to the failure_set."""
assert inspect.isfunction(base_obj)
closure = inspect.getclosurevars(base_obj)
found = False
if closure.globals:
_printer.print(f"Detected {len(closure.globals)} global variables. "
"Checking serializability...")
with _printer.indent():
for name, obj in closure.globals.items():
serializable, _ = inspect_serializability(
obj,
name=name,
depth=depth - 1,
_parent=parent,
_failure_set=failure_set)
found = found or not serializable
if found:
break
if closure.nonlocals:
_printer.print(
f"Detected {len(closure.nonlocals)} nonlocal variables. "
"Checking serializability...")
with _printer.indent():
for name, obj in closure.nonlocals.items():
serializable, _ = inspect_serializability(
obj,
name=name,
depth=depth - 1,
_parent=parent,
_failure_set=failure_set)
found = found or not serializable
if found:
break
if not found:
_printer.print(
f"WARNING: Did not find non-serializable object in {base_obj}. "
"This may be an oversight.")
return found
def _inspect_generic_serialization(base_obj, depth, parent, failure_set):
"""Adds the first-found non-serializable element to the failure_set."""
assert not inspect.isfunction(base_obj)
functions = inspect.getmembers(base_obj, predicate=inspect.isfunction)
found = False
with _printer.indent():
for name, obj in functions:
serializable, _ = inspect_serializability(
obj,
name=name,
depth=depth - 1,
_parent=parent,
_failure_set=failure_set)
found = found or not serializable
if found:
break
with _printer.indent():
members = inspect.getmembers(base_obj)
for name, obj in members:
if name.startswith("__") and name.endswith(
"__") or inspect.isbuiltin(obj):
continue
serializable, _ = inspect_serializability(
obj,
name=name,
depth=depth - 1,
_parent=parent,
_failure_set=failure_set)
found = found or not serializable
if found:
break
if not found:
_printer.print(
f"WARNING: Did not find non-serializable object in {base_obj}. "
"This may be an oversight.")
return found
def inspect_serializability(
base_obj: Any,
name: Optional[str] = None,
depth: int = 3,
_parent: Optional[Any] = None,
_failure_set: Optional[set] = None) -> Tuple[bool, Set[FailureTuple]]:
"""Identifies what objects are preventing serialization.
Args:
base_obj: Object to be serialized.
name: Optional name of string.
depth: Depth of the scope stack to walk through. Defaults to 3.
Returns:
bool: True if serializable.
set[FailureTuple]: Set of unserializable objects.
.. versionadded:: 1.1.0
"""
colorama.init()
top_level = False
declaration = ""
found = False
if _failure_set is None:
top_level = True
_failure_set = set()
declaration = f"Checking Serializability of {base_obj}"
print("=" * min(len(declaration), 80))
print(declaration)
print("=" * min(len(declaration), 80))
if name is None:
name = str(base_obj)
else:
_printer.print(f"Serializing '{name}' {base_obj}...")
try:
cp.dumps(base_obj)
return True, _failure_set
except Exception as e:
_printer.print(f"{colorama.Fore.RED}!!! FAIL{colorama.Fore.RESET} "
f"serialization: {e}")
found = True
try:
if depth == 0:
_failure_set.add(FailureTuple(base_obj, name, _parent))
# Some objects may not be hashable, so we skip adding this to the set.
except Exception:
pass
if depth <= 0:
return False, _failure_set
# TODO: we only differentiate between 'function' and 'object'
# but we should do a better job of diving into something
# more specific like a Type, Object, etc.
if inspect.isfunction(base_obj):
_inspect_func_serialization(
base_obj, depth=depth, parent=base_obj, failure_set=_failure_set)
else:
_inspect_generic_serialization(
base_obj, depth=depth, parent=base_obj, failure_set=_failure_set)
if not _failure_set:
_failure_set.add(FailureTuple(base_obj, name, _parent))
if top_level:
print("=" * min(len(declaration), 80))
if not _failure_set:
print("Nothing failed the inspect_serialization test, though "
"serialization did not succeed.")
else:
fail_vars = f"\n\n\t{colorama.Style.BRIGHT}" + "\n".join(
str(k)
for k in _failure_set) + f"{colorama.Style.RESET_ALL}\n\n"
print(f"Variable: {fail_vars}was found to be non-serializable. "
"There may be multiple other undetected variables that were "
"non-serializable. ")
print("Consider either removing the "
"instantiation/imports of these variables or moving the "
"instantiation into the scope of the function/class. ")
print("If you have any suggestions on how to improve "
"this error message, please reach out to the "
"Ray developers on github.com/ray-project/ray/issues/")
print("=" * min(len(declaration), 80))
return not found, _failure_set