From 509685d240d5f424655681806fce5d1e3a8ab95a Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Tue, 3 Jan 2017 18:41:03 -0800 Subject: [PATCH] Let the worker know about remote functions that failed to unpickle. (#175) * Let the worker know about remote functions that failed to unpickle. * Cleanup. --- lib/python/ray/worker.py | 14 ++++++++++---- test/failure_test.py | 11 ++++++++--- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index 18176cdca..a677035e0 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -909,6 +909,16 @@ def fetch_and_register_remote_function(key, worker=global_worker): num_return_vals = int(num_return_vals) module = module.decode("ascii") function_export_counter = int(function_export_counter) + + worker.function_names[function_id.id()] = function_name + worker.num_return_vals[function_id.id()] = num_return_vals + worker.function_export_counters[function_id.id()] = function_export_counter + # This is a placeholder in case the function can't be unpickled. This will be + # overwritten if the function is unpickled successfully. + def f(): + raise Exception("This function was not imported properly.") + worker.functions[function_id.id()] = remote(num_return_vals=num_return_vals, function_id=function_id)(lambda *xs: f()) + try: function = pickling.loads(serialized_function) except: @@ -924,11 +934,7 @@ def fetch_and_register_remote_function(key, worker=global_worker): else: # TODO(rkn): Why is the below line necessary? function.__module__ = module - function_name = "{}.{}".format(function.__module__, function.__name__) worker.functions[function_id.id()] = remote(num_return_vals=num_return_vals, function_id=function_id)(function) - worker.function_names[function_id.id()] = function_name - worker.num_return_vals[function_id.id()] = num_return_vals - worker.function_export_counters[function_id.id()] = function_export_counter # Add the function to the function table. worker.redis_client.rpush("FunctionTable:{}".format(function_id.id()), worker.worker_id) diff --git a/test/failure_test.py b/test/failure_test.py index a4f55cdc5..69e36185e 100644 --- a/test/failure_test.py +++ b/test/failure_test.py @@ -103,10 +103,15 @@ class TaskStatusTest(unittest.TestCase): return reducer, () def __call__(self): return - ray.remote(Foo()) - wait_for_errors(b"RemoteFunctionImportError", 1) + f = ray.remote(Foo()) + wait_for_errors(b"RemoteFunctionImportError", 2) self.assertTrue(b"There is a problem here." in ray.error_info()[b"RemoteFunctionImportError"][0][b"message"]) + # Check that if we try to call the function it throws an exception and does + # not hang. + for _ in range(10): + self.assertRaises(Exception, lambda : ray.get(f.remote())) + ray.worker.cleanup() def testFailImportingReusableVariable(self): @@ -119,7 +124,7 @@ class TaskStatusTest(unittest.TestCase): raise Exception("The initializer failed.") return 0 ray.reusables.foo = ray.Reusable(initializer) - wait_for_errors(b"ReusableVariableImportError", 1) + wait_for_errors(b"ReusableVariableImportError", 2) # Check that the error message is in the task info. self.assertTrue(b"The initializer failed." in ray.error_info()[b"ReusableVariableImportError"][0][b"message"])