From 7ee0fdba3dd1420729c2780b823a65fc3cbc6386 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Sat, 5 Sep 2020 10:09:19 +0800 Subject: [PATCH] [Java] Throw exception if Ray.init() is not called and users try to access ray API (#10497) --- java/api/src/main/java/io/ray/api/Ray.java | 42 +++++++++++-------- .../io/ray/runtime/object/ObjectRefImpl.java | 9 ++-- .../io/ray/test/BaseMultiLanguageTest.java | 2 +- .../src/main/java/io/ray/test/BaseTest.java | 2 +- .../streaming/api/context/ClusterStarter.java | 2 +- .../api/context/StreamingContext.java | 4 +- 6 files changed, 35 insertions(+), 26 deletions(-) diff --git a/java/api/src/main/java/io/ray/api/Ray.java b/java/api/src/main/java/io/ray/api/Ray.java index 623256db4..20c5422c6 100644 --- a/java/api/src/main/java/io/ray/api/Ray.java +++ b/java/api/src/main/java/io/ray/api/Ray.java @@ -49,7 +49,7 @@ public final class Ray extends RayCall { */ public static synchronized void shutdown() { if (runtime != null) { - runtime.shutdown(); + internal().shutdown(); runtime = null; } } @@ -61,7 +61,7 @@ public final class Ray extends RayCall { * @return A ObjectRef instance that represents the in-store object. */ public static ObjectRef put(T obj) { - return runtime.put(obj); + return internal().put(obj); } /** @@ -71,7 +71,7 @@ public final class Ray extends RayCall { * @return The Java object. */ public static T get(ObjectRef objectRef) { - return runtime.get(objectRef); + return internal().get(objectRef); } /** @@ -81,7 +81,7 @@ public final class Ray extends RayCall { * @return A list of Java objects. */ public static List get(List> objectList) { - return runtime.get(objectList); + return internal().get(objectList); } /** @@ -95,7 +95,7 @@ public final class Ray extends RayCall { */ public static WaitResult wait(List> waitList, int numReturns, int timeoutMs) { - return runtime.wait(waitList, numReturns, timeoutMs); + return internal().wait(waitList, numReturns, timeoutMs); } /** @@ -107,7 +107,7 @@ public final class Ray extends RayCall { * @return Two lists, one containing locally available objects, one containing the rest. */ public static WaitResult wait(List> waitList, int numReturns) { - return runtime.wait(waitList, numReturns, Integer.MAX_VALUE); + return internal().wait(waitList, numReturns, Integer.MAX_VALUE); } /** @@ -118,7 +118,7 @@ public final class Ray extends RayCall { * @return Two lists, one containing locally available objects, one containing the rest. */ public static WaitResult wait(List> waitList) { - return runtime.wait(waitList, waitList.size(), Integer.MAX_VALUE); + return internal().wait(waitList, waitList.size(), Integer.MAX_VALUE); } /** @@ -132,7 +132,7 @@ public final class Ray extends RayCall { * Optional.empty() */ public static Optional getActor(String name) { - return runtime.getActor(name, false); + return internal().getActor(name, false); } /** @@ -146,7 +146,7 @@ public final class Ray extends RayCall { * Optional.empty() */ public static Optional getGlobalActor(String name) { - return runtime.getActor(name, true); + return internal().getActor(name, true); } /** @@ -156,7 +156,7 @@ public final class Ray extends RayCall { * @return The async context. */ public static Object getAsyncContext() { - return runtime.getAsyncContext(); + return internal().getAsyncContext(); } /** @@ -165,7 +165,7 @@ public final class Ray extends RayCall { * @param asyncContext The async context to set. */ public static void setAsyncContext(Object asyncContext) { - runtime.setAsyncContext(asyncContext); + internal().setAsyncContext(asyncContext); } // TODO (kfstorm): add the `rollbackAsyncContext` API to allow rollbacking the async context of @@ -181,7 +181,7 @@ public final class Ray extends RayCall { * @return The wrapped runnable. */ public static Runnable wrapRunnable(Runnable runnable) { - return runtime.wrapRunnable(runnable); + return internal().wrapRunnable(runnable); } /** @@ -192,13 +192,21 @@ public final class Ray extends RayCall { * @return The wrapped callable. */ public static Callable wrapCallable(Callable callable) { - return runtime.wrapCallable(callable); + return internal().wrapCallable(callable); + } + + public static boolean isInitialized() { + return runtime != null; } /** * Get the underlying runtime instance. */ public static RayRuntime internal() { + if (runtime == null) { + throw new IllegalStateException( + "Ray has not been started yet. You can start Ray with 'Ray.init()'"); + } return runtime; } @@ -207,21 +215,21 @@ public final class Ray extends RayCall { * Set the resource for the specific node. */ public static void setResource(UniqueId nodeId, String resourceName, double capacity) { - runtime.setResource(resourceName, capacity, nodeId); + internal().setResource(resourceName, capacity, nodeId); } /** * Set the resource for local node. */ public static void setResource(String resourceName, double capacity) { - runtime.setResource(resourceName, capacity, UniqueId.NIL); + internal().setResource(resourceName, capacity, UniqueId.NIL); } /** * Get the runtime context. */ public static RuntimeContext getRuntimeContext() { - return runtime.getRuntimeContext(); + return internal().getRuntimeContext(); } /** @@ -239,7 +247,7 @@ public final class Ray extends RayCall { */ public static PlacementGroup createPlacementGroup(List> bundles, PlacementStrategy strategy) { - return runtime.createPlacementGroup(bundles, strategy); + return internal().createPlacementGroup(bundles, strategy); } /** diff --git a/java/runtime/src/main/java/io/ray/runtime/object/ObjectRefImpl.java b/java/runtime/src/main/java/io/ray/runtime/object/ObjectRefImpl.java index 839ed5575..d52494d54 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/ObjectRefImpl.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/ObjectRefImpl.java @@ -40,7 +40,8 @@ public final class ObjectRefImpl implements ObjectRef, Externalizable { addLocalReference(); } - public ObjectRefImpl() {} + public ObjectRefImpl() { + } @Override public synchronized T get() { @@ -103,10 +104,10 @@ public final class ObjectRefImpl implements ObjectRef, Externalizable { // unit tests). So if `workerId` is null, it means this method has been invoked. if (!removed.getAndSet(true)) { REFERENCES.remove(this); - RayRuntimeInternal runtime = (RayRuntimeInternal) Ray.internal(); // It's possible that GC is executed after the runtime is shutdown. - if (runtime != null) { - runtime.getObjectStore().removeLocalReference(workerId, objectId); + if (Ray.isInitialized()) { + ((RayRuntimeInternal) (Ray.internal())).getObjectStore() + .removeLocalReference(workerId, objectId); } } } diff --git a/java/test/src/main/java/io/ray/test/BaseMultiLanguageTest.java b/java/test/src/main/java/io/ray/test/BaseMultiLanguageTest.java index c2c3d0c3a..4d89bd1db 100644 --- a/java/test/src/main/java/io/ray/test/BaseMultiLanguageTest.java +++ b/java/test/src/main/java/io/ray/test/BaseMultiLanguageTest.java @@ -91,7 +91,7 @@ public abstract class BaseMultiLanguageTest { } // Connect to the cluster. - Assert.assertNull(Ray.internal()); + Assert.assertFalse(Ray.isInitialized()); System.setProperty("ray.redis.address", "127.0.0.1:6379"); System.setProperty("ray.object-store.socket-name", PLASMA_STORE_SOCKET_NAME); System.setProperty("ray.raylet.socket-name", RAYLET_SOCKET_NAME); diff --git a/java/test/src/main/java/io/ray/test/BaseTest.java b/java/test/src/main/java/io/ray/test/BaseTest.java index 14643ec59..07f2dbf4f 100644 --- a/java/test/src/main/java/io/ray/test/BaseTest.java +++ b/java/test/src/main/java/io/ray/test/BaseTest.java @@ -19,7 +19,7 @@ public class BaseTest { @BeforeMethod(alwaysRun = true) public void setUpBase(Method method) { - Assert.assertNull(Ray.internal()); + Assert.assertFalse(Ray.isInitialized()); Ray.init(); // These files need to be deleted after each test case. filesToDelete = ImmutableList.of( diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/ClusterStarter.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/ClusterStarter.java index 89d166af7..1c820e6f2 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/ClusterStarter.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/ClusterStarter.java @@ -23,7 +23,7 @@ class ClusterStarter { private static final String RAYLET_SOCKET_NAME = "/tmp/ray/raylet_socket"; static synchronized void startCluster(boolean isCrossLanguage, boolean isLocal) { - Preconditions.checkArgument(Ray.internal() == null); + Preconditions.checkArgument(!Ray.isInitialized()); RayConfig.reset(); if (!isLocal) { System.setProperty("ray.raylet.config.num_workers_per_process_java", "1"); diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/StreamingContext.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/StreamingContext.java index 1fb30e043..d3264e95f 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/StreamingContext.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/StreamingContext.java @@ -63,7 +63,7 @@ public class StreamingContext implements Serializable { jobGraph.printJobGraph(); LOG.info("JobGraph digraph\n{}", jobGraph.generateDigraph()); - if (Ray.internal() == null) { + if (!Ray.isInitialized()) { if (Config.MEMORY_CHANNEL.equalsIgnoreCase(jobConfig.get(Config.CHANNEL_TYPE))) { Preconditions.checkArgument(!jobGraph.isCrossLanguageGraph()); ClusterStarter.startCluster(false, true); @@ -102,7 +102,7 @@ public class StreamingContext implements Serializable { } public void stop() { - if (Ray.internal() != null) { + if (Ray.isInitialized()) { ClusterStarter.stopCluster(jobGraph.isCrossLanguageGraph()); } }