diff --git a/python/ray/actor.py b/python/ray/actor.py index 01134b7ed..c054a642c 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -39,6 +39,33 @@ def get_actor_method_function_id(attr): return ray.local_scheduler.ObjectID(function_id) +def get_actor_checkpoint(actor_id, worker): + """Get the most recent checkpoint associated with a given actor ID. + + Args: + actor_id: The actor ID of the actor to get the checkpoint for. + worker: The worker to use to get the checkpoint. + + Returns: + If a checkpoint exists, this returns a tuple of the checkpoint index + and the checkpoint. Otherwise it returns (-1, None). The checkpoint + index is the actor counter of the last task that was executed on + the actor before the checkpoint was made. + """ + # Get all of the keys associated with checkpoints for this actor. + actor_key = b"Actor:" + actor_id + checkpoint_indices = [int(key[len(b"checkpoint_"):]) + for key in worker.redis_client.hkeys(actor_key) + if key.startswith(b"checkpoint_")] + if len(checkpoint_indices) == 0: + return -1, None + most_recent_checkpoint_index = max(checkpoint_indices) + # Get the most recent checkpoint. + checkpoint = worker.redis_client.hget( + actor_key, "checkpoint_{}".format(most_recent_checkpoint_index)) + return most_recent_checkpoint_index, checkpoint + + def fetch_and_register_actor(actor_class_key, worker): """Import an actor. @@ -48,12 +75,15 @@ def fetch_and_register_actor(actor_class_key, worker): """ actor_id_str = worker.actor_id (driver_id, class_id, class_name, - module, pickled_class, actor_method_names) = worker.redis_client.hmget( + module, pickled_class, checkpoint_interval, + actor_method_names) = worker.redis_client.hmget( actor_class_key, ["driver_id", "class_id", "class_name", "module", - "class", "actor_method_names"]) + "class", "checkpoint_interval", + "actor_method_names"]) actor_name = class_name.decode("ascii") module = module.decode("ascii") + checkpoint_interval = int(checkpoint_interval) actor_method_names = json.loads(actor_method_names.decode("ascii")) # Create a temporary actor with some temporary methods so that if the actor @@ -62,6 +92,7 @@ def fetch_and_register_actor(actor_class_key, worker): class TemporaryActor(object): pass worker.actors[actor_id_str] = TemporaryActor() + worker.actor_checkpoint_interval = checkpoint_interval def temporary_actor_method(*xs): raise Exception("The actor with name {} failed to be imported, and so " @@ -79,6 +110,7 @@ def fetch_and_register_actor(actor_class_key, worker): try: unpickled_class = pickle.loads(pickled_class) + worker.actor_class = unpickled_class except Exception: # If an exception was thrown when the actor was imported, we record the # traceback and notify the scheduler of the failure. @@ -100,7 +132,8 @@ def fetch_and_register_actor(actor_class_key, worker): # for the actor. -def export_actor_class(class_id, Class, actor_method_names, worker): +def export_actor_class(class_id, Class, actor_method_names, + checkpoint_interval, worker): if worker.mode is None: raise NotImplemented("TODO(pcm): Cache actors") key = b"ActorClass:" + class_id @@ -108,6 +141,7 @@ def export_actor_class(class_id, Class, actor_method_names, worker): "class_name": Class.__name__, "module": Class.__module__, "class": pickle.dumps(Class), + "checkpoint_interval": checkpoint_interval, "actor_method_names": json.dumps(list(actor_method_names))} worker.redis_client.hmset(key, d) worker.redis_client.rpush("Exports", key) @@ -173,6 +207,18 @@ def reconstruct_actor_state(actor_id, worker): actor_id: The ID of the actor being reconstructed. worker: The worker object that is running the actor. """ + # Get the most recent actor checkpoint. + checkpoint_index, checkpoint = get_actor_checkpoint(actor_id, worker) + if checkpoint is not None: + print("Loading actor state from checkpoint {}" + .format(checkpoint_index)) + # Wait for the actor to have been defined. + worker._wait_for_actor() + # TODO(rkn): Restoring from the checkpoint may fail, so this should be + # in a try-except block and we should give a good error message. + worker.actors[actor_id] = ( + worker.actor_class.__ray_restore_from_checkpoint__(checkpoint)) + # TODO(rkn): This call is expensive. It'd be nice to find a way to get only # the tasks that are relevant to this actor. tasks = ray.global_state.task_table() @@ -238,10 +284,18 @@ def reconstruct_actor_state(actor_id, worker): # local scheduler does bookkeeping about this actor's resource # utilization and things like that. It's also important for updating # some state on the worker. - worker.submit_task( - hex_to_object_id(task_spec_info["FunctionID"]), - task_spec_info["Args"], - actor_id=hex_to_object_id(task_spec_info["ActorID"])) + if task_spec_info["ActorCounter"] > checkpoint_index: + worker.submit_task( + hex_to_object_id(task_spec_info["FunctionID"]), + task_spec_info["Args"], + actor_id=hex_to_object_id(task_spec_info["ActorID"])) + else: + # Pass in a dummy task with no arguments to avoid having to + # unnecessarily reconstruct past arguments. + worker.submit_task( + hex_to_object_id(task_spec_info["FunctionID"]), + [], + actor_id=hex_to_object_id(task_spec_info["ActorID"])) # Clear the extra state that we set. del worker.task_driver_id @@ -250,18 +304,22 @@ def reconstruct_actor_state(actor_id, worker): # Get the task from the local scheduler. retrieved_task = worker._get_next_task_from_local_scheduler() - # Assert that the retrieved task is the same as the constructed task. - assert (ray.local_scheduler.task_to_string(task_spec) == - ray.local_scheduler.task_to_string(retrieved_task)) - # Wait for the task to be ready and execute the task. - worker._wait_for_and_process_task(retrieved_task) + # If the task happened before the most recent checkpoint, ignore it. + # Otherwise, execute it. + if retrieved_task.actor_counter() > checkpoint_index: + # Assert that the retrieved task is the same as the constructed + # task. + assert (ray.local_scheduler.task_to_string(task_spec) == + ray.local_scheduler.task_to_string(retrieved_task)) + # Wait for the task to be ready and then execute it. + worker._wait_for_and_process_task(retrieved_task) # Enter the main loop to receive and process tasks. worker.main_loop() -def make_actor(cls, num_cpus, num_gpus): +def make_actor(cls, num_cpus, num_gpus, checkpoint_interval): # Modify the class to have an additional method that will be used for # terminating the worker. class Class(cls): @@ -278,6 +336,26 @@ def make_actor(cls, num_cpus, num_gpus): import os os._exit(0) + def __ray_save_checkpoint__(self): + if hasattr(self, "__ray_save__"): + object_to_serialize = self.__ray_save__() + else: + object_to_serialize = self + return pickle.dumps(object_to_serialize) + + @classmethod + def __ray_restore_from_checkpoint__(cls, pickled_checkpoint): + checkpoint = pickle.loads(pickled_checkpoint) + if hasattr(cls, "__ray_restore__"): + actor_object = cls.__new__(cls) + actor_object.__ray_restore__(checkpoint) + else: + # TODO(rkn): It's possible that this will cause problems. When + # you unpickle the same object twice, the two objects will not + # have the same class. + actor_object = pickle.loads(checkpoint) + return actor_object + Class.__module__ = cls.__module__ Class.__name__ = cls.__name__ @@ -363,6 +441,7 @@ def make_actor(cls, num_cpus, num_gpus): if len(exported) == 0: export_actor_class(class_id, Class, self._ray_actor_methods.keys(), + checkpoint_interval, ray.worker.global_worker) exported.append(0) # Export the actor. diff --git a/python/ray/worker.py b/python/ray/worker.py index c8811ad82..987e85a1c 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -581,6 +581,13 @@ class Worker(object): "data": data}) self.redis_client.rpush("ErrorKeys", error_key) + def _wait_for_actor(self): + """Wait until the actor has been imported.""" + assert self.actor_id != NIL_ACTOR_ID + # Wait until the actor has been imported. + while self.actor_id not in self.actors: + time.sleep(0.001) + def _wait_for_function(self, function_id, driver_id, timeout=10): """Wait until the function to be executed is present on this worker. @@ -764,6 +771,35 @@ class Worker(object): data={"function_id": function_id.id(), "function_name": function_name}) + def _checkpoint_actor_state(self, actor_counter): + """Checkpoint the actor state. + + This currently saves the checkpoint to Redis, but the checkpoint really + needs to go somewhere else. + + Args: + actor_counter: The index of the most recent task that ran on this + actor. + """ + print("Saving actor checkpoint. actor_counter = {}." + .format(actor_counter)) + actor_key = b"Actor:" + self.actor_id + checkpoint = self.actors[self.actor_id].__ray_save_checkpoint__() + # Save the checkpoint in Redis. TODO(rkn): Checkpoints should not + # be stored in Redis. Fix this. + self.redis_client.hset( + actor_key, + "checkpoint_{}".format(actor_counter), + checkpoint) + # Remove the previous checkpoints if there is one. + checkpoint_indices = [int(key[len(b"checkpoint_"):]) + for key in self.redis_client.hkeys(actor_key) + if key.startswith(b"checkpoint_")] + for index in checkpoint_indices: + if index < actor_counter: + self.redis_client.hdel(actor_key, + "checkpoint_{}".format(index)) + def _wait_for_and_process_task(self, task): """Wait for a task to be ready and process the task. @@ -811,6 +847,13 @@ class Worker(object): ray.worker.global_worker.local_scheduler_client.disconnect() os._exit(0) + # Checkpoint the actor state if it is the right time to do so. + actor_counter = task.actor_counter() + if (self.actor_id != NIL_ACTOR_ID and + self.actor_checkpoint_interval != -1 and + actor_counter % self.actor_checkpoint_interval == 0): + self._checkpoint_actor_state(actor_counter) + def _get_next_task_from_local_scheduler(self): """Get the next task from the local scheduler. @@ -2118,11 +2161,13 @@ def remote(*args, **kwargs): the driver. max_calls (int): The maximum number of tasks of this kind that can be run on a worker before the worker needs to be restarted. + checkpoint_interval (int): The number of tasks to run between + checkpoints of the actor state. """ worker = global_worker def make_remote_decorator(num_return_vals, num_cpus, num_gpus, - max_calls, func_id=None): + max_calls, checkpoint_interval, func_id=None): def remote_decorator(func_or_class): if inspect.isfunction(func_or_class): function_properties = FunctionProperties( @@ -2133,7 +2178,8 @@ def remote(*args, **kwargs): return remote_function_decorator(func_or_class, function_properties) if inspect.isclass(func_or_class): - return worker.make_actor(func_or_class, num_cpus, num_gpus) + return worker.make_actor(func_or_class, num_cpus, num_gpus, + checkpoint_interval) raise Exception("The @ray.remote decorator must be applied to " "either a function or to a class.") @@ -2203,17 +2249,21 @@ def remote(*args, **kwargs): num_cpus = kwargs["num_cpus"] if "num_cpus" in kwargs else 1 num_gpus = kwargs["num_gpus"] if "num_gpus" in kwargs else 0 max_calls = kwargs["max_calls"] if "max_calls" in kwargs else 0 + checkpoint_interval = (kwargs["checkpoint_interval"] + if "checkpoint_interval" in kwargs else -1) if _mode() == WORKER_MODE: if "function_id" in kwargs: function_id = kwargs["function_id"] return make_remote_decorator(num_return_vals, num_cpus, num_gpus, - max_calls, function_id) + max_calls, checkpoint_interval, + function_id) if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): # This is the case where the decorator is just @ray.remote. - return make_remote_decorator(num_return_vals, num_cpus, - num_gpus, max_calls)(args[0]) + return make_remote_decorator( + num_return_vals, num_cpus, + num_gpus, max_calls, checkpoint_interval)(args[0]) else: # This is the case where the decorator is something like # @ray.remote(num_return_vals=2). @@ -2223,13 +2273,16 @@ def remote(*args, **kwargs): "the arguments 'num_return_vals', 'num_cpus', " "'num_gpus', or 'max_calls', like " "'@ray.remote(num_return_vals=2)'.") - assert len(args) == 0 and ("num_return_vals" in kwargs or - "num_cpus" in kwargs or - "num_gpus" in kwargs or - "max_calls" in kwargs), error_string + assert (len(args) == 0 and + ("num_return_vals" in kwargs or + "num_cpus" in kwargs or + "num_gpus" in kwargs or + "max_calls" in kwargs or + "checkpoint_interval" in kwargs)), error_string for key in kwargs: assert key in ["num_return_vals", "num_cpus", - "num_gpus", "max_calls"], error_string + "num_gpus", "max_calls", + "checkpoint_interval"], error_string assert "function_id" not in kwargs return make_remote_decorator(num_return_vals, num_cpus, num_gpus, - max_calls) + max_calls, checkpoint_interval) diff --git a/test/actor_test.py b/test/actor_test.py index 23dba657e..f9075a0e4 100644 --- a/test/actor_test.py +++ b/test/actor_test.py @@ -1214,6 +1214,79 @@ class ActorReconstruction(unittest.TestCase): ray.worker.cleanup() + def testCheckpointing(self): + ray.worker._init(start_ray_local=True, num_local_schedulers=2, + num_workers=0, redirect_output=True) + + @ray.remote(checkpoint_interval=5) + class Counter(object): + def __init__(self): + self.x = 0 + # The number of times that inc has been called. We won't bother + # restoring this in the checkpoint + self.num_inc_calls = 0 + + def local_plasma(self): + return ray.worker.global_worker.plasma_client.store_socket_name + + def inc(self, *xs): + self.num_inc_calls += 1 + self.x += 1 + return self.x + + def get_num_inc_calls(self): + return self.num_inc_calls + + def test_restore(self): + # This method will only work if __ray_restore__ has been run. + return self.y + + def __ray_save__(self): + return self.x, -1 + + def __ray_restore__(self, checkpoint): + self.x, val = checkpoint + self.num_inc_calls = 0 + # Test that __ray_save__ has been run. + assert val == -1 + self.y = self.x + + local_plasma = ray.worker.global_worker.plasma_client.store_socket_name + + # Create an actor that is not on the local scheduler. + actor = Counter.remote() + while ray.get(actor.local_plasma.remote()) == local_plasma: + actor = Counter.remote() + + args = [ray.put(0) for _ in range(100)] + ids = [actor.inc.remote(*args[i:]) for i in range(100)] + + # Wait for the last task to finish running. + ray.get(ids[-1]) + + # Kill the second local scheduler. + process = ray.services.all_processes[ + ray.services.PROCESS_TYPE_LOCAL_SCHEDULER][1] + process.kill() + process.wait() + # Kill the corresponding plasma store to get rid of the cached objects. + process = ray.services.all_processes[ + ray.services.PROCESS_TYPE_PLASMA_STORE][1] + process.kill() + process.wait() + + # Get all of the results. TODO(rkn): This currently doesn't work. + # results = ray.get(ids) + # self.assertEqual(results, list(range(1, 1 + len(results)))) + + self.assertEqual(ray.get(actor.test_restore.remote()), 99) + + # The inc method should only have executed once on the new actor (for + # the one method call since the most recent checkpoint). + self.assertEqual(ray.get(actor.get_num_inc_calls.remote()), 1) + + ray.worker.cleanup() + if __name__ == "__main__": unittest.main(verbosity=2)