From 741f4824b1b0e00f96ec386724a27bb2598c4adb Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Sat, 25 Jun 2016 09:43:57 -0700 Subject: [PATCH] error messages for gets (#158) --- lib/python/ray/worker.py | 26 ++++++++++++++++++++++---- test/runtest.py | 23 ++++++++++++++++++++--- test/test_functions.py | 12 ++++++++++-- 3 files changed, 52 insertions(+), 9 deletions(-) diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index dcf39131e..50f98aa3a 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -13,6 +13,18 @@ import ray from ray.config import LOG_DIRECTORY, LOG_TIMESTAMP import serialization +class RayFailedObject(object): + """If a task throws an exception during execution, a RayFailedObject is stored in the object store for each of the tasks outputs.""" + + def __init__(self, error_message=None): + self.error_message = error_message + + def deserialize(self, primitives): + self.error_message = primitives + + def serialize(self): + return self.error_message + class RayDealloc(object): def __init__(self, handle, segmentid): self.handle = handle @@ -162,7 +174,10 @@ def get(objref, worker=global_worker): ray.lib.request_object(worker.handle, objref) if worker.mode == ray.SHELL_MODE or worker.mode == ray.SCRIPT_MODE: print_task_info(ray.lib.task_info(worker.handle), worker.mode) - return worker.get_object(objref) + value = worker.get_object(objref) + if isinstance(value, RayFailedObject): + raise Exception("The task that created this object reference failed with error message: {}".format(value.error_message)) + return value def put(value, worker=global_worker): objref = ray.lib.get_objref(worker.handle) @@ -180,7 +195,13 @@ def main_loop(worker=global_worker): arguments = get_arguments_for_execution(worker.functions[func_name], args, worker) # get args from objstore try: outputs = worker.functions[func_name].executor(arguments) # execute the function + if len(return_objrefs) == 1: + outputs = (outputs,) except Exception as e: + # Here we are storing RayFailedObjects in the object store to indicate + # failure (this is only interpreted by the worker). + failure_objects = [RayFailedObject(str(e)) for _ in range(len(return_objrefs))] + store_outputs_in_objstore(return_objrefs, failure_objects, worker) ray.lib.notify_task_completed(worker.handle, False, str(e)) # notify the scheduler that the task threw an exception logging.info("Worker through exception with message: {}, while running function {}.".format(str(e), func_name)) else: @@ -308,9 +329,6 @@ def get_arguments_for_execution(function, args, worker=global_worker): # helper method, this should not be called by the user def store_outputs_in_objstore(objrefs, outputs, worker=global_worker): - if len(objrefs) == 1: - outputs = (outputs,) - for i in range(len(objrefs)): if isinstance(outputs[i], ray.lib.ObjRef): # An ObjRef is being returned, so we must alias objrefs[i] so that it refers to the same object that outputs[i] refers to diff --git a/test/runtest.py b/test/runtest.py index 99692b4ab..a0994840d 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -243,8 +243,8 @@ class TaskStatusTest(unittest.TestCase): worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py") services.start_singlenode_cluster(return_drivers=False, num_workers_per_objstore=3, worker_path=worker_path, driver_mode=ray.WORKER_MODE) test_functions.test_alias_f() - test_functions.throw_exception_fct() - test_functions.throw_exception_fct() + test_functions.throw_exception_fct1() + test_functions.throw_exception_fct1() time.sleep(1) result = ray.task_info() self.assertTrue(len(result["failed_tasks"]) == 2) @@ -252,10 +252,27 @@ class TaskStatusTest(unittest.TestCase): for task in result["failed_tasks"]: self.assertTrue(task.has_key("worker_address")) self.assertTrue(task.has_key("operationid")) - self.assertEqual(task.get("error_message"), "Test function intentionally failed.") + self.assertEqual(task.get("error_message"), "Test function 1 intentionally failed.") self.assertTrue(task["operationid"] not in task_ids) task_ids.add(task["operationid"]) + x = test_functions.throw_exception_fct2() + try: + ray.get(x) + except Exception as e: + self.assertEqual(str(e), "The task that created this object reference failed with error message: Test function 2 intentionally failed.") + else: + self.assertTrue(False) # ray.get should throw an exception + + x, y, z = test_functions.throw_exception_fct3(1.0) + for ref in [x, y, z]: + try: + ray.get(ref) + except Exception as e: + self.assertEqual(str(e), "The task that created this object reference failed with error message: Test function 3 intentionally failed.") + else: + self.assertTrue(False) # ray.get should throw an exception + def check_get_deallocated(data): x = ray.put(data) ray.get(x) diff --git a/test/test_functions.py b/test/test_functions.py index 3509271e6..7e75671af 100644 --- a/test/test_functions.py +++ b/test/test_functions.py @@ -75,5 +75,13 @@ except: # test throwing an exception @ray.remote([], []) -def throw_exception_fct(): - raise Exception("Test function intentionally failed.") +def throw_exception_fct1(): + raise Exception("Test function 1 intentionally failed.") + +@ray.remote([], [int]) +def throw_exception_fct2(): + raise Exception("Test function 2 intentionally failed.") + +@ray.remote([float], [int, str, np.ndarray]) +def throw_exception_fct3(x): + raise Exception("Test function 3 intentionally failed.")