From fe6daef85e1f3cd2aef12b39487124a6e2dcc2b9 Mon Sep 17 00:00:00 2001 From: Lixin Wei Date: Thu, 27 Aug 2020 11:11:42 +0800 Subject: [PATCH] [Core]Add runtime context for python worker (#10309) * add runtime context for python * fixed * code fixed * test added * lint * lint --- python/ray/__init__.py | 3 +- python/ray/runtime_context.py | 57 ++++++++++++++++++ python/ray/state.py | 1 + python/ray/tests/test_runtime_context.py | 76 ++++++++++++++++++++++++ 4 files changed, 136 insertions(+), 1 deletion(-) create mode 100644 python/ray/runtime_context.py create mode 100644 python/ray/tests/test_runtime_context.py diff --git a/python/ray/__init__.py b/python/ray/__init__.py index 9409b5abf..6b0bdd1b2 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -91,6 +91,7 @@ import ray.projects # noqa: E402 import ray.actor # noqa: F401 from ray.actor import method # noqa: E402 from ray.cross_language import java_function, java_actor_class # noqa: E402 +from ray.runtime_context import get_runtime_context # noqa: E402 from ray import util # noqa: E402 # Replaced with the current commit when building the wheels. @@ -100,7 +101,7 @@ __version__ = "0.9.0.dev0" __all__ = [ "__version__", "_config", - "_get_runtime_context", + "get_runtime_context", "actor", "actors", "available_resources", diff --git a/python/ray/runtime_context.py b/python/ray/runtime_context.py new file mode 100644 index 000000000..387d8a48c --- /dev/null +++ b/python/ray/runtime_context.py @@ -0,0 +1,57 @@ +import ray.worker +import logging + +logger = logging.getLogger(__name__) + + +class RuntimeContext(object): + """A class used for getting runtime context.""" + + def __init__(self, worker): + assert worker is not None + self.worker = worker + + @property + def current_job_id(self): + """Get current job ID for this worker or driver. + + Returns: + If called by a driver, this returns the job ID. If called in + a task, return the job ID of the associated driver. + """ + return self.worker.current_job_id + + @property + def current_actor_id(self): + """Get the current actor ID in this worker. + + Returns: + The current driver id in this worker. + """ + # only worker mode has actor_id + assert self.worker.mode == ray.worker.WORKER_MODE, ( + f"This method is only available when the process is a\ + worker. Current mode: {self.worker.mode}") + return self.worker.actor_id + + @property + def was_current_actor_reconstructed(self): + """Check whether this actor has been restarted + + Returns: + Whether this actor has been ever restarted. + """ + # TODO: this method should not be called in a normal task. + actor_info = ray.state.actors(self.current_actor_id.hex()) + return actor_info and actor_info["NumRestarts"] != 0 + + +_runtime_context = None + + +def get_runtime_context(): + global _runtime_context + if _runtime_context is None: + _runtime_context = RuntimeContext(ray.worker.global_worker) + + return _runtime_context diff --git a/python/ray/state.py b/python/ray/state.py index 6b3394fed..005ab1a92 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -251,6 +251,7 @@ class GlobalState: actor_table_data.owner_address.raylet_id), }, "State": actor_table_data.state, + "NumRestarts": actor_table_data.num_restarts, "Timestamp": actor_table_data.timestamp, } return actor_info diff --git a/python/ray/tests/test_runtime_context.py b/python/ray/tests/test_runtime_context.py new file mode 100644 index 000000000..b71516b1c --- /dev/null +++ b/python/ray/tests/test_runtime_context.py @@ -0,0 +1,76 @@ +import ray +import os +import signal +import time +import sys + + +def test_was_current_actor_reconstructed(): + ray.init() + + @ray.remote(max_restarts=10) + class A(object): + def __init__(self): + self._was_reconstructed = ray.get_runtime_context( + ).was_current_actor_reconstructed + + def get_was_reconstructed(self): + return self._was_reconstructed + + def update_was_reconstructed(self): + return ray.get_runtime_context().was_current_actor_reconstructed + + def get_pid(self): + return os.getpid() + + # The following methods is to apply the checkpointable interface. + def should_checkpoint(self, checkpoint_context): + return False + + def save_checkpoint(self, actor_id, checkpoint_id): + pass + + def load_checkpoint(self, actor_id, available_checkpoints): + pass + + def checkpoint_expired(self, actor_id, checkpoint_id): + pass + + a = A.remote() + # `was_reconstructed` should be False when it's called in actor. + assert ray.get(a.get_was_reconstructed.remote()) is False + # `was_reconstructed` should be False when it's called in a remote method + # and the actor never fails. + assert ray.get(a.update_was_reconstructed.remote()) is False + + pid = ray.get(a.get_pid.remote()) + os.kill(pid, signal.SIGKILL) + time.sleep(2) + # These 2 methods should be return True because + # this actor failed and restored. + assert ray.get(a.get_was_reconstructed.remote()) is True + assert ray.get(a.update_was_reconstructed.remote()) is True + + ray.shutdown() + + +def test_runtime_context_interface(): + ray.init() + + @ray.remote(max_restarts=10) + class A(object): + def current_job_id(self): + return ray.get_runtime_context().current_job_id + + def current_actor_id(self): + return ray.get_runtime_context().current_actor_id + + a = A.remote() + assert ray.get(a.current_job_id.remote()) is not None + assert ray.get(a.current_actor_id.remote()) is not None + ray.shutdown() + + +if __name__ == "__main__": + import pytest + sys.exit(pytest.main(["-v", __file__]))