diff --git a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java index 5577a35d4..56901fe36 100644 --- a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java +++ b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java @@ -10,24 +10,17 @@ public class ActorCreationOptions extends BaseTaskOptions { public static final int NO_RECONSTRUCTION = 0; public static final int INFINITE_RECONSTRUCTION = (int) Math.pow(2, 30); - // DO NOT set this environment variable. It's only used for test purposes. - // Please use `setUseDirectCall` instead. - public static final boolean DEFAULT_USE_DIRECT_CALL = "1" - .equals(System.getenv("ACTOR_CREATION_OPTIONS_DEFAULT_USE_DIRECT_CALL")); public final int maxReconstructions; - public final boolean useDirectCall; - public final String jvmOptions; public final int maxConcurrency; private ActorCreationOptions(Map resources, int maxReconstructions, boolean useDirectCall, String jvmOptions, int maxConcurrency) { - super(resources); + super(resources, useDirectCall); this.maxReconstructions = maxReconstructions; - this.useDirectCall = useDirectCall; this.jvmOptions = jvmOptions; this.maxConcurrency = maxConcurrency; } diff --git a/java/api/src/main/java/org/ray/api/options/BaseTaskOptions.java b/java/api/src/main/java/org/ray/api/options/BaseTaskOptions.java index 501ad151e..054e44391 100644 --- a/java/api/src/main/java/org/ray/api/options/BaseTaskOptions.java +++ b/java/api/src/main/java/org/ray/api/options/BaseTaskOptions.java @@ -7,13 +7,21 @@ import java.util.Map; * The options class for RayCall or ActorCreation. */ public abstract class BaseTaskOptions { + // DO NOT set this environment variable. It's only used for test purposes. + // Please use `setUseDirectCall` instead. + public static final boolean DEFAULT_USE_DIRECT_CALL = "1" + .equals(System.getenv("DEFAULT_USE_DIRECT_CALL")); + public final Map resources; + public final boolean useDirectCall; + public BaseTaskOptions() { resources = new HashMap<>(); + useDirectCall = DEFAULT_USE_DIRECT_CALL; } - public BaseTaskOptions(Map resources) { + public BaseTaskOptions(Map resources, boolean useDirectCall) { for (Map.Entry entry : resources.entrySet()) { if (entry.getValue().compareTo(0.0) <= 0) { throw new IllegalArgumentException(String.format("Resource capacity should be " + @@ -21,6 +29,7 @@ public abstract class BaseTaskOptions { } } this.resources = resources; + this.useDirectCall = useDirectCall; } } diff --git a/java/api/src/main/java/org/ray/api/options/CallOptions.java b/java/api/src/main/java/org/ray/api/options/CallOptions.java index 1e5b61bf1..b4be033fd 100644 --- a/java/api/src/main/java/org/ray/api/options/CallOptions.java +++ b/java/api/src/main/java/org/ray/api/options/CallOptions.java @@ -8,8 +8,8 @@ import java.util.Map; */ public class CallOptions extends BaseTaskOptions { - private CallOptions(Map resources) { - super(resources); + private CallOptions(Map resources, boolean useDirectCall) { + super(resources, useDirectCall); } /** @@ -18,14 +18,23 @@ public class CallOptions extends BaseTaskOptions { public static class Builder { private Map resources = new HashMap<>(); + private boolean useDirectCall = DEFAULT_USE_DIRECT_CALL; public Builder setResources(Map resources) { this.resources = resources; return this; } + // Since direct call is not fully supported yet (see issue #5559), + // users are not allowed to set the option to true. + // TODO (kfstorm): uncomment when direct call is ready. + // public Builder setUseDirectCall(boolean useDirectCall) { + // this.useDirectCall = useDirectCall; + // return this; + // } + public CallOptions createCallOptions() { - return new CallOptions(resources); + return new CallOptions(resources, useDirectCall); } } } diff --git a/java/test.sh b/java/test.sh index 59080fbb3..d1f259a64 100755 --- a/java/test.sh +++ b/java/test.sh @@ -34,7 +34,7 @@ echo "Running tests under cluster mode." ENABLE_MULTI_LANGUAGE_TESTS=1 run_testng java -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.testng.TestNG -d /tmp/ray_java_test_output $ROOT_DIR/testng.xml echo "Running tests under cluster mode with direct actor call turned on." -ENABLE_MULTI_LANGUAGE_TESTS=1 ACTOR_CREATION_OPTIONS_DEFAULT_USE_DIRECT_CALL=1 run_testng java -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.testng.TestNG -d /tmp/ray_java_test_output $ROOT_DIR/testng.xml +ENABLE_MULTI_LANGUAGE_TESTS=1 DEFAULT_USE_DIRECT_CALL=1 run_testng java -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.testng.TestNG -d /tmp/ray_java_test_output $ROOT_DIR/testng.xml echo "Running tests under single-process mode." # bazel test //java:all_tests --jvmopt="-Dray.run-mode=SINGLE_PROCESS" --test_output="errors" || single_exit_code=$? diff --git a/java/test/src/main/java/org/ray/api/RayAlterSuiteListener.java b/java/test/src/main/java/org/ray/api/RayAlterSuiteListener.java deleted file mode 100644 index d5d042c1d..000000000 --- a/java/test/src/main/java/org/ray/api/RayAlterSuiteListener.java +++ /dev/null @@ -1,23 +0,0 @@ -package org.ray.api; - -import java.util.List; -import org.ray.api.options.ActorCreationOptions; -import org.testng.IAlterSuiteListener; -import org.testng.xml.XmlGroups; -import org.testng.xml.XmlRun; -import org.testng.xml.XmlSuite; - -public class RayAlterSuiteListener implements IAlterSuiteListener { - - @Override - public void alter(List suites) { - XmlSuite suite = suites.get(0); - if (ActorCreationOptions.DEFAULT_USE_DIRECT_CALL) { - XmlGroups groups = new XmlGroups(); - XmlRun run = new XmlRun(); - run.onInclude("directCall"); - groups.setRun(run); - suite.setGroups(groups); - } - } -} diff --git a/java/test/src/main/java/org/ray/api/test/ActorConcurrentCallTest.java b/java/test/src/main/java/org/ray/api/test/ActorConcurrentCallTest.java index 850bae9dd..6e0853890 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorConcurrentCallTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorConcurrentCallTest.java @@ -13,7 +13,7 @@ import org.testng.Assert; import org.testng.annotations.Test; -@Test(groups = {"directCall"}) +@Test public class ActorConcurrentCallTest extends BaseTest { @RayRemote diff --git a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java index 43ccfe0ff..2fed0a1ab 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java @@ -17,7 +17,7 @@ import org.ray.api.options.ActorCreationOptions; import org.testng.Assert; import org.testng.annotations.Test; -@Test(groups = {"directCall"}) +@Test public class ActorReconstructionTest extends BaseTest { @RayRemote() diff --git a/java/test/src/main/java/org/ray/api/test/ActorTest.java b/java/test/src/main/java/org/ray/api/test/ActorTest.java index 3bf3c79f2..3688a7119 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorTest.java @@ -16,7 +16,7 @@ import org.ray.api.id.UniqueId; import org.testng.Assert; import org.testng.annotations.Test; -@Test(groups = {"directCall"}) +@Test public class ActorTest extends BaseTest { @RayRemote diff --git a/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java b/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java index 5b22b9951..5a765a15f 100644 --- a/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java +++ b/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java @@ -60,7 +60,7 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest { Assert.assertEquals(res.get(), "[Python]py_func -> [Java]bytesEcho -> hello".getBytes()); } - @Test(groups = {"directCall"}) + @Test public void testCallingPythonActor() { // Python worker doesn't support direct call yet. TestUtils.skipTestIfDirectActorCallEnabled(); diff --git a/java/test/src/main/java/org/ray/api/test/FailureTest.java b/java/test/src/main/java/org/ray/api/test/FailureTest.java index a7091c2b2..dcba4656f 100644 --- a/java/test/src/main/java/org/ray/api/test/FailureTest.java +++ b/java/test/src/main/java/org/ray/api/test/FailureTest.java @@ -91,14 +91,14 @@ public class FailureTest extends BaseTest { assertTaskFailedWithRayTaskException(Ray.call(FailureTest::badFunc)); } - @Test(groups = {"directCall"}) + @Test public void testActorCreationFailure() { TestUtils.skipTestUnderSingleProcess(); RayActor actor = Ray.createActor(BadActor::new, true); assertTaskFailedWithRayTaskException(Ray.call(BadActor::badMethod, actor)); } - @Test(groups = {"directCall"}) + @Test public void testActorTaskFailure() { TestUtils.skipTestUnderSingleProcess(); RayActor actor = Ray.createActor(BadActor::new, false); @@ -117,7 +117,7 @@ public class FailureTest extends BaseTest { } } - @Test(groups = {"directCall"}) + @Test public void testActorProcessDying() { TestUtils.skipTestUnderSingleProcess(); // This test case hangs if the worker to worker connection is implemented with grpc. diff --git a/java/test/src/main/java/org/ray/api/test/KillActorTest.java b/java/test/src/main/java/org/ray/api/test/KillActorTest.java index 9cf7b4eec..4f584f397 100644 --- a/java/test/src/main/java/org/ray/api/test/KillActorTest.java +++ b/java/test/src/main/java/org/ray/api/test/KillActorTest.java @@ -10,7 +10,7 @@ import org.ray.api.exception.RayActorException; import org.testng.Assert; import org.testng.annotations.Test; -@Test(groups = { "directCall" }) +@Test public class KillActorTest extends BaseTest { @RayRemote diff --git a/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java b/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java index 1ba045f51..b50bf5306 100644 --- a/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java +++ b/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java @@ -21,7 +21,7 @@ import org.slf4j.LoggerFactory; import org.testng.Assert; import org.testng.annotations.Test; -@Test(groups = {"directCall"}) +@Test public class MultiThreadingTest extends BaseTest { private static final Logger LOGGER = LoggerFactory.getLogger(MultiThreadingTest.class); diff --git a/java/test/src/main/java/org/ray/api/test/StressTest.java b/java/test/src/main/java/org/ray/api/test/StressTest.java index fad127cd8..e2d150bc7 100644 --- a/java/test/src/main/java/org/ray/api/test/StressTest.java +++ b/java/test/src/main/java/org/ray/api/test/StressTest.java @@ -74,7 +74,7 @@ public class StressTest extends BaseTest { } } - @Test(enabled = false, groups = {"directCall"}) + @Test(enabled = false) public void testSubmittingManyTasksToOneActor() throws Exception { TestUtils.skipTestUnderSingleProcess(); RayActor actor = Ray.createActor(Actor::new); diff --git a/java/testng.xml b/java/testng.xml index 9d788abc6..f4659af98 100644 --- a/java/testng.xml +++ b/java/testng.xml @@ -8,7 +8,6 @@ - diff --git a/python/ray/serialization.py b/python/ray/serialization.py index 686474cd0..c7a352366 100644 --- a/python/ray/serialization.py +++ b/python/ray/serialization.py @@ -10,6 +10,7 @@ from ray.utils import _random_string from ray.gcs_utils import ErrorType from ray.exceptions import ( PlasmaObjectNotAvailable, + RayTaskError, RayActorError, RayWorkerError, UnreconstructableError, @@ -52,9 +53,9 @@ class SerializedObject: class Pickle5SerializedObject(SerializedObject): - def __init__(self, inband, writer, contained_object_ids): - super(Pickle5SerializedObject, self).__init__( - ray_constants.PICKLE5_BUFFER_METADATA, contained_object_ids) + def __init__(self, metadata, inband, writer, contained_object_ids): + super(Pickle5SerializedObject, self).__init__(metadata, + contained_object_ids) self.inband = inband self.writer = writer # cached total bytes @@ -226,24 +227,26 @@ class SerializationContext: self._thread_local.object_ids.add(object_id) + def _deserialize_pickle5_data(self, data): + if not self.use_pickle: + raise ValueError("Receiving pickle5 serialized objects " + "while the serialization context is " + "using a custom raw backend.") + try: + in_band, buffers = unpack_pickle5_buffers(data) + if len(buffers) > 0: + obj = pickle.loads(in_band, buffers=buffers) + else: + obj = pickle.loads(in_band) + # cloudpickle does not provide error types + except pickle.pickle.PicklingError: + raise DeserializationError() + return obj + def _deserialize_object(self, data, metadata, object_id): if metadata: if metadata == ray_constants.PICKLE5_BUFFER_METADATA: - if not self.use_pickle: - raise ValueError("Receiving pickle5 serialized objects " - "while the serialization context is " - "using a custom raw backend.") - try: - in_band, buffers = unpack_pickle5_buffers(data) - if len(buffers) > 0: - obj = pickle.loads(in_band, buffers=buffers) - else: - obj = pickle.loads(in_band) - # cloudpickle does not provide error types - except pickle.pickle.PicklingError: - raise DeserializationError() - return obj - + return self._deserialize_pickle5_data(data) # Check if the object should be returned as raw bytes. if metadata == ray_constants.RAW_BUFFER_METADATA: if data is None: @@ -252,7 +255,14 @@ class SerializationContext: # Otherwise, return an exception object based on # the error type. error_type = int(metadata) - if error_type == ErrorType.Value("WORKER_DIED"): + # RayTaskError is serialized with pickle5 in the data field. + # TODO (kfstorm): exception serialization should be language + # independent. + if error_type == ErrorType.Value("TASK_EXECUTION_EXCEPTION"): + obj = self._deserialize_pickle5_data(data) + assert isinstance(obj, RayTaskError) + return obj + elif error_type == ErrorType.Value("WORKER_DIED"): return RayWorkerError() elif error_type == ErrorType.Value("ACTOR_DIED"): return RayActorError() @@ -326,15 +336,24 @@ class SerializationContext: # use a special metadata to indicate it's raw binary. So # that this object can also be read by Java. return RawSerializedObject(value) + else: + # Only RayTaskError is possible to be serialized here. We don't + # need to deal with other exception types here. + if isinstance(value, RayTaskError): + metadata = str(ErrorType.Value( + "TASK_EXECUTION_EXCEPTION")).encode("ascii") + else: + metadata = ray_constants.PICKLE5_BUFFER_METADATA - assert self.worker.use_pickle - assert ray.cloudpickle.FAST_CLOUDPICKLE_USED - writer = Pickle5Writer() - # TODO(swang): Check that contained_object_ids is empty. - inband = pickle.dumps( - value, protocol=5, buffer_callback=writer.buffer_callback) - return Pickle5SerializedObject( - inband, writer, self.get_and_clear_contained_object_ids()) + assert self.worker.use_pickle + assert ray.cloudpickle.FAST_CLOUDPICKLE_USED + writer = Pickle5Writer() + # TODO(swang): Check that contained_object_ids is empty. + inband = pickle.dumps( + value, protocol=5, buffer_callback=writer.buffer_callback) + return Pickle5SerializedObject( + metadata, inband, writer, + self.get_and_clear_contained_object_ids()) def register_custom_serializer(self, cls, diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index ced05d1be..84d86738b 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -5,6 +5,7 @@ import sys import tempfile import threading import time +import uuid import numpy as np import pytest @@ -82,6 +83,56 @@ def test_failed_task(ray_start_regular): assert False +def test_get_throws_quickly_when_found_exception(ray_start_regular): + def random_path(): + return os.path.join(tempfile.gettempdir(), uuid.uuid4().hex) + + def touch(path): + with open(path, "w"): + pass + + def wait_for_file(path): + while True: + if os.path.exists(path): + break + time.sleep(0.1) + + # We use an actor instead of functions here. If we use functions, it's + # very likely that two normal tasks are submitted before the first worker + # is registered to Raylet. Since `maximum_startup_concurrency` is 1, + # the worker pool will wait for the registration of the first worker + # and skip starting new workers. The result is, the two tasks will be + # executed sequentially, which breaks an assumption of this test case - + # the two tasks run in parallel. + @ray.remote + class Actor(object): + def bad_func1(self): + raise Exception("Test function intentionally failed.") + + def bad_func2(self): + os._exit(0) + + def slow_func(self, path): + wait_for_file(path) + + def expect_exception(objects, exception): + with pytest.raises(ray.exceptions.RayError) as err: + ray.get(objects) + assert err.type is exception + + f = random_path() + actor = Actor.options(is_direct_call=True, max_concurrency=2).remote() + expect_exception([actor.bad_func1.remote(), + actor.slow_func.remote(f)], ray.exceptions.RayTaskError) + touch(f) + + f = random_path() + actor = Actor.options(is_direct_call=True, max_concurrency=2).remote() + expect_exception([actor.bad_func2.remote(), + actor.slow_func.remote(f)], ray.exceptions.RayActorError) + touch(f) + + def test_fail_importing_remote_function(ray_start_2_cpus): # Create the contents of a temporary Python file. temporary_python_file = """ diff --git a/src/ray/core_worker/lib/java/jni_init.cc b/src/ray/core_worker/lib/java/jni_init.cc index 9b3fe5ee1..56271f15d 100644 --- a/src/ray/core_worker/lib/java/jni_init.cc +++ b/src/ray/core_worker/lib/java/jni_init.cc @@ -50,11 +50,11 @@ jfieldID java_function_arg_value; jclass java_base_task_options_class; jfieldID java_base_task_options_resources; +jfieldID java_base_task_options_use_direct_call; +jfieldID java_base_task_options_default_use_direct_call; jclass java_actor_creation_options_class; -jfieldID java_actor_creation_options_default_use_direct_call; jfieldID java_actor_creation_options_max_reconstructions; -jfieldID java_actor_creation_options_use_direct_call; jfieldID java_actor_creation_options_jvm_options; jfieldID java_actor_creation_options_max_concurrency; @@ -155,15 +155,15 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { java_base_task_options_class = LoadClass(env, "org/ray/api/options/BaseTaskOptions"); java_base_task_options_resources = env->GetFieldID(java_base_task_options_class, "resources", "Ljava/util/Map;"); + java_base_task_options_use_direct_call = + env->GetFieldID(java_base_task_options_class, "useDirectCall", "Z"); + java_base_task_options_default_use_direct_call = + env->GetStaticFieldID(java_base_task_options_class, "DEFAULT_USE_DIRECT_CALL", "Z"); java_actor_creation_options_class = LoadClass(env, "org/ray/api/options/ActorCreationOptions"); - java_actor_creation_options_default_use_direct_call = env->GetStaticFieldID( - java_actor_creation_options_class, "DEFAULT_USE_DIRECT_CALL", "Z"); java_actor_creation_options_max_reconstructions = env->GetFieldID(java_actor_creation_options_class, "maxReconstructions", "I"); - java_actor_creation_options_use_direct_call = - env->GetFieldID(java_actor_creation_options_class, "useDirectCall", "Z"); java_actor_creation_options_jvm_options = env->GetFieldID( java_actor_creation_options_class, "jvmOptions", "Ljava/lang/String;"); java_actor_creation_options_max_concurrency = @@ -190,9 +190,7 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { "(Ljava/util/List;Ljava/util/List;)Ljava/util/List;"); java_task_executor_get = env->GetStaticMethodID( - java_task_executor_class, - "get", - "([B)Lorg/ray/runtime/task/TaskExecutor;"); + java_task_executor_class, "get", "([B)Lorg/ray/runtime/task/TaskExecutor;"); return CURRENT_JNI_VERSION; } diff --git a/src/ray/core_worker/lib/java/jni_utils.h b/src/ray/core_worker/lib/java/jni_utils.h index d2b39f1dd..fb6e57ed6 100644 --- a/src/ray/core_worker/lib/java/jni_utils.h +++ b/src/ray/core_worker/lib/java/jni_utils.h @@ -94,15 +94,15 @@ extern jfieldID java_function_arg_value; extern jclass java_base_task_options_class; /// resources field of BaseTaskOptions class extern jfieldID java_base_task_options_resources; +/// useDirectCall field of BaseTaskOptions class +extern jfieldID java_base_task_options_use_direct_call; +/// DEFAULT_USE_DIRECT_CALL field of BaseTaskOptions class +extern jfieldID java_base_task_options_default_use_direct_call; /// ActorCreationOptions class extern jclass java_actor_creation_options_class; -/// DEFAULT_USE_DIRECT_CALL field of ActorCreationOptions class -extern jfieldID java_actor_creation_options_default_use_direct_call; /// maxReconstructions field of ActorCreationOptions class extern jfieldID java_actor_creation_options_max_reconstructions; -/// useDirectCall field of ActorCreationOptions class -extern jfieldID java_actor_creation_options_use_direct_call; /// jvmOptions field of ActorCreationOptions class extern jfieldID java_actor_creation_options_jvm_options; /// maxConcurrency field of ActorCreationOptions class diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc index 03d3a567e..dcd07a983 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc @@ -76,13 +76,19 @@ inline std::unordered_map ToResources(JNIEnv *env, inline ray::TaskOptions ToTaskOptions(JNIEnv *env, jint numReturns, jobject callOptions) { std::unordered_map resources; + bool use_direct_call; if (callOptions) { jobject java_resources = env->GetObjectField(callOptions, java_base_task_options_resources); resources = ToResources(env, java_resources); + use_direct_call = + env->GetBooleanField(callOptions, java_base_task_options_use_direct_call); + } else { + use_direct_call = env->GetStaticBooleanField( + java_base_task_options_class, java_base_task_options_default_use_direct_call); } - ray::TaskOptions task_options{numReturns, /*is_direct_call=*/false, resources}; + ray::TaskOptions task_options{numReturns, use_direct_call, resources}; return task_options; } @@ -97,7 +103,7 @@ inline ray::ActorCreationOptions ToActorCreationOptions(JNIEnv *env, max_reconstructions = static_cast(env->GetIntField( actorCreationOptions, java_actor_creation_options_max_reconstructions)); use_direct_call = env->GetBooleanField(actorCreationOptions, - java_actor_creation_options_use_direct_call); + java_base_task_options_use_direct_call); jobject java_resources = env->GetObjectField(actorCreationOptions, java_base_task_options_resources); resources = ToResources(env, java_resources); @@ -110,9 +116,8 @@ inline ray::ActorCreationOptions ToActorCreationOptions(JNIEnv *env, max_concurrency = static_cast(env->GetIntField( actorCreationOptions, java_actor_creation_options_max_concurrency)); } else { - use_direct_call = - env->GetStaticBooleanField(java_actor_creation_options_class, - java_actor_creation_options_default_use_direct_call); + use_direct_call = env->GetStaticBooleanField( + java_base_task_options_class, java_base_task_options_default_use_direct_call); } ray::ActorCreationOptions actor_creation_options{ @@ -139,9 +144,10 @@ JNIEXPORT jobject JNICALL Java_org_ray_runtime_task_NativeTaskSubmitter_nativeSu auto task_options = ToTaskOptions(env, numReturns, callOptions); std::vector return_ids; + // TODO (kfstorm): Allow setting `max_retries` via `CallOptions`. auto status = GetCoreWorker(nativeCoreWorkerPointer) .SubmitTask(ray_function, task_args, task_options, &return_ids, - /*max_retries=*/1); + /*max_retries=*/0); THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.cc b/src/ray/core_worker/store_provider/memory_store/memory_store.cc index 210cda9c5..e7a366356 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.cc +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.cc @@ -10,7 +10,7 @@ namespace ray { class GetRequest { public: GetRequest(absl::flat_hash_set object_ids, size_t num_objects, - bool remove_after_get); + bool remove_after_get, bool abort_if_any_object_is_exception); const absl::flat_hash_set &ObjectIds() const; @@ -40,6 +40,8 @@ class GetRequest { // Whether the requested objects should be removed from store // after `get` returns. const bool remove_after_get_; + // Whether we should abort the waiting if any object is an exception. + const bool abort_if_any_object_is_exception_; // Whether all the requested objects are available. bool is_ready_; mutable std::mutex mutex_; @@ -47,10 +49,11 @@ class GetRequest { }; GetRequest::GetRequest(absl::flat_hash_set object_ids, size_t num_objects, - bool remove_after_get) + bool remove_after_get, bool abort_if_any_object_is_exception_) : object_ids_(std::move(object_ids)), num_objects_(num_objects), remove_after_get_(remove_after_get), + abort_if_any_object_is_exception_(abort_if_any_object_is_exception_), is_ready_(false) { RAY_CHECK(num_objects_ <= object_ids_.size()); } @@ -91,7 +94,9 @@ void GetRequest::Set(const ObjectID &object_id, std::shared_ptr objec return; // We have already hit the number of objects to return limit. } objects_.emplace(object_id, object); - if (objects_.size() == num_objects_) { + if (objects_.size() == num_objects_ || + (abort_if_any_object_is_exception_ && object->IsException() && + !object->IsInPlasmaError())) { is_ready_ = true; cv_.notify_all(); } @@ -219,6 +224,15 @@ Status CoreWorkerMemoryStore::Get(const std::vector &object_ids, int num_objects, int64_t timeout_ms, const WorkerContext &ctx, bool remove_after_get, std::vector> *results) { + return GetImpl(object_ids, num_objects, timeout_ms, ctx, remove_after_get, results, + /*abort_if_any_object_is_exception=*/true); +} + +Status CoreWorkerMemoryStore::GetImpl(const std::vector &object_ids, + int num_objects, int64_t timeout_ms, + const WorkerContext &ctx, bool remove_after_get, + std::vector> *results, + bool abort_if_any_object_is_exception) { (*results).resize(object_ids.size(), nullptr); std::shared_ptr get_request; @@ -263,8 +277,9 @@ Status CoreWorkerMemoryStore::Get(const std::vector &object_ids, size_t required_objects = num_objects - (object_ids.size() - remaining_ids.size()); // Otherwise, create a GetRequest to track remaining objects. - get_request = std::make_shared(std::move(remaining_ids), required_objects, - remove_after_get); + get_request = + std::make_shared(std::move(remaining_ids), required_objects, + remove_after_get, abort_if_any_object_is_exception); for (const auto &object_id : get_request->ObjectIds()) { object_get_requests_[object_id].push_back(get_request); } @@ -378,7 +393,8 @@ Status CoreWorkerMemoryStore::Wait(const absl::flat_hash_set &object_i std::vector id_vector(object_ids.begin(), object_ids.end()); std::vector> result_objects; RAY_CHECK(object_ids.size() == id_vector.size()); - auto status = Get(id_vector, num_objects, timeout_ms, ctx, false, &result_objects); + auto status = GetImpl(id_vector, num_objects, timeout_ms, ctx, false, &result_objects, + /*abort_if_any_object_is_exception=*/false); // Ignore TimedOut statuses since we return ready objects explicitly. if (!status.IsTimedOut()) { RAY_RETURN_NOT_OK(status); diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.h b/src/ray/core_worker/store_provider/memory_store/memory_store.h index 2089de675..38ea97c71 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.h +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.h @@ -134,6 +134,14 @@ class CoreWorkerMemoryStore { uint64_t UsedMemory(); private: + /// See the public version of `Get` for meaning of the other arguments. + /// \param[in] abort_if_any_object_is_exception Whether we should abort if any object + /// is an exception. + Status GetImpl(const std::vector &object_ids, int num_objects, + int64_t timeout_ms, const WorkerContext &ctx, bool remove_after_get, + std::vector> *results, + bool abort_if_any_object_is_exception); + /// Optional callback for putting objects into the plasma store. std::function store_in_plasma_;