mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 14:19:24 +08:00
[Java] Support direct call for normal tasks (#7193)
This commit is contained in:
+46
-27
@@ -10,6 +10,7 @@ from ray.utils import _random_string
|
||||
from ray.gcs_utils import ErrorType
|
||||
from ray.exceptions import (
|
||||
PlasmaObjectNotAvailable,
|
||||
RayTaskError,
|
||||
RayActorError,
|
||||
RayWorkerError,
|
||||
UnreconstructableError,
|
||||
@@ -52,9 +53,9 @@ class SerializedObject:
|
||||
|
||||
|
||||
class Pickle5SerializedObject(SerializedObject):
|
||||
def __init__(self, inband, writer, contained_object_ids):
|
||||
super(Pickle5SerializedObject, self).__init__(
|
||||
ray_constants.PICKLE5_BUFFER_METADATA, contained_object_ids)
|
||||
def __init__(self, metadata, inband, writer, contained_object_ids):
|
||||
super(Pickle5SerializedObject, self).__init__(metadata,
|
||||
contained_object_ids)
|
||||
self.inband = inband
|
||||
self.writer = writer
|
||||
# cached total bytes
|
||||
@@ -226,24 +227,26 @@ class SerializationContext:
|
||||
|
||||
self._thread_local.object_ids.add(object_id)
|
||||
|
||||
def _deserialize_pickle5_data(self, data):
|
||||
if not self.use_pickle:
|
||||
raise ValueError("Receiving pickle5 serialized objects "
|
||||
"while the serialization context is "
|
||||
"using a custom raw backend.")
|
||||
try:
|
||||
in_band, buffers = unpack_pickle5_buffers(data)
|
||||
if len(buffers) > 0:
|
||||
obj = pickle.loads(in_band, buffers=buffers)
|
||||
else:
|
||||
obj = pickle.loads(in_band)
|
||||
# cloudpickle does not provide error types
|
||||
except pickle.pickle.PicklingError:
|
||||
raise DeserializationError()
|
||||
return obj
|
||||
|
||||
def _deserialize_object(self, data, metadata, object_id):
|
||||
if metadata:
|
||||
if metadata == ray_constants.PICKLE5_BUFFER_METADATA:
|
||||
if not self.use_pickle:
|
||||
raise ValueError("Receiving pickle5 serialized objects "
|
||||
"while the serialization context is "
|
||||
"using a custom raw backend.")
|
||||
try:
|
||||
in_band, buffers = unpack_pickle5_buffers(data)
|
||||
if len(buffers) > 0:
|
||||
obj = pickle.loads(in_band, buffers=buffers)
|
||||
else:
|
||||
obj = pickle.loads(in_band)
|
||||
# cloudpickle does not provide error types
|
||||
except pickle.pickle.PicklingError:
|
||||
raise DeserializationError()
|
||||
return obj
|
||||
|
||||
return self._deserialize_pickle5_data(data)
|
||||
# Check if the object should be returned as raw bytes.
|
||||
if metadata == ray_constants.RAW_BUFFER_METADATA:
|
||||
if data is None:
|
||||
@@ -252,7 +255,14 @@ class SerializationContext:
|
||||
# Otherwise, return an exception object based on
|
||||
# the error type.
|
||||
error_type = int(metadata)
|
||||
if error_type == ErrorType.Value("WORKER_DIED"):
|
||||
# RayTaskError is serialized with pickle5 in the data field.
|
||||
# TODO (kfstorm): exception serialization should be language
|
||||
# independent.
|
||||
if error_type == ErrorType.Value("TASK_EXECUTION_EXCEPTION"):
|
||||
obj = self._deserialize_pickle5_data(data)
|
||||
assert isinstance(obj, RayTaskError)
|
||||
return obj
|
||||
elif error_type == ErrorType.Value("WORKER_DIED"):
|
||||
return RayWorkerError()
|
||||
elif error_type == ErrorType.Value("ACTOR_DIED"):
|
||||
return RayActorError()
|
||||
@@ -326,15 +336,24 @@ class SerializationContext:
|
||||
# use a special metadata to indicate it's raw binary. So
|
||||
# that this object can also be read by Java.
|
||||
return RawSerializedObject(value)
|
||||
else:
|
||||
# Only RayTaskError is possible to be serialized here. We don't
|
||||
# need to deal with other exception types here.
|
||||
if isinstance(value, RayTaskError):
|
||||
metadata = str(ErrorType.Value(
|
||||
"TASK_EXECUTION_EXCEPTION")).encode("ascii")
|
||||
else:
|
||||
metadata = ray_constants.PICKLE5_BUFFER_METADATA
|
||||
|
||||
assert self.worker.use_pickle
|
||||
assert ray.cloudpickle.FAST_CLOUDPICKLE_USED
|
||||
writer = Pickle5Writer()
|
||||
# TODO(swang): Check that contained_object_ids is empty.
|
||||
inband = pickle.dumps(
|
||||
value, protocol=5, buffer_callback=writer.buffer_callback)
|
||||
return Pickle5SerializedObject(
|
||||
inband, writer, self.get_and_clear_contained_object_ids())
|
||||
assert self.worker.use_pickle
|
||||
assert ray.cloudpickle.FAST_CLOUDPICKLE_USED
|
||||
writer = Pickle5Writer()
|
||||
# TODO(swang): Check that contained_object_ids is empty.
|
||||
inband = pickle.dumps(
|
||||
value, protocol=5, buffer_callback=writer.buffer_callback)
|
||||
return Pickle5SerializedObject(
|
||||
metadata, inband, writer,
|
||||
self.get_and_clear_contained_object_ids())
|
||||
|
||||
def register_custom_serializer(self,
|
||||
cls,
|
||||
|
||||
@@ -5,6 +5,7 @@ import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -82,6 +83,56 @@ def test_failed_task(ray_start_regular):
|
||||
assert False
|
||||
|
||||
|
||||
def test_get_throws_quickly_when_found_exception(ray_start_regular):
|
||||
def random_path():
|
||||
return os.path.join(tempfile.gettempdir(), uuid.uuid4().hex)
|
||||
|
||||
def touch(path):
|
||||
with open(path, "w"):
|
||||
pass
|
||||
|
||||
def wait_for_file(path):
|
||||
while True:
|
||||
if os.path.exists(path):
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
# We use an actor instead of functions here. If we use functions, it's
|
||||
# very likely that two normal tasks are submitted before the first worker
|
||||
# is registered to Raylet. Since `maximum_startup_concurrency` is 1,
|
||||
# the worker pool will wait for the registration of the first worker
|
||||
# and skip starting new workers. The result is, the two tasks will be
|
||||
# executed sequentially, which breaks an assumption of this test case -
|
||||
# the two tasks run in parallel.
|
||||
@ray.remote
|
||||
class Actor(object):
|
||||
def bad_func1(self):
|
||||
raise Exception("Test function intentionally failed.")
|
||||
|
||||
def bad_func2(self):
|
||||
os._exit(0)
|
||||
|
||||
def slow_func(self, path):
|
||||
wait_for_file(path)
|
||||
|
||||
def expect_exception(objects, exception):
|
||||
with pytest.raises(ray.exceptions.RayError) as err:
|
||||
ray.get(objects)
|
||||
assert err.type is exception
|
||||
|
||||
f = random_path()
|
||||
actor = Actor.options(is_direct_call=True, max_concurrency=2).remote()
|
||||
expect_exception([actor.bad_func1.remote(),
|
||||
actor.slow_func.remote(f)], ray.exceptions.RayTaskError)
|
||||
touch(f)
|
||||
|
||||
f = random_path()
|
||||
actor = Actor.options(is_direct_call=True, max_concurrency=2).remote()
|
||||
expect_exception([actor.bad_func2.remote(),
|
||||
actor.slow_func.remote(f)], ray.exceptions.RayActorError)
|
||||
touch(f)
|
||||
|
||||
|
||||
def test_fail_importing_remote_function(ray_start_2_cpus):
|
||||
# Create the contents of a temporary Python file.
|
||||
temporary_python_file = """
|
||||
|
||||
Reference in New Issue
Block a user