mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 07:45:26 +08:00
Allow multiple threads to call ray.get and ray.wait (#3244)
* Handle multiple threads calling ray.get * Multithreaded ray.wait * Pass in current task ID in java backend * Add multithreaded actor to tests, add warning messages to worker for multithreaded ray.get * Fix test * Some cleanups * Improve error message * Add assertion * Cleanup, throw error in HandleTaskUnblocked if task not actually blocked * lint * Fix python worker reset * Fix references to reconstruct_objects * Linting * java lint * Fix java * Fix iterator
This commit is contained in:
@@ -88,6 +88,8 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
@Override
|
||||
public <T> List<T> get(List<UniqueId> objectIds) {
|
||||
boolean wasBlocked = false;
|
||||
// TODO(swang): If we are not on the main thread, then we should generate a
|
||||
// random task ID to pass to the backend.
|
||||
UniqueId taskId = workerContext.getCurrentTask().taskId;
|
||||
|
||||
try {
|
||||
@@ -97,7 +99,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
List<List<UniqueId>> fetchBatches =
|
||||
splitIntoBatches(objectIds, FETCH_BATCH_SIZE);
|
||||
for (List<UniqueId> batch : fetchBatches) {
|
||||
rayletClient.reconstructObjects(batch, true);
|
||||
rayletClient.fetchOrReconstruct(batch, true, taskId);
|
||||
}
|
||||
|
||||
// Get the objects. We initially try to get the objects immediately.
|
||||
@@ -122,7 +124,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
splitIntoBatches(unreadyList, FETCH_BATCH_SIZE);
|
||||
|
||||
for (List<UniqueId> batch : reconstructBatches) {
|
||||
rayletClient.reconstructObjects(batch, false);
|
||||
rayletClient.fetchOrReconstruct(batch, false, taskId);
|
||||
}
|
||||
|
||||
List<Pair<T, GetStatus>> results = objectStoreProxy
|
||||
@@ -157,7 +159,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
// If there were objects that we weren't able to get locally, let the local
|
||||
// scheduler know that we're now unblocked.
|
||||
if (wasBlocked) {
|
||||
rayletClient.notifyUnblocked();
|
||||
rayletClient.notifyUnblocked(taskId);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -185,7 +187,10 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
|
||||
@Override
|
||||
public <T> WaitResult<T> wait(List<RayObject<T>> waitList, int numReturns, int timeoutMs) {
|
||||
return rayletClient.wait(waitList, numReturns, timeoutMs);
|
||||
// TODO(swang): If we are not on the main thread, then we should generate a
|
||||
// random task ID to pass to the backend.
|
||||
return rayletClient.wait(waitList, numReturns, timeoutMs,
|
||||
workerContext.getCurrentTask().taskId);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -66,12 +66,13 @@ public class MockRayletClient implements RayletClient {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reconstructObjects(List<UniqueId> objectIds, boolean fetchOnly) {
|
||||
public void fetchOrReconstruct(List<UniqueId> objectIds, boolean fetchOnly,
|
||||
UniqueId currentTaskId) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void notifyUnblocked() {
|
||||
public void notifyUnblocked(UniqueId currentTaskId) {
|
||||
|
||||
}
|
||||
|
||||
@@ -81,7 +82,8 @@ public class MockRayletClient implements RayletClient {
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int timeoutMs) {
|
||||
public <T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int
|
||||
timeoutMs, UniqueId currentTaskId) {
|
||||
return new WaitResult<T>(
|
||||
waitFor,
|
||||
ImmutableList.of()
|
||||
|
||||
@@ -15,13 +15,14 @@ public interface RayletClient {
|
||||
|
||||
TaskSpec getTask();
|
||||
|
||||
void reconstructObjects(List<UniqueId> objectIds, boolean fetchOnly);
|
||||
void fetchOrReconstruct(List<UniqueId> objectIds, boolean fetchOnly, UniqueId currentTaskId);
|
||||
|
||||
void notifyUnblocked();
|
||||
void notifyUnblocked(UniqueId currentTaskId);
|
||||
|
||||
UniqueId generateTaskId(UniqueId driverId, UniqueId parentTaskId, int taskIndex);
|
||||
|
||||
<T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int timeoutMs);
|
||||
<T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int
|
||||
timeoutMs, UniqueId currentTaskId);
|
||||
|
||||
void freePlasmaObjects(List<UniqueId> objectIds, boolean localOnly);
|
||||
}
|
||||
|
||||
@@ -44,14 +44,15 @@ public class RayletClientImpl implements RayletClient {
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int timeoutMs) {
|
||||
public <T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int
|
||||
timeoutMs, UniqueId currentTaskId) {
|
||||
List<UniqueId> ids = new ArrayList<>();
|
||||
for (RayObject<T> element : waitFor) {
|
||||
ids.add(element.getId());
|
||||
}
|
||||
|
||||
boolean[] ready = nativeWaitObject(client, UniqueIdUtil.getIdBytes(ids),
|
||||
numReturns, timeoutMs, false);
|
||||
numReturns, timeoutMs, false, currentTaskId.getBytes());
|
||||
List<RayObject<T>> readyList = new ArrayList<>();
|
||||
List<RayObject<T>> unreadyList = new ArrayList<>();
|
||||
|
||||
@@ -87,12 +88,14 @@ public class RayletClientImpl implements RayletClient {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reconstructObjects(List<UniqueId> objectIds, boolean fetchOnly) {
|
||||
public void fetchOrReconstruct(List<UniqueId> objectIds, boolean fetchOnly,
|
||||
UniqueId currentTaskId) {
|
||||
if (RayLog.core.isInfoEnabled()) {
|
||||
RayLog.core.info("Reconstructing objects for task {}, object IDs are {}",
|
||||
RayLog.core.info("Blocked on objects for task {}, object IDs are {}",
|
||||
UniqueIdUtil.computeTaskId(objectIds.get(0)), objectIds);
|
||||
}
|
||||
nativeReconstructObjects(client, UniqueIdUtil.getIdBytes(objectIds), fetchOnly);
|
||||
nativeFetchOrReconstruct(client, UniqueIdUtil.getIdBytes(objectIds),
|
||||
fetchOnly, currentTaskId.getBytes());
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -102,8 +105,8 @@ public class RayletClientImpl implements RayletClient {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void notifyUnblocked() {
|
||||
nativeNotifyUnblocked(client);
|
||||
public void notifyUnblocked(UniqueId currentTaskId) {
|
||||
nativeNotifyUnblocked(client, currentTaskId.getBytes());
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -271,15 +274,15 @@ public class RayletClientImpl implements RayletClient {
|
||||
|
||||
private static native void nativeDestroy(long client);
|
||||
|
||||
private static native void nativeReconstructObjects(long client, byte[][] objectIds,
|
||||
boolean fetchOnly);
|
||||
private static native void nativeFetchOrReconstruct(long client, byte[][] objectIds,
|
||||
boolean fetchOnly, byte[] currentTaskId);
|
||||
|
||||
private static native void nativeNotifyUnblocked(long client);
|
||||
private static native void nativeNotifyUnblocked(long client, byte[] currentTaskId);
|
||||
|
||||
private static native void nativePutObject(long client, byte[] taskId, byte[] objectId);
|
||||
|
||||
private static native boolean[] nativeWaitObject(long conn, byte[][] objectIds,
|
||||
int numReturns, int timeout, boolean waitLocal);
|
||||
int numReturns, int timeout, boolean waitLocal, byte[] currentTaskId);
|
||||
|
||||
private static native byte[] nativeGenerateTaskId(byte[] driverId, byte[] parentTaskId,
|
||||
int taskIndex);
|
||||
|
||||
Reference in New Issue
Block a user