diff --git a/java/api/src/main/java/io/ray/api/ObjectRef.java b/java/api/src/main/java/io/ray/api/ObjectRef.java index 15319bbe9..6b58781dc 100644 --- a/java/api/src/main/java/io/ray/api/ObjectRef.java +++ b/java/api/src/main/java/io/ray/api/ObjectRef.java @@ -1,7 +1,5 @@ package io.ray.api; -import io.ray.api.id.ObjectId; - /** * Represents a reference to an object in the object store. * @param The object type. @@ -14,15 +12,5 @@ public interface ObjectRef { */ T get(); - /** - * Get the object id. - */ - ObjectId getId(); - - /** - * Get the Object type. - */ - Class getType(); - } diff --git a/java/api/src/main/java/io/ray/api/Ray.java b/java/api/src/main/java/io/ray/api/Ray.java index 693c50dd0..1e9d68482 100644 --- a/java/api/src/main/java/io/ray/api/Ray.java +++ b/java/api/src/main/java/io/ray/api/Ray.java @@ -1,13 +1,11 @@ package io.ray.api; -import io.ray.api.id.ObjectId; import io.ray.api.id.UniqueId; import io.ray.api.placementgroup.PlacementGroup; import io.ray.api.placementgroup.PlacementStrategy; import io.ray.api.runtime.RayRuntime; import io.ray.api.runtime.RayRuntimeFactory; import io.ray.api.runtimecontext.RuntimeContext; -import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; @@ -67,25 +65,13 @@ public final class Ray extends RayCall { } /** - * Get an object by id from the object store. + * Get an object by `ObjectRef` from the object store. * - * @param objectId The ID of the object to get. - * @param objectType The type of the object to get. + * @param objectRef The reference of the object to get. * @return The Java object. */ - public static T get(ObjectId objectId, Class objectType) { - return runtime.get(objectId, objectType); - } - - /** - * Get a list of objects by ids from the object store. - * - * @param objectIds The list of object IDs. - * @param objectType The type of object. - * @return A list of Java objects. - */ - public static List get(List objectIds, Class objectType) { - return runtime.get(objectIds, objectType); + public static T get(ObjectRef objectRef) { + return runtime.get(objectRef); } /** @@ -95,13 +81,7 @@ public final class Ray extends RayCall { * @return A list of Java objects. */ public static List get(List> objectList) { - List objectIds = new ArrayList<>(); - Class objectType = null; - for (ObjectRef o : objectList) { - objectIds.add(o.getId()); - objectType = o.getType(); - } - return runtime.get(objectIds, objectType); + return runtime.get(objectList); } /** diff --git a/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java b/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java index 9340567ab..477fb029d 100644 --- a/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java +++ b/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java @@ -10,7 +10,6 @@ import io.ray.api.function.PyActorMethod; import io.ray.api.function.PyFunction; import io.ray.api.function.RayFunc; import io.ray.api.id.ActorId; -import io.ray.api.id.ObjectId; import io.ray.api.id.UniqueId; import io.ray.api.options.ActorCreationOptions; import io.ray.api.options.CallOptions; @@ -43,20 +42,18 @@ public interface RayRuntime { /** * Get an object from the object store. * - * @param objectId The ID of the object to get. - * @param objectType The type of the object to get. + * @param objectRef The reference of the object to get. * @return The Java object. */ - T get(ObjectId objectId, Class objectType); + T get(ObjectRef objectRef); /** * Get a list of objects from the object store. * - * @param objectIds The list of object IDs. - * @param objectType The type of object. + * @param objectRefs The list of object references. * @return A list of Java objects. */ - List get(List objectIds, Class objectType); + List get(List> objectRefs); /** * Wait for a list of RayObjects to be locally available, until specified number of objects are @@ -72,11 +69,11 @@ public interface RayRuntime { /** * Free a list of objects from Plasma Store. * - * @param objectIds The object ids to free. + * @param objectRefs The object references to free. * @param localOnly Whether only free objects for local object store or not. * @param deleteCreatingTasks Whether also delete objects' creating tasks from GCS. */ - void free(List objectIds, boolean localOnly, boolean deleteCreatingTasks); + void free(List> objectRefs, boolean localOnly, boolean deleteCreatingTasks); /** * Set the resource for the specific node. diff --git a/java/checkstyle-suppressions.xml b/java/checkstyle-suppressions.xml index e76437d14..88233117b 100644 --- a/java/checkstyle-suppressions.xml +++ b/java/checkstyle-suppressions.xml @@ -8,7 +8,6 @@ - diff --git a/java/dependencies.bzl b/java/dependencies.bzl index d1af092ef..e833b2800 100644 --- a/java/dependencies.bzl +++ b/java/dependencies.bzl @@ -18,7 +18,7 @@ def gen_java_deps(): "org.msgpack:msgpack-core:0.8.20", "org.ow2.asm:asm:6.0", "org.slf4j:slf4j-log4j12:1.7.25", - "org.testng:testng:6.9.10", + "org.testng:testng:7.1.0", "redis.clients:jedis:2.8.0", "net.java.dev.jna:jna:5.5.0", ], diff --git a/java/runtime/pom.xml b/java/runtime/pom.xml index 61ae25c51..280163cea 100644 --- a/java/runtime/pom.xml +++ b/java/runtime/pom.xml @@ -112,7 +112,7 @@ org.testng testng - 6.9.10 + 7.1.0 redis.clients diff --git a/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java index b2acec342..a8ed4fa3e 100644 --- a/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java @@ -36,10 +36,12 @@ import io.ray.runtime.task.ArgumentsBuilder; import io.ray.runtime.task.FunctionArg; import io.ray.runtime.task.TaskExecutor; import io.ray.runtime.task.TaskSubmitter; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.Callable; +import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -72,9 +74,6 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { runtimeContext = new RuntimeContextImpl(this); } - @Override - public abstract void shutdown(); - @Override public ObjectRef put(T obj) { ObjectId objectId = objectStore.put(obj); @@ -82,19 +81,27 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { } @Override - public T get(ObjectId objectId, Class objectType) throws RayException { - List ret = get(ImmutableList.of(objectId), objectType); + public T get(ObjectRef objectRef) throws RayException { + List ret = get(ImmutableList.of(objectRef)); return ret.get(0); } @Override - public List get(List objectIds, Class objectType) { + public List get(List> objectRefs) { + List objectIds = new ArrayList<>(); + Class objectType = null; + for (ObjectRef o : objectRefs) { + ObjectRefImpl objectRefImpl = (ObjectRefImpl) o; + objectIds.add(objectRefImpl.getId()); + objectType = objectRefImpl.getType(); + } return objectStore.get(objectIds, objectType); } @Override - public void free(List objectIds, boolean localOnly, boolean deleteCreatingTasks) { - objectStore.delete(objectIds, localOnly, deleteCreatingTasks); + public void free(List> objectRefs, boolean localOnly, boolean deleteCreatingTasks) { + objectStore.delete(objectRefs.stream().map(ref -> ((ObjectRefImpl) ref).getId()).collect( + Collectors.toList()), localOnly, deleteCreatingTasks); } @Override diff --git a/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java index 7773783d8..47eda035c 100644 --- a/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java @@ -22,6 +22,9 @@ import java.io.File; import java.io.IOException; import java.util.Map; import java.util.Optional; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; import org.apache.commons.io.FileUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -35,6 +38,15 @@ public final class RayNativeRuntime extends AbstractRayRuntime { private RunManager manager = null; + /** + * In Java, GC runs in a standalone thread, and we can't control the exact + * timing of garbage collection. By using this lock, when + * {@link NativeObjectStore#nativeRemoveLocalReference} is executing, the core + * worker will not be shut down, therefore it guarantees some kind of + * thread-safety. Note that this guarantee only works for driver. + */ + private final ReadWriteLock shutdownLock = new ReentrantReadWriteLock(); + static { LOGGER.debug("Loading native libraries."); @@ -105,7 +117,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime { taskExecutor = new NativeTaskExecutor(this); workerContext = new NativeWorkerContext(); - objectStore = new NativeObjectStore(workerContext); + objectStore = new NativeObjectStore(workerContext, shutdownLock); taskSubmitter = new NativeTaskSubmitter(); LOGGER.debug("RayNativeRuntime started with store {}, raylet {}", @@ -114,19 +126,25 @@ public final class RayNativeRuntime extends AbstractRayRuntime { @Override public void shutdown() { - if (rayConfig.workerMode == WorkerType.DRIVER) { - nativeShutdown(); - if (null != manager) { - manager.cleanup(); - manager = null; + Lock writeLock = shutdownLock.readLock(); + writeLock.lock(); + try { + if (rayConfig.workerMode == WorkerType.DRIVER) { + nativeShutdown(); + if (null != manager) { + manager.cleanup(); + manager = null; + } } + if (null != gcsClient) { + gcsClient.destroy(); + gcsClient = null; + } + RayConfig.reset(); + LOGGER.debug("RayNativeRuntime shutdown"); + } finally { + writeLock.unlock(); } - if (null != gcsClient) { - gcsClient.destroy(); - gcsClient = null; - } - RayConfig.reset(); - LOGGER.debug("RayNativeRuntime shutdown"); } // For test purpose only diff --git a/java/runtime/src/main/java/io/ray/runtime/object/LocalModeObjectStore.java b/java/runtime/src/main/java/io/ray/runtime/object/LocalModeObjectStore.java index 9a448c331..5a0b20f53 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/LocalModeObjectStore.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/LocalModeObjectStore.java @@ -2,6 +2,7 @@ package io.ray.runtime.object; import com.google.common.base.Preconditions; import io.ray.api.id.ObjectId; +import io.ray.api.id.UniqueId; import io.ray.runtime.context.WorkerContext; import java.util.ArrayList; import java.util.List; @@ -96,4 +97,12 @@ public class LocalModeObjectStore extends ObjectStore { pool.remove(objectId); } } + + @Override + public void addLocalReference(UniqueId workerId, ObjectId objectId) { + } + + @Override + public void removeLocalReference(UniqueId workerId, ObjectId objectId) { + } } diff --git a/java/runtime/src/main/java/io/ray/runtime/object/NativeObjectStore.java b/java/runtime/src/main/java/io/ray/runtime/object/NativeObjectStore.java index 9d0c6a317..fe39413b3 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/NativeObjectStore.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/NativeObjectStore.java @@ -2,8 +2,13 @@ package io.ray.runtime.object; import io.ray.api.id.BaseId; import io.ray.api.id.ObjectId; +import io.ray.api.id.UniqueId; import io.ray.runtime.context.WorkerContext; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReadWriteLock; import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -15,8 +20,11 @@ public class NativeObjectStore extends ObjectStore { private static final Logger LOGGER = LoggerFactory.getLogger(NativeObjectStore.class); - public NativeObjectStore(WorkerContext workerContext) { + private final ReadWriteLock shutdownLock; + + public NativeObjectStore(WorkerContext workerContext, ReadWriteLock shutdownLock) { super(workerContext); + this.shutdownLock = shutdownLock; } @Override @@ -44,6 +52,31 @@ public class NativeObjectStore extends ObjectStore { nativeDelete(toBinaryList(objectIds), localOnly, deleteCreatingTasks); } + @Override + public void addLocalReference(UniqueId workerId, ObjectId objectId) { + nativeAddLocalReference(workerId.getBytes(), objectId.getBytes()); + } + + @Override + public void removeLocalReference(UniqueId workerId, ObjectId objectId) { + Lock readLock = shutdownLock.readLock(); + readLock.lock(); + try { + nativeRemoveLocalReference(workerId.getBytes(), objectId.getBytes()); + } finally { + readLock.unlock(); + } + } + + public Map getAllReferenceCounts() { + Map referenceCounts = new HashMap<>(); + for (Map.Entry entry : + nativeGetAllReferenceCounts().entrySet()) { + referenceCounts.put(new ObjectId(entry.getKey()), entry.getValue()); + } + return referenceCounts; + } + private static List toBinaryList(List ids) { return ids.stream().map(BaseId::getBytes).collect(Collectors.toList()); } @@ -59,4 +92,10 @@ public class NativeObjectStore extends ObjectStore { private static native void nativeDelete(List objectIds, boolean localOnly, boolean deleteCreatingTasks); + + private static native void nativeAddLocalReference(byte[] workerId, byte[] objectId); + + private static native void nativeRemoveLocalReference(byte[] workerId, byte[] objectId); + + private static native Map nativeGetAllReferenceCounts(); } diff --git a/java/runtime/src/main/java/io/ray/runtime/object/NativeRayObject.java b/java/runtime/src/main/java/io/ray/runtime/object/NativeRayObject.java index 662ec0e3c..d3b0ae823 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/NativeRayObject.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/NativeRayObject.java @@ -1,6 +1,11 @@ package io.ray.runtime.object; import com.google.common.base.Preconditions; +import io.ray.api.id.BaseId; +import io.ray.api.id.ObjectId; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; /** * Binary representation of a ray object. See `RayObject` class in C++ for details. @@ -9,11 +14,17 @@ public class NativeRayObject { public byte[] data; public byte[] metadata; + public List containedObjectIds; public NativeRayObject(byte[] data, byte[] metadata) { Preconditions.checkState(bufferLength(data) > 0 || bufferLength(metadata) > 0); this.data = data; this.metadata = metadata; + this.containedObjectIds = Collections.emptyList(); + } + + public void setContainedObjectIds(List containedObjectIds) { + this.containedObjectIds = toBinaryList(containedObjectIds); } private static int bufferLength(byte[] buffer) { @@ -23,6 +34,10 @@ public class NativeRayObject { return buffer.length; } + private static List toBinaryList(List ids) { + return ids.stream().map(BaseId::getBytes).collect(Collectors.toList()); + } + @Override public String toString() { return ": " + bufferLength(data) + ", : " + bufferLength(metadata); diff --git a/java/runtime/src/main/java/io/ray/runtime/object/ObjectRefImpl.java b/java/runtime/src/main/java/io/ray/runtime/object/ObjectRefImpl.java index dcd5d4ac7..839ed5575 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/ObjectRefImpl.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/ObjectRefImpl.java @@ -1,54 +1,56 @@ package io.ray.runtime.object; +import com.google.common.base.FinalizableReferenceQueue; +import com.google.common.base.FinalizableWeakReference; +import com.google.common.base.Preconditions; +import com.google.common.collect.Sets; import io.ray.api.ObjectRef; import io.ray.api.Ray; import io.ray.api.id.ObjectId; -import java.io.Serializable; +import io.ray.api.id.UniqueId; +import io.ray.runtime.RayRuntimeInternal; +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.lang.ref.Reference; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; /** * Implementation of {@link ObjectRef}. */ -public final class ObjectRefImpl implements ObjectRef, Serializable { +public final class ObjectRefImpl implements ObjectRef, Externalizable { - private final ObjectId id; + private static final FinalizableReferenceQueue REFERENCE_QUEUE = new FinalizableReferenceQueue(); - /** - * Cache the result of `Ray.get()`. - * - * Note, this is necessary for direct calls, in which case, it's not allowed to call `Ray.get` on - * the same object twice. - */ - private transient T object; + private static final Set>> REFERENCES = Sets.newConcurrentHashSet(); + + private ObjectId id; + + // In GC thread, we don't know which worker this object binds to, so we need to + // store the worker ID for later uses. + private transient UniqueId workerId; private Class type; - /** - * Whether the object is already gotten from the object store. - */ - private transient boolean objectGotten; - public ObjectRefImpl(ObjectId id, Class type) { this.id = id; this.type = type; - object = null; - objectGotten = false; + addLocalReference(); } + public ObjectRefImpl() {} + @Override public synchronized T get() { - if (!objectGotten) { - object = Ray.get(id, type); - objectGotten = true; - } - return object; + return Ray.get(this); } - @Override public ObjectId getId() { return id; } - @Override public Class getType() { return type; } @@ -57,4 +59,56 @@ public final class ObjectRefImpl implements ObjectRef, Serializable { public String toString() { return "ObjectRef(" + id.toString() + ")"; } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + out.writeObject(this.getId()); + out.writeObject(this.getType()); + ObjectSerializer.addContainedObjectId(this.getId()); + } + + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + this.id = (ObjectId) in.readObject(); + this.type = (Class) in.readObject(); + addLocalReference(); + } + + private void addLocalReference() { + Preconditions.checkState(workerId == null); + RayRuntimeInternal runtime = (RayRuntimeInternal) Ray.internal(); + workerId = runtime.getWorkerContext().getCurrentWorkerId(); + runtime.getObjectStore().addLocalReference(workerId, id); + new ObjectRefImplReference(this); + } + + private static final class ObjectRefImplReference extends + FinalizableWeakReference> { + + private final UniqueId workerId; + private final ObjectId objectId; + private final AtomicBoolean removed; + + public ObjectRefImplReference(ObjectRefImpl obj) { + super(obj, REFERENCE_QUEUE); + this.workerId = obj.workerId; + this.objectId = obj.id; + this.removed = new AtomicBoolean(false); + REFERENCES.add(this); + } + + @Override + public void finalizeReferent() { + // This method may be invoked multiple times on the same instance (due to explicit invoking in + // unit tests). So if `workerId` is null, it means this method has been invoked. + if (!removed.getAndSet(true)) { + REFERENCES.remove(this); + RayRuntimeInternal runtime = (RayRuntimeInternal) Ray.internal(); + // It's possible that GC is executed after the runtime is shutdown. + if (runtime != null) { + runtime.getObjectStore().removeLocalReference(workerId, objectId); + } + } + } + } } diff --git a/java/runtime/src/main/java/io/ray/runtime/object/ObjectSerializer.java b/java/runtime/src/main/java/io/ray/runtime/object/ObjectSerializer.java index 35ab285a8..8b2c239aa 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/ObjectSerializer.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/ObjectSerializer.java @@ -8,7 +8,11 @@ import io.ray.api.id.ObjectId; import io.ray.runtime.generated.Gcs.ErrorType; import io.ray.runtime.serializer.Serializer; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; import org.apache.commons.lang3.tuple.Pair; /** @@ -31,6 +35,12 @@ public class ObjectSerializer { public static final byte[] OBJECT_METADATA_TYPE_PYTHON = "PYTHON".getBytes(); public static final byte[] OBJECT_METADATA_TYPE_RAW = "RAW".getBytes(); + // When an outer object is being serialized, the nested ObjectRefs are all + // serialized and the writeExternal method of the nested ObjectRefs are + // executed. So after the outer object is serialized, the containedObjectIds + // field will contain all the nested object IDs. + static ThreadLocal> containedObjectIds = ThreadLocal.withInitial(HashSet::new); + /** * Deserialize an object from an {@link NativeRayObject} instance. * @@ -100,9 +110,29 @@ public class ObjectSerializer { byte[] serializedBytes = Serializer.encode(object).getLeft(); return new NativeRayObject(serializedBytes, TASK_EXECUTION_EXCEPTION_META); } else { - Pair serialized = Serializer.encode(object); - return new NativeRayObject(serialized.getLeft(), serialized.getRight() ? - OBJECT_METADATA_TYPE_CROSS_LANGUAGE : OBJECT_METADATA_TYPE_JAVA); + try { + Pair serialized = Serializer.encode(object); + NativeRayObject nativeRayObject = new NativeRayObject(serialized.getLeft(), + serialized.getRight() + ? OBJECT_METADATA_TYPE_CROSS_LANGUAGE + : OBJECT_METADATA_TYPE_JAVA); + nativeRayObject.setContainedObjectIds(getAndClearContainedObjectIds()); + return nativeRayObject; + } catch (Exception e) { + // Clear `containedObjectIds`. + getAndClearContainedObjectIds(); + throw e; + } } } + + static void addContainedObjectId(ObjectId objectId) { + containedObjectIds.get().add(objectId); + } + + private static List getAndClearContainedObjectIds() { + List ids = new ArrayList<>(containedObjectIds.get()); + containedObjectIds.get().clear(); + return ids; + } } diff --git a/java/runtime/src/main/java/io/ray/runtime/object/ObjectStore.java b/java/runtime/src/main/java/io/ray/runtime/object/ObjectStore.java index d0c64783c..a30e5d8a4 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/ObjectStore.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/ObjectStore.java @@ -5,6 +5,7 @@ import io.ray.api.ObjectRef; import io.ray.api.WaitResult; import io.ray.api.exception.RayException; import io.ray.api.id.ObjectId; +import io.ray.api.id.UniqueId; import io.ray.runtime.context.WorkerContext; import java.util.ArrayList; import java.util.Collections; @@ -137,7 +138,8 @@ public abstract class ObjectStore { return new WaitResult<>(Collections.emptyList(), Collections.emptyList()); } - List ids = waitList.stream().map(ObjectRef::getId).collect(Collectors.toList()); + List ids = waitList.stream().map(ref -> ((ObjectRefImpl) ref).getId()) + .collect(Collectors.toList()); List ready = wait(ids, numReturns, timeoutMs); List> readyList = new ArrayList<>(); @@ -164,4 +166,18 @@ public abstract class ObjectStore { */ public abstract void delete(List objectIds, boolean localOnly, boolean deleteCreatingTasks); + + /** + * Increase the local reference count for this object ID. + * @param workerId The ID of the worker to increase on. + * @param objectId The object ID to increase the reference count for. + */ + public abstract void addLocalReference(UniqueId workerId, ObjectId objectId); + + /** + * Decrease the reference count for this object ID. + * @param workerId The ID of the worker to decrease on. + * @param objectId The object ID to decrease the reference count for. + */ + public abstract void removeLocalReference(UniqueId workerId, ObjectId objectId); } diff --git a/java/runtime/src/main/java/io/ray/runtime/task/ArgumentsBuilder.java b/java/runtime/src/main/java/io/ray/runtime/task/ArgumentsBuilder.java index 7d1b30ffd..9455fedb4 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/ArgumentsBuilder.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/ArgumentsBuilder.java @@ -7,6 +7,7 @@ import io.ray.api.id.ObjectId; import io.ray.runtime.RayRuntimeInternal; import io.ray.runtime.generated.Common.Language; import io.ray.runtime.object.NativeRayObject; +import io.ray.runtime.object.ObjectRefImpl; import io.ray.runtime.object.ObjectSerializer; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -22,7 +23,8 @@ public class ArgumentsBuilder { * If the the size of an argument's serialized data is smaller than this number, the argument will * be passed by value. Otherwise it'll be passed by reference. */ - private static final int LARGEST_SIZE_PASS_BY_VALUE = 100 * 1024; + // TODO(kfstorm): Read from internal config `max_direct_call_object_size`. + public static final int LARGEST_SIZE_PASS_BY_VALUE = 100 * 1024; /** * This dummy type is also defined in signature.py. Please keep it synced. @@ -39,7 +41,8 @@ public class ArgumentsBuilder { ObjectId id = null; NativeRayObject value = null; if (arg instanceof ObjectRef) { - id = ((ObjectRef) arg).getId(); + Preconditions.checkState(arg instanceof ObjectRefImpl); + id = ((ObjectRefImpl) arg).getId(); } else { value = ObjectSerializer.serialize(arg); if (language != Language.JAVA) { diff --git a/java/test/pom.xml b/java/test/pom.xml index 9cd773a05..be2ce21b0 100644 --- a/java/test/pom.xml +++ b/java/test/pom.xml @@ -74,7 +74,7 @@ org.testng testng - 6.9.10 + 7.1.0 diff --git a/java/test/src/main/java/io/ray/test/ActorTest.java b/java/test/src/main/java/io/ray/test/ActorTest.java index 3712da1cf..1e7e40b04 100644 --- a/java/test/src/main/java/io/ray/test/ActorTest.java +++ b/java/test/src/main/java/io/ray/test/ActorTest.java @@ -42,6 +42,10 @@ public class ActorTest extends BaseTest { value += largeObject.data.length; return value; } + + public TestUtils.LargeObject createLargeObject() { + return new TestUtils.LargeObject(); + } } public void testCreateAndCallActor() { @@ -68,8 +72,7 @@ public class ActorTest extends BaseTest { ObjectRef result = actor.task(Counter::getValue).remote(); Assert.assertEquals(result.get(), Integer.valueOf(1)); Assert.assertEquals(result.get(), Integer.valueOf(1)); - // TODO(hchen): The following code will still fail, and can be fixed by using ref counting. - // Assert.assertEquals(Ray.get(result.getId()), Integer.valueOf(1)); + Assert.assertEquals(Ray.get(result), Integer.valueOf(1)); } public void testCallActorWithLargeObject() { @@ -117,33 +120,28 @@ public class ActorTest extends BaseTest { Collections.singletonList(actor), 100).remote().get()); } - // TODO(qwang): Will re-enable this test case once ref counting is supported in Java. - @Test(enabled = false, groups = {"cluster"}) + // This test case follows `test_internal_free` in `python/ray/tests/test_advanced.py`. + @Test(groups = {"cluster"}) public void testUnreconstructableActorObject() throws InterruptedException { - // The UnreconstructableException is created by raylet. ActorHandle counter = Ray.actor(Counter::new, 100).remote(); // Call an actor method. ObjectRef value = counter.task(Counter::getValue).remote(); Assert.assertEquals(100, value.get()); // Delete the object from the object store. - Ray.internal().free(ImmutableList.of(value.getId()), false, false); - // Wait until the object is deleted, because the above free operation is async. - while (true) { - Boolean result = TestUtils.getRuntime().getObjectStore() - .wait(ImmutableList.of(value.getId()), 1, 0).get(0); - if (!result) { - break; - } - TimeUnit.MILLISECONDS.sleep(100); - } + Ray.internal().free(ImmutableList.of(value), false, false); + // Wait for delete RPC to propagate + TimeUnit.SECONDS.sleep(1); + // Free deletes from in-memory store. + Assert.expectThrows(UnreconstructableException.class, () -> value.get()); - try { - // Try getting the object again, this should throw an UnreconstructableException. - // Use `Ray.get()` to bypass the cache in `RayObjectImpl`. - Ray.get(value.getId(), value.getType()); - Assert.fail("This line should not be reachable."); - } catch (UnreconstructableException e) { - Assert.assertEquals(value.getId(), e.objectId); - } + // Call an actor method. + ObjectRef largeValue = counter.task(Counter::createLargeObject).remote(); + Assert.assertTrue(largeValue.get() instanceof TestUtils.LargeObject); + // Delete the object from the object store. + Ray.internal().free(ImmutableList.of(largeValue), false, false); + // Wait for delete RPC to propagate + TimeUnit.SECONDS.sleep(1); + // Free deletes big objects from plasma store. + Assert.expectThrows(UnreconstructableException.class, () -> largeValue.get()); } } diff --git a/java/test/src/main/java/io/ray/test/BaseMultiLanguageTest.java b/java/test/src/main/java/io/ray/test/BaseMultiLanguageTest.java index a45793ee7..77e388430 100644 --- a/java/test/src/main/java/io/ray/test/BaseMultiLanguageTest.java +++ b/java/test/src/main/java/io/ray/test/BaseMultiLanguageTest.java @@ -51,7 +51,7 @@ public abstract class BaseMultiLanguageTest { } } - @BeforeClass(alwaysRun = true) + @BeforeClass(alwaysRun = true, inheritGroups = false) public void setUp() { // Delete existing socket files. for (String socket : ImmutableList.of(RAYLET_SOCKET_NAME, PLASMA_STORE_SOCKET_NAME)) { @@ -106,7 +106,7 @@ public abstract class BaseMultiLanguageTest { return ImmutableMap.of(); } - @AfterClass(alwaysRun = true) + @AfterClass(alwaysRun = true, inheritGroups = false) public void tearDown() { // Disconnect to the cluster. Ray.shutdown(); diff --git a/java/test/src/main/java/io/ray/test/GcsClientTest.java b/java/test/src/main/java/io/ray/test/GcsClientTest.java index 6a5b2c002..d17e16a8b 100644 --- a/java/test/src/main/java/io/ray/test/GcsClientTest.java +++ b/java/test/src/main/java/io/ray/test/GcsClientTest.java @@ -33,7 +33,7 @@ public class GcsClientTest extends BaseTest { Assert.assertEquals(allNodeInfo.size(), 1); Assert.assertEquals(allNodeInfo.get(0).nodeAddress, config.nodeIp); Assert.assertTrue(allNodeInfo.get(0).isAlive); - Assert.assertEquals(allNodeInfo.get(0).resources.get("A"), 8.0); + Assert.assertEquals((double) allNodeInfo.get(0).resources.get("A"), 8.0); } @Test diff --git a/java/test/src/main/java/io/ray/test/GlobalGcTest.java b/java/test/src/main/java/io/ray/test/GlobalGcTest.java new file mode 100644 index 000000000..a75da086e --- /dev/null +++ b/java/test/src/main/java/io/ray/test/GlobalGcTest.java @@ -0,0 +1,96 @@ +package io.ray.test; + +import io.ray.api.ActorHandle; +import io.ray.api.ObjectRef; +import io.ray.api.Ray; +import java.lang.ref.WeakReference; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.testng.Assert; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +@Test(groups = {"cluster"}) +public class GlobalGcTest extends BaseTest { + + @BeforeClass + public void setUp() { + System.setProperty("ray.object-store.size", "140 MB"); + } + + @AfterClass + public void tearDown() { + System.clearProperty("ray.object-store.size"); + } + + public static class LargeObjectWithCyclicRef { + + private final LargeObjectWithCyclicRef loop; + + private final ObjectRef largeObject; + + public LargeObjectWithCyclicRef() { + this.loop = this; + this.largeObject = Ray.put(new TestUtils.LargeObject(40 * 1024 * 1024)); + } + } + + public static class GarbageHolder { + + private WeakReference garbage; + + public GarbageHolder() { + LargeObjectWithCyclicRef x = new LargeObjectWithCyclicRef(); + garbage = new WeakReference<>(x); + } + + public boolean hasGarbage() { + return garbage.get() != null; + } + + public TestUtils.LargeObject returnLargeObject() { + return new TestUtils.LargeObject(80 * 1024 * 1024); + } + } + + private void testGlobalGcWhenFull(boolean withPut) { + // Local driver. + WeakReference localRef = new WeakReference<>( + new LargeObjectWithCyclicRef()); + + // Remote workers. + List> actors = IntStream + .range(0, 2).mapToObj(i -> Ray.actor(GarbageHolder::new).remote()) + .collect(Collectors.toList()); + + Assert.assertNotNull(localRef.get()); + for (ActorHandle actor : actors) { + Assert.assertTrue(actor.task(GarbageHolder::hasGarbage).remote().get()); + } + + if (withPut) { + // GC should be triggered for all workers, including the local driver, + // when the driver tries to Ray.put a value that doesn't fit in the + // object store. This should cause the captured ObjectRefs to be evicted. + Ray.put(new TestUtils.LargeObject(80 * 1024 * 1024)); + } else { + // GC should be triggered for all workers, including the local driver, + // when a remote task tries to put a return value that doesn't fit in + // the object store. This should cause the captured ObjectRefs' to be evicted. + actors.get(0).task(GarbageHolder::returnLargeObject).remote().get(); + } + + TestUtils.waitForCondition(() -> localRef.get() == null && actors.stream().noneMatch( + a -> a.task(GarbageHolder::hasGarbage).remote().get()), 10 * 1000); + } + + public void testGlobalGcWhenFullWithPut() { + testGlobalGcWhenFull(true); + } + + public void testGlobalGcWhenFullWithReturn() { + testGlobalGcWhenFull(false); + } +} diff --git a/java/test/src/main/java/io/ray/test/MultiThreadingTest.java b/java/test/src/main/java/io/ray/test/MultiThreadingTest.java index 7a63c6435..d7d9f1425 100644 --- a/java/test/src/main/java/io/ray/test/MultiThreadingTest.java +++ b/java/test/src/main/java/io/ray/test/MultiThreadingTest.java @@ -143,7 +143,7 @@ public class MultiThreadingTest extends BaseTest { final ActorHandle fooActor = Ray.actor(Echo::new).remote(); return new Runnable[]{ () -> Ray.put(1), - () -> Ray.get(fooObject.getId(), fooObject.getType()), + () -> Ray.get(ImmutableList.of(fooObject)), fooObject::get, () -> Ray.wait(ImmutableList.of(fooObject)), Ray::getRuntimeContext, diff --git a/java/test/src/main/java/io/ray/test/ObjectStoreTest.java b/java/test/src/main/java/io/ray/test/ObjectStoreTest.java index 308763061..7643d8001 100644 --- a/java/test/src/main/java/io/ray/test/ObjectStoreTest.java +++ b/java/test/src/main/java/io/ray/test/ObjectStoreTest.java @@ -3,7 +3,6 @@ package io.ray.test; import com.google.common.collect.ImmutableList; import io.ray.api.ObjectRef; import io.ray.api.Ray; -import io.ray.api.id.ObjectId; import java.util.List; import java.util.stream.Collectors; import org.testng.Assert; @@ -37,8 +36,8 @@ public class ObjectStoreTest extends BaseTest { @Test public void testGetMultipleObjects() { List ints = ImmutableList.of(1, 2, 3, 4, 5); - List ids = ints.stream().map(obj -> Ray.put(obj).getId()) + List> refs = ints.stream().map(Ray::put) .collect(Collectors.toList()); - Assert.assertEquals(ints, Ray.get(ids, Integer.class)); + Assert.assertEquals(ints, Ray.get(refs)); } } diff --git a/java/test/src/main/java/io/ray/test/PlasmaFreeTest.java b/java/test/src/main/java/io/ray/test/PlasmaFreeTest.java index 99adafc7a..1b924f3c0 100644 --- a/java/test/src/main/java/io/ray/test/PlasmaFreeTest.java +++ b/java/test/src/main/java/io/ray/test/PlasmaFreeTest.java @@ -4,6 +4,7 @@ import com.google.common.collect.ImmutableList; import io.ray.api.ObjectRef; import io.ray.api.Ray; import io.ray.api.id.TaskId; +import io.ray.runtime.object.ObjectRefImpl; import java.util.Arrays; import org.testng.Assert; import org.testng.annotations.Test; @@ -19,11 +20,11 @@ public class PlasmaFreeTest extends BaseTest { ObjectRef helloId = Ray.task(PlasmaFreeTest::hello).remote(); String helloString = helloId.get(); Assert.assertEquals("hello", helloString); - Ray.internal().free(ImmutableList.of(helloId.getId()), true, false); + Ray.internal().free(ImmutableList.of(helloId), true, false); final boolean result = TestUtils.waitForCondition(() -> !TestUtils.getRuntime().getObjectStore() - .wait(ImmutableList.of(helloId.getId()), 1, 0).get(0), 50); + .wait(ImmutableList.of(((ObjectRefImpl) helloId).getId()), 1, 0).get(0), 50); if (TestUtils.isSingleProcessMode()) { Assert.assertTrue(result); } else { @@ -36,9 +37,10 @@ public class PlasmaFreeTest extends BaseTest { public void testDeleteCreatingTasks() { ObjectRef helloId = Ray.task(PlasmaFreeTest::hello).remote(); Assert.assertEquals("hello", helloId.get()); - Ray.internal().free(ImmutableList.of(helloId.getId()), true, true); + Ray.internal().free(ImmutableList.of(helloId), true, true); - TaskId taskId = TaskId.fromBytes(Arrays.copyOf(helloId.getId().getBytes(), TaskId.LENGTH)); + TaskId taskId = TaskId.fromBytes( + Arrays.copyOf(((ObjectRefImpl) helloId).getId().getBytes(), TaskId.LENGTH)); final boolean result = TestUtils.waitForCondition( () -> !TestUtils.getRuntime().getGcsClient() .rayletTaskExistsInGcs(taskId), 50); diff --git a/java/test/src/main/java/io/ray/test/PlasmaStoreTest.java b/java/test/src/main/java/io/ray/test/PlasmaStoreTest.java deleted file mode 100644 index dcbabf655..000000000 --- a/java/test/src/main/java/io/ray/test/PlasmaStoreTest.java +++ /dev/null @@ -1,21 +0,0 @@ -package io.ray.test; - -import io.ray.api.Ray; -import io.ray.api.id.ObjectId; -import io.ray.runtime.object.ObjectStore; -import org.testng.Assert; -import org.testng.annotations.Test; - -public class PlasmaStoreTest extends BaseTest { - - @Test(groups = {"cluster"}) - public void testPutWithDuplicateId() { - ObjectId objectId = ObjectId.fromRandom(); - ObjectStore objectStore = TestUtils.getRuntime().getObjectStore(); - objectStore.put("1", objectId); - Assert.assertEquals(Ray.get(objectId, String.class), "1"); - objectStore.put("2", objectId); - // Putting the second object with duplicate ID should fail but ignored. - Assert.assertEquals(Ray.get(objectId, String.class), "1"); - } -} diff --git a/java/test/src/main/java/io/ray/test/ReferenceCountingTest.java b/java/test/src/main/java/io/ray/test/ReferenceCountingTest.java new file mode 100644 index 000000000..24ccbb21d --- /dev/null +++ b/java/test/src/main/java/io/ray/test/ReferenceCountingTest.java @@ -0,0 +1,362 @@ +package io.ray.test; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.gson.Gson; +import io.ray.api.ActorHandle; +import io.ray.api.ObjectRef; +import io.ray.api.Ray; +import io.ray.api.id.ObjectId; +import io.ray.runtime.object.NativeObjectStore; +import io.ray.runtime.object.ObjectRefImpl; +import java.lang.ref.Reference; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.time.Duration; +import java.time.Instant; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.testng.Assert; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +@Test(groups = {"cluster"}) +public class ReferenceCountingTest extends BaseTest { + @BeforeClass + public void setUp() { + System.setProperty("ray.object-store.size", "100 MB"); + } + + @AfterClass + public void tearDown() { + System.clearProperty("ray.object-store.size"); + } + + /** + * Because we can't explicitly GC an Java object. We use this helper method to manually remove an + * local reference. + */ + private static void del(ObjectRef obj) { + try { + Field referencesField = ObjectRefImpl.class.getDeclaredField("REFERENCES"); + referencesField.setAccessible(true); + Set references = (Set) referencesField.get(null); + Class referenceClass = Class + .forName("io.ray.runtime.object.ObjectRefImpl$ObjectRefImplReference"); + Method finalizeReferentMethod = referenceClass.getDeclaredMethod("finalizeReferent"); + finalizeReferentMethod.setAccessible(true); + for (Object reference : references) { + if (obj.equals(((Reference)reference).get())) { + finalizeReferentMethod.invoke(reference); + break; + } + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private void checkRefCounts(Map expected, Duration timeout) { + Instant start = Instant.now(); + while (true) { + Map actual = + ((NativeObjectStore) TestUtils.getRuntime().getObjectStore()).getAllReferenceCounts(); + try { + Assert.assertEqualsDeep(actual, expected); + return; + } catch (AssertionError e) { + if (Duration.between(start, Instant.now()).compareTo(timeout) >= 0) { + System.out.println("Actual: " + new Gson().toJson(actual)); + System.out.println("Expected: " + new Gson().toJson(expected)); + throw e; + } else { + try { + Thread.sleep(100); + } catch (InterruptedException ex) { + throw new RuntimeException(e); + } + } + } + } + } + + private void checkRefCounts(Map expected) { + checkRefCounts(expected, Duration.ofSeconds(10)); + } + + private void checkRefCounts(ObjectId objectId, long localRefCount, long submittedTaskRefCount) { + checkRefCounts(ImmutableMap.of(objectId, new long[] {localRefCount, submittedTaskRefCount})); + } + + private void checkRefCounts(ObjectId objectId1, long localRefCount1, long submittedTaskRefCount1, + ObjectId objectId2, long localRefCount2, long submittedTaskRefCount2) { + checkRefCounts(ImmutableMap.of(objectId1, new long[] {localRefCount1, submittedTaskRefCount1}, + objectId2, new long[] {localRefCount2, submittedTaskRefCount2})); + } + + private static void fillObjectStoreAndGet(ObjectId objectId, boolean succeed) { + fillObjectStoreAndGet(objectId, succeed, 40 * 1024 * 1024, 5); + } + + private static void fillObjectStoreAndGet(ObjectId objectId, boolean succeed, int objectSize, + int numObjects) { + for (int i = 0; i < numObjects; i++) { + Ray.put(new TestUtils.LargeObject(objectSize)); + } + if (succeed) { + TestUtils.getRuntime().getObjectStore().getRaw(ImmutableList.of(objectId), Long.MAX_VALUE); + } else { + List result = + TestUtils.getRuntime().getObjectStore().wait(ImmutableList.of(objectId), 1, 100); + Assert.assertFalse(result.get(0)); + } + } + + /** + * Based on Python test case `test_local_refcounts`. + */ + public void testLocalRefCounts() { + ObjectRefImpl obj1 = (ObjectRefImpl) Ray.put(null); + checkRefCounts(obj1.getId(), 1, 0); + ObjectRef obj1Copy = new ObjectRefImpl<>(obj1.getId(), obj1.getType()); + checkRefCounts(obj1.getId(), 2, 0); + + del(obj1); + checkRefCounts(obj1.getId(), 1, 0); + del(obj1Copy); + checkRefCounts(ImmutableMap.of()); + } + + private static int oneDep(Object obj) { + return oneDep(obj, null); + } + + private static int oneDep(Object obj, ActorHandle singal) { + return oneDep(obj, singal, false); + } + + private static int oneDep(Object obj, ActorHandle singal, boolean fail) { + if (singal != null) { + singal.task(SignalActor::waitSignal).remote().get(); + } + if (fail) { + throw new RuntimeException("failed on purpose"); + } + return 0; + } + + private static TestUtils.LargeObject oneDepLarge(Object obj, ActorHandle singal) { + if (singal != null) { + singal.task(SignalActor::waitSignal).remote().get(); + } + // This will be spilled to plasma. + return new TestUtils.LargeObject(10 * 1024 * 1024); + } + + private static void sendSignal(ActorHandle signal) { + ObjectRef result = signal.task(SignalActor::sendSignal).remote(); + result.get(); + // Remove the reference immediately, otherwise it will affect subsequent tests. + del(result); + } + + /** + * Based on Python test case `test_dependency_refcounts`. + */ + public void testDependencyRefCounts() { + { + // Test that regular plasma dependency refcounts are decremented once the + // task finishes. + ActorHandle signal = SignalActor.create(); + ObjectRefImpl largeDep = (ObjectRefImpl) Ray + .put(new TestUtils.LargeObject()); + ObjectRefImpl result = (ObjectRefImpl) + Ray., Object>task( + ReferenceCountingTest::oneDep, largeDep, signal).remote(); + checkRefCounts(largeDep.getId(), 1, 1, result.getId(), 1, 0); + sendSignal(signal); + // Reference count should be removed once the task finishes. + checkRefCounts(largeDep.getId(), 1, 0, result.getId(), 1, 0); + del(largeDep); + del(result); + checkRefCounts(ImmutableMap.of()); + } + + { + // Test that inlined dependency refcounts are decremented once they are + // inlined. + ActorHandle signal = SignalActor.create(); + ObjectRefImpl dep = (ObjectRefImpl) + Ray., Integer>task(ReferenceCountingTest::oneDep, + Integer.valueOf(1), signal).remote(); + checkRefCounts(dep.getId(), 1, 0); + ObjectRefImpl result = (ObjectRefImpl) + Ray.task(ReferenceCountingTest::oneDep, dep).remote(); + checkRefCounts(dep.getId(), 1, 1, result.getId(), 1, 0); + sendSignal(signal); + // Reference count should be removed as soon as the dependency is inlined. + checkRefCounts(dep.getId(), 1, 0, result.getId(), 1, 0); + del(dep); + del(result); + checkRefCounts(ImmutableMap.of()); + } + + { + // Test that spilled plasma dependency refcounts are decremented once + // the task finishes. + ActorHandle signal1 = SignalActor.create(); + ActorHandle signal2 = SignalActor.create(); + ObjectRefImpl dep = (ObjectRefImpl) + Ray., TestUtils.LargeObject>task( + ReferenceCountingTest::oneDepLarge, (TestUtils.LargeObject) null, signal1).remote(); + checkRefCounts(dep.getId(), 1, 0); + ObjectRefImpl result = (ObjectRefImpl) + Ray., Integer>task( + ReferenceCountingTest::oneDep, dep, signal2).remote(); + checkRefCounts(dep.getId(), 1, 1, result.getId(), 1, 0); + sendSignal(signal1); + dep.get(); // TODO(kfstorm): timeout=10 + // Reference count should remain because the dependency is in plasma. + checkRefCounts(dep.getId(), 1, 1, result.getId(), 1, 0); + sendSignal(signal2); + // Reference count should be removed because the task finished. + checkRefCounts(dep.getId(), 1, 0, result.getId(), 1, 0); + del(dep); + del(result); + checkRefCounts(ImmutableMap.of()); + } + + { + // Test that regular plasma dependency refcounts are decremented if a task + // fails. + ActorHandle signal = SignalActor.create(); + ObjectRefImpl largeDep = (ObjectRefImpl) + Ray.put(new TestUtils.LargeObject(10 * 1024 * 1024)); + ObjectRefImpl result = (ObjectRefImpl) + Ray., Boolean, Integer>task( + ReferenceCountingTest::oneDep, largeDep, signal, /* fail= */true).remote(); + checkRefCounts(largeDep.getId(), 1, 1, result.getId(), 1, 0); + sendSignal(signal); + // Reference count should be removed once the task finishes. + checkRefCounts(largeDep.getId(), 1, 0, result.getId(), 1, 0); + del(largeDep); + del(result); + checkRefCounts(ImmutableMap.of()); + } + + { + // Test that spilled plasma dependency refcounts are decremented if a task + // fails. + ActorHandle signal1 = SignalActor.create(); + ActorHandle signal2 = SignalActor.create(); + ObjectRefImpl dep = (ObjectRefImpl) + Ray., TestUtils.LargeObject>task( + ReferenceCountingTest::oneDepLarge, (Integer) null, signal1).remote(); + checkRefCounts(dep.getId(), 1, 0); + ObjectRefImpl result = (ObjectRefImpl) + Ray., Boolean, Integer>task( + ReferenceCountingTest::oneDep, dep, signal2, /* fail= */true).remote(); + checkRefCounts(dep.getId(), 1, 1, result.getId(), 1, 0); + sendSignal(signal1); + dep.get(); // TODO(kfstorm): timeout=10 + // Reference count should remain because the dependency is in plasma. + checkRefCounts(dep.getId(), 1, 1, result.getId(), 1, 0); + sendSignal(signal2); + // Reference count should be removed because the task finished. + checkRefCounts(dep.getId(), 1, 0, result.getId(), 1, 0); + del(dep); + del(result); + checkRefCounts(ImmutableMap.of()); + } + } + + private static int fooBasicPinning(Object arg) { + return 0; + } + + public static class ActorBasicPinning { + private ObjectRef largeObject; + + public ActorBasicPinning() { + // Hold a long-lived reference to a ray.put object's ID. The object + // should not be garbage collected while the actor is alive because + // the object is pinned by the raylet. + largeObject = Ray.put(new TestUtils.LargeObject(25 * 1024 * 1024)); + } + + public TestUtils.LargeObject getLargeObject() { + return largeObject.get(); + } + } + + /** + * Based on Python test case `test_basic_pinning`. + */ + public void testBasicPinning() { + ActorHandle actor = Ray.actor(ActorBasicPinning::new).remote(); + + // Fill up the object store with short-lived objects. These should be + // evicted before the long-lived object whose reference is held by + // the actor. + for (int i = 0; i < 10; i++) { + ObjectRef intermediateResult = Ray + .task(ReferenceCountingTest::fooBasicPinning, new TestUtils.LargeObject(10 * 1024 * 1024)) + .remote(); + intermediateResult.get(); + } + // The ray.get below would fail with only LRU eviction, as the object + // that was ray.put by the actor would have been evicted. + actor.task(ActorBasicPinning::getLargeObject).remote().get(); + } + + private static Object pending(TestUtils.LargeObject input1, Integer input2) { + return null; + } + + /** + * Based on Python test case `test_pending_task_dependency_pinning`. + */ + public void testPendingTaskDependencyPinning() { + // The object that is ray.put here will go out of scope immediately, so if + // pending task dependencies aren't considered, it will be evicted before + // the ray.get below due to the subsequent ray.puts that fill up the object + // store. + TestUtils.LargeObject input1 = new TestUtils.LargeObject(40 * 1024 * 1024); + ActorHandle signal = SignalActor.create(); + ObjectRef result = Ray + .task(ReferenceCountingTest::pending, input1, signal.task(SignalActor::waitSignal).remote()) + .remote(); + + for (int i = 0; i < 2; i++) { + Ray.put(new TestUtils.LargeObject(40 * 1024 * 1024)); + } + + sendSignal(signal); + result.get(); + } + + /** + * Test that an object containing object IDs within it pins the inner IDs. Based on Python test + * case `test_basic_nested_ids`. + */ + public void testBasicNestedIds() { + ObjectRefImpl inner = (ObjectRefImpl) Ray.put(new byte[40 * 1024 * 1024]); + ObjectRef>> outer = Ray.put(Collections.singletonList(inner)); + + // Remove the local reference to the inner object. + del(inner); + + // Check that the outer reference pins the inner object. + fillObjectStoreAndGet(inner.getId(), true); + + // Remove the outer reference and check that the inner object gets evicted. + del(outer); + fillObjectStoreAndGet(inner.getId(), false); + } + + // TODO(kfstorm): Add more test cases +} diff --git a/java/test/src/main/java/io/ray/test/SignalActor.java b/java/test/src/main/java/io/ray/test/SignalActor.java new file mode 100644 index 000000000..722c5d578 --- /dev/null +++ b/java/test/src/main/java/io/ray/test/SignalActor.java @@ -0,0 +1,28 @@ +package io.ray.test; + +import io.ray.api.ActorHandle; +import io.ray.api.Ray; +import java.util.concurrent.Semaphore; + +public class SignalActor { + + private Semaphore semaphore; + + public SignalActor() { + this.semaphore = new Semaphore(0); + } + + public int sendSignal() { + this.semaphore.release(); + return 0; + } + + public int waitSignal() throws InterruptedException { + this.semaphore.acquire(); + return 0; + } + + public static ActorHandle create() { + return Ray.actor(SignalActor::new).setMaxConcurrency(2).remote(); + } +} \ No newline at end of file diff --git a/java/test/src/main/java/io/ray/test/StressTest.java b/java/test/src/main/java/io/ray/test/StressTest.java index 1a5c5e4d9..42260de91 100644 --- a/java/test/src/main/java/io/ray/test/StressTest.java +++ b/java/test/src/main/java/io/ray/test/StressTest.java @@ -4,7 +4,6 @@ import com.google.common.collect.ImmutableList; import io.ray.api.ActorHandle; import io.ray.api.ObjectRef; import io.ray.api.Ray; -import io.ray.api.id.ObjectId; import java.util.ArrayList; import java.util.List; import org.testng.Assert; @@ -21,12 +20,12 @@ public class StressTest extends BaseTest { for (int numIterations : ImmutableList.of(1, 10, 100, 1000)) { int numTasks = 1000 / numIterations; for (int i = 0; i < numIterations; i++) { - List resultIds = new ArrayList<>(); + List> results = new ArrayList<>(); for (int j = 0; j < numTasks; j++) { - resultIds.add(Ray.task(StressTest::echo, 1).remote().getId()); + results.add(Ray.task(StressTest::echo, 1).remote()); } - for (Integer result : Ray.get(resultIds, Integer.class)) { + for (Integer result : Ray.get(results)) { Assert.assertEquals(result, Integer.valueOf(1)); } } @@ -58,12 +57,12 @@ public class StressTest extends BaseTest { } public int ping(int n) { - List objectIds = new ArrayList<>(); + List> objectRefs = new ArrayList<>(); for (int i = 0; i < n; i++) { - objectIds.add(actor.task(Actor::ping).remote().getId()); + objectRefs.add(actor.task(Actor::ping).remote()); } int sum = 0; - for (Integer result : Ray.get(objectIds, Integer.class)) { + for (Integer result : Ray.get(objectRefs)) { sum += result; } return sum; @@ -72,13 +71,13 @@ public class StressTest extends BaseTest { public void testSubmittingManyTasksToOneActor() throws Exception { ActorHandle actor = Ray.actor(Actor::new).remote(); - List objectIds = new ArrayList<>(); + List> objectRefs = new ArrayList<>(); for (int i = 0; i < 10; i++) { ActorHandle worker = Ray.actor(Worker::new, actor).remote(); - objectIds.add(worker.task(Worker::ping, 100).remote().getId()); + objectRefs.add(worker.task(Worker::ping, 100).remote()); } - for (Integer result : Ray.get(objectIds, Integer.class)) { + for (Integer result : Ray.get(objectRefs)) { Assert.assertEquals(result, Integer.valueOf(100)); } } diff --git a/java/test/src/main/java/io/ray/test/TestUtils.java b/java/test/src/main/java/io/ray/test/TestUtils.java index 4dac2ccbb..88f4d25cd 100644 --- a/java/test/src/main/java/io/ray/test/TestUtils.java +++ b/java/test/src/main/java/io/ray/test/TestUtils.java @@ -1,10 +1,12 @@ package io.ray.test; +import com.google.common.base.Preconditions; import io.ray.api.ObjectRef; import io.ray.api.Ray; import io.ray.runtime.RayRuntimeInternal; import io.ray.runtime.RayRuntimeProxy; import io.ray.runtime.config.RunMode; +import io.ray.runtime.task.ArgumentsBuilder; import java.io.Serializable; import java.util.function.Supplier; import org.testng.Assert; @@ -13,7 +15,16 @@ public class TestUtils { public static class LargeObject implements Serializable { - public byte[] data = new byte[1024 * 1024]; + public byte[] data; + + public LargeObject() { + this(1024 * 1024); + } + + public LargeObject(int size) { + Preconditions.checkState(size > ArgumentsBuilder.LARGEST_SIZE_PASS_BY_VALUE); + data = new byte[size]; + } } private static final int WAIT_INTERVAL_MS = 5; @@ -55,11 +66,12 @@ public class TestUtils { /** * Warm up the cluster to make sure there's at least one idle worker. - *

- * This is needed before calling `wait`. Because, in Travis CI, starting a new worker - * process could be slower than the wait timeout. - * TODO(hchen): We should consider supporting always reversing a certain number of - * idle workers in Raylet's worker pool. + *

+ * This is needed before calling `wait`. Because, in Travis CI, starting a new worker process + * could be slower than the wait timeout. + *

+ * TODO(hchen): We should consider supporting always reversing a certain number of idle workers in + * Raylet's worker pool. */ public static void warmUpCluster() { ObjectRef obj = Ray.task(TestUtils::hi).remote(); diff --git a/java/tutorial/src/main/java/io/ray/exercise/Exercise04.java b/java/tutorial/src/main/java/io/ray/exercise/Exercise04.java index fbed4725c..7c2e56739 100644 --- a/java/tutorial/src/main/java/io/ray/exercise/Exercise04.java +++ b/java/tutorial/src/main/java/io/ray/exercise/Exercise04.java @@ -50,7 +50,6 @@ public class Exercise04 { System.out.printf("%d ready object(s): \n", waitResult.getReady().size()); waitResult.getReady().forEach(rayObject -> System.out.println(rayObject.get())); System.out.printf("%d unready object(s): \n", waitResult.getUnready().size()); - waitResult.getUnready().forEach(rayObject -> System.out.println(rayObject.getId())); } catch (Throwable t) { t.printStackTrace(); } finally { diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 1a52bf306..d3201cdb1 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -164,6 +164,18 @@ void CoreWorkerProcess::EnsureInitialized() { << "shutdown."; } +std::shared_ptr CoreWorkerProcess::TryGetWorker(const WorkerID &worker_id) { + if (!instance_) { + return nullptr; + } + absl::ReaderMutexLock workers_lock(&instance_->worker_map_mutex_); + auto it = instance_->workers_.find(worker_id); + if (it != instance_->workers_.end()) { + return it->second; + } + return nullptr; +} + CoreWorker &CoreWorkerProcess::GetCoreWorker() { EnsureInitialized(); if (instance_->options_.num_workers == 1) { diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 2915ede6f..24e09463c 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -178,6 +178,13 @@ class CoreWorkerProcess { /// `CoreWorkerProcess` has full control of the destruction timing of `CoreWorker`. static CoreWorker &GetCoreWorker(); + /// Try to get the `CoreWorker` instance by worker ID. + /// If the current thread is not associated with a core worker, returns a null pointer. + /// + /// \param[in] workerId The worker ID. + /// \return The `CoreWorker` instance. + static std::shared_ptr TryGetWorker(const WorkerID &worker_id); + /// Set the core worker associated with the current thread by worker ID. /// Currently used by Java worker only. /// diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc index c602fca27..f67e8e869 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc @@ -18,10 +18,10 @@ #include +#include "jni_utils.h" #include "ray/common/id.h" #include "ray/core_worker/actor_handle.h" #include "ray/core_worker/core_worker.h" -#include "jni_utils.h" thread_local JNIEnv *local_env = nullptr; jobject java_task_executor = nullptr; @@ -72,6 +72,19 @@ jobject ToJavaArgs(JNIEnv *env, jbooleanArray java_check_results, } } +JNIEnv *GetJNIEnv() { + JNIEnv *env = local_env; + if (!env) { + // Attach the native thread to JVM. + auto status = + jvm->AttachCurrentThreadAsDaemon(reinterpret_cast(&env), nullptr); + RAY_CHECK(status == JNI_OK) << "Failed to get JNIEnv. Return code: " << status; + local_env = env; + } + RAY_CHECK(env); + return env; +} + #ifdef __cplusplus extern "C" { #endif @@ -98,16 +111,7 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( const std::vector &arg_reference_ids, const std::vector &return_ids, std::vector> *results) { - JNIEnv *env = local_env; - if (!env) { - // Attach the native thread to JVM. - auto status = - jvm->AttachCurrentThreadAsDaemon(reinterpret_cast(&env), nullptr); - RAY_CHECK(status == JNI_OK) << "Failed to get JNIEnv. Return code: " << status; - local_env = env; - } - - RAY_CHECK(env); + JNIEnv *env = GetJNIEnv(); RAY_CHECK(java_task_executor); // convert RayFunction @@ -141,6 +145,8 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( env->CallObjectMethod(java_task_executor, java_task_executor_execute, ray_function_array_list, args_array_list); RAY_CHECK_JAVA_EXCEPTION(env); + + // Process return objects. if (!return_ids.empty()) { std::vector> return_objects; JavaListToNativeVector>( @@ -148,8 +154,26 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( [](JNIEnv *env, jobject java_native_ray_object) { return JavaNativeRayObjectToNativeRayObject(env, java_native_ray_object); }); - for (auto &obj : return_objects) { - results->push_back(obj); + std::vector data_sizes; + std::vector> metadatas; + std::vector> contained_object_ids; + for (size_t i = 0; i < return_objects.size(); i++) { + data_sizes.push_back( + return_objects[i]->HasData() ? return_objects[i]->GetData()->Size() : 0); + metadatas.push_back(return_objects[i]->GetMetadata()); + contained_object_ids.push_back(return_objects[i]->GetNestedIds()); + } + RAY_CHECK_OK(ray::CoreWorkerProcess::GetCoreWorker().AllocateReturnObjects( + return_ids, data_sizes, metadatas, contained_object_ids, results)); + for (size_t i = 0; i < data_sizes.size(); i++) { + auto result = (*results)[i]; + // A nullptr is returned if the object already exists. + if (result != nullptr) { + if (result->HasData()) { + memcpy(result->GetData()->Data(), return_objects[i]->GetData()->Data(), + data_sizes[i]); + } + } } } @@ -159,6 +183,26 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( return ray::Status::OK(); }; + auto gc_collect = []() { + // A Java worker process usually contains more than one worker. + // A LocalGC request is likely to be received by multiple workers in a short time. + // Here we ensure that the 1 second interval of `System.gc()` execution is + // guaranteed no matter how frequent the requests are received and how many workers + // the process has. + static absl::Mutex mutex; + static int64_t last_gc_time_ms = 0; + absl::MutexLock lock(&mutex); + int64_t start = current_time_ms(); + if (last_gc_time_ms + 1000 < start) { + JNIEnv *env = GetJNIEnv(); + RAY_LOG(INFO) << "Calling System.gc() ..."; + env->CallStaticObjectMethod(java_system_class, java_system_gc); + last_gc_time_ms = current_time_ms(); + RAY_LOG(INFO) << "GC finished in " << (double) (last_gc_time_ms - start) / 1000 + << " seconds."; + } + }; + ray::CoreWorkerOptions options = { static_cast(workerMode), // worker_type ray::Language::JAVA, // langauge @@ -178,10 +222,10 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( "", // stderr_file task_execution_callback, // task_execution_callback nullptr, // check_signals - nullptr, // gc_collect + gc_collect, // gc_collect nullptr, // get_lang_stack nullptr, // kill_main - false, // ref_counting_enabled + true, // ref_counting_enabled false, // is_local_mode static_cast(numWorkersPerProcess), // num_workers }; diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc index 7083db2f5..b806c369d 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc @@ -14,10 +14,47 @@ #include "io_ray_runtime_object_NativeObjectStore.h" #include +#include "jni_utils.h" #include "ray/common/id.h" #include "ray/core_worker/common.h" #include "ray/core_worker/core_worker.h" -#include "jni_utils.h" + +ray::Status PutSerializedObject(JNIEnv *env, jobject obj, ray::ObjectID object_id, + ray::ObjectID *out_object_id, bool pin_object = true) { + auto native_ray_object = JavaNativeRayObjectToNativeRayObject(env, obj); + RAY_CHECK(native_ray_object != nullptr); + + size_t data_size = 0; + if (native_ray_object->HasData()) { + data_size = native_ray_object->GetData()->Size(); + } + std::shared_ptr data; + ray::Status status; + if (object_id.IsNil()) { + status = ray::CoreWorkerProcess::GetCoreWorker().Create( + native_ray_object->GetMetadata(), data_size, native_ray_object->GetNestedIds(), + out_object_id, &data); + } else { + status = ray::CoreWorkerProcess::GetCoreWorker().Create( + native_ray_object->GetMetadata(), data_size, object_id, &data); + *out_object_id = object_id; + } + if (!status.ok()) { + return status; + } + + // If data is nullptr, that means the ObjectID already existed, which we ignore. + // TODO(edoakes): this is hacky, we should return the error instead and deal with it + // here. + if (data != nullptr) { + if (data->Size() > 0) { + memcpy(data->Data(), native_ray_object->GetData()->Data(), data->Size()); + } + RAY_CHECK_OK(ray::CoreWorkerProcess::GetCoreWorker().Seal( + *out_object_id, pin_object && object_id.IsNil())); + } + return ray::Status::OK(); +} #ifdef __cplusplus extern "C" { @@ -26,10 +63,9 @@ extern "C" { JNIEXPORT jbyteArray JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativePut__Lio_ray_runtime_object_NativeRayObject_2( JNIEnv *env, jclass, jobject obj) { - auto ray_object = JavaNativeRayObjectToNativeRayObject(env, obj); - RAY_CHECK(ray_object != nullptr); ray::ObjectID object_id; - auto status = ray::CoreWorkerProcess::GetCoreWorker().Put(*ray_object, {}, &object_id); + auto status = PutSerializedObject(env, obj, /*object_id=*/ray::ObjectID::Nil(), + /*out_object_id=*/&object_id, /*pin_object=*/true); THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); return IdToJavaByteArray(env, object_id); } @@ -38,9 +74,10 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativePut___3BLio_ray_runtime_object_NativeRayObject_2( JNIEnv *env, jclass, jbyteArray objectId, jobject obj) { auto object_id = JavaByteArrayToId(env, objectId); - auto ray_object = JavaNativeRayObjectToNativeRayObject(env, obj); - RAY_CHECK(ray_object != nullptr); - auto status = ray::CoreWorkerProcess::GetCoreWorker().Put(*ray_object, {}, object_id); + ray::ObjectID dummy_object_id; + auto status = + PutSerializedObject(env, obj, object_id, + /*out_object_id=*/&dummy_object_id, /*pin_object=*/true); THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); } @@ -71,7 +108,10 @@ JNIEXPORT jobject JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeWai object_ids, (int)numObjects, (int64_t)timeoutMs, &results); THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); return NativeVectorToJavaList(env, results, [](JNIEnv *env, const bool &item) { - return env->NewObject(java_boolean_class, java_boolean_init, (jboolean)item); + jobject java_item = + env->NewObject(java_boolean_class, java_boolean_init, (jboolean)item); + RAY_CHECK_JAVA_EXCEPTION(env); + return java_item; }); } @@ -88,6 +128,49 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeDelete THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); } +JNIEXPORT void JNICALL +Java_io_ray_runtime_object_NativeObjectStore_nativeAddLocalReference( + JNIEnv *env, jclass, jbyteArray workerId, jbyteArray objectId) { + auto worker_id = JavaByteArrayToId(env, workerId); + auto object_id = JavaByteArrayToId(env, objectId); + auto core_worker = ray::CoreWorkerProcess::TryGetWorker(worker_id); + RAY_CHECK(core_worker); + core_worker->AddLocalReference(object_id); +} + +JNIEXPORT void JNICALL +Java_io_ray_runtime_object_NativeObjectStore_nativeRemoveLocalReference( + JNIEnv *env, jclass, jbyteArray workerId, jbyteArray objectId) { + auto worker_id = JavaByteArrayToId(env, workerId); + auto object_id = JavaByteArrayToId(env, objectId); + // We can't control the timing of Java GC, so it's normal that this method is called but + // core worker is shutting down (or already shut down). If we can't get a core worker + // instance here, skip calling the `RemoveLocalReference` method. + auto core_worker = ray::CoreWorkerProcess::TryGetWorker(worker_id); + if (core_worker) { + core_worker->RemoveLocalReference(object_id); + } +} + +JNIEXPORT jobject JNICALL +Java_io_ray_runtime_object_NativeObjectStore_nativeGetAllReferenceCounts(JNIEnv *env, + jclass) { + auto reference_counts = ray::CoreWorkerProcess::GetCoreWorker().GetAllReferenceCounts(); + return NativeMapToJavaMap>( + env, reference_counts, + [](JNIEnv *env, const ray::ObjectID &key) { + return IdToJavaByteArray(env, key); + }, + [](JNIEnv *env, const std::pair &value) { + jlongArray array = env->NewLongArray(2); + jlong *elements = env->GetLongArrayElements(array, nullptr); + elements[0] = static_cast(value.first); + elements[1] = static_cast(value.second); + env->ReleaseLongArrayElements(array, elements, 0); + return array; + }); +} + #ifdef __cplusplus } #endif diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h index 183ccc4fa..bc9719b48 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h @@ -65,6 +65,35 @@ JNIEXPORT jobject JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeWai JNIEXPORT void JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeDelete( JNIEnv *, jclass, jobject, jboolean, jboolean); +/* + * Class: io_ray_runtime_object_NativeObjectStore + * Method: nativeAddLocalReference + * Signature: ([B)V + */ +JNIEXPORT void JNICALL +Java_io_ray_runtime_object_NativeObjectStore_nativeAddLocalReference(JNIEnv *, jclass, + jbyteArray, + jbyteArray); + +/* + * Class: io_ray_runtime_object_NativeObjectStore + * Method: nativeRemoveLocalReference + * Signature: ([B)V + */ +JNIEXPORT void JNICALL +Java_io_ray_runtime_object_NativeObjectStore_nativeRemoveLocalReference(JNIEnv *, jclass, + jbyteArray, + jbyteArray); + +/* + * Class: io_ray_runtime_object_NativeObjectStore + * Method: nativeGetAllReferenceCounts + * Signature: ()Ljava/util/Map; + */ +JNIEXPORT jobject JNICALL +Java_io_ray_runtime_object_NativeObjectStore_nativeGetAllReferenceCounts(JNIEnv *, + jclass); + #ifdef __cplusplus } #endif diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc index 00a84f52e..6c1218e6e 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc @@ -76,7 +76,6 @@ inline std::vector> ToTaskArgs(JNIEnv *env, jobjec inline std::unordered_map ToResources(JNIEnv *env, jobject java_resources) { - std::unordered_map resources; return JavaMapToNativeMap( env, java_resources, [](JNIEnv *env, jobject java_key) { diff --git a/src/ray/core_worker/lib/java/jni_init.cc b/src/ray/core_worker/lib/java/jni_init.cc index a6472e33a..30bd1b949 100644 --- a/src/ray/core_worker/lib/java/jni_init.cc +++ b/src/ray/core_worker/lib/java/jni_init.cc @@ -34,6 +34,10 @@ jmethodID java_array_list_init_with_capacity; jclass java_map_class; jmethodID java_map_entry_set; +jmethodID java_map_put; + +jclass java_hash_map_class; +jmethodID java_hash_map_init; jclass java_set_class; jmethodID java_set_iterator; @@ -46,6 +50,9 @@ jclass java_map_entry_class; jmethodID java_map_entry_get_key; jmethodID java_map_entry_get_value; +jclass java_system_class; +jmethodID java_system_gc; + jclass java_ray_exception_class; jclass java_jni_exception_util_class; @@ -86,6 +93,7 @@ jclass java_native_ray_object_class; jmethodID java_native_ray_object_init; jfieldID java_native_ray_object_data; jfieldID java_native_ray_object_metadata; +jfieldID java_native_ray_object_contained_object_ids; jclass java_task_executor_class; jmethodID java_task_executor_parse_function_arguments; @@ -135,6 +143,11 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { java_map_class = LoadClass(env, "java/util/Map"); java_map_entry_set = env->GetMethodID(java_map_class, "entrySet", "()Ljava/util/Set;"); + java_map_put = env->GetMethodID( + java_map_class, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;"); + + java_hash_map_class = LoadClass(env, "java/util/HashMap"); + java_hash_map_init = env->GetMethodID(java_hash_map_class, "", "()V"); java_set_class = LoadClass(env, "java/util/Set"); java_set_iterator = @@ -151,6 +164,9 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { java_map_entry_get_value = env->GetMethodID(java_map_entry_class, "getValue", "()Ljava/lang/Object;"); + java_system_class = LoadClass(env, "java/lang/System"); + java_system_gc = env->GetStaticMethodID(java_system_class, "gc", "()V"); + java_ray_exception_class = LoadClass(env, "io/ray/api/exception/RayException"); java_jni_exception_util_class = LoadClass(env, "io/ray/runtime/util/JniExceptionUtil"); @@ -220,6 +236,8 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { env->GetFieldID(java_native_ray_object_class, "data", "[B"); java_native_ray_object_metadata = env->GetFieldID(java_native_ray_object_class, "metadata", "[B"); + java_native_ray_object_contained_object_ids = env->GetFieldID( + java_native_ray_object_class, "containedObjectIds", "Ljava/util/List;"); java_task_executor_class = LoadClass(env, "io/ray/runtime/task/TaskExecutor"); java_task_executor_parse_function_arguments = env->GetMethodID( @@ -241,9 +259,11 @@ void JNI_OnUnload(JavaVM *vm, void *reserved) { env->DeleteGlobalRef(java_list_class); env->DeleteGlobalRef(java_array_list_class); env->DeleteGlobalRef(java_map_class); + env->DeleteGlobalRef(java_hash_map_class); env->DeleteGlobalRef(java_set_class); env->DeleteGlobalRef(java_iterator_class); env->DeleteGlobalRef(java_map_entry_class); + env->DeleteGlobalRef(java_system_class); env->DeleteGlobalRef(java_ray_exception_class); env->DeleteGlobalRef(java_jni_exception_util_class); env->DeleteGlobalRef(java_base_id_class); diff --git a/src/ray/core_worker/lib/java/jni_utils.h b/src/ray/core_worker/lib/java/jni_utils.h index 750eb5b2d..67207dc60 100644 --- a/src/ray/core_worker/lib/java/jni_utils.h +++ b/src/ray/core_worker/lib/java/jni_utils.h @@ -58,6 +58,13 @@ extern jmethodID java_array_list_init_with_capacity; extern jclass java_map_class; /// entrySet method of Map interface extern jmethodID java_map_entry_set; +/// put method of Map interface +extern jmethodID java_map_put; + +/// HashMap class +extern jclass java_hash_map_class; +/// Constructor of HashMap class +extern jmethodID java_hash_map_init; /// Set interface extern jclass java_set_class; @@ -78,6 +85,11 @@ extern jmethodID java_map_entry_get_key; /// getValue method of Map.Entry interface extern jmethodID java_map_entry_get_value; +/// System class +extern jclass java_system_class; +/// gc method of System class +extern jmethodID java_system_gc; + /// RayException class extern jclass java_ray_exception_class; @@ -149,6 +161,8 @@ extern jmethodID java_native_ray_object_init; extern jfieldID java_native_ray_object_data; /// metadata field of NativeRayObject class extern jfieldID java_native_ray_object_metadata; +// containedObjectIds field of NativeRayObject class +extern jfieldID java_native_ray_object_contained_object_ids; /// TaskExecutor class extern jclass java_task_executor_class; @@ -323,6 +337,7 @@ inline jobject NativeVectorToJavaList( jobject java_list = env->NewObject(java_array_list_class, java_array_list_init_with_capacity, (jint)native_vector.size()); + RAY_CHECK_JAVA_EXCEPTION(env); for (const auto &item : native_vector) { auto element = element_converter(env, item); env->CallVoidMethod(java_list, java_list_add, element); @@ -364,10 +379,11 @@ inline std::unordered_map JavaMapToNativeMap( RAY_CHECK_JAVA_EXCEPTION(env); jobject map_entry = env->CallObjectMethod(iterator, java_iterator_next); RAY_CHECK_JAVA_EXCEPTION(env); - auto java_key = (jstring)env->CallObjectMethod(map_entry, java_map_entry_get_key); + auto java_key = env->CallObjectMethod(map_entry, java_map_entry_get_key); RAY_CHECK_JAVA_EXCEPTION(env); key_type key = key_converter(env, java_key); auto java_value = env->CallObjectMethod(map_entry, java_map_entry_get_value); + RAY_CHECK_JAVA_EXCEPTION(env); value_type value = value_converter(env, java_value); native_map.emplace(key, value); env->DeleteLocalRef(java_key); @@ -381,6 +397,25 @@ inline std::unordered_map JavaMapToNativeMap( return native_map; } +/// Convert a C++ std::unordered_map to a Java Map +template +inline jobject NativeMapToJavaMap( + JNIEnv *env, const std::unordered_map &native_map, + const std::function &key_converter, + const std::function &value_converter) { + jobject java_map = env->NewObject(java_hash_map_class, java_hash_map_init); + RAY_CHECK_JAVA_EXCEPTION(env); + for (const auto &entry : native_map) { + jobject java_key = key_converter(env, entry.first); + jobject java_value = value_converter(env, entry.second); + env->CallObjectMethod(java_map, java_map_put, java_key, java_value); + RAY_CHECK_JAVA_EXCEPTION(env); + env->DeleteLocalRef(java_key); + env->DeleteLocalRef(java_value); + } + return java_map; +} + /// Convert a C++ ray::Buffer to a Java byte array. inline jbyteArray NativeBufferToJavaByteArray(JNIEnv *env, const std::shared_ptr buffer) { @@ -423,9 +458,16 @@ inline std::shared_ptr JavaNativeRayObjectToNativeRayObject( if (metadata_buffer && metadata_buffer->Size() == 0) { metadata_buffer = nullptr; } - // TODO: Support nested IDs for Java. + + auto java_contained_ids = + env->GetObjectField(java_obj, java_native_ray_object_contained_object_ids); + std::vector contained_object_ids; + JavaListToNativeVector( + env, java_contained_ids, &contained_object_ids, [](JNIEnv *env, jobject id) { + return JavaByteArrayToId(env, static_cast(id)); + }); return std::make_shared(data_buffer, metadata_buffer, - std::vector()); + contained_object_ids); } /// Convert a C++ ray::RayObject to a Java NativeRayObject. @@ -438,6 +480,7 @@ inline jobject NativeRayObjectToJavaNativeRayObject( auto java_metadata = NativeBufferToJavaByteArray(env, rayObject->GetMetadata()); auto java_obj = env->NewObject(java_native_ray_object_class, java_native_ray_object_init, java_data, java_metadata); + RAY_CHECK_JAVA_EXCEPTION(env); env->DeleteLocalRef(java_metadata); env->DeleteLocalRef(java_data); return java_obj; diff --git a/streaming/java/dependencies.bzl b/streaming/java/dependencies.bzl index 998d88434..18230f0b9 100644 --- a/streaming/java/dependencies.bzl +++ b/streaming/java/dependencies.bzl @@ -16,12 +16,12 @@ def gen_streaming_java_deps(): "org.slf4j:slf4j-api:1.7.12", "org.slf4j:slf4j-log4j12:1.7.25", "org.apache.logging.log4j:log4j-core:2.8.2", - "org.testng:testng:6.9.10", + "org.testng:testng:7.1.0", "log4j:log4j:1.2.17", "org.mockito:mockito-all:1.10.19", "org.apache.commons:commons-lang3:3.3.2", "org.msgpack:msgpack-core:0.8.20", - "org.testng:testng:6.9.10", + "org.testng:testng:7.1.0", "org.mockito:mockito-all:1.10.19", "org.powermock:powermock-module-testng:1.6.6", "org.powermock:powermock-api-mockito:1.6.6", diff --git a/streaming/java/pom.xml b/streaming/java/pom.xml index 9fea52c93..ac3c74468 100644 --- a/streaming/java/pom.xml +++ b/streaming/java/pom.xml @@ -49,7 +49,7 @@ 0.1-SNAPSHOT 1.7.25 1.2.17 - 6.9.10 + 7.1.0 1.10.19 27.0.1-jre 2.57 diff --git a/streaming/java/streaming-api/pom.xml b/streaming/java/streaming-api/pom.xml index 0f71d8b77..2cc92a321 100644 --- a/streaming/java/streaming-api/pom.xml +++ b/streaming/java/streaming-api/pom.xml @@ -65,7 +65,7 @@ org.testng testng - 6.9.10 + 7.1.0 diff --git a/streaming/java/streaming-runtime/pom.xml b/streaming/java/streaming-runtime/pom.xml index 1a5e57d8e..d3066ba0e 100644 --- a/streaming/java/streaming-runtime/pom.xml +++ b/streaming/java/streaming-runtime/pom.xml @@ -117,7 +117,7 @@ org.testng testng - 6.9.10 + 7.1.0 diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/core/graph/ExecutionGraphTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/core/graph/ExecutionGraphTest.java index 36c3b71ad..2c49eadaa 100644 --- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/core/graph/ExecutionGraphTest.java +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/core/graph/ExecutionGraphTest.java @@ -66,7 +66,7 @@ public class ExecutionGraphTest extends BaseUnitTest { List upStreamVertices = upStream.getExecutionVertices(); List downStreamVertices = downStream.getExecutionVertices(); upStreamVertices.forEach(vertex -> { - Assert.assertEquals(vertex.getResource().get(ResourceType.CPU.name()), 2.0); + Assert.assertEquals((double) vertex.getResource().get(ResourceType.CPU.name()), 2.0); vertex.getOutputEdges().forEach(upStreamOutPutEdge -> { Assert.assertTrue(downStreamVertices.contains(upStreamOutPutEdge.getTargetExecutionVertex())); });