From 794a0932496e87de74823467f7c8044e9c53c889 Mon Sep 17 00:00:00 2001 From: Wang Qing Date: Tue, 19 Feb 2019 15:57:30 +0800 Subject: [PATCH] Add runtime_context to get some runtime fields in worker (#4065) --- python/ray/__init__.py | 2 ++ python/ray/runtime_context.py | 34 ++++++++++++++++++++++++++++++++++ python/ray/worker.py | 1 + test/runtest.py | 13 +++++++------ 4 files changed, 44 insertions(+), 6 deletions(-) create mode 100644 python/ray/runtime_context.py diff --git a/python/ray/__init__.py b/python/ray/__init__.py index a38797c68..3d243c322 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -92,6 +92,7 @@ import ray.internal # noqa: E402 # some functions in the worker. import ray.actor # noqa: F401 from ray.actor import method # noqa: E402 +from ray.runtime_context import _get_runtime_context # noqa: E402 # Ray version string. __version__ = "0.7.0.dev0" @@ -103,6 +104,7 @@ __all__ = [ "WORKER_MODE", "__version__", "_config", + "_get_runtime_context", "actor", "connect", "disconnect", diff --git a/python/ray/runtime_context.py b/python/ray/runtime_context.py new file mode 100644 index 000000000..cb3b004cb --- /dev/null +++ b/python/ray/runtime_context.py @@ -0,0 +1,34 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ray.worker + + +class RuntimeContext(object): + """A class used for getting runtime context.""" + + def __init__(self, worker=None): + self.worker = worker + + @property + def current_driver_id(self): + """Get current driver ID for this worker or driver. + + Returns: + If called by a driver, this returns the driver ID. If called in + a task, return the driver ID of the associated driver. + """ + assert self.worker is not None + return self.worker.task_driver_id + + +_runtime_context = None + + +def _get_runtime_context(): + global _runtime_context + if _runtime_context is None: + _runtime_context = RuntimeContext(ray.worker.get_global_worker()) + + return _runtime_context diff --git a/python/ray/worker.py b/python/ray/worker.py index 556a4f765..3af7c7011 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -44,6 +44,7 @@ from ray import ( ) from ray import import_thread from ray import profiling + from ray.core.generated.ErrorType import ErrorType from ray.exceptions import ( RayActorError, diff --git a/test/runtest.py b/test/runtest.py index e1d281f17..0f8f0f10e 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -2590,16 +2590,17 @@ def test_workers(shutdown_only): def test_specific_driver_id(): dummy_driver_id = ray.DriverID(b"00112233445566778899") - ray.init(driver_id=dummy_driver_id) + ray.init(num_cpus=1, driver_id=dummy_driver_id) + # in driver + assert dummy_driver_id == ray._get_runtime_context().current_driver_id + + # in worker @ray.remote def f(): - return ray.worker.global_worker.task_driver_id.binary() + return ray._get_runtime_context().current_driver_id - assert dummy_driver_id.binary() == ray.worker.global_worker.worker_id - - task_driver_id = ray.get(f.remote()) - assert dummy_driver_id.binary() == task_driver_id + assert dummy_driver_id == ray.get(f.remote()) ray.shutdown()