Cross language exception (#10023)

This commit is contained in:
fyrestone
2020-08-26 10:46:05 +08:00
committed by GitHub
parent 1e99b814f0
commit 08adbb371f
30 changed files with 441 additions and 137 deletions
+3
View File
@@ -222,6 +222,9 @@ cdef class Language:
cdef from_native(const CLanguage& lang):
return Language(<int32_t>lang)
def value(self):
return <int32_t>self.lang
def __eq__(self, other):
return (isinstance(other, Language) and
(<int32_t>self.lang) == (<int32_t>(<Language>other).lang))
+31 -1
View File
@@ -1,14 +1,44 @@
import os
from traceback import format_exception
import colorama
import ray
import ray.cloudpickle as pickle
from ray.core.generated.common_pb2 import RayException, Language
import setproctitle
class RayError(Exception):
"""Super class of all ray exception types."""
pass
def to_bytes(self):
# Extract exc_info from exception object.
exc_info = (type(self), self, self.__traceback__)
formatted_exception_string = "\n".join(format_exception(*exc_info))
return RayException(
language=ray.Language.PYTHON.value(),
serialized_exception=pickle.dumps(self),
formatted_exception_string=formatted_exception_string
).SerializeToString()
@staticmethod
def from_bytes(b):
ray_exception = RayException()
ray_exception.ParseFromString(b)
if ray_exception.language == ray.Language.PYTHON.value():
return pickle.loads(ray_exception.serialized_exception)
else:
return CrossLanguageError(ray_exception)
class CrossLanguageError(RayError):
"""Raised from another language."""
def __init__(self, ray_exception):
super().__init__("An exception raised from {}:\n{}".format(
Language.Name(ray_exception.language),
ray_exception.formatted_exception_string))
class RayConnectionError(RayError):
+44 -13
View File
@@ -1,19 +1,50 @@
from ray.core.generated.common_pb2 import ErrorType
from ray.core.generated.gcs_pb2 import (
ActorCheckpointIdData, ActorTableData, GcsNodeInfo, JobTableData,
JobConfig, ErrorTableData, ErrorType, GcsEntry, HeartbeatBatchTableData,
HeartbeatTableData, ObjectTableData, ProfileTableData, TablePrefix,
TablePubsub, TaskTableData, ResourceMap, ResourceTableData,
ObjectLocationInfo, PubSubMessage, WorkerTableData,
PlacementGroupTableData)
ActorCheckpointIdData,
ActorTableData,
GcsNodeInfo,
JobTableData,
JobConfig,
ErrorTableData,
GcsEntry,
HeartbeatBatchTableData,
HeartbeatTableData,
ObjectTableData,
ProfileTableData,
TablePrefix,
TablePubsub,
TaskTableData,
ResourceMap,
ResourceTableData,
ObjectLocationInfo,
PubSubMessage,
WorkerTableData,
PlacementGroupTableData,
)
__all__ = [
"ActorCheckpointIdData", "ActorTableData", "GcsNodeInfo", "JobTableData",
"JobConfig", "ErrorTableData", "ErrorType", "GcsEntry",
"HeartbeatBatchTableData", "HeartbeatTableData", "ObjectTableData",
"ProfileTableData", "TablePrefix", "TablePubsub", "TaskTableData",
"ResourceMap", "ResourceTableData", "construct_error_message",
"ObjectLocationInfo", "PubSubMessage", "WorkerTableData",
"PlacementGroupTableData"
"ActorCheckpointIdData",
"ActorTableData",
"GcsNodeInfo",
"JobTableData",
"JobConfig",
"ErrorTableData",
"ErrorType",
"GcsEntry",
"HeartbeatBatchTableData",
"HeartbeatTableData",
"ObjectTableData",
"ProfileTableData",
"TablePrefix",
"TablePubsub",
"TaskTableData",
"ResourceMap",
"ResourceTableData",
"construct_error_message",
"ObjectLocationInfo",
"PubSubMessage",
"WorkerTableData",
"PlacementGroupTableData",
]
FUNCTION_PREFIX = "RemoteFunction:"
+17 -16
View File
@@ -9,6 +9,7 @@ import ray.utils
from ray.utils import _random_string
from ray.gcs_utils import ErrorType
from ray.exceptions import (
RayError,
PlasmaObjectNotAvailable,
RayTaskError,
RayActorError,
@@ -221,10 +222,10 @@ class SerializationContext:
def _deserialize_msgpack_data(self, data, metadata):
msgpack_data, pickle5_data = split_buffer(data)
if metadata == ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE:
python_objects = []
else:
if metadata == ray_constants.OBJECT_METADATA_TYPE_PYTHON:
python_objects = self._deserialize_pickle5_data(pickle5_data)
else:
python_objects = []
try:
@@ -262,8 +263,7 @@ class SerializationContext:
# independent.
if error_type == ErrorType.Value("TASK_EXECUTION_EXCEPTION"):
obj = self._deserialize_msgpack_data(data, metadata)
assert isinstance(obj, RayTaskError)
return obj
return RayError.from_bytes(obj)
elif error_type == ErrorType.Value("WORKER_DIED"):
return RayWorkerError()
elif error_type == ErrorType.Value("ACTOR_DIED"):
@@ -347,7 +347,16 @@ class SerializationContext:
metadata, inband, writer,
self.get_and_clear_contained_object_refs())
def _serialize_to_msgpack(self, metadata, value):
def _serialize_to_msgpack(self, value):
# Only RayTaskError is possible to be serialized here. We don't
# need to deal with other exception types here.
if isinstance(value, RayTaskError):
metadata = str(
ErrorType.Value("TASK_EXECUTION_EXCEPTION")).encode("ascii")
value = value.to_bytes()
else:
metadata = ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE
python_objects = []
def _python_serializer(o):
@@ -358,10 +367,10 @@ class SerializationContext:
msgpack_data = MessagePackSerializer.dumps(value, _python_serializer)
if python_objects:
metadata = ray_constants.OBJECT_METADATA_TYPE_PYTHON
pickle5_serialized_object = \
self._serialize_to_pickle5(metadata, python_objects)
else:
metadata = ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE
pickle5_serialized_object = None
return MessagePackSerializedObject(metadata, msgpack_data,
@@ -379,15 +388,7 @@ class SerializationContext:
# that this object can also be read by Java.
return RawSerializedObject(value)
else:
# Only RayTaskError is possible to be serialized here. We don't
# need to deal with other exception types here.
if isinstance(value, RayTaskError):
metadata = str(ErrorType.Value(
"TASK_EXECUTION_EXCEPTION")).encode("ascii")
else:
metadata = ray_constants.OBJECT_METADATA_TYPE_PYTHON
return self._serialize_to_msgpack(metadata, value)
return self._serialize_to_msgpack(value)
def register_custom_serializer(self,
cls,
+17
View File
@@ -12,6 +12,7 @@ import redis
import ray
import ray.ray_constants as ray_constants
from ray.exceptions import RayTaskError
from ray.cluster_utils import Cluster
from ray.test_utils import (
wait_for_condition,
@@ -452,6 +453,22 @@ def test_actor_scope_or_intentionally_killed_message(ray_start_regular,
errors)
def test_exception_chain(ray_start_regular):
@ray.remote
def bar():
return 1 / 0
@ray.remote
def foo():
return ray.get(bar.remote())
r = foo.remote()
try:
ray.get(r)
except ZeroDivisionError as ex:
assert isinstance(ex, RayTaskError)
@pytest.mark.skip("This test does not work yet.")
@pytest.mark.parametrize(
"ray_start_object_store_memory", [10**6], indirect=True)