mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 15:40:09 +08:00
[Java] Local and distributed ref counting in Java (#9371)
This commit is contained in:
@@ -36,10 +36,12 @@ import io.ray.runtime.task.ArgumentsBuilder;
|
||||
import io.ray.runtime.task.FunctionArg;
|
||||
import io.ray.runtime.task.TaskExecutor;
|
||||
import io.ray.runtime.task.TaskSubmitter;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.stream.Collectors;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
@@ -72,9 +74,6 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
||||
runtimeContext = new RuntimeContextImpl(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public abstract void shutdown();
|
||||
|
||||
@Override
|
||||
public <T> ObjectRef<T> put(T obj) {
|
||||
ObjectId objectId = objectStore.put(obj);
|
||||
@@ -82,19 +81,27 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> T get(ObjectId objectId, Class<T> objectType) throws RayException {
|
||||
List<T> ret = get(ImmutableList.of(objectId), objectType);
|
||||
public <T> T get(ObjectRef<T> objectRef) throws RayException {
|
||||
List<T> ret = get(ImmutableList.of(objectRef));
|
||||
return ret.get(0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> List<T> get(List<ObjectId> objectIds, Class<T> objectType) {
|
||||
public <T> List<T> get(List<ObjectRef<T>> objectRefs) {
|
||||
List<ObjectId> objectIds = new ArrayList<>();
|
||||
Class<T> objectType = null;
|
||||
for (ObjectRef<T> o : objectRefs) {
|
||||
ObjectRefImpl<T> objectRefImpl = (ObjectRefImpl<T>) o;
|
||||
objectIds.add(objectRefImpl.getId());
|
||||
objectType = objectRefImpl.getType();
|
||||
}
|
||||
return objectStore.get(objectIds, objectType);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void free(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks) {
|
||||
objectStore.delete(objectIds, localOnly, deleteCreatingTasks);
|
||||
public void free(List<ObjectRef<?>> objectRefs, boolean localOnly, boolean deleteCreatingTasks) {
|
||||
objectStore.delete(objectRefs.stream().map(ref -> ((ObjectRefImpl<?>) ref).getId()).collect(
|
||||
Collectors.toList()), localOnly, deleteCreatingTasks);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -22,6 +22,9 @@ import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.locks.Lock;
|
||||
import java.util.concurrent.locks.ReadWriteLock;
|
||||
import java.util.concurrent.locks.ReentrantReadWriteLock;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
@@ -35,6 +38,15 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
|
||||
private RunManager manager = null;
|
||||
|
||||
/**
|
||||
* In Java, GC runs in a standalone thread, and we can't control the exact
|
||||
* timing of garbage collection. By using this lock, when
|
||||
* {@link NativeObjectStore#nativeRemoveLocalReference} is executing, the core
|
||||
* worker will not be shut down, therefore it guarantees some kind of
|
||||
* thread-safety. Note that this guarantee only works for driver.
|
||||
*/
|
||||
private final ReadWriteLock shutdownLock = new ReentrantReadWriteLock();
|
||||
|
||||
|
||||
static {
|
||||
LOGGER.debug("Loading native libraries.");
|
||||
@@ -105,7 +117,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
|
||||
taskExecutor = new NativeTaskExecutor(this);
|
||||
workerContext = new NativeWorkerContext();
|
||||
objectStore = new NativeObjectStore(workerContext);
|
||||
objectStore = new NativeObjectStore(workerContext, shutdownLock);
|
||||
taskSubmitter = new NativeTaskSubmitter();
|
||||
|
||||
LOGGER.debug("RayNativeRuntime started with store {}, raylet {}",
|
||||
@@ -114,19 +126,25 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
|
||||
@Override
|
||||
public void shutdown() {
|
||||
if (rayConfig.workerMode == WorkerType.DRIVER) {
|
||||
nativeShutdown();
|
||||
if (null != manager) {
|
||||
manager.cleanup();
|
||||
manager = null;
|
||||
Lock writeLock = shutdownLock.readLock();
|
||||
writeLock.lock();
|
||||
try {
|
||||
if (rayConfig.workerMode == WorkerType.DRIVER) {
|
||||
nativeShutdown();
|
||||
if (null != manager) {
|
||||
manager.cleanup();
|
||||
manager = null;
|
||||
}
|
||||
}
|
||||
if (null != gcsClient) {
|
||||
gcsClient.destroy();
|
||||
gcsClient = null;
|
||||
}
|
||||
RayConfig.reset();
|
||||
LOGGER.debug("RayNativeRuntime shutdown");
|
||||
} finally {
|
||||
writeLock.unlock();
|
||||
}
|
||||
if (null != gcsClient) {
|
||||
gcsClient.destroy();
|
||||
gcsClient = null;
|
||||
}
|
||||
RayConfig.reset();
|
||||
LOGGER.debug("RayNativeRuntime shutdown");
|
||||
}
|
||||
|
||||
// For test purpose only
|
||||
|
||||
@@ -2,6 +2,7 @@ package io.ray.runtime.object;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.ray.api.id.ObjectId;
|
||||
import io.ray.api.id.UniqueId;
|
||||
import io.ray.runtime.context.WorkerContext;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
@@ -96,4 +97,12 @@ public class LocalModeObjectStore extends ObjectStore {
|
||||
pool.remove(objectId);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addLocalReference(UniqueId workerId, ObjectId objectId) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeLocalReference(UniqueId workerId, ObjectId objectId) {
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,8 +2,13 @@ package io.ray.runtime.object;
|
||||
|
||||
import io.ray.api.id.BaseId;
|
||||
import io.ray.api.id.ObjectId;
|
||||
import io.ray.api.id.UniqueId;
|
||||
import io.ray.runtime.context.WorkerContext;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.locks.Lock;
|
||||
import java.util.concurrent.locks.ReadWriteLock;
|
||||
import java.util.stream.Collectors;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
@@ -15,8 +20,11 @@ public class NativeObjectStore extends ObjectStore {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(NativeObjectStore.class);
|
||||
|
||||
public NativeObjectStore(WorkerContext workerContext) {
|
||||
private final ReadWriteLock shutdownLock;
|
||||
|
||||
public NativeObjectStore(WorkerContext workerContext, ReadWriteLock shutdownLock) {
|
||||
super(workerContext);
|
||||
this.shutdownLock = shutdownLock;
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -44,6 +52,31 @@ public class NativeObjectStore extends ObjectStore {
|
||||
nativeDelete(toBinaryList(objectIds), localOnly, deleteCreatingTasks);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addLocalReference(UniqueId workerId, ObjectId objectId) {
|
||||
nativeAddLocalReference(workerId.getBytes(), objectId.getBytes());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeLocalReference(UniqueId workerId, ObjectId objectId) {
|
||||
Lock readLock = shutdownLock.readLock();
|
||||
readLock.lock();
|
||||
try {
|
||||
nativeRemoveLocalReference(workerId.getBytes(), objectId.getBytes());
|
||||
} finally {
|
||||
readLock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public Map<ObjectId, long[]> getAllReferenceCounts() {
|
||||
Map<ObjectId, long[]> referenceCounts = new HashMap<>();
|
||||
for (Map.Entry<byte[], long[]> entry :
|
||||
nativeGetAllReferenceCounts().entrySet()) {
|
||||
referenceCounts.put(new ObjectId(entry.getKey()), entry.getValue());
|
||||
}
|
||||
return referenceCounts;
|
||||
}
|
||||
|
||||
private static List<byte[]> toBinaryList(List<ObjectId> ids) {
|
||||
return ids.stream().map(BaseId::getBytes).collect(Collectors.toList());
|
||||
}
|
||||
@@ -59,4 +92,10 @@ public class NativeObjectStore extends ObjectStore {
|
||||
|
||||
private static native void nativeDelete(List<byte[]> objectIds, boolean localOnly,
|
||||
boolean deleteCreatingTasks);
|
||||
|
||||
private static native void nativeAddLocalReference(byte[] workerId, byte[] objectId);
|
||||
|
||||
private static native void nativeRemoveLocalReference(byte[] workerId, byte[] objectId);
|
||||
|
||||
private static native Map<byte[], long[]> nativeGetAllReferenceCounts();
|
||||
}
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
package io.ray.runtime.object;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.ray.api.id.BaseId;
|
||||
import io.ray.api.id.ObjectId;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Binary representation of a ray object. See `RayObject` class in C++ for details.
|
||||
@@ -9,11 +14,17 @@ public class NativeRayObject {
|
||||
|
||||
public byte[] data;
|
||||
public byte[] metadata;
|
||||
public List<byte[]> containedObjectIds;
|
||||
|
||||
public NativeRayObject(byte[] data, byte[] metadata) {
|
||||
Preconditions.checkState(bufferLength(data) > 0 || bufferLength(metadata) > 0);
|
||||
this.data = data;
|
||||
this.metadata = metadata;
|
||||
this.containedObjectIds = Collections.emptyList();
|
||||
}
|
||||
|
||||
public void setContainedObjectIds(List<ObjectId> containedObjectIds) {
|
||||
this.containedObjectIds = toBinaryList(containedObjectIds);
|
||||
}
|
||||
|
||||
private static int bufferLength(byte[] buffer) {
|
||||
@@ -23,6 +34,10 @@ public class NativeRayObject {
|
||||
return buffer.length;
|
||||
}
|
||||
|
||||
private static List<byte[]> toBinaryList(List<ObjectId> ids) {
|
||||
return ids.stream().map(BaseId::getBytes).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "<data>: " + bufferLength(data) + ", <metadata>: " + bufferLength(metadata);
|
||||
|
||||
@@ -1,54 +1,56 @@
|
||||
package io.ray.runtime.object;
|
||||
|
||||
import com.google.common.base.FinalizableReferenceQueue;
|
||||
import com.google.common.base.FinalizableWeakReference;
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.Sets;
|
||||
import io.ray.api.ObjectRef;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.api.id.ObjectId;
|
||||
import java.io.Serializable;
|
||||
import io.ray.api.id.UniqueId;
|
||||
import io.ray.runtime.RayRuntimeInternal;
|
||||
import java.io.Externalizable;
|
||||
import java.io.IOException;
|
||||
import java.io.ObjectInput;
|
||||
import java.io.ObjectOutput;
|
||||
import java.lang.ref.Reference;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
/**
|
||||
* Implementation of {@link ObjectRef}.
|
||||
*/
|
||||
public final class ObjectRefImpl<T> implements ObjectRef<T>, Serializable {
|
||||
public final class ObjectRefImpl<T> implements ObjectRef<T>, Externalizable {
|
||||
|
||||
private final ObjectId id;
|
||||
private static final FinalizableReferenceQueue REFERENCE_QUEUE = new FinalizableReferenceQueue();
|
||||
|
||||
/**
|
||||
* Cache the result of `Ray.get()`.
|
||||
*
|
||||
* Note, this is necessary for direct calls, in which case, it's not allowed to call `Ray.get` on
|
||||
* the same object twice.
|
||||
*/
|
||||
private transient T object;
|
||||
private static final Set<Reference<ObjectRefImpl<?>>> REFERENCES = Sets.newConcurrentHashSet();
|
||||
|
||||
private ObjectId id;
|
||||
|
||||
// In GC thread, we don't know which worker this object binds to, so we need to
|
||||
// store the worker ID for later uses.
|
||||
private transient UniqueId workerId;
|
||||
|
||||
private Class<T> type;
|
||||
|
||||
/**
|
||||
* Whether the object is already gotten from the object store.
|
||||
*/
|
||||
private transient boolean objectGotten;
|
||||
|
||||
public ObjectRefImpl(ObjectId id, Class<T> type) {
|
||||
this.id = id;
|
||||
this.type = type;
|
||||
object = null;
|
||||
objectGotten = false;
|
||||
addLocalReference();
|
||||
}
|
||||
|
||||
public ObjectRefImpl() {}
|
||||
|
||||
@Override
|
||||
public synchronized T get() {
|
||||
if (!objectGotten) {
|
||||
object = Ray.get(id, type);
|
||||
objectGotten = true;
|
||||
}
|
||||
return object;
|
||||
return Ray.get(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ObjectId getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Class<T> getType() {
|
||||
return type;
|
||||
}
|
||||
@@ -57,4 +59,56 @@ public final class ObjectRefImpl<T> implements ObjectRef<T>, Serializable {
|
||||
public String toString() {
|
||||
return "ObjectRef(" + id.toString() + ")";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeExternal(ObjectOutput out) throws IOException {
|
||||
out.writeObject(this.getId());
|
||||
out.writeObject(this.getType());
|
||||
ObjectSerializer.addContainedObjectId(this.getId());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
|
||||
this.id = (ObjectId) in.readObject();
|
||||
this.type = (Class<T>) in.readObject();
|
||||
addLocalReference();
|
||||
}
|
||||
|
||||
private void addLocalReference() {
|
||||
Preconditions.checkState(workerId == null);
|
||||
RayRuntimeInternal runtime = (RayRuntimeInternal) Ray.internal();
|
||||
workerId = runtime.getWorkerContext().getCurrentWorkerId();
|
||||
runtime.getObjectStore().addLocalReference(workerId, id);
|
||||
new ObjectRefImplReference(this);
|
||||
}
|
||||
|
||||
private static final class ObjectRefImplReference extends
|
||||
FinalizableWeakReference<ObjectRefImpl<?>> {
|
||||
|
||||
private final UniqueId workerId;
|
||||
private final ObjectId objectId;
|
||||
private final AtomicBoolean removed;
|
||||
|
||||
public ObjectRefImplReference(ObjectRefImpl<?> obj) {
|
||||
super(obj, REFERENCE_QUEUE);
|
||||
this.workerId = obj.workerId;
|
||||
this.objectId = obj.id;
|
||||
this.removed = new AtomicBoolean(false);
|
||||
REFERENCES.add(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void finalizeReferent() {
|
||||
// This method may be invoked multiple times on the same instance (due to explicit invoking in
|
||||
// unit tests). So if `workerId` is null, it means this method has been invoked.
|
||||
if (!removed.getAndSet(true)) {
|
||||
REFERENCES.remove(this);
|
||||
RayRuntimeInternal runtime = (RayRuntimeInternal) Ray.internal();
|
||||
// It's possible that GC is executed after the runtime is shutdown.
|
||||
if (runtime != null) {
|
||||
runtime.getObjectStore().removeLocalReference(workerId, objectId);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,11 @@ import io.ray.api.id.ObjectId;
|
||||
import io.ray.runtime.generated.Gcs.ErrorType;
|
||||
import io.ray.runtime.serializer.Serializer;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
|
||||
/**
|
||||
@@ -31,6 +35,12 @@ public class ObjectSerializer {
|
||||
public static final byte[] OBJECT_METADATA_TYPE_PYTHON = "PYTHON".getBytes();
|
||||
public static final byte[] OBJECT_METADATA_TYPE_RAW = "RAW".getBytes();
|
||||
|
||||
// When an outer object is being serialized, the nested ObjectRefs are all
|
||||
// serialized and the writeExternal method of the nested ObjectRefs are
|
||||
// executed. So after the outer object is serialized, the containedObjectIds
|
||||
// field will contain all the nested object IDs.
|
||||
static ThreadLocal<Set<ObjectId>> containedObjectIds = ThreadLocal.withInitial(HashSet::new);
|
||||
|
||||
/**
|
||||
* Deserialize an object from an {@link NativeRayObject} instance.
|
||||
*
|
||||
@@ -100,9 +110,29 @@ public class ObjectSerializer {
|
||||
byte[] serializedBytes = Serializer.encode(object).getLeft();
|
||||
return new NativeRayObject(serializedBytes, TASK_EXECUTION_EXCEPTION_META);
|
||||
} else {
|
||||
Pair<byte[], Boolean> serialized = Serializer.encode(object);
|
||||
return new NativeRayObject(serialized.getLeft(), serialized.getRight() ?
|
||||
OBJECT_METADATA_TYPE_CROSS_LANGUAGE : OBJECT_METADATA_TYPE_JAVA);
|
||||
try {
|
||||
Pair<byte[], Boolean> serialized = Serializer.encode(object);
|
||||
NativeRayObject nativeRayObject = new NativeRayObject(serialized.getLeft(),
|
||||
serialized.getRight()
|
||||
? OBJECT_METADATA_TYPE_CROSS_LANGUAGE
|
||||
: OBJECT_METADATA_TYPE_JAVA);
|
||||
nativeRayObject.setContainedObjectIds(getAndClearContainedObjectIds());
|
||||
return nativeRayObject;
|
||||
} catch (Exception e) {
|
||||
// Clear `containedObjectIds`.
|
||||
getAndClearContainedObjectIds();
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void addContainedObjectId(ObjectId objectId) {
|
||||
containedObjectIds.get().add(objectId);
|
||||
}
|
||||
|
||||
private static List<ObjectId> getAndClearContainedObjectIds() {
|
||||
List<ObjectId> ids = new ArrayList<>(containedObjectIds.get());
|
||||
containedObjectIds.get().clear();
|
||||
return ids;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import io.ray.api.ObjectRef;
|
||||
import io.ray.api.WaitResult;
|
||||
import io.ray.api.exception.RayException;
|
||||
import io.ray.api.id.ObjectId;
|
||||
import io.ray.api.id.UniqueId;
|
||||
import io.ray.runtime.context.WorkerContext;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
@@ -137,7 +138,8 @@ public abstract class ObjectStore {
|
||||
return new WaitResult<>(Collections.emptyList(), Collections.emptyList());
|
||||
}
|
||||
|
||||
List<ObjectId> ids = waitList.stream().map(ObjectRef::getId).collect(Collectors.toList());
|
||||
List<ObjectId> ids = waitList.stream().map(ref -> ((ObjectRefImpl<?>) ref).getId())
|
||||
.collect(Collectors.toList());
|
||||
|
||||
List<Boolean> ready = wait(ids, numReturns, timeoutMs);
|
||||
List<ObjectRef<T>> readyList = new ArrayList<>();
|
||||
@@ -164,4 +166,18 @@ public abstract class ObjectStore {
|
||||
*/
|
||||
public abstract void delete(List<ObjectId> objectIds, boolean localOnly,
|
||||
boolean deleteCreatingTasks);
|
||||
|
||||
/**
|
||||
* Increase the local reference count for this object ID.
|
||||
* @param workerId The ID of the worker to increase on.
|
||||
* @param objectId The object ID to increase the reference count for.
|
||||
*/
|
||||
public abstract void addLocalReference(UniqueId workerId, ObjectId objectId);
|
||||
|
||||
/**
|
||||
* Decrease the reference count for this object ID.
|
||||
* @param workerId The ID of the worker to decrease on.
|
||||
* @param objectId The object ID to decrease the reference count for.
|
||||
*/
|
||||
public abstract void removeLocalReference(UniqueId workerId, ObjectId objectId);
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import io.ray.api.id.ObjectId;
|
||||
import io.ray.runtime.RayRuntimeInternal;
|
||||
import io.ray.runtime.generated.Common.Language;
|
||||
import io.ray.runtime.object.NativeRayObject;
|
||||
import io.ray.runtime.object.ObjectRefImpl;
|
||||
import io.ray.runtime.object.ObjectSerializer;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
@@ -22,7 +23,8 @@ public class ArgumentsBuilder {
|
||||
* If the the size of an argument's serialized data is smaller than this number, the argument will
|
||||
* be passed by value. Otherwise it'll be passed by reference.
|
||||
*/
|
||||
private static final int LARGEST_SIZE_PASS_BY_VALUE = 100 * 1024;
|
||||
// TODO(kfstorm): Read from internal config `max_direct_call_object_size`.
|
||||
public static final int LARGEST_SIZE_PASS_BY_VALUE = 100 * 1024;
|
||||
|
||||
/**
|
||||
* This dummy type is also defined in signature.py. Please keep it synced.
|
||||
@@ -39,7 +41,8 @@ public class ArgumentsBuilder {
|
||||
ObjectId id = null;
|
||||
NativeRayObject value = null;
|
||||
if (arg instanceof ObjectRef) {
|
||||
id = ((ObjectRef) arg).getId();
|
||||
Preconditions.checkState(arg instanceof ObjectRefImpl);
|
||||
id = ((ObjectRefImpl<?>) arg).getId();
|
||||
} else {
|
||||
value = ObjectSerializer.serialize(arg);
|
||||
if (language != Language.JAVA) {
|
||||
|
||||
Reference in New Issue
Block a user