From 88a5b4e77bb628b56514ec948044f7ee8ef61bdb Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Thu, 16 Feb 2017 11:30:35 -0800 Subject: [PATCH] Simplify imports and exports and provide driver isolation for remote functions. (#288) * Remove import counter and export counter. * Provide isolation between drivers for remote functions. * Add test for driver function isolation. * Hash source code into function ID to reduce likelihood of collisions. * Fix failure test example. * Replace assertTrue with assertIn to improve failure messages in tests. * Fix failure test. --- python/ray/actor.py | 44 ++++---- python/ray/worker.py | 206 +++++++++++++--------------------- src/photon/photon_algorithm.c | 2 +- test/failure_test.py | 63 ++++++----- test/multi_node_test.py | 62 +++++++++- test/runtest.py | 56 +++++++++ 6 files changed, 252 insertions(+), 181 deletions(-) diff --git a/python/ray/actor.py b/python/ray/actor.py index a6c9b5733..ef0bab812 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -27,18 +27,19 @@ def get_actor_method_function_id(attr): Returns: Function ID corresponding to the method. """ - function_id = hashlib.sha1() - function_id.update(attr.encode("ascii")) - return photon.ObjectID(function_id.digest()) + function_id_hash = hashlib.sha1() + function_id_hash.update(attr.encode("ascii")) + function_id = function_id_hash.digest() + assert len(function_id) == 20 + return photon.ObjectID(function_id) def fetch_and_register_actor(key, worker): """Import an actor.""" - driver_id, actor_id_str, actor_name, module, pickled_class, class_export_counter = \ - worker.redis_client.hmget(key, ["driver_id", "actor_id", "name", "module", "class", "class_export_counter"]) + driver_id, actor_id_str, actor_name, module, pickled_class = \ + worker.redis_client.hmget(key, ["driver_id", "actor_id", "name", "module", "class"]) actor_id = photon.ObjectID(actor_id_str) actor_name = actor_name.decode("ascii") module = module.decode("ascii") - class_export_counter = int(class_export_counter) try: unpickled_class = pickling.loads(pickled_class) except: @@ -49,16 +50,17 @@ def fetch_and_register_actor(key, worker): worker.actors[actor_id_str] = unpickled_class.__new__(unpickled_class) for (k, v) in inspect.getmembers(unpickled_class, predicate=(lambda x: inspect.isfunction(x) or inspect.ismethod(x))): function_id = get_actor_method_function_id(k).id() - worker.function_names[function_id] = k - worker.functions[function_id] = v + worker.functions[driver_id][function_id] = (k, v) + # We do not set worker.function_properties[driver_id][function_id] because + # we currently do need the actor worker to submit new tasks for the actor. -def export_actor(actor_id, Class, worker): +def export_actor(actor_id, Class, actor_method_names, worker): """Export an actor to redis. Args: actor_id: The ID of the actor. Class: Name of the class to be exported as an actor. - worker: The worker class + actor_method_names (list): A list of the names of this actor's methods. """ ray.worker.check_main_thread() if worker.mode is None: @@ -66,28 +68,25 @@ def export_actor(actor_id, Class, worker): key = "Actor:{}".format(actor_id.id()) pickled_class = pickling.dumps(Class) + # For now, all actor methods have 1 return value and require 0 CPUs and GPUs. + driver_id = worker.task_driver_id.id() + for actor_method_name in actor_method_names: + function_id = get_actor_method_function_id(actor_method_name).id() + worker.function_properties[driver_id][function_id] = (1, 0, 0) + # Select a local scheduler for the actor. local_schedulers = state.get_local_schedulers() local_scheduler_id = random.choice(local_schedulers) worker.redis_client.publish("actor_notifications", actor_id.id() + local_scheduler_id) - # The export counter is computed differently depending on whether we are - # currently in a driver or a worker. - if worker.mode in [ray.SCRIPT_MODE, ray.SILENT_MODE]: - export_counter = worker.driver_export_counter - elif worker.mode == ray.WORKER_MODE: - # We don't actually need export counters for actors. - export_counter = 0 - d = {"driver_id": worker.task_driver_id.id(), + d = {"driver_id": driver_id, "actor_id": actor_id.id(), "name": Class.__name__, "module": Class.__module__, - "class": pickled_class, - "class_export_counter": export_counter} + "class": pickled_class} worker.redis_client.hmset(key, d) worker.redis_client.rpush("Exports", key) - worker.driver_export_counter += 1 def actor(Class): # The function actor_method_call gets called if somebody tries to call a @@ -105,7 +104,6 @@ def actor(Class): num_cpus = 0 num_gpus = 0 object_ids = ray.worker.global_worker.submit_task(function_id, "", args, - num_cpus, num_gpus, actor_id=actor_id) if len(object_ids) == 1: return object_ids[0] @@ -116,7 +114,7 @@ def actor(Class): def __init__(self, *args, **kwargs): self._ray_actor_id = random_actor_id() self._ray_actor_methods = {k: v for (k, v) in inspect.getmembers(Class, predicate=(lambda x: inspect.isfunction(x) or inspect.ismethod(x)))} - export_actor(self._ray_actor_id, Class, ray.worker.global_worker) + export_actor(self._ray_actor_id, Class, self._ray_actor_methods, ray.worker.global_worker) # Call __init__ as a remote function. if "__init__" in self._ray_actor_methods.keys(): actor_method_call(self._ray_actor_id, "__init__", *args, **kwargs) diff --git a/python/ray/worker.py b/python/ray/worker.py index 505d69a5b..766136a26 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -2,22 +2,23 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import json -import hashlib -import os -import sys -import time -import traceback -import copy -import collections -import funcsigs -import numpy as np -import colorama import atexit +import collections +import colorama +import copy +import funcsigs +import hashlib +import inspect +import json +import numpy as np +import os import random import redis -import threading import string +import sys +import threading +import time +import traceback # Ray modules import ray.pickling as pickling @@ -373,61 +374,27 @@ class Worker(object): that connect has been called already. cached_functions_to_run (List): A list of functions to run on all of the workers that should be exported as soon as connect is called. - driver_export_counter (int): The number of exports that the driver has - exported. This is only used on the driver. - worker_import_counter (int): The number of exports that the worker has - imported so far. This is only used on the workers. """ def __init__(self): """Initialize a Worker object.""" - self.functions = {} - # Use a defaultdict for the number of return values. If this is accessed - # with a missing key, the default value of 1 is returned, and that key value - # pair is added to the dict. - self.num_return_vals = collections.defaultdict(lambda: 1) - self.function_names = {} - self.function_export_counters = {} + # The functions field is a dictionary that maps a driver ID to a dictionary + # of functions that have been registered for that driver (this inner + # dictionary maps function IDs to a tuple of the function name and the + # function itself). This should only be used on workers that execute remote + # functions. + self.functions = collections.defaultdict(lambda: {}) + # The function_properties field is a dictionary that maps a driver ID to a + # dictionary of functions that have been registered for that driver (this + # inner dictionary maps function IDs to a tuple of the number of values + # returned by that function, the number of CPUs required by that function, + # and the number of GPUs required by that function). This is used when + # submitting a function (which can be done both on workers and on drivers). + self.function_properties = collections.defaultdict(lambda: {}) self.connected = False self.mode = None self.cached_remote_functions = [] self.cached_functions_to_run = [] - # The driver_export_counter and worker_import_counter are used to make sure - # that no task executes before everything it needs is present. For example, - # if we define a remote function f, a worker cannot execute a task for f - # until the worker has imported the function f. - # - When a remote function, a reusable variable, or a function to run is - # exported, the driver_export_counter is incremented. These exports must - # take place from the driver. - # - When an actor is created, the driver_export_counter is NOT - # incremented. Note that an actor can be created from a driver or from - # any worker. - # - When a worker imports a remote function, a reusable variable, or a - # function to run, its worker_import_counter is incremented. - # - Notably, when an actor is imported, its worker_import_counter is NOT - # incremented. - # - Whenever a remote function is DEFINED on the driver, it records the - # value of the driver_export_counter and a worker will not execute that - # remote function until it has imported that many exports (excluding - # actors). - # - When an actor is defined. - # a) If the actor is created on a driver, it records the - # driver_export_counter. - # b) If the actor is created inside a task on a regular worker, it - # records the driver_export_counter associated with the function in - # task creating the actor. - # c) If the actor is created inside a task on an actor worker, it - # records - # The worker that ultimately runs the actor will not execute any tasks - # until it has imported that many imports. - # - # TODO(rkn): These counters must be tracked separately for each driver. - # TODO(rkn): Maybe none of these counters are necessary? When executing a - # regular task, workers can just wait until the function ID is present. When - # executing an actor task, the actor worker can just wait until the actor - # has been defined. - self.driver_export_counter = 0 - self.worker_import_counter = 0 self.fetch_and_register = {} self.actors = {} # Use a defaultdict for the actor counts. If this is accessed with a missing @@ -526,7 +493,7 @@ class Worker(object): assert final_results[i][0] == object_ids[i].id() return [result[1][0] for result in final_results] - def submit_task(self, function_id, func_name, args, num_cpus, num_gpus, actor_id=photon.ObjectID(NIL_ACTOR_ID)): + def submit_task(self, function_id, func_name, args, actor_id=photon.ObjectID(NIL_ACTOR_ID)): """Submit a remote task to the scheduler. Tell the scheduler to schedule the execution of the function with name @@ -538,8 +505,6 @@ class Worker(object): args (List[Any]): The arguments to pass into the function. Arguments can be object IDs or they can be values. If they are values, they must be serializable objecs. - num_cpus (int): The number of cpu cores this task requires to run. - num_gpus (int): The number of gpus this task requires to run. """ with log_span("ray:submit_task", worker=self): check_main_thread() @@ -554,11 +519,14 @@ class Worker(object): else: args_for_photon.append(put(arg)) + # Look up the various function properties. + num_return_vals, num_cpus, num_gpus = self.function_properties[self.task_driver_id.id()][function_id.id()] + # Submit the task to Photon. task = photon.Task(self.task_driver_id, photon.ObjectID(function_id.id()), args_for_photon, - self.num_return_vals[function_id.id()], + num_return_vals, self.current_task_id, self.task_index, actor_id, self.actor_counters[actor_id], @@ -604,7 +572,6 @@ class Worker(object): "function_id": function_to_run_id, "function": pickling.dumps(function)}) self.redis_client.rpush("Exports", key) - self.driver_export_counter += 1 def push_error_to_driver(self, driver_id, error_type, message, data=None): """Push an error message to the driver to be printed in the background. @@ -963,8 +930,6 @@ def cleanup(worker=global_worker): disconnect(worker) worker.set_mode(None) - worker.driver_export_counter = 0 - worker.worker_import_counter = 0 if hasattr(worker, "plasma_client"): worker.plasma_client.shutdown() services.cleanup() @@ -1037,14 +1002,13 @@ If this driver is hanging, start a new one with def fetch_and_register_remote_function(key, worker=global_worker): """Import a remote function.""" - driver_id, function_id_str, function_name, serialized_function, num_return_vals, module, function_export_counter, num_cpus, num_gpus = \ + driver_id, function_id_str, function_name, serialized_function, num_return_vals, module, num_cpus, num_gpus = \ worker.redis_client.hmget(key, ["driver_id", "function_id", "name", "function", "num_return_vals", "module", - "function_export_counter", "num_cpus", "num_gpus"]) function_id = photon.ObjectID(function_id_str) @@ -1053,19 +1017,14 @@ def fetch_and_register_remote_function(key, worker=global_worker): num_cpus = int(num_cpus) num_gpus = int(num_gpus) module = module.decode("ascii") - function_export_counter = int(function_export_counter) - worker.function_names[function_id.id()] = function_name - worker.num_return_vals[function_id.id()] = num_return_vals - worker.function_export_counters[function_id.id()] = function_export_counter # This is a placeholder in case the function can't be unpickled. This will be - # overwritten if the function is unpickled successfully. + # overwritten if the function is successfully registered. def f(): raise Exception("This function was not imported properly.") - worker.functions[function_id.id()] = remote(num_return_vals=num_return_vals, - function_id=function_id, - num_cpus=num_cpus, - num_gpus=num_gpus)(lambda *xs: f()) + remote_f_placeholder = remote(function_id=function_id)(lambda *xs: f()) + worker.functions[driver_id][function_id.id()] = (function_name, remote_f_placeholder) + worker.function_properties[driver_id][function_id.id()] = (num_return_vals, num_cpus, num_gpus) try: function = pickling.loads(serialized_function) @@ -1081,10 +1040,7 @@ def fetch_and_register_remote_function(key, worker=global_worker): else: # TODO(rkn): Why is the below line necessary? function.__module__ = module - worker.functions[function_id.id()] = remote(num_return_vals=num_return_vals, - function_id=function_id, - num_cpus=num_cpus, - num_gpus=num_gpus)(function) + worker.functions[driver_id][function_id.id()] = (function_name, remote(function_id=function_id)(function)) # Add the function to the function table. worker.redis_client.rpush("FunctionTable:{}".format(function_id.id()), worker.worker_id) @@ -1133,10 +1089,7 @@ def import_thread(worker): # in the loop. worker.import_pubsub_client.psubscribe("__keyspace@0__:Exports") worker_info_key = "WorkerInfo:{}".format(worker.worker_id) - worker.redis_client.hset(worker_info_key, "export_counter", 0) - worker.worker_import_counter = 0 - # The number of imports is similar to the worker_import_counter except that it - # also counts actors. + # Keep track of the number of imports that we've imported. num_imported = 0 # Get the exports that occurred before the call to psubscribe. @@ -1157,10 +1110,6 @@ def import_thread(worker): worker.fetch_and_register["Actor"](key, worker) else: raise Exception("This code should be unreachable.") - # Actors do not contribute to the import counter. - if not key.startswith(b"Actor"): - worker.redis_client.hincrby(worker_info_key, "export_counter", 1) - worker.worker_import_counter += 1 num_imported += 1 for msg in worker.import_pubsub_client.listen(): @@ -1189,10 +1138,6 @@ def import_thread(worker): worker.fetch_and_register["Actor"](key, worker) else: raise Exception("This code should be unreachable.") - # Actors do not contribute to the import counter. - if not key.startswith(b"Actor"): - worker.redis_client.hincrby(worker_info_key, "export_counter", 1) - worker.worker_import_counter += 1 num_imported += 1 def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, actor_id=NIL_ACTOR_ID): @@ -1522,13 +1467,12 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): remaining_ids = [photon.ObjectID(object_id) for object_id in remaining_ids] return ready_ids, remaining_ids -def wait_for_valid_import_counter(function_id, driver_id, timeout=5, worker=global_worker): - """Wait until this worker has imported enough to execute the function. +def wait_for_function(function_id, driver_id, timeout=5, worker=global_worker): + """Wait until the function to be executed is present on this worker. - This method will simply loop until the import thread has imported enough of - the exports to execute the function. If we spend too long in this loop, that - may indicate a problem somewhere and we will push an error message to the - user. + This method will simply loop until the import thread has imported the relevant + function. If we spend too long in this loop, that may indicate a problem + somewhere and we will push an error message to the user. If this worker is an actor, then this will wait until the actor has been defined. @@ -1545,17 +1489,14 @@ def wait_for_valid_import_counter(function_id, driver_id, timeout=5, worker=glob num_warnings_sent = 0 while True: with worker.lock: - if worker.actor_id == NIL_ACTOR_ID and function_id.id() in worker.functions and (worker.function_export_counters[function_id.id()] <= worker.worker_import_counter): + if worker.actor_id == NIL_ACTOR_ID and function_id.id() in worker.functions[driver_id]: break elif worker.actor_id != NIL_ACTOR_ID and worker.actor_id in worker.actors: break if time.time() - start_time > timeout * (num_warnings_sent + 1): - if function_id.id() not in worker.functions: - warning_message = "This worker was asked to execute a function that it does not have registered. You may have to restart Ray." - else: - warning_message = "This worker's import counter is too small." + warning_message = "This worker was asked to execute a function that it does not have registered. You may have to restart Ray." if not warning_sent: - worker.push_error_to_driver(driver_id, "import_counter", + worker.push_error_to_driver(driver_id, "wait_for_function", warning_message) warning_sent = True time.sleep(0.001) @@ -1614,18 +1555,18 @@ def main_loop(worker=global_worker): function_id = task.function_id() args = task.arguments() return_object_ids = task.returns() - function_name = worker.function_names[function_id.id()] + function_name, function_executor = worker.functions[worker.task_driver_id.id()][function_id.id()] # Get task arguments from the object store. with log_span("ray:task:get_arguments", worker=worker): - arguments = get_arguments_for_execution(worker.functions[function_id.id()], args, worker) + arguments = get_arguments_for_execution(function_name, args, worker) # Execute the task. with log_span("ray:task:execute", worker=worker): if task.actor_id().id() == NIL_ACTOR_ID: - outputs = worker.functions[task.function_id().id()].executor(arguments) + outputs = function_executor.executor(arguments) else: - outputs = worker.functions[task.function_id().id()](worker.actors[task.actor_id().id()], *arguments) + outputs = function_executor(worker.actors[task.actor_id().id()], *arguments) # Store the outputs in the local object store. with log_span("ray:task:store_outputs", worker=worker): @@ -1680,11 +1621,11 @@ def main_loop(worker=global_worker): task = worker.photon_client.get_task() function_id = task.function_id() - # Check that the number of imports we have is at least as great as the - # export counter for the task. If not, wait until we have imported enough. - # We will push warnings to the user if we spend too long in this loop. - with log_span("ray:wait_for_import_counter", worker=worker): - wait_for_valid_import_counter(function_id, task.driver_id().id(), worker=worker) + # Wait until the function to be executed has actually been registered on + # this worker. We will push warnings to the user if we spend too long in + # this loop. + with log_span("ray:wait_for_function", worker=worker): + wait_for_function(function_id, task.driver_id().id(), worker=worker) # Execute the task. # TODO(rkn): Consider acquiring this lock with a timeout and pushing a @@ -1695,7 +1636,8 @@ def main_loop(worker=global_worker): with worker.lock: log(event_type="ray:acquire_lock", kind=LOG_SPAN_END, worker=worker) - contents = {"function_name": worker.function_names[function_id.id()], + function_name, _ = worker.functions[task.driver_id().id()][function_id.id()] + contents = {"function_name": function_name, "task_id": task.task_id().hex()} with log_span("ray:task", contents=contents, worker=worker): process_task(task) @@ -1703,7 +1645,7 @@ def main_loop(worker=global_worker): # Push all of the log events to the global state store. flush_log() -def _submit_task(function_id, func_name, args, num_cpus, num_gpus, worker=global_worker): +def _submit_task(function_id, func_name, args, worker=global_worker): """This is a wrapper around worker.submit_task. We use this wrapper so that in the remote decorator, we can call _submit_task @@ -1711,7 +1653,7 @@ def _submit_task(function_id, func_name, args, num_cpus, num_gpus, worker=global serialize remote functions, we don't attempt to serialize the worker object, which cannot be serialized. """ - return worker.submit_task(function_id, func_name, args, num_cpus, num_gpus) + return worker.submit_task(function_id, func_name, args) def _mode(worker=global_worker): """This is a wrapper around worker.mode. @@ -1751,14 +1693,13 @@ def _export_environment_variable(name, environment_variable, worker=global_worke "initializer": pickling.dumps(environment_variable.initializer), "reinitializer": pickling.dumps(environment_variable.reinitializer)}) worker.redis_client.rpush("Exports", key) - worker.driver_export_counter += 1 def export_remote_function(function_id, func_name, func, num_return_vals, num_cpus, num_gpus, worker=global_worker): check_main_thread() if _mode(worker) not in [SCRIPT_MODE, SILENT_MODE]: raise Exception("export_remote_function can only be called on a driver.") + worker.function_properties[worker.task_driver_id.id()][function_id.id()] = (num_return_vals, num_cpus, num_gpus) key = "RemoteFunction:{}".format(function_id.id()) - worker.num_return_vals[function_id.id()] = num_return_vals pickled_func = pickling.dumps(func) worker.redis_client.hmset(key, {"driver_id": worker.task_driver_id.id(), "function_id": function_id.id(), @@ -1766,11 +1707,9 @@ def export_remote_function(function_id, func_name, func, num_return_vals, num_cp "module": func.__module__, "function": pickled_func, "num_return_vals": num_return_vals, - "function_export_counter": worker.driver_export_counter, "num_cpus": num_cpus, "num_gpus": num_gpus}) worker.redis_client.rpush("Exports", key) - worker.driver_export_counter += 1 def remote(*args, **kwargs): """This decorator is used to create remote functions. @@ -1778,13 +1717,26 @@ def remote(*args, **kwargs): Args: num_return_vals (int): The number of object IDs that a call to this function should return. + num_cpus (int): The number of CPUs needed to execute this function. This + should only be passed in when defining the remote function on the driver. + num_gpus (int): The number of GPUs needed to execute this function. This + should only be passed in when defining the remote function on the driver. """ worker = global_worker def make_remote_decorator(num_return_vals, num_cpus, num_gpus, func_id=None): def remote_decorator(func): func_name = "{}.{}".format(func.__module__, func.__name__) if func_id is None: - function_id = FunctionID((hashlib.sha256(func_name.encode("ascii")).digest())[:20]) + # Compute the function ID as a hash of the function name as well as the + # source code. We could in principle hash in the values in the closure + # of the function, but that is likely to introduce non-determinism in + # the computation of the function ID. + function_id_hash = hashlib.sha1() + function_id_hash.update(func_name.encode("ascii")) + function_id_hash.update(inspect.getsource(func).encode("ascii")) + function_id = function_id_hash.digest() + assert len(function_id) == 20 + function_id = FunctionID(function_id) else: function_id = func_id @@ -1807,7 +1759,7 @@ def remote(*args, **kwargs): _env()._reinitialize() _env()._running_remote_function_locally = False return result - objectids = _submit_task(function_id, func_name, args, num_cpus, num_gpus) + objectids = _submit_task(function_id, func_name, args) if len(objectids) == 1: return objectids[0] elif len(objectids) > 1: @@ -1909,7 +1861,7 @@ def check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaul if has_vararg_param and any([d != funcsigs._empty for _, d in keyword_defaults]): raise "Function {} has a *args argument as well as a keyword argument, which is currently not supported.".format(name) -def get_arguments_for_execution(function, serialized_args, worker=global_worker): +def get_arguments_for_execution(function_name, serialized_args, worker=global_worker): """Retrieve the arguments for the remote function. This retrieves the values for the arguments to the remote function that were @@ -1917,8 +1869,8 @@ def get_arguments_for_execution(function, serialized_args, worker=global_worker) This is called by the worker that is executing the remote function. Args: - function (Callable): The remote function whose arguments are being - retrieved. + function_name (str): The name of the remote function whose arguments are + being retrieved. serialized_args (List): The arguments to the function. These are either strings representing serialized objects passed by value or they are ObjectIDs. @@ -1939,7 +1891,7 @@ def get_arguments_for_execution(function, serialized_args, worker=global_worker) if isinstance(argument, RayTaskError): # If the result is a RayTaskError, then the task that created this # object failed, and we should propagate the error message here. - raise RayGetArgumentError(function.__name__, i, arg, argument) + raise RayGetArgumentError(function_name, i, arg, argument) else: # pass the argument by value argument = arg diff --git a/src/photon/photon_algorithm.c b/src/photon/photon_algorithm.c index 8e8aa11cc..1d7e006d7 100644 --- a/src/photon/photon_algorithm.c +++ b/src/photon/photon_algorithm.c @@ -709,7 +709,7 @@ void queue_task_locally(local_scheduler_state *state, /** * Give a task directly to another local scheduler. This is currently only used - * for assigning actor tasks to the local scheduer responsible for that actor. + * for assigning actor tasks to the local scheduler responsible for that actor. * * @param state The scheduler state. * @param algorithm_state The scheduling algorithm state. diff --git a/test/failure_test.py b/test/failure_test.py index 158471964..07db58336 100644 --- a/test/failure_test.py +++ b/test/failure_test.py @@ -2,10 +2,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import unittest +import os import ray import sys +import tempfile import time +import unittest if sys.version_info >= (3, 0): from importlib import reload @@ -67,13 +69,13 @@ class TaskStatusTest(unittest.TestCase): result = ray.error_info() self.assertEqual(len(relevant_errors(b"task")), 2) for task in relevant_errors(b"task"): - self.assertTrue(b"Test function 1 intentionally failed." in task.get(b"message")) + self.assertIn(b"Test function 1 intentionally failed.", task.get(b"message")) x = test_functions.throw_exception_fct2.remote() try: ray.get(x) except Exception as e: - self.assertTrue("Test function 2 intentionally failed." in str(e)) + self.assertIn("Test function 2 intentionally failed.", str(e)) else: self.assertTrue(False) # ray.get should throw an exception @@ -82,7 +84,7 @@ class TaskStatusTest(unittest.TestCase): try: ray.get(ref) except Exception as e: - self.assertTrue("Test function 3 intentionally failed." in str(e)) + self.assertIn("Test function 3 intentionally failed.", str(e)) else: self.assertTrue(False) # ray.get should throw an exception @@ -91,29 +93,40 @@ class TaskStatusTest(unittest.TestCase): def testFailImportingRemoteFunction(self): ray.init(num_workers=2, driver_mode=ray.SILENT_MODE) - # This example is somewhat contrived. It should be successfully pickled, and - # then it should throw an exception when it is unpickled. This may depend a - # bit on the specifics of our pickler. - def reducer(*args): - raise Exception("There is a problem here.") - class Foo(object): - def __init__(self): - self.__name__ = "Foo_object" - self.func_doc = "" - self.__globals__ = {} - def __reduce__(self): - return reducer, () - def __call__(self): - return - f = ray.remote(Foo()) + # Create the contents of a temporary Python file. + temporary_python_file = """ +def temporary_helper_function(): + return 1 +""" + + f = tempfile.NamedTemporaryFile(suffix=".py") + f.write(temporary_python_file.encode("ascii")) + f.flush() + directory = os.path.dirname(f.name) + # Get the module name and strip ".py" from the end. + module_name = os.path.basename(f.name)[:-3] + sys.path.append(directory) + module = __import__(module_name) + + # Define a function that closes over this temporary module. This should fail + # when it is unpickled. + @ray.remote + def g(): + return module.temporary_python_file() + wait_for_errors(b"register_remote_function", 2) - self.assertTrue(b"There is a problem here." in ray.error_info()[0][b"message"]) + self.assertIn(b"No module named", ray.error_info()[0][b"message"]) + self.assertIn(b"No module named", ray.error_info()[1][b"message"]) # Check that if we try to call the function it throws an exception and does # not hang. for _ in range(10): - self.assertRaises(Exception, lambda : ray.get(f.remote())) + self.assertRaises(Exception, lambda : ray.get(g.remote())) + f.close() + + # Clean up the junk we added to sys.path. + sys.path.pop(-1) ray.worker.cleanup() def testFailImportingEnvironmentVariable(self): @@ -128,7 +141,7 @@ class TaskStatusTest(unittest.TestCase): ray.env.foo = ray.EnvironmentVariable(initializer) wait_for_errors(b"register_environment_variable", 2) # Check that the error message is in the task info. - self.assertTrue(b"The initializer failed." in ray.error_info()[0][b"message"]) + self.assertIn(b"The initializer failed.", ray.error_info()[0][b"message"]) ray.worker.cleanup() @@ -146,7 +159,7 @@ class TaskStatusTest(unittest.TestCase): use_foo.remote() wait_for_errors(b"reinitialize_environment_variable", 1) # Check that the error message is in the task info. - self.assertTrue(b"The reinitializer failed." in ray.error_info()[0][b"message"]) + self.assertIn(b"The reinitializer failed.", ray.error_info()[0][b"message"]) ray.worker.cleanup() @@ -160,8 +173,8 @@ class TaskStatusTest(unittest.TestCase): wait_for_errors(b"function_to_run", 2) # Check that the error message is in the task info. self.assertEqual(len(ray.error_info()), 2) - self.assertTrue(b"Function to run failed." in ray.error_info()[0][b"message"]) - self.assertTrue(b"Function to run failed." in ray.error_info()[1][b"message"]) + self.assertIn(b"Function to run failed.", ray.error_info()[0][b"message"]) + self.assertIn(b"Function to run failed.", ray.error_info()[1][b"message"]) ray.worker.cleanup() diff --git a/test/multi_node_test.py b/test/multi_node_test.py index e486f681d..1dabff337 100644 --- a/test/multi_node_test.py +++ b/test/multi_node_test.py @@ -16,16 +16,22 @@ stop_ray_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../s class MultiNodeTest(unittest.TestCase): - def testErrorIsolation(self): + def setUp(self): # Start the Ray processes on this machine. out = subprocess.check_output([start_ray_script, "--head"]).decode("ascii") # Get the redis address from the output. redis_substring_prefix = "redis_address=\"" redis_address_location = out.find(redis_substring_prefix) + len(redis_substring_prefix) redis_address = out[redis_address_location:] - redis_address = redis_address.split("\"")[0] + self.redis_address = redis_address.split("\"")[0] + + def tearDown(self): + # Kill the Ray cluster. + subprocess.Popen([stop_ray_script]).wait() + + def testErrorIsolation(self): # Connect a driver to the Ray cluster. - ray.init(redis_address=redis_address, driver_mode=ray.SILENT_MODE) + ray.init(redis_address=self.redis_address, driver_mode=ray.SILENT_MODE) # There shouldn't be any errors yet. self.assertEqual(len(ray.error_info()), 0) @@ -79,7 +85,7 @@ assert len(ray.error_info()) == 1 assert "{}" in ray.error_info()[0][b"message"].decode("ascii") print("success") -""".format(redis_address, error_string2, error_string2) +""".format(self.redis_address, error_string2, error_string2) # Save the driver script as a file so we can call it using subprocess. with tempfile.NamedTemporaryFile() as f: @@ -95,7 +101,53 @@ print("success") self.assertIn(error_string1, ray.error_info()[0][b"message"].decode("ascii")) ray.worker.cleanup() - subprocess.Popen([stop_ray_script]).wait() + + def testRemoteFunctionIsolation(self): + # This test will run multiple remote functions with the same names in two + # different drivers. + # Connect a driver to the Ray cluster. + ray.init(redis_address=self.redis_address, driver_mode=ray.SILENT_MODE) + + # Start another driver and make sure that it can define and call its own + # commands with the same names. + driver_script = """ +import ray +import time +ray.init(redis_address="{}") +@ray.remote +def f(): + return 3 +@ray.remote +def g(x, y): + return 4 +for _ in range(10000): + result = ray.get([f.remote(), g.remote(0, 0)]) + assert result == [3, 4] +print("success") +""".format(self.redis_address) + + # Save the driver script as a file so we can call it using subprocess. + with tempfile.NamedTemporaryFile() as f: + f.write(driver_script.encode("ascii")) + f.flush() + out = subprocess.check_output(["python", f.name]).decode("ascii") + + @ray.remote + def f(): + return 1 + + @ray.remote + def g(x): + return 2 + + for _ in range(10000): + result = ray.get([f.remote(), g.remote(0)]) + self.assertEqual(result, [1, 2]) + + # Make sure the other driver succeeded. + self.assertIn("success", out) + + ray.worker.cleanup() class StartRayScriptTest(unittest.TestCase): diff --git a/test/runtest.py b/test/runtest.py index e017b4bcc..7fec277f9 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -599,6 +599,62 @@ class APITest(unittest.TestCase): ray.worker.cleanup() + def testIdenticalFunctionNames(self): + # Define a bunch of remote functions and make sure that we don't + # accidentally call an older version. + ray.init(num_workers=2) + + num_remote_functions = 100 + num_calls = 200 + + @ray.remote + def f(): + return 1 + results1 = [f.remote() for _ in range(num_calls)] + @ray.remote + def f(): + return 2 + results2 = [f.remote() for _ in range(num_calls)] + @ray.remote + def f(): + return 3 + results3 = [f.remote() for _ in range(num_calls)] + @ray.remote + def f(): + return 4 + results4 = [f.remote() for _ in range(num_calls)] + @ray.remote + def f(): + return 5 + results5 = [f.remote() for _ in range(num_calls)] + + self.assertEqual(ray.get(results1), num_calls * [1]) + self.assertEqual(ray.get(results2), num_calls * [2]) + self.assertEqual(ray.get(results3), num_calls * [3]) + self.assertEqual(ray.get(results4), num_calls * [4]) + self.assertEqual(ray.get(results5), num_calls * [5]) + + @ray.remote + def g(): + return 1 + @ray.remote + def g(): + return 2 + @ray.remote + def g(): + return 3 + @ray.remote + def g(): + return 4 + @ray.remote + def g(): + return 5 + + result_values = ray.get([g.remote() for _ in range(num_calls)]) + self.assertEqual(result_values, num_calls * [5]) + + ray.worker.cleanup() + class PythonModeTest(unittest.TestCase): def testPythonMode(self):