mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 02:47:10 +08:00
[Java] Support calling Ray APIs from multiple threads (#3646)
This commit is contained in:
@@ -108,9 +108,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
@Override
|
||||
public <T> List<T> get(List<UniqueId> objectIds) {
|
||||
boolean wasBlocked = false;
|
||||
// TODO(swang): If we are not on the main thread, then we should generate a
|
||||
// random task ID to pass to the backend.
|
||||
UniqueId taskId = workerContext.getCurrentTask().taskId;
|
||||
UniqueId taskId = workerContext.getCurrentThreadTaskId();
|
||||
|
||||
try {
|
||||
int numObjectIds = objectIds.size();
|
||||
@@ -218,10 +216,8 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
|
||||
@Override
|
||||
public <T> WaitResult<T> wait(List<RayObject<T>> waitList, int numReturns, int timeoutMs) {
|
||||
// TODO(swang): If we are not on the main thread, then we should generate a
|
||||
// random task ID to pass to the backend.
|
||||
return rayletClient.wait(waitList, numReturns, timeoutMs,
|
||||
workerContext.getCurrentTask().taskId);
|
||||
return rayletClient.wait(waitList, numReturns,
|
||||
timeoutMs, workerContext.getCurrentThreadTaskId());
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -237,9 +233,12 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
throw new IllegalArgumentException("Unsupported actor type: " + actor.getClass().getName());
|
||||
}
|
||||
RayActorImpl actorImpl = (RayActorImpl)actor;
|
||||
TaskSpec spec = createTaskSpec(func, actorImpl, args, false, null);
|
||||
spec.getExecutionDependencies().add(((RayActorImpl) actor).getTaskCursor());
|
||||
actorImpl.setTaskCursor(spec.returnIds[1]);
|
||||
TaskSpec spec;
|
||||
synchronized (actor) {
|
||||
spec = createTaskSpec(func, actorImpl, args, false, null);
|
||||
spec.getExecutionDependencies().add(((RayActorImpl) actor).getTaskCursor());
|
||||
actorImpl.setTaskCursor(spec.returnIds[1]);
|
||||
}
|
||||
rayletClient.submitTask(spec);
|
||||
return new RayObjectImpl(spec.returnIds[0]);
|
||||
}
|
||||
@@ -342,4 +341,8 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
public FunctionManager getFunctionManager() {
|
||||
return functionManager;
|
||||
}
|
||||
|
||||
public RayConfig getRayConfig() {
|
||||
return rayConfig;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,10 +11,12 @@ public class RayDevRuntime extends AbstractRayRuntime {
|
||||
super(rayConfig);
|
||||
}
|
||||
|
||||
private MockObjectStore store;
|
||||
|
||||
@Override
|
||||
public void start() {
|
||||
MockObjectStore store = new MockObjectStore(this);
|
||||
objectStoreProxy = new ObjectStoreProxy(this, store);
|
||||
store = new MockObjectStore(this);
|
||||
objectStoreProxy = new ObjectStoreProxy(this, null);
|
||||
rayletClient = new MockRayletClient(this, store);
|
||||
}
|
||||
|
||||
@@ -22,4 +24,8 @@ public class RayDevRuntime extends AbstractRayRuntime {
|
||||
public void shutdown() {
|
||||
// nothing to do
|
||||
}
|
||||
|
||||
public MockObjectStore getObjectStore() {
|
||||
return store;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,8 +74,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
}
|
||||
kvStore = new RedisClient(rayConfig.getRedisAddress());
|
||||
|
||||
ObjectStoreLink store = new PlasmaClient(rayConfig.objectStoreSocketName, "", 0);
|
||||
objectStoreProxy = new ObjectStoreProxy(this, store);
|
||||
objectStoreProxy = new ObjectStoreProxy(this, rayConfig.objectStoreSocketName);
|
||||
|
||||
rayletClient = new RayletClientImpl(
|
||||
rayConfig.rayletSocketName,
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
package org.ray.runtime;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.util.HashMap;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.config.WorkerMode;
|
||||
import org.ray.runtime.task.TaskSpec;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
public class WorkerContext {
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(WorkerContext.class);
|
||||
|
||||
/**
|
||||
* Worker id.
|
||||
@@ -25,19 +31,53 @@ public class WorkerContext {
|
||||
/**
|
||||
* How many puts have been done by current task.
|
||||
*/
|
||||
private int currentTaskPutCount;
|
||||
private AtomicInteger currentTaskPutCount;
|
||||
|
||||
/**
|
||||
* How many calls have been done by current task.
|
||||
*/
|
||||
private int currentTaskCallCount;
|
||||
private AtomicInteger currentTaskCallCount;
|
||||
|
||||
/**
|
||||
* The ID of main thread which created the worker context.
|
||||
*/
|
||||
private long mainThreadId;
|
||||
/**
|
||||
* If the multi-threading warning message has been logged.
|
||||
*/
|
||||
private AtomicBoolean multiThreadingWarned;
|
||||
|
||||
public WorkerContext(WorkerMode workerMode, UniqueId driverId) {
|
||||
workerId = workerMode == WorkerMode.DRIVER ? driverId : UniqueId.randomId();
|
||||
currentTaskPutCount = 0;
|
||||
currentTaskCallCount = 0;
|
||||
currentTaskPutCount = new AtomicInteger(0);
|
||||
currentTaskCallCount = new AtomicInteger(0);
|
||||
currentClassLoader = null;
|
||||
currentTask = createDummyTask(workerMode, driverId);
|
||||
mainThreadId = Thread.currentThread().getId();
|
||||
multiThreadingWarned = new AtomicBoolean(false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current thread's task ID.
|
||||
* This returns the assigned task ID if called on the main thread, else a
|
||||
* random task ID.
|
||||
*/
|
||||
public UniqueId getCurrentThreadTaskId() {
|
||||
UniqueId taskId;
|
||||
if (Thread.currentThread().getId() == mainThreadId) {
|
||||
taskId = currentTask.taskId;
|
||||
} else {
|
||||
taskId = UniqueId.randomId();
|
||||
if (multiThreadingWarned.compareAndSet(false, true)) {
|
||||
LOGGER.warn("Calling Ray.get or Ray.wait in a separate thread " +
|
||||
"may lead to deadlock if the main thread blocks on this " +
|
||||
"thread and there are not enough resources to execute " +
|
||||
"more tasks");
|
||||
}
|
||||
}
|
||||
|
||||
Preconditions.checkState(!taskId.isNil());
|
||||
return taskId;
|
||||
}
|
||||
|
||||
public void setWorkerId(UniqueId workerId) {
|
||||
@@ -49,11 +89,11 @@ public class WorkerContext {
|
||||
}
|
||||
|
||||
public int nextPutIndex() {
|
||||
return ++currentTaskPutCount;
|
||||
return currentTaskPutCount.incrementAndGet();
|
||||
}
|
||||
|
||||
public int nextCallIndex() {
|
||||
return ++currentTaskCallCount;
|
||||
return currentTaskCallCount.incrementAndGet();
|
||||
}
|
||||
|
||||
public UniqueId getCurrentWorkerId() {
|
||||
@@ -66,6 +106,8 @@ public class WorkerContext {
|
||||
|
||||
public void setCurrentTask(TaskSpec currentTask) {
|
||||
this.currentTask = currentTask;
|
||||
currentTaskCallCount.set(0);
|
||||
currentTaskPutCount.set(0);
|
||||
}
|
||||
|
||||
public void setCurrentClassLoader(ClassLoader currentClassLoader) {
|
||||
|
||||
@@ -3,10 +3,13 @@ package org.ray.runtime.objectstore;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.apache.arrow.plasma.ObjectStoreLink;
|
||||
import org.apache.arrow.plasma.PlasmaClient;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.ray.api.exception.RayException;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.RayDevRuntime;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
import org.ray.runtime.util.Serializer;
|
||||
import org.ray.runtime.util.UniqueIdUtil;
|
||||
|
||||
@@ -19,11 +22,18 @@ public class ObjectStoreProxy {
|
||||
private static final int GET_TIMEOUT_MS = 1000;
|
||||
|
||||
private final AbstractRayRuntime runtime;
|
||||
private final ObjectStoreLink store;
|
||||
|
||||
public ObjectStoreProxy(AbstractRayRuntime runtime, ObjectStoreLink store) {
|
||||
private static ThreadLocal<ObjectStoreLink> objectStore;
|
||||
|
||||
public ObjectStoreProxy(AbstractRayRuntime runtime, String storeSocketName) {
|
||||
this.runtime = runtime;
|
||||
this.store = store;
|
||||
objectStore = ThreadLocal.withInitial(() -> {
|
||||
if (runtime.getRayConfig().runMode == RunMode.CLUSTER) {
|
||||
return new PlasmaClient(storeSocketName, "", 0);
|
||||
} else {
|
||||
return ((RayDevRuntime) runtime).getObjectStore();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public <T> Pair<T, GetStatus> get(UniqueId objectId, boolean isMetadata)
|
||||
@@ -33,10 +43,10 @@ public class ObjectStoreProxy {
|
||||
|
||||
public <T> Pair<T, GetStatus> get(UniqueId id, int timeoutMs, boolean isMetadata)
|
||||
throws RayException {
|
||||
byte[] obj = store.get(id.getBytes(), timeoutMs, isMetadata);
|
||||
byte[] obj = objectStore.get().get(id.getBytes(), timeoutMs, isMetadata);
|
||||
if (obj != null) {
|
||||
T t = Serializer.decode(obj, runtime.getWorkerContext().getCurrentClassLoader());
|
||||
store.release(id.getBytes());
|
||||
objectStore.get().release(id.getBytes());
|
||||
if (t instanceof RayException) {
|
||||
throw (RayException) t;
|
||||
}
|
||||
@@ -53,13 +63,13 @@ public class ObjectStoreProxy {
|
||||
|
||||
public <T> List<Pair<T, GetStatus>> get(List<UniqueId> ids, int timeoutMs, boolean isMetadata)
|
||||
throws RayException {
|
||||
List<byte[]> objs = store.get(UniqueIdUtil.getIdBytes(ids), timeoutMs, isMetadata);
|
||||
List<byte[]> objs = objectStore.get().get(UniqueIdUtil.getIdBytes(ids), timeoutMs, isMetadata);
|
||||
List<Pair<T, GetStatus>> ret = new ArrayList<>();
|
||||
for (int i = 0; i < objs.size(); i++) {
|
||||
byte[] obj = objs.get(i);
|
||||
if (obj != null) {
|
||||
T t = Serializer.decode(obj, runtime.getWorkerContext().getCurrentClassLoader());
|
||||
store.release(ids.get(i).getBytes());
|
||||
objectStore.get().release(ids.get(i).getBytes());
|
||||
if (t instanceof RayException) {
|
||||
throw (RayException) t;
|
||||
}
|
||||
@@ -72,11 +82,11 @@ public class ObjectStoreProxy {
|
||||
}
|
||||
|
||||
public void put(UniqueId id, Object obj, Object metadata) {
|
||||
store.put(id.getBytes(), Serializer.encode(obj), Serializer.encode(metadata));
|
||||
objectStore.get().put(id.getBytes(), Serializer.encode(obj), Serializer.encode(metadata));
|
||||
}
|
||||
|
||||
public void putSerialized(UniqueId id, byte[] obj, byte[] metadata) {
|
||||
store.put(id.getBytes(), obj, metadata);
|
||||
objectStore.get().put(id.getBytes(), obj, metadata);
|
||||
}
|
||||
|
||||
public enum GetStatus {
|
||||
|
||||
Reference in New Issue
Block a user