Improved KeyboardInterrupt Exception Handling (#5237)

This commit is contained in:
Richard Liaw
2019-07-22 02:29:56 -07:00
committed by GitHub
parent f9043cc49a
commit 53fb876a5f
5 changed files with 34 additions and 3 deletions
+2 -1
View File
@@ -36,6 +36,7 @@ from ray.includes.unique_ids cimport (
)
from ray.includes.task cimport CTaskSpec
from ray.includes.ray_config cimport RayConfig
from ray.exceptions import RayletError
from ray.utils import decode
cimport cpython
@@ -57,7 +58,7 @@ cdef int check_status(const CRayStatus& status) nogil except -1:
with gil:
message = status.message().decode()
raise Exception(message)
raise RayletError(message)
cdef c_vector[CObjectID] ObjectIDsToVector(object_ids):
+13
View File
@@ -77,6 +77,19 @@ class RayActorError(RayError):
return "The actor died unexpectedly before finishing this task."
class RayletError(RayError):
"""Indicates that the Raylet client has errored.
This exception can be thrown when the raylet is killed.
"""
def __init__(self, client_exc):
self.client_exc = client_exc
def __str__(self):
return "The Raylet died with this message: {}".format(self.client_exc)
class UnreconstructableError(RayError):
"""Indicates that an object is lost and cannot be reconstructed.
+7
View File
@@ -2,6 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import redis
import threading
import traceback
@@ -11,6 +12,10 @@ from ray import cloudpickle as pickle
from ray import profiling
from ray import utils
import logging
logger = logging.getLogger(__name__)
class ImportThread(object):
"""A thread used to import exports from the driver or other workers.
@@ -80,6 +85,8 @@ class ImportThread(object):
num_imported += 1
key = self.redis_client.lindex("Exports", i)
self._process_key(key)
except (OSError, redis.exceptions.ConnectionError) as e:
logger.error("ImportThread: {}".format(e))
finally:
# Close the pubsub client to avoid leaking file descriptors.
import_pubsub_client.close()
@@ -112,6 +112,7 @@ if __name__ == "__main__":
args = parser.parse_args()
if args.ray_redis_address:
ray.init(redis_address=args.ray_redis_address)
datasets.MNIST("~/data", train=True, download=True)
sched = AsyncHyperBandScheduler(
time_attr="training_iteration", metric="mean_accuracy")
tune.run(
+11 -2
View File
@@ -12,6 +12,7 @@ import json
import logging
import numpy as np
import os
import redis
import signal
from six.moves import queue
import sys
@@ -1520,8 +1521,12 @@ def custom_excepthook(type, value, tb):
# If this is a driver, push the exception to redis.
if global_worker.mode == SCRIPT_MODE:
error_message = "".join(traceback.format_tb(tb))
global_worker.redis_client.hmset(b"Drivers:" + global_worker.worker_id,
{"exception": error_message})
try:
global_worker.redis_client.hmset(
b"Drivers:" + global_worker.worker_id,
{"exception": error_message})
except (ConnectionRefusedError, redis.exceptions.ConnectionError):
logger.warning("Could not push exception to redis.")
# Call the normal excepthook.
normal_excepthook(type, value, tb)
@@ -1583,6 +1588,8 @@ def print_logs(redis_client, threads_stopped):
"The driver may not be able to keep up with the "
"stdout/stderr of the workers. To avoid forwarding logs "
"to the driver, use 'ray.init(log_to_driver=False)'.")
except (OSError, redis.exceptions.ConnectionError) as e:
logger.error("print_logs: {}".format(e))
finally:
# Close the pubsub client to avoid leaking file descriptors.
pubsub_client.close()
@@ -1681,6 +1688,8 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
task_error_queue.put((error_message, time.time()))
else:
logger.error(error_message)
except (OSError, redis.exceptions.ConnectionError) as e:
logger.error("listen_error_messages_raylet: {}".format(e))
finally:
# Close the pubsub client to avoid leaking file descriptors.
worker.error_message_pubsub_client.close()