Propagate backend error to worker (#4039)

This commit is contained in:
Hao Chen
2019-02-16 11:39:15 +08:00
committed by GitHub
parent 4be3d0c5d3
commit de17443dc2
21 changed files with 635 additions and 258 deletions
@@ -1,12 +1,15 @@
package org.ray.runtime;
import static java.util.stream.Collectors.toList;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.tuple.Pair;
import java.util.stream.Collectors;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
@@ -21,7 +24,7 @@ import org.ray.runtime.config.RayConfig;
import org.ray.runtime.functionmanager.FunctionManager;
import org.ray.runtime.functionmanager.RayFunction;
import org.ray.runtime.objectstore.ObjectStoreProxy;
import org.ray.runtime.objectstore.ObjectStoreProxy.GetStatus;
import org.ray.runtime.objectstore.ObjectStoreProxy.GetResult;
import org.ray.runtime.raylet.RayletClient;
import org.ray.runtime.task.ArgumentsBuilder;
import org.ray.runtime.task.TaskSpec;
@@ -37,9 +40,22 @@ public abstract class AbstractRayRuntime implements RayRuntime {
private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class);
/**
* Default timeout of a get.
*/
private static final int GET_TIMEOUT_MS = 1000;
/**
* Split objects in this batch size when fetching or reconstructing them.
*/
private static final int FETCH_BATCH_SIZE = 1000;
private static final int LIMITED_RETRY_COUNTER = 10;
/**
* Print a warning every this number of attempts.
*/
private static final int WARN_PER_NUM_ATTEMPTS = 50;
/**
* Max number of ids to print in the warning message.
*/
private static final int MAX_IDS_TO_PRINT_IN_WARNING = 20;
protected RayConfig rayConfig;
protected WorkerContext workerContext;
@@ -75,7 +91,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
public <T> void put(UniqueId objectId, T obj) {
UniqueId taskId = workerContext.getCurrentTaskId();
LOGGER.debug("Putting object {}, for task {} ", objectId, taskId);
objectStoreProxy.put(objectId, obj, null);
objectStoreProxy.put(objectId, obj);
}
@@ -87,10 +103,10 @@ public abstract class AbstractRayRuntime implements RayRuntime {
*/
public RayObject<Object> putSerialized(byte[] obj) {
UniqueId objectId = UniqueIdUtil.computePutId(
workerContext.getCurrentTaskId(), workerContext.nextPutIndex());
workerContext.getCurrentTaskId(), workerContext.nextPutIndex());
UniqueId taskId = workerContext.getCurrentTaskId();
LOGGER.debug("Putting serialized object {}, for task {} ", objectId, taskId);
objectStoreProxy.putSerialized(objectId, obj, null);
objectStoreProxy.putSerialized(objectId, obj);
return new RayObjectImpl<>(objectId);
}
@@ -102,63 +118,68 @@ public abstract class AbstractRayRuntime implements RayRuntime {
@Override
public <T> List<T> get(List<UniqueId> objectIds) {
List<T> ret = new ArrayList<>(Collections.nCopies(objectIds.size(), null));
boolean wasBlocked = false;
try {
int numObjectIds = objectIds.size();
// Do an initial fetch for remote objects.
List<List<UniqueId>> fetchBatches = splitIntoBatches(objectIds);
for (List<UniqueId> batch : fetchBatches) {
rayletClient.fetchOrReconstruct(batch, true, workerContext.getCurrentTaskId());
// A map that stores the unready object ids and their original indexes.
Map<UniqueId, Integer> unready = new HashMap<>();
for (int i = 0; i < objectIds.size(); i++) {
unready.put(objectIds.get(i), i);
}
int numAttempts = 0;
// Get the objects. We initially try to get the objects immediately.
List<Pair<T, GetStatus>> ret = objectStoreProxy
.get(objectIds, GET_TIMEOUT_MS, false);
assert ret.size() == numObjectIds;
// Repeat until we get all objects.
while (!unready.isEmpty()) {
List<UniqueId> unreadyIds = new ArrayList<>(unready.keySet());
// Mapping the object IDs that we haven't gotten yet to their original index in objectIds.
Map<UniqueId, Integer> unreadys = new HashMap<>();
for (int i = 0; i < numObjectIds; i++) {
if (ret.get(i).getRight() != GetStatus.SUCCESS) {
unreadys.put(objectIds.get(i), i);
// For the initial fetch, we only fetch the objects, do not reconstruct them.
boolean fetchOnly = numAttempts == 0;
if (!fetchOnly) {
// If fetchOnly is false, this worker will be blocked.
wasBlocked = true;
}
}
wasBlocked = (unreadys.size() > 0);
// Try reconstructing any objects we haven't gotten yet. Try to get them
// until at least PlasmaLink.GET_TIMEOUT_MS milliseconds passes, then repeat.
int retryCounter = 0;
while (unreadys.size() > 0) {
retryCounter++;
List<UniqueId> unreadyList = new ArrayList<>(unreadys.keySet());
List<List<UniqueId>> reconstructBatches = splitIntoBatches(unreadyList);
for (List<UniqueId> batch : reconstructBatches) {
rayletClient.fetchOrReconstruct(batch, false, workerContext.getCurrentTaskId());
// Call `fetchOrReconstruct` in batches.
for (List<UniqueId> batch : splitIntoBatches(unreadyIds)) {
rayletClient.fetchOrReconstruct(batch, fetchOnly, workerContext.getCurrentTaskId());
}
List<Pair<T, GetStatus>> results = objectStoreProxy
.get(unreadyList, GET_TIMEOUT_MS, false);
// Remove any entries for objects we received during this iteration so we
// don't retrieve the same object twice.
for (int i = 0; i < results.size(); i++) {
Pair<T, GetStatus> value = results.get(i);
if (value.getRight() == GetStatus.SUCCESS) {
UniqueId id = unreadyList.get(i);
ret.set(unreadys.get(id), value);
unreadys.remove(id);
// Get the objects from the object store, and parse the result.
List<GetResult<T>> getResults = objectStoreProxy.get(unreadyIds, GET_TIMEOUT_MS);
for (int i = 0; i < getResults.size(); i++) {
GetResult<T> getResult = getResults.get(i);
if (getResult.exists) {
if (getResult.exception != null) {
// If the result is an exception, throw it.
throw getResult.exception;
} else {
// Set the result to the return list, and remove it from the unready map.
UniqueId id = unreadyIds.get(i);
ret.set(unready.get(id), getResult.object);
unready.remove(id);
}
}
}
if (retryCounter % LIMITED_RETRY_COUNTER == 0) {
LOGGER.warn("Attempted {} times to reconstruct objects {}, "
+ "but haven't received response. If this message continues to print,"
+ " it may indicate that the task is hanging, or someting wrong "
+ "happened in raylet backend.",
retryCounter, unreadys.keySet());
numAttempts += 1;
if (LOGGER.isWarnEnabled() && numAttempts % WARN_PER_NUM_ATTEMPTS == 0) {
// Print a warning if we've attempted too many times, but some objects are still
// unavailable.
List<UniqueId> idsToPrint = new ArrayList<>(unready.keySet());
if (idsToPrint.size() > MAX_IDS_TO_PRINT_IN_WARNING) {
idsToPrint = idsToPrint.subList(0, MAX_IDS_TO_PRINT_IN_WARNING);
}
String ids = idsToPrint.stream().map(UniqueId::toString)
.collect(Collectors.joining(", "));
if (idsToPrint.size() < unready.size()) {
ids += ", etc";
}
String msg = String.format("Attempted %d times to reconstruct objects,"
+ " but some objects are still unavailable. If this message continues to print,"
+ " it may indicate that object's creating task is hanging, or something wrong"
+ " happened in raylet backend. %d object(s) pending: %s.", numAttempts,
unreadyIds.size(), ids);
LOGGER.warn(msg);
}
}
@@ -167,19 +188,10 @@ public abstract class AbstractRayRuntime implements RayRuntime {
workerContext.getCurrentTaskId());
}
List<T> finalRet = new ArrayList<>();
for (Pair<T, GetStatus> value : ret) {
finalRet.add(value.getLeft());
}
return finalRet;
} catch (RayException e) {
LOGGER.error("Failed to get objects for task {}.", workerContext.getCurrentTaskId(), e);
throw e;
return ret;
} finally {
// If there were objects that we weren't able to get locally, let the local
// scheduler know that we're now unblocked.
// If there were objects that we weren't able to get locally, let the raylet backend
// know that we're now unblocked.
if (wasBlocked) {
rayletClient.notifyUnblocked(workerContext.getCurrentTaskId());
}
@@ -252,6 +264,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
/**
* Create the task specification.
*
* @param func The target remote function.
* @param actor The actor handle. If the task is not an actor task, actor id must be NIL.
* @param args The arguments for the remote function.
@@ -278,7 +291,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
}
if (!resources.containsKey(ResourceUtil.CPU_LITERAL)
&& !resources.containsKey(ResourceUtil.CPU_LITERAL.toLowerCase())) {
&& !resources.containsKey(ResourceUtil.CPU_LITERAL.toLowerCase())) {
resources.put(ResourceUtil.CPU_LITERAL, 0.0);
}
@@ -323,6 +336,10 @@ public abstract class AbstractRayRuntime implements RayRuntime {
return rayletClient;
}
public ObjectStoreProxy getObjectStoreProxy() {
return objectStoreProxy;
}
public FunctionManager getFunctionManager() {
return functionManager;
}
@@ -6,7 +6,7 @@ import java.util.List;
import org.ray.api.Checkpointable;
import org.ray.api.Checkpointable.Checkpoint;
import org.ray.api.Checkpointable.CheckpointContext;
import org.ray.api.exception.RayException;
import org.ray.api.exception.RayTaskException;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.functionmanager.RayFunction;
@@ -118,7 +118,7 @@ public class Worker {
} catch (Exception e) {
LOGGER.error("Error executing task " + spec, e);
if (!spec.isActorCreationTask()) {
runtime.put(returnId, new RayException("Error executing task " + spec, e));
runtime.put(returnId, new RayTaskException("Error executing task " + spec, e));
} else {
actorCreationException = e;
currentActorId = returnId;
@@ -1,30 +1,42 @@
package org.ray.runtime.objectstore;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.arrow.plasma.ObjectStoreLink;
import org.apache.arrow.plasma.ObjectStoreLink.ObjectStoreData;
import org.apache.arrow.plasma.PlasmaClient;
import org.apache.arrow.plasma.exceptions.DuplicateObjectException;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.exception.RayActorException;
import org.ray.api.exception.RayException;
import org.ray.api.exception.RayTaskException;
import org.ray.api.exception.RayWorkerException;
import org.ray.api.exception.UnreconstructableException;
import org.ray.api.id.UniqueId;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.RayDevRuntime;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.generated.ErrorType;
import org.ray.runtime.util.Serializer;
import org.ray.runtime.util.UniqueIdUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Object store proxy, which handles serialization and deserialization, and utilize a {@code
* org.ray.spi.ObjectStoreLink} to actually store data.
* A class that is used to put/get objects to/from the object store.
*/
public class ObjectStoreProxy {
private static final Logger LOGGER = LoggerFactory.getLogger(ObjectStoreProxy.class);
private static final int GET_TIMEOUT_MS = 1000;
private static final byte[] WORKER_EXCEPTION_META = String.valueOf(ErrorType.WORKER_DIED)
.getBytes();
private static final byte[] ACTOR_EXCEPTION_META = String.valueOf(ErrorType.ACTOR_DIED)
.getBytes();
private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String
.valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE).getBytes();
private final AbstractRayRuntime runtime;
@@ -41,68 +53,134 @@ public class ObjectStoreProxy {
});
}
public <T> Pair<T, GetStatus> get(UniqueId objectId, boolean isMetadata)
throws RayException {
return get(objectId, GET_TIMEOUT_MS, isMetadata);
/**
* Get an object from the object store.
*
* @param id Id of the object.
* @param timeoutMs Timeout in milliseconds.
* @param <T> Type of the object.
* @return The GetResult object.
*/
public <T> GetResult<T> get(UniqueId id, int timeoutMs) {
List<GetResult<T>> list = get(ImmutableList.of(id), timeoutMs);
return list.get(0);
}
public <T> Pair<T, GetStatus> get(UniqueId id, int timeoutMs, boolean isMetadata)
throws RayException {
byte[] obj = objectStore.get().get(id.getBytes(), timeoutMs, isMetadata);
if (obj != null) {
T t = Serializer.decode(obj, runtime.getWorkerContext().getCurrentClassLoader());
objectStore.get().release(id.getBytes());
if (t instanceof RayException) {
throw (RayException) t;
}
return Pair.of(t, GetStatus.SUCCESS);
} else {
return Pair.of(null, GetStatus.FAILED);
}
}
/**
* Get a list of objects from the object store.
*
* @param ids List of the object ids.
* @param timeoutMs Timeout in milliseconds.
* @param <T> Type of these objects.
* @return A list of GetResult objects.
*/
public <T> List<GetResult<T>> get(List<UniqueId> ids, int timeoutMs) {
byte[][] binaryIds = UniqueIdUtil.getIdBytes(ids);
List<ObjectStoreData> dataAndMetaList = objectStore.get().get(binaryIds, timeoutMs);
public <T> List<Pair<T, GetStatus>> get(List<UniqueId> objectIds, boolean isMetadata)
throws RayException {
return get(objectIds, GET_TIMEOUT_MS, isMetadata);
}
List<GetResult<T>> results = new ArrayList<>();
for (int i = 0; i < dataAndMetaList.size(); i++) {
// TODO(hchen): Plasma API returns data and metadata in wrong order, this should be fixed
// from the arrow side first.
byte[] meta = dataAndMetaList.get(i).data;
byte[] data = dataAndMetaList.get(i).metadata;
public <T> List<Pair<T, GetStatus>> get(List<UniqueId> ids, int timeoutMs, boolean isMetadata)
throws RayException {
List<byte[]> objs = objectStore.get().get(UniqueIdUtil.getIdBytes(ids), timeoutMs, isMetadata);
List<Pair<T, GetStatus>> ret = new ArrayList<>();
for (int i = 0; i < objs.size(); i++) {
byte[] obj = objs.get(i);
if (obj != null) {
T t = Serializer.decode(obj, runtime.getWorkerContext().getCurrentClassLoader());
objectStore.get().release(ids.get(i).getBytes());
if (t instanceof RayException) {
throw (RayException) t;
GetResult<T> result;
if (meta != null) {
// If meta is not null, deserialize the exception.
RayException exception = deserializeRayExceptionFromMeta(meta, ids.get(i));
result = new GetResult<>(true, null, exception);
} else if (data != null) {
// If data is not null, deserialize the Java object.
Object object = Serializer.decode(data, runtime.getWorkerContext().getCurrentClassLoader());
if (object instanceof RayException) {
// If the object is a `RayException`, it means that an error occurred during task
// execution.
result = new GetResult<>(true, null, (RayException) object);
} else {
// Otherwise, the object is valid.
result = new GetResult<>(true, (T) object, null);
}
ret.add(Pair.of(t, GetStatus.SUCCESS));
} else {
ret.add(Pair.of(null, GetStatus.FAILED));
// If both meta and data are null, the object doesn't exist in object store.
result = new GetResult<>(false, null, null);
}
if (meta != null || data != null) {
// Release the object from object store..
objectStore.get().release(binaryIds[i]);
}
results.add(result);
}
return ret;
return results;
}
public void put(UniqueId id, Object obj, Object metadata) {
private RayException deserializeRayExceptionFromMeta(byte[] meta, UniqueId objectId) {
if (Arrays.equals(meta, WORKER_EXCEPTION_META)) {
return RayWorkerException.INSTANCE;
} else if (Arrays.equals(meta, ACTOR_EXCEPTION_META)) {
return RayActorException.INSTANCE;
} else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) {
return new UnreconstructableException(objectId);
}
throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta));
}
/**
* Serialize and put an object to the object store.
*
* @param id Id of the object.
* @param object The object to put.
*/
public void put(UniqueId id, Object object) {
try {
objectStore.get().put(id.getBytes(), Serializer.encode(obj), Serializer.encode(metadata));
objectStore.get().put(id.getBytes(), Serializer.encode(object), null);
} catch (DuplicateObjectException e) {
LOGGER.warn(e.getMessage());
}
}
public void putSerialized(UniqueId id, byte[] obj, byte[] metadata) {
/**
* Put an already serialized object to the object store.
*
* @param id Id of the object.
* @param serializedObject The serialized object to put.
*/
public void putSerialized(UniqueId id, byte[] serializedObject) {
try {
objectStore.get().put(id.getBytes(), obj, metadata);
objectStore.get().put(id.getBytes(), serializedObject, null);
} catch (DuplicateObjectException e) {
LOGGER.warn(e.getMessage());
}
}
public enum GetStatus {
SUCCESS, FAILED
/**
* A class that represents the result of a get operation.
*/
public static class GetResult<T> {
/**
* Whether this object exists in object store.
*/
public final boolean exists;
/**
* The Java object that was fetched and deserialized from the object store. Note, this field
* only makes sense when @code{exists == true && exception !=null}.
*/
public final T object;
/**
* If this field is not null, it represents the exception that occurred during object's creating
* task.
*/
public final RayException exception;
GetResult(boolean exists, T object, RayException exception) {
this.exists = exists;
this.object = object;
this.exception = exception;
}
}
}