mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 01:59:23 +08:00
Refine multi-threading support (#3672)
* [Python] refine multi-threading support fix * [java] refine multithreading code fix java * format
This commit is contained in:
@@ -71,14 +71,14 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
@Override
|
||||
public <T> RayObject<T> put(T obj) {
|
||||
UniqueId objectId = UniqueIdUtil.computePutId(
|
||||
workerContext.getCurrentTask().taskId, workerContext.nextPutIndex());
|
||||
workerContext.getCurrentTaskId(), workerContext.nextPutIndex());
|
||||
|
||||
put(objectId, obj);
|
||||
return new RayObjectImpl<>(objectId);
|
||||
}
|
||||
|
||||
public <T> void put(UniqueId objectId, T obj) {
|
||||
UniqueId taskId = workerContext.getCurrentTask().taskId;
|
||||
UniqueId taskId = workerContext.getCurrentTaskId();
|
||||
LOGGER.debug("Putting object {}, for task {} ", objectId, taskId);
|
||||
objectStoreProxy.put(objectId, obj, null);
|
||||
}
|
||||
@@ -92,8 +92,8 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
*/
|
||||
public RayObject<Object> putSerialized(byte[] obj) {
|
||||
UniqueId objectId = UniqueIdUtil.computePutId(
|
||||
workerContext.getCurrentTask().taskId, workerContext.nextPutIndex());
|
||||
UniqueId taskId = workerContext.getCurrentTask().taskId;
|
||||
workerContext.getCurrentTaskId(), workerContext.nextPutIndex());
|
||||
UniqueId taskId = workerContext.getCurrentTaskId();
|
||||
LOGGER.debug("Putting serialized object {}, for task {} ", objectId, taskId);
|
||||
objectStoreProxy.putSerialized(objectId, obj, null);
|
||||
return new RayObjectImpl<>(objectId);
|
||||
@@ -108,7 +108,6 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
@Override
|
||||
public <T> List<T> get(List<UniqueId> objectIds) {
|
||||
boolean wasBlocked = false;
|
||||
UniqueId taskId = workerContext.getCurrentThreadTaskId();
|
||||
|
||||
try {
|
||||
int numObjectIds = objectIds.size();
|
||||
@@ -117,7 +116,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
List<List<UniqueId>> fetchBatches =
|
||||
splitIntoBatches(objectIds, FETCH_BATCH_SIZE);
|
||||
for (List<UniqueId> batch : fetchBatches) {
|
||||
rayletClient.fetchOrReconstruct(batch, true, taskId);
|
||||
rayletClient.fetchOrReconstruct(batch, true, workerContext.getCurrentTaskId());
|
||||
}
|
||||
|
||||
// Get the objects. We initially try to get the objects immediately.
|
||||
@@ -144,7 +143,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
splitIntoBatches(unreadyList, FETCH_BATCH_SIZE);
|
||||
|
||||
for (List<UniqueId> batch : reconstructBatches) {
|
||||
rayletClient.fetchOrReconstruct(batch, false, taskId);
|
||||
rayletClient.fetchOrReconstruct(batch, false, workerContext.getCurrentTaskId());
|
||||
}
|
||||
|
||||
List<Pair<T, GetStatus>> results = objectStoreProxy
|
||||
@@ -171,7 +170,8 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
}
|
||||
|
||||
if (LOGGER.isDebugEnabled()) {
|
||||
LOGGER.debug("Got objects {} for task {}.", Arrays.toString(objectIds.toArray()), taskId);
|
||||
LOGGER.debug("Got objects {} for task {}.", Arrays.toString(objectIds.toArray()),
|
||||
workerContext.getCurrentTaskId());
|
||||
}
|
||||
|
||||
List<T> finalRet = new ArrayList<>();
|
||||
@@ -182,13 +182,13 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
|
||||
return finalRet;
|
||||
} catch (RayException e) {
|
||||
LOGGER.error("Failed to get objects for task {}.", taskId, e);
|
||||
LOGGER.error("Failed to get objects for task {}.", workerContext.getCurrentTaskId(), e);
|
||||
throw e;
|
||||
} finally {
|
||||
// If there were objects that we weren't able to get locally, let the local
|
||||
// scheduler know that we're now unblocked.
|
||||
if (wasBlocked) {
|
||||
rayletClient.notifyUnblocked(taskId);
|
||||
rayletClient.notifyUnblocked(workerContext.getCurrentTaskId());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -217,7 +217,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
@Override
|
||||
public <T> WaitResult<T> wait(List<RayObject<T>> waitList, int numReturns, int timeoutMs) {
|
||||
return rayletClient.wait(waitList, numReturns,
|
||||
timeoutMs, workerContext.getCurrentThreadTaskId());
|
||||
timeoutMs, workerContext.getCurrentTaskId());
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -277,9 +277,8 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
*/
|
||||
private TaskSpec createTaskSpec(RayFunc func, RayActorImpl actor, Object[] args,
|
||||
boolean isActorCreationTask, BaseTaskOptions taskOptions) {
|
||||
final TaskSpec current = workerContext.getCurrentTask();
|
||||
UniqueId taskId = rayletClient.generateTaskId(current.driverId,
|
||||
current.taskId, workerContext.nextCallIndex());
|
||||
UniqueId taskId = rayletClient.generateTaskId(workerContext.getCurrentDriverId(),
|
||||
workerContext.getCurrentTaskId(), workerContext.nextTaskIndex());
|
||||
int numReturns = actor.getId().isNil() ? 1 : 2;
|
||||
UniqueId[] returnIds = genReturnIds(taskId, numReturns);
|
||||
|
||||
@@ -304,11 +303,11 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
if (taskOptions instanceof ActorCreationOptions) {
|
||||
maxActorReconstruction = ((ActorCreationOptions) taskOptions).maxReconstructions;
|
||||
}
|
||||
RayFunction rayFunction = functionManager.getFunction(current.driverId, func);
|
||||
RayFunction rayFunction = functionManager.getFunction(workerContext.getCurrentDriverId(), func);
|
||||
return new TaskSpec(
|
||||
current.driverId,
|
||||
workerContext.getCurrentDriverId(),
|
||||
taskId,
|
||||
current.taskId,
|
||||
workerContext.getCurrentTaskId(),
|
||||
-1,
|
||||
actorCreationId,
|
||||
maxActorReconstruction,
|
||||
|
||||
@@ -39,7 +39,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
path += ":";
|
||||
}
|
||||
|
||||
path += rayConfig.libraryPath.stream().collect(Collectors.joining(":"));
|
||||
path += String.join(":", rayConfig.libraryPath);
|
||||
|
||||
// This is a hack to reset library path at runtime,
|
||||
// see https://stackoverflow.com/questions/15409223/.
|
||||
@@ -80,7 +80,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
rayConfig.rayletSocketName,
|
||||
workerContext.getCurrentWorkerId(),
|
||||
rayConfig.workerMode == WorkerMode.WORKER,
|
||||
workerContext.getCurrentTask().taskId
|
||||
workerContext.getCurrentDriverId()
|
||||
);
|
||||
|
||||
// register
|
||||
|
||||
@@ -43,8 +43,7 @@ public class Worker {
|
||||
RayFunction rayFunction = runtime.getFunctionManager()
|
||||
.getFunction(spec.driverId, spec.functionDescriptor);
|
||||
// Set context
|
||||
runtime.getWorkerContext().setCurrentTask(spec);
|
||||
runtime.getWorkerContext().setCurrentClassLoader(rayFunction.classLoader);
|
||||
runtime.getWorkerContext().setCurrentTask(spec, rayFunction.classLoader);
|
||||
Thread.currentThread().setContextClassLoader(rayFunction.classLoader);
|
||||
// Get local actor object and arguments.
|
||||
Object actor = spec.isActorTask() ? runtime.localActors.get(spec.actorId) : null;
|
||||
@@ -67,6 +66,7 @@ public class Worker {
|
||||
LOGGER.error("Error executing task " + spec, e);
|
||||
runtime.put(returnId, new RayException("Error executing task " + spec, e));
|
||||
} finally {
|
||||
runtime.getWorkerContext().setCurrentTask(null, null);
|
||||
Thread.currentThread().setContextClassLoader(oldLoader);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
package org.ray.runtime;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.util.HashMap;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.config.WorkerMode;
|
||||
import org.ray.runtime.task.TaskSpec;
|
||||
@@ -11,123 +8,114 @@ import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
public class WorkerContext {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(WorkerContext.class);
|
||||
|
||||
/**
|
||||
* Worker id.
|
||||
*/
|
||||
private UniqueId workerId;
|
||||
|
||||
/**
|
||||
* Current task.
|
||||
*/
|
||||
private TaskSpec currentTask;
|
||||
private ThreadLocal<UniqueId> currentTaskId;
|
||||
|
||||
/**
|
||||
* Current class loader.
|
||||
* Number of objects that have been put from current task.
|
||||
*/
|
||||
private ThreadLocal<Integer> putIndex;
|
||||
|
||||
/**
|
||||
* Number of tasks that have been submitted from current task.
|
||||
*/
|
||||
private ThreadLocal<Integer> taskIndex;
|
||||
|
||||
private UniqueId currentDriverId;
|
||||
|
||||
private ClassLoader currentClassLoader;
|
||||
|
||||
/**
|
||||
* How many puts have been done by current task.
|
||||
*/
|
||||
private AtomicInteger currentTaskPutCount;
|
||||
|
||||
/**
|
||||
* How many calls have been done by current task.
|
||||
*/
|
||||
private AtomicInteger currentTaskCallCount;
|
||||
|
||||
/**
|
||||
* The ID of main thread which created the worker context.
|
||||
*/
|
||||
private long mainThreadId;
|
||||
/**
|
||||
* If the multi-threading warning message has been logged.
|
||||
*/
|
||||
private AtomicBoolean multiThreadingWarned;
|
||||
|
||||
|
||||
public WorkerContext(WorkerMode workerMode, UniqueId driverId) {
|
||||
workerId = workerMode == WorkerMode.DRIVER ? driverId : UniqueId.randomId();
|
||||
currentTaskPutCount = new AtomicInteger(0);
|
||||
currentTaskCallCount = new AtomicInteger(0);
|
||||
currentClassLoader = null;
|
||||
currentTask = createDummyTask(workerMode, driverId);
|
||||
mainThreadId = Thread.currentThread().getId();
|
||||
multiThreadingWarned = new AtomicBoolean(false);
|
||||
taskIndex = ThreadLocal.withInitial(() -> 0);
|
||||
putIndex = ThreadLocal.withInitial(() -> 0);
|
||||
currentTaskId = ThreadLocal.withInitial(UniqueId::randomId);
|
||||
if (workerMode == WorkerMode.DRIVER) {
|
||||
workerId = driverId;
|
||||
currentTaskId.set(UniqueId.randomId());
|
||||
currentDriverId = driverId;
|
||||
currentClassLoader = null;
|
||||
} else {
|
||||
workerId = UniqueId.randomId();
|
||||
setCurrentTask(null, null);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current thread's task ID.
|
||||
* This returns the assigned task ID if called on the main thread, else a
|
||||
* random task ID.
|
||||
* @return For the main thread, this method returns the ID of this worker's current running task;
|
||||
* for other threads, this method returns a random ID.
|
||||
*/
|
||||
public UniqueId getCurrentThreadTaskId() {
|
||||
UniqueId taskId;
|
||||
if (Thread.currentThread().getId() == mainThreadId) {
|
||||
taskId = currentTask.taskId;
|
||||
public UniqueId getCurrentTaskId() {
|
||||
return currentTaskId.get();
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the current task which is being executed by the current worker. Note, this method can only
|
||||
* be called from the main thread.
|
||||
*/
|
||||
public void setCurrentTask(TaskSpec task, ClassLoader classLoader) {
|
||||
Preconditions.checkState(
|
||||
Thread.currentThread().getId() == mainThreadId,
|
||||
"This method should only be called from the main thread."
|
||||
);
|
||||
if (task != null) {
|
||||
currentTaskId.set(task.taskId);
|
||||
currentDriverId = task.driverId;
|
||||
} else {
|
||||
taskId = UniqueId.randomId();
|
||||
if (multiThreadingWarned.compareAndSet(false, true)) {
|
||||
LOGGER.warn("Calling Ray.get or Ray.wait in a separate thread " +
|
||||
"may lead to deadlock if the main thread blocks on this " +
|
||||
"thread and there are not enough resources to execute " +
|
||||
"more tasks");
|
||||
}
|
||||
currentTaskId.set(UniqueId.NIL);
|
||||
currentDriverId = UniqueId.NIL;
|
||||
}
|
||||
|
||||
Preconditions.checkState(!taskId.isNil());
|
||||
return taskId;
|
||||
}
|
||||
|
||||
public void setWorkerId(UniqueId workerId) {
|
||||
this.workerId = workerId;
|
||||
}
|
||||
|
||||
public TaskSpec getCurrentTask() {
|
||||
return currentTask;
|
||||
taskIndex.set(0);
|
||||
putIndex.set(0);
|
||||
currentClassLoader = classLoader;
|
||||
}
|
||||
|
||||
/**
|
||||
* Increment the put index and return the new value.
|
||||
*/
|
||||
public int nextPutIndex() {
|
||||
return currentTaskPutCount.incrementAndGet();
|
||||
putIndex.set(putIndex.get() + 1);
|
||||
return putIndex.get();
|
||||
}
|
||||
|
||||
public int nextCallIndex() {
|
||||
return currentTaskCallCount.incrementAndGet();
|
||||
/**
|
||||
* Increment the task index and return the new value.
|
||||
*/
|
||||
public int nextTaskIndex() {
|
||||
taskIndex.set(taskIndex.get() + 1);
|
||||
return taskIndex.get();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The ID of the current worker.
|
||||
*/
|
||||
public UniqueId getCurrentWorkerId() {
|
||||
return workerId;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return If this worker is a driver, this method returns the driver ID; Otherwise, it returns
|
||||
* the driver ID of the current running task.
|
||||
*/
|
||||
public UniqueId getCurrentDriverId() {
|
||||
return currentDriverId;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The class loader which is associated with the current driver.
|
||||
*/
|
||||
public ClassLoader getCurrentClassLoader() {
|
||||
return currentClassLoader;
|
||||
}
|
||||
|
||||
public void setCurrentTask(TaskSpec currentTask) {
|
||||
this.currentTask = currentTask;
|
||||
currentTaskCallCount.set(0);
|
||||
currentTaskPutCount.set(0);
|
||||
}
|
||||
|
||||
public void setCurrentClassLoader(ClassLoader currentClassLoader) {
|
||||
this.currentClassLoader = currentClassLoader;
|
||||
}
|
||||
|
||||
private TaskSpec createDummyTask(WorkerMode workerMode, UniqueId driverId) {
|
||||
return new TaskSpec(
|
||||
driverId,
|
||||
workerMode == WorkerMode.DRIVER ? UniqueId.randomId() : UniqueId.NIL,
|
||||
UniqueId.NIL,
|
||||
0,
|
||||
UniqueId.NIL,
|
||||
0,
|
||||
UniqueId.NIL,
|
||||
UniqueId.NIL,
|
||||
0,
|
||||
null,
|
||||
null,
|
||||
new HashMap<>(),
|
||||
null);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,7 +95,7 @@ public class MockObjectStore implements ObjectStoreLink {
|
||||
}
|
||||
|
||||
private String logPrefix() {
|
||||
return runtime.getWorkerContext().getCurrentTask().taskId + "-" + getUserTrace() + " -> ";
|
||||
return runtime.getWorkerContext().getCurrentTaskId() + "-" + getUserTrace() + " -> ";
|
||||
}
|
||||
|
||||
private String getUserTrace() {
|
||||
|
||||
@@ -79,6 +79,9 @@ public class RayletClientImpl implements RayletClient {
|
||||
@Override
|
||||
public void submitTask(TaskSpec spec) {
|
||||
LOGGER.debug("Submitting task: {}", spec);
|
||||
Preconditions.checkState(!spec.parentTaskId.isNil());
|
||||
Preconditions.checkState(!spec.driverId.isNil());
|
||||
|
||||
ByteBuffer info = convertTaskSpecToFlatbuffer(spec);
|
||||
byte[] cursorId = null;
|
||||
if (!spec.getExecutionDependencies().isEmpty()) {
|
||||
|
||||
Reference in New Issue
Block a user