Refine multi-threading support (#3672)

* [Python] refine multi-threading support

fix

* [java] refine multithreading code

fix java

* format
This commit is contained in:
Hao Chen
2019-01-11 05:58:11 +08:00
committed by Stephanie Wang
parent 71243203a4
commit 597abb24ea
9 changed files with 394 additions and 313 deletions
@@ -71,14 +71,14 @@ public abstract class AbstractRayRuntime implements RayRuntime {
@Override
public <T> RayObject<T> put(T obj) {
UniqueId objectId = UniqueIdUtil.computePutId(
workerContext.getCurrentTask().taskId, workerContext.nextPutIndex());
workerContext.getCurrentTaskId(), workerContext.nextPutIndex());
put(objectId, obj);
return new RayObjectImpl<>(objectId);
}
public <T> void put(UniqueId objectId, T obj) {
UniqueId taskId = workerContext.getCurrentTask().taskId;
UniqueId taskId = workerContext.getCurrentTaskId();
LOGGER.debug("Putting object {}, for task {} ", objectId, taskId);
objectStoreProxy.put(objectId, obj, null);
}
@@ -92,8 +92,8 @@ public abstract class AbstractRayRuntime implements RayRuntime {
*/
public RayObject<Object> putSerialized(byte[] obj) {
UniqueId objectId = UniqueIdUtil.computePutId(
workerContext.getCurrentTask().taskId, workerContext.nextPutIndex());
UniqueId taskId = workerContext.getCurrentTask().taskId;
workerContext.getCurrentTaskId(), workerContext.nextPutIndex());
UniqueId taskId = workerContext.getCurrentTaskId();
LOGGER.debug("Putting serialized object {}, for task {} ", objectId, taskId);
objectStoreProxy.putSerialized(objectId, obj, null);
return new RayObjectImpl<>(objectId);
@@ -108,7 +108,6 @@ public abstract class AbstractRayRuntime implements RayRuntime {
@Override
public <T> List<T> get(List<UniqueId> objectIds) {
boolean wasBlocked = false;
UniqueId taskId = workerContext.getCurrentThreadTaskId();
try {
int numObjectIds = objectIds.size();
@@ -117,7 +116,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
List<List<UniqueId>> fetchBatches =
splitIntoBatches(objectIds, FETCH_BATCH_SIZE);
for (List<UniqueId> batch : fetchBatches) {
rayletClient.fetchOrReconstruct(batch, true, taskId);
rayletClient.fetchOrReconstruct(batch, true, workerContext.getCurrentTaskId());
}
// Get the objects. We initially try to get the objects immediately.
@@ -144,7 +143,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
splitIntoBatches(unreadyList, FETCH_BATCH_SIZE);
for (List<UniqueId> batch : reconstructBatches) {
rayletClient.fetchOrReconstruct(batch, false, taskId);
rayletClient.fetchOrReconstruct(batch, false, workerContext.getCurrentTaskId());
}
List<Pair<T, GetStatus>> results = objectStoreProxy
@@ -171,7 +170,8 @@ public abstract class AbstractRayRuntime implements RayRuntime {
}
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Got objects {} for task {}.", Arrays.toString(objectIds.toArray()), taskId);
LOGGER.debug("Got objects {} for task {}.", Arrays.toString(objectIds.toArray()),
workerContext.getCurrentTaskId());
}
List<T> finalRet = new ArrayList<>();
@@ -182,13 +182,13 @@ public abstract class AbstractRayRuntime implements RayRuntime {
return finalRet;
} catch (RayException e) {
LOGGER.error("Failed to get objects for task {}.", taskId, e);
LOGGER.error("Failed to get objects for task {}.", workerContext.getCurrentTaskId(), e);
throw e;
} finally {
// If there were objects that we weren't able to get locally, let the local
// scheduler know that we're now unblocked.
if (wasBlocked) {
rayletClient.notifyUnblocked(taskId);
rayletClient.notifyUnblocked(workerContext.getCurrentTaskId());
}
}
}
@@ -217,7 +217,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
@Override
public <T> WaitResult<T> wait(List<RayObject<T>> waitList, int numReturns, int timeoutMs) {
return rayletClient.wait(waitList, numReturns,
timeoutMs, workerContext.getCurrentThreadTaskId());
timeoutMs, workerContext.getCurrentTaskId());
}
@Override
@@ -277,9 +277,8 @@ public abstract class AbstractRayRuntime implements RayRuntime {
*/
private TaskSpec createTaskSpec(RayFunc func, RayActorImpl actor, Object[] args,
boolean isActorCreationTask, BaseTaskOptions taskOptions) {
final TaskSpec current = workerContext.getCurrentTask();
UniqueId taskId = rayletClient.generateTaskId(current.driverId,
current.taskId, workerContext.nextCallIndex());
UniqueId taskId = rayletClient.generateTaskId(workerContext.getCurrentDriverId(),
workerContext.getCurrentTaskId(), workerContext.nextTaskIndex());
int numReturns = actor.getId().isNil() ? 1 : 2;
UniqueId[] returnIds = genReturnIds(taskId, numReturns);
@@ -304,11 +303,11 @@ public abstract class AbstractRayRuntime implements RayRuntime {
if (taskOptions instanceof ActorCreationOptions) {
maxActorReconstruction = ((ActorCreationOptions) taskOptions).maxReconstructions;
}
RayFunction rayFunction = functionManager.getFunction(current.driverId, func);
RayFunction rayFunction = functionManager.getFunction(workerContext.getCurrentDriverId(), func);
return new TaskSpec(
current.driverId,
workerContext.getCurrentDriverId(),
taskId,
current.taskId,
workerContext.getCurrentTaskId(),
-1,
actorCreationId,
maxActorReconstruction,
@@ -39,7 +39,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
path += ":";
}
path += rayConfig.libraryPath.stream().collect(Collectors.joining(":"));
path += String.join(":", rayConfig.libraryPath);
// This is a hack to reset library path at runtime,
// see https://stackoverflow.com/questions/15409223/.
@@ -80,7 +80,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
rayConfig.rayletSocketName,
workerContext.getCurrentWorkerId(),
rayConfig.workerMode == WorkerMode.WORKER,
workerContext.getCurrentTask().taskId
workerContext.getCurrentDriverId()
);
// register
@@ -43,8 +43,7 @@ public class Worker {
RayFunction rayFunction = runtime.getFunctionManager()
.getFunction(spec.driverId, spec.functionDescriptor);
// Set context
runtime.getWorkerContext().setCurrentTask(spec);
runtime.getWorkerContext().setCurrentClassLoader(rayFunction.classLoader);
runtime.getWorkerContext().setCurrentTask(spec, rayFunction.classLoader);
Thread.currentThread().setContextClassLoader(rayFunction.classLoader);
// Get local actor object and arguments.
Object actor = spec.isActorTask() ? runtime.localActors.get(spec.actorId) : null;
@@ -67,6 +66,7 @@ public class Worker {
LOGGER.error("Error executing task " + spec, e);
runtime.put(returnId, new RayException("Error executing task " + spec, e));
} finally {
runtime.getWorkerContext().setCurrentTask(null, null);
Thread.currentThread().setContextClassLoader(oldLoader);
}
}
@@ -1,9 +1,6 @@
package org.ray.runtime;
import com.google.common.base.Preconditions;
import java.util.HashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.WorkerMode;
import org.ray.runtime.task.TaskSpec;
@@ -11,123 +8,114 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class WorkerContext {
private static final Logger LOGGER = LoggerFactory.getLogger(WorkerContext.class);
/**
* Worker id.
*/
private UniqueId workerId;
/**
* Current task.
*/
private TaskSpec currentTask;
private ThreadLocal<UniqueId> currentTaskId;
/**
* Current class loader.
* Number of objects that have been put from current task.
*/
private ThreadLocal<Integer> putIndex;
/**
* Number of tasks that have been submitted from current task.
*/
private ThreadLocal<Integer> taskIndex;
private UniqueId currentDriverId;
private ClassLoader currentClassLoader;
/**
* How many puts have been done by current task.
*/
private AtomicInteger currentTaskPutCount;
/**
* How many calls have been done by current task.
*/
private AtomicInteger currentTaskCallCount;
/**
* The ID of main thread which created the worker context.
*/
private long mainThreadId;
/**
* If the multi-threading warning message has been logged.
*/
private AtomicBoolean multiThreadingWarned;
public WorkerContext(WorkerMode workerMode, UniqueId driverId) {
workerId = workerMode == WorkerMode.DRIVER ? driverId : UniqueId.randomId();
currentTaskPutCount = new AtomicInteger(0);
currentTaskCallCount = new AtomicInteger(0);
currentClassLoader = null;
currentTask = createDummyTask(workerMode, driverId);
mainThreadId = Thread.currentThread().getId();
multiThreadingWarned = new AtomicBoolean(false);
taskIndex = ThreadLocal.withInitial(() -> 0);
putIndex = ThreadLocal.withInitial(() -> 0);
currentTaskId = ThreadLocal.withInitial(UniqueId::randomId);
if (workerMode == WorkerMode.DRIVER) {
workerId = driverId;
currentTaskId.set(UniqueId.randomId());
currentDriverId = driverId;
currentClassLoader = null;
} else {
workerId = UniqueId.randomId();
setCurrentTask(null, null);
}
}
/**
* Get the current thread's task ID.
* This returns the assigned task ID if called on the main thread, else a
* random task ID.
* @return For the main thread, this method returns the ID of this worker's current running task;
* for other threads, this method returns a random ID.
*/
public UniqueId getCurrentThreadTaskId() {
UniqueId taskId;
if (Thread.currentThread().getId() == mainThreadId) {
taskId = currentTask.taskId;
public UniqueId getCurrentTaskId() {
return currentTaskId.get();
}
/**
* Set the current task which is being executed by the current worker. Note, this method can only
* be called from the main thread.
*/
public void setCurrentTask(TaskSpec task, ClassLoader classLoader) {
Preconditions.checkState(
Thread.currentThread().getId() == mainThreadId,
"This method should only be called from the main thread."
);
if (task != null) {
currentTaskId.set(task.taskId);
currentDriverId = task.driverId;
} else {
taskId = UniqueId.randomId();
if (multiThreadingWarned.compareAndSet(false, true)) {
LOGGER.warn("Calling Ray.get or Ray.wait in a separate thread " +
"may lead to deadlock if the main thread blocks on this " +
"thread and there are not enough resources to execute " +
"more tasks");
}
currentTaskId.set(UniqueId.NIL);
currentDriverId = UniqueId.NIL;
}
Preconditions.checkState(!taskId.isNil());
return taskId;
}
public void setWorkerId(UniqueId workerId) {
this.workerId = workerId;
}
public TaskSpec getCurrentTask() {
return currentTask;
taskIndex.set(0);
putIndex.set(0);
currentClassLoader = classLoader;
}
/**
* Increment the put index and return the new value.
*/
public int nextPutIndex() {
return currentTaskPutCount.incrementAndGet();
putIndex.set(putIndex.get() + 1);
return putIndex.get();
}
public int nextCallIndex() {
return currentTaskCallCount.incrementAndGet();
/**
* Increment the task index and return the new value.
*/
public int nextTaskIndex() {
taskIndex.set(taskIndex.get() + 1);
return taskIndex.get();
}
/**
* @return The ID of the current worker.
*/
public UniqueId getCurrentWorkerId() {
return workerId;
}
/**
* @return If this worker is a driver, this method returns the driver ID; Otherwise, it returns
* the driver ID of the current running task.
*/
public UniqueId getCurrentDriverId() {
return currentDriverId;
}
/**
* @return The class loader which is associated with the current driver.
*/
public ClassLoader getCurrentClassLoader() {
return currentClassLoader;
}
public void setCurrentTask(TaskSpec currentTask) {
this.currentTask = currentTask;
currentTaskCallCount.set(0);
currentTaskPutCount.set(0);
}
public void setCurrentClassLoader(ClassLoader currentClassLoader) {
this.currentClassLoader = currentClassLoader;
}
private TaskSpec createDummyTask(WorkerMode workerMode, UniqueId driverId) {
return new TaskSpec(
driverId,
workerMode == WorkerMode.DRIVER ? UniqueId.randomId() : UniqueId.NIL,
UniqueId.NIL,
0,
UniqueId.NIL,
0,
UniqueId.NIL,
UniqueId.NIL,
0,
null,
null,
new HashMap<>(),
null);
}
}
@@ -95,7 +95,7 @@ public class MockObjectStore implements ObjectStoreLink {
}
private String logPrefix() {
return runtime.getWorkerContext().getCurrentTask().taskId + "-" + getUserTrace() + " -> ";
return runtime.getWorkerContext().getCurrentTaskId() + "-" + getUserTrace() + " -> ";
}
private String getUserTrace() {
@@ -79,6 +79,9 @@ public class RayletClientImpl implements RayletClient {
@Override
public void submitTask(TaskSpec spec) {
LOGGER.debug("Submitting task: {}", spec);
Preconditions.checkState(!spec.parentTaskId.isNil());
Preconditions.checkState(!spec.driverId.isNil());
ByteBuffer info = convertTaskSpecToFlatbuffer(spec);
byte[] cursorId = null;
if (!spec.getExecutionDependencies().isEmpty()) {