Add API for querying global control state. (#431)

* Add API for querying global control state.

* Fix linting.

* Fix errors in Python 2.

* Fix bug in test.

* Fix bug in test.
This commit is contained in:
Robert Nishihara
2017-04-06 23:51:12 -07:00
committed by Philipp Moritz
parent 320109a5bd
commit 7af6f462fb
5 changed files with 388 additions and 2 deletions
+2 -1
View File
@@ -9,6 +9,7 @@ from ray.actor import actor
from ray.actor import get_gpu_ids
from ray.worker import EnvironmentVariable, env
from ray.worker import SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, SILENT_MODE
from ray.worker import global_state
# Ray version string
__version__ = "0.01"
@@ -17,7 +18,7 @@ __all__ = ["register_class", "error_info", "init", "connect", "disconnect",
"get", "put", "wait", "remote", "log_event", "log_span",
"flush_log", "actor", "get_gpu_ids", "EnvironmentVariable", "env",
"SCRIPT_MODE", "WORKER_MODE", "PYTHON_MODE", "SILENT_MODE",
"__version__"]
"global_state", "__version__"]
import ctypes
# Windows only
+248
View File
@@ -2,6 +2,62 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import binascii
import pickle
import redis
import sys
import ray.local_scheduler
# Import flatbuffer bindings.
from ray.core.generated.TaskInfo import TaskInfo
from ray.core.generated.TaskReply import TaskReply
from ray.core.generated.ResultTableReply import ResultTableReply
# These prefixes must be kept up-to-date with the definitions in
# ray_redis_module.cc.
DB_CLIENT_PREFIX = "CL:"
OBJECT_INFO_PREFIX = "OI:"
OBJECT_LOCATION_PREFIX = "OL:"
OBJECT_SUBSCRIBE_PREFIX = "OS:"
TASK_PREFIX = "TT:"
OBJECT_CHANNEL_PREFIX = "OC:"
# This mapping from integer to task state string must be kept up-to-date with
# the scheduling_state enum in task.h.
task_state_mapping = {
1: "WAITING",
2: "SCHEDULED",
4: "QUEUED",
8: "RUNNING",
16: "DONE",
32: "LOST",
64: "RECONSTRUCTING"
}
def decode(byte_str):
"""Make this unicode in Python 3, otherwise leave it as bytes."""
if sys.version_info >= (3, 0):
return byte_str.decode("ascii")
else:
return byte_str
def binary_to_object_id(binary_object_id):
return ray.local_scheduler.ObjectID(binary_object_id)
def binary_to_hex(identifier):
hex_identifier = binascii.hexlify(identifier)
if sys.version_info >= (3, 0):
hex_identifier = hex_identifier.decode()
return hex_identifier
def hex_to_binary(hex_identifier):
return binascii.unhexlify(hex_identifier)
def get_local_schedulers(worker):
local_schedulers = []
@@ -12,3 +68,195 @@ def get_local_schedulers(worker):
if client_info[b"client_type"] == b"local_scheduler":
local_schedulers.append(client_info)
return local_schedulers
class GlobalState(object):
"""A class used to interface with the Ray control state.
Attributes:
redis_client: The redis client used to query the redis server.
"""
def __init__(self):
"""Create a GlobalState object."""
self.redis_client = None
def _check_connected(self):
"""Check that the object has been initialized before it is used.
Raises:
Exception: An exception is raised if ray.init() has not been called yet.
"""
if self.redis_client is None:
raise Exception("The ray.global_state API cannot be used before "
"ray.init has been called.")
def _initialize_global_state(self, redis_ip_address, redis_port):
"""Initialize the GlobalState object by connecting to Redis.
Args:
redis_ip_address: The IP address of the node that the Redis server lives
on.
redis_port: The port that the Redis server is listening on.
"""
self.redis_client = redis.StrictRedis(host=redis_ip_address,
port=redis_port)
def _object_table(self, object_id_binary):
"""Fetch and parse the object table information for a single object ID.
Args:
object_id_binary: A string of bytes with the object ID to get information
about.
Returns:
A dictionary with information about the object ID in question.
"""
# Return information about a single object ID.
object_locations = self.redis_client.execute_command(
"RAY.OBJECT_TABLE_LOOKUP", object_id_binary)
if object_locations is not None:
manager_ids = [binary_to_hex(manager_id)
for manager_id in object_locations]
else:
manager_ids = None
result_table_response = self.redis_client.execute_command(
"RAY.RESULT_TABLE_LOOKUP", object_id_binary)
result_table_message = ResultTableReply.GetRootAsResultTableReply(
result_table_response, 0)
result = {"ManagerIDs": manager_ids,
"TaskID": binary_to_hex(result_table_message.TaskId()),
"IsPut": bool(result_table_message.IsPut())}
return result
def object_table(self, object_id=None):
"""Fetch and parse the object table information for one or more object IDs.
Args:
object_id: An object ID to fetch information about. If this is None, then
the entire object table is fetched.
Returns:
Information from the object table.
"""
self._check_connected()
if object_id is not None:
# Return information about a single object ID.
return self._object_table(object_id.id())
else:
# Return the entire object table.
object_info_keys = self.redis_client.keys(OBJECT_INFO_PREFIX + "*")
object_location_keys = self.redis_client.keys(
OBJECT_LOCATION_PREFIX + "*")
object_ids_binary = set(
[key[len(OBJECT_INFO_PREFIX):] for key in object_info_keys] +
[key[len(OBJECT_LOCATION_PREFIX):] for key in object_location_keys])
results = {}
for object_id_binary in object_ids_binary:
results[binary_to_object_id(object_id_binary)] = self._object_table(
object_id_binary)
return results
def _task_table(self, task_id_binary):
"""Fetch and parse the task table information for a single object task ID.
Args:
task_id_binary: A string of bytes with the task ID to get information
about.
Returns:
A dictionary with information about the task ID in question.
"""
task_table_response = self.redis_client.execute_command(
"RAY.TASK_TABLE_GET", task_id_binary)
if task_table_response is None:
raise Exception("There is no entry for task ID {} in the task table."
.format(binary_to_hex(task_id_binary)))
task_table_message = TaskReply.GetRootAsTaskReply(task_table_response, 0)
task_spec = task_table_message.TaskSpec()
task_spec_message = TaskInfo.GetRootAsTaskInfo(task_spec, 0)
args = []
for i in range(task_spec_message.ArgsLength()):
arg = task_spec_message.Args(i)
if len(arg.ObjectId()) != 0:
args.append(binary_to_object_id(arg.ObjectId()))
else:
args.append(pickle.loads(arg.Data()))
assert task_spec_message.RequiredResourcesLength() == 2
required_resources = {"CPUs": task_spec_message.RequiredResources(0),
"GPUs": task_spec_message.RequiredResources(1)}
task_spec_info = {
"DriverID": binary_to_hex(task_spec_message.DriverId()),
"TaskID": binary_to_hex(task_spec_message.TaskId()),
"ParentTaskID": binary_to_hex(task_spec_message.ParentTaskId()),
"ParentCounter": task_spec_message.ParentCounter(),
"ActorID": binary_to_hex(task_spec_message.ActorId()),
"ActorCounter": task_spec_message.ActorCounter(),
"FunctionID": binary_to_hex(task_spec_message.FunctionId()),
"Args": args,
"ReturnObjectIDs": [binary_to_object_id(task_spec_message.Returns(i))
for i in range(task_spec_message.ReturnsLength())],
"RequiredResources": required_resources}
return {"State": task_state_mapping[task_table_message.State()],
"LocalSchedulerID": binary_to_hex(
task_table_message.LocalSchedulerId()),
"TaskSpec": task_spec_info}
def task_table(self, task_id=None):
"""Fetch and parse the task table information for one or more task IDs.
Args:
task_id: A hex string of the task ID to fetch information about. If this
is None, then the task object table is fetched.
Returns:
Information from the task table.
"""
self._check_connected()
if task_id is not None:
return self._task_table(hex_to_binary(task_id))
else:
task_table_keys = self.redis_client.keys(TASK_PREFIX + "*")
results = {}
for key in task_table_keys:
task_id_binary = key[len(TASK_PREFIX):]
results[binary_to_hex(task_id_binary)] = self._task_table(
task_id_binary)
return results
def client_table(self):
"""Fetch and parse the Redis DB client table.
Returns:
Information about the Ray clients in the cluster.
"""
self._check_connected()
db_client_keys = self.redis_client.keys(DB_CLIENT_PREFIX + "*")
node_info = dict()
for key in db_client_keys:
client_info = self.redis_client.hgetall(key)
node_ip_address = decode(client_info[b"node_ip_address"])
if node_ip_address not in node_info:
node_info[node_ip_address] = []
client_info_parsed = {
"ClientType": decode(client_info[b"client_type"]),
"Deleted": bool(int(decode(client_info[b"deleted"]))),
"DBClientID": binary_to_hex(client_info[b"ray_client_id"])
}
if b"aux_address" in client_info:
client_info_parsed["AuxAddress"] = decode(client_info[b"aux_address"])
if b"num_cpus" in client_info:
client_info_parsed["NumCPUs"] = float(decode(client_info[b"num_cpus"]))
if b"num_gpus" in client_info:
client_info_parsed["NumGPUs"] = float(decode(client_info[b"num_gpus"]))
if b"local_scheduler_socket_name" in client_info:
client_info_parsed["LocalSchedulerSocketName"] = decode(
client_info[b"local_scheduler_socket_name"])
node_info[node_ip_address].append(client_info_parsed)
return node_info
+6
View File
@@ -20,6 +20,7 @@ import time
import traceback
# Ray modules
import ray.experimental.state as state
import ray.pickling as pickling
import ray.serialization as serialization
import ray.services as services
@@ -686,6 +687,8 @@ We use a global Worker object to ensure that there is a single worker object
per worker process.
"""
global_state = state.GlobalState()
env = RayEnvironmentVariables()
"""RayEnvironmentVariables: The environment variables that are shared by tasks.
@@ -1361,6 +1364,9 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker,
port=int(redis_port))
worker.lock = threading.Lock()
# Create an object for interfacing with the global state.
global_state._initialize_global_state(redis_ip_address, int(redis_port))
# Register the worker with Redis.
if mode in [SCRIPT_MODE, SILENT_MODE]:
# The concept of a driver is the same as the concept of a "job". Register