mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 22:20:31 +08:00
[Java] Local and distributed ref counting in Java (#9371)
This commit is contained in:
@@ -1,7 +1,5 @@
|
||||
package io.ray.api;
|
||||
|
||||
import io.ray.api.id.ObjectId;
|
||||
|
||||
/**
|
||||
* Represents a reference to an object in the object store.
|
||||
* @param <T> The object type.
|
||||
@@ -14,15 +12,5 @@ public interface ObjectRef<T> {
|
||||
*/
|
||||
T get();
|
||||
|
||||
/**
|
||||
* Get the object id.
|
||||
*/
|
||||
ObjectId getId();
|
||||
|
||||
/**
|
||||
* Get the Object type.
|
||||
*/
|
||||
Class<T> getType();
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
package io.ray.api;
|
||||
|
||||
import io.ray.api.id.ObjectId;
|
||||
import io.ray.api.id.UniqueId;
|
||||
import io.ray.api.placementgroup.PlacementGroup;
|
||||
import io.ray.api.placementgroup.PlacementStrategy;
|
||||
import io.ray.api.runtime.RayRuntime;
|
||||
import io.ray.api.runtime.RayRuntimeFactory;
|
||||
import io.ray.api.runtimecontext.RuntimeContext;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
@@ -67,25 +65,13 @@ public final class Ray extends RayCall {
|
||||
}
|
||||
|
||||
/**
|
||||
* Get an object by id from the object store.
|
||||
* Get an object by `ObjectRef` from the object store.
|
||||
*
|
||||
* @param objectId The ID of the object to get.
|
||||
* @param objectType The type of the object to get.
|
||||
* @param objectRef The reference of the object to get.
|
||||
* @return The Java object.
|
||||
*/
|
||||
public static <T> T get(ObjectId objectId, Class<T> objectType) {
|
||||
return runtime.get(objectId, objectType);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a list of objects by ids from the object store.
|
||||
*
|
||||
* @param objectIds The list of object IDs.
|
||||
* @param objectType The type of object.
|
||||
* @return A list of Java objects.
|
||||
*/
|
||||
public static <T> List<T> get(List<ObjectId> objectIds, Class<T> objectType) {
|
||||
return runtime.get(objectIds, objectType);
|
||||
public static <T> T get(ObjectRef<T> objectRef) {
|
||||
return runtime.get(objectRef);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -95,13 +81,7 @@ public final class Ray extends RayCall {
|
||||
* @return A list of Java objects.
|
||||
*/
|
||||
public static <T> List<T> get(List<ObjectRef<T>> objectList) {
|
||||
List<ObjectId> objectIds = new ArrayList<>();
|
||||
Class<T> objectType = null;
|
||||
for (ObjectRef<T> o : objectList) {
|
||||
objectIds.add(o.getId());
|
||||
objectType = o.getType();
|
||||
}
|
||||
return runtime.get(objectIds, objectType);
|
||||
return runtime.get(objectList);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -10,7 +10,6 @@ import io.ray.api.function.PyActorMethod;
|
||||
import io.ray.api.function.PyFunction;
|
||||
import io.ray.api.function.RayFunc;
|
||||
import io.ray.api.id.ActorId;
|
||||
import io.ray.api.id.ObjectId;
|
||||
import io.ray.api.id.UniqueId;
|
||||
import io.ray.api.options.ActorCreationOptions;
|
||||
import io.ray.api.options.CallOptions;
|
||||
@@ -43,20 +42,18 @@ public interface RayRuntime {
|
||||
/**
|
||||
* Get an object from the object store.
|
||||
*
|
||||
* @param objectId The ID of the object to get.
|
||||
* @param objectType The type of the object to get.
|
||||
* @param objectRef The reference of the object to get.
|
||||
* @return The Java object.
|
||||
*/
|
||||
<T> T get(ObjectId objectId, Class<T> objectType);
|
||||
<T> T get(ObjectRef<T> objectRef);
|
||||
|
||||
/**
|
||||
* Get a list of objects from the object store.
|
||||
*
|
||||
* @param objectIds The list of object IDs.
|
||||
* @param objectType The type of object.
|
||||
* @param objectRefs The list of object references.
|
||||
* @return A list of Java objects.
|
||||
*/
|
||||
<T> List<T> get(List<ObjectId> objectIds, Class<T> objectType);
|
||||
<T> List<T> get(List<ObjectRef<T>> objectRefs);
|
||||
|
||||
/**
|
||||
* Wait for a list of RayObjects to be locally available, until specified number of objects are
|
||||
@@ -72,11 +69,11 @@ public interface RayRuntime {
|
||||
/**
|
||||
* Free a list of objects from Plasma Store.
|
||||
*
|
||||
* @param objectIds The object ids to free.
|
||||
* @param objectRefs The object references to free.
|
||||
* @param localOnly Whether only free objects for local object store or not.
|
||||
* @param deleteCreatingTasks Whether also delete objects' creating tasks from GCS.
|
||||
*/
|
||||
void free(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks);
|
||||
void free(List<ObjectRef<?>> objectRefs, boolean localOnly, boolean deleteCreatingTasks);
|
||||
|
||||
/**
|
||||
* Set the resource for the specific node.
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
<suppress checks="SummaryJavadoc" files=".*" />
|
||||
<suppress checks="MemberNameCheck" files="PathConfig.java"/>
|
||||
<suppress checks="MemberNameCheck" files="RayParameters.java"/>
|
||||
<suppress checks="NoFinalizer" files="NativeRayActor.java"/>
|
||||
<suppress checks="AbbreviationAsWordInNameCheck" files="RayParameters.java"/>
|
||||
<suppress checks=".*" files="RayCall.java"/>
|
||||
<suppress checks=".*" files="ActorCall.java"/>
|
||||
|
||||
@@ -18,7 +18,7 @@ def gen_java_deps():
|
||||
"org.msgpack:msgpack-core:0.8.20",
|
||||
"org.ow2.asm:asm:6.0",
|
||||
"org.slf4j:slf4j-log4j12:1.7.25",
|
||||
"org.testng:testng:6.9.10",
|
||||
"org.testng:testng:7.1.0",
|
||||
"redis.clients:jedis:2.8.0",
|
||||
"net.java.dev.jna:jna:5.5.0",
|
||||
],
|
||||
|
||||
@@ -112,7 +112,7 @@
|
||||
<dependency>
|
||||
<groupId>org.testng</groupId>
|
||||
<artifactId>testng</artifactId>
|
||||
<version>6.9.10</version>
|
||||
<version>7.1.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>redis.clients</groupId>
|
||||
|
||||
@@ -36,10 +36,12 @@ import io.ray.runtime.task.ArgumentsBuilder;
|
||||
import io.ray.runtime.task.FunctionArg;
|
||||
import io.ray.runtime.task.TaskExecutor;
|
||||
import io.ray.runtime.task.TaskSubmitter;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.stream.Collectors;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
@@ -72,9 +74,6 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
||||
runtimeContext = new RuntimeContextImpl(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public abstract void shutdown();
|
||||
|
||||
@Override
|
||||
public <T> ObjectRef<T> put(T obj) {
|
||||
ObjectId objectId = objectStore.put(obj);
|
||||
@@ -82,19 +81,27 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> T get(ObjectId objectId, Class<T> objectType) throws RayException {
|
||||
List<T> ret = get(ImmutableList.of(objectId), objectType);
|
||||
public <T> T get(ObjectRef<T> objectRef) throws RayException {
|
||||
List<T> ret = get(ImmutableList.of(objectRef));
|
||||
return ret.get(0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> List<T> get(List<ObjectId> objectIds, Class<T> objectType) {
|
||||
public <T> List<T> get(List<ObjectRef<T>> objectRefs) {
|
||||
List<ObjectId> objectIds = new ArrayList<>();
|
||||
Class<T> objectType = null;
|
||||
for (ObjectRef<T> o : objectRefs) {
|
||||
ObjectRefImpl<T> objectRefImpl = (ObjectRefImpl<T>) o;
|
||||
objectIds.add(objectRefImpl.getId());
|
||||
objectType = objectRefImpl.getType();
|
||||
}
|
||||
return objectStore.get(objectIds, objectType);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void free(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks) {
|
||||
objectStore.delete(objectIds, localOnly, deleteCreatingTasks);
|
||||
public void free(List<ObjectRef<?>> objectRefs, boolean localOnly, boolean deleteCreatingTasks) {
|
||||
objectStore.delete(objectRefs.stream().map(ref -> ((ObjectRefImpl<?>) ref).getId()).collect(
|
||||
Collectors.toList()), localOnly, deleteCreatingTasks);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -22,6 +22,9 @@ import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.locks.Lock;
|
||||
import java.util.concurrent.locks.ReadWriteLock;
|
||||
import java.util.concurrent.locks.ReentrantReadWriteLock;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
@@ -35,6 +38,15 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
|
||||
private RunManager manager = null;
|
||||
|
||||
/**
|
||||
* In Java, GC runs in a standalone thread, and we can't control the exact
|
||||
* timing of garbage collection. By using this lock, when
|
||||
* {@link NativeObjectStore#nativeRemoveLocalReference} is executing, the core
|
||||
* worker will not be shut down, therefore it guarantees some kind of
|
||||
* thread-safety. Note that this guarantee only works for driver.
|
||||
*/
|
||||
private final ReadWriteLock shutdownLock = new ReentrantReadWriteLock();
|
||||
|
||||
|
||||
static {
|
||||
LOGGER.debug("Loading native libraries.");
|
||||
@@ -105,7 +117,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
|
||||
taskExecutor = new NativeTaskExecutor(this);
|
||||
workerContext = new NativeWorkerContext();
|
||||
objectStore = new NativeObjectStore(workerContext);
|
||||
objectStore = new NativeObjectStore(workerContext, shutdownLock);
|
||||
taskSubmitter = new NativeTaskSubmitter();
|
||||
|
||||
LOGGER.debug("RayNativeRuntime started with store {}, raylet {}",
|
||||
@@ -114,19 +126,25 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
|
||||
@Override
|
||||
public void shutdown() {
|
||||
if (rayConfig.workerMode == WorkerType.DRIVER) {
|
||||
nativeShutdown();
|
||||
if (null != manager) {
|
||||
manager.cleanup();
|
||||
manager = null;
|
||||
Lock writeLock = shutdownLock.readLock();
|
||||
writeLock.lock();
|
||||
try {
|
||||
if (rayConfig.workerMode == WorkerType.DRIVER) {
|
||||
nativeShutdown();
|
||||
if (null != manager) {
|
||||
manager.cleanup();
|
||||
manager = null;
|
||||
}
|
||||
}
|
||||
if (null != gcsClient) {
|
||||
gcsClient.destroy();
|
||||
gcsClient = null;
|
||||
}
|
||||
RayConfig.reset();
|
||||
LOGGER.debug("RayNativeRuntime shutdown");
|
||||
} finally {
|
||||
writeLock.unlock();
|
||||
}
|
||||
if (null != gcsClient) {
|
||||
gcsClient.destroy();
|
||||
gcsClient = null;
|
||||
}
|
||||
RayConfig.reset();
|
||||
LOGGER.debug("RayNativeRuntime shutdown");
|
||||
}
|
||||
|
||||
// For test purpose only
|
||||
|
||||
@@ -2,6 +2,7 @@ package io.ray.runtime.object;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.ray.api.id.ObjectId;
|
||||
import io.ray.api.id.UniqueId;
|
||||
import io.ray.runtime.context.WorkerContext;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
@@ -96,4 +97,12 @@ public class LocalModeObjectStore extends ObjectStore {
|
||||
pool.remove(objectId);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addLocalReference(UniqueId workerId, ObjectId objectId) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeLocalReference(UniqueId workerId, ObjectId objectId) {
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,8 +2,13 @@ package io.ray.runtime.object;
|
||||
|
||||
import io.ray.api.id.BaseId;
|
||||
import io.ray.api.id.ObjectId;
|
||||
import io.ray.api.id.UniqueId;
|
||||
import io.ray.runtime.context.WorkerContext;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.locks.Lock;
|
||||
import java.util.concurrent.locks.ReadWriteLock;
|
||||
import java.util.stream.Collectors;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
@@ -15,8 +20,11 @@ public class NativeObjectStore extends ObjectStore {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(NativeObjectStore.class);
|
||||
|
||||
public NativeObjectStore(WorkerContext workerContext) {
|
||||
private final ReadWriteLock shutdownLock;
|
||||
|
||||
public NativeObjectStore(WorkerContext workerContext, ReadWriteLock shutdownLock) {
|
||||
super(workerContext);
|
||||
this.shutdownLock = shutdownLock;
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -44,6 +52,31 @@ public class NativeObjectStore extends ObjectStore {
|
||||
nativeDelete(toBinaryList(objectIds), localOnly, deleteCreatingTasks);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addLocalReference(UniqueId workerId, ObjectId objectId) {
|
||||
nativeAddLocalReference(workerId.getBytes(), objectId.getBytes());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeLocalReference(UniqueId workerId, ObjectId objectId) {
|
||||
Lock readLock = shutdownLock.readLock();
|
||||
readLock.lock();
|
||||
try {
|
||||
nativeRemoveLocalReference(workerId.getBytes(), objectId.getBytes());
|
||||
} finally {
|
||||
readLock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public Map<ObjectId, long[]> getAllReferenceCounts() {
|
||||
Map<ObjectId, long[]> referenceCounts = new HashMap<>();
|
||||
for (Map.Entry<byte[], long[]> entry :
|
||||
nativeGetAllReferenceCounts().entrySet()) {
|
||||
referenceCounts.put(new ObjectId(entry.getKey()), entry.getValue());
|
||||
}
|
||||
return referenceCounts;
|
||||
}
|
||||
|
||||
private static List<byte[]> toBinaryList(List<ObjectId> ids) {
|
||||
return ids.stream().map(BaseId::getBytes).collect(Collectors.toList());
|
||||
}
|
||||
@@ -59,4 +92,10 @@ public class NativeObjectStore extends ObjectStore {
|
||||
|
||||
private static native void nativeDelete(List<byte[]> objectIds, boolean localOnly,
|
||||
boolean deleteCreatingTasks);
|
||||
|
||||
private static native void nativeAddLocalReference(byte[] workerId, byte[] objectId);
|
||||
|
||||
private static native void nativeRemoveLocalReference(byte[] workerId, byte[] objectId);
|
||||
|
||||
private static native Map<byte[], long[]> nativeGetAllReferenceCounts();
|
||||
}
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
package io.ray.runtime.object;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.ray.api.id.BaseId;
|
||||
import io.ray.api.id.ObjectId;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Binary representation of a ray object. See `RayObject` class in C++ for details.
|
||||
@@ -9,11 +14,17 @@ public class NativeRayObject {
|
||||
|
||||
public byte[] data;
|
||||
public byte[] metadata;
|
||||
public List<byte[]> containedObjectIds;
|
||||
|
||||
public NativeRayObject(byte[] data, byte[] metadata) {
|
||||
Preconditions.checkState(bufferLength(data) > 0 || bufferLength(metadata) > 0);
|
||||
this.data = data;
|
||||
this.metadata = metadata;
|
||||
this.containedObjectIds = Collections.emptyList();
|
||||
}
|
||||
|
||||
public void setContainedObjectIds(List<ObjectId> containedObjectIds) {
|
||||
this.containedObjectIds = toBinaryList(containedObjectIds);
|
||||
}
|
||||
|
||||
private static int bufferLength(byte[] buffer) {
|
||||
@@ -23,6 +34,10 @@ public class NativeRayObject {
|
||||
return buffer.length;
|
||||
}
|
||||
|
||||
private static List<byte[]> toBinaryList(List<ObjectId> ids) {
|
||||
return ids.stream().map(BaseId::getBytes).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "<data>: " + bufferLength(data) + ", <metadata>: " + bufferLength(metadata);
|
||||
|
||||
@@ -1,54 +1,56 @@
|
||||
package io.ray.runtime.object;
|
||||
|
||||
import com.google.common.base.FinalizableReferenceQueue;
|
||||
import com.google.common.base.FinalizableWeakReference;
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.Sets;
|
||||
import io.ray.api.ObjectRef;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.api.id.ObjectId;
|
||||
import java.io.Serializable;
|
||||
import io.ray.api.id.UniqueId;
|
||||
import io.ray.runtime.RayRuntimeInternal;
|
||||
import java.io.Externalizable;
|
||||
import java.io.IOException;
|
||||
import java.io.ObjectInput;
|
||||
import java.io.ObjectOutput;
|
||||
import java.lang.ref.Reference;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
/**
|
||||
* Implementation of {@link ObjectRef}.
|
||||
*/
|
||||
public final class ObjectRefImpl<T> implements ObjectRef<T>, Serializable {
|
||||
public final class ObjectRefImpl<T> implements ObjectRef<T>, Externalizable {
|
||||
|
||||
private final ObjectId id;
|
||||
private static final FinalizableReferenceQueue REFERENCE_QUEUE = new FinalizableReferenceQueue();
|
||||
|
||||
/**
|
||||
* Cache the result of `Ray.get()`.
|
||||
*
|
||||
* Note, this is necessary for direct calls, in which case, it's not allowed to call `Ray.get` on
|
||||
* the same object twice.
|
||||
*/
|
||||
private transient T object;
|
||||
private static final Set<Reference<ObjectRefImpl<?>>> REFERENCES = Sets.newConcurrentHashSet();
|
||||
|
||||
private ObjectId id;
|
||||
|
||||
// In GC thread, we don't know which worker this object binds to, so we need to
|
||||
// store the worker ID for later uses.
|
||||
private transient UniqueId workerId;
|
||||
|
||||
private Class<T> type;
|
||||
|
||||
/**
|
||||
* Whether the object is already gotten from the object store.
|
||||
*/
|
||||
private transient boolean objectGotten;
|
||||
|
||||
public ObjectRefImpl(ObjectId id, Class<T> type) {
|
||||
this.id = id;
|
||||
this.type = type;
|
||||
object = null;
|
||||
objectGotten = false;
|
||||
addLocalReference();
|
||||
}
|
||||
|
||||
public ObjectRefImpl() {}
|
||||
|
||||
@Override
|
||||
public synchronized T get() {
|
||||
if (!objectGotten) {
|
||||
object = Ray.get(id, type);
|
||||
objectGotten = true;
|
||||
}
|
||||
return object;
|
||||
return Ray.get(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ObjectId getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Class<T> getType() {
|
||||
return type;
|
||||
}
|
||||
@@ -57,4 +59,56 @@ public final class ObjectRefImpl<T> implements ObjectRef<T>, Serializable {
|
||||
public String toString() {
|
||||
return "ObjectRef(" + id.toString() + ")";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeExternal(ObjectOutput out) throws IOException {
|
||||
out.writeObject(this.getId());
|
||||
out.writeObject(this.getType());
|
||||
ObjectSerializer.addContainedObjectId(this.getId());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
|
||||
this.id = (ObjectId) in.readObject();
|
||||
this.type = (Class<T>) in.readObject();
|
||||
addLocalReference();
|
||||
}
|
||||
|
||||
private void addLocalReference() {
|
||||
Preconditions.checkState(workerId == null);
|
||||
RayRuntimeInternal runtime = (RayRuntimeInternal) Ray.internal();
|
||||
workerId = runtime.getWorkerContext().getCurrentWorkerId();
|
||||
runtime.getObjectStore().addLocalReference(workerId, id);
|
||||
new ObjectRefImplReference(this);
|
||||
}
|
||||
|
||||
private static final class ObjectRefImplReference extends
|
||||
FinalizableWeakReference<ObjectRefImpl<?>> {
|
||||
|
||||
private final UniqueId workerId;
|
||||
private final ObjectId objectId;
|
||||
private final AtomicBoolean removed;
|
||||
|
||||
public ObjectRefImplReference(ObjectRefImpl<?> obj) {
|
||||
super(obj, REFERENCE_QUEUE);
|
||||
this.workerId = obj.workerId;
|
||||
this.objectId = obj.id;
|
||||
this.removed = new AtomicBoolean(false);
|
||||
REFERENCES.add(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void finalizeReferent() {
|
||||
// This method may be invoked multiple times on the same instance (due to explicit invoking in
|
||||
// unit tests). So if `workerId` is null, it means this method has been invoked.
|
||||
if (!removed.getAndSet(true)) {
|
||||
REFERENCES.remove(this);
|
||||
RayRuntimeInternal runtime = (RayRuntimeInternal) Ray.internal();
|
||||
// It's possible that GC is executed after the runtime is shutdown.
|
||||
if (runtime != null) {
|
||||
runtime.getObjectStore().removeLocalReference(workerId, objectId);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,11 @@ import io.ray.api.id.ObjectId;
|
||||
import io.ray.runtime.generated.Gcs.ErrorType;
|
||||
import io.ray.runtime.serializer.Serializer;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
|
||||
/**
|
||||
@@ -31,6 +35,12 @@ public class ObjectSerializer {
|
||||
public static final byte[] OBJECT_METADATA_TYPE_PYTHON = "PYTHON".getBytes();
|
||||
public static final byte[] OBJECT_METADATA_TYPE_RAW = "RAW".getBytes();
|
||||
|
||||
// When an outer object is being serialized, the nested ObjectRefs are all
|
||||
// serialized and the writeExternal method of the nested ObjectRefs are
|
||||
// executed. So after the outer object is serialized, the containedObjectIds
|
||||
// field will contain all the nested object IDs.
|
||||
static ThreadLocal<Set<ObjectId>> containedObjectIds = ThreadLocal.withInitial(HashSet::new);
|
||||
|
||||
/**
|
||||
* Deserialize an object from an {@link NativeRayObject} instance.
|
||||
*
|
||||
@@ -100,9 +110,29 @@ public class ObjectSerializer {
|
||||
byte[] serializedBytes = Serializer.encode(object).getLeft();
|
||||
return new NativeRayObject(serializedBytes, TASK_EXECUTION_EXCEPTION_META);
|
||||
} else {
|
||||
Pair<byte[], Boolean> serialized = Serializer.encode(object);
|
||||
return new NativeRayObject(serialized.getLeft(), serialized.getRight() ?
|
||||
OBJECT_METADATA_TYPE_CROSS_LANGUAGE : OBJECT_METADATA_TYPE_JAVA);
|
||||
try {
|
||||
Pair<byte[], Boolean> serialized = Serializer.encode(object);
|
||||
NativeRayObject nativeRayObject = new NativeRayObject(serialized.getLeft(),
|
||||
serialized.getRight()
|
||||
? OBJECT_METADATA_TYPE_CROSS_LANGUAGE
|
||||
: OBJECT_METADATA_TYPE_JAVA);
|
||||
nativeRayObject.setContainedObjectIds(getAndClearContainedObjectIds());
|
||||
return nativeRayObject;
|
||||
} catch (Exception e) {
|
||||
// Clear `containedObjectIds`.
|
||||
getAndClearContainedObjectIds();
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void addContainedObjectId(ObjectId objectId) {
|
||||
containedObjectIds.get().add(objectId);
|
||||
}
|
||||
|
||||
private static List<ObjectId> getAndClearContainedObjectIds() {
|
||||
List<ObjectId> ids = new ArrayList<>(containedObjectIds.get());
|
||||
containedObjectIds.get().clear();
|
||||
return ids;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import io.ray.api.ObjectRef;
|
||||
import io.ray.api.WaitResult;
|
||||
import io.ray.api.exception.RayException;
|
||||
import io.ray.api.id.ObjectId;
|
||||
import io.ray.api.id.UniqueId;
|
||||
import io.ray.runtime.context.WorkerContext;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
@@ -137,7 +138,8 @@ public abstract class ObjectStore {
|
||||
return new WaitResult<>(Collections.emptyList(), Collections.emptyList());
|
||||
}
|
||||
|
||||
List<ObjectId> ids = waitList.stream().map(ObjectRef::getId).collect(Collectors.toList());
|
||||
List<ObjectId> ids = waitList.stream().map(ref -> ((ObjectRefImpl<?>) ref).getId())
|
||||
.collect(Collectors.toList());
|
||||
|
||||
List<Boolean> ready = wait(ids, numReturns, timeoutMs);
|
||||
List<ObjectRef<T>> readyList = new ArrayList<>();
|
||||
@@ -164,4 +166,18 @@ public abstract class ObjectStore {
|
||||
*/
|
||||
public abstract void delete(List<ObjectId> objectIds, boolean localOnly,
|
||||
boolean deleteCreatingTasks);
|
||||
|
||||
/**
|
||||
* Increase the local reference count for this object ID.
|
||||
* @param workerId The ID of the worker to increase on.
|
||||
* @param objectId The object ID to increase the reference count for.
|
||||
*/
|
||||
public abstract void addLocalReference(UniqueId workerId, ObjectId objectId);
|
||||
|
||||
/**
|
||||
* Decrease the reference count for this object ID.
|
||||
* @param workerId The ID of the worker to decrease on.
|
||||
* @param objectId The object ID to decrease the reference count for.
|
||||
*/
|
||||
public abstract void removeLocalReference(UniqueId workerId, ObjectId objectId);
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import io.ray.api.id.ObjectId;
|
||||
import io.ray.runtime.RayRuntimeInternal;
|
||||
import io.ray.runtime.generated.Common.Language;
|
||||
import io.ray.runtime.object.NativeRayObject;
|
||||
import io.ray.runtime.object.ObjectRefImpl;
|
||||
import io.ray.runtime.object.ObjectSerializer;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
@@ -22,7 +23,8 @@ public class ArgumentsBuilder {
|
||||
* If the the size of an argument's serialized data is smaller than this number, the argument will
|
||||
* be passed by value. Otherwise it'll be passed by reference.
|
||||
*/
|
||||
private static final int LARGEST_SIZE_PASS_BY_VALUE = 100 * 1024;
|
||||
// TODO(kfstorm): Read from internal config `max_direct_call_object_size`.
|
||||
public static final int LARGEST_SIZE_PASS_BY_VALUE = 100 * 1024;
|
||||
|
||||
/**
|
||||
* This dummy type is also defined in signature.py. Please keep it synced.
|
||||
@@ -39,7 +41,8 @@ public class ArgumentsBuilder {
|
||||
ObjectId id = null;
|
||||
NativeRayObject value = null;
|
||||
if (arg instanceof ObjectRef) {
|
||||
id = ((ObjectRef) arg).getId();
|
||||
Preconditions.checkState(arg instanceof ObjectRefImpl);
|
||||
id = ((ObjectRefImpl<?>) arg).getId();
|
||||
} else {
|
||||
value = ObjectSerializer.serialize(arg);
|
||||
if (language != Language.JAVA) {
|
||||
|
||||
+1
-1
@@ -74,7 +74,7 @@
|
||||
<dependency>
|
||||
<groupId>org.testng</groupId>
|
||||
<artifactId>testng</artifactId>
|
||||
<version>6.9.10</version>
|
||||
<version>7.1.0</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
<build>
|
||||
|
||||
@@ -42,6 +42,10 @@ public class ActorTest extends BaseTest {
|
||||
value += largeObject.data.length;
|
||||
return value;
|
||||
}
|
||||
|
||||
public TestUtils.LargeObject createLargeObject() {
|
||||
return new TestUtils.LargeObject();
|
||||
}
|
||||
}
|
||||
|
||||
public void testCreateAndCallActor() {
|
||||
@@ -68,8 +72,7 @@ public class ActorTest extends BaseTest {
|
||||
ObjectRef<Integer> result = actor.task(Counter::getValue).remote();
|
||||
Assert.assertEquals(result.get(), Integer.valueOf(1));
|
||||
Assert.assertEquals(result.get(), Integer.valueOf(1));
|
||||
// TODO(hchen): The following code will still fail, and can be fixed by using ref counting.
|
||||
// Assert.assertEquals(Ray.get(result.getId()), Integer.valueOf(1));
|
||||
Assert.assertEquals(Ray.get(result), Integer.valueOf(1));
|
||||
}
|
||||
|
||||
public void testCallActorWithLargeObject() {
|
||||
@@ -117,33 +120,28 @@ public class ActorTest extends BaseTest {
|
||||
Collections.singletonList(actor), 100).remote().get());
|
||||
}
|
||||
|
||||
// TODO(qwang): Will re-enable this test case once ref counting is supported in Java.
|
||||
@Test(enabled = false, groups = {"cluster"})
|
||||
// This test case follows `test_internal_free` in `python/ray/tests/test_advanced.py`.
|
||||
@Test(groups = {"cluster"})
|
||||
public void testUnreconstructableActorObject() throws InterruptedException {
|
||||
// The UnreconstructableException is created by raylet.
|
||||
ActorHandle<Counter> counter = Ray.actor(Counter::new, 100).remote();
|
||||
// Call an actor method.
|
||||
ObjectRef value = counter.task(Counter::getValue).remote();
|
||||
Assert.assertEquals(100, value.get());
|
||||
// Delete the object from the object store.
|
||||
Ray.internal().free(ImmutableList.of(value.getId()), false, false);
|
||||
// Wait until the object is deleted, because the above free operation is async.
|
||||
while (true) {
|
||||
Boolean result = TestUtils.getRuntime().getObjectStore()
|
||||
.wait(ImmutableList.of(value.getId()), 1, 0).get(0);
|
||||
if (!result) {
|
||||
break;
|
||||
}
|
||||
TimeUnit.MILLISECONDS.sleep(100);
|
||||
}
|
||||
Ray.internal().free(ImmutableList.of(value), false, false);
|
||||
// Wait for delete RPC to propagate
|
||||
TimeUnit.SECONDS.sleep(1);
|
||||
// Free deletes from in-memory store.
|
||||
Assert.expectThrows(UnreconstructableException.class, () -> value.get());
|
||||
|
||||
try {
|
||||
// Try getting the object again, this should throw an UnreconstructableException.
|
||||
// Use `Ray.get()` to bypass the cache in `RayObjectImpl`.
|
||||
Ray.get(value.getId(), value.getType());
|
||||
Assert.fail("This line should not be reachable.");
|
||||
} catch (UnreconstructableException e) {
|
||||
Assert.assertEquals(value.getId(), e.objectId);
|
||||
}
|
||||
// Call an actor method.
|
||||
ObjectRef<TestUtils.LargeObject> largeValue = counter.task(Counter::createLargeObject).remote();
|
||||
Assert.assertTrue(largeValue.get() instanceof TestUtils.LargeObject);
|
||||
// Delete the object from the object store.
|
||||
Ray.internal().free(ImmutableList.of(largeValue), false, false);
|
||||
// Wait for delete RPC to propagate
|
||||
TimeUnit.SECONDS.sleep(1);
|
||||
// Free deletes big objects from plasma store.
|
||||
Assert.expectThrows(UnreconstructableException.class, () -> largeValue.get());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ public abstract class BaseMultiLanguageTest {
|
||||
}
|
||||
}
|
||||
|
||||
@BeforeClass(alwaysRun = true)
|
||||
@BeforeClass(alwaysRun = true, inheritGroups = false)
|
||||
public void setUp() {
|
||||
// Delete existing socket files.
|
||||
for (String socket : ImmutableList.of(RAYLET_SOCKET_NAME, PLASMA_STORE_SOCKET_NAME)) {
|
||||
@@ -106,7 +106,7 @@ public abstract class BaseMultiLanguageTest {
|
||||
return ImmutableMap.of();
|
||||
}
|
||||
|
||||
@AfterClass(alwaysRun = true)
|
||||
@AfterClass(alwaysRun = true, inheritGroups = false)
|
||||
public void tearDown() {
|
||||
// Disconnect to the cluster.
|
||||
Ray.shutdown();
|
||||
|
||||
@@ -33,7 +33,7 @@ public class GcsClientTest extends BaseTest {
|
||||
Assert.assertEquals(allNodeInfo.size(), 1);
|
||||
Assert.assertEquals(allNodeInfo.get(0).nodeAddress, config.nodeIp);
|
||||
Assert.assertTrue(allNodeInfo.get(0).isAlive);
|
||||
Assert.assertEquals(allNodeInfo.get(0).resources.get("A"), 8.0);
|
||||
Assert.assertEquals((double) allNodeInfo.get(0).resources.get("A"), 8.0);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
package io.ray.test;
|
||||
|
||||
import io.ray.api.ActorHandle;
|
||||
import io.ray.api.ObjectRef;
|
||||
import io.ray.api.Ray;
|
||||
import java.lang.ref.WeakReference;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.AfterClass;
|
||||
import org.testng.annotations.BeforeClass;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
@Test(groups = {"cluster"})
|
||||
public class GlobalGcTest extends BaseTest {
|
||||
|
||||
@BeforeClass
|
||||
public void setUp() {
|
||||
System.setProperty("ray.object-store.size", "140 MB");
|
||||
}
|
||||
|
||||
@AfterClass
|
||||
public void tearDown() {
|
||||
System.clearProperty("ray.object-store.size");
|
||||
}
|
||||
|
||||
public static class LargeObjectWithCyclicRef {
|
||||
|
||||
private final LargeObjectWithCyclicRef loop;
|
||||
|
||||
private final ObjectRef<TestUtils.LargeObject> largeObject;
|
||||
|
||||
public LargeObjectWithCyclicRef() {
|
||||
this.loop = this;
|
||||
this.largeObject = Ray.put(new TestUtils.LargeObject(40 * 1024 * 1024));
|
||||
}
|
||||
}
|
||||
|
||||
public static class GarbageHolder {
|
||||
|
||||
private WeakReference<LargeObjectWithCyclicRef> garbage;
|
||||
|
||||
public GarbageHolder() {
|
||||
LargeObjectWithCyclicRef x = new LargeObjectWithCyclicRef();
|
||||
garbage = new WeakReference<>(x);
|
||||
}
|
||||
|
||||
public boolean hasGarbage() {
|
||||
return garbage.get() != null;
|
||||
}
|
||||
|
||||
public TestUtils.LargeObject returnLargeObject() {
|
||||
return new TestUtils.LargeObject(80 * 1024 * 1024);
|
||||
}
|
||||
}
|
||||
|
||||
private void testGlobalGcWhenFull(boolean withPut) {
|
||||
// Local driver.
|
||||
WeakReference<LargeObjectWithCyclicRef> localRef = new WeakReference<>(
|
||||
new LargeObjectWithCyclicRef());
|
||||
|
||||
// Remote workers.
|
||||
List<ActorHandle<GarbageHolder>> actors = IntStream
|
||||
.range(0, 2).mapToObj(i -> Ray.actor(GarbageHolder::new).remote())
|
||||
.collect(Collectors.toList());
|
||||
|
||||
Assert.assertNotNull(localRef.get());
|
||||
for (ActorHandle<GarbageHolder> actor : actors) {
|
||||
Assert.assertTrue(actor.task(GarbageHolder::hasGarbage).remote().get());
|
||||
}
|
||||
|
||||
if (withPut) {
|
||||
// GC should be triggered for all workers, including the local driver,
|
||||
// when the driver tries to Ray.put a value that doesn't fit in the
|
||||
// object store. This should cause the captured ObjectRefs to be evicted.
|
||||
Ray.put(new TestUtils.LargeObject(80 * 1024 * 1024));
|
||||
} else {
|
||||
// GC should be triggered for all workers, including the local driver,
|
||||
// when a remote task tries to put a return value that doesn't fit in
|
||||
// the object store. This should cause the captured ObjectRefs' to be evicted.
|
||||
actors.get(0).task(GarbageHolder::returnLargeObject).remote().get();
|
||||
}
|
||||
|
||||
TestUtils.waitForCondition(() -> localRef.get() == null && actors.stream().noneMatch(
|
||||
a -> a.task(GarbageHolder::hasGarbage).remote().get()), 10 * 1000);
|
||||
}
|
||||
|
||||
public void testGlobalGcWhenFullWithPut() {
|
||||
testGlobalGcWhenFull(true);
|
||||
}
|
||||
|
||||
public void testGlobalGcWhenFullWithReturn() {
|
||||
testGlobalGcWhenFull(false);
|
||||
}
|
||||
}
|
||||
@@ -143,7 +143,7 @@ public class MultiThreadingTest extends BaseTest {
|
||||
final ActorHandle<Echo> fooActor = Ray.actor(Echo::new).remote();
|
||||
return new Runnable[]{
|
||||
() -> Ray.put(1),
|
||||
() -> Ray.get(fooObject.getId(), fooObject.getType()),
|
||||
() -> Ray.get(ImmutableList.of(fooObject)),
|
||||
fooObject::get,
|
||||
() -> Ray.wait(ImmutableList.of(fooObject)),
|
||||
Ray::getRuntimeContext,
|
||||
|
||||
@@ -3,7 +3,6 @@ package io.ray.test;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import io.ray.api.ObjectRef;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.api.id.ObjectId;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import org.testng.Assert;
|
||||
@@ -37,8 +36,8 @@ public class ObjectStoreTest extends BaseTest {
|
||||
@Test
|
||||
public void testGetMultipleObjects() {
|
||||
List<Integer> ints = ImmutableList.of(1, 2, 3, 4, 5);
|
||||
List<ObjectId> ids = ints.stream().map(obj -> Ray.put(obj).getId())
|
||||
List<ObjectRef<Integer>> refs = ints.stream().map(Ray::put)
|
||||
.collect(Collectors.toList());
|
||||
Assert.assertEquals(ints, Ray.get(ids, Integer.class));
|
||||
Assert.assertEquals(ints, Ray.get(refs));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import com.google.common.collect.ImmutableList;
|
||||
import io.ray.api.ObjectRef;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.api.id.TaskId;
|
||||
import io.ray.runtime.object.ObjectRefImpl;
|
||||
import java.util.Arrays;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
@@ -19,11 +20,11 @@ public class PlasmaFreeTest extends BaseTest {
|
||||
ObjectRef<String> helloId = Ray.task(PlasmaFreeTest::hello).remote();
|
||||
String helloString = helloId.get();
|
||||
Assert.assertEquals("hello", helloString);
|
||||
Ray.internal().free(ImmutableList.of(helloId.getId()), true, false);
|
||||
Ray.internal().free(ImmutableList.of(helloId), true, false);
|
||||
|
||||
final boolean result = TestUtils.waitForCondition(() ->
|
||||
!TestUtils.getRuntime().getObjectStore()
|
||||
.wait(ImmutableList.of(helloId.getId()), 1, 0).get(0), 50);
|
||||
.wait(ImmutableList.of(((ObjectRefImpl<String>) helloId).getId()), 1, 0).get(0), 50);
|
||||
if (TestUtils.isSingleProcessMode()) {
|
||||
Assert.assertTrue(result);
|
||||
} else {
|
||||
@@ -36,9 +37,10 @@ public class PlasmaFreeTest extends BaseTest {
|
||||
public void testDeleteCreatingTasks() {
|
||||
ObjectRef<String> helloId = Ray.task(PlasmaFreeTest::hello).remote();
|
||||
Assert.assertEquals("hello", helloId.get());
|
||||
Ray.internal().free(ImmutableList.of(helloId.getId()), true, true);
|
||||
Ray.internal().free(ImmutableList.of(helloId), true, true);
|
||||
|
||||
TaskId taskId = TaskId.fromBytes(Arrays.copyOf(helloId.getId().getBytes(), TaskId.LENGTH));
|
||||
TaskId taskId = TaskId.fromBytes(
|
||||
Arrays.copyOf(((ObjectRefImpl<String>) helloId).getId().getBytes(), TaskId.LENGTH));
|
||||
final boolean result = TestUtils.waitForCondition(
|
||||
() -> !TestUtils.getRuntime().getGcsClient()
|
||||
.rayletTaskExistsInGcs(taskId), 50);
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
package io.ray.test;
|
||||
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.api.id.ObjectId;
|
||||
import io.ray.runtime.object.ObjectStore;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
public class PlasmaStoreTest extends BaseTest {
|
||||
|
||||
@Test(groups = {"cluster"})
|
||||
public void testPutWithDuplicateId() {
|
||||
ObjectId objectId = ObjectId.fromRandom();
|
||||
ObjectStore objectStore = TestUtils.getRuntime().getObjectStore();
|
||||
objectStore.put("1", objectId);
|
||||
Assert.assertEquals(Ray.get(objectId, String.class), "1");
|
||||
objectStore.put("2", objectId);
|
||||
// Putting the second object with duplicate ID should fail but ignored.
|
||||
Assert.assertEquals(Ray.get(objectId, String.class), "1");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,362 @@
|
||||
package io.ray.test;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import com.google.gson.Gson;
|
||||
import io.ray.api.ActorHandle;
|
||||
import io.ray.api.ObjectRef;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.api.id.ObjectId;
|
||||
import io.ray.runtime.object.NativeObjectStore;
|
||||
import io.ray.runtime.object.ObjectRefImpl;
|
||||
import java.lang.ref.Reference;
|
||||
import java.lang.reflect.Field;
|
||||
import java.lang.reflect.Method;
|
||||
import java.time.Duration;
|
||||
import java.time.Instant;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.AfterClass;
|
||||
import org.testng.annotations.BeforeClass;
|
||||
import org.testng.annotations.BeforeMethod;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
@Test(groups = {"cluster"})
|
||||
public class ReferenceCountingTest extends BaseTest {
|
||||
@BeforeClass
|
||||
public void setUp() {
|
||||
System.setProperty("ray.object-store.size", "100 MB");
|
||||
}
|
||||
|
||||
@AfterClass
|
||||
public void tearDown() {
|
||||
System.clearProperty("ray.object-store.size");
|
||||
}
|
||||
|
||||
/**
|
||||
* Because we can't explicitly GC an Java object. We use this helper method to manually remove an
|
||||
* local reference.
|
||||
*/
|
||||
private static void del(ObjectRef<?> obj) {
|
||||
try {
|
||||
Field referencesField = ObjectRefImpl.class.getDeclaredField("REFERENCES");
|
||||
referencesField.setAccessible(true);
|
||||
Set<?> references = (Set<?>) referencesField.get(null);
|
||||
Class<?> referenceClass = Class
|
||||
.forName("io.ray.runtime.object.ObjectRefImpl$ObjectRefImplReference");
|
||||
Method finalizeReferentMethod = referenceClass.getDeclaredMethod("finalizeReferent");
|
||||
finalizeReferentMethod.setAccessible(true);
|
||||
for (Object reference : references) {
|
||||
if (obj.equals(((Reference<?>)reference).get())) {
|
||||
finalizeReferentMethod.invoke(reference);
|
||||
break;
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private void checkRefCounts(Map<ObjectId, long[]> expected, Duration timeout) {
|
||||
Instant start = Instant.now();
|
||||
while (true) {
|
||||
Map<ObjectId, long[]> actual =
|
||||
((NativeObjectStore) TestUtils.getRuntime().getObjectStore()).getAllReferenceCounts();
|
||||
try {
|
||||
Assert.assertEqualsDeep(actual, expected);
|
||||
return;
|
||||
} catch (AssertionError e) {
|
||||
if (Duration.between(start, Instant.now()).compareTo(timeout) >= 0) {
|
||||
System.out.println("Actual: " + new Gson().toJson(actual));
|
||||
System.out.println("Expected: " + new Gson().toJson(expected));
|
||||
throw e;
|
||||
} else {
|
||||
try {
|
||||
Thread.sleep(100);
|
||||
} catch (InterruptedException ex) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void checkRefCounts(Map<ObjectId, long[]> expected) {
|
||||
checkRefCounts(expected, Duration.ofSeconds(10));
|
||||
}
|
||||
|
||||
private void checkRefCounts(ObjectId objectId, long localRefCount, long submittedTaskRefCount) {
|
||||
checkRefCounts(ImmutableMap.of(objectId, new long[] {localRefCount, submittedTaskRefCount}));
|
||||
}
|
||||
|
||||
private void checkRefCounts(ObjectId objectId1, long localRefCount1, long submittedTaskRefCount1,
|
||||
ObjectId objectId2, long localRefCount2, long submittedTaskRefCount2) {
|
||||
checkRefCounts(ImmutableMap.of(objectId1, new long[] {localRefCount1, submittedTaskRefCount1},
|
||||
objectId2, new long[] {localRefCount2, submittedTaskRefCount2}));
|
||||
}
|
||||
|
||||
private static void fillObjectStoreAndGet(ObjectId objectId, boolean succeed) {
|
||||
fillObjectStoreAndGet(objectId, succeed, 40 * 1024 * 1024, 5);
|
||||
}
|
||||
|
||||
private static void fillObjectStoreAndGet(ObjectId objectId, boolean succeed, int objectSize,
|
||||
int numObjects) {
|
||||
for (int i = 0; i < numObjects; i++) {
|
||||
Ray.put(new TestUtils.LargeObject(objectSize));
|
||||
}
|
||||
if (succeed) {
|
||||
TestUtils.getRuntime().getObjectStore().getRaw(ImmutableList.of(objectId), Long.MAX_VALUE);
|
||||
} else {
|
||||
List<Boolean> result =
|
||||
TestUtils.getRuntime().getObjectStore().wait(ImmutableList.of(objectId), 1, 100);
|
||||
Assert.assertFalse(result.get(0));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Based on Python test case `test_local_refcounts`.
|
||||
*/
|
||||
public void testLocalRefCounts() {
|
||||
ObjectRefImpl<Object> obj1 = (ObjectRefImpl<Object>) Ray.put(null);
|
||||
checkRefCounts(obj1.getId(), 1, 0);
|
||||
ObjectRef<Object> obj1Copy = new ObjectRefImpl<>(obj1.getId(), obj1.getType());
|
||||
checkRefCounts(obj1.getId(), 2, 0);
|
||||
|
||||
del(obj1);
|
||||
checkRefCounts(obj1.getId(), 1, 0);
|
||||
del(obj1Copy);
|
||||
checkRefCounts(ImmutableMap.of());
|
||||
}
|
||||
|
||||
private static int oneDep(Object obj) {
|
||||
return oneDep(obj, null);
|
||||
}
|
||||
|
||||
private static int oneDep(Object obj, ActorHandle<SignalActor> singal) {
|
||||
return oneDep(obj, singal, false);
|
||||
}
|
||||
|
||||
private static int oneDep(Object obj, ActorHandle<SignalActor> singal, boolean fail) {
|
||||
if (singal != null) {
|
||||
singal.task(SignalActor::waitSignal).remote().get();
|
||||
}
|
||||
if (fail) {
|
||||
throw new RuntimeException("failed on purpose");
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
private static TestUtils.LargeObject oneDepLarge(Object obj, ActorHandle<SignalActor> singal) {
|
||||
if (singal != null) {
|
||||
singal.task(SignalActor::waitSignal).remote().get();
|
||||
}
|
||||
// This will be spilled to plasma.
|
||||
return new TestUtils.LargeObject(10 * 1024 * 1024);
|
||||
}
|
||||
|
||||
private static void sendSignal(ActorHandle<SignalActor> signal) {
|
||||
ObjectRef<Integer> result = signal.task(SignalActor::sendSignal).remote();
|
||||
result.get();
|
||||
// Remove the reference immediately, otherwise it will affect subsequent tests.
|
||||
del(result);
|
||||
}
|
||||
|
||||
/**
|
||||
* Based on Python test case `test_dependency_refcounts`.
|
||||
*/
|
||||
public void testDependencyRefCounts() {
|
||||
{
|
||||
// Test that regular plasma dependency refcounts are decremented once the
|
||||
// task finishes.
|
||||
ActorHandle<SignalActor> signal = SignalActor.create();
|
||||
ObjectRefImpl<TestUtils.LargeObject> largeDep = (ObjectRefImpl<TestUtils.LargeObject>) Ray
|
||||
.put(new TestUtils.LargeObject());
|
||||
ObjectRefImpl<Object> result = (ObjectRefImpl<Object>)
|
||||
Ray.<TestUtils.LargeObject, ActorHandle<SignalActor>, Object>task(
|
||||
ReferenceCountingTest::oneDep, largeDep, signal).remote();
|
||||
checkRefCounts(largeDep.getId(), 1, 1, result.getId(), 1, 0);
|
||||
sendSignal(signal);
|
||||
// Reference count should be removed once the task finishes.
|
||||
checkRefCounts(largeDep.getId(), 1, 0, result.getId(), 1, 0);
|
||||
del(largeDep);
|
||||
del(result);
|
||||
checkRefCounts(ImmutableMap.of());
|
||||
}
|
||||
|
||||
{
|
||||
// Test that inlined dependency refcounts are decremented once they are
|
||||
// inlined.
|
||||
ActorHandle<SignalActor> signal = SignalActor.create();
|
||||
ObjectRefImpl<Integer> dep = (ObjectRefImpl<Integer>)
|
||||
Ray.<Integer, ActorHandle<SignalActor>, Integer>task(ReferenceCountingTest::oneDep,
|
||||
Integer.valueOf(1), signal).remote();
|
||||
checkRefCounts(dep.getId(), 1, 0);
|
||||
ObjectRefImpl<Object> result = (ObjectRefImpl<Object>)
|
||||
Ray.<Integer, Object>task(ReferenceCountingTest::oneDep, dep).remote();
|
||||
checkRefCounts(dep.getId(), 1, 1, result.getId(), 1, 0);
|
||||
sendSignal(signal);
|
||||
// Reference count should be removed as soon as the dependency is inlined.
|
||||
checkRefCounts(dep.getId(), 1, 0, result.getId(), 1, 0);
|
||||
del(dep);
|
||||
del(result);
|
||||
checkRefCounts(ImmutableMap.of());
|
||||
}
|
||||
|
||||
{
|
||||
// Test that spilled plasma dependency refcounts are decremented once
|
||||
// the task finishes.
|
||||
ActorHandle<SignalActor> signal1 = SignalActor.create();
|
||||
ActorHandle<SignalActor> signal2 = SignalActor.create();
|
||||
ObjectRefImpl<TestUtils.LargeObject> dep = (ObjectRefImpl<TestUtils.LargeObject>)
|
||||
Ray.<TestUtils.LargeObject, ActorHandle<SignalActor>, TestUtils.LargeObject>task(
|
||||
ReferenceCountingTest::oneDepLarge, (TestUtils.LargeObject) null, signal1).remote();
|
||||
checkRefCounts(dep.getId(), 1, 0);
|
||||
ObjectRefImpl<Integer> result = (ObjectRefImpl<Integer>)
|
||||
Ray.<TestUtils.LargeObject, ActorHandle<SignalActor>, Integer>task(
|
||||
ReferenceCountingTest::oneDep, dep, signal2).remote();
|
||||
checkRefCounts(dep.getId(), 1, 1, result.getId(), 1, 0);
|
||||
sendSignal(signal1);
|
||||
dep.get(); // TODO(kfstorm): timeout=10
|
||||
// Reference count should remain because the dependency is in plasma.
|
||||
checkRefCounts(dep.getId(), 1, 1, result.getId(), 1, 0);
|
||||
sendSignal(signal2);
|
||||
// Reference count should be removed because the task finished.
|
||||
checkRefCounts(dep.getId(), 1, 0, result.getId(), 1, 0);
|
||||
del(dep);
|
||||
del(result);
|
||||
checkRefCounts(ImmutableMap.of());
|
||||
}
|
||||
|
||||
{
|
||||
// Test that regular plasma dependency refcounts are decremented if a task
|
||||
// fails.
|
||||
ActorHandle<SignalActor> signal = SignalActor.create();
|
||||
ObjectRefImpl<TestUtils.LargeObject> largeDep = (ObjectRefImpl<TestUtils.LargeObject>)
|
||||
Ray.put(new TestUtils.LargeObject(10 * 1024 * 1024));
|
||||
ObjectRefImpl<Integer> result = (ObjectRefImpl<Integer>)
|
||||
Ray.<TestUtils.LargeObject, ActorHandle<SignalActor>, Boolean, Integer>task(
|
||||
ReferenceCountingTest::oneDep, largeDep, signal, /* fail= */true).remote();
|
||||
checkRefCounts(largeDep.getId(), 1, 1, result.getId(), 1, 0);
|
||||
sendSignal(signal);
|
||||
// Reference count should be removed once the task finishes.
|
||||
checkRefCounts(largeDep.getId(), 1, 0, result.getId(), 1, 0);
|
||||
del(largeDep);
|
||||
del(result);
|
||||
checkRefCounts(ImmutableMap.of());
|
||||
}
|
||||
|
||||
{
|
||||
// Test that spilled plasma dependency refcounts are decremented if a task
|
||||
// fails.
|
||||
ActorHandle<SignalActor> signal1 = SignalActor.create();
|
||||
ActorHandle<SignalActor> signal2 = SignalActor.create();
|
||||
ObjectRefImpl<TestUtils.LargeObject> dep = (ObjectRefImpl<TestUtils.LargeObject>)
|
||||
Ray.<Integer, ActorHandle<SignalActor>, TestUtils.LargeObject>task(
|
||||
ReferenceCountingTest::oneDepLarge, (Integer) null, signal1).remote();
|
||||
checkRefCounts(dep.getId(), 1, 0);
|
||||
ObjectRefImpl<Integer> result = (ObjectRefImpl<Integer>)
|
||||
Ray.<TestUtils.LargeObject, ActorHandle<SignalActor>, Boolean, Integer>task(
|
||||
ReferenceCountingTest::oneDep, dep, signal2, /* fail= */true).remote();
|
||||
checkRefCounts(dep.getId(), 1, 1, result.getId(), 1, 0);
|
||||
sendSignal(signal1);
|
||||
dep.get(); // TODO(kfstorm): timeout=10
|
||||
// Reference count should remain because the dependency is in plasma.
|
||||
checkRefCounts(dep.getId(), 1, 1, result.getId(), 1, 0);
|
||||
sendSignal(signal2);
|
||||
// Reference count should be removed because the task finished.
|
||||
checkRefCounts(dep.getId(), 1, 0, result.getId(), 1, 0);
|
||||
del(dep);
|
||||
del(result);
|
||||
checkRefCounts(ImmutableMap.of());
|
||||
}
|
||||
}
|
||||
|
||||
private static int fooBasicPinning(Object arg) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
public static class ActorBasicPinning {
|
||||
private ObjectRef<TestUtils.LargeObject> largeObject;
|
||||
|
||||
public ActorBasicPinning() {
|
||||
// Hold a long-lived reference to a ray.put object's ID. The object
|
||||
// should not be garbage collected while the actor is alive because
|
||||
// the object is pinned by the raylet.
|
||||
largeObject = Ray.put(new TestUtils.LargeObject(25 * 1024 * 1024));
|
||||
}
|
||||
|
||||
public TestUtils.LargeObject getLargeObject() {
|
||||
return largeObject.get();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Based on Python test case `test_basic_pinning`.
|
||||
*/
|
||||
public void testBasicPinning() {
|
||||
ActorHandle<ActorBasicPinning> actor = Ray.actor(ActorBasicPinning::new).remote();
|
||||
|
||||
// Fill up the object store with short-lived objects. These should be
|
||||
// evicted before the long-lived object whose reference is held by
|
||||
// the actor.
|
||||
for (int i = 0; i < 10; i++) {
|
||||
ObjectRef<Integer> intermediateResult = Ray
|
||||
.task(ReferenceCountingTest::fooBasicPinning, new TestUtils.LargeObject(10 * 1024 * 1024))
|
||||
.remote();
|
||||
intermediateResult.get();
|
||||
}
|
||||
// The ray.get below would fail with only LRU eviction, as the object
|
||||
// that was ray.put by the actor would have been evicted.
|
||||
actor.task(ActorBasicPinning::getLargeObject).remote().get();
|
||||
}
|
||||
|
||||
private static Object pending(TestUtils.LargeObject input1, Integer input2) {
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Based on Python test case `test_pending_task_dependency_pinning`.
|
||||
*/
|
||||
public void testPendingTaskDependencyPinning() {
|
||||
// The object that is ray.put here will go out of scope immediately, so if
|
||||
// pending task dependencies aren't considered, it will be evicted before
|
||||
// the ray.get below due to the subsequent ray.puts that fill up the object
|
||||
// store.
|
||||
TestUtils.LargeObject input1 = new TestUtils.LargeObject(40 * 1024 * 1024);
|
||||
ActorHandle<SignalActor> signal = SignalActor.create();
|
||||
ObjectRef<Object> result = Ray
|
||||
.task(ReferenceCountingTest::pending, input1, signal.task(SignalActor::waitSignal).remote())
|
||||
.remote();
|
||||
|
||||
for (int i = 0; i < 2; i++) {
|
||||
Ray.put(new TestUtils.LargeObject(40 * 1024 * 1024));
|
||||
}
|
||||
|
||||
sendSignal(signal);
|
||||
result.get();
|
||||
}
|
||||
|
||||
/**
|
||||
* Test that an object containing object IDs within it pins the inner IDs. Based on Python test
|
||||
* case `test_basic_nested_ids`.
|
||||
*/
|
||||
public void testBasicNestedIds() {
|
||||
ObjectRefImpl<byte[]> inner = (ObjectRefImpl<byte[]>) Ray.put(new byte[40 * 1024 * 1024]);
|
||||
ObjectRef<List<ObjectRef<byte[]>>> outer = Ray.put(Collections.singletonList(inner));
|
||||
|
||||
// Remove the local reference to the inner object.
|
||||
del(inner);
|
||||
|
||||
// Check that the outer reference pins the inner object.
|
||||
fillObjectStoreAndGet(inner.getId(), true);
|
||||
|
||||
// Remove the outer reference and check that the inner object gets evicted.
|
||||
del(outer);
|
||||
fillObjectStoreAndGet(inner.getId(), false);
|
||||
}
|
||||
|
||||
// TODO(kfstorm): Add more test cases
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
package io.ray.test;
|
||||
|
||||
import io.ray.api.ActorHandle;
|
||||
import io.ray.api.Ray;
|
||||
import java.util.concurrent.Semaphore;
|
||||
|
||||
public class SignalActor {
|
||||
|
||||
private Semaphore semaphore;
|
||||
|
||||
public SignalActor() {
|
||||
this.semaphore = new Semaphore(0);
|
||||
}
|
||||
|
||||
public int sendSignal() {
|
||||
this.semaphore.release();
|
||||
return 0;
|
||||
}
|
||||
|
||||
public int waitSignal() throws InterruptedException {
|
||||
this.semaphore.acquire();
|
||||
return 0;
|
||||
}
|
||||
|
||||
public static ActorHandle<SignalActor> create() {
|
||||
return Ray.actor(SignalActor::new).setMaxConcurrency(2).remote();
|
||||
}
|
||||
}
|
||||
@@ -4,7 +4,6 @@ import com.google.common.collect.ImmutableList;
|
||||
import io.ray.api.ActorHandle;
|
||||
import io.ray.api.ObjectRef;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.api.id.ObjectId;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.testng.Assert;
|
||||
@@ -21,12 +20,12 @@ public class StressTest extends BaseTest {
|
||||
for (int numIterations : ImmutableList.of(1, 10, 100, 1000)) {
|
||||
int numTasks = 1000 / numIterations;
|
||||
for (int i = 0; i < numIterations; i++) {
|
||||
List<ObjectId> resultIds = new ArrayList<>();
|
||||
List<ObjectRef<Integer>> results = new ArrayList<>();
|
||||
for (int j = 0; j < numTasks; j++) {
|
||||
resultIds.add(Ray.task(StressTest::echo, 1).remote().getId());
|
||||
results.add(Ray.task(StressTest::echo, 1).remote());
|
||||
}
|
||||
|
||||
for (Integer result : Ray.<Integer>get(resultIds, Integer.class)) {
|
||||
for (Integer result : Ray.get(results)) {
|
||||
Assert.assertEquals(result, Integer.valueOf(1));
|
||||
}
|
||||
}
|
||||
@@ -58,12 +57,12 @@ public class StressTest extends BaseTest {
|
||||
}
|
||||
|
||||
public int ping(int n) {
|
||||
List<ObjectId> objectIds = new ArrayList<>();
|
||||
List<ObjectRef<Integer>> objectRefs = new ArrayList<>();
|
||||
for (int i = 0; i < n; i++) {
|
||||
objectIds.add(actor.task(Actor::ping).remote().getId());
|
||||
objectRefs.add(actor.task(Actor::ping).remote());
|
||||
}
|
||||
int sum = 0;
|
||||
for (Integer result : Ray.<Integer>get(objectIds, Integer.class)) {
|
||||
for (Integer result : Ray.get(objectRefs)) {
|
||||
sum += result;
|
||||
}
|
||||
return sum;
|
||||
@@ -72,13 +71,13 @@ public class StressTest extends BaseTest {
|
||||
|
||||
public void testSubmittingManyTasksToOneActor() throws Exception {
|
||||
ActorHandle<Actor> actor = Ray.actor(Actor::new).remote();
|
||||
List<ObjectId> objectIds = new ArrayList<>();
|
||||
List<ObjectRef<Integer>> objectRefs = new ArrayList<>();
|
||||
for (int i = 0; i < 10; i++) {
|
||||
ActorHandle<Worker> worker = Ray.actor(Worker::new, actor).remote();
|
||||
objectIds.add(worker.task(Worker::ping, 100).remote().getId());
|
||||
objectRefs.add(worker.task(Worker::ping, 100).remote());
|
||||
}
|
||||
|
||||
for (Integer result : Ray.<Integer>get(objectIds, Integer.class)) {
|
||||
for (Integer result : Ray.get(objectRefs)) {
|
||||
Assert.assertEquals(result, Integer.valueOf(100));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package io.ray.test;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.ray.api.ObjectRef;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.runtime.RayRuntimeInternal;
|
||||
import io.ray.runtime.RayRuntimeProxy;
|
||||
import io.ray.runtime.config.RunMode;
|
||||
import io.ray.runtime.task.ArgumentsBuilder;
|
||||
import java.io.Serializable;
|
||||
import java.util.function.Supplier;
|
||||
import org.testng.Assert;
|
||||
@@ -13,7 +15,16 @@ public class TestUtils {
|
||||
|
||||
public static class LargeObject implements Serializable {
|
||||
|
||||
public byte[] data = new byte[1024 * 1024];
|
||||
public byte[] data;
|
||||
|
||||
public LargeObject() {
|
||||
this(1024 * 1024);
|
||||
}
|
||||
|
||||
public LargeObject(int size) {
|
||||
Preconditions.checkState(size > ArgumentsBuilder.LARGEST_SIZE_PASS_BY_VALUE);
|
||||
data = new byte[size];
|
||||
}
|
||||
}
|
||||
|
||||
private static final int WAIT_INTERVAL_MS = 5;
|
||||
@@ -55,11 +66,12 @@ public class TestUtils {
|
||||
|
||||
/**
|
||||
* Warm up the cluster to make sure there's at least one idle worker.
|
||||
* <p>
|
||||
* This is needed before calling `wait`. Because, in Travis CI, starting a new worker
|
||||
* process could be slower than the wait timeout.
|
||||
* TODO(hchen): We should consider supporting always reversing a certain number of
|
||||
* idle workers in Raylet's worker pool.
|
||||
* <p/>
|
||||
* This is needed before calling `wait`. Because, in Travis CI, starting a new worker process
|
||||
* could be slower than the wait timeout.
|
||||
* <p/>
|
||||
* TODO(hchen): We should consider supporting always reversing a certain number of idle workers in
|
||||
* Raylet's worker pool.
|
||||
*/
|
||||
public static void warmUpCluster() {
|
||||
ObjectRef<String> obj = Ray.task(TestUtils::hi).remote();
|
||||
|
||||
@@ -50,7 +50,6 @@ public class Exercise04 {
|
||||
System.out.printf("%d ready object(s): \n", waitResult.getReady().size());
|
||||
waitResult.getReady().forEach(rayObject -> System.out.println(rayObject.get()));
|
||||
System.out.printf("%d unready object(s): \n", waitResult.getUnready().size());
|
||||
waitResult.getUnready().forEach(rayObject -> System.out.println(rayObject.getId()));
|
||||
} catch (Throwable t) {
|
||||
t.printStackTrace();
|
||||
} finally {
|
||||
|
||||
Reference in New Issue
Block a user