mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 14:31:15 +08:00
Propagate backend error to worker (#4039)
This commit is contained in:
@@ -1,12 +1,15 @@
|
||||
package org.ray.runtime;
|
||||
|
||||
import static java.util.stream.Collectors.toList;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import java.util.stream.Collectors;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.WaitResult;
|
||||
@@ -21,7 +24,7 @@ import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.functionmanager.FunctionManager;
|
||||
import org.ray.runtime.functionmanager.RayFunction;
|
||||
import org.ray.runtime.objectstore.ObjectStoreProxy;
|
||||
import org.ray.runtime.objectstore.ObjectStoreProxy.GetStatus;
|
||||
import org.ray.runtime.objectstore.ObjectStoreProxy.GetResult;
|
||||
import org.ray.runtime.raylet.RayletClient;
|
||||
import org.ray.runtime.task.ArgumentsBuilder;
|
||||
import org.ray.runtime.task.TaskSpec;
|
||||
@@ -37,9 +40,22 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class);
|
||||
|
||||
/**
|
||||
* Default timeout of a get.
|
||||
*/
|
||||
private static final int GET_TIMEOUT_MS = 1000;
|
||||
/**
|
||||
* Split objects in this batch size when fetching or reconstructing them.
|
||||
*/
|
||||
private static final int FETCH_BATCH_SIZE = 1000;
|
||||
private static final int LIMITED_RETRY_COUNTER = 10;
|
||||
/**
|
||||
* Print a warning every this number of attempts.
|
||||
*/
|
||||
private static final int WARN_PER_NUM_ATTEMPTS = 50;
|
||||
/**
|
||||
* Max number of ids to print in the warning message.
|
||||
*/
|
||||
private static final int MAX_IDS_TO_PRINT_IN_WARNING = 20;
|
||||
|
||||
protected RayConfig rayConfig;
|
||||
protected WorkerContext workerContext;
|
||||
@@ -75,7 +91,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
public <T> void put(UniqueId objectId, T obj) {
|
||||
UniqueId taskId = workerContext.getCurrentTaskId();
|
||||
LOGGER.debug("Putting object {}, for task {} ", objectId, taskId);
|
||||
objectStoreProxy.put(objectId, obj, null);
|
||||
objectStoreProxy.put(objectId, obj);
|
||||
}
|
||||
|
||||
|
||||
@@ -87,10 +103,10 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
*/
|
||||
public RayObject<Object> putSerialized(byte[] obj) {
|
||||
UniqueId objectId = UniqueIdUtil.computePutId(
|
||||
workerContext.getCurrentTaskId(), workerContext.nextPutIndex());
|
||||
workerContext.getCurrentTaskId(), workerContext.nextPutIndex());
|
||||
UniqueId taskId = workerContext.getCurrentTaskId();
|
||||
LOGGER.debug("Putting serialized object {}, for task {} ", objectId, taskId);
|
||||
objectStoreProxy.putSerialized(objectId, obj, null);
|
||||
objectStoreProxy.putSerialized(objectId, obj);
|
||||
return new RayObjectImpl<>(objectId);
|
||||
}
|
||||
|
||||
@@ -102,63 +118,68 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
|
||||
@Override
|
||||
public <T> List<T> get(List<UniqueId> objectIds) {
|
||||
List<T> ret = new ArrayList<>(Collections.nCopies(objectIds.size(), null));
|
||||
boolean wasBlocked = false;
|
||||
|
||||
try {
|
||||
int numObjectIds = objectIds.size();
|
||||
|
||||
// Do an initial fetch for remote objects.
|
||||
List<List<UniqueId>> fetchBatches = splitIntoBatches(objectIds);
|
||||
for (List<UniqueId> batch : fetchBatches) {
|
||||
rayletClient.fetchOrReconstruct(batch, true, workerContext.getCurrentTaskId());
|
||||
// A map that stores the unready object ids and their original indexes.
|
||||
Map<UniqueId, Integer> unready = new HashMap<>();
|
||||
for (int i = 0; i < objectIds.size(); i++) {
|
||||
unready.put(objectIds.get(i), i);
|
||||
}
|
||||
int numAttempts = 0;
|
||||
|
||||
// Get the objects. We initially try to get the objects immediately.
|
||||
List<Pair<T, GetStatus>> ret = objectStoreProxy
|
||||
.get(objectIds, GET_TIMEOUT_MS, false);
|
||||
assert ret.size() == numObjectIds;
|
||||
// Repeat until we get all objects.
|
||||
while (!unready.isEmpty()) {
|
||||
List<UniqueId> unreadyIds = new ArrayList<>(unready.keySet());
|
||||
|
||||
// Mapping the object IDs that we haven't gotten yet to their original index in objectIds.
|
||||
Map<UniqueId, Integer> unreadys = new HashMap<>();
|
||||
for (int i = 0; i < numObjectIds; i++) {
|
||||
if (ret.get(i).getRight() != GetStatus.SUCCESS) {
|
||||
unreadys.put(objectIds.get(i), i);
|
||||
// For the initial fetch, we only fetch the objects, do not reconstruct them.
|
||||
boolean fetchOnly = numAttempts == 0;
|
||||
if (!fetchOnly) {
|
||||
// If fetchOnly is false, this worker will be blocked.
|
||||
wasBlocked = true;
|
||||
}
|
||||
}
|
||||
wasBlocked = (unreadys.size() > 0);
|
||||
|
||||
// Try reconstructing any objects we haven't gotten yet. Try to get them
|
||||
// until at least PlasmaLink.GET_TIMEOUT_MS milliseconds passes, then repeat.
|
||||
int retryCounter = 0;
|
||||
while (unreadys.size() > 0) {
|
||||
retryCounter++;
|
||||
List<UniqueId> unreadyList = new ArrayList<>(unreadys.keySet());
|
||||
List<List<UniqueId>> reconstructBatches = splitIntoBatches(unreadyList);
|
||||
|
||||
for (List<UniqueId> batch : reconstructBatches) {
|
||||
rayletClient.fetchOrReconstruct(batch, false, workerContext.getCurrentTaskId());
|
||||
// Call `fetchOrReconstruct` in batches.
|
||||
for (List<UniqueId> batch : splitIntoBatches(unreadyIds)) {
|
||||
rayletClient.fetchOrReconstruct(batch, fetchOnly, workerContext.getCurrentTaskId());
|
||||
}
|
||||
|
||||
List<Pair<T, GetStatus>> results = objectStoreProxy
|
||||
.get(unreadyList, GET_TIMEOUT_MS, false);
|
||||
|
||||
// Remove any entries for objects we received during this iteration so we
|
||||
// don't retrieve the same object twice.
|
||||
for (int i = 0; i < results.size(); i++) {
|
||||
Pair<T, GetStatus> value = results.get(i);
|
||||
if (value.getRight() == GetStatus.SUCCESS) {
|
||||
UniqueId id = unreadyList.get(i);
|
||||
ret.set(unreadys.get(id), value);
|
||||
unreadys.remove(id);
|
||||
// Get the objects from the object store, and parse the result.
|
||||
List<GetResult<T>> getResults = objectStoreProxy.get(unreadyIds, GET_TIMEOUT_MS);
|
||||
for (int i = 0; i < getResults.size(); i++) {
|
||||
GetResult<T> getResult = getResults.get(i);
|
||||
if (getResult.exists) {
|
||||
if (getResult.exception != null) {
|
||||
// If the result is an exception, throw it.
|
||||
throw getResult.exception;
|
||||
} else {
|
||||
// Set the result to the return list, and remove it from the unready map.
|
||||
UniqueId id = unreadyIds.get(i);
|
||||
ret.set(unready.get(id), getResult.object);
|
||||
unready.remove(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (retryCounter % LIMITED_RETRY_COUNTER == 0) {
|
||||
LOGGER.warn("Attempted {} times to reconstruct objects {}, "
|
||||
+ "but haven't received response. If this message continues to print,"
|
||||
+ " it may indicate that the task is hanging, or someting wrong "
|
||||
+ "happened in raylet backend.",
|
||||
retryCounter, unreadys.keySet());
|
||||
numAttempts += 1;
|
||||
if (LOGGER.isWarnEnabled() && numAttempts % WARN_PER_NUM_ATTEMPTS == 0) {
|
||||
// Print a warning if we've attempted too many times, but some objects are still
|
||||
// unavailable.
|
||||
List<UniqueId> idsToPrint = new ArrayList<>(unready.keySet());
|
||||
if (idsToPrint.size() > MAX_IDS_TO_PRINT_IN_WARNING) {
|
||||
idsToPrint = idsToPrint.subList(0, MAX_IDS_TO_PRINT_IN_WARNING);
|
||||
}
|
||||
String ids = idsToPrint.stream().map(UniqueId::toString)
|
||||
.collect(Collectors.joining(", "));
|
||||
if (idsToPrint.size() < unready.size()) {
|
||||
ids += ", etc";
|
||||
}
|
||||
String msg = String.format("Attempted %d times to reconstruct objects,"
|
||||
+ " but some objects are still unavailable. If this message continues to print,"
|
||||
+ " it may indicate that object's creating task is hanging, or something wrong"
|
||||
+ " happened in raylet backend. %d object(s) pending: %s.", numAttempts,
|
||||
unreadyIds.size(), ids);
|
||||
LOGGER.warn(msg);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -167,19 +188,10 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
workerContext.getCurrentTaskId());
|
||||
}
|
||||
|
||||
List<T> finalRet = new ArrayList<>();
|
||||
|
||||
for (Pair<T, GetStatus> value : ret) {
|
||||
finalRet.add(value.getLeft());
|
||||
}
|
||||
|
||||
return finalRet;
|
||||
} catch (RayException e) {
|
||||
LOGGER.error("Failed to get objects for task {}.", workerContext.getCurrentTaskId(), e);
|
||||
throw e;
|
||||
return ret;
|
||||
} finally {
|
||||
// If there were objects that we weren't able to get locally, let the local
|
||||
// scheduler know that we're now unblocked.
|
||||
// If there were objects that we weren't able to get locally, let the raylet backend
|
||||
// know that we're now unblocked.
|
||||
if (wasBlocked) {
|
||||
rayletClient.notifyUnblocked(workerContext.getCurrentTaskId());
|
||||
}
|
||||
@@ -252,6 +264,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
|
||||
/**
|
||||
* Create the task specification.
|
||||
*
|
||||
* @param func The target remote function.
|
||||
* @param actor The actor handle. If the task is not an actor task, actor id must be NIL.
|
||||
* @param args The arguments for the remote function.
|
||||
@@ -278,7 +291,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
}
|
||||
|
||||
if (!resources.containsKey(ResourceUtil.CPU_LITERAL)
|
||||
&& !resources.containsKey(ResourceUtil.CPU_LITERAL.toLowerCase())) {
|
||||
&& !resources.containsKey(ResourceUtil.CPU_LITERAL.toLowerCase())) {
|
||||
resources.put(ResourceUtil.CPU_LITERAL, 0.0);
|
||||
}
|
||||
|
||||
@@ -323,6 +336,10 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
return rayletClient;
|
||||
}
|
||||
|
||||
public ObjectStoreProxy getObjectStoreProxy() {
|
||||
return objectStoreProxy;
|
||||
}
|
||||
|
||||
public FunctionManager getFunctionManager() {
|
||||
return functionManager;
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import java.util.List;
|
||||
import org.ray.api.Checkpointable;
|
||||
import org.ray.api.Checkpointable.Checkpoint;
|
||||
import org.ray.api.Checkpointable.CheckpointContext;
|
||||
import org.ray.api.exception.RayException;
|
||||
import org.ray.api.exception.RayTaskException;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
import org.ray.runtime.functionmanager.RayFunction;
|
||||
@@ -118,7 +118,7 @@ public class Worker {
|
||||
} catch (Exception e) {
|
||||
LOGGER.error("Error executing task " + spec, e);
|
||||
if (!spec.isActorCreationTask()) {
|
||||
runtime.put(returnId, new RayException("Error executing task " + spec, e));
|
||||
runtime.put(returnId, new RayTaskException("Error executing task " + spec, e));
|
||||
} else {
|
||||
actorCreationException = e;
|
||||
currentActorId = returnId;
|
||||
|
||||
@@ -1,30 +1,42 @@
|
||||
package org.ray.runtime.objectstore;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.apache.arrow.plasma.ObjectStoreLink;
|
||||
import org.apache.arrow.plasma.ObjectStoreLink.ObjectStoreData;
|
||||
import org.apache.arrow.plasma.PlasmaClient;
|
||||
import org.apache.arrow.plasma.exceptions.DuplicateObjectException;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.ray.api.exception.RayActorException;
|
||||
import org.ray.api.exception.RayException;
|
||||
import org.ray.api.exception.RayTaskException;
|
||||
import org.ray.api.exception.RayWorkerException;
|
||||
import org.ray.api.exception.UnreconstructableException;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.RayDevRuntime;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
import org.ray.runtime.generated.ErrorType;
|
||||
import org.ray.runtime.util.Serializer;
|
||||
import org.ray.runtime.util.UniqueIdUtil;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* Object store proxy, which handles serialization and deserialization, and utilize a {@code
|
||||
* org.ray.spi.ObjectStoreLink} to actually store data.
|
||||
* A class that is used to put/get objects to/from the object store.
|
||||
*/
|
||||
public class ObjectStoreProxy {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(ObjectStoreProxy.class);
|
||||
|
||||
private static final int GET_TIMEOUT_MS = 1000;
|
||||
private static final byte[] WORKER_EXCEPTION_META = String.valueOf(ErrorType.WORKER_DIED)
|
||||
.getBytes();
|
||||
private static final byte[] ACTOR_EXCEPTION_META = String.valueOf(ErrorType.ACTOR_DIED)
|
||||
.getBytes();
|
||||
private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String
|
||||
.valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE).getBytes();
|
||||
|
||||
private final AbstractRayRuntime runtime;
|
||||
|
||||
@@ -41,68 +53,134 @@ public class ObjectStoreProxy {
|
||||
});
|
||||
}
|
||||
|
||||
public <T> Pair<T, GetStatus> get(UniqueId objectId, boolean isMetadata)
|
||||
throws RayException {
|
||||
return get(objectId, GET_TIMEOUT_MS, isMetadata);
|
||||
/**
|
||||
* Get an object from the object store.
|
||||
*
|
||||
* @param id Id of the object.
|
||||
* @param timeoutMs Timeout in milliseconds.
|
||||
* @param <T> Type of the object.
|
||||
* @return The GetResult object.
|
||||
*/
|
||||
public <T> GetResult<T> get(UniqueId id, int timeoutMs) {
|
||||
List<GetResult<T>> list = get(ImmutableList.of(id), timeoutMs);
|
||||
return list.get(0);
|
||||
}
|
||||
|
||||
public <T> Pair<T, GetStatus> get(UniqueId id, int timeoutMs, boolean isMetadata)
|
||||
throws RayException {
|
||||
byte[] obj = objectStore.get().get(id.getBytes(), timeoutMs, isMetadata);
|
||||
if (obj != null) {
|
||||
T t = Serializer.decode(obj, runtime.getWorkerContext().getCurrentClassLoader());
|
||||
objectStore.get().release(id.getBytes());
|
||||
if (t instanceof RayException) {
|
||||
throw (RayException) t;
|
||||
}
|
||||
return Pair.of(t, GetStatus.SUCCESS);
|
||||
} else {
|
||||
return Pair.of(null, GetStatus.FAILED);
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Get a list of objects from the object store.
|
||||
*
|
||||
* @param ids List of the object ids.
|
||||
* @param timeoutMs Timeout in milliseconds.
|
||||
* @param <T> Type of these objects.
|
||||
* @return A list of GetResult objects.
|
||||
*/
|
||||
public <T> List<GetResult<T>> get(List<UniqueId> ids, int timeoutMs) {
|
||||
byte[][] binaryIds = UniqueIdUtil.getIdBytes(ids);
|
||||
List<ObjectStoreData> dataAndMetaList = objectStore.get().get(binaryIds, timeoutMs);
|
||||
|
||||
public <T> List<Pair<T, GetStatus>> get(List<UniqueId> objectIds, boolean isMetadata)
|
||||
throws RayException {
|
||||
return get(objectIds, GET_TIMEOUT_MS, isMetadata);
|
||||
}
|
||||
List<GetResult<T>> results = new ArrayList<>();
|
||||
for (int i = 0; i < dataAndMetaList.size(); i++) {
|
||||
// TODO(hchen): Plasma API returns data and metadata in wrong order, this should be fixed
|
||||
// from the arrow side first.
|
||||
byte[] meta = dataAndMetaList.get(i).data;
|
||||
byte[] data = dataAndMetaList.get(i).metadata;
|
||||
|
||||
public <T> List<Pair<T, GetStatus>> get(List<UniqueId> ids, int timeoutMs, boolean isMetadata)
|
||||
throws RayException {
|
||||
List<byte[]> objs = objectStore.get().get(UniqueIdUtil.getIdBytes(ids), timeoutMs, isMetadata);
|
||||
List<Pair<T, GetStatus>> ret = new ArrayList<>();
|
||||
for (int i = 0; i < objs.size(); i++) {
|
||||
byte[] obj = objs.get(i);
|
||||
if (obj != null) {
|
||||
T t = Serializer.decode(obj, runtime.getWorkerContext().getCurrentClassLoader());
|
||||
objectStore.get().release(ids.get(i).getBytes());
|
||||
if (t instanceof RayException) {
|
||||
throw (RayException) t;
|
||||
GetResult<T> result;
|
||||
if (meta != null) {
|
||||
// If meta is not null, deserialize the exception.
|
||||
RayException exception = deserializeRayExceptionFromMeta(meta, ids.get(i));
|
||||
result = new GetResult<>(true, null, exception);
|
||||
} else if (data != null) {
|
||||
// If data is not null, deserialize the Java object.
|
||||
Object object = Serializer.decode(data, runtime.getWorkerContext().getCurrentClassLoader());
|
||||
if (object instanceof RayException) {
|
||||
// If the object is a `RayException`, it means that an error occurred during task
|
||||
// execution.
|
||||
result = new GetResult<>(true, null, (RayException) object);
|
||||
} else {
|
||||
// Otherwise, the object is valid.
|
||||
result = new GetResult<>(true, (T) object, null);
|
||||
}
|
||||
ret.add(Pair.of(t, GetStatus.SUCCESS));
|
||||
} else {
|
||||
ret.add(Pair.of(null, GetStatus.FAILED));
|
||||
// If both meta and data are null, the object doesn't exist in object store.
|
||||
result = new GetResult<>(false, null, null);
|
||||
}
|
||||
|
||||
if (meta != null || data != null) {
|
||||
// Release the object from object store..
|
||||
objectStore.get().release(binaryIds[i]);
|
||||
}
|
||||
|
||||
results.add(result);
|
||||
}
|
||||
return ret;
|
||||
return results;
|
||||
}
|
||||
|
||||
public void put(UniqueId id, Object obj, Object metadata) {
|
||||
private RayException deserializeRayExceptionFromMeta(byte[] meta, UniqueId objectId) {
|
||||
if (Arrays.equals(meta, WORKER_EXCEPTION_META)) {
|
||||
return RayWorkerException.INSTANCE;
|
||||
} else if (Arrays.equals(meta, ACTOR_EXCEPTION_META)) {
|
||||
return RayActorException.INSTANCE;
|
||||
} else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) {
|
||||
return new UnreconstructableException(objectId);
|
||||
}
|
||||
throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta));
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialize and put an object to the object store.
|
||||
*
|
||||
* @param id Id of the object.
|
||||
* @param object The object to put.
|
||||
*/
|
||||
public void put(UniqueId id, Object object) {
|
||||
try {
|
||||
objectStore.get().put(id.getBytes(), Serializer.encode(obj), Serializer.encode(metadata));
|
||||
objectStore.get().put(id.getBytes(), Serializer.encode(object), null);
|
||||
} catch (DuplicateObjectException e) {
|
||||
LOGGER.warn(e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
public void putSerialized(UniqueId id, byte[] obj, byte[] metadata) {
|
||||
/**
|
||||
* Put an already serialized object to the object store.
|
||||
*
|
||||
* @param id Id of the object.
|
||||
* @param serializedObject The serialized object to put.
|
||||
*/
|
||||
public void putSerialized(UniqueId id, byte[] serializedObject) {
|
||||
try {
|
||||
objectStore.get().put(id.getBytes(), obj, metadata);
|
||||
objectStore.get().put(id.getBytes(), serializedObject, null);
|
||||
} catch (DuplicateObjectException e) {
|
||||
LOGGER.warn(e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
public enum GetStatus {
|
||||
SUCCESS, FAILED
|
||||
/**
|
||||
* A class that represents the result of a get operation.
|
||||
*/
|
||||
public static class GetResult<T> {
|
||||
|
||||
/**
|
||||
* Whether this object exists in object store.
|
||||
*/
|
||||
public final boolean exists;
|
||||
|
||||
/**
|
||||
* The Java object that was fetched and deserialized from the object store. Note, this field
|
||||
* only makes sense when @code{exists == true && exception !=null}.
|
||||
*/
|
||||
public final T object;
|
||||
|
||||
/**
|
||||
* If this field is not null, it represents the exception that occurred during object's creating
|
||||
* task.
|
||||
*/
|
||||
public final RayException exception;
|
||||
|
||||
GetResult(boolean exists, T object, RayException exception) {
|
||||
this.exists = exists;
|
||||
this.object = object;
|
||||
this.exception = exception;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user