[Java] Support calling Ray APIs from multiple threads (#3646)

This commit is contained in:
Wang Qing
2018-12-28 17:44:31 +08:00
committed by Hao Chen
parent 0b682d043e
commit c59b506c6e
6 changed files with 201 additions and 29 deletions
@@ -108,9 +108,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
@Override
public <T> List<T> get(List<UniqueId> objectIds) {
boolean wasBlocked = false;
// TODO(swang): If we are not on the main thread, then we should generate a
// random task ID to pass to the backend.
UniqueId taskId = workerContext.getCurrentTask().taskId;
UniqueId taskId = workerContext.getCurrentThreadTaskId();
try {
int numObjectIds = objectIds.size();
@@ -218,10 +216,8 @@ public abstract class AbstractRayRuntime implements RayRuntime {
@Override
public <T> WaitResult<T> wait(List<RayObject<T>> waitList, int numReturns, int timeoutMs) {
// TODO(swang): If we are not on the main thread, then we should generate a
// random task ID to pass to the backend.
return rayletClient.wait(waitList, numReturns, timeoutMs,
workerContext.getCurrentTask().taskId);
return rayletClient.wait(waitList, numReturns,
timeoutMs, workerContext.getCurrentThreadTaskId());
}
@Override
@@ -237,9 +233,12 @@ public abstract class AbstractRayRuntime implements RayRuntime {
throw new IllegalArgumentException("Unsupported actor type: " + actor.getClass().getName());
}
RayActorImpl actorImpl = (RayActorImpl)actor;
TaskSpec spec = createTaskSpec(func, actorImpl, args, false, null);
spec.getExecutionDependencies().add(((RayActorImpl) actor).getTaskCursor());
actorImpl.setTaskCursor(spec.returnIds[1]);
TaskSpec spec;
synchronized (actor) {
spec = createTaskSpec(func, actorImpl, args, false, null);
spec.getExecutionDependencies().add(((RayActorImpl) actor).getTaskCursor());
actorImpl.setTaskCursor(spec.returnIds[1]);
}
rayletClient.submitTask(spec);
return new RayObjectImpl(spec.returnIds[0]);
}
@@ -342,4 +341,8 @@ public abstract class AbstractRayRuntime implements RayRuntime {
public FunctionManager getFunctionManager() {
return functionManager;
}
public RayConfig getRayConfig() {
return rayConfig;
}
}
@@ -11,10 +11,12 @@ public class RayDevRuntime extends AbstractRayRuntime {
super(rayConfig);
}
private MockObjectStore store;
@Override
public void start() {
MockObjectStore store = new MockObjectStore(this);
objectStoreProxy = new ObjectStoreProxy(this, store);
store = new MockObjectStore(this);
objectStoreProxy = new ObjectStoreProxy(this, null);
rayletClient = new MockRayletClient(this, store);
}
@@ -22,4 +24,8 @@ public class RayDevRuntime extends AbstractRayRuntime {
public void shutdown() {
// nothing to do
}
public MockObjectStore getObjectStore() {
return store;
}
}
@@ -74,8 +74,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
}
kvStore = new RedisClient(rayConfig.getRedisAddress());
ObjectStoreLink store = new PlasmaClient(rayConfig.objectStoreSocketName, "", 0);
objectStoreProxy = new ObjectStoreProxy(this, store);
objectStoreProxy = new ObjectStoreProxy(this, rayConfig.objectStoreSocketName);
rayletClient = new RayletClientImpl(
rayConfig.rayletSocketName,
@@ -1,11 +1,17 @@
package org.ray.runtime;
import com.google.common.base.Preconditions;
import java.util.HashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.WorkerMode;
import org.ray.runtime.task.TaskSpec;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class WorkerContext {
private static final Logger LOGGER = LoggerFactory.getLogger(WorkerContext.class);
/**
* Worker id.
@@ -25,19 +31,53 @@ public class WorkerContext {
/**
* How many puts have been done by current task.
*/
private int currentTaskPutCount;
private AtomicInteger currentTaskPutCount;
/**
* How many calls have been done by current task.
*/
private int currentTaskCallCount;
private AtomicInteger currentTaskCallCount;
/**
* The ID of main thread which created the worker context.
*/
private long mainThreadId;
/**
* If the multi-threading warning message has been logged.
*/
private AtomicBoolean multiThreadingWarned;
public WorkerContext(WorkerMode workerMode, UniqueId driverId) {
workerId = workerMode == WorkerMode.DRIVER ? driverId : UniqueId.randomId();
currentTaskPutCount = 0;
currentTaskCallCount = 0;
currentTaskPutCount = new AtomicInteger(0);
currentTaskCallCount = new AtomicInteger(0);
currentClassLoader = null;
currentTask = createDummyTask(workerMode, driverId);
mainThreadId = Thread.currentThread().getId();
multiThreadingWarned = new AtomicBoolean(false);
}
/**
* Get the current thread's task ID.
* This returns the assigned task ID if called on the main thread, else a
* random task ID.
*/
public UniqueId getCurrentThreadTaskId() {
UniqueId taskId;
if (Thread.currentThread().getId() == mainThreadId) {
taskId = currentTask.taskId;
} else {
taskId = UniqueId.randomId();
if (multiThreadingWarned.compareAndSet(false, true)) {
LOGGER.warn("Calling Ray.get or Ray.wait in a separate thread " +
"may lead to deadlock if the main thread blocks on this " +
"thread and there are not enough resources to execute " +
"more tasks");
}
}
Preconditions.checkState(!taskId.isNil());
return taskId;
}
public void setWorkerId(UniqueId workerId) {
@@ -49,11 +89,11 @@ public class WorkerContext {
}
public int nextPutIndex() {
return ++currentTaskPutCount;
return currentTaskPutCount.incrementAndGet();
}
public int nextCallIndex() {
return ++currentTaskCallCount;
return currentTaskCallCount.incrementAndGet();
}
public UniqueId getCurrentWorkerId() {
@@ -66,6 +106,8 @@ public class WorkerContext {
public void setCurrentTask(TaskSpec currentTask) {
this.currentTask = currentTask;
currentTaskCallCount.set(0);
currentTaskPutCount.set(0);
}
public void setCurrentClassLoader(ClassLoader currentClassLoader) {
@@ -3,10 +3,13 @@ package org.ray.runtime.objectstore;
import java.util.ArrayList;
import java.util.List;
import org.apache.arrow.plasma.ObjectStoreLink;
import org.apache.arrow.plasma.PlasmaClient;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.exception.RayException;
import org.ray.api.id.UniqueId;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.RayDevRuntime;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.util.Serializer;
import org.ray.runtime.util.UniqueIdUtil;
@@ -19,11 +22,18 @@ public class ObjectStoreProxy {
private static final int GET_TIMEOUT_MS = 1000;
private final AbstractRayRuntime runtime;
private final ObjectStoreLink store;
public ObjectStoreProxy(AbstractRayRuntime runtime, ObjectStoreLink store) {
private static ThreadLocal<ObjectStoreLink> objectStore;
public ObjectStoreProxy(AbstractRayRuntime runtime, String storeSocketName) {
this.runtime = runtime;
this.store = store;
objectStore = ThreadLocal.withInitial(() -> {
if (runtime.getRayConfig().runMode == RunMode.CLUSTER) {
return new PlasmaClient(storeSocketName, "", 0);
} else {
return ((RayDevRuntime) runtime).getObjectStore();
}
});
}
public <T> Pair<T, GetStatus> get(UniqueId objectId, boolean isMetadata)
@@ -33,10 +43,10 @@ public class ObjectStoreProxy {
public <T> Pair<T, GetStatus> get(UniqueId id, int timeoutMs, boolean isMetadata)
throws RayException {
byte[] obj = store.get(id.getBytes(), timeoutMs, isMetadata);
byte[] obj = objectStore.get().get(id.getBytes(), timeoutMs, isMetadata);
if (obj != null) {
T t = Serializer.decode(obj, runtime.getWorkerContext().getCurrentClassLoader());
store.release(id.getBytes());
objectStore.get().release(id.getBytes());
if (t instanceof RayException) {
throw (RayException) t;
}
@@ -53,13 +63,13 @@ public class ObjectStoreProxy {
public <T> List<Pair<T, GetStatus>> get(List<UniqueId> ids, int timeoutMs, boolean isMetadata)
throws RayException {
List<byte[]> objs = store.get(UniqueIdUtil.getIdBytes(ids), timeoutMs, isMetadata);
List<byte[]> objs = objectStore.get().get(UniqueIdUtil.getIdBytes(ids), timeoutMs, isMetadata);
List<Pair<T, GetStatus>> ret = new ArrayList<>();
for (int i = 0; i < objs.size(); i++) {
byte[] obj = objs.get(i);
if (obj != null) {
T t = Serializer.decode(obj, runtime.getWorkerContext().getCurrentClassLoader());
store.release(ids.get(i).getBytes());
objectStore.get().release(ids.get(i).getBytes());
if (t instanceof RayException) {
throw (RayException) t;
}
@@ -72,11 +82,11 @@ public class ObjectStoreProxy {
}
public void put(UniqueId id, Object obj, Object metadata) {
store.put(id.getBytes(), Serializer.encode(obj), Serializer.encode(metadata));
objectStore.get().put(id.getBytes(), Serializer.encode(obj), Serializer.encode(metadata));
}
public void putSerialized(UniqueId id, byte[] obj, byte[] metadata) {
store.put(id.getBytes(), obj, metadata);
objectStore.get().put(id.getBytes(), obj, metadata);
}
public enum GetStatus {