From de17443dc26bc5ba8463b8df0741c5b2c741220e Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Sat, 16 Feb 2019 11:39:15 +0800 Subject: [PATCH] Propagate backend error to worker (#4039) --- doc/source/conf.py | 42 +++-- .../ray/api/exception/RayActorException.java | 16 ++ .../ray/api/exception/RayTaskException.java | 15 ++ .../ray/api/exception/RayWorkerException.java | 13 ++ .../exception/UnreconstructableException.java | 23 +++ .../org/ray/runtime/AbstractRayRuntime.java | 145 ++++++++------- .../src/main/java/org/ray/runtime/Worker.java | 4 +- .../runtime/objectstore/ObjectStoreProxy.java | 166 ++++++++++++----- .../ray/api/test/ActorReconstructionTest.java | 8 +- .../main/java/org/ray/api/test/ActorTest.java | 31 ++++ .../java/org/ray/api/test/FailureTest.java | 57 +++++- python/ray/exceptions.py | 105 +++++++++++ python/ray/worker.py | 173 +++++++++--------- src/ray/gcs/format/gcs.fbs | 20 ++ src/ray/raylet/node_manager.cc | 41 +++-- src/ray/raylet/node_manager.h | 3 +- test/actor_test.py | 10 +- test/component_failures_test.py | 4 +- test/failure_test.py | 7 +- test/runtest.py | 4 +- test/test_signal.py | 6 +- 21 files changed, 635 insertions(+), 258 deletions(-) create mode 100644 java/api/src/main/java/org/ray/api/exception/RayActorException.java create mode 100644 java/api/src/main/java/org/ray/api/exception/RayTaskException.java create mode 100644 java/api/src/main/java/org/ray/api/exception/RayWorkerException.java create mode 100644 java/api/src/main/java/org/ray/api/exception/UnreconstructableException.java create mode 100644 python/ray/exceptions.py diff --git a/doc/source/conf.py b/doc/source/conf.py index bf8db51e5..057f6ba35 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -19,22 +19,38 @@ import shlex # These lines added to enable Sphinx to work without installing Ray. import mock MOCK_MODULES = [ - "gym", "gym.spaces", "scipy", "scipy.signal", "tensorflow", - "tensorflow.contrib", "tensorflow.contrib.all_reduce", - "tensorflow.contrib.all_reduce.python", "tensorflow.contrib.layers", - "tensorflow.contrib.slim", "tensorflow.contrib.rnn", "tensorflow.core", - "tensorflow.core.util", "tensorflow.python", "tensorflow.python.client", - "tensorflow.python.util", "ray.core.generated", + "gym", + "gym.spaces", + "ray._raylet", + "ray.core.generated", "ray.core.generated.ActorCheckpointIdData", - "ray.core.generated.ClientTableData", "ray.core.generated.GcsTableEntry", - "ray.core.generated.HeartbeatTableData", + "ray.core.generated.ClientTableData", + "ray.core.generated.DriverTableData", + "ray.core.generated.ErrorTableData", + "ray.core.generated.ErrorType", + "ray.core.generated.GcsTableEntry", "ray.core.generated.HeartbeatBatchTableData", - "ray.core.generated.DriverTableData", "ray.core.generated.ErrorTableData", - "ray.core.generated.ProfileTableData", + "ray.core.generated.HeartbeatTableData", + "ray.core.generated.Language", "ray.core.generated.ObjectTableData", - "ray.core.generated.ray.protocol.Task", "ray.core.generated.TablePrefix", - "ray.core.generated.TablePubsub", "ray.core.generated.Language", - "ray._raylet" + "ray.core.generated.ProfileTableData", + "ray.core.generated.TablePrefix", + "ray.core.generated.TablePubsub", + "ray.core.generated.ray.protocol.Task", + "scipy", + "scipy.signal", + "tensorflow", + "tensorflow.contrib", + "tensorflow.contrib.all_reduce", + "tensorflow.contrib.all_reduce.python", + "tensorflow.contrib.layers", + "tensorflow.contrib.rnn", + "tensorflow.contrib.slim", + "tensorflow.core", + "tensorflow.core.util", + "tensorflow.python", + "tensorflow.python.client", + "tensorflow.python.util", ] for mod_name in MOCK_MODULES: sys.modules[mod_name] = mock.Mock() diff --git a/java/api/src/main/java/org/ray/api/exception/RayActorException.java b/java/api/src/main/java/org/ray/api/exception/RayActorException.java new file mode 100644 index 000000000..42ac9d408 --- /dev/null +++ b/java/api/src/main/java/org/ray/api/exception/RayActorException.java @@ -0,0 +1,16 @@ +package org.ray.api.exception; + +/** + * Indicates that the actor died unexpectedly before finishing a task. + * + * This exception could happen either because the actor process dies while executing a task, or + * because a task is submitted to a dead actor. + */ +public class RayActorException extends RayException { + + public static final RayActorException INSTANCE = new RayActorException(); + + private RayActorException() { + super("The actor died unexpectedly before finishing this task."); + } +} diff --git a/java/api/src/main/java/org/ray/api/exception/RayTaskException.java b/java/api/src/main/java/org/ray/api/exception/RayTaskException.java new file mode 100644 index 000000000..d2ba9ac3a --- /dev/null +++ b/java/api/src/main/java/org/ray/api/exception/RayTaskException.java @@ -0,0 +1,15 @@ +package org.ray.api.exception; + +/** + * Indicates that a task threw an exception during execution. + * + * If a task throws an exception during execution, a RayTaskException is stored in the object store + * as the task's output. Then when the object is retrieved from the object store, this exception + * will be thrown and propagate the error message. + */ +public class RayTaskException extends RayException { + + public RayTaskException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/java/api/src/main/java/org/ray/api/exception/RayWorkerException.java b/java/api/src/main/java/org/ray/api/exception/RayWorkerException.java new file mode 100644 index 000000000..512cca614 --- /dev/null +++ b/java/api/src/main/java/org/ray/api/exception/RayWorkerException.java @@ -0,0 +1,13 @@ +package org.ray.api.exception; + +/** + * Indicates that the worker died unexpectedly while executing a task. + */ +public class RayWorkerException extends RayException { + + public static final RayWorkerException INSTANCE = new RayWorkerException(); + + private RayWorkerException() { + super("The worker died unexpectedly while executing this task."); + } +} diff --git a/java/api/src/main/java/org/ray/api/exception/UnreconstructableException.java b/java/api/src/main/java/org/ray/api/exception/UnreconstructableException.java new file mode 100644 index 000000000..8362295ba --- /dev/null +++ b/java/api/src/main/java/org/ray/api/exception/UnreconstructableException.java @@ -0,0 +1,23 @@ +package org.ray.api.exception; + +import org.ray.api.id.UniqueId; + +/** + * Indicates that an object is lost (either evicted or explicitly deleted) and cannot be + * reconstructed. + * + * Note, this exception only happens for actor objects. If actor's current state is after object's + * creating task, the actor cannot re-run the task to reconstruct the object. + */ +public class UnreconstructableException extends RayException { + + public final UniqueId objectId; + + public UnreconstructableException(UniqueId objectId) { + super(String.format( + "Object %s is lost (either evicted or explicitly deleted) and cannot be reconstructed.", + objectId)); + this.objectId = objectId; + } + +} diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index 2944da07e..2ad1028f8 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -1,12 +1,15 @@ package org.ray.runtime; +import static java.util.stream.Collectors.toList; + import com.google.common.collect.ImmutableList; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import org.apache.commons.lang3.tuple.Pair; +import java.util.stream.Collectors; import org.ray.api.RayActor; import org.ray.api.RayObject; import org.ray.api.WaitResult; @@ -21,7 +24,7 @@ import org.ray.runtime.config.RayConfig; import org.ray.runtime.functionmanager.FunctionManager; import org.ray.runtime.functionmanager.RayFunction; import org.ray.runtime.objectstore.ObjectStoreProxy; -import org.ray.runtime.objectstore.ObjectStoreProxy.GetStatus; +import org.ray.runtime.objectstore.ObjectStoreProxy.GetResult; import org.ray.runtime.raylet.RayletClient; import org.ray.runtime.task.ArgumentsBuilder; import org.ray.runtime.task.TaskSpec; @@ -37,9 +40,22 @@ public abstract class AbstractRayRuntime implements RayRuntime { private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class); + /** + * Default timeout of a get. + */ private static final int GET_TIMEOUT_MS = 1000; + /** + * Split objects in this batch size when fetching or reconstructing them. + */ private static final int FETCH_BATCH_SIZE = 1000; - private static final int LIMITED_RETRY_COUNTER = 10; + /** + * Print a warning every this number of attempts. + */ + private static final int WARN_PER_NUM_ATTEMPTS = 50; + /** + * Max number of ids to print in the warning message. + */ + private static final int MAX_IDS_TO_PRINT_IN_WARNING = 20; protected RayConfig rayConfig; protected WorkerContext workerContext; @@ -75,7 +91,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { public void put(UniqueId objectId, T obj) { UniqueId taskId = workerContext.getCurrentTaskId(); LOGGER.debug("Putting object {}, for task {} ", objectId, taskId); - objectStoreProxy.put(objectId, obj, null); + objectStoreProxy.put(objectId, obj); } @@ -87,10 +103,10 @@ public abstract class AbstractRayRuntime implements RayRuntime { */ public RayObject putSerialized(byte[] obj) { UniqueId objectId = UniqueIdUtil.computePutId( - workerContext.getCurrentTaskId(), workerContext.nextPutIndex()); + workerContext.getCurrentTaskId(), workerContext.nextPutIndex()); UniqueId taskId = workerContext.getCurrentTaskId(); LOGGER.debug("Putting serialized object {}, for task {} ", objectId, taskId); - objectStoreProxy.putSerialized(objectId, obj, null); + objectStoreProxy.putSerialized(objectId, obj); return new RayObjectImpl<>(objectId); } @@ -102,63 +118,68 @@ public abstract class AbstractRayRuntime implements RayRuntime { @Override public List get(List objectIds) { + List ret = new ArrayList<>(Collections.nCopies(objectIds.size(), null)); boolean wasBlocked = false; try { - int numObjectIds = objectIds.size(); - - // Do an initial fetch for remote objects. - List> fetchBatches = splitIntoBatches(objectIds); - for (List batch : fetchBatches) { - rayletClient.fetchOrReconstruct(batch, true, workerContext.getCurrentTaskId()); + // A map that stores the unready object ids and their original indexes. + Map unready = new HashMap<>(); + for (int i = 0; i < objectIds.size(); i++) { + unready.put(objectIds.get(i), i); } + int numAttempts = 0; - // Get the objects. We initially try to get the objects immediately. - List> ret = objectStoreProxy - .get(objectIds, GET_TIMEOUT_MS, false); - assert ret.size() == numObjectIds; + // Repeat until we get all objects. + while (!unready.isEmpty()) { + List unreadyIds = new ArrayList<>(unready.keySet()); - // Mapping the object IDs that we haven't gotten yet to their original index in objectIds. - Map unreadys = new HashMap<>(); - for (int i = 0; i < numObjectIds; i++) { - if (ret.get(i).getRight() != GetStatus.SUCCESS) { - unreadys.put(objectIds.get(i), i); + // For the initial fetch, we only fetch the objects, do not reconstruct them. + boolean fetchOnly = numAttempts == 0; + if (!fetchOnly) { + // If fetchOnly is false, this worker will be blocked. + wasBlocked = true; } - } - wasBlocked = (unreadys.size() > 0); - - // Try reconstructing any objects we haven't gotten yet. Try to get them - // until at least PlasmaLink.GET_TIMEOUT_MS milliseconds passes, then repeat. - int retryCounter = 0; - while (unreadys.size() > 0) { - retryCounter++; - List unreadyList = new ArrayList<>(unreadys.keySet()); - List> reconstructBatches = splitIntoBatches(unreadyList); - - for (List batch : reconstructBatches) { - rayletClient.fetchOrReconstruct(batch, false, workerContext.getCurrentTaskId()); + // Call `fetchOrReconstruct` in batches. + for (List batch : splitIntoBatches(unreadyIds)) { + rayletClient.fetchOrReconstruct(batch, fetchOnly, workerContext.getCurrentTaskId()); } - List> results = objectStoreProxy - .get(unreadyList, GET_TIMEOUT_MS, false); - - // Remove any entries for objects we received during this iteration so we - // don't retrieve the same object twice. - for (int i = 0; i < results.size(); i++) { - Pair value = results.get(i); - if (value.getRight() == GetStatus.SUCCESS) { - UniqueId id = unreadyList.get(i); - ret.set(unreadys.get(id), value); - unreadys.remove(id); + // Get the objects from the object store, and parse the result. + List> getResults = objectStoreProxy.get(unreadyIds, GET_TIMEOUT_MS); + for (int i = 0; i < getResults.size(); i++) { + GetResult getResult = getResults.get(i); + if (getResult.exists) { + if (getResult.exception != null) { + // If the result is an exception, throw it. + throw getResult.exception; + } else { + // Set the result to the return list, and remove it from the unready map. + UniqueId id = unreadyIds.get(i); + ret.set(unready.get(id), getResult.object); + unready.remove(id); + } } } - if (retryCounter % LIMITED_RETRY_COUNTER == 0) { - LOGGER.warn("Attempted {} times to reconstruct objects {}, " - + "but haven't received response. If this message continues to print," - + " it may indicate that the task is hanging, or someting wrong " - + "happened in raylet backend.", - retryCounter, unreadys.keySet()); + numAttempts += 1; + if (LOGGER.isWarnEnabled() && numAttempts % WARN_PER_NUM_ATTEMPTS == 0) { + // Print a warning if we've attempted too many times, but some objects are still + // unavailable. + List idsToPrint = new ArrayList<>(unready.keySet()); + if (idsToPrint.size() > MAX_IDS_TO_PRINT_IN_WARNING) { + idsToPrint = idsToPrint.subList(0, MAX_IDS_TO_PRINT_IN_WARNING); + } + String ids = idsToPrint.stream().map(UniqueId::toString) + .collect(Collectors.joining(", ")); + if (idsToPrint.size() < unready.size()) { + ids += ", etc"; + } + String msg = String.format("Attempted %d times to reconstruct objects," + + " but some objects are still unavailable. If this message continues to print," + + " it may indicate that object's creating task is hanging, or something wrong" + + " happened in raylet backend. %d object(s) pending: %s.", numAttempts, + unreadyIds.size(), ids); + LOGGER.warn(msg); } } @@ -167,19 +188,10 @@ public abstract class AbstractRayRuntime implements RayRuntime { workerContext.getCurrentTaskId()); } - List finalRet = new ArrayList<>(); - - for (Pair value : ret) { - finalRet.add(value.getLeft()); - } - - return finalRet; - } catch (RayException e) { - LOGGER.error("Failed to get objects for task {}.", workerContext.getCurrentTaskId(), e); - throw e; + return ret; } finally { - // If there were objects that we weren't able to get locally, let the local - // scheduler know that we're now unblocked. + // If there were objects that we weren't able to get locally, let the raylet backend + // know that we're now unblocked. if (wasBlocked) { rayletClient.notifyUnblocked(workerContext.getCurrentTaskId()); } @@ -252,6 +264,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { /** * Create the task specification. + * * @param func The target remote function. * @param actor The actor handle. If the task is not an actor task, actor id must be NIL. * @param args The arguments for the remote function. @@ -278,7 +291,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { } if (!resources.containsKey(ResourceUtil.CPU_LITERAL) - && !resources.containsKey(ResourceUtil.CPU_LITERAL.toLowerCase())) { + && !resources.containsKey(ResourceUtil.CPU_LITERAL.toLowerCase())) { resources.put(ResourceUtil.CPU_LITERAL, 0.0); } @@ -323,6 +336,10 @@ public abstract class AbstractRayRuntime implements RayRuntime { return rayletClient; } + public ObjectStoreProxy getObjectStoreProxy() { + return objectStoreProxy; + } + public FunctionManager getFunctionManager() { return functionManager; } diff --git a/java/runtime/src/main/java/org/ray/runtime/Worker.java b/java/runtime/src/main/java/org/ray/runtime/Worker.java index e18e5f5e8..79ef90107 100644 --- a/java/runtime/src/main/java/org/ray/runtime/Worker.java +++ b/java/runtime/src/main/java/org/ray/runtime/Worker.java @@ -6,7 +6,7 @@ import java.util.List; import org.ray.api.Checkpointable; import org.ray.api.Checkpointable.Checkpoint; import org.ray.api.Checkpointable.CheckpointContext; -import org.ray.api.exception.RayException; +import org.ray.api.exception.RayTaskException; import org.ray.api.id.UniqueId; import org.ray.runtime.config.RunMode; import org.ray.runtime.functionmanager.RayFunction; @@ -118,7 +118,7 @@ public class Worker { } catch (Exception e) { LOGGER.error("Error executing task " + spec, e); if (!spec.isActorCreationTask()) { - runtime.put(returnId, new RayException("Error executing task " + spec, e)); + runtime.put(returnId, new RayTaskException("Error executing task " + spec, e)); } else { actorCreationException = e; currentActorId = returnId; diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java index 31026930f..448c05364 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java @@ -1,30 +1,42 @@ package org.ray.runtime.objectstore; +import com.google.common.collect.ImmutableList; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import org.apache.arrow.plasma.ObjectStoreLink; +import org.apache.arrow.plasma.ObjectStoreLink.ObjectStoreData; import org.apache.arrow.plasma.PlasmaClient; import org.apache.arrow.plasma.exceptions.DuplicateObjectException; import org.apache.commons.lang3.tuple.Pair; +import org.ray.api.exception.RayActorException; import org.ray.api.exception.RayException; +import org.ray.api.exception.RayTaskException; +import org.ray.api.exception.RayWorkerException; +import org.ray.api.exception.UnreconstructableException; import org.ray.api.id.UniqueId; import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.RayDevRuntime; import org.ray.runtime.config.RunMode; +import org.ray.runtime.generated.ErrorType; import org.ray.runtime.util.Serializer; import org.ray.runtime.util.UniqueIdUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** - * Object store proxy, which handles serialization and deserialization, and utilize a {@code - * org.ray.spi.ObjectStoreLink} to actually store data. + * A class that is used to put/get objects to/from the object store. */ public class ObjectStoreProxy { private static final Logger LOGGER = LoggerFactory.getLogger(ObjectStoreProxy.class); - private static final int GET_TIMEOUT_MS = 1000; + private static final byte[] WORKER_EXCEPTION_META = String.valueOf(ErrorType.WORKER_DIED) + .getBytes(); + private static final byte[] ACTOR_EXCEPTION_META = String.valueOf(ErrorType.ACTOR_DIED) + .getBytes(); + private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String + .valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE).getBytes(); private final AbstractRayRuntime runtime; @@ -41,68 +53,134 @@ public class ObjectStoreProxy { }); } - public Pair get(UniqueId objectId, boolean isMetadata) - throws RayException { - return get(objectId, GET_TIMEOUT_MS, isMetadata); + /** + * Get an object from the object store. + * + * @param id Id of the object. + * @param timeoutMs Timeout in milliseconds. + * @param Type of the object. + * @return The GetResult object. + */ + public GetResult get(UniqueId id, int timeoutMs) { + List> list = get(ImmutableList.of(id), timeoutMs); + return list.get(0); } - public Pair get(UniqueId id, int timeoutMs, boolean isMetadata) - throws RayException { - byte[] obj = objectStore.get().get(id.getBytes(), timeoutMs, isMetadata); - if (obj != null) { - T t = Serializer.decode(obj, runtime.getWorkerContext().getCurrentClassLoader()); - objectStore.get().release(id.getBytes()); - if (t instanceof RayException) { - throw (RayException) t; - } - return Pair.of(t, GetStatus.SUCCESS); - } else { - return Pair.of(null, GetStatus.FAILED); - } - } + /** + * Get a list of objects from the object store. + * + * @param ids List of the object ids. + * @param timeoutMs Timeout in milliseconds. + * @param Type of these objects. + * @return A list of GetResult objects. + */ + public List> get(List ids, int timeoutMs) { + byte[][] binaryIds = UniqueIdUtil.getIdBytes(ids); + List dataAndMetaList = objectStore.get().get(binaryIds, timeoutMs); - public List> get(List objectIds, boolean isMetadata) - throws RayException { - return get(objectIds, GET_TIMEOUT_MS, isMetadata); - } + List> results = new ArrayList<>(); + for (int i = 0; i < dataAndMetaList.size(); i++) { + // TODO(hchen): Plasma API returns data and metadata in wrong order, this should be fixed + // from the arrow side first. + byte[] meta = dataAndMetaList.get(i).data; + byte[] data = dataAndMetaList.get(i).metadata; - public List> get(List ids, int timeoutMs, boolean isMetadata) - throws RayException { - List objs = objectStore.get().get(UniqueIdUtil.getIdBytes(ids), timeoutMs, isMetadata); - List> ret = new ArrayList<>(); - for (int i = 0; i < objs.size(); i++) { - byte[] obj = objs.get(i); - if (obj != null) { - T t = Serializer.decode(obj, runtime.getWorkerContext().getCurrentClassLoader()); - objectStore.get().release(ids.get(i).getBytes()); - if (t instanceof RayException) { - throw (RayException) t; + GetResult result; + if (meta != null) { + // If meta is not null, deserialize the exception. + RayException exception = deserializeRayExceptionFromMeta(meta, ids.get(i)); + result = new GetResult<>(true, null, exception); + } else if (data != null) { + // If data is not null, deserialize the Java object. + Object object = Serializer.decode(data, runtime.getWorkerContext().getCurrentClassLoader()); + if (object instanceof RayException) { + // If the object is a `RayException`, it means that an error occurred during task + // execution. + result = new GetResult<>(true, null, (RayException) object); + } else { + // Otherwise, the object is valid. + result = new GetResult<>(true, (T) object, null); } - ret.add(Pair.of(t, GetStatus.SUCCESS)); } else { - ret.add(Pair.of(null, GetStatus.FAILED)); + // If both meta and data are null, the object doesn't exist in object store. + result = new GetResult<>(false, null, null); } + + if (meta != null || data != null) { + // Release the object from object store.. + objectStore.get().release(binaryIds[i]); + } + + results.add(result); } - return ret; + return results; } - public void put(UniqueId id, Object obj, Object metadata) { + private RayException deserializeRayExceptionFromMeta(byte[] meta, UniqueId objectId) { + if (Arrays.equals(meta, WORKER_EXCEPTION_META)) { + return RayWorkerException.INSTANCE; + } else if (Arrays.equals(meta, ACTOR_EXCEPTION_META)) { + return RayActorException.INSTANCE; + } else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) { + return new UnreconstructableException(objectId); + } + throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta)); + } + + /** + * Serialize and put an object to the object store. + * + * @param id Id of the object. + * @param object The object to put. + */ + public void put(UniqueId id, Object object) { try { - objectStore.get().put(id.getBytes(), Serializer.encode(obj), Serializer.encode(metadata)); + objectStore.get().put(id.getBytes(), Serializer.encode(object), null); } catch (DuplicateObjectException e) { LOGGER.warn(e.getMessage()); } } - public void putSerialized(UniqueId id, byte[] obj, byte[] metadata) { + /** + * Put an already serialized object to the object store. + * + * @param id Id of the object. + * @param serializedObject The serialized object to put. + */ + public void putSerialized(UniqueId id, byte[] serializedObject) { try { - objectStore.get().put(id.getBytes(), obj, metadata); + objectStore.get().put(id.getBytes(), serializedObject, null); } catch (DuplicateObjectException e) { LOGGER.warn(e.getMessage()); } } - public enum GetStatus { - SUCCESS, FAILED + /** + * A class that represents the result of a get operation. + */ + public static class GetResult { + + /** + * Whether this object exists in object store. + */ + public final boolean exists; + + /** + * The Java object that was fetched and deserialized from the object store. Note, this field + * only makes sense when @code{exists == true && exception !=null}. + */ + public final T object; + + /** + * If this field is not null, it represents the exception that occurred during object's creating + * task. + */ + public final RayException exception; + + GetResult(boolean exists, T object, RayException exception) { + this.exists = exists; + this.object = object; + this.exception = exception; + } } } diff --git a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java index fd82cf4cf..c10a51516 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java @@ -10,6 +10,7 @@ import org.ray.api.Checkpointable; import org.ray.api.Ray; import org.ray.api.RayActor; import org.ray.api.annotation.RayRemote; +import org.ray.api.exception.RayActorException; import org.ray.api.id.UniqueId; import org.ray.api.options.ActorCreationOptions; import org.testng.Assert; @@ -60,11 +61,8 @@ public class ActorReconstructionTest extends BaseTest { try { Ray.call(Counter::increase, actor).get(); Assert.fail("The above task didn't fail."); - } catch (StringIndexOutOfBoundsException e) { - // Raylet backend will put invalid data in task's result to indicate the task has failed. - // Thus, Java deserialization will fail and throw `StringIndexOutOfBoundsException`. - // TODO(hchen): we should use object's metadata to indicate task failure, - // instead of throwing this exception. + } catch (RayActorException e) { + // We should receive a RayActorException because the actor is dead. } } diff --git a/java/test/src/main/java/org/ray/api/test/ActorTest.java b/java/test/src/main/java/org/ray/api/test/ActorTest.java index e0fc595e1..96be700b9 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorTest.java @@ -1,11 +1,16 @@ package org.ray.api.test; +import com.google.common.collect.ImmutableList; +import java.util.concurrent.TimeUnit; import org.ray.api.Ray; import org.ray.api.RayActor; import org.ray.api.RayObject; import org.ray.api.annotation.RayRemote; +import org.ray.api.exception.UnreconstructableException; import org.ray.api.id.UniqueId; +import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.RayActorImpl; +import org.ray.runtime.objectstore.ObjectStoreProxy.GetResult; import org.testng.Assert; import org.testng.annotations.Test; @@ -83,4 +88,30 @@ public class ActorTest extends BaseTest { Assert.assertEquals(Integer.valueOf(103), Ray.call(Counter::increase, counter2, 2).get()); } + @Test + public void testUnreconstructableActorObject() throws InterruptedException { + RayActor counter = Ray.createActor(Counter::new, 100); + // Call an actor method. + RayObject value = Ray.call(Counter::getValue, counter); + Assert.assertEquals(100, value.get()); + // Delete the object from the object store. + Ray.internal().free(ImmutableList.of(value.getId()), false); + // Wait until the object is deleted, because the above free operation is async. + while (true) { + GetResult result = ((AbstractRayRuntime) + Ray.internal()).getObjectStoreProxy().get(value.getId(), 0); + if (!result.exists) { + break; + } + TimeUnit.MILLISECONDS.sleep(100); + } + + try { + // Try getting the object again, this should throw an UnreconstructableException. + value.get(); + Assert.fail("This line should not be reachable."); + } catch (UnreconstructableException e) { + Assert.assertEquals(value.getId(), e.objectId); + } + } } diff --git a/java/test/src/main/java/org/ray/api/test/FailureTest.java b/java/test/src/main/java/org/ray/api/test/FailureTest.java index 3c1fa94d8..bf1710422 100644 --- a/java/test/src/main/java/org/ray/api/test/FailureTest.java +++ b/java/test/src/main/java/org/ray/api/test/FailureTest.java @@ -3,7 +3,9 @@ package org.ray.api.test; import org.ray.api.Ray; import org.ray.api.RayActor; import org.ray.api.RayObject; -import org.ray.api.exception.RayException; +import org.ray.api.exception.RayActorException; +import org.ray.api.exception.RayTaskException; +import org.ray.api.exception.RayWorkerException; import org.testng.Assert; import org.testng.annotations.Test; @@ -15,6 +17,11 @@ public class FailureTest extends BaseTest { throw new RuntimeException(EXCEPTION_MESSAGE); } + public static int badFunc2() { + System.exit(-1); + return 0; + } + public static class BadActor { public BadActor(boolean failOnCreation) { @@ -23,17 +30,21 @@ public class FailureTest extends BaseTest { } } - public int func() { + public int badMethod() { throw new RuntimeException(EXCEPTION_MESSAGE); } + + public int badMethod2() { + System.exit(-1); + return 0; + } } - private static void assertTaskFail(RayObject rayObject) { + private static void assertTaskFailedWithRayTaskException(RayObject rayObject) { try { rayObject.get(); Assert.fail("Task didn't fail."); - } catch (RayException e) { - e.printStackTrace(); + } catch (RayTaskException e) { Throwable rootCause = e.getCause(); while (rootCause.getCause() != null) { rootCause = rootCause.getCause(); @@ -45,19 +56,49 @@ public class FailureTest extends BaseTest { @Test public void testNormalTaskFailure() { - assertTaskFail(Ray.call(FailureTest::badFunc)); + assertTaskFailedWithRayTaskException(Ray.call(FailureTest::badFunc)); } @Test public void testActorCreationFailure() { RayActor actor = Ray.createActor(BadActor::new, true); - assertTaskFail(Ray.call(BadActor::func, actor)); + assertTaskFailedWithRayTaskException(Ray.call(BadActor::badMethod, actor)); } @Test public void testActorTaskFailure() { RayActor actor = Ray.createActor(BadActor::new, false); - assertTaskFail(Ray.call(BadActor::func, actor)); + assertTaskFailedWithRayTaskException(Ray.call(BadActor::badMethod, actor)); + } + + @Test + public void testWorkerProcessDying() { + try { + Ray.call(FailureTest::badFunc2).get(); + Assert.fail("This line shouldn't be reached."); + } catch (RayWorkerException e) { + // When the worker process dies while executing a task, we should receive an + // RayWorkerException. + } + } + + @Test + public void testActorProcessDying() { + RayActor actor = Ray.createActor(BadActor::new, false); + try { + Ray.call(BadActor::badMethod2, actor).get(); + Assert.fail("This line shouldn't be reached."); + } catch (RayActorException e) { + // When the actor process dies while executing a task, we should receive an + // RayActorException. + } + try { + Ray.call(BadActor::badMethod, actor).get(); + Assert.fail("This line shouldn't be reached."); + } catch (RayActorException e) { + // When a actor task is submitted to a dead actor, we should also receive an + // RayActorException. + } } } diff --git a/python/ray/exceptions.py b/python/ray/exceptions.py new file mode 100644 index 000000000..2cd5ed56d --- /dev/null +++ b/python/ray/exceptions.py @@ -0,0 +1,105 @@ +import os + +import colorama + +try: + import setproctitle +except ImportError: + setproctitle = None + + +class RayError(Exception): + """Super class of all ray exception types.""" + pass + + +class RayTaskError(RayError): + """Indicates that a task threw an exception during execution. + + If a task throws an exception during execution, a RayTaskError is stored in + the object store for each of the task's outputs. When an object is + retrieved from the object store, the Python method that retrieved it checks + to see if the object is a RayTaskError and if it is then an exception is + thrown propagating the error message. + + Attributes: + function_name (str): The name of the function that failed and produced + the RayTaskError. + traceback_str (str): The traceback from the exception. + """ + + def __init__(self, function_name, traceback_str): + """Initialize a RayTaskError.""" + if setproctitle: + self.proctitle = setproctitle.getproctitle() + else: + self.proctitle = "ray_worker" + self.pid = os.getpid() + self.host = os.uname()[1] + self.function_name = function_name + self.traceback_str = traceback_str + assert traceback_str is not None + + def __str__(self): + """Format a RayTaskError as a string.""" + lines = self.traceback_str.split("\n") + out = [] + in_worker = False + for line in lines: + if line.startswith("Traceback "): + out.append("{}{}{} (pid={}, host={})".format( + colorama.Fore.CYAN, self.proctitle, colorama.Fore.RESET, + self.pid, self.host)) + elif in_worker: + in_worker = False + elif "ray/worker.py" in line or "ray/function_manager.py" in line: + in_worker = True + else: + out.append(line) + return "\n".join(out) + + +class RayWorkerError(RayError): + """Indicates that the worker died unexpectedly while executing a task.""" + + def __str__(self): + return "The worker died unexpectedly while executing this task." + + +class RayActorError(RayError): + """Indicates that the actor died unexpectedly before finishing a task. + + This exception could happen either because the actor process dies while + executing a task, or because a task is submitted to a dead actor. + """ + + def __str__(self): + return "The actor died unexpectedly before finishing this task." + + +class UnreconstructableError(RayError): + """Indicates that an object is lost and cannot be reconstructed. + + Note, this exception only happens for actor objects. If actor's current + state is after object's creating task, the actor cannot re-run the task to + reconstruct the object. + + Attributes: + object_id: ID of the object. + """ + + def __init__(self, object_id): + self.object_id = object_id + + def __str__(self): + return ("Object {} is lost (either evicted or explicitly deleted) and " + + "cannot be reconstructed.").format(self.object_id.hex()) + + +RAY_EXCEPTION_TYPES = [ + RayError, + RayTaskError, + RayWorkerError, + RayActorError, + UnreconstructableError, +] diff --git a/python/ray/worker.py b/python/ray/worker.py index 01aebf44c..556a4f765 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -4,7 +4,6 @@ from __future__ import print_function from contextlib import contextmanager import atexit -import colorama import faulthandler import hashlib import inspect @@ -28,18 +27,43 @@ import ray.experimental.state as state import ray.gcs_utils import ray.memory_monitor as memory_monitor import ray.node +import ray.parameter +import ray.ray_constants as ray_constants import ray.remote_function import ray.serialization as serialization import ray.services as services import ray.signature -import ray.ray_constants as ray_constants + +from ray import ( + ActorHandleID, + ActorID, + ClientID, + DriverID, + ObjectID, + TaskID, +) from ray import import_thread -from ray import ObjectID, DriverID, ActorID, ActorHandleID, ClientID, TaskID from ray import profiling -from ray.function_manager import (FunctionActorManager, FunctionDescriptor) -import ray.parameter -from ray.utils import (check_oversized_pickle, is_cython, _random_string, - thread_safe_client, setup_logger) +from ray.core.generated.ErrorType import ErrorType +from ray.exceptions import ( + RayActorError, + RayError, + RayTaskError, + RayWorkerError, + UnreconstructableError, + RAY_EXCEPTION_TYPES, +) +from ray.function_manager import ( + FunctionActorManager, + FunctionDescriptor, +) +from ray.utils import ( + _random_string, + check_oversized_pickle, + is_cython, + setup_logger, + thread_safe_client, +) SCRIPT_MODE = 0 WORKER_MODE = 1 @@ -68,55 +92,6 @@ except ImportError: setproctitle = None -class RayTaskError(Exception): - """An object used internally to represent a task that threw an exception. - - If a task throws an exception during execution, a RayTaskError is stored in - the object store for each of the task's outputs. When an object is - retrieved from the object store, the Python method that retrieved it checks - to see if the object is a RayTaskError and if it is then an exception is - thrown propagating the error message. - - Currently, we either use the exception attribute or the traceback attribute - but not both. - - Attributes: - function_name (str): The name of the function that failed and produced - the RayTaskError. - traceback_str (str): The traceback from the exception. - """ - - def __init__(self, function_name, traceback_str): - """Initialize a RayTaskError.""" - if setproctitle: - self.proctitle = setproctitle.getproctitle() - else: - self.proctitle = "ray_worker" - self.pid = os.getpid() - self.host = os.uname()[1] - self.function_name = function_name - self.traceback_str = traceback_str - assert traceback_str is not None - - def __str__(self): - """Format a RayTaskError as a string.""" - lines = self.traceback_str.split("\n") - out = [] - in_worker = False - for line in lines: - if line.startswith("Traceback "): - out.append("{}{}{} (pid={}, host={})".format( - colorama.Fore.CYAN, self.proctitle, colorama.Fore.RESET, - self.pid, self.host)) - elif in_worker: - in_worker = False - elif "ray/worker.py" in line or "ray/function_manager.py" in line: - in_worker = True - else: - out.append(line) - return "\n".join(out) - - class ActorCheckpointInfo(object): """Information used to maintain actor checkpoints.""" @@ -400,6 +375,8 @@ class Worker(object): start_time = time.time() # Only send the warning once. warning_sent = False + serialization_context = self.get_serialization_context( + self.task_driver_id) while True: try: # We divide very large get requests into smaller get requests @@ -407,23 +384,23 @@ class Worker(object): # long time, if the store is blocked, it can block the manager # as well as a consequence. results = [] - for i in range(0, len(object_ids), - ray._config.worker_get_request_size()): - results += self.plasma_client.get( - object_ids[i:( - i + ray._config.worker_get_request_size())], + batch_size = ray._config.worker_fetch_request_size() + for i in range(0, len(object_ids), batch_size): + metadata_data_pairs = self.plasma_client.get_buffers( + object_ids[i:i + batch_size], timeout, - self.get_serialization_context(self.task_driver_id)) + with_meta=True, + ) + for j in range(len(metadata_data_pairs)): + metadata, data = metadata_data_pairs[j] + results.append( + self._deserialize_object_from_arrow( + data, + metadata, + object_ids[i + j], + serialization_context, + )) return results - except pyarrow.lib.ArrowInvalid: - # TODO(ekl): the local scheduler could include relevant - # metadata in the task kill case for a better error message - invalid_error = RayTaskError( - "", - "Invalid return value: likely worker died or was killed " - "while executing the task; check previous logs or dmesg " - "for errors.") - return [invalid_error] * len(object_ids) except pyarrow.DeserializationCallbackError: # Wait a little bit for the import thread to import the class. # If we currently have the worker lock, we need to release it @@ -448,6 +425,30 @@ class Worker(object): driver_id=self.task_driver_id) warning_sent = True + def _deserialize_object_from_arrow(self, data, metadata, object_id, + serialization_context): + if metadata: + # If metadata is not empty, return an exception object based on + # the error type. + error_type = int(metadata) + if error_type == ErrorType.WORKER_DIED: + return RayWorkerError() + elif error_type == ErrorType.ACTOR_DIED: + return RayActorError() + elif error_type == ErrorType.OBJECT_UNRECONSTRUCTABLE: + return UnreconstructableError(ray.ObjectID(object_id.binary())) + else: + assert False, "Unrecognized error type " + str(error_type) + elif data: + # If data is not empty, deserialize the object. + # Note, the lock is needed because `serialization_context` isn't + # thread-safe. + with self.plasma_client.lock: + return pyarrow.deserialize(data, serialization_context) + else: + # Object isn't available in plasma. + return plasma.ObjectNotAvailable + def get_object(self, object_ids): """Get the value or values in the object store associated with the IDs. @@ -741,7 +742,7 @@ class Worker(object): passed by value. Raises: - RayTaskError: This exception is raised if a task that + RayError: This exception is raised if a task that created one of the arguments failed. """ arguments = [] @@ -749,7 +750,7 @@ class Worker(object): if isinstance(arg, ObjectID): # get the object from the local object store argument = self.get_object([arg])[0] - if isinstance(argument, RayTaskError): + if isinstance(argument, RayError): raise argument else: # pass the argument by value @@ -831,11 +832,6 @@ class Worker(object): with profiling.profile("task:deserialize_arguments"): arguments = self._get_arguments_for_execution( function_name, args) - except RayTaskError as e: - self._handle_process_task_failure( - function_descriptor, return_object_ids, e, - ray.utils.format_error_message(traceback.format_exc())) - return except Exception as e: self._handle_process_task_failure( function_descriptor, return_object_ids, e, @@ -1155,12 +1151,15 @@ def _initialize_serialization(driver_id, worker=global_worker): worker.serialization_context_map[driver_id] = serialization_context - register_custom_serializer( - RayTaskError, - use_dict=True, - local=True, - driver_id=driver_id, - class_id="ray.RayTaskError") + # Register exception types. + for error_cls in RAY_EXCEPTION_TYPES: + register_custom_serializer( + error_cls, + use_dict=True, + local=True, + driver_id=driver_id, + class_id=error_cls.__module__ + ". " + error_cls.__name__, + ) # Tell Ray to serialize lambdas with pickle. register_custom_serializer( type(lambda: 0), @@ -2229,14 +2228,14 @@ def get(object_ids): if isinstance(object_ids, list): values = worker.get_object(object_ids) for i, value in enumerate(values): - if isinstance(value, RayTaskError): + if isinstance(value, RayError): last_task_error_raise_time = time.time() raise value return values else: value = worker.get_object([object_ids])[0] - if isinstance(value, RayTaskError): - # If the result is a RayTaskError, then the task that created + if isinstance(value, RayError): + # If the result is a RayError, then the task that created # this object failed, and we should propagate the error message # here. last_task_error_raise_time = time.time() diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index e643279f0..fcb70ab37 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -347,3 +347,23 @@ table ActorCheckpointIdData { // A list of the timestamps for each of the above `checkpoint_ids`. timestamps: [long]; } + +// This enum type is used as object's metadata to indicate the object's creating +// task has failed because of a certain error. +// TODO(hchen): We may want to make these errors more specific. E.g., we may want +// to distinguish between intentional and expected actor failures, and between +// worker process failure and node failure. +enum ErrorType:int { + // Indicates that a task failed because the worker died unexpectedly while executing it. + WORKER_DIED = 1, + // Indicates that a task failed because the actor died unexpectedly before finishing it. + ACTOR_DIED = 2, + // Indicates that an object is lost and cannot be reconstructed. + // Note, this currently only happens to actor objects. When the actor's state is already + // after the object's creating task, the actor cannot re-run the task. + // TODO(hchen): we may want to reuse this error type for more cases. E.g., + // 1) A object that was put by the driver. + // 2) The object's creating task is already cleaned up from GCS (this currently + // crashes raylet). + OBJECT_UNRECONSTRUCTABLE = 3, +} diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index deb099bb0..39f980d66 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -574,7 +574,7 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, auto tasks_to_remove = local_queues_.GetTaskIdsForActor(actor_id); auto removed_tasks = local_queues_.RemoveTasks(tasks_to_remove); for (auto const &task : removed_tasks) { - TreatTaskAsFailed(task); + TreatTaskAsFailed(task, ErrorType::ACTOR_DIED); } } else { RAY_CHECK(actor_registration.GetState() == ActorState::RECONSTRUCTING); @@ -858,7 +858,7 @@ void NodeManager::ProcessDisconnectClientMessage( // `HandleDisconnectedActor`. if (actor_id.is_nil()) { const Task &task = local_queues_.RemoveTask(task_id); - TreatTaskAsFailed(task); + TreatTaskAsFailed(task, ErrorType::WORKER_DIED); } if (!intentional_disconnect) { @@ -1214,9 +1214,10 @@ bool NodeManager::CheckDependencyManagerInvariant() const { return true; } -void NodeManager::TreatTaskAsFailed(const Task &task) { +void NodeManager::TreatTaskAsFailed(const Task &task, const ErrorType &error_type) { const TaskSpecification &spec = task.GetTaskSpecification(); - RAY_LOG(DEBUG) << "Treating task " << spec.TaskId() << " as failed."; + RAY_LOG(DEBUG) << "Treating task " << spec.TaskId() << " as failed because of error " + << EnumNameErrorType(error_type) << "."; // If this was an actor creation task that tried to resume from a checkpoint, // then erase it here since the task did not finish. if (spec.IsActorCreationTask()) { @@ -1231,20 +1232,22 @@ void NodeManager::TreatTaskAsFailed(const Task &task) { // information about the TaskSpecification implementation. num_returns -= 1; } + const std::string meta = std::to_string(static_cast(error_type)); for (int64_t i = 0; i < num_returns; i++) { - const ObjectID object_id = spec.ReturnId(i); - - std::shared_ptr data; - // TODO(ekl): this writes an invalid arrow object, which is sufficient to - // signal that the worker failed, but it would be nice to return more - // detailed failure metadata in the future. - arrow::Status status = - store_client_.Create(object_id.to_plasma_id(), 1, NULL, 0, &data); - if (!status.IsPlasmaObjectExists()) { - // TODO(rkn): We probably don't want this checks. E.g., if the object - // store is full, we don't want to kill the raylet. - RAY_ARROW_CHECK_OK(status); - RAY_ARROW_CHECK_OK(store_client_.Seal(object_id.to_plasma_id())); + const auto object_id = spec.ReturnId(i).to_plasma_id(); + arrow::Status status = store_client_.CreateAndSeal(object_id, "", meta); + if (!status.ok() && !status.IsPlasmaObjectExists()) { + // If we failed to save the error code, log a warning and push an error message + // to the driver. + std::ostringstream stream; + stream << "An plasma error (" << status.ToString() << ") occurred while saving" + << " error code to object " << object_id << ". Anyone who's getting this" + << " object may hang forever."; + std::string error_message = stream.str(); + RAY_LOG(WARNING) << error_message; + RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( + task.GetTaskSpecification().DriverId(), "task", error_message, + current_time_ms())); } } // A task failing is equivalent to assigning and finishing the task, so clean @@ -1297,7 +1300,7 @@ void NodeManager::TreatTaskAsFailedIfLost(const Task &task) { // The object does not exist on any nodes but has been created // before, so the object has been lost. Mark the task as failed to // prevent any tasks that depend on this object from hanging. - TreatTaskAsFailed(task); + TreatTaskAsFailed(task, ErrorType::OBJECT_UNRECONSTRUCTABLE); *task_marked_as_failed = true; } } @@ -1343,7 +1346,7 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag if (actor_entry->second.GetState() == ActorState::DEAD) { // If this actor is dead, either because the actor process is dead // or because its residing node is dead, treat this task as failed. - TreatTaskAsFailed(task); + TreatTaskAsFailed(task, ErrorType::ACTOR_DIED); } else { // If this actor is alive, check whether this actor is local. auto node_manager_id = actor_entry->second.GetNodeManagerId(); diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index dd9ac71bd..47bc86e53 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -157,8 +157,9 @@ class NodeManager { /// the local queue. /// /// \param task The task to fail. + /// \param error_type The type of the error that caused this task to fail. /// \return Void. - void TreatTaskAsFailed(const Task &task); + void TreatTaskAsFailed(const Task &task, const ErrorType &error_type); /// This is similar to TreatTaskAsFailed, but it will only mark the task as /// failed if at least one of the task's return values is lost. A return /// value is lost if it has been created before, but no longer exists on any diff --git a/test/actor_test.py b/test/actor_test.py index 4a3a0ad58..16642540c 100644 --- a/test/actor_test.py +++ b/test/actor_test.py @@ -1405,7 +1405,7 @@ def test_exception_raised_when_actor_node_dies(head_node_cluster): # Submit some new actor tasks. x_ids = [actor.inc.remote() for _ in range(5)] for x_id in x_ids: - with pytest.raises(ray.worker.RayTaskError): + with pytest.raises(ray.exceptions.RayActorError): # There is some small chance that ray.get will actually # succeed (if the object is transferred before the raylet # dies). @@ -2128,7 +2128,7 @@ def test_actor_eviction(shutdown_only): try: ray.get(obj) num_success += 1 - except ray.worker.RayTaskError: + except ray.exceptions.UnreconstructableError: num_evicted += 1 # Some objects should have been evicted, and some should still be in the # object store. @@ -2173,7 +2173,7 @@ def test_actor_reconstruction(ray_start_regular): pid = ray.get(actor.get_pid.remote()) os.kill(pid, signal.SIGKILL) # The actor has exceeded max reconstructions, and this task should fail. - with pytest.raises(ray.worker.RayTaskError): + with pytest.raises(ray.exceptions.RayActorError): ray.get(actor.increase.remote()) # Create another actor. @@ -2181,7 +2181,7 @@ def test_actor_reconstruction(ray_start_regular): # Intentionlly exit the actor actor.__ray_terminate__.remote() # Check that the actor won't be reconstructed. - with pytest.raises(ray.worker.RayTaskError): + with pytest.raises(ray.exceptions.RayActorError): ray.get(actor.increase.remote()) @@ -2241,7 +2241,7 @@ def test_actor_reconstruction_on_node_failure(head_node_cluster): object_store_socket = ray.get(actor.get_object_store_socket.remote()) kill_node(object_store_socket) # The actor has exceeded max reconstructions, and this task should fail. - with pytest.raises(ray.worker.RayTaskError): + with pytest.raises(ray.exceptions.RayActorError): ray.get(actor.increase.remote()) diff --git a/test/component_failures_test.py b/test/component_failures_test.py index 07d775885..a4b84bb7b 100644 --- a/test/component_failures_test.py +++ b/test/component_failures_test.py @@ -279,7 +279,7 @@ def test_worker_failed(ray_start_workers_separate_multinode): for object_id in object_ids: try: ray.get(object_id) - except ray.worker.RayTaskError: + except (ray.exceptions.RayTaskError, ray.exceptions.RayWorkerError): pass @@ -424,7 +424,7 @@ def test_actor_creation_node_failure(ray_start_cluster): for i, out in enumerate(children_out): try: ray.get(out) - except ray.worker.RayTaskError: + except ray.exceptions.RayActorError: children[i] = Child.remote(death_probability) # Remove a node. Any actor creation tasks that were forwarded to this # node must be reconstructed. diff --git a/test/failure_test.py b/test/failure_test.py index 203e6d96d..543cb7140 100644 --- a/test/failure_test.py +++ b/test/failure_test.py @@ -319,7 +319,8 @@ def test_worker_dying(ray_start_regular): def f(): eval("exit()") - f.remote() + with pytest.raises(ray.exceptions.RayWorkerError): + ray.get(f.remote()) wait_for_errors(ray_constants.WORKER_DIED_PUSH_ERROR, 1) @@ -340,9 +341,9 @@ def test_actor_worker_dying(ray_start_regular): a = Actor.remote() [obj], _ = ray.wait([a.kill.remote()], timeout=5.0) - with pytest.raises(Exception): + with pytest.raises(ray.exceptions.RayActorError): ray.get(obj) - with pytest.raises(Exception): + with pytest.raises(ray.exceptions.RayTaskError): ray.get(consume.remote(obj)) wait_for_errors(ray_constants.WORKER_DIED_PUSH_ERROR, 1) diff --git a/test/runtest.py b/test/runtest.py index b2ff6c730..de0fe5699 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -2621,7 +2621,7 @@ def test_inline_objects(shutdown_only): value = ray.get(inline_object) assert value == "inline" inlined += 1 - except ray.worker.RayTaskError: + except ray.exceptions.UnreconstructableError: pass # Make sure some objects were inlined. Some of them may not get inlined # because we evict the object soon after creating it. @@ -2638,7 +2638,7 @@ def test_inline_objects(shutdown_only): ray.worker.global_worker.plasma_client.delete([plasma_id]) # Objects created by an actor that were evicted and larger than the # maximum inline object size cannot be retrieved or reconstructed. - with pytest.raises(ray.worker.RayTaskError): + with pytest.raises(ray.exceptions.UnreconstructableError): ray.get(non_inline_object) == 10000 * [1] diff --git a/test/test_signal.py b/test/test_signal.py index e86812c8e..3ff8e7734 100644 --- a/test/test_signal.py +++ b/test/test_signal.py @@ -106,7 +106,7 @@ def test_task_crash(ray_start): try: ray.get(object_id) except Exception as e: - assert type(e) == ray.worker.RayTaskError + assert type(e) == ray.exceptions.RayTaskError finally: result_list = signal.receive([object_id], timeout=5) assert len(result_list) == 1 @@ -142,7 +142,7 @@ def test_actor_crash(ray_start): try: ray.get(a.crash.remote()) except Exception as e: - assert type(e) == ray.worker.RayTaskError + assert type(e) == ray.exceptions.RayTaskError finally: result_list = signal.receive([a], timeout=5) assert len(result_list) == 1 @@ -184,7 +184,7 @@ def test_actor_crash_init2(ray_start): try: ray.get(a.method.remote()) except Exception as e: - assert type(e) == ray.worker.RayTaskError + assert type(e) == ray.exceptions.RayTaskError finally: result_list = receive_all_signals([a], timeout=5) assert len(result_list) == 2