diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 8005a4d04..83d48eafc 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -36,6 +36,7 @@ from ray.includes.unique_ids cimport ( ) from ray.includes.task cimport CTaskSpec from ray.includes.ray_config cimport RayConfig +from ray.exceptions import RayletError from ray.utils import decode cimport cpython @@ -57,7 +58,7 @@ cdef int check_status(const CRayStatus& status) nogil except -1: with gil: message = status.message().decode() - raise Exception(message) + raise RayletError(message) cdef c_vector[CObjectID] ObjectIDsToVector(object_ids): diff --git a/python/ray/exceptions.py b/python/ray/exceptions.py index 2cd5ed56d..1d1dfcd97 100644 --- a/python/ray/exceptions.py +++ b/python/ray/exceptions.py @@ -77,6 +77,19 @@ class RayActorError(RayError): return "The actor died unexpectedly before finishing this task." +class RayletError(RayError): + """Indicates that the Raylet client has errored. + + This exception can be thrown when the raylet is killed. + """ + + def __init__(self, client_exc): + self.client_exc = client_exc + + def __str__(self): + return "The Raylet died with this message: {}".format(self.client_exc) + + class UnreconstructableError(RayError): """Indicates that an object is lost and cannot be reconstructed. diff --git a/python/ray/import_thread.py b/python/ray/import_thread.py index 91f5e8b1d..484410f1c 100644 --- a/python/ray/import_thread.py +++ b/python/ray/import_thread.py @@ -2,6 +2,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import redis import threading import traceback @@ -11,6 +12,10 @@ from ray import cloudpickle as pickle from ray import profiling from ray import utils +import logging + +logger = logging.getLogger(__name__) + class ImportThread(object): """A thread used to import exports from the driver or other workers. @@ -80,6 +85,8 @@ class ImportThread(object): num_imported += 1 key = self.redis_client.lindex("Exports", i) self._process_key(key) + except (OSError, redis.exceptions.ConnectionError) as e: + logger.error("ImportThread: {}".format(e)) finally: # Close the pubsub client to avoid leaking file descriptors. import_pubsub_client.close() diff --git a/python/ray/tune/examples/mnist_pytorch.py b/python/ray/tune/examples/mnist_pytorch.py index acef9fc51..b791f7415 100644 --- a/python/ray/tune/examples/mnist_pytorch.py +++ b/python/ray/tune/examples/mnist_pytorch.py @@ -112,6 +112,7 @@ if __name__ == "__main__": args = parser.parse_args() if args.ray_redis_address: ray.init(redis_address=args.ray_redis_address) + datasets.MNIST("~/data", train=True, download=True) sched = AsyncHyperBandScheduler( time_attr="training_iteration", metric="mean_accuracy") tune.run( diff --git a/python/ray/worker.py b/python/ray/worker.py index 40bbd8c75..e9729d243 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -12,6 +12,7 @@ import json import logging import numpy as np import os +import redis import signal from six.moves import queue import sys @@ -1520,8 +1521,12 @@ def custom_excepthook(type, value, tb): # If this is a driver, push the exception to redis. if global_worker.mode == SCRIPT_MODE: error_message = "".join(traceback.format_tb(tb)) - global_worker.redis_client.hmset(b"Drivers:" + global_worker.worker_id, - {"exception": error_message}) + try: + global_worker.redis_client.hmset( + b"Drivers:" + global_worker.worker_id, + {"exception": error_message}) + except (ConnectionRefusedError, redis.exceptions.ConnectionError): + logger.warning("Could not push exception to redis.") # Call the normal excepthook. normal_excepthook(type, value, tb) @@ -1583,6 +1588,8 @@ def print_logs(redis_client, threads_stopped): "The driver may not be able to keep up with the " "stdout/stderr of the workers. To avoid forwarding logs " "to the driver, use 'ray.init(log_to_driver=False)'.") + except (OSError, redis.exceptions.ConnectionError) as e: + logger.error("print_logs: {}".format(e)) finally: # Close the pubsub client to avoid leaking file descriptors. pubsub_client.close() @@ -1681,6 +1688,8 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): task_error_queue.put((error_message, time.time())) else: logger.error(error_message) + except (OSError, redis.exceptions.ConnectionError) as e: + logger.error("listen_error_messages_raylet: {}".format(e)) finally: # Close the pubsub client to avoid leaking file descriptors. worker.error_message_pubsub_client.close()