From eb912b68b1ef8e77e613bd4eb69c0b6df982a63d Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Sat, 7 Dec 2019 16:28:29 +0800 Subject: [PATCH] [Java] Fix `instanceof RayPyActor` (#6377) --- .../org/ray/runtime/actor/NativeRayActor.java | 59 +++++++++---------- .../ray/runtime/actor/NativeRayJavaActor.java | 29 +++++++++ .../ray/runtime/actor/NativeRayPyActor.java | 40 +++++++++++++ .../ray/runtime/task/NativeTaskSubmitter.java | 2 +- .../main/java/org/ray/api/test/ActorTest.java | 3 + .../org_ray_runtime_actor_NativeRayActor.cc | 10 ---- .../org_ray_runtime_actor_NativeRayActor.h | 8 --- 7 files changed, 102 insertions(+), 49 deletions(-) create mode 100644 java/runtime/src/main/java/org/ray/runtime/actor/NativeRayJavaActor.java create mode 100644 java/runtime/src/main/java/org/ray/runtime/actor/NativeRayPyActor.java diff --git a/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayActor.java b/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayActor.java index a4d7c31b9..7c6f79368 100644 --- a/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayActor.java +++ b/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayActor.java @@ -8,40 +8,54 @@ import java.io.ObjectOutput; import java.util.List; import org.ray.api.Ray; import org.ray.api.RayActor; -import org.ray.api.RayPyActor; import org.ray.api.id.ActorId; -import org.ray.api.id.UniqueId; import org.ray.api.runtime.RayRuntime; -import org.ray.runtime.AbstractRayRuntime; -import org.ray.runtime.RayNativeRuntime; import org.ray.runtime.RayMultiWorkerNativeRuntime; +import org.ray.runtime.RayNativeRuntime; import org.ray.runtime.generated.Common.Language; /** - * RayActor implementation for cluster mode. This is a wrapper class for C++ ActorHandle. + * RayActor abstract language-independent implementation for cluster mode. This is a wrapper class + * for C++ ActorHandle. */ -public class NativeRayActor implements RayActor, RayPyActor, Externalizable { +public abstract class NativeRayActor implements RayActor, Externalizable { /** * Address of core worker. */ - private long nativeCoreWorkerPointer; + long nativeCoreWorkerPointer; /** * ID of the actor. */ - private byte[] actorId; + byte[] actorId; - public NativeRayActor(long nativeCoreWorkerPointer, byte[] actorId) { + private Language language; + + NativeRayActor(long nativeCoreWorkerPointer, byte[] actorId, Language language) { Preconditions.checkState(nativeCoreWorkerPointer != 0); Preconditions.checkState(!ActorId.fromBytes(actorId).isNil()); this.nativeCoreWorkerPointer = nativeCoreWorkerPointer; this.actorId = actorId; + this.language = language; } /** * Required by FST */ - public NativeRayActor() { + NativeRayActor() { + } + + public static NativeRayActor create(long nativeCoreWorkerPointer, byte[] actorId, + Language language) { + Preconditions.checkState(nativeCoreWorkerPointer != 0); + switch (language) { + case JAVA: + return new NativeRayJavaActor(nativeCoreWorkerPointer, actorId); + case PYTHON: + return new NativeRayPyActor(nativeCoreWorkerPointer, actorId); + default: + throw new IllegalStateException("Unknown actor handle language: " + language); + } } @Override @@ -50,30 +64,17 @@ public class NativeRayActor implements RayActor, RayPyActor, Externalizable { } public Language getLanguage() { - return Language.forNumber(nativeGetLanguage(nativeCoreWorkerPointer, actorId)); + return language; } public boolean isDirectCallActor() { return nativeIsDirectCallActor(nativeCoreWorkerPointer, actorId); } - @Override - public String getModuleName() { - Preconditions.checkState(getLanguage() == Language.PYTHON); - return nativeGetActorCreationTaskFunctionDescriptor( - nativeCoreWorkerPointer, actorId).get(0); - } - - @Override - public String getClassName() { - Preconditions.checkState(getLanguage() == Language.PYTHON); - return nativeGetActorCreationTaskFunctionDescriptor( - nativeCoreWorkerPointer, actorId).get(1); - } - @Override public void writeExternal(ObjectOutput out) throws IOException { out.writeObject(nativeSerialize(nativeCoreWorkerPointer, actorId)); + out.writeObject(language); } @Override @@ -82,11 +83,11 @@ public class NativeRayActor implements RayActor, RayPyActor, Externalizable { if (runtime instanceof RayMultiWorkerNativeRuntime) { runtime = ((RayMultiWorkerNativeRuntime) runtime).getCurrentRuntime(); } - Preconditions.checkState(runtime instanceof RayNativeRuntime); - nativeCoreWorkerPointer = ((RayNativeRuntime)runtime).getNativeCoreWorkerPointer(); + nativeCoreWorkerPointer = ((RayNativeRuntime) runtime).getNativeCoreWorkerPointer(); actorId = nativeDeserialize(nativeCoreWorkerPointer, (byte[]) in.readObject()); + language = (Language) in.readObject(); } @Override @@ -94,11 +95,9 @@ public class NativeRayActor implements RayActor, RayPyActor, Externalizable { // TODO(zhijunfu): do we need to free the ActorHandle in core worker? } - private static native int nativeGetLanguage(long nativeCoreWorkerPointer, byte[] actorId); - private static native boolean nativeIsDirectCallActor(long nativeCoreWorkerPointer, byte[] actorId); - private static native List nativeGetActorCreationTaskFunctionDescriptor( + static native List nativeGetActorCreationTaskFunctionDescriptor( long nativeCoreWorkerPointer, byte[] actorId); private static native byte[] nativeSerialize(long nativeCoreWorkerPointer, byte[] actorId); diff --git a/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayJavaActor.java b/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayJavaActor.java new file mode 100644 index 000000000..d103c04cb --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayJavaActor.java @@ -0,0 +1,29 @@ +package org.ray.runtime.actor; + +import com.google.common.base.Preconditions; +import java.io.IOException; +import java.io.ObjectInput; +import org.ray.runtime.generated.Common.Language; + +/** + * RayActor Java implementation for cluster mode. + */ +public class NativeRayJavaActor extends NativeRayActor { + + NativeRayJavaActor(long nativeCoreWorkerPointer, byte[] actorId) { + super(nativeCoreWorkerPointer, actorId, Language.JAVA); + } + + /** + * Required by FST + */ + public NativeRayJavaActor() { + super(); + } + + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + super.readExternal(in); + Preconditions.checkState(getLanguage() == Language.JAVA); + } +} diff --git a/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayPyActor.java b/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayPyActor.java new file mode 100644 index 000000000..40fbc1581 --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayPyActor.java @@ -0,0 +1,40 @@ +package org.ray.runtime.actor; + +import com.google.common.base.Preconditions; +import java.io.IOException; +import java.io.ObjectInput; +import org.ray.api.RayPyActor; +import org.ray.runtime.generated.Common.Language; + +/** + * RayActor Python implementation for cluster mode. + */ +public class NativeRayPyActor extends NativeRayActor implements RayPyActor { + + NativeRayPyActor(long nativeCoreWorkerPointer, byte[] actorId) { + super(nativeCoreWorkerPointer, actorId, Language.PYTHON); + } + + /** + * Required by FST + */ + public NativeRayPyActor() { + super(); + } + + @Override + public String getModuleName() { + return nativeGetActorCreationTaskFunctionDescriptor(nativeCoreWorkerPointer, actorId).get(0); + } + + @Override + public String getClassName() { + return nativeGetActorCreationTaskFunctionDescriptor(nativeCoreWorkerPointer, actorId).get(1); + } + + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + super.readExternal(in); + Preconditions.checkState(getLanguage() == Language.PYTHON); + } +} diff --git a/java/runtime/src/main/java/org/ray/runtime/task/NativeTaskSubmitter.java b/java/runtime/src/main/java/org/ray/runtime/task/NativeTaskSubmitter.java index bc1083d9b..803161420 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/NativeTaskSubmitter.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/NativeTaskSubmitter.java @@ -37,7 +37,7 @@ public class NativeTaskSubmitter implements TaskSubmitter { ActorCreationOptions options) { byte[] actorId = nativeCreateActor(nativeCoreWorkerPointer, functionDescriptor, args, options); - return new NativeRayActor(nativeCoreWorkerPointer, actorId); + return NativeRayActor.create(nativeCoreWorkerPointer, actorId, functionDescriptor.getLanguage()); } @Override diff --git a/java/test/src/main/java/org/ray/api/test/ActorTest.java b/java/test/src/main/java/org/ray/api/test/ActorTest.java index 5dd81443f..500b9c31b 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorTest.java @@ -7,6 +7,7 @@ import java.util.concurrent.TimeUnit; import org.ray.api.Ray; import org.ray.api.RayActor; import org.ray.api.RayObject; +import org.ray.api.RayPyActor; import org.ray.api.TestUtils; import org.ray.api.TestUtils.LargeObject; import org.ray.api.annotation.RayRemote; @@ -50,6 +51,8 @@ public class ActorTest extends BaseTest { // Test creating an actor from a constructor RayActor actor = Ray.createActor(Counter::new, 1); Assert.assertNotEquals(actor.getId(), UniqueId.NIL); + // A java actor is not a python actor + Assert.assertFalse(actor instanceof RayPyActor); // Test calling an actor Assert.assertEquals(Integer.valueOf(1), Ray.call(Counter::getValue, actor).get()); Ray.call(Counter::increase, actor, 1); diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.cc b/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.cc index e4379f98e..636118d7d 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.cc @@ -13,16 +13,6 @@ inline ray::CoreWorker &GetCoreWorker(jlong nativeCoreWorkerPointer) { extern "C" { #endif -JNIEXPORT jint JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeGetLanguage( - JNIEnv *env, jclass o, jlong nativeCoreWorkerPointer, jbyteArray actorId) { - auto actor_id = JavaByteArrayToId(env, actorId); - ray::ActorHandle *native_actor_handle = nullptr; - auto status = GetCoreWorker(nativeCoreWorkerPointer) - .GetActorHandle(actor_id, &native_actor_handle); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (jint)0); - return (jint)native_actor_handle->ActorLanguage(); -} - JNIEXPORT jboolean JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeIsDirectCallActor( JNIEnv *env, jclass o, jlong nativeCoreWorkerPointer, jbyteArray actorId) { diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.h b/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.h index 8f75e3a82..9a0f9c427 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.h +++ b/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.h @@ -7,14 +7,6 @@ #ifdef __cplusplus extern "C" { #endif -/* - * Class: org_ray_runtime_actor_NativeRayActor - * Method: nativeGetLanguage - * Signature: (J[B)I - */ -JNIEXPORT jint JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeGetLanguage( - JNIEnv *, jclass, jlong, jbyteArray); - /* * Class: org_ray_runtime_actor_NativeRayActor * Method: nativeIsDirectCallActor