diff --git a/python/ray/__init__.py b/python/ray/__init__.py index 8fd789798..cfee4b088 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -57,13 +57,14 @@ from ray.worker import global_state # noqa: E402 # We import ray.actor because some code is run in actor.py which initializes # some functions in the worker. import ray.actor # noqa: F401 +from ray.actor import method # noqa: E402 # Ray version string. TODO(rkn): This is also defined separately in setup.py. # Fix this. __version__ = "0.3.0" __all__ = ["error_info", "init", "connect", "disconnect", "get", "put", "wait", - "remote", "log_event", "log_span", "flush_log", "actor", + "remote", "log_event", "log_span", "flush_log", "actor", "method", "get_gpu_ids", "get_webui_url", "register_custom_serializer", "SCRIPT_MODE", "WORKER_MODE", "PYTHON_MODE", "SILENT_MODE", "global_state", "_config", "__version__"] diff --git a/python/ray/actor.py b/python/ray/actor.py index 17d8b3590..86a8ac621 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -210,15 +210,19 @@ def fetch_and_register_actor(actor_class_key, worker): actor_id_str = worker.actor_id (driver_id, class_id, class_name, module, pickled_class, checkpoint_interval, - actor_method_names) = worker.redis_client.hmget( + actor_method_names, + actor_method_num_return_vals) = worker.redis_client.hmget( actor_class_key, ["driver_id", "class_id", "class_name", "module", "class", "checkpoint_interval", - "actor_method_names"]) + "actor_method_names", + "actor_method_num_return_vals"]) actor_name = class_name.decode("ascii") module = module.decode("ascii") checkpoint_interval = int(checkpoint_interval) actor_method_names = json.loads(actor_method_names.decode("ascii")) + actor_method_num_return_vals = json.loads( + actor_method_num_return_vals.decode("ascii")) # Create a temporary actor with some temporary methods so that if the actor # fails to be unpickled, the temporary actor can be used (just to produce @@ -233,7 +237,7 @@ def fetch_and_register_actor(actor_class_key, worker): "cannot execute this method".format(actor_name)) # Register the actor method signatures. register_actor_signatures(worker, driver_id, class_name, - actor_method_names) + actor_method_names, actor_method_num_return_vals) # Register the actor method executors. for actor_method_name in actor_method_names: function_id = compute_actor_method_function_id(class_name, @@ -287,7 +291,8 @@ def fetch_and_register_actor(actor_class_key, worker): def register_actor_signatures(worker, driver_id, class_name, - actor_method_names): + actor_method_names, + actor_method_num_return_vals): """Register an actor's method signatures in the worker. Args: @@ -295,16 +300,20 @@ def register_actor_signatures(worker, driver_id, class_name, driver_id: The ID of the driver that this actor is associated with. actor_id: The ID of the actor. actor_method_names: The names of the methods to register. + actor_method_num_return_vals: A list of the number of return values for + each of the actor's methods. """ - for actor_method_name in actor_method_names: + assert len(actor_method_names) == len(actor_method_num_return_vals) + for actor_method_name, num_return_vals in zip( + actor_method_names, actor_method_num_return_vals): # TODO(rkn): When we create a second actor, we are probably overwriting # the values from the first actor here. This may or may not be a # problem. function_id = compute_actor_method_function_id(class_name, actor_method_name).id() - # For now, all actor methods have 1 return value. worker.function_properties[driver_id][function_id] = ( - FunctionProperties(num_return_vals=2, + # The extra return value is an actor dummy object. + FunctionProperties(num_return_vals=num_return_vals + 1, resources={"CPU": 1}, max_calls=0)) @@ -329,6 +338,7 @@ def publish_actor_class_to_key(key, actor_class_info, worker): def export_actor_class(class_id, Class, actor_method_names, + actor_method_num_return_vals, checkpoint_interval, worker): key = b"ActorClass:" + class_id actor_class_info = { @@ -336,7 +346,9 @@ def export_actor_class(class_id, Class, actor_method_names, "module": Class.__module__, "class": pickle.dumps(Class), "checkpoint_interval": checkpoint_interval, - "actor_method_names": json.dumps(list(actor_method_names))} + "actor_method_names": json.dumps(list(actor_method_names)), + "actor_method_num_return_vals": json.dumps( + actor_method_num_return_vals)} if worker.mode is None: # This means that 'ray.init()' has not been called yet and so we must @@ -356,8 +368,8 @@ def export_actor_class(class_id, Class, actor_method_names, # https://github.com/ray-project/ray/issues/1146. -def export_actor(actor_id, class_id, class_name, actor_method_names, resources, - worker): +def export_actor(actor_id, class_id, class_name, actor_method_names, + actor_method_num_return_vals, resources, worker): """Export an actor to redis. Args: @@ -365,6 +377,8 @@ def export_actor(actor_id, class_id, class_name, actor_method_names, resources, class_id (str): A random ID for the actor class. class_name (str): The actor class name. actor_method_names (list): A list of the names of this actor's methods. + actor_method_num_return_vals: A list of the number of return values for + each of the actor's methods. resources: A dictionary mapping resource name to the quantity of that resource required by the actor. """ @@ -375,7 +389,7 @@ def export_actor(actor_id, class_id, class_name, actor_method_names, resources, driver_id = worker.task_driver_id.id() register_actor_signatures(worker, driver_id, class_name, - actor_method_names) + actor_method_names, actor_method_num_return_vals) # Select a local scheduler for the actor. key = b"Actor:" + actor_id.id() @@ -403,6 +417,19 @@ def export_actor(actor_id, class_id, class_name, actor_method_names, resources, worker.redis_client) +def method(*args, **kwargs): + assert len(args) == 0 + assert len(kwargs) == 1 + assert "num_return_vals" in kwargs + num_return_vals = kwargs["num_return_vals"] + + def annotate_method(method): + method.__ray_num_return_vals__ = num_return_vals + return method + + return annotate_method + + # Create objects to wrap method invocations. This is done so that we can # invoke methods with actor.method.remote() instead of actor.method(). class ActorMethod(object): @@ -441,13 +468,14 @@ class ActorHandleWrapper(object): can tell that an argument is an ActorHandle. """ def __init__(self, actor_id, actor_handle_id, actor_cursor, actor_counter, - actor_method_names, method_signatures, checkpoint_interval, - class_name): + actor_method_names, actor_method_num_return_vals, + method_signatures, checkpoint_interval, class_name): self.actor_id = actor_id self.actor_handle_id = actor_handle_id self.actor_cursor = actor_cursor self.actor_counter = actor_counter self.actor_method_names = actor_method_names + self.actor_method_num_return_vals = actor_method_num_return_vals # TODO(swang): Fetch this information from Redis so that we don't have # to fall back to pickle. self.method_signatures = method_signatures @@ -474,6 +502,7 @@ def wrap_actor_handle(actor_handle): actor_handle._ray_actor_cursor, 0, # Reset the actor counter. actor_handle._ray_actor_method_names, + actor_handle._ray_actor_method_num_return_vals, actor_handle._ray_method_signatures, actor_handle._ray_checkpoint_interval, actor_handle._ray_class_name) @@ -493,7 +522,8 @@ def unwrap_actor_handle(worker, wrapper): """ driver_id = worker.task_driver_id.id() register_actor_signatures(worker, driver_id, wrapper.class_name, - wrapper.actor_method_names) + wrapper.actor_method_names, + wrapper.actor_method_num_return_vals) actor_handle_class = make_actor_handle_class(wrapper.class_name) actor_object = actor_handle_class.__new__(actor_handle_class) @@ -503,6 +533,7 @@ def unwrap_actor_handle(worker, wrapper): wrapper.actor_cursor, wrapper.actor_counter, wrapper.actor_method_names, + wrapper.actor_method_num_return_vals, wrapper.method_signatures, wrapper.checkpoint_interval) return actor_object @@ -530,13 +561,16 @@ def make_actor_handle_class(class_name): "called on the original Class.") def _manual_init(self, actor_id, actor_handle_id, actor_cursor, - actor_counter, actor_method_names, method_signatures, + actor_counter, actor_method_names, + actor_method_num_return_vals, method_signatures, checkpoint_interval): self._ray_actor_id = actor_id self._ray_actor_handle_id = actor_handle_id self._ray_actor_cursor = actor_cursor self._ray_actor_counter = actor_counter self._ray_actor_method_names = actor_method_names + self._ray_actor_method_num_return_vals = ( + actor_method_num_return_vals) self._ray_method_signatures = method_signatures self._ray_checkpoint_interval = checkpoint_interval self._ray_class_name = class_name @@ -702,6 +736,13 @@ def actor_handle_from_class(Class, class_id, resources, checkpoint_interval): actor_method_names = [method_name for method_name, _ in actor_methods] + actor_method_num_return_vals = [] + for _, method in actor_methods: + if hasattr(method, "__ray_num_return_vals__"): + actor_method_num_return_vals.append( + method.__ray_num_return_vals__) + else: + actor_method_num_return_vals.append(1) # Do not export the actor class or the actor if run in PYTHON_MODE # Instead, instantiate the actor locally and add it to # global_worker's dictionary @@ -712,18 +753,21 @@ def actor_handle_from_class(Class, class_id, resources, checkpoint_interval): # Export the actor. if not exported: export_actor_class(class_id, Class, actor_method_names, + actor_method_num_return_vals, checkpoint_interval, ray.worker.global_worker) exported.append(0) export_actor(actor_id, class_id, class_name, - actor_method_names, resources, - ray.worker.global_worker) + actor_method_names, actor_method_num_return_vals, + resources, ray.worker.global_worker) # Instantiate the actor handle. actor_object = cls.__new__(cls) actor_object._manual_init(actor_id, actor_handle_id, actor_cursor, actor_counter, actor_method_names, - method_signatures, checkpoint_interval) + actor_method_num_return_vals, + method_signatures, + checkpoint_interval) # Call __init__ as a remote function. if "__init__" in actor_object._ray_actor_method_names: diff --git a/test/actor_test.py b/test/actor_test.py index 930370ae8..5d1d30338 100644 --- a/test/actor_test.py +++ b/test/actor_test.py @@ -264,6 +264,40 @@ class ActorAPI(unittest.TestCase): self.assertEqual(actor_class_info[b"class_name"], b"Foo") self.assertEqual(actor_class_info[b"module"], b"__main__") + def testMultipleReturnValues(self): + ray.init(num_workers=0) + + @ray.remote + class Foo(object): + def method0(self): + return 1 + + @ray.method(num_return_vals=1) + def method1(self): + return 1 + + @ray.method(num_return_vals=2) + def method2(self): + return 1, 2 + + @ray.method(num_return_vals=3) + def method3(self): + return 1, 2, 3 + + f = Foo.remote() + + id0 = f.method0.remote() + self.assertEqual(ray.get(id0), 1) + + id1 = f.method1.remote() + self.assertEqual(ray.get(id1), 1) + + id2a, id2b = f.method2.remote() + self.assertEqual(ray.get([id2a, id2b]), [1, 2]) + + id3a, id3b, id3c = f.method3.remote() + self.assertEqual(ray.get([id3a, id3b, id3c]), [1, 2, 3]) + class ActorMethods(unittest.TestCase):