Allow multiple threads to call ray.get and ray.wait (#3244)

* Handle multiple threads calling ray.get

* Multithreaded ray.wait

* Pass in current task ID in java backend

* Add multithreaded actor to tests, add warning messages to worker for multithreaded ray.get

* Fix test

* Some cleanups

* Improve error message

* Add assertion

* Cleanup, throw error in HandleTaskUnblocked if task not actually blocked

* lint

* Fix python worker reset

* Fix references to reconstruct_objects

* Linting

* java lint

* Fix java

* Fix iterator
This commit is contained in:
Stephanie Wang
2018-11-07 22:39:28 -08:00
committed by GitHub
parent 0bab8ed95c
commit d950e92f63
23 changed files with 460 additions and 281 deletions
@@ -88,6 +88,8 @@ public abstract class AbstractRayRuntime implements RayRuntime {
@Override
public <T> List<T> get(List<UniqueId> objectIds) {
boolean wasBlocked = false;
// TODO(swang): If we are not on the main thread, then we should generate a
// random task ID to pass to the backend.
UniqueId taskId = workerContext.getCurrentTask().taskId;
try {
@@ -97,7 +99,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
List<List<UniqueId>> fetchBatches =
splitIntoBatches(objectIds, FETCH_BATCH_SIZE);
for (List<UniqueId> batch : fetchBatches) {
rayletClient.reconstructObjects(batch, true);
rayletClient.fetchOrReconstruct(batch, true, taskId);
}
// Get the objects. We initially try to get the objects immediately.
@@ -122,7 +124,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
splitIntoBatches(unreadyList, FETCH_BATCH_SIZE);
for (List<UniqueId> batch : reconstructBatches) {
rayletClient.reconstructObjects(batch, false);
rayletClient.fetchOrReconstruct(batch, false, taskId);
}
List<Pair<T, GetStatus>> results = objectStoreProxy
@@ -157,7 +159,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
// If there were objects that we weren't able to get locally, let the local
// scheduler know that we're now unblocked.
if (wasBlocked) {
rayletClient.notifyUnblocked();
rayletClient.notifyUnblocked(taskId);
}
}
}
@@ -185,7 +187,10 @@ public abstract class AbstractRayRuntime implements RayRuntime {
@Override
public <T> WaitResult<T> wait(List<RayObject<T>> waitList, int numReturns, int timeoutMs) {
return rayletClient.wait(waitList, numReturns, timeoutMs);
// TODO(swang): If we are not on the main thread, then we should generate a
// random task ID to pass to the backend.
return rayletClient.wait(waitList, numReturns, timeoutMs,
workerContext.getCurrentTask().taskId);
}
@Override
@@ -66,12 +66,13 @@ public class MockRayletClient implements RayletClient {
}
@Override
public void reconstructObjects(List<UniqueId> objectIds, boolean fetchOnly) {
public void fetchOrReconstruct(List<UniqueId> objectIds, boolean fetchOnly,
UniqueId currentTaskId) {
}
@Override
public void notifyUnblocked() {
public void notifyUnblocked(UniqueId currentTaskId) {
}
@@ -81,7 +82,8 @@ public class MockRayletClient implements RayletClient {
}
@Override
public <T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int timeoutMs) {
public <T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int
timeoutMs, UniqueId currentTaskId) {
return new WaitResult<T>(
waitFor,
ImmutableList.of()
@@ -15,13 +15,14 @@ public interface RayletClient {
TaskSpec getTask();
void reconstructObjects(List<UniqueId> objectIds, boolean fetchOnly);
void fetchOrReconstruct(List<UniqueId> objectIds, boolean fetchOnly, UniqueId currentTaskId);
void notifyUnblocked();
void notifyUnblocked(UniqueId currentTaskId);
UniqueId generateTaskId(UniqueId driverId, UniqueId parentTaskId, int taskIndex);
<T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int timeoutMs);
<T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int
timeoutMs, UniqueId currentTaskId);
void freePlasmaObjects(List<UniqueId> objectIds, boolean localOnly);
}
@@ -44,14 +44,15 @@ public class RayletClientImpl implements RayletClient {
}
@Override
public <T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int timeoutMs) {
public <T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int
timeoutMs, UniqueId currentTaskId) {
List<UniqueId> ids = new ArrayList<>();
for (RayObject<T> element : waitFor) {
ids.add(element.getId());
}
boolean[] ready = nativeWaitObject(client, UniqueIdUtil.getIdBytes(ids),
numReturns, timeoutMs, false);
numReturns, timeoutMs, false, currentTaskId.getBytes());
List<RayObject<T>> readyList = new ArrayList<>();
List<RayObject<T>> unreadyList = new ArrayList<>();
@@ -87,12 +88,14 @@ public class RayletClientImpl implements RayletClient {
}
@Override
public void reconstructObjects(List<UniqueId> objectIds, boolean fetchOnly) {
public void fetchOrReconstruct(List<UniqueId> objectIds, boolean fetchOnly,
UniqueId currentTaskId) {
if (RayLog.core.isInfoEnabled()) {
RayLog.core.info("Reconstructing objects for task {}, object IDs are {}",
RayLog.core.info("Blocked on objects for task {}, object IDs are {}",
UniqueIdUtil.computeTaskId(objectIds.get(0)), objectIds);
}
nativeReconstructObjects(client, UniqueIdUtil.getIdBytes(objectIds), fetchOnly);
nativeFetchOrReconstruct(client, UniqueIdUtil.getIdBytes(objectIds),
fetchOnly, currentTaskId.getBytes());
}
@Override
@@ -102,8 +105,8 @@ public class RayletClientImpl implements RayletClient {
}
@Override
public void notifyUnblocked() {
nativeNotifyUnblocked(client);
public void notifyUnblocked(UniqueId currentTaskId) {
nativeNotifyUnblocked(client, currentTaskId.getBytes());
}
@Override
@@ -271,15 +274,15 @@ public class RayletClientImpl implements RayletClient {
private static native void nativeDestroy(long client);
private static native void nativeReconstructObjects(long client, byte[][] objectIds,
boolean fetchOnly);
private static native void nativeFetchOrReconstruct(long client, byte[][] objectIds,
boolean fetchOnly, byte[] currentTaskId);
private static native void nativeNotifyUnblocked(long client);
private static native void nativeNotifyUnblocked(long client, byte[] currentTaskId);
private static native void nativePutObject(long client, byte[] taskId, byte[] objectId);
private static native boolean[] nativeWaitObject(long conn, byte[][] objectIds,
int numReturns, int timeout, boolean waitLocal);
int numReturns, int timeout, boolean waitLocal, byte[] currentTaskId);
private static native byte[] nativeGenerateTaskId(byte[] driverId, byte[] parentTaskId,
int taskIndex);