From 6172f94c044fcd351819eb4779763140984a0e27 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Tue, 29 May 2018 16:25:54 -0700 Subject: [PATCH] Implement Python global state API for xray. (#2125) * Implement global state API for xray. * Fix object table. * Fixes for log structure. * Implement cluster_resources. * Add driver task to task table. * Remove python flatbuffers code * Get some global state API tests running. * Python linting. * Fix linting. * Fix mock modules for doc * Copy over flatbuffer bindings. * Fix for tests. * Linting * Fix monitor crash. --- doc/source/conf.py | 5 +- python/ray/experimental/state.py | 406 +++++++++++++++------- python/ray/monitor.py | 34 +- python/ray/services.py | 9 + python/ray/worker.py | 26 +- python/setup.py | 19 +- src/common/lib/python/common_extension.cc | 34 +- src/ray/raylet/CMakeLists.txt | 11 + test/runtest.py | 109 ++++-- 9 files changed, 474 insertions(+), 179 deletions(-) diff --git a/doc/source/conf.py b/doc/source/conf.py index 4c8bf3d20..3b82a3819 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -38,7 +38,10 @@ MOCK_MODULES = ["gym", "ray.core.generated.TaskReply", "ray.core.generated.ResultTableReply", "ray.core.generated.TaskExecutionDependencies", - "ray.core.generated.ClientTableData"] + "ray.core.generated.ClientTableData", + "ray.core.generated.GcsTableEntry", + "ray.core.generated.ObjectTableData", + "ray.core.generated.ray.protocol.Task"] for mod_name in MOCK_MODULES: sys.modules[mod_name] = mock.Mock() diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index 9e0643a99..cc23a4e58 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -6,6 +6,7 @@ import copy from collections import defaultdict import heapq import json +import os import redis import sys import time @@ -20,6 +21,12 @@ from ray.core.generated.ResultTableReply import ResultTableReply from ray.core.generated.TaskExecutionDependencies import \ TaskExecutionDependencies +from ray.core.generated.ClientTableData import ClientTableData +from ray.core.generated.GcsTableEntry import GcsTableEntry +from ray.core.generated.ObjectTableData import ObjectTableData + +from ray.core.generated.ray.protocol.Task import Task + # These prefixes must be kept up-to-date with the definitions in # ray_redis_module.cc. DB_CLIENT_PREFIX = "CL:" @@ -30,6 +37,16 @@ TASK_PREFIX = "TT:" FUNCTION_PREFIX = "RemoteFunction:" OBJECT_CHANNEL_PREFIX = "OC:" +# These prefixes must be kept up-to-date with the TablePrefix enum in gcs.fbs. +# TODO(rkn): We should use scoped enums, in which case we should be able to +# just access the flatbuffer generated values. +TablePrefix_RAYLET_TASK = 2 +TablePrefix_RAYLET_TASK_string = "TASK" +TablePrefix_CLIENT = 3 +TablePrefix_CLIENT_string = "CLIENT" +TablePrefix_OBJECT = 4 +TablePrefix_OBJECT_string = "OBJECT" + # This mapping from integer to task state string must be kept up-to-date with # the scheduling_state enum in task.h. TASK_STATUS_WAITING = 1 @@ -57,7 +74,9 @@ class GlobalState(object): # backend to cut down on # of request RPCs. Attributes: - redis_client: The redis client used to query the redis server. + redis_client: The Redis client used to query the primary redis server. + redis_clients: Redis clients for each of the Redis shards. + use_raylet: True if we are using the raylet code path. """ def __init__(self): @@ -65,8 +84,10 @@ class GlobalState(object): # The redis server storing metadata, such as function table, client # table, log files, event logs, workers/actions info. self.redis_client = None - # A list of redis shards, storing the object table & task table. + # Clients for the redis shards, storing the object table & task table. self.redis_clients = None + # True if we are using the raylet code path and false otherwise. + self.use_raylet = None def _check_connected(self): """Check that the object has been initialized before it is used. @@ -97,8 +118,6 @@ class GlobalState(object): redis_ip_address: The IP address of the node that the Redis server lives on. redis_port: The port that the Redis server is listening on. - timeout: The maximum amount of time (in seconds) that we should - wait for the keys in Redis to be populated. """ self.redis_client = redis.StrictRedis( host=redis_ip_address, port=redis_port) @@ -138,6 +157,16 @@ class GlobalState(object): "ip_address_ports = {}".format( num_redis_shards, ip_address_ports)) + use_raylet = self.redis_client.get("UseRaylet") + if use_raylet is not None: + self.use_raylet = int(use_raylet) == 1 + elif os.environ.get("RAY_USE_XRAY") == "1": + # This environment variable is used in our testing setup. + print("Detected environment variable 'RAY_USE_XRAY'.") + self.use_raylet = True + else: + self.use_raylet = False + # Get the rest of the information. self.redis_clients = [] for ip_address_port in ip_address_ports: @@ -188,28 +217,48 @@ class GlobalState(object): object_id = ray.ObjectID(hex_to_binary(object_id)) # Return information about a single object ID. - object_locations = self._execute_command(object_id, - "RAY.OBJECT_TABLE_LOOKUP", - object_id.id()) - if object_locations is not None: - manager_ids = [ - binary_to_hex(manager_id) for manager_id in object_locations - ] + if not self.use_raylet: + # Use the non-raylet code path. + object_locations = self._execute_command( + object_id, "RAY.OBJECT_TABLE_LOOKUP", object_id.id()) + if object_locations is not None: + manager_ids = [ + binary_to_hex(manager_id) + for manager_id in object_locations + ] + else: + manager_ids = None + + result_table_response = self._execute_command( + object_id, "RAY.RESULT_TABLE_LOOKUP", object_id.id()) + result_table_message = ResultTableReply.GetRootAsResultTableReply( + result_table_response, 0) + + result = { + "ManagerIDs": manager_ids, + "TaskID": binary_to_hex(result_table_message.TaskId()), + "IsPut": bool(result_table_message.IsPut()), + "DataSize": result_table_message.DataSize(), + "Hash": binary_to_hex(result_table_message.Hash()) + } + else: - manager_ids = None + # Use the raylet code path. + message = self.redis_client.execute_command( + "RAY.TABLE_LOOKUP", TablePrefix_OBJECT, "", object_id.id()) + result = [] + gcs_entry = GcsTableEntry.GetRootAsGcsTableEntry(message, 0) - result_table_response = self._execute_command( - object_id, "RAY.RESULT_TABLE_LOOKUP", object_id.id()) - result_table_message = ResultTableReply.GetRootAsResultTableReply( - result_table_response, 0) - - result = { - "ManagerIDs": manager_ids, - "TaskID": binary_to_hex(result_table_message.TaskId()), - "IsPut": bool(result_table_message.IsPut()), - "DataSize": result_table_message.DataSize(), - "Hash": binary_to_hex(result_table_message.Hash()) - } + for i in range(gcs_entry.EntriesLength()): + entry = ObjectTableData.GetRootAsObjectTableData( + gcs_entry.Entries(i), 0) + object_info = { + "DataSize": entry.ObjectSize(), + "Manager": entry.Manager(), + "IsEviction": entry.IsEviction(), + "NumEvictions": entry.NumEvictions() + } + result.append(object_info) return result @@ -220,7 +269,6 @@ class GlobalState(object): object_id: An object ID to fetch information about. If this is None, then the entire object table is fetched. - Returns: Information from the object table. """ @@ -230,13 +278,23 @@ class GlobalState(object): return self._object_table(object_id) else: # Return the entire object table. - object_info_keys = self._keys(OBJECT_INFO_PREFIX + "*") - object_location_keys = self._keys(OBJECT_LOCATION_PREFIX + "*") - object_ids_binary = set( - [key[len(OBJECT_INFO_PREFIX):] for key in object_info_keys] + [ + if not self.use_raylet: + object_info_keys = self._keys(OBJECT_INFO_PREFIX + "*") + object_location_keys = self._keys(OBJECT_LOCATION_PREFIX + "*") + object_ids_binary = set([ + key[len(OBJECT_INFO_PREFIX):] for key in object_info_keys + ] + [ key[len(OBJECT_LOCATION_PREFIX):] for key in object_location_keys ]) + else: + object_keys = self.redis_client.keys( + TablePrefix_OBJECT_string + ":*") + object_ids_binary = { + key[len(TablePrefix_OBJECT_string + ":"):] + for key in object_keys + } + results = {} for object_id_binary in object_ids_binary: results[binary_to_object_id(object_id_binary)] = ( @@ -255,58 +313,108 @@ class GlobalState(object): TASK_STATUS_MAPPING should be used to parse the "State" field into a human-readable string. """ - task_table_response = self._execute_command(task_id, - "RAY.TASK_TABLE_GET", - task_id.id()) - if task_table_response is None: - raise Exception("There is no entry for task ID {} in the task " - "table.".format(binary_to_hex(task_id.id()))) - task_table_message = TaskReply.GetRootAsTaskReply( - task_table_response, 0) - task_spec = task_table_message.TaskSpec() - task_spec = ray.local_scheduler.task_from_string(task_spec) + if not self.use_raylet: + # Use the non-raylet code path. + task_table_response = self._execute_command( + task_id, "RAY.TASK_TABLE_GET", task_id.id()) + if task_table_response is None: + raise Exception("There is no entry for task ID {} in the task " + "table.".format(binary_to_hex(task_id.id()))) + task_table_message = TaskReply.GetRootAsTaskReply( + task_table_response, 0) + task_spec = task_table_message.TaskSpec() + task_spec = ray.local_scheduler.task_from_string(task_spec) - task_spec_info = { - "DriverID": binary_to_hex(task_spec.driver_id().id()), - "TaskID": binary_to_hex(task_spec.task_id().id()), - "ParentTaskID": binary_to_hex(task_spec.parent_task_id().id()), - "ParentCounter": task_spec.parent_counter(), - "ActorID": binary_to_hex(task_spec.actor_id().id()), - "ActorCreationID": binary_to_hex( - task_spec.actor_creation_id().id()), - "ActorCreationDummyObjectID": binary_to_hex( - task_spec.actor_creation_dummy_object_id().id()), - "ActorCounter": task_spec.actor_counter(), - "FunctionID": binary_to_hex(task_spec.function_id().id()), - "Args": task_spec.arguments(), - "ReturnObjectIDs": task_spec.returns(), - "RequiredResources": task_spec.required_resources() - } + task_spec_info = { + "DriverID": binary_to_hex(task_spec.driver_id().id()), + "TaskID": binary_to_hex(task_spec.task_id().id()), + "ParentTaskID": binary_to_hex(task_spec.parent_task_id().id()), + "ParentCounter": task_spec.parent_counter(), + "ActorID": binary_to_hex(task_spec.actor_id().id()), + "ActorCreationID": binary_to_hex( + task_spec.actor_creation_id().id()), + "ActorCreationDummyObjectID": binary_to_hex( + task_spec.actor_creation_dummy_object_id().id()), + "ActorCounter": task_spec.actor_counter(), + "FunctionID": binary_to_hex(task_spec.function_id().id()), + "Args": task_spec.arguments(), + "ReturnObjectIDs": task_spec.returns(), + "RequiredResources": task_spec.required_resources() + } - execution_dependencies_message = ( - TaskExecutionDependencies.GetRootAsTaskExecutionDependencies( - task_table_message.ExecutionDependencies(), 0)) - execution_dependencies = [ - ray.ObjectID( - execution_dependencies_message.ExecutionDependencies(i)) - for i in range( - execution_dependencies_message.ExecutionDependenciesLength()) - ] + execution_dependencies_message = ( + TaskExecutionDependencies.GetRootAsTaskExecutionDependencies( + task_table_message.ExecutionDependencies(), 0)) + execution_dependencies = [ + ray.ObjectID( + execution_dependencies_message.ExecutionDependencies(i)) + for i in range(execution_dependencies_message. + ExecutionDependenciesLength()) + ] - # TODO(rkn): The return fields ExecutionDependenciesString and - # ExecutionDependencies are redundant, so we should remove - # ExecutionDependencies. However, it is currently used in monitor.py. + # TODO(rkn): The return fields ExecutionDependenciesString and + # ExecutionDependencies are redundant, so we should remove + # ExecutionDependencies. However, it is currently used in + # monitor.py. - return { - "State": task_table_message.State(), - "LocalSchedulerID": binary_to_hex( - task_table_message.LocalSchedulerId()), - "ExecutionDependenciesString": task_table_message. - ExecutionDependencies(), - "ExecutionDependencies": execution_dependencies, - "SpillbackCount": task_table_message.SpillbackCount(), - "TaskSpec": task_spec_info - } + return { + "State": task_table_message.State(), + "LocalSchedulerID": binary_to_hex( + task_table_message.LocalSchedulerId()), + "ExecutionDependenciesString": task_table_message. + ExecutionDependencies(), + "ExecutionDependencies": execution_dependencies, + "SpillbackCount": task_table_message.SpillbackCount(), + "TaskSpec": task_spec_info + } + + else: + # Use the raylet code path. + message = self.redis_client.execute_command( + "RAY.TABLE_LOOKUP", TablePrefix_RAYLET_TASK, "", task_id.id()) + gcs_entries = GcsTableEntry.GetRootAsGcsTableEntry(message, 0) + + info = [] + for i in range(gcs_entries.EntriesLength()): + task_table_message = Task.GetRootAsTask( + gcs_entries.Entries(i), 0) + + task_table_message = Task.GetRootAsTask( + gcs_entries.Entries(0), 0) + execution_spec = task_table_message.TaskExecutionSpec() + task_spec = task_table_message.TaskSpecification() + task_spec = ray.local_scheduler.task_from_string(task_spec) + task_spec_info = { + "DriverID": binary_to_hex(task_spec.driver_id().id()), + "TaskID": binary_to_hex(task_spec.task_id().id()), + "ParentTaskID": binary_to_hex( + task_spec.parent_task_id().id()), + "ParentCounter": task_spec.parent_counter(), + "ActorID": binary_to_hex(task_spec.actor_id().id()), + "ActorCreationID": binary_to_hex( + task_spec.actor_creation_id().id()), + "ActorCreationDummyObjectID": binary_to_hex( + task_spec.actor_creation_dummy_object_id().id()), + "ActorCounter": task_spec.actor_counter(), + "FunctionID": binary_to_hex(task_spec.function_id().id()), + "Args": task_spec.arguments(), + "ReturnObjectIDs": task_spec.returns(), + "RequiredResources": task_spec.required_resources() + } + + info.append({ + "ExecutionSpec": { + "Dependencies": [ + execution_spec.Dependencies(i) + for i in range(execution_spec.DependenciesLength()) + ], + "LastTimestamp": execution_spec.LastTimestamp(), + "NumForwards": execution_spec.NumForwards() + }, + "TaskSpec": task_spec_info + }) + + return info def task_table(self, task_id=None): """Fetch and parse the task table information for one or more task IDs. @@ -315,7 +423,6 @@ class GlobalState(object): task_id: A hex string of the task ID to fetch information about. If this is None, then the task object table is fetched. - Returns: Information from the task table. """ @@ -324,10 +431,21 @@ class GlobalState(object): task_id = ray.ObjectID(hex_to_binary(task_id)) return self._task_table(task_id) else: - task_table_keys = self._keys(TASK_PREFIX + "*") + if not self.use_raylet: + task_table_keys = self._keys(TASK_PREFIX + "*") + task_ids_binary = [ + key[len(TASK_PREFIX):] for key in task_table_keys + ] + else: + task_table_keys = self.redis_client.keys( + TablePrefix_RAYLET_TASK_string + ":*") + task_ids_binary = [ + key[len(TablePrefix_RAYLET_TASK_string + ":"):] + for key in task_table_keys + ] + results = {} - for key in task_table_keys: - task_id_binary = key[len(TASK_PREFIX):] + for task_id_binary in task_ids_binary: results[binary_to_hex(task_id_binary)] = self._task_table( ray.ObjectID(task_id_binary)) return results @@ -359,41 +477,76 @@ class GlobalState(object): Information about the Ray clients in the cluster. """ self._check_connected() - db_client_keys = self.redis_client.keys(DB_CLIENT_PREFIX + "*") - node_info = {} - for key in db_client_keys: - client_info = self.redis_client.hgetall(key) - node_ip_address = decode(client_info[b"node_ip_address"]) - if node_ip_address not in node_info: - node_info[node_ip_address] = [] - client_info_parsed = {} - assert b"client_type" in client_info - assert b"deleted" in client_info - assert b"ray_client_id" in client_info - for field, value in client_info.items(): - if field == b"node_ip_address": - pass - elif field == b"client_type": - client_info_parsed["ClientType"] = decode(value) - elif field == b"deleted": - client_info_parsed["Deleted"] = bool(int(decode(value))) - elif field == b"ray_client_id": - client_info_parsed["DBClientID"] = binary_to_hex(value) - elif field == b"manager_address": - client_info_parsed["AuxAddress"] = decode(value) - elif field == b"local_scheduler_socket_name": - client_info_parsed["LocalSchedulerSocketName"] = ( - decode(value)) - elif client_info[b"client_type"] == b"local_scheduler": - # The remaining fields are resource types. - client_info_parsed[field.decode("ascii")] = float( - decode(value)) - else: - client_info_parsed[field.decode("ascii")] = decode(value) + if not self.use_raylet: + db_client_keys = self.redis_client.keys(DB_CLIENT_PREFIX + "*") + node_info = {} + for key in db_client_keys: + client_info = self.redis_client.hgetall(key) + node_ip_address = decode(client_info[b"node_ip_address"]) + if node_ip_address not in node_info: + node_info[node_ip_address] = [] + client_info_parsed = {} + assert b"client_type" in client_info + assert b"deleted" in client_info + assert b"ray_client_id" in client_info + for field, value in client_info.items(): + if field == b"node_ip_address": + pass + elif field == b"client_type": + client_info_parsed["ClientType"] = decode(value) + elif field == b"deleted": + client_info_parsed["Deleted"] = bool( + int(decode(value))) + elif field == b"ray_client_id": + client_info_parsed["DBClientID"] = binary_to_hex(value) + elif field == b"manager_address": + client_info_parsed["AuxAddress"] = decode(value) + elif field == b"local_scheduler_socket_name": + client_info_parsed["LocalSchedulerSocketName"] = ( + decode(value)) + elif client_info[b"client_type"] == b"local_scheduler": + # The remaining fields are resource types. + client_info_parsed[field.decode("ascii")] = float( + decode(value)) + else: + client_info_parsed[field.decode("ascii")] = decode( + value) - node_info[node_ip_address].append(client_info_parsed) + node_info[node_ip_address].append(client_info_parsed) - return node_info + return node_info + + else: + # This is the raylet code path. + NIL_CLIENT_ID = 20 * b"\xff" + message = self.redis_client.execute_command( + "RAY.TABLE_LOOKUP", TablePrefix_CLIENT, "", NIL_CLIENT_ID) + node_info = [] + gcs_entry = GcsTableEntry.GetRootAsGcsTableEntry(message, 0) + + for i in range(gcs_entry.EntriesLength()): + client = ClientTableData.GetRootAsClientTableData( + gcs_entry.Entries(i), 0) + + resources = { + client.ResourcesTotalLabel(i).decode("ascii"): + client.ResourcesTotalCapacity(i) + for i in range(client.ResourcesTotalLabelLength()) + } + node_info.append({ + "ClientID": ray.utils.binary_to_hex(client.ClientId()), + "IsInsertion": client.IsInsertion(), + "NodeManagerAddress": client.NodeManagerAddress().decode( + "ascii"), + "NodeManagerPort": client.NodeManagerPort(), + "ObjectManagerPort": client.ObjectManagerPort(), + "ObjectStoreSocketName": client.ObjectStoreSocketName() + .decode("ascii"), + "RayletSocketName": client.RayletSocketName().decode( + "ascii"), + "Resources": resources + }) + return node_info def log_files(self): """Fetch and return a dictionary of log file names to outputs. @@ -451,7 +604,7 @@ class GlobalState(object): # The heap is used to maintain the set of x tasks that occurred the # most recently across all of the workers, where x is defined as the # function parameter num. The key is the start time of the "get_task" - # component of each task. Calling heappop will result in the taks with + # component of each task. Calling heappop will result in the task with # the earliest "get_task_start" to be removed from the heap. heap = [] heapq.heapify(heap) @@ -889,6 +1042,8 @@ class GlobalState(object): Returns: A list of the live local schedulers. """ + if self.use_raylet: + raise Exception("The local_schedulers() method is deprecated.") clients = self.client_table() local_schedulers = [] for ip_address, client_list in clients.items(): @@ -972,15 +1127,22 @@ class GlobalState(object): A dictionary mapping resource name to the total quantity of that resource in the cluster. """ - local_schedulers = self.local_schedulers() resources = defaultdict(lambda: 0) + if not self.use_raylet: + local_schedulers = self.local_schedulers() - for local_scheduler in local_schedulers: - for key, value in local_scheduler.items(): - if key not in [ - "ClientType", "Deleted", "DBClientID", "AuxAddress", - "LocalSchedulerSocketName" - ]: + for local_scheduler in local_schedulers: + for key, value in local_scheduler.items(): + if key not in [ + "ClientType", "Deleted", "DBClientID", + "AuxAddress", "LocalSchedulerSocketName" + ]: + resources[key] += value + + else: + clients = self.client_table() + for client in clients: + for key, value in client["Resources"].items(): resources[key] += value return dict(resources) diff --git a/python/ray/monitor.py b/python/ray/monitor.py index 52b6b20cd..e1ef42128 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -65,6 +65,8 @@ class Monitor(object): Attributes: redis: A connection to the Redis server. + use_raylet: A bool indicating whether to use the raylet code path or + not. subscribe_client: A pubsub client for the Redis server. This is used to receive notifications about failed components. subscribed: A dictionary mapping channel names (str) to whether or not @@ -84,6 +86,7 @@ class Monitor(object): # Initialize the Redis clients. self.state = ray.experimental.state.GlobalState() self.state._initialize_global_state(redis_address, redis_port) + self.use_raylet = self.state.use_raylet self.redis = redis.StrictRedis( host=redis_address, port=redis_port, db=0) # TODO(swang): Update pubsub client to use ray.experimental.state once @@ -207,6 +210,11 @@ class Monitor(object): that we do not miss any notifications for deleted clients that occurred before we subscribed. """ + # Exit if we are using the raylet code path because client_table is + # implemented differently. TODO(rkn): Fix this. + if self.use_raylet: + return + clients = self.state.client_table() for node_ip_address, node_clients in clients.items(): for client in node_clients: @@ -514,18 +522,22 @@ class Monitor(object): # Handle messages from the subscription channels. while True: - # Update the mapping from local scheduler client ID to IP address. - # This is only used to update the load metrics for the autoscaler. - local_schedulers = self.state.local_schedulers() - self.local_scheduler_id_to_ip_map = {} - for local_scheduler_info in local_schedulers: - client_id = local_scheduler_info["DBClientID"] - ip_address = local_scheduler_info["AuxAddress"].split(":")[0] - self.local_scheduler_id_to_ip_map[client_id] = ip_address + # TODO(rkn): The autoscaler needs to be re-enabled for xray. + if not self.use_raylet: + # Update the mapping from local scheduler client ID to IP + # address. This is only used to update the load metrics for the + # autoscaler. + local_schedulers = self.state.local_schedulers() + self.local_scheduler_id_to_ip_map = {} + for local_scheduler_info in local_schedulers: + client_id = local_scheduler_info["DBClientID"] + ip_address = local_scheduler_info["AuxAddress"].split(":")[ + 0] + self.local_scheduler_id_to_ip_map[client_id] = ip_address - # Process autoscaling actions - if self.autoscaler: - self.autoscaler.update() + # Process autoscaling actions + if self.autoscaler: + self.autoscaler.update() # Record how many dead local schedulers and plasma managers we had # at the beginning of this round. num_dead_local_schedulers = len(self.dead_local_schedulers) diff --git a/python/ray/services.py b/python/ray/services.py index f1fdcc3ac..e46c480ff 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -400,6 +400,7 @@ def start_redis(node_ip_address, redis_shard_ports=None, num_redis_shards=1, redis_max_clients=None, + use_raylet=False, redirect_output=False, redirect_worker_output=False, cleanup=True, @@ -418,6 +419,8 @@ def start_redis(node_ip_address, shard. redis_max_clients: If this is provided, Ray will attempt to configure Redis with this maxclients number. + use_raylet: True if the new raylet code path should be used. This is + not supported yet. redirect_output (bool): True if output should be redirected to a file and false otherwise. redirect_worker_output (bool): True if worker output should be @@ -472,6 +475,11 @@ def start_redis(node_ip_address, port = assigned_port redis_address = address(node_ip_address, port) + redis_client = redis.StrictRedis(host=node_ip_address, port=port) + + # Store whether we're using the raylet code path or not. + redis_client.set("UseRaylet", 1 if use_raylet else 0) + # Register the number of Redis shards in the primary shard, so that clients # know how many redis shards to expect under RedisShards. primary_redis_client = redis.StrictRedis(host=node_ip_address, port=port) @@ -1314,6 +1322,7 @@ def start_ray_processes(address_info=None, redis_shard_ports=redis_shard_ports, num_redis_shards=num_redis_shards, redis_max_clients=redis_max_clients, + use_raylet=use_raylet, redirect_output=True, redirect_worker_output=redirect_worker_output, cleanup=cleanup) diff --git a/python/ray/worker.py b/python/ray/worker.py index 9b66141d5..75fc83047 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -2142,12 +2142,26 @@ def connect(info, ray.ObjectID(NIL_ACTOR_ID), ray.ObjectID(NIL_ACTOR_ID), ray.ObjectID(NIL_ACTOR_ID), ray.ObjectID(NIL_ACTOR_ID), nil_actor_counter, False, [], {"CPU": 0}, worker.use_raylet) - global_state._execute_command( - driver_task.task_id(), "RAY.TASK_TABLE_ADD", - driver_task.task_id().id(), - TASK_STATUS_RUNNING, NIL_LOCAL_SCHEDULER_ID, - driver_task.execution_dependencies_string(), 0, - ray.local_scheduler.task_to_string(driver_task)) + + # Add the driver task to the task table. + if not worker.use_raylet: + global_state._execute_command( + driver_task.task_id(), "RAY.TASK_TABLE_ADD", + driver_task.task_id().id(), TASK_STATUS_RUNNING, + NIL_LOCAL_SCHEDULER_ID, + driver_task.execution_dependencies_string(), 0, + ray.local_scheduler.task_to_string(driver_task)) + else: + TablePubsub_RAYLET_TASK = 2 + + # TODO(rkn): When we shard the GCS in xray, we will need to change + # this to use _execute_command. + global_state.redis_client.execute_command( + "RAY.TABLE_ADD", state.TablePrefix_RAYLET_TASK, + TablePubsub_RAYLET_TASK, + driver_task.task_id().id(), + driver_task._serialized_raylet_task()) + # Set the driver's current task ID to the task ID assigned to the # driver task. worker.current_task_id = driver_task.task_id() diff --git a/python/setup.py b/python/setup.py index 8f71376e6..565880682 100644 --- a/python/setup.py +++ b/python/setup.py @@ -26,6 +26,13 @@ ray_files = [ "ray/WebUI.ipynb" ] +# These are the directories where automatically generated Python flatbuffer +# bindings are created. +generated_python_directories = [ + "ray/core/generated", "ray/core/generated/ray", + "ray/core/generated/ray/protocol" +] + optional_ray_files = [] ray_ui_files = [ @@ -77,14 +84,14 @@ class build_ext(_build_ext.build_ext): files_to_include = ray_files + pyarrow_files + # Copy over the autogenerated flatbuffer Python bindings. + for directory in generated_python_directories: + for filename in os.listdir(directory): + if filename[-3:] == ".py": + files_to_include.append(os.path.join(directory, filename)) + for filename in files_to_include: self.move_file(filename) - # Copy over the autogenerated flatbuffer Python bindings. - generated_python_directory = "ray/core/generated" - for filename in os.listdir(generated_python_directory): - if filename[-3:] == ".py": - self.move_file( - os.path.join(generated_python_directory, filename)) # Try to copy over the optional files. for filename in optional_ray_files: diff --git a/src/common/lib/python/common_extension.cc b/src/common/lib/python/common_extension.cc index 82e811281..22d7877ba 100644 --- a/src/common/lib/python/common_extension.cc +++ b/src/common/lib/python/common_extension.cc @@ -10,7 +10,9 @@ #include "common.h" #include "common_extension.h" #include "common_protocol.h" +#include "ray/raylet/task.h" #include "ray/raylet/task_spec.h" +#include "ray/raylet/task_execution_spec.h" #include "task.h" #include @@ -143,7 +145,15 @@ PyObject *PyTask_to_string(PyObject *self, PyObject *args) { return NULL; } PyTask *task = (PyTask *) arg; - return PyBytes_FromStringAndSize((char *) task->spec, task->size); + if (!use_raylet(task)) { + return PyBytes_FromStringAndSize((char *) task->spec, task->size); + } else { + flatbuffers::FlatBufferBuilder fbb; + auto task_spec_string = task->task_spec->ToFlatbuffer(fbb); + fbb.Finish(task_spec_string); + return PyBytes_FromStringAndSize((char *) fbb.GetBufferPointer(), + fbb.GetSize()); + } } static PyObject *PyObjectID_id(PyObject *self) { @@ -693,6 +703,23 @@ static PyObject *PyTask_execution_dependencies_string(PyTask *self) { fbb.GetSize()); } +static PyObject *PyTask_to_serialized_flatbuf(PyTask *self) { + RAY_CHECK(use_raylet(self)); + + const std::vector execution_dependencies( + *self->execution_dependencies); + auto const execution_spec = ray::raylet::TaskExecutionSpecification( + std::move(execution_dependencies)); + auto const task = ray::raylet::Task(execution_spec, *self->task_spec); + + flatbuffers::FlatBufferBuilder fbb; + auto task_flatbuffer = task.ToFlatbuffer(fbb); + fbb.Finish(task_flatbuffer); + + return PyBytes_FromStringAndSize( + reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); +} + static PyMethodDef PyTask_methods[] = { {"function_id", (PyCFunction) PyTask_function_id, METH_NOARGS, "Return the function ID for this task."}, @@ -722,6 +749,11 @@ static PyMethodDef PyTask_methods[] = { {"execution_dependencies_string", (PyCFunction) PyTask_execution_dependencies_string, METH_NOARGS, "Return the execution dependencies for the task as a string."}, + {"_serialized_raylet_task", (PyCFunction) PyTask_to_serialized_flatbuf, + METH_NOARGS, + "This is a hack used to create a serialized flatbuffer object for the " + "driver task. We're doing this because creating the flatbuffer object in " + "Python didn't seem to work."}, {NULL} /* Sentinel */ }; diff --git a/src/ray/raylet/CMakeLists.txt b/src/ray/raylet/CMakeLists.txt index 8149fb35e..aca592de0 100644 --- a/src/ray/raylet/CMakeLists.txt +++ b/src/ray/raylet/CMakeLists.txt @@ -19,6 +19,17 @@ add_custom_command( add_custom_target(gen_node_manager_fbs DEPENDS ${NODE_MANAGER_FBS_OUTPUT_FILES}) +# Generate Python bindings for the flatbuffers objects. +set(PYTHON_OUTPUT_DIR ${CMAKE_BINARY_DIR}/generated/) +add_custom_command( + TARGET gen_node_manager_fbs + COMMAND ${FLATBUFFERS_COMPILER} -p -o ${PYTHON_OUTPUT_DIR} ${NODE_MANAGER_FBS_SRC} + DEPENDS ${FBS_DEPENDS} + COMMENT "Running flatc compiler on ${NODE_MANAGER_FBS_SRC}" + VERBATIM) + +add_dependencies(gen_node_manager_fbs flatbuffers_ep) + ADD_RAY_TEST(object_manager_integration_test STATIC_LINK_LIBS ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} gtest gtest_main pthread ${Boost_SYSTEM_LIBRARY}) ADD_RAY_TEST(worker_pool_test STATIC_LINK_LIBS ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY}) diff --git a/test/runtest.py b/test/runtest.py index 2022f4310..1286c7ad7 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -1904,9 +1904,6 @@ class GlobalStateAPI(unittest.TestCase): def tearDown(self): ray.worker.cleanup() - @unittest.skipIf( - os.environ.get("RAY_USE_XRAY") == "1", - "This test does not work with xray yet.") def testGlobalStateAPI(self): with self.assertRaises(Exception): ray.global_state.object_table() @@ -1942,27 +1939,57 @@ class GlobalStateAPI(unittest.TestCase): task_table = ray.global_state.task_table() self.assertEqual(len(task_table), 1) self.assertEqual(driver_task_id, list(task_table.keys())[0]) - self.assertEqual(task_table[driver_task_id]["State"], - ray.experimental.state.TASK_STATUS_RUNNING) - self.assertEqual(task_table[driver_task_id]["TaskSpec"]["TaskID"], - driver_task_id) - self.assertEqual(task_table[driver_task_id]["TaskSpec"]["ActorID"], - ID_SIZE * "ff") - self.assertEqual(task_table[driver_task_id]["TaskSpec"]["Args"], []) - self.assertEqual(task_table[driver_task_id]["TaskSpec"]["DriverID"], - driver_id) - self.assertEqual(task_table[driver_task_id]["TaskSpec"]["FunctionID"], - ID_SIZE * "ff") - self.assertEqual( - (task_table[driver_task_id]["TaskSpec"]["ReturnObjectIDs"]), []) + if not ray.worker.global_worker.use_raylet: + self.assertEqual(task_table[driver_task_id]["State"], + ray.experimental.state.TASK_STATUS_RUNNING) + if not ray.worker.global_worker.use_raylet: + self.assertEqual(task_table[driver_task_id]["TaskSpec"]["TaskID"], + driver_task_id) + self.assertEqual(task_table[driver_task_id]["TaskSpec"]["ActorID"], + ID_SIZE * "ff") + self.assertEqual(task_table[driver_task_id]["TaskSpec"]["Args"], + []) + self.assertEqual( + task_table[driver_task_id]["TaskSpec"]["DriverID"], driver_id) + self.assertEqual( + task_table[driver_task_id]["TaskSpec"]["FunctionID"], + ID_SIZE * "ff") + self.assertEqual( + (task_table[driver_task_id]["TaskSpec"]["ReturnObjectIDs"]), + []) + + else: + self.assertEqual(len(task_table[driver_task_id]), 1) + self.assertEqual( + task_table[driver_task_id][0]["TaskSpec"]["TaskID"], + driver_task_id) + self.assertEqual( + task_table[driver_task_id][0]["TaskSpec"]["ActorID"], + ID_SIZE * "ff") + self.assertEqual(task_table[driver_task_id][0]["TaskSpec"]["Args"], + []) + self.assertEqual( + task_table[driver_task_id][0]["TaskSpec"]["DriverID"], + driver_id) + self.assertEqual( + task_table[driver_task_id][0]["TaskSpec"]["FunctionID"], + ID_SIZE * "ff") + self.assertEqual( + (task_table[driver_task_id][0]["TaskSpec"]["ReturnObjectIDs"]), + []) client_table = ray.global_state.client_table() node_ip_address = ray.worker.global_worker.node_ip_address - self.assertEqual(len(client_table[node_ip_address]), 3) - manager_client = [ - c for c in client_table[node_ip_address] - if c["ClientType"] == "plasma_manager" - ][0] + + if not ray.worker.global_worker.use_raylet: + self.assertEqual(len(client_table[node_ip_address]), 3) + manager_client = [ + c for c in client_table[node_ip_address] + if c["ClientType"] == "plasma_manager" + ][0] + else: + assert len(client_table) == 1 + assert client_table[0]["NodeManagerAddress"] == node_ip_address @ray.remote def f(*xs): @@ -1980,11 +2007,17 @@ class GlobalStateAPI(unittest.TestCase): task_id_set = set(task_table.keys()) task_id_set.remove(driver_task_id) task_id = list(task_id_set)[0] - if task_table[task_id]["State"] == "DONE": + if ray.worker.global_worker.use_raylet: + break + if (task_table[task_id]["State"] == + ray.experimental.state.TASK_STATUS_DONE): break time.sleep(0.1) function_table = ray.global_state.function_table() - task_spec = task_table[task_id]["TaskSpec"] + if not ray.worker.global_worker.use_raylet: + task_spec = task_table[task_id]["TaskSpec"] + else: + task_spec = task_table[task_id][0]["TaskSpec"] self.assertEqual(task_spec["ActorID"], ID_SIZE * "ff") self.assertEqual(task_spec["Args"], [1, "hi", x_id]) self.assertEqual(task_spec["DriverID"], driver_id) @@ -2015,19 +2048,31 @@ class GlobalStateAPI(unittest.TestCase): "update.") # Wait for the object table to be updated. - wait_for_object_table() + if not ray.worker.global_worker.use_raylet: + wait_for_object_table() + object_table = ray.global_state.object_table() self.assertEqual(len(object_table), 2) - self.assertEqual(object_table[x_id]["IsPut"], True) - self.assertEqual(object_table[x_id]["TaskID"], driver_task_id) - self.assertEqual(object_table[x_id]["ManagerIDs"], - [manager_client["DBClientID"]]) + if not ray.worker.global_worker.use_raylet: + self.assertEqual(object_table[x_id]["IsPut"], True) + self.assertEqual(object_table[x_id]["TaskID"], driver_task_id) + self.assertEqual(object_table[x_id]["ManagerIDs"], + [manager_client["DBClientID"]]) - self.assertEqual(object_table[result_id]["IsPut"], False) - self.assertEqual(object_table[result_id]["TaskID"], task_id) - self.assertEqual(object_table[result_id]["ManagerIDs"], - [manager_client["DBClientID"]]) + self.assertEqual(object_table[result_id]["IsPut"], False) + self.assertEqual(object_table[result_id]["TaskID"], task_id) + self.assertEqual(object_table[result_id]["ManagerIDs"], + [manager_client["DBClientID"]]) + + else: + assert len(object_table[x_id]) == 1 + self.assertEqual(object_table[x_id][0]["IsEviction"], False) + self.assertEqual(object_table[x_id][0]["NumEvictions"], 0) + + assert len(object_table[result_id]) == 1 + self.assertEqual(object_table[result_id][0]["IsEviction"], False) + self.assertEqual(object_table[result_id][0]["NumEvictions"], 0) self.assertEqual(object_table[x_id], ray.global_state.object_table(x_id))