mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 20:02:10 +08:00
Improved KeyboardInterrupt Exception Handling (#5237)
This commit is contained in:
@@ -36,6 +36,7 @@ from ray.includes.unique_ids cimport (
|
||||
)
|
||||
from ray.includes.task cimport CTaskSpec
|
||||
from ray.includes.ray_config cimport RayConfig
|
||||
from ray.exceptions import RayletError
|
||||
from ray.utils import decode
|
||||
|
||||
cimport cpython
|
||||
@@ -57,7 +58,7 @@ cdef int check_status(const CRayStatus& status) nogil except -1:
|
||||
|
||||
with gil:
|
||||
message = status.message().decode()
|
||||
raise Exception(message)
|
||||
raise RayletError(message)
|
||||
|
||||
|
||||
cdef c_vector[CObjectID] ObjectIDsToVector(object_ids):
|
||||
|
||||
@@ -77,6 +77,19 @@ class RayActorError(RayError):
|
||||
return "The actor died unexpectedly before finishing this task."
|
||||
|
||||
|
||||
class RayletError(RayError):
|
||||
"""Indicates that the Raylet client has errored.
|
||||
|
||||
This exception can be thrown when the raylet is killed.
|
||||
"""
|
||||
|
||||
def __init__(self, client_exc):
|
||||
self.client_exc = client_exc
|
||||
|
||||
def __str__(self):
|
||||
return "The Raylet died with this message: {}".format(self.client_exc)
|
||||
|
||||
|
||||
class UnreconstructableError(RayError):
|
||||
"""Indicates that an object is lost and cannot be reconstructed.
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import redis
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
@@ -11,6 +12,10 @@ from ray import cloudpickle as pickle
|
||||
from ray import profiling
|
||||
from ray import utils
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImportThread(object):
|
||||
"""A thread used to import exports from the driver or other workers.
|
||||
@@ -80,6 +85,8 @@ class ImportThread(object):
|
||||
num_imported += 1
|
||||
key = self.redis_client.lindex("Exports", i)
|
||||
self._process_key(key)
|
||||
except (OSError, redis.exceptions.ConnectionError) as e:
|
||||
logger.error("ImportThread: {}".format(e))
|
||||
finally:
|
||||
# Close the pubsub client to avoid leaking file descriptors.
|
||||
import_pubsub_client.close()
|
||||
|
||||
@@ -112,6 +112,7 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
if args.ray_redis_address:
|
||||
ray.init(redis_address=args.ray_redis_address)
|
||||
datasets.MNIST("~/data", train=True, download=True)
|
||||
sched = AsyncHyperBandScheduler(
|
||||
time_attr="training_iteration", metric="mean_accuracy")
|
||||
tune.run(
|
||||
|
||||
+11
-2
@@ -12,6 +12,7 @@ import json
|
||||
import logging
|
||||
import numpy as np
|
||||
import os
|
||||
import redis
|
||||
import signal
|
||||
from six.moves import queue
|
||||
import sys
|
||||
@@ -1520,8 +1521,12 @@ def custom_excepthook(type, value, tb):
|
||||
# If this is a driver, push the exception to redis.
|
||||
if global_worker.mode == SCRIPT_MODE:
|
||||
error_message = "".join(traceback.format_tb(tb))
|
||||
global_worker.redis_client.hmset(b"Drivers:" + global_worker.worker_id,
|
||||
{"exception": error_message})
|
||||
try:
|
||||
global_worker.redis_client.hmset(
|
||||
b"Drivers:" + global_worker.worker_id,
|
||||
{"exception": error_message})
|
||||
except (ConnectionRefusedError, redis.exceptions.ConnectionError):
|
||||
logger.warning("Could not push exception to redis.")
|
||||
# Call the normal excepthook.
|
||||
normal_excepthook(type, value, tb)
|
||||
|
||||
@@ -1583,6 +1588,8 @@ def print_logs(redis_client, threads_stopped):
|
||||
"The driver may not be able to keep up with the "
|
||||
"stdout/stderr of the workers. To avoid forwarding logs "
|
||||
"to the driver, use 'ray.init(log_to_driver=False)'.")
|
||||
except (OSError, redis.exceptions.ConnectionError) as e:
|
||||
logger.error("print_logs: {}".format(e))
|
||||
finally:
|
||||
# Close the pubsub client to avoid leaking file descriptors.
|
||||
pubsub_client.close()
|
||||
@@ -1681,6 +1688,8 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
|
||||
task_error_queue.put((error_message, time.time()))
|
||||
else:
|
||||
logger.error(error_message)
|
||||
except (OSError, redis.exceptions.ConnectionError) as e:
|
||||
logger.error("listen_error_messages_raylet: {}".format(e))
|
||||
finally:
|
||||
# Close the pubsub client to avoid leaking file descriptors.
|
||||
worker.error_message_pubsub_client.close()
|
||||
|
||||
Reference in New Issue
Block a user