diff --git a/python/ray/runtime_context.py b/python/ray/runtime_context.py index 318824ed3..fa922cfa0 100644 --- a/python/ray/runtime_context.py +++ b/python/ray/runtime_context.py @@ -12,10 +12,10 @@ class RuntimeContext(object): self.worker = worker def get(self): - """Get a dictionary of the current_context. + """Get a dictionary of the current context. - For fields that are not available (for example actor id inside a task) - won't be included in the field. + Fields that are not available (e.g., actor ID inside a task) won't be + included in the field. Returns: dict: Dictionary of the current context. @@ -23,14 +23,14 @@ class RuntimeContext(object): context = { "job_id": self.job_id, "node_id": self.node_id, - "task_id": self.task_id, - "actor_id": self.actor_id - } - # Remove fields that are None. - return { - key: value - for key, value in context.items() if value is not None } + if self.worker.mode == ray.worker.WORKER_MODE: + if self.task_id is not None: + context["task_id"] = self.task_id + if self.actor_id is not None: + context["actor_id"] = self.actor_id + + return context @property def job_id(self): diff --git a/python/ray/tests/test_runtime_context.py b/python/ray/tests/test_runtime_context.py index aa56df592..5f24f538f 100644 --- a/python/ray/tests/test_runtime_context.py +++ b/python/ray/tests/test_runtime_context.py @@ -64,6 +64,36 @@ def test_was_current_actor_reconstructed(shutdown_only): ray.get(f.remote()) +def test_get_context_dict(ray_start_regular): + context_dict = ray.get_runtime_context().get() + assert context_dict["node_id"] is not None + assert context_dict["job_id"] is not None + assert "actor_id" not in context_dict + assert "task_id" not in context_dict + + @ray.remote + class Actor: + def check(self, node_id, job_id): + context_dict = ray.get_runtime_context().get() + assert context_dict["node_id"] == node_id + assert context_dict["job_id"] == job_id + assert context_dict["actor_id"] is not None + assert context_dict["task_id"] is not None + + a = Actor.remote() + ray.get(a.check.remote(context_dict["node_id"], context_dict["job_id"])) + + @ray.remote + def task(node_id, job_id): + context_dict = ray.get_runtime_context().get() + assert context_dict["node_id"] == node_id + assert context_dict["job_id"] == job_id + assert context_dict["task_id"] is not None + assert "actor_id" not in context_dict + + ray.get(task.remote(context_dict["node_id"], context_dict["job_id"])) + + if __name__ == "__main__": import pytest sys.exit(pytest.main(["-v", __file__]))