From ddadc18ef65652e67e451abe13b3b53f6782e592 Mon Sep 17 00:00:00 2001 From: Mitchell Stern Date: Thu, 5 Sep 2019 13:18:57 -0700 Subject: [PATCH] Fix bug in ray.errors and update its default behavior (#5576) --- python/ray/state.py | 33 ++++++++++++++++++--------------- python/ray/tests/test_stress.py | 5 +++-- python/ray/tests/utils.py | 9 ++++++++- python/ray/worker.py | 4 ++-- 4 files changed, 31 insertions(+), 20 deletions(-) diff --git a/python/ray/state.py b/python/ray/state.py index 9ba192eb8..7581bc7a9 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -945,8 +945,9 @@ class GlobalState(object): None, then this method retrieves the errors for all jobs. Returns: - A dictionary mapping driver ID to a list of the error messages for - that driver. + A list of the error messages for the specified driver if one was + given, or a dictionary mapping from job ID to a list of error + messages for that driver otherwise. """ self._check_connected() @@ -1057,12 +1058,11 @@ class DeprecatedGlobalState(object): "instead.") return ray.available_resources() - def error_messages(self, job_id=None): + def error_messages(self, all_jobs=False): logger.warning( "ray.global_state.error_messages() is deprecated and will be " - "removed in a subsequent release. Use ray.errors() " - "instead.") - return ray.errors(job_id=job_id) + "removed in a subsequent release. Use ray.errors() instead.") + return ray.errors(all_jobs=all_jobs) state = GlobalState() @@ -1185,19 +1185,22 @@ def available_resources(): return state.available_resources() -def errors(include_cluster_errors=True): +def errors(all_jobs=False): """Get error messages from the cluster. Args: - include_cluster_errors: True if we should include error messages for - all drivers, and false if we should only include error messages for - this specific driver. + all_jobs: False if we should only include error messages for this + specific job, or True if we should include error messages for all + jobs. Returns: - Error messages pushed from the cluster. + Error messages pushed from the cluster. This will be a single list if + all_jobs is False, or a dictionary mapping from job ID to a list of + error messages for that job if all_jobs is True. """ - worker = ray.worker.global_worker - error_messages = state.error_messages(job_id=worker.current_job_id) - if include_cluster_errors: - error_messages += state.error_messages(job_id=ray.JobID.nil()) + if not all_jobs: + worker = ray.worker.global_worker + error_messages = state.error_messages(job_id=worker.current_job_id) + else: + error_messages = state.error_messages(job_id=None) return error_messages diff --git a/python/ray/tests/test_stress.py b/python/ray/tests/test_stress.py index 036ccda98..0ab8ec501 100644 --- a/python/ray/tests/test_stress.py +++ b/python/ray/tests/test_stress.py @@ -11,6 +11,7 @@ import time import ray from ray.tests.cluster_utils import Cluster +from ray.tests.utils import flat_errors import ray.ray_constants as ray_constants @@ -397,13 +398,13 @@ def wait_for_errors(error_check): errors = [] time_left = 100 while time_left > 0: - errors = ray.errors() + errors = flat_errors() if error_check(errors): break time_left -= 1 time.sleep(1) - # Make sure that enough errors came through. + # Make sure that enough errors came through. assert error_check(errors) return errors diff --git a/python/ray/tests/utils.py b/python/ray/tests/utils.py index e7dff2639..5203a7d8e 100644 --- a/python/ray/tests/utils.py +++ b/python/ray/tests/utils.py @@ -75,8 +75,15 @@ def run_string_as_driver_nonblocking(driver_script): [sys.executable, f.name], stdout=subprocess.PIPE) +def flat_errors(): + errors = [] + for job_errors in ray.errors(all_jobs=True).values(): + errors.extend(job_errors) + return errors + + def relevant_errors(error_type): - return [info for info in ray.errors() if info["type"] == error_type] + return [error for error in flat_errors() if error["type"] == error_type] def wait_for_errors(error_type, num_errors, timeout=10): diff --git a/python/ray/worker.py b/python/ray/worker.py index 011b73ee8..cb9286799 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1764,8 +1764,8 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): # worker.error_message_pubsub_client.psubscribe("*") try: - # Get the exports that occurred before the call to subscribe. - error_messages = ray.errors(include_cluster_errors=False) + # Get the errors that occurred before the call to subscribe. + error_messages = ray.errors() for error_message in error_messages: logger.error(error_message)