mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 16:54:21 +08:00
[tune/core] serialization debugging utility (#12142)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: Kai Fricke <kai@anyscale.com>
This commit is contained in:
@@ -1,11 +1,47 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
import os
|
||||
from pickle import PicklingError
|
||||
|
||||
from ray.cloudpickle.cloudpickle import * # noqa
|
||||
from ray.cloudpickle.cloudpickle_fast import CloudPickler, dumps, dump # noqa
|
||||
|
||||
|
||||
# Conform to the convention used by python serialization libraries, which
|
||||
# expose their Pickler subclass at top-level under the "Pickler" name.
|
||||
Pickler = CloudPickler
|
||||
|
||||
__version__ = '1.6.0'
|
||||
__version__ = "1.6.0"
|
||||
|
||||
|
||||
def _warn_msg(obj, method, exc):
|
||||
return (
|
||||
f"{method}({str(obj)}) failed."
|
||||
"\nTo check which non-serializable variables are captured "
|
||||
"in scope, re-run the ray script with 'RAY_PICKLE_VERBOSE_DEBUG=1'.")
|
||||
|
||||
|
||||
def dump_debug(obj, *args, **kwargs):
|
||||
try:
|
||||
return dump(obj, *args, **kwargs)
|
||||
except (TypeError, PicklingError) as exc:
|
||||
if os.environ.get("RAY_PICKLE_VERBOSE_DEBUG"):
|
||||
from ray.util.check_serialize import inspect_serializability
|
||||
inspect_serializability(obj)
|
||||
raise
|
||||
else:
|
||||
msg = _warn_msg(obj, "ray.cloudpickle.dump", exc)
|
||||
raise type(exc)(msg)
|
||||
|
||||
|
||||
def dumps_debug(obj, *args, **kwargs):
|
||||
try:
|
||||
return dumps(obj, *args, **kwargs)
|
||||
except (TypeError, PicklingError) as exc:
|
||||
if os.environ.get("RAY_PICKLE_VERBOSE_DEBUG"):
|
||||
from ray.util.check_serialize import inspect_serializability
|
||||
inspect_serializability(obj)
|
||||
raise
|
||||
else:
|
||||
msg = _warn_msg(obj, "ray.cloudpickle.dumps", exc)
|
||||
raise type(exc)(msg)
|
||||
|
||||
@@ -61,6 +61,7 @@ if __name__ == "__main__":
|
||||
do_link("rllib", force=args.yes, local_path="../../../rllib")
|
||||
do_link("tune", force=args.yes)
|
||||
do_link("autoscaler", force=args.yes)
|
||||
do_link("cloudpickle", force=args.yes)
|
||||
do_link("scripts", force=args.yes)
|
||||
do_link("internal", force=args.yes)
|
||||
do_link("tests", force=args.yes)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
This file defines the common pytest fixtures used in current directory.
|
||||
"""
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
import pytest
|
||||
import subprocess
|
||||
@@ -204,6 +204,13 @@ def call_ray_stop_only():
|
||||
subprocess.check_call(["ray", "stop"])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def enable_pickle_debug():
|
||||
os.environ["RAY_PICKLE_VERBOSE_DEBUG"] = "1"
|
||||
yield
|
||||
del os.environ["RAY_PICKLE_VERBOSE_DEBUG"]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def two_node_cluster():
|
||||
system_config = {
|
||||
|
||||
@@ -331,6 +331,36 @@ def test_numpy_subclass_serialization_pickle(ray_start_regular):
|
||||
assert repr_orig == repr_ser
|
||||
|
||||
|
||||
def test_inspect_serialization(enable_pickle_debug):
|
||||
import threading
|
||||
from ray.cloudpickle import dumps_debug
|
||||
|
||||
lock = threading.Lock()
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
dumps_debug(lock)
|
||||
|
||||
def test_func():
|
||||
print(lock)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
dumps_debug(test_func)
|
||||
|
||||
class test_class:
|
||||
def test(self):
|
||||
self.lock = lock
|
||||
|
||||
from ray.util.check_serialize import inspect_serializability
|
||||
results = inspect_serializability(lock)
|
||||
assert list(results[1])[0].obj == lock, results
|
||||
|
||||
results = inspect_serializability(test_func)
|
||||
assert list(results[1])[0].obj == lock, results
|
||||
|
||||
results = inspect_serializability(test_class)
|
||||
assert list(results[1])[0].obj == lock, results
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_regular", [{
|
||||
"local_mode": True
|
||||
|
||||
@@ -286,21 +286,12 @@ class Experiment:
|
||||
try:
|
||||
register_trainable(name, run_object)
|
||||
except (TypeError, PicklingError) as e:
|
||||
msg = (
|
||||
f"{str(e)}. The trainable ({str(run_object)}) could not "
|
||||
"be serialized, which is needed for parallel execution. "
|
||||
"To diagnose the issue, try the following:\n\n"
|
||||
"\t- Run `tune.utils.diagnose_serialization(trainable)` "
|
||||
"to check if non-serializable variables are captured "
|
||||
"in scope.\n"
|
||||
"\t- Try reproducing the issue by calling "
|
||||
"`pickle.dumps(trainable)`.\n"
|
||||
"\t- If the error is typing-related, try removing "
|
||||
"the type annotations and try again.\n\n"
|
||||
"If you have any suggestions on how to improve "
|
||||
"this error message, please reach out to the "
|
||||
"Ray developers on github.com/ray-project/ray/issues/")
|
||||
raise type(e)(msg) from None
|
||||
extra_msg = (f"Other options: "
|
||||
"\n-Try reproducing the issue by calling "
|
||||
"`pickle.dumps(trainable)`. "
|
||||
"\n-If the error is typing-related, try removing "
|
||||
"the type annotations and try again.")
|
||||
raise type(e)(str(e) + " " + extra_msg) from None
|
||||
return name
|
||||
else:
|
||||
raise TuneError("Improper 'run' - not string nor trainable.")
|
||||
|
||||
@@ -119,7 +119,7 @@ class _Registry:
|
||||
from ray.tune import TuneError
|
||||
raise TuneError("Unknown category {} not among {}".format(
|
||||
category, KNOWN_CATEGORIES))
|
||||
self._to_flush[(category, key)] = pickle.dumps(value)
|
||||
self._to_flush[(category, key)] = pickle.dumps_debug(value)
|
||||
if _internal_kv_initialized():
|
||||
self.flush_values()
|
||||
|
||||
|
||||
@@ -1038,8 +1038,6 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
self.assertLessEqual(status.get("PENDING", 0), 1)
|
||||
|
||||
def testMetricCheckingEndToEnd(self):
|
||||
from ray import tune
|
||||
|
||||
def train(config):
|
||||
tune.report(val=4, second=8)
|
||||
|
||||
@@ -1130,6 +1128,50 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
self.assertFalse(found)
|
||||
|
||||
|
||||
class SerializabilityTest(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
ray.init(local_mode=True)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
ray.shutdown()
|
||||
|
||||
def tearDown(self):
|
||||
if "RAY_PICKLE_VERBOSE_DEBUG" in os.environ:
|
||||
del os.environ["RAY_PICKLE_VERBOSE_DEBUG"]
|
||||
|
||||
def testNotRaisesNonserializable(self):
|
||||
import threading
|
||||
lock = threading.Lock()
|
||||
|
||||
def train(config):
|
||||
print(lock)
|
||||
tune.report(val=4, second=8)
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "RAY_PICKLE_VERBOSE_DEBUG"):
|
||||
# The trial runner raises a ValueError, but the experiment fails
|
||||
# with a TuneError
|
||||
tune.run(train, metric="acc")
|
||||
|
||||
def testRaisesNonserializable(self):
|
||||
os.environ["RAY_PICKLE_VERBOSE_DEBUG"] = "1"
|
||||
import threading
|
||||
lock = threading.Lock()
|
||||
|
||||
def train(config):
|
||||
print(lock)
|
||||
tune.report(val=4, second=8)
|
||||
|
||||
with self.assertRaises(TypeError) as cm:
|
||||
# The trial runner raises a ValueError, but the experiment fails
|
||||
# with a TuneError
|
||||
tune.run(train, metric="acc")
|
||||
msg = cm.exception.args[0]
|
||||
assert "RAY_PICKLE_VERBOSE_DEBUG" not in msg
|
||||
assert "thread.lock" in msg
|
||||
|
||||
|
||||
class ShimCreationTest(unittest.TestCase):
|
||||
def testCreateScheduler(self):
|
||||
kwargs = {"metric": "metric_foo", "mode": "min"}
|
||||
|
||||
@@ -4,6 +4,7 @@ import types
|
||||
|
||||
from ray import cloudpickle as cloudpickle
|
||||
from ray.utils import binary_to_hex, hex_to_binary
|
||||
from ray.util.debug import log_once
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -15,7 +16,8 @@ class TuneFunctionEncoder(json.JSONEncoder):
|
||||
try:
|
||||
return super(TuneFunctionEncoder, self).default(obj)
|
||||
except Exception:
|
||||
logger.debug("Unable to encode. Falling back to cloudpickle.")
|
||||
if log_once(f"tune_func_encode:{str(obj)}"):
|
||||
logger.debug("Unable to encode. Falling back to cloudpickle.")
|
||||
return self._to_cloudpickle(obj)
|
||||
|
||||
def _to_cloudpickle(self, obj):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from ray.util import iter
|
||||
from ray.util.actor_pool import ActorPool
|
||||
from ray.util.check_serialize import inspect_serializability
|
||||
from ray.util.debug import log_once, disable_log_once_globally, \
|
||||
enable_periodic_logging
|
||||
from ray.util.placement_group import (placement_group, placement_group_table,
|
||||
@@ -9,5 +10,5 @@ from ray.util import rpdb as pdb
|
||||
__all__ = [
|
||||
"ActorPool", "disable_log_once_globally", "enable_periodic_logging",
|
||||
"iter", "log_once", "pdb", "placement_group", "placement_group_table",
|
||||
"remove_placement_group"
|
||||
"remove_placement_group", "inspect_serializability"
|
||||
]
|
||||
|
||||
@@ -0,0 +1,216 @@
|
||||
"""A utility for debugging serialization issues."""
|
||||
from typing import Any, Tuple, Set, Optional
|
||||
import inspect
|
||||
import ray.cloudpickle as cp
|
||||
import colorama
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _indent(printer):
|
||||
printer.level += 1
|
||||
yield
|
||||
printer.level -= 1
|
||||
|
||||
|
||||
class _Printer:
|
||||
def __init__(self):
|
||||
self.level = 0
|
||||
|
||||
def indent(self):
|
||||
return _indent(self)
|
||||
|
||||
def print(self, msg):
|
||||
indent = " " * self.level
|
||||
print(indent + msg)
|
||||
|
||||
|
||||
_printer = _Printer()
|
||||
|
||||
|
||||
class FailureTuple:
|
||||
"""Represents the serialization 'frame'.
|
||||
|
||||
Attributes:
|
||||
obj: The object that fails serialization.
|
||||
name: The variable name of the object.
|
||||
parent: The object that references the `obj`.
|
||||
"""
|
||||
|
||||
def __init__(self, obj: Any, name: str, parent: Any):
|
||||
self.obj = obj
|
||||
self.name = name
|
||||
self.parent = parent
|
||||
|
||||
def __repr__(self):
|
||||
return f"FailTuple({self.name} [obj={self.obj}, parent={self.parent}])"
|
||||
|
||||
|
||||
def _inspect_func_serialization(base_obj, depth, parent, failure_set):
|
||||
"""Adds the first-found non-serializable element to the failure_set."""
|
||||
assert inspect.isfunction(base_obj)
|
||||
closure = inspect.getclosurevars(base_obj)
|
||||
found = False
|
||||
if closure.globals:
|
||||
_printer.print(f"Detected {len(closure.globals)} global variables. "
|
||||
"Checking serializability...")
|
||||
|
||||
with _printer.indent():
|
||||
for name, obj in closure.globals.items():
|
||||
serializable, _ = inspect_serializability(
|
||||
obj,
|
||||
name=name,
|
||||
depth=depth - 1,
|
||||
_parent=parent,
|
||||
_failure_set=failure_set)
|
||||
found = found or not serializable
|
||||
if found:
|
||||
break
|
||||
|
||||
if closure.nonlocals:
|
||||
_printer.print(
|
||||
f"Detected {len(closure.nonlocals)} nonlocal variables. "
|
||||
"Checking serializability...")
|
||||
with _printer.indent():
|
||||
for name, obj in closure.nonlocals.items():
|
||||
serializable, _ = inspect_serializability(
|
||||
obj,
|
||||
name=name,
|
||||
depth=depth - 1,
|
||||
_parent=parent,
|
||||
_failure_set=failure_set)
|
||||
found = found or not serializable
|
||||
if found:
|
||||
break
|
||||
if not found:
|
||||
_printer.print(
|
||||
f"WARNING: Did not find non-serializable object in {base_obj}. "
|
||||
"This may be an oversight.")
|
||||
return found
|
||||
|
||||
|
||||
def _inspect_generic_serialization(base_obj, depth, parent, failure_set):
|
||||
"""Adds the first-found non-serializable element to the failure_set."""
|
||||
assert not inspect.isfunction(base_obj)
|
||||
functions = inspect.getmembers(base_obj, predicate=inspect.isfunction)
|
||||
found = False
|
||||
with _printer.indent():
|
||||
for name, obj in functions:
|
||||
serializable, _ = inspect_serializability(
|
||||
obj,
|
||||
name=name,
|
||||
depth=depth - 1,
|
||||
_parent=parent,
|
||||
_failure_set=failure_set)
|
||||
found = found or not serializable
|
||||
if found:
|
||||
break
|
||||
|
||||
with _printer.indent():
|
||||
members = inspect.getmembers(base_obj)
|
||||
for name, obj in members:
|
||||
if name.startswith("__") and name.endswith(
|
||||
"__") or inspect.isbuiltin(obj):
|
||||
continue
|
||||
serializable, _ = inspect_serializability(
|
||||
obj,
|
||||
name=name,
|
||||
depth=depth - 1,
|
||||
_parent=parent,
|
||||
_failure_set=failure_set)
|
||||
found = found or not serializable
|
||||
if found:
|
||||
break
|
||||
if not found:
|
||||
_printer.print(
|
||||
f"WARNING: Did not find non-serializable object in {base_obj}. "
|
||||
"This may be an oversight.")
|
||||
return found
|
||||
|
||||
|
||||
def inspect_serializability(
|
||||
base_obj: Any,
|
||||
name: Optional[str] = None,
|
||||
depth: int = 3,
|
||||
_parent: Optional[Any] = None,
|
||||
_failure_set: Optional[set] = None) -> Tuple[bool, Set[FailureTuple]]:
|
||||
"""Identifies what objects are preventing serialization.
|
||||
|
||||
Args:
|
||||
base_obj: Object to be serialized.
|
||||
name: Optional name of string.
|
||||
depth: Depth of the scope stack to walk through. Defaults to 3.
|
||||
|
||||
Returns:
|
||||
bool: True if serializable.
|
||||
set[FailureTuple]: Set of unserializable objects.
|
||||
|
||||
.. versionadded:: 1.1.0
|
||||
|
||||
"""
|
||||
colorama.init()
|
||||
top_level = False
|
||||
declaration = ""
|
||||
found = False
|
||||
if _failure_set is None:
|
||||
top_level = True
|
||||
_failure_set = set()
|
||||
declaration = f"Checking Serializability of {base_obj}"
|
||||
print("=" * min(len(declaration), 80))
|
||||
print(declaration)
|
||||
print("=" * min(len(declaration), 80))
|
||||
|
||||
if name is None:
|
||||
name = str(base_obj)
|
||||
else:
|
||||
_printer.print(f"Serializing '{name}' {base_obj}...")
|
||||
try:
|
||||
cp.dumps(base_obj)
|
||||
return True, _failure_set
|
||||
except Exception as e:
|
||||
_printer.print(f"{colorama.Fore.RED}!!! FAIL{colorama.Fore.RESET} "
|
||||
f"serialization: {e}")
|
||||
found = True
|
||||
try:
|
||||
if depth == 0:
|
||||
_failure_set.add(FailureTuple(base_obj, name, _parent))
|
||||
# Some objects may not be hashable, so we skip adding this to the set.
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if depth <= 0:
|
||||
return False, _failure_set
|
||||
|
||||
# TODO: we only differentiate between 'function' and 'object'
|
||||
# but we should do a better job of diving into something
|
||||
# more specific like a Type, Object, etc.
|
||||
if inspect.isfunction(base_obj):
|
||||
_inspect_func_serialization(
|
||||
base_obj, depth=depth, parent=base_obj, failure_set=_failure_set)
|
||||
else:
|
||||
_inspect_generic_serialization(
|
||||
base_obj, depth=depth, parent=base_obj, failure_set=_failure_set)
|
||||
|
||||
if not _failure_set:
|
||||
_failure_set.add(FailureTuple(base_obj, name, _parent))
|
||||
|
||||
if top_level:
|
||||
print("=" * min(len(declaration), 80))
|
||||
if not _failure_set:
|
||||
print("Nothing failed the inspect_serialization test, though "
|
||||
"serialization did not succeed.")
|
||||
else:
|
||||
fail_vars = f"\n\n\t{colorama.Style.BRIGHT}" + "\n".join(
|
||||
str(k)
|
||||
for k in _failure_set) + f"{colorama.Style.RESET_ALL}\n\n"
|
||||
print(f"Variable: {fail_vars}was found to be non-serializable. "
|
||||
"There may be multiple other undetected variables that were "
|
||||
"non-serializable. ")
|
||||
print("Consider either removing the "
|
||||
"instantiation/imports of these variables or moving the "
|
||||
"instantiation into the scope of the function/class. ")
|
||||
print("If you have any suggestions on how to improve "
|
||||
"this error message, please reach out to the "
|
||||
"Ray developers on github.com/ray-project/ray/issues/")
|
||||
print("=" * min(len(declaration), 80))
|
||||
return not found, _failure_set
|
||||
Reference in New Issue
Block a user