Fix getting runtime context dict in driver (#13417)

This commit is contained in:
Edward Oakes
2021-01-14 14:41:53 -06:00
committed by GitHub
parent 411e37ce3f
commit 7ba87b8abe
2 changed files with 40 additions and 10 deletions
+10 -10
View File
@@ -12,10 +12,10 @@ class RuntimeContext(object):
self.worker = worker
def get(self):
"""Get a dictionary of the current_context.
"""Get a dictionary of the current context.
For fields that are not available (for example actor id inside a task)
won't be included in the field.
Fields that are not available (e.g., actor ID inside a task) won't be
included in the field.
Returns:
dict: Dictionary of the current context.
@@ -23,14 +23,14 @@ class RuntimeContext(object):
context = {
"job_id": self.job_id,
"node_id": self.node_id,
"task_id": self.task_id,
"actor_id": self.actor_id
}
# Remove fields that are None.
return {
key: value
for key, value in context.items() if value is not None
}
if self.worker.mode == ray.worker.WORKER_MODE:
if self.task_id is not None:
context["task_id"] = self.task_id
if self.actor_id is not None:
context["actor_id"] = self.actor_id
return context
@property
def job_id(self):
+30
View File
@@ -64,6 +64,36 @@ def test_was_current_actor_reconstructed(shutdown_only):
ray.get(f.remote())
def test_get_context_dict(ray_start_regular):
context_dict = ray.get_runtime_context().get()
assert context_dict["node_id"] is not None
assert context_dict["job_id"] is not None
assert "actor_id" not in context_dict
assert "task_id" not in context_dict
@ray.remote
class Actor:
def check(self, node_id, job_id):
context_dict = ray.get_runtime_context().get()
assert context_dict["node_id"] == node_id
assert context_dict["job_id"] == job_id
assert context_dict["actor_id"] is not None
assert context_dict["task_id"] is not None
a = Actor.remote()
ray.get(a.check.remote(context_dict["node_id"], context_dict["job_id"]))
@ray.remote
def task(node_id, job_id):
context_dict = ray.get_runtime_context().get()
assert context_dict["node_id"] == node_id
assert context_dict["job_id"] == job_id
assert context_dict["task_id"] is not None
assert "actor_id" not in context_dict
ray.get(task.remote(context_dict["node_id"], context_dict["job_id"]))
if __name__ == "__main__":
import pytest
sys.exit(pytest.main(["-v", __file__]))