[Java] Support direct call for normal tasks (#7193)

This commit is contained in:
Kai Yang
2020-02-21 10:03:34 +08:00
committed by GitHub
parent f27bb6eb47
commit 007333b960
21 changed files with 184 additions and 99 deletions
+46 -27
View File
@@ -10,6 +10,7 @@ from ray.utils import _random_string
from ray.gcs_utils import ErrorType
from ray.exceptions import (
PlasmaObjectNotAvailable,
RayTaskError,
RayActorError,
RayWorkerError,
UnreconstructableError,
@@ -52,9 +53,9 @@ class SerializedObject:
class Pickle5SerializedObject(SerializedObject):
def __init__(self, inband, writer, contained_object_ids):
super(Pickle5SerializedObject, self).__init__(
ray_constants.PICKLE5_BUFFER_METADATA, contained_object_ids)
def __init__(self, metadata, inband, writer, contained_object_ids):
super(Pickle5SerializedObject, self).__init__(metadata,
contained_object_ids)
self.inband = inband
self.writer = writer
# cached total bytes
@@ -226,24 +227,26 @@ class SerializationContext:
self._thread_local.object_ids.add(object_id)
def _deserialize_pickle5_data(self, data):
if not self.use_pickle:
raise ValueError("Receiving pickle5 serialized objects "
"while the serialization context is "
"using a custom raw backend.")
try:
in_band, buffers = unpack_pickle5_buffers(data)
if len(buffers) > 0:
obj = pickle.loads(in_band, buffers=buffers)
else:
obj = pickle.loads(in_band)
# cloudpickle does not provide error types
except pickle.pickle.PicklingError:
raise DeserializationError()
return obj
def _deserialize_object(self, data, metadata, object_id):
if metadata:
if metadata == ray_constants.PICKLE5_BUFFER_METADATA:
if not self.use_pickle:
raise ValueError("Receiving pickle5 serialized objects "
"while the serialization context is "
"using a custom raw backend.")
try:
in_band, buffers = unpack_pickle5_buffers(data)
if len(buffers) > 0:
obj = pickle.loads(in_band, buffers=buffers)
else:
obj = pickle.loads(in_band)
# cloudpickle does not provide error types
except pickle.pickle.PicklingError:
raise DeserializationError()
return obj
return self._deserialize_pickle5_data(data)
# Check if the object should be returned as raw bytes.
if metadata == ray_constants.RAW_BUFFER_METADATA:
if data is None:
@@ -252,7 +255,14 @@ class SerializationContext:
# Otherwise, return an exception object based on
# the error type.
error_type = int(metadata)
if error_type == ErrorType.Value("WORKER_DIED"):
# RayTaskError is serialized with pickle5 in the data field.
# TODO (kfstorm): exception serialization should be language
# independent.
if error_type == ErrorType.Value("TASK_EXECUTION_EXCEPTION"):
obj = self._deserialize_pickle5_data(data)
assert isinstance(obj, RayTaskError)
return obj
elif error_type == ErrorType.Value("WORKER_DIED"):
return RayWorkerError()
elif error_type == ErrorType.Value("ACTOR_DIED"):
return RayActorError()
@@ -326,15 +336,24 @@ class SerializationContext:
# use a special metadata to indicate it's raw binary. So
# that this object can also be read by Java.
return RawSerializedObject(value)
else:
# Only RayTaskError is possible to be serialized here. We don't
# need to deal with other exception types here.
if isinstance(value, RayTaskError):
metadata = str(ErrorType.Value(
"TASK_EXECUTION_EXCEPTION")).encode("ascii")
else:
metadata = ray_constants.PICKLE5_BUFFER_METADATA
assert self.worker.use_pickle
assert ray.cloudpickle.FAST_CLOUDPICKLE_USED
writer = Pickle5Writer()
# TODO(swang): Check that contained_object_ids is empty.
inband = pickle.dumps(
value, protocol=5, buffer_callback=writer.buffer_callback)
return Pickle5SerializedObject(
inband, writer, self.get_and_clear_contained_object_ids())
assert self.worker.use_pickle
assert ray.cloudpickle.FAST_CLOUDPICKLE_USED
writer = Pickle5Writer()
# TODO(swang): Check that contained_object_ids is empty.
inband = pickle.dumps(
value, protocol=5, buffer_callback=writer.buffer_callback)
return Pickle5SerializedObject(
metadata, inband, writer,
self.get_and_clear_contained_object_ids())
def register_custom_serializer(self,
cls,
+51
View File
@@ -5,6 +5,7 @@ import sys
import tempfile
import threading
import time
import uuid
import numpy as np
import pytest
@@ -82,6 +83,56 @@ def test_failed_task(ray_start_regular):
assert False
def test_get_throws_quickly_when_found_exception(ray_start_regular):
def random_path():
return os.path.join(tempfile.gettempdir(), uuid.uuid4().hex)
def touch(path):
with open(path, "w"):
pass
def wait_for_file(path):
while True:
if os.path.exists(path):
break
time.sleep(0.1)
# We use an actor instead of functions here. If we use functions, it's
# very likely that two normal tasks are submitted before the first worker
# is registered to Raylet. Since `maximum_startup_concurrency` is 1,
# the worker pool will wait for the registration of the first worker
# and skip starting new workers. The result is, the two tasks will be
# executed sequentially, which breaks an assumption of this test case -
# the two tasks run in parallel.
@ray.remote
class Actor(object):
def bad_func1(self):
raise Exception("Test function intentionally failed.")
def bad_func2(self):
os._exit(0)
def slow_func(self, path):
wait_for_file(path)
def expect_exception(objects, exception):
with pytest.raises(ray.exceptions.RayError) as err:
ray.get(objects)
assert err.type is exception
f = random_path()
actor = Actor.options(is_direct_call=True, max_concurrency=2).remote()
expect_exception([actor.bad_func1.remote(),
actor.slow_func.remote(f)], ray.exceptions.RayTaskError)
touch(f)
f = random_path()
actor = Actor.options(is_direct_call=True, max_concurrency=2).remote()
expect_exception([actor.bad_func2.remote(),
actor.slow_func.remote(f)], ray.exceptions.RayActorError)
touch(f)
def test_fail_importing_remote_function(ray_start_2_cpus):
# Create the contents of a temporary Python file.
temporary_python_file = """