diff --git a/python/ray/common/redis_module/runtest.py b/python/ray/common/redis_module/runtest.py index 493af9c50..42e6057cd 100644 --- a/python/ray/common/redis_module/runtest.py +++ b/python/ray/common/redis_module/runtest.py @@ -52,7 +52,7 @@ def get_next_message(pubsub_client, timeout_seconds=10): class TestGlobalStateStore(unittest.TestCase): def setUp(self): - redis_port, _ = ray.services.start_redis() + redis_port, _ = ray.services.start_redis_instance() self.redis = redis.StrictRedis(host="localhost", port=redis_port, db=0) def tearDown(self): @@ -308,6 +308,10 @@ class TestGlobalStateStore(unittest.TestCase): TASK_STATUS_SCHEDULED = 2 TASK_STATUS_QUEUED = 4 + # make sure somebody will get a notification (checked in the redis module) + p = self.redis.pubsub() + p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX)) + def check_task_reply(message, task_args, updated=False): task_status, local_scheduler_id, task_spec = task_args task_reply_object = TaskReply.GetRootAsTaskReply(message, 0) @@ -388,33 +392,53 @@ class TestGlobalStateStore(unittest.TestCase): self.assertNotEqual(get_response, old_response) check_task_reply(get_response, task_args[1:]) + def check_task_subscription(self, p, scheduling_state, local_scheduler_id): + task_args = [b"task_id", scheduling_state, + local_scheduler_id.encode("ascii"), b"task_spec"] + self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args) + # Receive the data. + message = get_next_message(p)["data"] + # Check that the notification object is correct. + notification_object = TaskReply.GetRootAsTaskReply(message, 0) + self.assertEqual(notification_object.TaskId(), b"task_id") + self.assertEqual(notification_object.State(), scheduling_state) + self.assertEqual(notification_object.LocalSchedulerId(), + local_scheduler_id.encode("ascii")) + self.assertEqual(notification_object.TaskSpec(), b"task_spec") + def testTaskTableSubscribe(self): scheduling_state = 1 local_scheduler_id = "local_scheduler_id" # Subscribe to the task table. p = self.redis.pubsub() p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX)) + # Receive acknowledgment. + self.assertEqual(get_next_message(p)["data"], 1) + self.check_task_subscription(p, scheduling_state, local_scheduler_id) + # unsubscribe to make sure there is only one subscriber at a given time + p.punsubscribe("{prefix}*:*".format(prefix=TASK_PREFIX)) + # Receive acknowledgment. + self.assertEqual(get_next_message(p)["data"], 0) + p.psubscribe("{prefix}*:{state}".format( prefix=TASK_PREFIX, state=scheduling_state)) + # Receive acknowledgment. + self.assertEqual(get_next_message(p)["data"], 1) + self.check_task_subscription(p, scheduling_state, local_scheduler_id) + p.punsubscribe("{prefix}*:{state}".format( + prefix=TASK_PREFIX, state=scheduling_state)) + # Receive acknowledgment. + self.assertEqual(get_next_message(p)["data"], 0) + p.psubscribe("{prefix}{local_scheduler_id}:*".format( prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id)) - task_args = [b"task_id", scheduling_state, - local_scheduler_id.encode("ascii"), b"task_spec"] - self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args) - # Receive the acknowledgement message. + # Receive acknowledgment. self.assertEqual(get_next_message(p)["data"], 1) - self.assertEqual(get_next_message(p)["data"], 2) - self.assertEqual(get_next_message(p)["data"], 3) - # Receive the actual data. - for i in range(3): - message = get_next_message(p)["data"] - # Check that the notification object is correct. - notification_object = TaskReply.GetRootAsTaskReply(message, 0) - self.assertEqual(notification_object.TaskId(), b"task_id") - self.assertEqual(notification_object.State(), scheduling_state) - self.assertEqual(notification_object.LocalSchedulerId(), - local_scheduler_id.encode("ascii")) - self.assertEqual(notification_object.TaskSpec(), b"task_spec") + self.check_task_subscription(p, scheduling_state, local_scheduler_id) + p.punsubscribe("{prefix}{local_scheduler_id}:*".format( + prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id)) + # Receive acknowledgment. + self.assertEqual(get_next_message(p)["data"], 0) if __name__ == "__main__": diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index e7c649f28..7f7c17b34 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -5,6 +5,7 @@ from __future__ import print_function import pickle import redis +import ray from ray.utils import (decode, binary_to_object_id, binary_to_hex, hex_to_binary) @@ -25,14 +26,21 @@ OBJECT_CHANNEL_PREFIX = "OC:" # This mapping from integer to task state string must be kept up-to-date with # the scheduling_state enum in task.h. -task_state_mapping = { - 1: "WAITING", - 2: "SCHEDULED", - 4: "QUEUED", - 8: "RUNNING", - 16: "DONE", - 32: "LOST", - 64: "RECONSTRUCTING" +TASK_STATUS_WAITING = 1 +TASK_STATUS_SCHEDULED = 2 +TASK_STATUS_QUEUED = 4 +TASK_STATUS_RUNNING = 8 +TASK_STATUS_DONE = 16 +TASK_STATUS_LOST = 32 +TASK_STATUS_RECONSTRUCTING = 64 +TASK_STATUS_MAPPING = { + TASK_STATUS_WAITING: "WAITING", + TASK_STATUS_SCHEDULED: "SCHEDULED", + TASK_STATUS_QUEUED: "QUEUED", + TASK_STATUS_RUNNING: "RUNNING", + TASK_STATUS_DONE: "DONE", + TASK_STATUS_LOST: "LOST", + TASK_STATUS_RECONSTRUCTING: "RECONSTRUCTING", } @@ -66,8 +74,54 @@ class GlobalState(object): """ self.redis_client = redis.StrictRedis(host=redis_ip_address, port=redis_port) + self.redis_clients = [] + num_redis_shards = self.redis_client.get("NumRedisShards") + if num_redis_shards is None: + raise Exception("No entry found for NumRedisShards") + num_redis_shards = int(num_redis_shards) + if (num_redis_shards < 1): + raise Exception("Expected at least one Redis shard, found " + "{}.".format(num_redis_shards)) - def _object_table(self, object_id_binary): + ip_address_ports = self.redis_client.lrange("RedisShards", start=0, end=-1) + if len(ip_address_ports) != num_redis_shards: + raise Exception("Expected {} Redis shard addresses, found " + "{}".format(num_redis_shards, len(ip_address_ports))) + + for ip_address_port in ip_address_ports: + shard_address, shard_port = ip_address_port.split(b":") + self.redis_clients.append(redis.StrictRedis(host=shard_address, + port=shard_port)) + + def _execute_command(self, key, *args): + """Execute a Redis command on the appropriate Redis shard based on key. + + Args: + key: The object ID or the task ID that the query is about. + args: The command to run. + + Returns: + The value returned by the Redis command. + """ + client = self.redis_clients[key.redis_shard_hash() % + len(self.redis_clients)] + return client.execute_command(*args) + + def _keys(self, pattern): + """Execute the KEYS command on all Redis shards. + + Args: + pattern: The KEYS pattern to query. + + Returns: + The concatenated list of results from all shards. + """ + result = [] + for client in self.redis_clients: + result.extend(client.keys(pattern)) + return result + + def _object_table(self, object_id): """Fetch and parse the object table information for a single object ID. Args: @@ -78,16 +132,18 @@ class GlobalState(object): A dictionary with information about the object ID in question. """ # Return information about a single object ID. - object_locations = self.redis_client.execute_command( - "RAY.OBJECT_TABLE_LOOKUP", object_id_binary) + object_locations = self._execute_command(object_id, + "RAY.OBJECT_TABLE_LOOKUP", + object_id.id()) if object_locations is not None: manager_ids = [binary_to_hex(manager_id) for manager_id in object_locations] else: manager_ids = None - result_table_response = self.redis_client.execute_command( - "RAY.RESULT_TABLE_LOOKUP", object_id_binary) + result_table_response = self._execute_command(object_id, + "RAY.RESULT_TABLE_LOOKUP", + object_id.id()) result_table_message = ResultTableReply.GetRootAsResultTableReply( result_table_response, 0) @@ -111,22 +167,21 @@ class GlobalState(object): self._check_connected() if object_id is not None: # Return information about a single object ID. - return self._object_table(object_id.id()) + return self._object_table(object_id) else: # Return the entire object table. - object_info_keys = self.redis_client.keys(OBJECT_INFO_PREFIX + "*") - object_location_keys = self.redis_client.keys( - OBJECT_LOCATION_PREFIX + "*") + object_info_keys = self._keys(OBJECT_INFO_PREFIX + "*") + object_location_keys = self._keys(OBJECT_LOCATION_PREFIX + "*") object_ids_binary = set( [key[len(OBJECT_INFO_PREFIX):] for key in object_info_keys] + [key[len(OBJECT_LOCATION_PREFIX):] for key in object_location_keys]) results = {} for object_id_binary in object_ids_binary: results[binary_to_object_id(object_id_binary)] = self._object_table( - object_id_binary) + binary_to_object_id(object_id_binary)) return results - def _task_table(self, task_id_binary): + def _task_table(self, task_id): """Fetch and parse the task table information for a single object task ID. Args: @@ -135,12 +190,15 @@ class GlobalState(object): Returns: A dictionary with information about the task ID in question. + TASK_STATUS_MAPPING should be used to parse the "State" field into a + human-readable string. """ - task_table_response = self.redis_client.execute_command( - "RAY.TASK_TABLE_GET", task_id_binary) + task_table_response = self._execute_command(task_id, + "RAY.TASK_TABLE_GET", + task_id.id()) if task_table_response is None: raise Exception("There is no entry for task ID {} in the task table." - .format(binary_to_hex(task_id_binary))) + .format(binary_to_hex(task_id.id()))) task_table_message = TaskReply.GetRootAsTaskReply(task_table_response, 0) task_spec = task_table_message.TaskSpec() task_spec_message = TaskInfo.GetRootAsTaskInfo(task_spec, 0) @@ -167,7 +225,7 @@ class GlobalState(object): for i in range(task_spec_message.ReturnsLength())], "RequiredResources": required_resources} - return {"State": task_state_mapping[task_table_message.State()], + return {"State": task_table_message.State(), "LocalSchedulerID": binary_to_hex( task_table_message.LocalSchedulerId()), "TaskSpec": task_spec_info} @@ -185,14 +243,15 @@ class GlobalState(object): """ self._check_connected() if task_id is not None: - return self._task_table(hex_to_binary(task_id)) + task_id = ray.local_scheduler.ObjectID(hex_to_binary(task_id)) + return self._task_table(task_id) else: - task_table_keys = self.redis_client.keys(TASK_PREFIX + "*") + task_table_keys = self._keys(TASK_PREFIX + "*") results = {} for key in task_table_keys: task_id_binary = key[len(TASK_PREFIX):] results[binary_to_hex(task_id_binary)] = self._task_table( - task_id_binary) + ray.local_scheduler.ObjectID(task_id_binary)) return results def function_table(self, function_id=None): diff --git a/python/ray/global_scheduler/global_scheduler_services.py b/python/ray/global_scheduler/global_scheduler_services.py index eb7ec7ee9..8cbc6f28e 100644 --- a/python/ray/global_scheduler/global_scheduler_services.py +++ b/python/ray/global_scheduler/global_scheduler_services.py @@ -7,9 +7,9 @@ import subprocess import time -def start_global_scheduler(redis_address, node_ip_address, use_valgrind=False, - use_profiler=False, stdout_file=None, - stderr_file=None): +def start_global_scheduler(redis_address, node_ip_address, + use_valgrind=False, use_profiler=False, + stdout_file=None, stderr_file=None): """Start a global scheduler process. Args: diff --git a/python/ray/global_scheduler/test/test.py b/python/ray/global_scheduler/test/test.py index 29df06f3b..af32c194f 100644 --- a/python/ray/global_scheduler/test/test.py +++ b/python/ray/global_scheduler/test/test.py @@ -5,7 +5,6 @@ from __future__ import print_function import numpy as np import os import random -import redis import signal import sys import time @@ -17,6 +16,7 @@ import ray.plasma as plasma from ray.plasma.utils import create_object from ray import services +from ray.experimental import state USE_VALGRIND = False PLASMA_STORE_MEMORY = 1000000000 @@ -26,13 +26,6 @@ NUM_CLUSTER_NODES = 2 NIL_WORKER_ID = 20 * b"\xff" NIL_ACTOR_ID = 20 * b"\xff" -# These constants must match the scheduling state enum in task.h. -TASK_STATUS_WAITING = 1 -TASK_STATUS_SCHEDULED = 2 -TASK_STATUS_QUEUED = 4 -TASK_STATUS_RUNNING = 8 -TASK_STATUS_DONE = 16 - # These constants are an implementation detail of ray_redis_module.cc, so this # must be kept in sync with that file. DB_CLIENT_PREFIX = "CL:" @@ -63,15 +56,17 @@ class TestGlobalScheduler(unittest.TestCase): def setUp(self): # Start one Redis server and N pairs of (plasma, local_scheduler) - node_ip_address = "127.0.0.1" - redis_port, self.redis_process = services.start_redis(cleanup=False) - redis_address = services.address(node_ip_address, redis_port) - # Create a Redis client. - self.redis_client = redis.StrictRedis(host=node_ip_address, - port=redis_port) + self.node_ip_address = "127.0.0.1" + redis_address, redis_shards = services.start_redis(self.node_ip_address) + redis_port = services.get_port(redis_address) + time.sleep(0.1) + # Create a client for the global state store. + self.state = state.GlobalState() + self.state._initialize_global_state(self.node_ip_address, redis_port) + # Start one global scheduler. self.p1 = global_scheduler.start_global_scheduler( - redis_address, node_ip_address, use_valgrind=USE_VALGRIND) + redis_address, self.node_ip_address, use_valgrind=USE_VALGRIND) self.plasma_store_pids = [] self.plasma_manager_pids = [] self.local_scheduler_pids = [] @@ -89,7 +84,8 @@ class TestGlobalScheduler(unittest.TestCase): redis_address) plasma_manager_name, p3, plasma_manager_port = manager_info self.plasma_manager_pids.append(p3) - plasma_address = "{}:{}".format(node_ip_address, plasma_manager_port) + plasma_address = "{}:{}".format(self.node_ip_address, + plasma_manager_port) plasma_client = plasma.PlasmaClient(plasma_store_name, plasma_manager_name) self.plasma_clients.append(plasma_client) @@ -116,7 +112,10 @@ class TestGlobalScheduler(unittest.TestCase): for p4 in self.local_scheduler_pids: self.assertEqual(p4.poll(), None) - self.assertEqual(self.redis_process.poll(), None) + redis_processes = services.all_processes[ + services.PROCESS_TYPE_REDIS_SERVER] + for redis_process in redis_processes: + self.assertEqual(redis_process.poll(), None) # Kill the global scheduler. if USE_VALGRIND: @@ -135,7 +134,9 @@ class TestGlobalScheduler(unittest.TestCase): p4.kill() # Kill Redis. In the event that we are using valgrind, this needs to happen # after we kill the global scheduler. - self.redis_process.kill() + while redis_processes: + redis_process = redis_processes.pop() + redis_process.kill() def get_plasma_manager_id(self): """Get the db_client_id with client_type equal to plasma_manager. @@ -150,11 +151,10 @@ class TestGlobalScheduler(unittest.TestCase): """ db_client_id = None - client_list = self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX)) - for client_id in client_list: - response = self.redis_client.hget(client_id, b"client_type") - if response == b"plasma_manager": - db_client_id = client_id + client_list = self.state.client_table()[self.node_ip_address] + for client in client_list: + if client["ClientType"] == "plasma_manager": + db_client_id = client["DBClientID"] break return db_client_id @@ -178,18 +178,16 @@ class TestGlobalScheduler(unittest.TestCase): # There should be 2n+1 db clients: the global scheduler + one local # scheduler and one plasma per node. self.assertEqual( - len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))), + len(self.state.client_table()[self.node_ip_address]), 2 * NUM_CLUSTER_NODES + 1) db_client_id = self.get_plasma_manager_id() assert(db_client_id is not None) - assert(db_client_id.startswith(b"CL:")) - db_client_id = db_client_id[len(b"CL:"):] # Remove the CL: prefix. def test_integration_single_task(self): # There should be three db clients, the global scheduler, the local # scheduler, and the plasma manager. self.assertEqual( - len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))), + len(self.state.client_table()[self.node_ip_address]), 2 * NUM_CLUSTER_NODES + 1) num_return_vals = [0, 1, 2, 3, 5, 10] @@ -212,15 +210,15 @@ class TestGlobalScheduler(unittest.TestCase): # local scheduler num_retries = 10 while num_retries > 0: - task_entries = self.redis_client.keys("{}*".format(TASK_PREFIX)) + task_entries = self.state.task_table() self.assertLessEqual(len(task_entries), 1) if len(task_entries) == 1: - task_contents = self.redis_client.hgetall(task_entries[0]) - task_status = int(task_contents[b"state"]) - self.assertTrue(task_status in [TASK_STATUS_WAITING, - TASK_STATUS_SCHEDULED, - TASK_STATUS_QUEUED]) - if task_status == TASK_STATUS_QUEUED: + task_id, task = task_entries.popitem() + task_status = task["State"] + self.assertTrue(task_status in [state.TASK_STATUS_WAITING, + state.TASK_STATUS_SCHEDULED, + state.TASK_STATUS_QUEUED]) + if task_status == state.TASK_STATUS_QUEUED: break else: print(task_status) @@ -228,7 +226,7 @@ class TestGlobalScheduler(unittest.TestCase): num_retries -= 1 time.sleep(1) - if num_retries <= 0 and task_status != TASK_STATUS_QUEUED: + if num_retries <= 0 and task_status != state.TASK_STATUS_QUEUED: # Failed to submit and schedule a single task -- bail. self.tearDown() sys.exit(1) @@ -237,7 +235,7 @@ class TestGlobalScheduler(unittest.TestCase): # There should be three db clients, the global scheduler, the local # scheduler, and the plasma manager. self.assertEqual( - len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))), + len(self.state.client_table()[self.node_ip_address]), 2 * NUM_CLUSTER_NODES + 1) num_return_vals = [0, 1, 2, 3, 5, 10] @@ -264,34 +262,31 @@ class TestGlobalScheduler(unittest.TestCase): num_retries = 10 num_tasks_done = 0 while num_retries > 0: - task_entries = self.redis_client.keys("{}*".format(TASK_PREFIX)) + task_entries = self.state.task_table() self.assertLessEqual(len(task_entries), num_tasks) # First, check if all tasks made it to Redis. if len(task_entries) == num_tasks: - task_contents = [self.redis_client.hgetall(task_entries[i]) - for i in range(len(task_entries))] - task_statuses = [int(contents[b"state"]) for contents in task_contents] - self.assertTrue(all([status in [TASK_STATUS_WAITING, - TASK_STATUS_SCHEDULED, - TASK_STATUS_QUEUED] + task_statuses = [task_entry["State"] for task_entry in + task_entries.values()] + self.assertTrue(all([status in [state.TASK_STATUS_WAITING, + state.TASK_STATUS_SCHEDULED, + state.TASK_STATUS_QUEUED] for status in task_statuses])) - num_tasks_done = task_statuses.count(TASK_STATUS_QUEUED) - num_tasks_scheduled = task_statuses.count(TASK_STATUS_SCHEDULED) - num_tasks_waiting = task_statuses.count(TASK_STATUS_WAITING) + num_tasks_done = task_statuses.count(state.TASK_STATUS_QUEUED) + num_tasks_scheduled = task_statuses.count(state.TASK_STATUS_SCHEDULED) + num_tasks_waiting = task_statuses.count(state.TASK_STATUS_WAITING) print("tasks in Redis = {}, tasks waiting = {}, tasks scheduled = {}, " "tasks queued = {}, retries left = {}" .format(len(task_entries), num_tasks_waiting, num_tasks_scheduled, num_tasks_done, num_retries)) - if all([status == TASK_STATUS_QUEUED for status in task_statuses]): + if all([status == state.TASK_STATUS_QUEUED for status in + task_statuses]): # We're done, so pass. break num_retries -= 1 time.sleep(0.1) - if num_tasks_done != num_tasks: - # At least one of the tasks failed to schedule. - self.tearDown() - sys.exit(2) + self.assertEqual(num_tasks_done, num_tasks) def test_integration_many_tasks_handler_sync(self): self.integration_many_tasks_helper(timesync=True) diff --git a/python/ray/monitor.py b/python/ray/monitor.py index eaeb4d746..8bdedae97 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -12,11 +12,13 @@ import time import ray from ray.services import get_ip_address from ray.services import get_port +from ray.utils import binary_to_object_id +from ray.utils import binary_to_hex +from ray.utils import hex_to_binary # Import flatbuffer bindings. from ray.core.generated.SubscribeToDBClientTableReply \ import SubscribeToDBClientTableReply -from ray.core.generated.TaskReply import TaskReply from ray.core.generated.DriverTableMessage import DriverTableMessage # These variables must be kept in sync with the C codebase. @@ -31,7 +33,6 @@ TASK_STATUS_LOST = 32 PLASMA_MANAGER_HEARTBEAT_CHANNEL = b"plasma_managers" DRIVER_DEATH_CHANNEL = b"driver_deaths" # common/redis_module/ray_redis_module.cc -TASK_PREFIX = "TT:" OBJECT_PREFIX = "OL:" DB_CLIENT_PREFIX = "CL:" DB_CLIENT_TABLE_NAME = b"db_clients" @@ -43,7 +44,7 @@ PLASMA_MANAGER_CLIENT_TYPE = b"plasma_manager" # Set up logging. logging.basicConfig() log = logging.getLogger() -log.setLevel(logging.WARN) +log.setLevel(logging.INFO) class Monitor(object): @@ -70,7 +71,11 @@ class Monitor(object): """ def __init__(self, redis_address, redis_port): # Initialize the Redis clients. + self.state = ray.experimental.state.GlobalState() + self.state._initialize_global_state(redis_address, redis_port) self.redis = redis.StrictRedis(host=redis_address, port=redis_port, db=0) + # TODO(swang): Update pubsub client to use ray.experimental.state once + # subscriptions are implemented there. self.subscribe_client = self.redis.pubsub() self.subscribed = {} # Initialize data structures to keep track of the active database clients. @@ -97,23 +102,17 @@ class Monitor(object): TASK_STATUS_LOST. A local scheduler is deemed dead if it is in self.dead_local_schedulers. """ - task_ids = self.redis.scan_iter( - match="{prefix}*".format(prefix=TASK_PREFIX)) + tasks = self.state.task_table() num_tasks_updated = 0 - for task_id in task_ids: - task_id = task_id[len(TASK_PREFIX):] - response = self.redis.execute_command("RAY.TASK_TABLE_GET", task_id) - # Parse the serialized task object. - task_object = TaskReply.GetRootAsTaskReply(response, 0) - local_scheduler_id = task_object.LocalSchedulerId() + for task_id, task in tasks.items(): # See if the corresponding local scheduler is alive. - if local_scheduler_id in self.dead_local_schedulers: + if task["LocalSchedulerID"] in self.dead_local_schedulers: # If the task is scheduled on a dead local scheduler, mark the task as # lost. - ok = self.redis.execute_command("RAY.TASK_TABLE_UPDATE", - task_id, - TASK_STATUS_LOST, - NIL_ID) + key = binary_to_object_id(hex_to_binary(task_id)) + ok = self.state._execute_command( + key, "RAY.TASK_TABLE_UPDATE", hex_to_binary(task_id), + ray.experimental.state.TASK_STATUS_LOST, NIL_ID) if ok != b"OK": log.warn("Failed to update lost task for dead scheduler.") num_tasks_updated += 1 @@ -129,19 +128,20 @@ class Monitor(object): """ # TODO(swang): Also kill the associated plasma store, since it's no longer # reachable without a plasma manager. - object_ids = self.redis.scan_iter( - match="{prefix}*".format(prefix=OBJECT_PREFIX)) + objects = self.state.object_table() num_objects_removed = 0 - for object_id in object_ids: - object_id = object_id[len(OBJECT_PREFIX):] - managers = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - object_id) - for manager in managers: + for object_id, obj in objects.items(): + manager_ids = obj["ManagerIDs"] + if manager_ids is None: + continue + for manager in manager_ids: if manager in self.dead_plasma_managers: # If the object was on a dead plasma manager, remove that location # entry. - ok = self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", object_id, - manager) + ok = self.state._execute_command(object_id, + "RAY.OBJECT_TABLE_REMOVE", + object_id.id(), + hex_to_binary(manager)) if ok != b"OK": log.warn("Failed to remove object location for dead plasma " "manager.") @@ -157,18 +157,16 @@ class Monitor(object): not miss any notifications for deleted clients that occurred before we subscribed. """ - db_client_keys = self.redis.keys( - "{prefix}*".format(prefix=DB_CLIENT_PREFIX)) - for db_client_key in db_client_keys: - db_client_id = db_client_key[len(DB_CLIENT_PREFIX):] - client_type, deleted = self.redis.hmget(db_client_key, - [b"client_type", b"deleted"]) - deleted = bool(int(deleted)) - if deleted: - if client_type == LOCAL_SCHEDULER_CLIENT_TYPE: - self.dead_local_schedulers.add(db_client_id) - elif client_type == PLASMA_MANAGER_CLIENT_TYPE: - self.dead_plasma_managers.add(db_client_id) + clients = self.state.client_table() + for node_ip_address, node_clients in clients.items(): + for client in node_clients: + db_client_id = client["DBClientID"] + client_type = client["ClientType"] + if client["Deleted"]: + if client_type == LOCAL_SCHEDULER_CLIENT_TYPE: + self.dead_local_schedulers.add(db_client_id) + elif client_type == PLASMA_MANAGER_CLIENT_TYPE: + self.dead_plasma_managers.add(db_client_id) def subscribe_handler(self, channel, data): """Handle a subscription success message from Redis. @@ -186,7 +184,7 @@ class Monitor(object): """ notification_object = (SubscribeToDBClientTableReply .GetRootAsSubscribeToDBClientTableReply(data, 0)) - db_client_id = notification_object.DbClientId() + db_client_id = binary_to_hex(notification_object.DbClientId()) client_type = notification_object.ClientType() is_insertion = notification_object.IsInsertion() @@ -196,7 +194,7 @@ class Monitor(object): # If the update was a deletion, add them to our accounting for dead # local schedulers and plasma managers. - log.warn("Removed {}".format(client_type)) + log.warn("Removed {}, client ID {}".format(client_type, db_client_id)) if client_type == LOCAL_SCHEDULER_CLIENT_TYPE: if db_client_id not in self.dead_local_schedulers: self.dead_local_schedulers.add(db_client_id) @@ -256,7 +254,7 @@ class Monitor(object): result = pipe.hget(local_scheduler_id, "gpus_in_use") gpus_in_use = dict() if result is None else json.loads(result) - driver_id_hex = ray.utils.binary_to_hex(driver_id) + driver_id_hex = binary_to_hex(driver_id) if driver_id_hex in gpus_in_use: num_gpus_returned = gpus_in_use.pop(driver_id_hex) diff --git a/python/ray/plasma/plasma.py b/python/ray/plasma/plasma.py index cdee371b8..718d3f01f 100644 --- a/python/ray/plasma/plasma.py +++ b/python/ray/plasma/plasma.py @@ -405,7 +405,8 @@ def start_plasma_manager(store_name, redis_address, "-m", plasma_manager_name, "-h", node_ip_address, "-p", str(plasma_manager_port), - "-r", redis_address] + "-r", redis_address, + ] if use_valgrind: process = subprocess.Popen(["valgrind", "--track-origins=yes", diff --git a/python/ray/plasma/test/test.py b/python/ray/plasma/test/test.py index 68546b405..94a670626 100644 --- a/python/ray/plasma/test/test.py +++ b/python/ray/plasma/test/test.py @@ -480,7 +480,7 @@ class TestPlasmaManager(unittest.TestCase): store_name1, self.p2 = plasma.start_plasma_store(use_valgrind=USE_VALGRIND) store_name2, self.p3 = plasma.start_plasma_store(use_valgrind=USE_VALGRIND) # Start a Redis server. - redis_address = services.address("127.0.0.1", services.start_redis()[0]) + redis_address, _ = services.start_redis("127.0.0.1") # Start two PlasmaManagers. manager_name1, self.p4, self.port1 = plasma.start_plasma_manager( store_name1, redis_address, use_valgrind=USE_VALGRIND) @@ -778,8 +778,7 @@ class TestPlasmaManagerRecovery(unittest.TestCase): self.store_name, self.p2 = plasma.start_plasma_store( use_valgrind=USE_VALGRIND) # Start a Redis server. - self.redis_address = services.address("127.0.0.1", - services.start_redis()[0]) + self.redis_address, _ = services.start_redis("127.0.0.1") # Start a PlasmaManagers. manager_name, self.p3, self.port1 = plasma.start_plasma_manager( self.store_name, diff --git a/python/ray/services.py b/python/ray/services.py index 0ac9e6935..5fe278f71 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -240,15 +240,75 @@ def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5): "configured properly.") -def start_redis(node_ip_address="127.0.0.1", port=None, num_retries=20, - stdout_file=None, stderr_file=None, cleanup=True): - """Start a Redis server. +def start_redis(node_ip_address, + port=None, + num_redis_shards=1, + redirect_output=False, + cleanup=True): + """Start the Redis global state store. Args: node_ip_address: The IP address of the current node. This is only used for recording the log filenames in Redis. + port (int): If provided, the primary Redis shard will be started on this + port. + num_redis_shards (int): If provided, the number of Redis shards to start, + in addition to the primary one. The default value is one shard. + cleanup (bool): True if using Ray in local mode. If cleanup is true, then + all Redis processes started by this method will be killed by + serices.cleanup() when the Python process that imported services exits. + + Returns: + A tuple of the address for the primary Redis shard and a list of addresses + for the remaining shards. + """ + redis_stdout_file, redis_stderr_file = new_log_files( + "redis", redirect_output) + assigned_port, _ = start_redis_instance( + node_ip_address=node_ip_address, port=port, + stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, + cleanup=cleanup) + if port is not None: + assert assigned_port == port + port = assigned_port + redis_address = address(node_ip_address, port) + + # Register the number of Redis shards in the primary shard, so that clients + # know how many redis shards to expect under RedisShards. + redis_client = redis.StrictRedis(host=node_ip_address, port=port) + redis_client.set("NumRedisShards", str(num_redis_shards)) + + # Start other Redis shards listening on random ports. Each Redis shard logs + # to a separate file, prefixed by "redis-". + redis_shards = [] + for i in range(num_redis_shards): + redis_stdout_file, redis_stderr_file = new_log_files( + "redis-{}".format(i), redirect_output) + redis_shard_port, _ = start_redis_instance( + node_ip_address=node_ip_address, stdout_file=redis_stdout_file, + stderr_file=redis_stderr_file, cleanup=cleanup) + shard_address = address(node_ip_address, redis_shard_port) + redis_shards.append(shard_address) + # Store redis shard information in the primary redis shard. + redis_client.rpush("RedisShards", shard_address) + + return redis_address, redis_shards + + +def start_redis_instance(node_ip_address="127.0.0.1", + port=None, + num_retries=20, + stdout_file=None, + stderr_file=None, + cleanup=True): + """Start a single Redis server. + + Args: + node_ip_address (str): The IP address of the current node. This is only + used for recording the log filenames in Redis. port (int): If provided, start a Redis server with this port. - num_retries (int): The number of times to attempt to start Redis. + num_retries (int): The number of times to attempt to start Redis. If a port + is provided, this defaults to 1. stdout_file: A file handle opened for writing to redirect stdout to. If no redirection should happen, then this should be None. stderr_file: A file handle opened for writing to redirect stderr to. If no @@ -275,8 +335,8 @@ def start_redis(node_ip_address="127.0.0.1", port=None, num_retries=20, assert os.path.isfile(redis_module) counter = 0 if port is not None: - if num_retries != 1: - raise Exception("num_retries must be 1 if port is specified.") + # If a port is specified, then try only once to connect. + num_retries = 1 else: port = new_port() while counter < num_retries: @@ -356,8 +416,8 @@ def start_log_monitor(redis_address, node_ip_address, stdout_file=None, [stdout_file, stderr_file]) -def start_global_scheduler(redis_address, node_ip_address, stdout_file=None, - stderr_file=None, cleanup=True): +def start_global_scheduler(redis_address, node_ip_address, + stdout_file=None, stderr_file=None, cleanup=True): """Start a global scheduler process. Args: @@ -372,7 +432,8 @@ def start_global_scheduler(redis_address, node_ip_address, stdout_file=None, this process will be killed by services.cleanup() when the Python process that imported services exits. """ - p = global_scheduler.start_global_scheduler(redis_address, node_ip_address, + p = global_scheduler.start_global_scheduler(redis_address, + node_ip_address, stdout_file=stdout_file, stderr_file=stderr_file) if cleanup: @@ -545,10 +606,11 @@ def start_local_scheduler(redis_address, return local_scheduler_name -def start_objstore(node_ip_address, redis_address, object_manager_port=None, - store_stdout_file=None, store_stderr_file=None, - manager_stdout_file=None, manager_stderr_file=None, - cleanup=True, objstore_memory=None): +def start_objstore(node_ip_address, redis_address, + object_manager_port=None, store_stdout_file=None, + store_stderr_file=None, manager_stdout_file=None, + manager_stderr_file=None, cleanup=True, + objstore_memory=None): """This method starts an object store process. Args: @@ -704,13 +766,14 @@ def start_monitor(redis_address, node_ip_address, stdout_file=None, def start_ray_processes(address_info=None, node_ip_address="127.0.0.1", + redis_port=None, num_workers=None, num_local_schedulers=1, + num_redis_shards=1, worker_path=None, cleanup=True, redirect_output=False, include_global_scheduler=False, - include_redis=False, include_log_monitor=False, include_webui=False, start_workers_from_local_scheduler=True, @@ -723,12 +786,17 @@ def start_ray_processes(address_info=None, that have already been started. If provided, address_info will be modified to include processes that are newly started. node_ip_address (str): The IP address of this node. + redis_port (int): The port that the primary Redis shard should listen to. + If None, then a random port will be chosen. If the key "redis_address" is + in address_info, then this argument will be ignored. num_workers (int): The number of workers to start. num_local_schedulers (int): The total number of local schedulers required. This is also the total number of object stores required. This method will start new instances of local schedulers and object stores until there are num_local_schedulers existing instances of each, including ones already registered with the given address_info. + num_redis_shards: The number of Redis shards to start in addition to the + primary Redis shard. worker_path (str): The path of the source code that will be run by the worker. cleanup (bool): If cleanup is true, then the processes started here will be @@ -738,8 +806,6 @@ def start_ray_processes(address_info=None, file. include_global_scheduler (bool): If include_global_scheduler is True, then start a global scheduler process. - include_redis (bool): If include_redis is True, then start a Redis server - process. include_log_monitor (bool): If True, then start a log monitor to monitor the log files for all processes on this node and push their contents to Redis. @@ -785,29 +851,14 @@ def start_ray_processes(address_info=None, # warning messages when it starts up. Instead of suppressing the output, we # should address the warnings. redis_address = address_info.get("redis_address") - if include_redis: - redis_stdout_file, redis_stderr_file = new_log_files("redis", - redirect_output) - if redis_address is None: - # Start a Redis server. The start_redis method will choose a random port. - redis_port, _ = start_redis(node_ip_address, - stdout_file=redis_stdout_file, - stderr_file=redis_stderr_file, - cleanup=cleanup) - redis_address = address(node_ip_address, redis_port) - address_info["redis_address"] = redis_address - time.sleep(0.1) - else: - # A Redis address was provided, so start a Redis server with the given - # port. TODO(rkn): We should check that the IP address corresponds to the - # machine that this method is running on. - redis_port = get_port(redis_address) - new_redis_port, _ = start_redis(port=int(redis_port), - num_retries=1, - stdout_file=redis_stdout_file, - stderr_file=redis_stderr_file, - cleanup=cleanup) - assert redis_port == new_redis_port + redis_shards = address_info.get("redis_shards", []) + if redis_address is None: + redis_address, redis_shards = start_redis( + node_ip_address, port=redis_port, num_redis_shards=num_redis_shards, + redirect_output=redirect_output, cleanup=cleanup) + address_info["redis_address"] = redis_address + time.sleep(0.1) + # Start monitoring the processes. monitor_stdout_file, monitor_stderr_file = new_log_files("monitor", redirect_output) @@ -815,9 +866,14 @@ def start_ray_processes(address_info=None, node_ip_address, stdout_file=monitor_stdout_file, stderr_file=monitor_stderr_file) - else: - if redis_address is None: - raise Exception("Redis address expected") + + if redis_shards == []: + # Get redis shards from primary redis instance. + redis_ip_address, redis_port = redis_address.split(":") + redis_client = redis.StrictRedis(host=redis_ip_address, port=redis_port) + redis_shards = redis_client.lrange("RedisShards", start=0, end=-1) + redis_shards = [shard.decode("ascii") for shard in redis_shards] + address_info["redis_shards"] = redis_shards # Start the log monitor, if necessary. if include_log_monitor: @@ -1005,6 +1061,7 @@ def start_ray_node(node_ip_address, def start_ray_head(address_info=None, node_ip_address="127.0.0.1", + redis_port=None, num_workers=0, num_local_schedulers=1, worker_path=None, @@ -1012,7 +1069,8 @@ def start_ray_head(address_info=None, redirect_output=False, start_workers_from_local_scheduler=True, num_cpus=None, - num_gpus=None): + num_gpus=None, + num_redis_shards=None): """Start Ray in local mode. Args: @@ -1020,6 +1078,9 @@ def start_ray_head(address_info=None, that have already been started. If provided, address_info will be modified to include processes that are newly started. node_ip_address (str): The IP address of this node. + redis_port (int): The port that the primary Redis shard should listen to. + If None, then a random port will be chosen. If the key "redis_address" is + in address_info, then this argument will be ignored. num_workers (int): The number of workers to start. num_local_schedulers (int): The total number of local schedulers required. This is also the total number of object stores required. This method will @@ -1038,14 +1099,18 @@ def start_ray_head(address_info=None, Python. num_cpus (int): number of cpus to configure the local scheduler with. num_gpus (int): number of gpus to configure the local scheduler with. + num_redis_shards: The number of Redis shards to start in addition to the + primary Redis shard. Returns: A dictionary of the address information for the processes that were started. """ + num_redis_shards = 1 if num_redis_shards is None else num_redis_shards return start_ray_processes( address_info=address_info, node_ip_address=node_ip_address, + redis_port=redis_port, num_workers=num_workers, num_local_schedulers=num_local_schedulers, worker_path=worker_path, @@ -1053,11 +1118,11 @@ def start_ray_head(address_info=None, redirect_output=redirect_output, include_global_scheduler=True, include_log_monitor=True, - include_redis=True, include_webui=False, start_workers_from_local_scheduler=start_workers_from_local_scheduler, num_cpus=num_cpus, - num_gpus=num_gpus) + num_gpus=num_gpus, + num_redis_shards=num_redis_shards) def new_log_files(name, redirect_output): diff --git a/python/ray/worker.py b/python/ray/worker.py index 0920418bb..c6d9d009b 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -992,7 +992,8 @@ def _init(address_info=None, redirect_output=False, start_workers_from_local_scheduler=True, num_cpus=None, - num_gpus=None): + num_gpus=None, + num_redis_shards=None): """Helper method to connect to an existing Ray cluster or start a new one. This method handles two cases. Either a Ray cluster already exists and we @@ -1002,8 +1003,9 @@ def _init(address_info=None, Args: address_info (dict): A dictionary with address information for processes in a partially-started Ray cluster. If start_ray_local=True, any processes - not in this dictionary will be started. If provided, address_info will be - modified to include processes that are newly started. + not in this dictionary will be started. If provided, an updated + address_info dictionary will be returned to include processes that are + newly started. start_ray_local (bool): If True then this will start any processes not already in address_info, including Redis, a global scheduler, local scheduler(s), object store(s), and worker(s). It will also kill these @@ -1028,6 +1030,8 @@ def _init(address_info=None, be configured with. num_gpus: A list containing the number of GPUs the local schedulers should be configured with. + num_redis_shards: The number of Redis shards to start in addition to the + primary Redis shard. Returns: Address information about the started processes. @@ -1069,6 +1073,8 @@ def _init(address_info=None, num_local_schedulers = len(local_schedulers) else: num_local_schedulers = 1 + # Use 1 additional redis shard if num_redis_shards is not provided. + num_redis_shards = 1 if num_redis_shards is None else num_redis_shards # Start the scheduler, object store, and some workers. These will be killed # by the call to cleanup(), which happens when the Python script exits. address_info = services.start_ray_head( @@ -1079,20 +1085,24 @@ def _init(address_info=None, redirect_output=redirect_output, start_workers_from_local_scheduler=start_workers_from_local_scheduler, num_cpus=num_cpus, - num_gpus=num_gpus) + num_gpus=num_gpus, + num_redis_shards=num_redis_shards) else: if redis_address is None: - raise Exception("If start_ray_local=False, then redis_address must be " - "provided.") + raise Exception("When connecting to an existing cluster, redis_address " + "must be provided.") if num_workers is not None: - raise Exception("If start_ray_local=False, then num_workers must not be " - "provided.") + raise Exception("When connecting to an existing cluster, num_workers " + "must not be provided.") if num_local_schedulers is not None: - raise Exception("If start_ray_local=False, then num_local_schedulers " - "must not be provided.") + raise Exception("When connecting to an existing cluster, " + "num_local_schedulers must not be provided.") if num_cpus is not None or num_gpus is not None: - raise Exception("If start_ray_local=False, then num_cpus and num_gpus " - "must not be provided.") + raise Exception("When connecting to an existing cluster, num_cpus and " + "num_gpus must not be provided.") + if num_redis_shards is not None: + raise Exception("When connecting to an existing cluster, " + "num_redis_shards must not be provided.") # Get the node IP address if one is not provided. if node_ip_address is None: node_ip_address = services.get_node_ip_address(redis_address) @@ -1121,7 +1131,7 @@ def _init(address_info=None, def init(redis_address=None, node_ip_address=None, object_id_seed=None, num_workers=None, driver_mode=SCRIPT_MODE, redirect_output=False, - num_cpus=None, num_gpus=None): + num_cpus=None, num_gpus=None, num_redis_shards=None): """Either connect to an existing Ray cluster or start one and connect to it. This method handles two cases. Either a Ray cluster already exists and we @@ -1148,6 +1158,8 @@ def init(redis_address=None, node_ip_address=None, object_id_seed=None, configured with. num_gpus (int): Number of gpus the user wishes all local schedulers to be configured with. + num_redis_shards: The number of Redis shards to start in addition to the + primary Redis shard. Returns: Address information about the started processes. @@ -1161,7 +1173,7 @@ def init(redis_address=None, node_ip_address=None, object_id_seed=None, return _init(address_info=info, start_ray_local=(redis_address is None), num_workers=num_workers, driver_mode=driver_mode, redirect_output=redirect_output, num_cpus=num_cpus, - num_gpus=num_gpus) + num_gpus=num_gpus, num_redis_shards=num_redis_shards) def cleanup(worker=global_worker): @@ -1577,7 +1589,8 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, ray.local_scheduler.ObjectID(NIL_ACTOR_ID), worker.actor_counters[actor_id], [0, 0]) - worker.redis_client.execute_command( + global_state._execute_command( + driver_task.task_id(), "RAY.TASK_TABLE_ADD", driver_task.task_id().id(), TASK_STATUS_RUNNING, diff --git a/scripts/start_ray.py b/scripts/start_ray.py index a577d537d..31f9cc5af 100644 --- a/scripts/start_ray.py +++ b/scripts/start_ray.py @@ -15,6 +15,9 @@ parser.add_argument("--redis-address", required=False, type=str, help="the address to use for connecting to Redis") parser.add_argument("--redis-port", required=False, type=str, help="the port to use for starting Redis") +parser.add_argument("--num-redis-shards", required=False, type=int, + help=("the number of additional Redis shards to use in " + "addition to the primary Redis shard")) parser.add_argument("--object-manager-port", required=False, type=int, help="the port to use for starting the object manager") parser.add_argument("--num-workers", required=False, type=int, @@ -75,23 +78,22 @@ if __name__ == "__main__": print("Using IP address {} for this node.".format(node_ip_address)) address_info = {} - # Use the provided Redis port if there is one. - if args.redis_port is not None: - address_info["redis_address"] = "{}:{}".format(node_ip_address, - args.redis_port) # Use the provided object manager port if there is one. if args.object_manager_port is not None: address_info["object_manager_ports"] = [args.object_manager_port] if address_info == {}: address_info = None - address_info = services.start_ray_head(address_info=address_info, - node_ip_address=node_ip_address, - num_workers=args.num_workers, - cleanup=False, - redirect_output=True, - num_cpus=args.num_cpus, - num_gpus=args.num_gpus) + address_info = services.start_ray_head( + address_info=address_info, + node_ip_address=node_ip_address, + redis_port=args.redis_port, + num_workers=args.num_workers, + cleanup=False, + redirect_output=True, + num_cpus=args.num_cpus, + num_gpus=args.num_gpus, + num_redis_shards=args.num_redis_shards) print(address_info) print("\nStarted Ray on this node. You can add additional nodes to the " "cluster by calling\n\n" @@ -113,6 +115,9 @@ if __name__ == "__main__": if args.redis_address is None: raise Exception("If --head is not passed in, --redis-address must be " "provided.") + if args.num_redis_shards is not None: + raise Exception("If --head is not passed in, --num-redis-shards must " + "not be provided.") redis_ip_address, redis_port = args.redis_address.split(":") # Wait for the Redis server to be started. And throw an exception if we # can't connect to it. diff --git a/src/common/common_protocol.h b/src/common/common_protocol.h index 232dcd7ba..57a7466eb 100644 --- a/src/common/common_protocol.h +++ b/src/common/common_protocol.h @@ -5,6 +5,8 @@ #include "common.h" +#define DB_CLIENT_PREFIX "CL:" + /** * Convert an object ID to a flatbuffer string. * diff --git a/src/common/lib/python/common_extension.cc b/src/common/lib/python/common_extension.cc index 633be8e46..fca10f130 100644 --- a/src/common/lib/python/common_extension.cc +++ b/src/common/lib/python/common_extension.cc @@ -172,6 +172,14 @@ static PyObject *PyObjectID_richcompare(PyObjectID *self, return result; } +static PyObject *PyObjectID_redis_shard_hash(PyObjectID *self) { + /* NOTE: The hash function used here must match the one in get_redis_context + * in src/common/state/redis.cc. Changes to the hash function should only be + * made through UniqueIDHasher in src/common/common.h */ + UniqueIDHasher hash; + return PyLong_FromSize_t(hash(self->object_id)); +} + static long PyObjectID_hash(PyObjectID *self) { PyObject *tuple = PyTuple_New(UNIQUE_ID_SIZE); for (int i = 0; i < UNIQUE_ID_SIZE; ++i) { @@ -201,6 +209,8 @@ static PyObject *PyObjectID___reduce__(PyObjectID *self) { static PyMethodDef PyObjectID_methods[] = { {"id", (PyCFunction) PyObjectID_id, METH_NOARGS, "Return the hash associated with this ObjectID"}, + {"redis_shard_hash", (PyCFunction) PyObjectID_redis_shard_hash, METH_NOARGS, + "Return the redis shard that this ObjectID is associated with"}, {"hex", (PyCFunction) PyObjectID_hex, METH_NOARGS, "Return the object ID as a string in hex."}, {"__reduce__", (PyCFunction) PyObjectID___reduce__, METH_NOARGS, diff --git a/src/common/logging.cc b/src/common/logging.cc index 50d40ef87..87ba0f20d 100644 --- a/src/common/logging.cc +++ b/src/common/logging.cc @@ -67,11 +67,14 @@ void RayLogger_log(RayLogger *logger, if (logger->is_direct) { DBHandle *db = (DBHandle *) logger->conn; /* Fill in the client ID and send the message to Redis. */ - int status = redisAsyncCommand( - db->context, NULL, NULL, utstring_body(formatted_message), - (char *) db->client.id, sizeof(db->client.id)); - if ((status == REDIS_ERR) || db->context->err) { - LOG_REDIS_DEBUG(db->context, "error while logging message to log table"); + + redisAsyncContext *context = get_redis_context(db, db->client); + + int status = + redisAsyncCommand(context, NULL, NULL, utstring_body(formatted_message), + (char *) db->client.id, sizeof(db->client.id)); + if ((status == REDIS_ERR) || context->err) { + LOG_REDIS_DEBUG(context, "error while logging message to log table"); } } else { /* If we don't own a Redis connection, we leave our client diff --git a/src/common/net.cc b/src/common/net.cc index 3e56cff50..3f2aaf6fa 100644 --- a/src/common/net.cc +++ b/src/common/net.cc @@ -1,5 +1,9 @@ #include "net.h" +#include + +#include + #include "common.h" int parse_ip_addr_port(const char *ip_addr_port, char *ip_addr, int *port) { @@ -11,3 +15,10 @@ int parse_ip_addr_port(const char *ip_addr_port, char *ip_addr, int *port) { *port = atoi(port_str); return 0; } + +/* Return true if the ip address is valid. */ +bool valid_ip_address(const std::string &ip_address) { + struct sockaddr_in sa; + int result = inet_pton(AF_INET, ip_address.c_str(), &sa.sin_addr); + return result == 1; +} diff --git a/src/common/redis_module/ray_redis_module.cc b/src/common/redis_module/ray_redis_module.cc index 1b28c456c..9bbd6df78 100644 --- a/src/common/redis_module/ray_redis_module.cc +++ b/src/common/redis_module/ray_redis_module.cc @@ -27,7 +27,6 @@ * TODO(pcm): Fill this out. */ -#define DB_CLIENT_PREFIX "CL:" #define OBJECT_INFO_PREFIX "OI:" #define OBJECT_LOCATION_PREFIX "OL:" #define OBJECT_NOTIFICATION_PREFIX "ON:" @@ -929,6 +928,10 @@ int TaskTableWrite(RedisModuleCtx *ctx, RedisModuleCallReply *reply = RedisModule_Call(ctx, "PUBLISH", "ss", publish_topic, publish_message); + /* See how many clients received this publish. */ + long long num_clients = RedisModule_CallReplyInteger(reply); + CHECKM(num_clients <= 1, "Published to %lld clients.", num_clients); + RedisModule_FreeString(ctx, publish_message); RedisModule_FreeString(ctx, publish_topic); if (existing_task_spec != NULL) { @@ -938,6 +941,18 @@ int TaskTableWrite(RedisModuleCtx *ctx, if (reply == NULL) { return RedisModule_ReplyWithError(ctx, "PUBLISH unsuccessful"); } + + if (num_clients == 0) { + LOG_WARN( + "No subscribers received this publish. This most likely means that " + "either the intended recipient has not subscribed yet or that the " + "pubsub connection to the intended recipient has been broken."); + /* This reply will be received by redis_task_table_update_callback or + * redis_task_table_add_task_callback in redis.cc, which will then reissue + * the command. */ + return RedisModule_ReplyWithError(ctx, + "No subscribers received message."); + } } RedisModule_ReplyWithSimpleString(ctx, "OK"); diff --git a/src/common/state/db.h b/src/common/state/db.h index 2548b5f21..bfac9e212 100644 --- a/src/common/state/db.h +++ b/src/common/state/db.h @@ -11,6 +11,9 @@ typedef struct DBHandle DBHandle; * * @param db_address The hostname to use to connect to the database. * @param db_port The port to use to connect to the database. + * @param db_shards_addresses The list of database shard IP addresses. + * @param db_shards_ports The list of database shard ports, in the same order + * as db_shards_addresses. * @param client_type The type of this client. * @param node_ip_address The hostname of the client that is connecting. * @param num_args The number of extra arguments that should be supplied. This @@ -21,8 +24,8 @@ typedef struct DBHandle DBHandle; * @return This returns a handle to the database, which must be freed with * db_disconnect after use. */ -DBHandle *db_connect(const char *db_address, - int db_port, +DBHandle *db_connect(const std::string &db_primary_address, + int db_primary_port, const char *client_type, const char *node_ip_address, int num_args, diff --git a/src/common/state/db_client_table.h b/src/common/state/db_client_table.h index c969104b2..009f78c01 100644 --- a/src/common/state/db_client_table.h +++ b/src/common/state/db_client_table.h @@ -29,11 +29,22 @@ void db_client_table_remove(DBHandle *db_handle, * ==== Subscribing to the db client table ==== */ +/* An entry in the db client table. */ +typedef struct { + /** The database client ID. */ + DBClientID id; + /** The database client type. */ + const char *client_type; + /** An optional auxiliary address for an associated database client on the + * same node. */ + const char *aux_address; + /** Whether or not the database client exists. If this is false for an entry, + * then it will never again be true. */ + bool is_insertion; +} DBClient; + /* Callback for subscribing to the db client table. */ -typedef void (*db_client_table_subscribe_callback)(DBClientID db_client_id, - const char *client_type, - const char *aux_address, - bool is_insertion, +typedef void (*db_client_table_subscribe_callback)(DBClient *db_client, void *user_context); /** diff --git a/src/common/state/redis.cc b/src/common/state/redis.cc index cdb29693c..0cdfa09d3 100644 --- a/src/common/state/redis.cc +++ b/src/common/state/redis.cc @@ -4,6 +4,7 @@ #include #include #include +#include extern "C" { /* Including hiredis here is necessary on Windows for typedefs used in ae.h. */ @@ -26,6 +27,7 @@ extern "C" { #include "event_loop.h" #include "redis.h" #include "io.h" +#include "net.h" #include "format/common_generated.h" @@ -77,52 +79,128 @@ extern int usleep(useconds_t usec); do { \ } while (0) -DBHandle *db_connect(const char *db_address, - int db_port, - const char *client_type, - const char *node_ip_address, - int num_args, - const char **args) { - /* Check that the number of args is even. These args will be passed to the - * RAY.CONNECT Redis command, which takes arguments in pairs. */ - if (num_args % 2 != 0) { - LOG_FATAL("The number of extra args must be divisible by two."); - } +redisAsyncContext *get_redis_context(DBHandle *db, UniqueID id) { + /* NOTE: The hash function used here must match the one in + * PyObjectID_redis_shard_hash in src/common/lib/python/common_extension.cc. + * Changes to the hash function should only be made through + * UniqueIDHasher in src/common/common.h */ + UniqueIDHasher index; + return db->contexts[index(id) % db->contexts.size()]; +} - DBHandle *db = (DBHandle *) malloc(sizeof(DBHandle)); - /* Sync connection for initial handshake */ +redisAsyncContext *get_redis_subscribe_context(DBHandle *db, UniqueID id) { + UniqueIDHasher index; + return db->subscribe_contexts[index(id) % db->subscribe_contexts.size()]; +} + +void get_redis_shards(redisContext *context, + std::vector &db_shards_addresses, + std::vector &db_shards_ports) { + /* Get the total number of Redis shards in the system. */ + int num_attempts = 0; + redisReply *reply = NULL; + while (num_attempts < REDIS_DB_CONNECT_RETRIES) { + /* Try to read the number of Redis shards from the primary shard. If the + * entry is present, exit. */ + reply = (redisReply *) redisCommand(context, "GET NumRedisShards"); + if (reply->type != REDIS_REPLY_NIL) { + break; + } + + /* Sleep for a little, and try again if the entry isn't there yet. */ + freeReplyObject(reply); + usleep(REDIS_DB_CONNECT_WAIT_MS * 1000); + num_attempts++; + continue; + } + CHECKM(num_attempts < REDIS_DB_CONNECT_RETRIES, + "No entry found for NumRedisShards"); + CHECKM(reply->type == REDIS_REPLY_STRING, + "Expected string, found Redis type %d for NumRedisShards", + reply->type); + int num_redis_shards = atoi(reply->str); + CHECKM(num_redis_shards >= 1, "Expected at least one Redis shard, found %d.", + num_redis_shards); + freeReplyObject(reply); + + /* Get the addresses of all of the Redis shards. */ + num_attempts = 0; + while (num_attempts < REDIS_DB_CONNECT_RETRIES) { + /* Try to read the Redis shard locations from the primary shard. If we find + * that all of them are present, exit. */ + reply = (redisReply *) redisCommand(context, "LRANGE RedisShards 0 -1"); + if (reply->elements == num_redis_shards) { + break; + } + + /* Sleep for a little, and try again if not all Redis shard addresses have + * been added yet. */ + freeReplyObject(reply); + usleep(REDIS_DB_CONNECT_WAIT_MS * 1000); + num_attempts++; + continue; + } + CHECKM(num_attempts < REDIS_DB_CONNECT_RETRIES, + "Expected %d Redis shard addresses, found %d", num_redis_shards, + (int) reply->elements); + + /* Parse the Redis shard addresses. */ + char db_shard_address[16]; + int db_shard_port; + for (int i = 0; i < reply->elements; ++i) { + /* Parse the shard addresses and ports. */ + CHECK(reply->element[i]->type == REDIS_REPLY_STRING); + CHECK(parse_ip_addr_port(reply->element[i]->str, db_shard_address, + &db_shard_port) == 0); + db_shards_addresses.push_back(std::string(db_shard_address)); + db_shards_ports.push_back(db_shard_port); + } + freeReplyObject(reply); +} + +void db_connect_shard(const std::string &db_address, + int db_port, + DBClientID client, + const char *client_type, + const char *node_ip_address, + int num_args, + const char **args, + DBHandle *db, + redisAsyncContext **context_out, + redisAsyncContext **subscribe_context_out, + redisContext **sync_context_out) { + /* Synchronous connection for initial handshake */ redisReply *reply; int connection_attempts = 0; - redisContext *context = redisConnect(db_address, db_port); - while (context == NULL || context->err) { + redisContext *sync_context = redisConnect(db_address.c_str(), db_port); + while (sync_context == NULL || sync_context->err) { if (connection_attempts >= REDIS_DB_CONNECT_RETRIES) { break; } LOG_WARN("Failed to connect to Redis, retrying."); /* Sleep for a little. */ usleep(REDIS_DB_CONNECT_WAIT_MS * 1000); - context = redisConnect(db_address, db_port); + sync_context = redisConnect(db_address.c_str(), db_port); connection_attempts += 1; } - CHECK_REDIS_CONNECT(redisContext, context, + CHECK_REDIS_CONNECT(redisContext, sync_context, "could not establish synchronous connection to redis " "%s:%d", - db_address, db_port); + db_address.c_str(), db_port); /* Configure Redis to generate keyspace notifications for list events. This * should only need to be done once (by whoever started Redis), but since * Redis may be started in multiple places (e.g., for testing or when starting * processes by hand), it is easier to do it multiple times. */ - reply = (redisReply *) redisCommand(context, + reply = (redisReply *) redisCommand(sync_context, "CONFIG SET notify-keyspace-events Kl"); CHECKM(reply != NULL, "db_connect failed on CONFIG SET"); freeReplyObject(reply); /* Also configure Redis to not run in protected mode, so clients on other * hosts can connect to it. */ - reply = (redisReply *) redisCommand(context, "CONFIG SET protected-mode no"); + reply = + (redisReply *) redisCommand(sync_context, "CONFIG SET protected-mode no"); CHECKM(reply != NULL, "db_connect failed on CONFIG SET"); freeReplyObject(reply); - /* Create a client ID for this client. */ - DBClientID client = globally_unique_id(); /* Construct the argument arrays for RAY.CONNECT. */ int argc = num_args + 4; @@ -133,7 +211,7 @@ DBHandle *db_connect(const char *db_address, argvlen[0] = strlen(argv[0]); /* Set the client ID argument. */ argv[1] = (char *) client.id; - argvlen[1] = sizeof(db->client.id); + argvlen[1] = sizeof(client.id); /* Set the node IP address argument. */ argv[2] = node_ip_address; argvlen[2] = strlen(node_ip_address); @@ -152,7 +230,7 @@ DBHandle *db_connect(const char *db_address, /* Register this client with Redis. RAY.CONNECT is a custom Redis command that * we've defined. */ - reply = (redisReply *) redisCommandArgv(context, argc, argv, argvlen); + reply = (redisReply *) redisCommandArgv(sync_context, argc, argv, argvlen); CHECKM(reply != NULL, "db_connect failed on RAY.CONNECT"); CHECK(reply->type != REDIS_REPLY_ERROR); CHECK(strcmp(reply->str, "OK") == 0); @@ -160,25 +238,75 @@ DBHandle *db_connect(const char *db_address, free(argv); free(argvlen); + *sync_context_out = sync_context; + + /* Establish connection for control data. */ + redisAsyncContext *context = redisAsyncConnect(db_address.c_str(), db_port); + CHECK_REDIS_CONNECT(redisAsyncContext, context, + "could not establish asynchronous connection to redis " + "%s:%d", + db_address.c_str(), db_port); + context->data = (void *) db; + *context_out = context; + + /* Establish async connection for subscription. */ + redisAsyncContext *subscribe_context = + redisAsyncConnect(db_address.c_str(), db_port); + CHECK_REDIS_CONNECT(redisAsyncContext, subscribe_context, + "could not establish asynchronous subscription " + "connection to redis %s:%d", + db_address.c_str(), db_port); + subscribe_context->data = (void *) db; + *subscribe_context_out = subscribe_context; +} + +DBHandle *db_connect(const std::string &db_primary_address, + int db_primary_port, + const char *client_type, + const char *node_ip_address, + int num_args, + const char **args) { + /* Check that the number of args is even. These args will be passed to the + * RAY.CONNECT Redis command, which takes arguments in pairs. */ + if (num_args % 2 != 0) { + LOG_FATAL("The number of extra args must be divisible by two."); + } + + /* Create a client ID for this client. */ + DBClientID client = globally_unique_id(); + + DBHandle *db = new DBHandle(); + db->client_type = strdup(client_type); db->client = client; db->db_client_cache = NULL; - db->sync_context = context; - /* Establish async connection */ - db->context = redisAsyncConnect(db_address, db_port); - CHECK_REDIS_CONNECT(redisAsyncContext, db->context, - "could not establish asynchronous connection to redis " - "%s:%d", - db_address, db_port); - db->context->data = (void *) db; - /* Establish async connection for subscription */ - db->sub_context = redisAsyncConnect(db_address, db_port); - CHECK_REDIS_CONNECT(redisAsyncContext, db->sub_context, - "could not establish asynchronous subscription " - "connection to redis %s:%d", - db_address, db_port); - db->sub_context->data = (void *) db; + redisAsyncContext *context; + redisAsyncContext *subscribe_context; + redisContext *sync_context; + + /* Connect to the primary redis instance. */ + db_connect_shard(db_primary_address, db_primary_port, client, client_type, + node_ip_address, num_args, args, db, &context, + &subscribe_context, &sync_context); + db->context = context; + db->subscribe_context = subscribe_context; + db->sync_context = sync_context; + + /* Get the shard locations. */ + std::vector db_shards_addresses; + std::vector db_shards_ports; + get_redis_shards(db->sync_context, db_shards_addresses, db_shards_ports); + CHECKM(db_shards_addresses.size() > 0, "No Redis shards found"); + /* Connect to the shards. */ + for (int i = 0; i < db_shards_addresses.size(); ++i) { + db_connect_shard(db_shards_addresses[i], db_shards_ports[i], client, + client_type, node_ip_address, num_args, args, db, &context, + &subscribe_context, &sync_context); + db->contexts.push_back(context); + db->subscribe_contexts.push_back(subscribe_context); + redisFree(sync_context); + } return db; } @@ -193,10 +321,17 @@ void db_disconnect(DBHandle *db) { CHECK(strcmp(reply->str, "OK") == 0); freeReplyObject(reply); - /* Clean up the Redis connection state. */ + /* Clean up the primary Redis connection state. */ redisFree(db->sync_context); redisAsyncFree(db->context); - redisAsyncFree(db->sub_context); + redisAsyncFree(db->subscribe_context); + + /* Clean up the Redis shards. */ + CHECK(db->contexts.size() == db->subscribe_contexts.size()); + for (int i = 0; i < db->contexts.size(); ++i) { + redisAsyncFree(db->contexts[i]); + redisAsyncFree(db->subscribe_contexts[i]); + } /* Clean up memory. */ DBClientCacheEntry *e, *tmp; @@ -206,21 +341,36 @@ void db_disconnect(DBHandle *db) { free(e); } free(db->client_type); - free(db); + delete db; } void db_attach(DBHandle *db, event_loop *loop, bool reattach) { db->loop = loop; + /* Attach primary redis instance to the event loop. */ int err = redisAeAttach(loop, db->context); /* If the database is reattached in the tests, redis normally gives * an error which we can safely ignore. */ if (!reattach) { CHECKM(err == REDIS_OK, "failed to attach the event loop"); } - err = redisAeAttach(loop, db->sub_context); + err = redisAeAttach(loop, db->subscribe_context); if (!reattach) { CHECKM(err == REDIS_OK, "failed to attach the event loop"); } + /* Attach other redis shards to the event loop. */ + CHECK(db->contexts.size() == db->subscribe_contexts.size()); + for (int i = 0; i < db->contexts.size(); ++i) { + int err = redisAeAttach(loop, db->contexts[i]); + /* If the database is reattached in the tests, redis normally gives + * an error which we can safely ignore. */ + if (!reattach) { + CHECKM(err == REDIS_OK, "failed to attach the event loop"); + } + err = redisAeAttach(loop, db->subscribe_contexts[i]); + if (!reattach) { + CHECKM(err == REDIS_OK, "failed to attach the event loop"); + } + } } /* @@ -264,14 +414,16 @@ void redis_object_table_add(TableCallbackData *callback_data) { int64_t object_size = info->object_size; unsigned char *digest = info->digest; + redisAsyncContext *context = get_redis_context(db, obj_id); + int status = redisAsyncCommand( - db->context, redis_object_table_add_callback, + context, redis_object_table_add_callback, (void *) callback_data->timer_id, "RAY.OBJECT_TABLE_ADD %b %ld %b %b", obj_id.id, sizeof(obj_id.id), object_size, digest, (size_t) DIGEST_SIZE, db->client.id, sizeof(db->client.id)); - if ((status == REDIS_ERR) || db->context->err) { - LOG_REDIS_DEBUG(db->context, "error in redis_object_table_add"); + if ((status == REDIS_ERR) || context->err) { + LOG_REDIS_DEBUG(context, "error in redis_object_table_add"); } } @@ -309,13 +461,16 @@ void redis_object_table_remove(TableCallbackData *callback_data) { if (client_id == NULL) { client_id = &db->client; } + + redisAsyncContext *context = get_redis_context(db, obj_id); + int status = redisAsyncCommand( - db->context, redis_object_table_remove_callback, + context, redis_object_table_remove_callback, (void *) callback_data->timer_id, "RAY.OBJECT_TABLE_REMOVE %b %b", obj_id.id, sizeof(obj_id.id), client_id->id, sizeof(client_id->id)); - if ((status == REDIS_ERR) || db->context->err) { - LOG_REDIS_DEBUG(db->context, "error in redis_object_table_remove"); + if ((status == REDIS_ERR) || context->err) { + LOG_REDIS_DEBUG(context, "error in redis_object_table_remove"); } } @@ -324,12 +479,15 @@ void redis_object_table_lookup(TableCallbackData *callback_data) { DBHandle *db = callback_data->db_handle; ObjectID obj_id = callback_data->id; - int status = redisAsyncCommand( - db->context, redis_object_table_lookup_callback, - (void *) callback_data->timer_id, "RAY.OBJECT_TABLE_LOOKUP %b", obj_id.id, - sizeof(obj_id.id)); - if ((status == REDIS_ERR) || db->context->err) { - LOG_REDIS_DEBUG(db->context, "error in object_table lookup"); + + redisAsyncContext *context = get_redis_context(db, obj_id); + + int status = redisAsyncCommand(context, redis_object_table_lookup_callback, + (void *) callback_data->timer_id, + "RAY.OBJECT_TABLE_LOOKUP %b", obj_id.id, + sizeof(obj_id.id)); + if ((status == REDIS_ERR) || context->err) { + LOG_REDIS_DEBUG(context, "error in object_table lookup"); } } @@ -358,13 +516,15 @@ void redis_result_table_add(TableCallbackData *callback_data) { ResultTableAddInfo *info = (ResultTableAddInfo *) callback_data->data; int is_put = info->is_put ? 1 : 0; + redisAsyncContext *context = get_redis_context(db, id); + /* Add the result entry to the result table. */ int status = redisAsyncCommand( - db->context, redis_result_table_add_callback, + context, redis_result_table_add_callback, (void *) callback_data->timer_id, "RAY.RESULT_TABLE_ADD %b %b %d", id.id, sizeof(id.id), info->task_id.id, sizeof(info->task_id.id), is_put); - if ((status == REDIS_ERR) || db->context->err) { - LOG_REDIS_DEBUG(db->context, "Error in result table add"); + if ((status == REDIS_ERR) || context->err) { + LOG_REDIS_DEBUG(context, "Error in result table add"); } } @@ -423,12 +583,13 @@ void redis_result_table_lookup(TableCallbackData *callback_data) { CHECK(callback_data); DBHandle *db = callback_data->db_handle; ObjectID id = callback_data->id; + redisAsyncContext *context = get_redis_context(db, id); int status = - redisAsyncCommand(db->context, redis_result_table_lookup_callback, + redisAsyncCommand(context, redis_result_table_lookup_callback, (void *) callback_data->timer_id, "RAY.RESULT_TABLE_LOOKUP %b", id.id, sizeof(id.id)); - if ((status == REDIS_ERR) || db->context->err) { - LOG_REDIS_DEBUG(db->context, "Error in result table lookup"); + if ((status == REDIS_ERR) || context->err) { + LOG_REDIS_DEBUG(context, "Error in result table lookup"); } } @@ -586,28 +747,33 @@ void redis_object_table_subscribe_to_notifications( * src/common/redismodule/ray_redis_module.cc. */ const char *object_channel_prefix = "OC:"; const char *object_channel_bcast = "BCAST"; - int status = REDIS_OK; - /* Subscribe to notifications from the object table. This uses the client ID - * as the channel name so this channel is specific to this client. TODO(rkn): - * The channel name should probably be the client ID with some prefix. */ - CHECKM(callback_data->data != NULL, - "Object table subscribe data passed as NULL."); - if (((ObjectTableSubscribeData *) (callback_data->data))->subscribe_all) { - /* Subscribe to the object broadcast channel. */ - status = redisAsyncCommand( - db->sub_context, object_table_redis_subscribe_to_notifications_callback, - (void *) callback_data->timer_id, "SUBSCRIBE %s%s", - object_channel_prefix, object_channel_bcast); - } else { - status = redisAsyncCommand( - db->sub_context, object_table_redis_subscribe_to_notifications_callback, - (void *) callback_data->timer_id, "SUBSCRIBE %s%b", - object_channel_prefix, db->client.id, sizeof(db->client.id)); - } + for (int i = 0; i < db->subscribe_contexts.size(); ++i) { + int status = REDIS_OK; + /* Subscribe to notifications from the object table. This uses the client ID + * as the channel name so this channel is specific to this client. + * TODO(rkn): + * The channel name should probably be the client ID with some prefix. */ + CHECKM(callback_data->data != NULL, + "Object table subscribe data passed as NULL."); + if (((ObjectTableSubscribeData *) (callback_data->data))->subscribe_all) { + /* Subscribe to the object broadcast channel. */ + status = redisAsyncCommand( + db->subscribe_contexts[i], + object_table_redis_subscribe_to_notifications_callback, + (void *) callback_data->timer_id, "SUBSCRIBE %s%s", + object_channel_prefix, object_channel_bcast); + } else { + status = redisAsyncCommand( + db->subscribe_contexts[i], + object_table_redis_subscribe_to_notifications_callback, + (void *) callback_data->timer_id, "SUBSCRIBE %s%b", + object_channel_prefix, db->client.id, sizeof(db->client.id)); + } - if ((status == REDIS_ERR) || db->sub_context->err) { - LOG_REDIS_DEBUG(db->sub_context, - "error in redis_object_table_subscribe_to_notifications"); + if ((status == REDIS_ERR) || db->subscribe_contexts[i]->err) { + LOG_REDIS_DEBUG(db->subscribe_contexts[i], + "error in redis_object_table_subscribe_to_notifications"); + } } } @@ -633,31 +799,33 @@ void redis_object_table_request_notifications( int num_object_ids = request_data->num_object_ids; ObjectID *object_ids = request_data->object_ids; - /* Create the arguments for the Redis command. */ - int num_args = 1 + 1 + num_object_ids; - const char **argv = (const char **) malloc(sizeof(char *) * num_args); - size_t *argvlen = (size_t *) malloc(sizeof(size_t) * num_args); - /* Set the command name argument. */ - argv[0] = "RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS"; - argvlen[0] = strlen(argv[0]); - /* Set the client ID argument. */ - argv[1] = (char *) db->client.id; - argvlen[1] = sizeof(db->client.id); - /* Set the object ID arguments. */ for (int i = 0; i < num_object_ids; ++i) { - argv[2 + i] = (char *) object_ids[i].id; - argvlen[2 + i] = sizeof(object_ids[i].id); - } + redisAsyncContext *context = get_redis_context(db, object_ids[i]); - int status = redisAsyncCommandArgv( - db->context, redis_object_table_request_notifications_callback, - (void *) callback_data->timer_id, num_args, argv, argvlen); - free(argv); - free(argvlen); + /* Create the arguments for the Redis command. */ + int num_args = 1 + 1 + 1; + const char **argv = (const char **) malloc(sizeof(char *) * num_args); + size_t *argvlen = (size_t *) malloc(sizeof(size_t) * num_args); + /* Set the command name argument. */ + argv[0] = "RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS"; + argvlen[0] = strlen(argv[0]); + /* Set the client ID argument. */ + argv[1] = (char *) db->client.id; + argvlen[1] = sizeof(db->client.id); + /* Set the object ID arguments. */ + argv[2] = (char *) object_ids[i].id; + argvlen[2] = sizeof(object_ids[i].id); - if ((status == REDIS_ERR) || db->context->err) { - LOG_REDIS_DEBUG(db->context, - "error in redis_object_table_subscribe_to_notifications"); + int status = redisAsyncCommandArgv( + context, redis_object_table_request_notifications_callback, + (void *) callback_data->timer_id, num_args, argv, argvlen); + free(argv); + free(argvlen); + + if ((status == REDIS_ERR) || context->err) { + LOG_REDIS_DEBUG(context, + "error in redis_object_table_subscribe_to_notifications"); + } } } @@ -690,12 +858,14 @@ void redis_task_table_get_task(TableCallbackData *callback_data) { CHECK(callback_data->data == NULL); TaskID task_id = callback_data->id; - int status = redisAsyncCommand( - db->context, redis_task_table_get_task_callback, - (void *) callback_data->timer_id, "RAY.TASK_TABLE_GET %b", task_id.id, - sizeof(task_id.id)); - if ((status == REDIS_ERR) || db->context->err) { - LOG_REDIS_DEBUG(db->context, "error in redis_task_table_get_task"); + redisAsyncContext *context = get_redis_context(db, task_id); + + int status = redisAsyncCommand(context, redis_task_table_get_task_callback, + (void *) callback_data->timer_id, + "RAY.TASK_TABLE_GET %b", task_id.id, + sizeof(task_id.id)); + if ((status == REDIS_ERR) || context->err) { + LOG_REDIS_DEBUG(context, "error in redis_task_table_get_task"); } } @@ -706,6 +876,36 @@ void redis_task_table_add_task_callback(redisAsyncContext *c, /* Do some minimal checking. */ redisReply *reply = (redisReply *) r; + + /* If the publish which happens inside of the call to RAY.TASK_TABLE_ADD was + * not received by any subscribers, then reissue the command. TODO(rkn): This + * entire if block should be temporary. Once we address the problem where in + * which a global scheduler may publish a task to a local scheduler before the + * local scheduler has subscribed to the relevant channel, we shouldn't need + * this block any more. */ + if (reply->type == REDIS_REPLY_ERROR && + strcmp(reply->str, "No subscribers received message.") == 0) { + Task *task = (Task *) callback_data->data; + TaskID task_id = Task_task_id(task); + DBClientID local_scheduler_id = Task_local_scheduler(task); + redisAsyncContext *context = get_redis_context(db, task_id); + int state = Task_state(task); + TaskSpec *spec = Task_task_spec(task); + /* Reissue the command. */ + CHECKM(task != NULL, "NULL task passed to redis_task_table_add_task."); + int status = redisAsyncCommand( + context, redis_task_table_add_task_callback, + (void *) callback_data->timer_id, "RAY.TASK_TABLE_ADD %b %d %b %b", + task_id.id, sizeof(task_id.id), state, local_scheduler_id.id, + sizeof(local_scheduler_id.id), spec, Task_task_spec_size(task)); + if ((status == REDIS_ERR) || context->err) { + LOG_REDIS_DEBUG(context, "error in redis_task_table_add_task"); + } + /* Since we are reissuing the same command with the same callback data, + * return early to avoid freeing the callback data. */ + return; + } + CHECKM(strcmp(reply->str, "OK") == 0, "reply->str is %s", reply->str); /* Call the done callback if there is one. */ if (callback_data->done_callback != NULL) { @@ -722,17 +922,18 @@ void redis_task_table_add_task(TableCallbackData *callback_data) { Task *task = (Task *) callback_data->data; TaskID task_id = Task_task_id(task); DBClientID local_scheduler_id = Task_local_scheduler(task); + redisAsyncContext *context = get_redis_context(db, task_id); int state = Task_state(task); TaskSpec *spec = Task_task_spec(task); CHECKM(task != NULL, "NULL task passed to redis_task_table_add_task."); int status = redisAsyncCommand( - db->context, redis_task_table_add_task_callback, + context, redis_task_table_add_task_callback, (void *) callback_data->timer_id, "RAY.TASK_TABLE_ADD %b %d %b %b", task_id.id, sizeof(task_id.id), state, local_scheduler_id.id, sizeof(local_scheduler_id.id), spec, Task_task_spec_size(task)); - if ((status == REDIS_ERR) || db->context->err) { - LOG_REDIS_DEBUG(db->context, "error in redis_task_table_add_task"); + if ((status == REDIS_ERR) || context->err) { + LOG_REDIS_DEBUG(context, "error in redis_task_table_add_task"); } } @@ -743,6 +944,35 @@ void redis_task_table_update_callback(redisAsyncContext *c, /* Do some minimal checking. */ redisReply *reply = (redisReply *) r; + + /* If the publish which happens inside of the call to RAY.TASK_TABLE_UPDATE + * was not received by any subscribers, then reissue the command. TODO(rkn): + * This entire if block should be temporary. Once we address the problem where + * in which a global scheduler may publish a task to a local scheduler before + * the local scheduler has subscribed to the relevant channel, we shouldn't + * need this block any more. */ + if (reply->type == REDIS_REPLY_ERROR && + strcmp(reply->str, "No subscribers received message.") == 0) { + Task *task = (Task *) callback_data->data; + TaskID task_id = Task_task_id(task); + redisAsyncContext *context = get_redis_context(db, task_id); + DBClientID local_scheduler_id = Task_local_scheduler(task); + int state = Task_state(task); + /* Reissue the command. */ + CHECKM(task != NULL, "NULL task passed to redis_task_table_update."); + int status = redisAsyncCommand( + context, redis_task_table_update_callback, + (void *) callback_data->timer_id, "RAY.TASK_TABLE_UPDATE %b %d %b", + task_id.id, sizeof(task_id.id), state, local_scheduler_id.id, + sizeof(local_scheduler_id.id)); + if ((status == REDIS_ERR) || context->err) { + LOG_REDIS_DEBUG(context, "error in redis_task_table_update"); + } + /* Since we are reissuing the same command with the same callback data, + * return early to avoid freeing the callback data. */ + return; + } + CHECKM(strcmp(reply->str, "OK") == 0, "reply->str is %s", reply->str); /* Call the done callback if there is one. */ if (callback_data->done_callback != NULL) { @@ -758,17 +988,18 @@ void redis_task_table_update(TableCallbackData *callback_data) { DBHandle *db = callback_data->db_handle; Task *task = (Task *) callback_data->data; TaskID task_id = Task_task_id(task); + redisAsyncContext *context = get_redis_context(db, task_id); DBClientID local_scheduler_id = Task_local_scheduler(task); int state = Task_state(task); CHECKM(task != NULL, "NULL task passed to redis_task_table_update."); int status = redisAsyncCommand( - db->context, redis_task_table_update_callback, + context, redis_task_table_update_callback, (void *) callback_data->timer_id, "RAY.TASK_TABLE_UPDATE %b %d %b", task_id.id, sizeof(task_id.id), state, local_scheduler_id.id, sizeof(local_scheduler_id.id)); - if ((status == REDIS_ERR) || db->context->err) { - LOG_REDIS_DEBUG(db->context, "error in redis_task_table_update"); + if ((status == REDIS_ERR) || context->err) { + LOG_REDIS_DEBUG(context, "error in redis_task_table_update"); } } @@ -779,6 +1010,15 @@ void redis_task_table_test_and_update_callback(redisAsyncContext *c, redisReply *reply = (redisReply *) r; /* Parse the task from the reply. */ Task *task = parse_and_construct_task_from_redis_reply(reply); + if (task == NULL) { + /* A NULL task means that the task was not in the task table. NOTE(swang): + * For normal tasks, this is not expected behavior, but actor tasks may be + * delayed when added to the task table if they are submitted to a local + * scheduler before it receives the notification that maps the actor to a + * local scheduler. */ + LOG_ERROR("No task found during task_table_test_and_update"); + return; + } /* Determine whether the update happened. */ auto message = flatbuffers::GetRoot(reply->str); bool updated = message->updated(); @@ -800,18 +1040,19 @@ void redis_task_table_test_and_update_callback(redisAsyncContext *c, void redis_task_table_test_and_update(TableCallbackData *callback_data) { DBHandle *db = callback_data->db_handle; TaskID task_id = callback_data->id; + redisAsyncContext *context = get_redis_context(db, task_id); TaskTableTestAndUpdateData *update_data = (TaskTableTestAndUpdateData *) callback_data->data; int status = redisAsyncCommand( - db->context, redis_task_table_test_and_update_callback, + context, redis_task_table_test_and_update_callback, (void *) callback_data->timer_id, "RAY.TASK_TABLE_TEST_AND_UPDATE %b %d %d %b", task_id.id, sizeof(task_id.id), update_data->test_state_bitmask, update_data->update_state, update_data->local_scheduler_id.id, sizeof(update_data->local_scheduler_id.id)); - if ((status == REDIS_ERR) || db->context->err) { - LOG_REDIS_DEBUG(db->context, "error in redis_task_table_test_and_update"); + if ((status == REDIS_ERR) || context->err) { + LOG_REDIS_DEBUG(context, "error in redis_task_table_test_and_update"); } } @@ -879,24 +1120,26 @@ void redis_task_table_subscribe(TableCallbackData *callback_data) { /* TASK_CHANNEL_PREFIX is defined in ray_redis_module.cc and must be kept in * sync with that file. */ const char *TASK_CHANNEL_PREFIX = "TT:"; - int status; - if (IS_NIL_ID(data->local_scheduler_id)) { - /* TODO(swang): Implement the state_filter by translating the bitmask into - * a Redis key-matching pattern. */ - status = - redisAsyncCommand(db->sub_context, redis_task_table_subscribe_callback, - (void *) callback_data->timer_id, "PSUBSCRIBE %s*:%d", - TASK_CHANNEL_PREFIX, data->state_filter); - } else { - DBClientID local_scheduler_id = data->local_scheduler_id; - status = - redisAsyncCommand(db->sub_context, redis_task_table_subscribe_callback, - (void *) callback_data->timer_id, "SUBSCRIBE %s%b:%d", - TASK_CHANNEL_PREFIX, (char *) local_scheduler_id.id, - sizeof(local_scheduler_id.id), data->state_filter); - } - if ((status == REDIS_ERR) || db->sub_context->err) { - LOG_REDIS_DEBUG(db->sub_context, "error in redis_task_table_subscribe"); + for (auto subscribe_context : db->subscribe_contexts) { + int status; + if (IS_NIL_ID(data->local_scheduler_id)) { + /* TODO(swang): Implement the state_filter by translating the bitmask into + * a Redis key-matching pattern. */ + status = redisAsyncCommand( + subscribe_context, redis_task_table_subscribe_callback, + (void *) callback_data->timer_id, "PSUBSCRIBE %s*:%d", + TASK_CHANNEL_PREFIX, data->state_filter); + } else { + DBClientID local_scheduler_id = data->local_scheduler_id; + status = redisAsyncCommand( + subscribe_context, redis_task_table_subscribe_callback, + (void *) callback_data->timer_id, "SUBSCRIBE %s%b:%d", + TASK_CHANNEL_PREFIX, (char *) local_scheduler_id.id, + sizeof(local_scheduler_id.id), data->state_filter); + } + if ((status == REDIS_ERR) || subscribe_context->err) { + LOG_REDIS_DEBUG(subscribe_context, "error in redis_task_table_subscribe"); + } } } @@ -934,6 +1177,55 @@ void redis_db_client_table_remove(TableCallbackData *callback_data) { } } +void redis_db_client_table_scan(DBHandle *db, + std::vector &db_clients) { + /* TODO(swang): Integrate this functionality with the Ray Redis module. To do + * this, we need the KEYS or SCAN command in Redis modules. */ + /* Get all the database client keys. */ + redisReply *reply = (redisReply *) redisCommand(db->sync_context, "KEYS %s*", + DB_CLIENT_PREFIX); + if (reply->type == REDIS_REPLY_NIL) { + return; + } + /* Get all the database client information. */ + CHECK(reply->type == REDIS_REPLY_ARRAY); + for (int i = 0; i < reply->elements; ++i) { + redisReply *client_reply = (redisReply *) redisCommand( + db->sync_context, "HGETALL %b", reply->element[i]->str, + reply->element[i]->len); + CHECK(reply->type == REDIS_REPLY_ARRAY); + CHECK(reply->elements > 0); + DBClient db_client; + memset(&db_client, 0, sizeof(db_client)); + int num_fields = 0; + /* Parse the fields into a DBClient. */ + for (int j = 0; j < client_reply->elements; j = j + 2) { + const char *key = client_reply->element[j]->str; + const char *value = client_reply->element[j + 1]->str; + if (strcmp(key, "ray_client_id") == 0) { + memcpy(db_client.id.id, value, sizeof(db_client.id)); + num_fields++; + } else if (strcmp(key, "client_type") == 0) { + db_client.client_type = strdup(value); + num_fields++; + } else if (strcmp(key, "aux_address") == 0) { + db_client.aux_address = strdup(value); + num_fields++; + } else if (strcmp(key, "deleted") == 0) { + bool is_deleted = atoi(value); + db_client.is_insertion = !is_deleted; + num_fields++; + } + } + freeReplyObject(client_reply); + /* The client ID, type, and whether it is deleted are all mandatory fields. + * Auxiliary address is optional. */ + CHECK(num_fields >= 3); + db_clients.push_back(db_client); + } + freeReplyObject(reply); +} + void redis_db_client_table_subscribe_callback(redisAsyncContext *c, void *r, void *privdata) { @@ -956,35 +1248,54 @@ void redis_db_client_table_subscribe_callback(redisAsyncContext *c, /* Note that we do not destroy the callback data yet because the * subscription callback needs this data. */ event_loop_remove_timer(db->loop, callback_data->timer_id); + + /* Get the current db client table entries, in case we missed notifications + * before the initial subscription. This must be done before we process any + * notifications from the subscription channel, so that we don't readd an + * entry that has already been deleted. */ + std::vector db_clients; + redis_db_client_table_scan(db, db_clients); + /* Call the subscription callback for all entries that we missed. */ + DBClientTableSubscribeData *data = + (DBClientTableSubscribeData *) callback_data->data; + for (auto db_client : db_clients) { + data->subscribe_callback(&db_client, data->subscribe_context); + if (db_client.client_type != NULL) { + free((void *) db_client.client_type); + } + if (db_client.aux_address != NULL) { + free((void *) db_client.aux_address); + } + } return; } /* Otherwise, parse the payload and call the callback. */ auto message = flatbuffers::GetRoot(payload->str); - DBClientID client = from_flatbuf(message->db_client_id()); /* Parse the client type and auxiliary address from the response. If there is * only client type, then the update was a delete. */ - char *client_type = (char *) message->client_type()->data(); - char *aux_address = (char *) message->aux_address()->data(); - bool is_insertion = message->is_insertion(); + DBClient db_client; + db_client.id = from_flatbuf(message->db_client_id()); + db_client.client_type = (char *) message->client_type()->data(); + db_client.aux_address = message->aux_address()->data(); + db_client.is_insertion = message->is_insertion(); /* Call the subscription callback. */ DBClientTableSubscribeData *data = (DBClientTableSubscribeData *) callback_data->data; if (data->subscribe_callback) { - data->subscribe_callback(client, client_type, aux_address, is_insertion, - data->subscribe_context); + data->subscribe_callback(&db_client, data->subscribe_context); } } void redis_db_client_table_subscribe(TableCallbackData *callback_data) { DBHandle *db = callback_data->db_handle; int status = redisAsyncCommand( - db->sub_context, redis_db_client_table_subscribe_callback, + db->subscribe_context, redis_db_client_table_subscribe_callback, (void *) callback_data->timer_id, "SUBSCRIBE db_clients"); - if ((status == REDIS_ERR) || db->sub_context->err) { - LOG_REDIS_DEBUG(db->sub_context, + if ((status == REDIS_ERR) || db->subscribe_context->err) { + LOG_REDIS_DEBUG(db->subscribe_context, "error in db_client_table_register_callback"); } } @@ -1042,10 +1353,10 @@ void redis_local_scheduler_table_subscribe_callback(redisAsyncContext *c, void redis_local_scheduler_table_subscribe(TableCallbackData *callback_data) { DBHandle *db = callback_data->db_handle; int status = redisAsyncCommand( - db->sub_context, redis_local_scheduler_table_subscribe_callback, + db->subscribe_context, redis_local_scheduler_table_subscribe_callback, (void *) callback_data->timer_id, "SUBSCRIBE local_schedulers"); - if ((status == REDIS_ERR) || db->sub_context->err) { - LOG_REDIS_DEBUG(db->sub_context, + if ((status == REDIS_ERR) || db->subscribe_context->err) { + LOG_REDIS_DEBUG(db->subscribe_context, "error in redis_local_scheduler_table_subscribe"); } } @@ -1130,10 +1441,11 @@ void redis_driver_table_subscribe_callback(redisAsyncContext *c, void redis_driver_table_subscribe(TableCallbackData *callback_data) { DBHandle *db = callback_data->db_handle; int status = redisAsyncCommand( - db->sub_context, redis_driver_table_subscribe_callback, + db->subscribe_context, redis_driver_table_subscribe_callback, (void *) callback_data->timer_id, "SUBSCRIBE driver_deaths"); - if ((status == REDIS_ERR) || db->sub_context->err) { - LOG_REDIS_DEBUG(db->sub_context, "error in redis_driver_table_subscribe"); + if ((status == REDIS_ERR) || db->subscribe_context->err) { + LOG_REDIS_DEBUG(db->subscribe_context, + "error in redis_driver_table_subscribe"); } } @@ -1240,10 +1552,10 @@ void redis_actor_notification_table_subscribe( TableCallbackData *callback_data) { DBHandle *db = callback_data->db_handle; int status = redisAsyncCommand( - db->sub_context, redis_actor_notification_table_subscribe_callback, + db->subscribe_context, redis_actor_notification_table_subscribe_callback, (void *) callback_data->timer_id, "SUBSCRIBE actor_notifications"); - if ((status == REDIS_ERR) || db->sub_context->err) { - LOG_REDIS_DEBUG(db->sub_context, + if ((status == REDIS_ERR) || db->subscribe_context->err) { + LOG_REDIS_DEBUG(db->subscribe_context, "error in redis_actor_notification_table_subscribe"); } } @@ -1290,10 +1602,11 @@ void redis_object_info_subscribe_callback(redisAsyncContext *c, void redis_object_info_subscribe(TableCallbackData *callback_data) { DBHandle *db = callback_data->db_handle; int status = redisAsyncCommand( - db->sub_context, redis_object_info_subscribe_callback, + db->subscribe_context, redis_object_info_subscribe_callback, (void *) callback_data->timer_id, "PSUBSCRIBE obj:info"); - if ((status == REDIS_ERR) || db->sub_context->err) { - LOG_REDIS_DEBUG(db->sub_context, "error in object_info_register_callback"); + if ((status == REDIS_ERR) || db->subscribe_context->err) { + LOG_REDIS_DEBUG(db->subscribe_context, + "error in object_info_register_callback"); } } @@ -1324,8 +1637,8 @@ void redis_push_error_hmset_callback(redisAsyncContext *c, "RPUSH ErrorKeys Error:%b:%b", info->driver_id.id, sizeof(info->driver_id.id), info->error_key, sizeof(info->error_key)); - if ((status == REDIS_ERR) || db->sub_context->err) { - LOG_REDIS_DEBUG(db->sub_context, "error in redis_push_error rpush"); + if ((status == REDIS_ERR) || db->subscribe_context->err) { + LOG_REDIS_DEBUG(db->subscribe_context, "error in redis_push_error rpush"); } } @@ -1344,8 +1657,8 @@ void redis_push_error(TableCallbackData *callback_data) { "HMSET Error:%b:%b type %s message %s data %b", info->driver_id.id, sizeof(info->driver_id.id), info->error_key, sizeof(info->error_key), error_type, error_message, info->data, info->data_length); - if ((status == REDIS_ERR) || db->sub_context->err) { - LOG_REDIS_DEBUG(db->sub_context, "error in redis_push_error hmset"); + if ((status == REDIS_ERR) || db->subscribe_context->err) { + LOG_REDIS_DEBUG(db->subscribe_context, "error in redis_push_error hmset"); } } diff --git a/src/common/state/redis.h b/src/common/state/redis.h index cada086ae..33c2e6436 100644 --- a/src/common/state/redis.h +++ b/src/common/state/redis.h @@ -34,11 +34,24 @@ struct DBHandle { char *client_type; /** Unique ID for this client. */ DBClientID client; - /** Redis context for all non-subscribe connections. */ + /** Primary redis context for all non-subscribe connections. This is used for + * the database client table, heartbeats, and errors that should be pushed to + * the driver. */ redisAsyncContext *context; - /** Redis context for "subscribe" communication. Yes, we need a separate one - * for that, see https://github.com/redis/hiredis/issues/55. */ - redisAsyncContext *sub_context; + /** Primary redis context for "subscribe" communication. A separate context + * is needed for this communication (see + * https://github.com/redis/hiredis/issues/55). This is used for the + * database client table, heartbeats, and errors that should be pushed to + * the driver. */ + redisAsyncContext *subscribe_context; + /** Redis contexts for shards for all non-subscribe connections. All requests + * to the object table, task table, and event table should be directed here. + * The correct shard can be retrieved using get_redis_context below. */ + std::vector contexts; + /** Redis contexts for shards for "subscribe" communication. All requests + * to the object table, task table, and event table should be directed here. + * The correct shard can be retrieved using get_redis_context below. */ + std::vector subscribe_contexts; /** The event loop this global state store connection is part of. */ event_loop *loop; /** Index of the database connection in the event loop */ @@ -51,6 +64,40 @@ struct DBHandle { redisContext *sync_context; }; +/** + * Get the Redis asynchronous context responsible for non-subscription + * communication for the given UniqueID. + * + * @param db The database handle. + * @param id The ID whose location we are querying for. + * @return The redisAsyncContext responsible for the given ID. + */ +redisAsyncContext *get_redis_context(DBHandle *db, UniqueID id); + +/** + * Get the Redis asynchronous context responsible for subscription + * communication for the given UniqueID. + * + * @param db The database handle. + * @param id The ID whose location we are querying for. + * @return The redisAsyncContext responsible for the given ID. + */ +redisAsyncContext *get_redis_subscribe_context(DBHandle *db, UniqueID id); + +/** + * Get a list of Redis shard IP addresses from the primary shard. + * + * @param context A Redis context connected to the primary shard. + * @param db_shards_addresses The IP addresses for the shards registered + * with the primary shard will be added to this vector. + * @param db_shards_ports The IP ports for the shards registered with the + * primary shard will be added to this vector, in the same order as + * db_shards_addresses. + */ +void get_redis_shards(redisContext *context, + std::vector &db_shards_addresses, + std::vector &db_shards_ports); + void redis_object_table_get_entry(redisAsyncContext *c, void *r, void *privdata); diff --git a/src/common/state/task_table.h b/src/common/state/task_table.h index fee7b854c..dd5cf5d5d 100644 --- a/src/common/state/task_table.h +++ b/src/common/state/task_table.h @@ -151,9 +151,10 @@ typedef void (*task_table_subscribe_callback)(Task *task, void *user_context); * @param local_scheduler_id The db_client_id of the local scheduler whose * events we want to listen to. If you want to subscribe to updates from * all local schedulers, pass in NIL_ID. - * @param state_filter Flags for events we want to listen to. If you want - * to listen to all events, use state_filter = TASK_WAITING | - * TASK_SCHEDULED | TASK_RUNNING | TASK_DONE. + * @param state_filter Events we want to listen to. Can have values from the + * enum "scheduling_state" in task.h. + * TODO(pcm): Make it possible to combine these using flags like + * TASK_STATUS_WAITING | TASK_STATUS_SCHEDULED. * @param retry Information about retrying the request to the database. * @param done_callback Function to be called when database returns result. * @param user_context Data that will be passed to done_callback and diff --git a/src/common/test/db_tests.cc b/src/common/test/db_tests.cc index 73bff065d..5d54ac9a9 100644 --- a/src/common/test/db_tests.cc +++ b/src/common/test/db_tests.cc @@ -72,12 +72,12 @@ TEST object_table_lookup_test(void) { event_loop *loop = event_loop_create(); /* This uses manager_port1. */ const char *db_connect_args1[] = {"address", "127.0.0.1:12345"}; - DBHandle *db1 = db_connect("127.0.0.1", 6379, "plasma_manager", manager_addr, - 2, db_connect_args1); + DBHandle *db1 = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", + manager_addr, 2, db_connect_args1); /* This uses manager_port2. */ const char *db_connect_args2[] = {"address", "127.0.0.1:12346"}; - DBHandle *db2 = db_connect("127.0.0.1", 6379, "plasma_manager", manager_addr, - 2, db_connect_args2); + DBHandle *db2 = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", + manager_addr, 2, db_connect_args2); db_attach(db1, loop, false); db_attach(db2, loop, false); UniqueID id = globally_unique_id(); @@ -148,8 +148,8 @@ void task_table_test_callback(Task *callback_task, void *user_data) { TEST task_table_test(void) { task_table_test_callback_called = 0; event_loop *loop = event_loop_create(); - DBHandle *db = - db_connect("127.0.0.1", 6379, "local_scheduler", "127.0.0.1", 0, NULL); + DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "local_scheduler", + "127.0.0.1", 0, NULL); db_attach(db, loop, false); DBClientID local_scheduler_id = globally_unique_id(); int64_t task_spec_size; @@ -184,8 +184,8 @@ void task_table_all_test_callback(Task *task, void *user_data) { TEST task_table_all_test(void) { event_loop *loop = event_loop_create(); - DBHandle *db = - db_connect("127.0.0.1", 6379, "local_scheduler", "127.0.0.1", 0, NULL); + DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "local_scheduler", + "127.0.0.1", 0, NULL); db_attach(db, loop, false); int64_t task_spec_size; TaskSpec *spec = example_task_spec(1, 1, &task_spec_size); @@ -222,7 +222,8 @@ TEST unique_client_id_test(void) { DBClientID ids[num_conns]; DBHandle *db; for (int i = 0; i < num_conns; ++i) { - db = db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL); + db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", + "127.0.0.1", 0, NULL); ids[i] = get_db_client_id(db); db_disconnect(db); } diff --git a/src/common/test/object_table_tests.cc b/src/common/test/object_table_tests.cc index 6a5216829..195f83b56 100644 --- a/src/common/test/object_table_tests.cc +++ b/src/common/test/object_table_tests.cc @@ -65,6 +65,15 @@ void new_object_task_callback(TaskID task_id, void *user_context) { new_object_lookup_callback, (void *) db); } +void task_table_subscribe_done(TaskID task_id, void *user_context) { + RetryInfo retry = { + .num_retries = 5, .timeout = 100, .fail_callback = NULL, + }; + DBHandle *db = (DBHandle *) user_context; + task_table_add_task(db, Task_copy(new_object_task), &retry, + new_object_task_callback, db); +} + TEST new_object_test(void) { new_object_failed = 0; new_object_succeeded = 0; @@ -73,16 +82,16 @@ TEST new_object_test(void) { new_object_task_spec = Task_task_spec(new_object_task); new_object_task_id = TaskSpec_task_id(new_object_task_spec); g_loop = event_loop_create(); - DBHandle *db = - db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL); + DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", + "127.0.0.1", 0, NULL); db_attach(db, g_loop, false); RetryInfo retry = { .num_retries = 5, .timeout = 100, .fail_callback = new_object_fail_callback, }; - task_table_add_task(db, Task_copy(new_object_task), &retry, - new_object_task_callback, db); + task_table_subscribe(db, NIL_ID, TASK_STATUS_WAITING, NULL, NULL, &retry, + task_table_subscribe_done, db); event_loop_run(g_loop); db_disconnect(db); destroy_outstanding_callbacks(g_loop); @@ -109,8 +118,8 @@ TEST new_object_no_task_test(void) { new_object_id = globally_unique_id(); new_object_task_id = globally_unique_id(); g_loop = event_loop_create(); - DBHandle *db = - db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL); + DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", + "127.0.0.1", 0, NULL); db_attach(db, g_loop, false); RetryInfo retry = { .num_retries = 5, @@ -151,8 +160,8 @@ void lookup_fail_callback(UniqueID id, void *user_context, void *user_data) { TEST lookup_timeout_test(void) { g_loop = event_loop_create(); - DBHandle *db = - db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL); + DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", + "127.0.0.1", 0, NULL); db_attach(db, g_loop, false); RetryInfo retry = { .num_retries = 5, .timeout = 100, .fail_callback = lookup_fail_callback, @@ -161,6 +170,9 @@ TEST lookup_timeout_test(void) { (void *) lookup_timeout_context); /* Disconnect the database to see if the lookup times out. */ close(db->context->c.fd); + for (auto context : db->contexts) { + close(context->c.fd); + } event_loop_run(g_loop); db_disconnect(db); destroy_outstanding_callbacks(g_loop); @@ -187,8 +199,8 @@ void add_fail_callback(UniqueID id, void *user_context, void *user_data) { TEST add_timeout_test(void) { g_loop = event_loop_create(); - DBHandle *db = - db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL); + DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", + "127.0.0.1", 0, NULL); db_attach(db, g_loop, false); RetryInfo retry = { .num_retries = 5, .timeout = 100, .fail_callback = add_fail_callback, @@ -197,6 +209,9 @@ TEST add_timeout_test(void) { add_done_callback, (void *) add_timeout_context); /* Disconnect the database to see if the lookup times out. */ close(db->context->c.fd); + for (auto context : db->contexts) { + close(context->c.fd); + } event_loop_run(g_loop); db_disconnect(db); destroy_outstanding_callbacks(g_loop); @@ -225,8 +240,8 @@ void subscribe_fail_callback(UniqueID id, void *user_context, void *user_data) { TEST subscribe_timeout_test(void) { g_loop = event_loop_create(); - DBHandle *db = - db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL); + DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", + "127.0.0.1", 0, NULL); db_attach(db, g_loop, false); RetryInfo retry = { .num_retries = 5, @@ -236,7 +251,10 @@ TEST subscribe_timeout_test(void) { object_table_subscribe_to_notifications(db, false, subscribe_done_callback, NULL, &retry, NULL, NULL); /* Disconnect the database to see if the lookup times out. */ - close(db->sub_context->c.fd); + close(db->subscribe_context->c.fd); + for (auto subscribe_context : db->subscribe_contexts) { + close(subscribe_context->c.fd); + } event_loop_run(g_loop); db_disconnect(db); destroy_outstanding_callbacks(g_loop); @@ -324,8 +342,8 @@ TEST add_lookup_test(void) { lookup_retry_succeeded = 0; /* Construct the arguments to db_connect. */ const char *db_connect_args[] = {"address", "127.0.0.1:11235"}; - DBHandle *db = db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 2, - db_connect_args); + DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", + "127.0.0.1", 2, db_connect_args); db_attach(db, g_loop, true); RetryInfo retry = { .num_retries = 5, @@ -385,8 +403,8 @@ void add_remove_callback(ObjectID object_id, bool success, void *user_context) { TEST add_remove_lookup_test(void) { g_loop = event_loop_create(); lookup_retry_succeeded = 0; - DBHandle *db = - db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL); + DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", + "127.0.0.1", 0, NULL); db_attach(db, g_loop, true); RetryInfo retry = { .num_retries = 5, @@ -407,29 +425,6 @@ TEST add_remove_lookup_test(void) { PASS(); } -/* === Test subscribe retry === */ - -const char *subscribe_retry_context = "subscribe_retry"; -int subscribe_retry_succeeded = 0; - -int64_t reconnect_sub_context_callback(event_loop *loop, - int64_t timer_id, - void *context) { - DBHandle *db = (DBHandle *) context; - /* Reconnect to redis. This is not reconnecting the pub/sub channel. */ - redisAsyncFree(db->sub_context); - redisAsyncFree(db->context); - redisFree(db->sync_context); - db->sub_context = redisAsyncConnect("127.0.0.1", 6379); - db->sub_context->data = (void *) db; - db->context = redisAsyncConnect("127.0.0.1", 6379); - db->context->data = (void *) db; - db->sync_context = redisConnect("127.0.0.1", 6379); - /* Re-attach the database to the event loop (the file descriptor changed). */ - db_attach(db, loop, true); - return EVENT_LOOP_TIMER_DONE; -} - /* ==== Test if late succeed is working correctly ==== */ /* === Test lookup late succeed === */ @@ -454,8 +449,8 @@ void lookup_late_done_callback(ObjectID object_id, TEST lookup_late_test(void) { g_loop = event_loop_create(); - DBHandle *db = - db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL); + DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", + "127.0.0.1", 0, NULL); db_attach(db, g_loop, false); RetryInfo retry = { .num_retries = 0, @@ -498,8 +493,8 @@ void add_late_done_callback(ObjectID object_id, TEST add_late_test(void) { g_loop = event_loop_create(); - DBHandle *db = - db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL); + DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", + "127.0.0.1", 0, NULL); db_attach(db, g_loop, false); RetryInfo retry = { .num_retries = 0, .timeout = 0, .fail_callback = add_late_fail_callback, @@ -543,8 +538,8 @@ void subscribe_late_done_callback(ObjectID object_id, TEST subscribe_late_test(void) { g_loop = event_loop_create(); - DBHandle *db = - db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL); + DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", + "127.0.0.1", 0, NULL); db_attach(db, g_loop, false); RetryInfo retry = { .num_retries = 0, @@ -611,8 +606,8 @@ TEST subscribe_success_test(void) { /* Construct the arguments to db_connect. */ const char *db_connect_args[] = {"address", "127.0.0.1:11236"}; - DBHandle *db = db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 2, - db_connect_args); + DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", + "127.0.0.1", 2, db_connect_args); db_attach(db, g_loop, false); subscribe_id = globally_unique_id(); @@ -680,8 +675,8 @@ TEST subscribe_object_present_test(void) { g_loop = event_loop_create(); /* Construct the arguments to db_connect. */ const char *db_connect_args[] = {"address", "127.0.0.1:11236"}; - DBHandle *db = db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 2, - db_connect_args); + DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", + "127.0.0.1", 2, db_connect_args); db_attach(db, g_loop, false); UniqueID id = globally_unique_id(); RetryInfo retry = { @@ -732,8 +727,8 @@ void subscribe_object_not_present_object_available_callback( TEST subscribe_object_not_present_test(void) { g_loop = event_loop_create(); - DBHandle *db = - db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL); + DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", + "127.0.0.1", 0, NULL); db_attach(db, g_loop, false); UniqueID id = globally_unique_id(); RetryInfo retry = { @@ -796,8 +791,8 @@ TEST subscribe_object_available_later_test(void) { g_loop = event_loop_create(); /* Construct the arguments to db_connect. */ const char *db_connect_args[] = {"address", "127.0.0.1:11236"}; - DBHandle *db = db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 2, - db_connect_args); + DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", + "127.0.0.1", 2, db_connect_args); db_attach(db, g_loop, false); UniqueID id = globally_unique_id(); RetryInfo retry = { @@ -849,8 +844,8 @@ TEST subscribe_object_available_subscribe_all(void) { g_loop = event_loop_create(); /* Construct the arguments to db_connect. */ const char *db_connect_args[] = {"address", "127.0.0.1:11236"}; - DBHandle *db = db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 2, - db_connect_args); + DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", + "127.0.0.1", 2, db_connect_args); db_attach(db, g_loop, false); UniqueID id = globally_unique_id(); RetryInfo retry = { @@ -904,7 +899,6 @@ SUITE(object_table_tests) { RUN_REDIS_TEST(add_late_test); RUN_REDIS_TEST(subscribe_late_test); RUN_REDIS_TEST(subscribe_success_test); - RUN_REDIS_TEST(subscribe_object_present_test); RUN_REDIS_TEST(subscribe_object_not_present_test); RUN_REDIS_TEST(subscribe_object_available_later_test); RUN_REDIS_TEST(subscribe_object_available_subscribe_all); diff --git a/src/common/test/redis_tests.cc b/src/common/test/redis_tests.cc index 92cdc3829..78f80f986 100644 --- a/src/common/test/redis_tests.cc +++ b/src/common/test/redis_tests.cc @@ -102,8 +102,8 @@ TEST async_redis_socket_test(void) { utarray_push_back(connections, &socket_fd); /* Start connection to Redis. */ - DBHandle *db = - db_connect("127.0.0.1", 6379, "test_process", "127.0.0.1", 0, NULL); + DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "test_process", + "127.0.0.1", 0, NULL); db_attach(db, loop, false); /* Send a command to the Redis process. */ @@ -177,8 +177,8 @@ TEST logging_test(void) { utarray_push_back(connections, &socket_fd); /* Start connection to Redis. */ - DBHandle *conn = - db_connect("127.0.0.1", 6379, "test_process", "127.0.0.1", 0, NULL); + DBHandle *conn = db_connect(std::string("127.0.0.1"), 6379, "test_process", + "127.0.0.1", 0, NULL); db_attach(conn, loop, false); /* Send a command to the Redis process. */ diff --git a/src/common/test/run_tests.sh b/src/common/test/run_tests.sh index 4fd2f1d4c..cc5c73eb2 100644 --- a/src/common/test/run_tests.sh +++ b/src/common/test/run_tests.sh @@ -5,8 +5,14 @@ # Cause the script to exit if a single command fails. set -e -./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so & +# Start the Redis shards. +./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6379 & +./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6380 & sleep 1s +# Register the shard location with the primary shard. +./src/common/thirdparty/redis/src/redis-cli set NumRedisShards 1 +./src/common/thirdparty/redis/src/redis-cli rpush RedisShards 127.0.0.1:6380 + ./src/common/common_tests ./src/common/db_tests ./src/common/io_tests @@ -14,4 +20,5 @@ sleep 1s ./src/common/redis_tests ./src/common/task_table_tests ./src/common/object_table_tests -./src/common/thirdparty/redis/src/redis-cli shutdown +./src/common/thirdparty/redis/src/redis-cli -p 6379 shutdown +./src/common/thirdparty/redis/src/redis-cli -p 6380 shutdown diff --git a/src/common/test/run_valgrind.sh b/src/common/test/run_valgrind.sh index 51edf5f5b..7e700f574 100644 --- a/src/common/test/run_valgrind.sh +++ b/src/common/test/run_valgrind.sh @@ -5,8 +5,14 @@ # Cause the script to exit if a single command fails. set -e -./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so & +# Start the Redis shards. +./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6379 & +./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6380 & sleep 1s +# Register the shard location with the primary shard. +./src/common/thirdparty/redis/src/redis-cli set NumRedisShards 1 +./src/common/thirdparty/redis/src/redis-cli rpush RedisShards 127.0.0.1:6380 + valgrind --leak-check=full --error-exitcode=1 ./src/common/common_tests valgrind --leak-check=full --error-exitcode=1 ./src/common/db_tests valgrind --leak-check=full --error-exitcode=1 ./src/common/io_tests @@ -14,4 +20,6 @@ valgrind --leak-check=full --error-exitcode=1 ./src/common/task_tests valgrind --leak-check=full --error-exitcode=1 ./src/common/redis_tests valgrind --leak-check=full --error-exitcode=1 ./src/common/task_table_tests valgrind --leak-check=full --error-exitcode=1 ./src/common/object_table_tests + ./src/common/thirdparty/redis/src/redis-cli shutdown +./src/common/thirdparty/redis/src/redis-cli -p 6380 shutdown diff --git a/src/common/test/task_table_tests.cc b/src/common/test/task_table_tests.cc index 4047cb35c..ae363eca4 100644 --- a/src/common/test/task_table_tests.cc +++ b/src/common/test/task_table_tests.cc @@ -40,8 +40,8 @@ void lookup_nil_success_callback(Task *task, void *context) { TEST lookup_nil_test(void) { lookup_nil_id = globally_unique_id(); g_loop = event_loop_create(); - DBHandle *db = - db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL); + DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", + "127.0.0.1", 0, NULL); db_attach(db, g_loop, false); RetryInfo retry = { .num_retries = 5, @@ -96,14 +96,16 @@ void add_success_callback(TaskID task_id, void *context) { TEST add_lookup_test(void) { add_lookup_task = example_task(1, 1, TASK_STATUS_WAITING); g_loop = event_loop_create(); - DBHandle *db = - db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL); + DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", + "127.0.0.1", 0, NULL); db_attach(db, g_loop, false); RetryInfo retry = { .num_retries = 5, .timeout = 1000, .fail_callback = add_lookup_fail_callback, }; + task_table_subscribe(db, NIL_ID, TASK_STATUS_WAITING, NULL, NULL, &retry, + NULL, NULL); task_table_add_task(db, Task_copy(add_lookup_task), &retry, add_success_callback, (void *) db); /* Disconnect the database to see if the lookup times out. */ @@ -136,8 +138,8 @@ void subscribe_fail_callback(UniqueID id, void *user_context, void *user_data) { TEST subscribe_timeout_test(void) { g_loop = event_loop_create(); - DBHandle *db = - db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL); + DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", + "127.0.0.1", 0, NULL); db_attach(db, g_loop, false); RetryInfo retry = { .num_retries = 5, @@ -148,7 +150,10 @@ TEST subscribe_timeout_test(void) { subscribe_done_callback, (void *) subscribe_timeout_context); /* Disconnect the database to see if the subscribe times out. */ - close(db->sub_context->c.fd); + close(db->subscribe_context->c.fd); + for (int i = 0; i < db->subscribe_contexts.size(); ++i) { + close(db->subscribe_contexts[i]->c.fd); + } aeProcessEvents(g_loop, AE_TIME_EVENTS); event_loop_run(g_loop); db_disconnect(db); @@ -177,17 +182,22 @@ void publish_fail_callback(UniqueID id, void *user_context, void *user_data) { TEST publish_timeout_test(void) { g_loop = event_loop_create(); - DBHandle *db = - db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL); + DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", + "127.0.0.1", 0, NULL); db_attach(db, g_loop, false); Task *task = example_task(1, 1, TASK_STATUS_WAITING); RetryInfo retry = { .num_retries = 5, .timeout = 100, .fail_callback = publish_fail_callback, }; + task_table_subscribe(db, NIL_ID, TASK_STATUS_WAITING, NULL, NULL, &retry, + NULL, NULL); task_table_add_task(db, task, &retry, publish_done_callback, (void *) publish_timeout_context); /* Disconnect the database to see if the publish times out. */ close(db->context->c.fd); + for (int i = 0; i < db->contexts.size(); ++i) { + close(db->contexts[i]->c.fd); + } aeProcessEvents(g_loop, AE_TIME_EVENTS); event_loop_run(g_loop); db_disconnect(db); @@ -204,9 +214,14 @@ int64_t reconnect_db_callback(event_loop *loop, void *context) { DBHandle *db = (DBHandle *) context; /* Reconnect to redis. */ - redisAsyncFree(db->sub_context); - db->sub_context = redisAsyncConnect("127.0.0.1", 6379); - db->sub_context->data = (void *) db; + redisAsyncFree(db->subscribe_context); + db->subscribe_context = redisAsyncConnect("127.0.0.1", 6379); + db->subscribe_context->data = (void *) db; + for (int i = 0; i < db->subscribe_contexts.size(); ++i) { + redisAsyncFree(db->subscribe_contexts[i]); + db->subscribe_contexts[i] = redisAsyncConnect("127.0.0.1", 6380 + i); + db->subscribe_contexts[i]->data = (void *) db; + } /* Re-attach the database to the event loop (the file descriptor changed). */ db_attach(db, loop, true); return EVENT_LOOP_TIMER_DONE; @@ -239,8 +254,8 @@ void subscribe_retry_fail_callback(UniqueID id, TEST subscribe_retry_test(void) { g_loop = event_loop_create(); - DBHandle *db = - db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL); + DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", + "127.0.0.1", 0, NULL); db_attach(db, g_loop, false); RetryInfo retry = { .num_retries = 5, @@ -251,7 +266,10 @@ TEST subscribe_retry_test(void) { subscribe_retry_done_callback, (void *) subscribe_retry_context); /* Disconnect the database to see if the subscribe times out. */ - close(db->sub_context->c.fd); + close(db->subscribe_context->c.fd); + for (int i = 0; i < db->subscribe_contexts.size(); ++i) { + close(db->subscribe_contexts[i]->c.fd); + } /* Install handler for reconnecting the database. */ event_loop_add_timer(g_loop, 150, (event_loop_timer_handler) reconnect_db_callback, db); @@ -286,8 +304,8 @@ void publish_retry_fail_callback(UniqueID id, TEST publish_retry_test(void) { g_loop = event_loop_create(); - DBHandle *db = - db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL); + DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", + "127.0.0.1", 0, NULL); db_attach(db, g_loop, false); Task *task = example_task(1, 1, TASK_STATUS_WAITING); RetryInfo retry = { @@ -295,10 +313,15 @@ TEST publish_retry_test(void) { .timeout = 100, .fail_callback = publish_retry_fail_callback, }; + task_table_subscribe(db, NIL_ID, TASK_STATUS_WAITING, NULL, NULL, &retry, + NULL, NULL); task_table_add_task(db, task, &retry, publish_retry_done_callback, (void *) publish_retry_context); /* Disconnect the database to see if the publish times out. */ - close(db->sub_context->c.fd); + close(db->subscribe_context->c.fd); + for (int i = 0; i < db->subscribe_contexts.size(); ++i) { + close(db->subscribe_contexts[i]->c.fd); + } /* Install handler for reconnecting the database. */ event_loop_add_timer(g_loop, 150, (event_loop_timer_handler) reconnect_db_callback, db); @@ -335,8 +358,8 @@ void subscribe_late_done_callback(TaskID task_id, void *user_context) { TEST subscribe_late_test(void) { g_loop = event_loop_create(); - DBHandle *db = - db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL); + DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", + "127.0.0.1", 0, NULL); db_attach(db, g_loop, false); RetryInfo retry = { .num_retries = 0, @@ -380,8 +403,8 @@ void publish_late_done_callback(TaskID task_id, void *user_context) { TEST publish_late_test(void) { g_loop = event_loop_create(); - DBHandle *db = - db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL); + DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", + "127.0.0.1", 0, NULL); db_attach(db, g_loop, false); Task *task = example_task(1, 1, TASK_STATUS_WAITING); RetryInfo retry = { @@ -389,6 +412,8 @@ TEST publish_late_test(void) { .timeout = 0, .fail_callback = publish_late_fail_callback, }; + task_table_subscribe(db, NIL_ID, TASK_STATUS_WAITING, NULL, NULL, NULL, NULL, + NULL); task_table_add_task(db, task, &retry, publish_late_done_callback, (void *) publish_late_context); /* Install handler for terminating the event loop. */ diff --git a/src/common/test/test_common.h b/src/common/test/test_common.h index d4b81036d..94c6ef3bb 100644 --- a/src/common/test/test_common.h +++ b/src/common/test/test_common.h @@ -2,11 +2,13 @@ #define TEST_COMMON_H #include +#include #include "common.h" #include "io.h" #include "hiredis/hiredis.h" #include "utstring.h" +#include "state/redis.h" #ifndef _WIN32 /* This function is actually not declared in standard POSIX, so declare it. */ @@ -48,10 +50,29 @@ static inline int bind_inet_sock_retry(int *fd) { } /* Flush redis. */ -static inline void flushall_redis() { +static inline void flushall_redis(void) { + /* Flush the primary shard. */ redisContext *context = redisConnect("127.0.0.1", 6379); + std::vector db_shards_addresses; + std::vector db_shards_ports; + get_redis_shards(context, db_shards_addresses, db_shards_ports); freeReplyObject(redisCommand(context, "FLUSHALL")); + /* Readd the shard locations. */ + freeReplyObject(redisCommand(context, "SET NumRedisShards %d", + db_shards_addresses.size())); + for (int i = 0; i < db_shards_addresses.size(); ++i) { + freeReplyObject(redisCommand(context, "RPUSH RedisShards %s:%d", + db_shards_addresses[i].c_str(), + db_shards_ports[i])); + } redisFree(context); + + /* Flush the remaining shards. */ + for (int i = 0; i < db_shards_addresses.size(); ++i) { + context = redisConnect(db_shards_addresses[i].c_str(), db_shards_ports[i]); + freeReplyObject(redisCommand(context, "FLUSHALL")); + redisFree(context); + } } /* Cleanup method for running tests with the greatest library. diff --git a/src/global_scheduler/global_scheduler.cc b/src/global_scheduler/global_scheduler.cc index e7956e289..0271cb95c 100644 --- a/src/global_scheduler/global_scheduler.cc +++ b/src/global_scheduler/global_scheduler.cc @@ -62,14 +62,14 @@ void assign_task_to_local_scheduler(GlobalSchedulerState *state, GlobalSchedulerState *GlobalSchedulerState_init(event_loop *loop, const char *node_ip_address, - const char *redis_addr, - int redis_port) { + const char *redis_primary_addr, + int redis_primary_port) { GlobalSchedulerState *state = (GlobalSchedulerState *) malloc(sizeof(GlobalSchedulerState)); /* Must initialize state to 0. Sets hashmap head(s) to NULL. */ memset(state, 0, sizeof(GlobalSchedulerState)); - state->db = db_connect(redis_addr, redis_port, "global_scheduler", - node_ip_address, 0, NULL); + state->db = db_connect(std::string(redis_primary_addr), redis_primary_port, + "global_scheduler", node_ip_address, 0, NULL); db_attach(state->db, loop, false); utarray_new(state->local_schedulers, &local_scheduler_icd); state->policy_state = GlobalSchedulerPolicyState_init(); @@ -253,26 +253,34 @@ void remove_local_scheduler(GlobalSchedulerState *state, int index) { * @param aux_address: an ip:port pair for the plasma manager associated with * this db client. */ -void process_new_db_client(DBClientID db_client_id, - const char *client_type, - const char *aux_address, - bool is_insertion, - void *user_context) { +void process_new_db_client(DBClient *db_client, void *user_context) { GlobalSchedulerState *state = (GlobalSchedulerState *) user_context; char id_string[ID_STRING_SIZE]; LOG_DEBUG("db client table callback for db client = %s", - ObjectID_to_string(db_client_id, id_string, ID_STRING_SIZE)); + ObjectID_to_string(db_client->id, id_string, ID_STRING_SIZE)); UNUSED(id_string); - if (strncmp(client_type, "local_scheduler", strlen("local_scheduler")) == 0) { - if (is_insertion) { - /* This is a notification for an insert. */ - add_local_scheduler(state, db_client_id, aux_address); + if (strncmp(db_client->client_type, "local_scheduler", + strlen("local_scheduler")) == 0) { + if (db_client->is_insertion) { + /* This is a notification for an insert. We may receive duplicate + * notifications since we read the entire table before processing + * notifications. Filter out local schedulers that we already added. */ + for (LocalScheduler *scheduler = + (LocalScheduler *) utarray_front(state->local_schedulers); + scheduler != NULL; scheduler = (LocalScheduler *) utarray_next( + state->local_schedulers, scheduler)) { + if (UNIQUE_ID_EQ(scheduler->id, db_client->id)) { + return; + } + } + + add_local_scheduler(state, db_client->id, db_client->aux_address); } else { int i = 0; for (; i < utarray_len(state->local_schedulers); ++i) { LocalScheduler *active_worker = (LocalScheduler *) utarray_eltptr(state->local_schedulers, i); - if (DBClientID_equal(active_worker->id, db_client_id)) { + if (DBClientID_equal(active_worker->id, db_client->id)) { break; } } @@ -418,11 +426,11 @@ int heartbeat_timeout_handler(event_loop *loop, timer_id id, void *context) { } void start_server(const char *node_ip_address, - const char *redis_addr, - int redis_port) { + const char *redis_primary_addr, + int redis_primary_port) { event_loop *loop = event_loop_create(); - g_state = - GlobalSchedulerState_init(loop, node_ip_address, redis_addr, redis_port); + g_state = GlobalSchedulerState_init(loop, node_ip_address, redis_primary_addr, + redis_primary_port); /* TODO(rkn): subscribe to notifications from the object table. */ /* Subscribe to notifications about new local schedulers. TODO(rkn): this * needs to also get all of the clients that registered with the database @@ -458,15 +466,15 @@ void start_server(const char *node_ip_address, int main(int argc, char *argv[]) { signal(SIGTERM, signal_handler); - /* IP address and port of redis. */ - char *redis_addr_port = NULL; + /* IP address and port of the primary redis instance. */ + char *redis_primary_addr_port = NULL; /* The IP address of the node that this global scheduler is running on. */ char *node_ip_address = NULL; int c; while ((c = getopt(argc, argv, "h:r:")) != -1) { switch (c) { case 'r': - redis_addr_port = optarg; + redis_primary_addr_port = optarg; break; case 'h': node_ip_address = optarg; @@ -476,16 +484,18 @@ int main(int argc, char *argv[]) { exit(-1); } } - char redis_addr[16]; - int redis_port; - if (!redis_addr_port || - parse_ip_addr_port(redis_addr_port, redis_addr, &redis_port) == -1) { - LOG_ERROR( - "specify the redis address like 127.0.0.1:6379 with the -r switch"); - exit(-1); + + char redis_primary_addr[16]; + int redis_primary_port; + if (!redis_primary_addr_port || + parse_ip_addr_port(redis_primary_addr_port, redis_primary_addr, + &redis_primary_port) == -1) { + LOG_FATAL( + "specify the primary redis address like 127.0.0.1:6379 with the -r " + "switch"); } if (!node_ip_address) { LOG_FATAL("specify the node IP address with the -h switch"); } - start_server(node_ip_address, redis_addr, redis_port); + start_server(node_ip_address, redis_primary_addr, redis_primary_port); } diff --git a/src/local_scheduler/local_scheduler.cc b/src/local_scheduler/local_scheduler.cc index 2188b7ac7..315a70280 100644 --- a/src/local_scheduler/local_scheduler.cc +++ b/src/local_scheduler/local_scheduler.cc @@ -16,6 +16,7 @@ #include "local_scheduler_shared.h" #include "local_scheduler.h" #include "local_scheduler_algorithm.h" +#include "net.h" #include "state/actor_notification_table.h" #include "state/db.h" #include "state/driver_table.h" @@ -296,8 +297,8 @@ const char **parse_command(const char *command) { LocalSchedulerState *LocalSchedulerState_init( const char *node_ip_address, event_loop *loop, - const char *redis_addr, - int redis_port, + const char *redis_primary_addr, + int redis_primary_port, const char *local_scheduler_socket_name, const char *plasma_store_socket_name, const char *plasma_manager_socket_name, @@ -323,7 +324,7 @@ LocalSchedulerState *LocalSchedulerState_init( state->loop = loop; /* Connect to Redis if a Redis address is provided. */ - if (redis_addr != NULL) { + if (redis_primary_addr != NULL) { int num_args; const char **db_connect_args = NULL; /* Use UT_string to convert the resource value into a string. */ @@ -354,8 +355,9 @@ LocalSchedulerState *LocalSchedulerState_init( db_connect_args[4] = "num_gpus"; db_connect_args[5] = utstring_body(num_gpus); } - state->db = db_connect(redis_addr, redis_port, "local_scheduler", - node_ip_address, num_args, db_connect_args); + state->db = db_connect(std::string(redis_primary_addr), redis_primary_port, + "local_scheduler", node_ip_address, num_args, + db_connect_args); utstring_free(num_cpus); utstring_free(num_gpus); free(db_connect_args); @@ -548,8 +550,6 @@ void process_plasma_notification(event_loop *loop, void reconstruct_task_update_callback(Task *task, void *user_context, bool updated) { - /* The task ID should be in the task table. */ - CHECK(task != NULL); if (!updated) { /* The test-and-set of the task's scheduling state failed, so the task was * either not finished yet, or it was already being reconstructed. @@ -578,7 +578,6 @@ void reconstruct_task_update_callback(Task *task, void reconstruct_put_task_update_callback(Task *task, void *user_context, bool updated) { - CHECK(task != NULL); if (updated) { /* The update to TASK_STATUS_RECONSTRUCTING succeeded, so continue with * reconstruction as usual. */ @@ -1111,8 +1110,8 @@ int heartbeat_handler(event_loop *loop, timer_id id, void *context) { void start_server(const char *node_ip_address, const char *socket_name, - const char *redis_addr, - int redis_port, + const char *redis_primary_addr, + int redis_primary_port, const char *plasma_store_socket_name, const char *plasma_manager_socket_name, const char *plasma_manager_address, @@ -1126,8 +1125,8 @@ void start_server(const char *node_ip_address, int fd = bind_ipc_sock(socket_name, true); event_loop *loop = event_loop_create(); g_state = LocalSchedulerState_init( - node_ip_address, loop, redis_addr, redis_port, socket_name, - plasma_store_socket_name, plasma_manager_socket_name, + node_ip_address, loop, redis_primary_addr, redis_primary_port, + socket_name, plasma_store_socket_name, plasma_manager_socket_name, plasma_manager_address, global_scheduler_exists, static_resource_conf, start_worker_command, num_workers); /* Register a callback for registering new clients. */ @@ -1173,8 +1172,8 @@ int main(int argc, char *argv[]) { signal(SIGTERM, signal_handler); /* Path of the listening socket of the local scheduler. */ char *scheduler_socket_name = NULL; - /* IP address and port of redis. */ - char *redis_addr_port = NULL; + /* IP address and port of the primary redis instance. */ + char *redis_primary_addr_port = NULL; /* Socket name for the local Plasma store. */ char *plasma_store_socket_name = NULL; /* Socket name for the local Plasma manager. */ @@ -1199,7 +1198,7 @@ int main(int argc, char *argv[]) { scheduler_socket_name = optarg; break; case 'r': - redis_addr_port = optarg; + redis_primary_addr_port = optarg; break; case 'p': plasma_store_socket_name = optarg; @@ -1266,7 +1265,7 @@ int main(int argc, char *argv[]) { char *redis_addr = NULL; int redis_port = -1; - if (!redis_addr_port) { + if (!redis_primary_addr_port) { /* Start the local scheduler without connecting to Redis. In this case, all * submitted tasks will be queued and scheduled locally. */ if (plasma_manager_socket_name) { @@ -1275,27 +1274,22 @@ int main(int argc, char *argv[]) { "then a redis address must be provided with the -r switch"); } } else { - char redis_addr_buffer[16] = {0}; - char redis_port_str[6] = {0}; - /* Parse the Redis address into an IP address and a port. */ - int num_assigned = sscanf(redis_addr_port, "%15[0-9.]:%5[0-9]", - redis_addr_buffer, redis_port_str); - if (num_assigned != 2) { + char redis_primary_addr[16]; + int redis_primary_port; + /* Parse the primary Redis address into an IP address and a port. */ + if (parse_ip_addr_port(redis_primary_addr_port, redis_primary_addr, + &redis_primary_port) == -1) { LOG_FATAL( "if a redis address is provided with the -r switch, it should be " "formatted like 127.0.0.1:6379"); } - redis_addr = redis_addr_buffer; - redis_port = strtol(redis_port_str, NULL, 10); - if (redis_port == 0) { - LOG_FATAL("Unable to parse port number from redis address %s", - redis_addr_port); - } if (!plasma_manager_socket_name) { LOG_FATAL( "please specify socket for connecting to Plasma manager with -m " "switch"); } + redis_addr = redis_primary_addr; + redis_port = redis_primary_port; } start_server(node_ip_address, scheduler_socket_name, redis_addr, redis_port, diff --git a/src/local_scheduler/local_scheduler_algorithm.cc b/src/local_scheduler/local_scheduler_algorithm.cc index 24149df5b..d71fa65a7 100644 --- a/src/local_scheduler/local_scheduler_algorithm.cc +++ b/src/local_scheduler/local_scheduler_algorithm.cc @@ -839,7 +839,8 @@ void handle_actor_task_submitted(LocalSchedulerState *state, /* Add this task to a queue of tasks that have been submitted but the local * scheduler doesn't know which actor is responsible for them. These tasks * will be resubmitted (internally by the local scheduler) whenever a new - * actor notification arrives. */ + * actor notification arrives. NOTE(swang): These tasks have not yet been + * added to the task table. */ utarray_push_back(algorithm_state->cached_submitted_actor_tasks, &spec); utarray_push_back(algorithm_state->cached_submitted_actor_task_sizes, &task_spec_size); diff --git a/src/local_scheduler/test/local_scheduler_tests.cc b/src/local_scheduler/test/local_scheduler_tests.cc index 31ba4c0bc..acdecc91c 100644 --- a/src/local_scheduler/test/local_scheduler_tests.cc +++ b/src/local_scheduler/test/local_scheduler_tests.cc @@ -18,6 +18,7 @@ #include "task.h" #include "state/object_table.h" #include "state/task_table.h" +#include "state/redis.h" #include "local_scheduler_shared.h" #include "local_scheduler.h" @@ -182,7 +183,16 @@ TEST object_reconstruction_test(void) { /* Add an empty object table entry for the object we want to reconstruct, to * simulate it having been created and evicted. */ const char *client_id = "clientid"; + /* Lookup the shard locations for the object table. */ + std::vector db_shards_addresses; + std::vector db_shards_ports; redisContext *context = redisConnect("127.0.0.1", 6379); + get_redis_shards(context, db_shards_addresses, db_shards_ports); + redisFree(context); + /* There should only be one shard, so we can safely add the empty object + * table entry to the first one. */ + ASSERT(db_shards_addresses.size() == 1); + context = redisConnect(db_shards_addresses[0].c_str(), db_shards_ports[0]); redisReply *reply = (redisReply *) redisCommand( context, "RAY.OBJECT_TABLE_ADD %b %ld %b %s", return_id.id, sizeof(return_id.id), 1, NIL_DIGEST, (size_t) DIGEST_SIZE, client_id); @@ -273,7 +283,16 @@ TEST object_reconstruction_recursive_test(void) { /* Add an empty object table entry for each object we want to reconstruct, to * simulate their having been created and evicted. */ const char *client_id = "clientid"; + /* Lookup the shard locations for the object table. */ + std::vector db_shards_addresses; + std::vector db_shards_ports; redisContext *context = redisConnect("127.0.0.1", 6379); + get_redis_shards(context, db_shards_addresses, db_shards_ports); + redisFree(context); + /* There should only be one shard, so we can safely add the empty object + * table entry to the first one. */ + ASSERT(db_shards_addresses.size() == 1); + context = redisConnect(db_shards_addresses[0].c_str(), db_shards_ports[0]); for (int i = 0; i < NUM_TASKS; ++i) { ObjectID return_id = TaskSpec_return(specs[i], 0); redisReply *reply = (redisReply *) redisCommand( @@ -406,8 +425,8 @@ TEST object_reconstruction_suppression_test(void) { } else { /* Connect a plasma manager client so we can call object_table_add. */ const char *db_connect_args[] = {"address", "127.0.0.1:12346"}; - DBHandle *db = db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", - 2, db_connect_args); + DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", + "127.0.0.1", 2, db_connect_args); db_attach(db, local_scheduler->loop, false); /* Add the object to the object table. */ object_table_add(db, return_id, 1, (unsigned char *) NIL_DIGEST, NULL, diff --git a/src/local_scheduler/test/run_tests.sh b/src/local_scheduler/test/run_tests.sh index 36f31566c..4cb5732a9 100644 --- a/src/local_scheduler/test/run_tests.sh +++ b/src/local_scheduler/test/run_tests.sh @@ -5,10 +5,17 @@ # Cause the script to exit if a single command fails. set -e -./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so & +# Start the Redis shards. +./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6379 & +./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6380 & sleep 1s +# Register the shard location with the primary shard. +./src/common/thirdparty/redis/src/redis-cli set NumRedisShards 1 +./src/common/thirdparty/redis/src/redis-cli rpush RedisShards 127.0.0.1:6380 + ./src/plasma/plasma_store -s /tmp/plasma_store_socket_1 -m 100000000 & sleep 0.5s ./src/local_scheduler/local_scheduler_tests ./src/common/thirdparty/redis/src/redis-cli shutdown +./src/common/thirdparty/redis/src/redis-cli -p 6380 shutdown killall plasma_store diff --git a/src/local_scheduler/test/run_valgrind.sh b/src/local_scheduler/test/run_valgrind.sh index 4aee85aed..9c4e49a3f 100644 --- a/src/local_scheduler/test/run_valgrind.sh +++ b/src/local_scheduler/test/run_valgrind.sh @@ -5,10 +5,17 @@ # Cause the script to exit if a single command fails. set -e -./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so & +# Start the Redis shards. +./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6379 & +./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6380 & sleep 1s +# Register the shard location with the primary shard. +./src/common/thirdparty/redis/src/redis-cli set NumRedisShards 1 +./src/common/thirdparty/redis/src/redis-cli rpush RedisShards 127.0.0.1:6380 + ./src/plasma/plasma_store -s /tmp/plasma_store_socket_1 -m 100000000 & sleep 0.5s valgrind --leak-check=full --show-leak-kinds=all --error-exitcode=1 ./src/local_scheduler/local_scheduler_tests ./src/common/thirdparty/redis/src/redis-cli shutdown +./src/common/thirdparty/redis/src/redis-cli -p 6380 shutdown killall plasma_store diff --git a/src/plasma/plasma_manager.cc b/src/plasma/plasma_manager.cc index 57bab658a..31b3d9662 100644 --- a/src/plasma/plasma_manager.cc +++ b/src/plasma/plasma_manager.cc @@ -493,8 +493,8 @@ PlasmaManagerState *PlasmaManagerState_init(const char *store_socket_name, const char *manager_socket_name, const char *manager_addr, int manager_port, - const char *db_addr, - int db_port) { + const char *redis_primary_addr, + int redis_primary_port) { PlasmaManagerState *state = (PlasmaManagerState *) malloc(sizeof(PlasmaManagerState)); state->loop = event_loop_create(); @@ -504,7 +504,7 @@ PlasmaManagerState *PlasmaManagerState_init(const char *store_socket_name, state->fetch_requests = NULL; state->object_wait_requests_local = NULL; state->object_wait_requests_remote = NULL; - if (db_addr) { + if (redis_primary_addr) { /* Get the manager port as a string. */ UT_string *manager_address_str; utstring_new(manager_address_str); @@ -519,8 +519,9 @@ PlasmaManagerState *PlasmaManagerState_init(const char *store_socket_name, db_connect_args[3] = manager_socket_name; db_connect_args[4] = "address"; db_connect_args[5] = utstring_body(manager_address_str); - state->db = db_connect(db_addr, db_port, "plasma_manager", manager_addr, - num_args, db_connect_args); + state->db = + db_connect(std::string(redis_primary_addr), redis_primary_port, + "plasma_manager", manager_addr, num_args, db_connect_args); utstring_free(manager_address_str); free(db_connect_args); db_attach(state->db, state->loop, false); @@ -1594,8 +1595,8 @@ void start_server(const char *store_socket_name, const char *manager_socket_name, const char *master_addr, int port, - const char *db_addr, - int db_port) { + const char *redis_primary_addr, + int redis_primary_port) { /* Ignore SIGPIPE signals. If we don't do this, then when we attempt to write * to a client that has already died, the manager could die. */ signal(SIGPIPE, SIG_IGN); @@ -1610,9 +1611,9 @@ void start_server(const char *store_socket_name, int local_sock = bind_ipc_sock(manager_socket_name, false); CHECKM(local_sock >= 0, "Unable to bind local manager socket"); - g_manager_state = - PlasmaManagerState_init(store_socket_name, manager_socket_name, - master_addr, port, db_addr, db_port); + g_manager_state = PlasmaManagerState_init( + store_socket_name, manager_socket_name, master_addr, port, + redis_primary_addr, redis_primary_port); CHECK(g_manager_state); CHECK(listen(remote_sock, 5) != -1); @@ -1664,8 +1665,8 @@ int main(int argc, char *argv[]) { char *master_addr = NULL; /* Port number the manager should use. */ int port = -1; - /* IP address and port of state database. */ - char *db_host = NULL; + /* IP address and port of the primary redis instance. */ + char *redis_primary_addr_port = NULL; int c; while ((c = getopt(argc, argv, "s:m:h:p:r:")) != -1) { switch (c) { @@ -1682,7 +1683,7 @@ int main(int argc, char *argv[]) { port = atoi(optarg); break; case 'r': - db_host = optarg; + redis_primary_addr_port = optarg; break; default: LOG_FATAL("unknown option %c", c); @@ -1708,15 +1709,16 @@ int main(int argc, char *argv[]) { "please specify port the plasma manager shall listen to in the" "format 12345 with -p switch"); } - char db_addr[16]; - int db_port; - if (db_host) { - parse_ip_addr_port(db_host, db_addr, &db_port); - start_server(store_socket_name, manager_socket_name, master_addr, port, - db_addr, db_port); - } else { - start_server(store_socket_name, manager_socket_name, master_addr, port, - NULL, 0); + char redis_primary_addr[16]; + int redis_primary_port; + if (!redis_primary_addr_port || + parse_ip_addr_port(redis_primary_addr_port, redis_primary_addr, + &redis_primary_port) == -1) { + LOG_FATAL( + "specify the primary redis address like 127.0.0.1:6379 with the -r " + "switch"); } + start_server(store_socket_name, manager_socket_name, master_addr, port, + redis_primary_addr, redis_primary_port); } #endif diff --git a/src/plasma/test/run_tests.sh b/src/plasma/test/run_tests.sh index ccb1da562..82f8ff994 100644 --- a/src/plasma/test/run_tests.sh +++ b/src/plasma/test/run_tests.sh @@ -9,11 +9,18 @@ sleep 1 killall plasma_store ./src/plasma/serialization_tests -./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so & -redis_pid=$! +# Start the Redis shards. +./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6379 & +redis_pid1=$! +./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6380 & +redis_pid2=$! sleep 1 -# flush the redis server -./src/common/thirdparty/redis/src/redis-cli flushall & + +# Flush the redis server +./src/common/thirdparty/redis/src/redis-cli flushall +# Register the shard location with the primary shard. +./src/common/thirdparty/redis/src/redis-cli set NumRedisShards 1 +./src/common/thirdparty/redis/src/redis-cli rpush RedisShards 127.0.0.1:6380 sleep 1 ./src/plasma/plasma_store -s /tmp/store1 -m 1000000000 & plasma1_pid=$! @@ -31,5 +38,7 @@ kill $plasma4_pid kill $plasma3_pid kill $plasma2_pid kill $plasma1_pid -kill $redis_pid -wait $redis_pid +kill $redis_pid1 +wait $redis_pid1 +kill $redis_pid2 +wait $redis_pid2 diff --git a/test/jenkins_tests/multi_node_docker_test.py b/test/jenkins_tests/multi_node_docker_test.py index 9362f2bd9..6dd69b1bd 100644 --- a/test/jenkins_tests/multi_node_docker_test.py +++ b/test/jenkins_tests/multi_node_docker_test.py @@ -86,8 +86,8 @@ class DockerRunner(object): else: return m.group(1) - def _start_head_node(self, docker_image, mem_size, shm_size, num_cpus, - num_gpus, development_mode): + def _start_head_node(self, docker_image, mem_size, shm_size, + num_redis_shards, num_cpus, num_gpus, development_mode): """Start the Ray head node inside a docker container.""" mem_arg = ["--memory=" + mem_size] if mem_size else [] shm_arg = ["--shm-size=" + shm_size] if shm_size else [] @@ -99,6 +99,7 @@ class DockerRunner(object): command = (["docker", "run", "-d"] + mem_arg + shm_arg + volume_arg + [docker_image, "/ray/scripts/start_ray.sh", "--head", "--redis-port=6379", + "--num-redis-shards={}".format(num_redis_shards), "--num-cpus={}".format(num_cpus), "--num-gpus={}".format(num_gpus)]) print("Starting head node with command:{}".format(command)) @@ -137,8 +138,8 @@ class DockerRunner(object): self.worker_container_ids.append(container_id) def start_ray(self, docker_image=None, mem_size=None, shm_size=None, - num_nodes=None, num_cpus=None, num_gpus=None, - development_mode=None): + num_nodes=None, num_redis_shards=1, num_cpus=None, + num_gpus=None, development_mode=None): """Start a Ray cluster within docker. This starts one docker container running the head node and num_nodes - 1 @@ -153,6 +154,7 @@ class DockerRunner(object): with. This will be passed into `docker run` as the `--shm-size` flag. num_nodes: The number of nodes to use in the cluster (this counts the head node as well). + num_redis_shards: The number of Redis shards to use on the head node. num_cpus: A list of the number of CPUs to start each node with. num_gpus: A list of the number of GPUs to start each node with. development_mode: True if you want to mount the local copy of @@ -163,8 +165,8 @@ class DockerRunner(object): assert len(num_gpus) == num_nodes # Launch the head node. - self._start_head_node(docker_image, mem_size, shm_size, num_cpus[0], - num_gpus[0], development_mode) + self._start_head_node(docker_image, mem_size, shm_size, num_redis_shards, + num_cpus[0], num_gpus[0], development_mode) # Start the worker nodes. for i in range(num_nodes - 1): self._start_worker_node(docker_image, mem_size, shm_size, @@ -252,6 +254,9 @@ if __name__ == "__main__": parser.add_argument("--shm-size", default="1G", help="shared memory size") parser.add_argument("--num-nodes", default=1, type=int, help="number of nodes to use in the cluster") + parser.add_argument("--num-redis-shards", default=1, type=int, + help=("the number of Redis shards to start on the head " + "node")) parser.add_argument("--num-cpus", type=str, help=("a comma separated list of values representing " "the number of CPUs to start each node with")) @@ -282,8 +287,8 @@ if __name__ == "__main__": d = DockerRunner() d.start_ray(docker_image=args.docker_image, mem_size=args.mem_size, shm_size=args.shm_size, num_nodes=num_nodes, - num_cpus=num_cpus, num_gpus=num_gpus, - development_mode=args.development_mode) + num_redis_shards=args.num_redis_shards, num_cpus=num_cpus, + num_gpus=num_gpus, development_mode=args.development_mode) try: run_results = d.run_test(args.test_script, args.num_drivers, driver_locations=driver_locations) diff --git a/test/jenkins_tests/run_multi_node_tests.sh b/test/jenkins_tests/run_multi_node_tests.sh index 7042fb80c..91c9884b1 100755 --- a/test/jenkins_tests/run_multi_node_tests.sh +++ b/test/jenkins_tests/run_multi_node_tests.sh @@ -14,11 +14,13 @@ echo "Using Docker image" $DOCKER_SHA python $ROOT_DIR/multi_node_docker_test.py \ --docker-image=$DOCKER_SHA \ --num-nodes=5 \ + --num-redis-shards=10 \ --test-script=/ray/test/jenkins_tests/multi_node_tests/test_0.py python $ROOT_DIR/multi_node_docker_test.py \ --docker-image=$DOCKER_SHA \ --num-nodes=5 \ + --num-redis-shards=5 \ --num-gpus=0,1,2,3,4 \ --num-drivers=7 \ --driver-locations=0,1,0,1,2,3,4 \ @@ -27,6 +29,7 @@ python $ROOT_DIR/multi_node_docker_test.py \ python $ROOT_DIR/multi_node_docker_test.py \ --docker-image=$DOCKER_SHA \ --num-nodes=5 \ + --num-redis-shards=2 \ --num-gpus=0,0,5,6,50 \ --num-drivers=100 \ --test-script=/ray/test/jenkins_tests/multi_node_tests/many_drivers_test.py diff --git a/test/runtest.py b/test/runtest.py index 46178a880..6bc31b5da 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -293,8 +293,16 @@ class WorkerTest(unittest.TestCase): class APITest(unittest.TestCase): + def init_ray(self, kwargs=None): + if kwargs is None: + kwargs = {} + ray.init(**kwargs) + + def tearDown(self): + ray.worker.cleanup() + def testRegisterClass(self): - ray.init(num_workers=2) + self.init_ray({"num_workers": 2}) # Check that putting an object of a class that has not been registered # throws an exception. @@ -417,11 +425,9 @@ class APITest(unittest.TestCase): self.assertFalse(hasattr(c2, "method0")) self.assertFalse(hasattr(c2, "method1")) - ray.worker.cleanup() - def testKeywordArgs(self): reload(test_functions) - ray.init(num_workers=1) + self.init_ray() x = test_functions.keyword_fct1.remote(1) self.assertEqual(ray.get(x), "1 hello") @@ -483,11 +489,9 @@ class APITest(unittest.TestCase): self.assertEqual(ray.get(f3.remote(4)), 4) - ray.worker.cleanup() - def testVariableNumberOfArgs(self): reload(test_functions) - ray.init(num_workers=1) + self.init_ray() x = test_functions.varargs_fct1.remote(0, 1, 2) self.assertEqual(ray.get(x), "0 1 2") @@ -516,18 +520,14 @@ class APITest(unittest.TestCase): self.assertEqual(ray.get(f2.remote(1, 2, 3)), (1, 2, (3,))) self.assertEqual(ray.get(f2.remote(1, 2, 3, 4)), (1, 2, (3, 4))) - ray.worker.cleanup() - def testNoArgs(self): reload(test_functions) - ray.init(num_workers=1) + self.init_ray() ray.get(test_functions.no_op.remote()) - ray.worker.cleanup() - def testDefiningRemoteFunctions(self): - ray.init(num_workers=3, num_cpus=3) + self.init_ray({"num_cpus": 3}) # Test that we can define a remote function in the shell. @ray.remote @@ -584,10 +584,8 @@ class APITest(unittest.TestCase): self.assertEqual(ray.get(l.remote(1)), 2) self.assertEqual(ray.get(m.remote(1)), 2) - ray.worker.cleanup() - def testGetMultiple(self): - ray.init(num_workers=0) + self.init_ray() object_ids = [ray.put(i) for i in range(10)] self.assertEqual(ray.get(object_ids), list(range(10))) @@ -597,10 +595,8 @@ class APITest(unittest.TestCase): results = ray.get([object_ids[i] for i in indices]) self.assertEqual(results, indices) - ray.worker.cleanup() - def testWait(self): - ray.init(num_workers=1, num_cpus=1) + self.init_ray({"num_cpus": 1}) @ray.remote def f(delay): @@ -633,12 +629,10 @@ class APITest(unittest.TestCase): x = ray.put(1) self.assertRaises(Exception, lambda: ray.wait([x, x])) - ray.worker.cleanup() - def testMultipleWaitsAndGets(self): # It is important to use three workers here, so that the three tasks # launched in this experiment can run at the same time. - ray.init(num_workers=3) + self.init_ray() @ray.remote def f(delay): @@ -665,8 +659,6 @@ class APITest(unittest.TestCase): x = f.remote(1) ray.get([h.remote([x]), h.remote([x])]) - ray.worker.cleanup() - def testCachingEnvironmentVariables(self): # Test that we can define environment variables before the driver is # connected. @@ -690,15 +682,13 @@ class APITest(unittest.TestCase): ray.env.bar.append(1) return ray.env.bar - ray.init(num_workers=2) + self.init_ray() self.assertEqual(ray.get(use_foo.remote()), 1) self.assertEqual(ray.get(use_foo.remote()), 1) self.assertEqual(ray.get(use_bar.remote()), [1]) self.assertEqual(ray.get(use_bar.remote()), [1]) - ray.worker.cleanup() - def testCachingFunctionsToRun(self): # Test that we export functions to run on all workers before the driver is # connected. @@ -718,7 +708,7 @@ class APITest(unittest.TestCase): sys.path.append(4) ray.worker.global_worker.run_function_on_all_workers(f) - ray.init(num_workers=2) + self.init_ray() @ray.remote def get_state(): @@ -738,10 +728,8 @@ class APITest(unittest.TestCase): sys.path.pop() ray.worker.global_worker.run_function_on_all_workers(f) - ray.worker.cleanup() - def testRunningFunctionOnAllWorkers(self): - ray.init(num_workers=1) + self.init_ray() def f(worker_info): sys.path.append("fake_directory") @@ -764,10 +752,8 @@ class APITest(unittest.TestCase): return sys.path self.assertTrue("fake_directory" not in ray.get(get_path2.remote())) - ray.worker.cleanup() - def testLoggingAPI(self): - ray.init(num_workers=1, driver_mode=ray.SILENT_MODE) + self.init_ray({"driver_mode": ray.SILENT_MODE}) def events(): # This is a hack for getting the event log. It is not part of the API. @@ -815,12 +801,10 @@ class APITest(unittest.TestCase): wait_for_num_events(3) self.assertEqual(len(events()), 3) - ray.worker.cleanup() - def testIdenticalFunctionNames(self): # Define a bunch of remote functions and make sure that we don't # accidentally call an older version. - ray.init(num_workers=2) + self.init_ray() num_calls = 200 @@ -878,10 +862,8 @@ class APITest(unittest.TestCase): result_values = ray.get([g.remote() for _ in range(num_calls)]) self.assertEqual(result_values, num_calls * [5]) - ray.worker.cleanup() - def testIllegalAPICalls(self): - ray.init(num_workers=0) + self.init_ray() # Verify that we cannot call put on an ObjectID. x = ray.put(1) @@ -891,7 +873,16 @@ class APITest(unittest.TestCase): with self.assertRaises(Exception): ray.get(3) - ray.worker.cleanup() + +class APITestSharded(APITest): + + def init_ray(self, kwargs=None): + if kwargs is None: + kwargs = {} + kwargs["start_ray_local"] = True + kwargs["num_redis_shards"] = 20 + kwargs["redirect_output"] = True + ray.worker._init(**kwargs) class PythonModeTest(unittest.TestCase): @@ -1619,7 +1610,8 @@ class GlobalStateAPI(unittest.TestCase): task_table = ray.global_state.task_table() self.assertEqual(len(task_table), 1) self.assertEqual(driver_task_id, list(task_table.keys())[0]) - self.assertEqual(task_table[driver_task_id]["State"], "RUNNING") + self.assertEqual(task_table[driver_task_id]["State"], + ray.experimental.state.TASK_STATUS_RUNNING) self.assertEqual(task_table[driver_task_id]["TaskSpec"]["TaskID"], driver_task_id) self.assertEqual(task_table[driver_task_id]["TaskSpec"]["ActorID"], diff --git a/test/stress_tests.py b/test/stress_tests.py index 2cb5ac54f..bb61bf93c 100644 --- a/test/stress_tests.py +++ b/test/stress_tests.py @@ -6,10 +6,6 @@ import unittest import ray import numpy as np import time -import redis - -# Import flatbuffer bindings. -from ray.core.generated.TaskReply import TaskReply class TaskTests(unittest.TestCase): @@ -137,26 +133,38 @@ class ReconstructionTests(unittest.TestCase): num_local_schedulers = 1 def setUp(self): - # Start a Redis instance and Plasma store instances with a total of 1GB - # memory. + # Start the Redis global state store. node_ip_address = "127.0.0.1" - self.redis_port = ray.services.new_port() - print(self.redis_port) - redis_address = ray.services.address(node_ip_address, self.redis_port) + redis_address, redis_shards = ray.services.start_redis(node_ip_address) + self.redis_ip_address = ray.services.get_ip_address(redis_address) + self.redis_port = ray.services.get_port(redis_address) + time.sleep(0.1) + + # Start the Plasma store instances with a total of 1GB memory. self.plasma_store_memory = 10 ** 9 plasma_addresses = [] objstore_memory = (self.plasma_store_memory // self.num_local_schedulers) for i in range(self.num_local_schedulers): + store_stdout_file, store_stderr_file = ray.services.new_log_files( + "plasma_store_{}".format(i), True) + manager_stdout_file, manager_stderr_file = ray.services.new_log_files( + "plasma_manager_{}".format(i), True) plasma_addresses.append(ray.services.start_objstore( - node_ip_address, redis_address, objstore_memory=objstore_memory)) - address_info = {"redis_address": redis_address, - "object_store_addresses": plasma_addresses} + node_ip_address, redis_address, objstore_memory=objstore_memory, + store_stdout_file=store_stdout_file, + store_stderr_file=store_stderr_file, + manager_stdout_file=manager_stdout_file, + manager_stderr_file=manager_stderr_file)) # Start the rest of the services in the Ray cluster. + address_info = {"redis_address": redis_address, + "redis_shards": redis_shards, + "object_store_addresses": plasma_addresses} ray.worker._init(address_info=address_info, start_ray_local=True, num_workers=1, num_local_schedulers=self.num_local_schedulers, num_cpus=[1] * self.num_local_schedulers, + redirect_output=True, driver_mode=ray.SILENT_MODE) def tearDown(self): @@ -164,14 +172,11 @@ class ReconstructionTests(unittest.TestCase): # Determine the IDs of all local schedulers that had a task scheduled or # submitted. - r = redis.StrictRedis(port=self.redis_port) - task_ids = r.keys("TT:*") - task_ids = [task_id[3:] for task_id in task_ids] - local_scheduler_ids = [] - for task_id in task_ids: - message = r.execute_command("ray.task_table_get", task_id) - task_reply_object = TaskReply.GetRootAsTaskReply(message, 0) - local_scheduler_ids.append(task_reply_object.LocalSchedulerId()) + state = ray.experimental.state.GlobalState() + state._initialize_global_state(self.redis_ip_address, self.redis_port) + tasks = state.task_table() + local_scheduler_ids = set(task["LocalSchedulerID"] for task in + tasks.values()) # Make sure that all nodes in the cluster were used by checking that the # set of local scheduler IDs that had a task scheduled or submitted is @@ -179,7 +184,7 @@ class ReconstructionTests(unittest.TestCase): # total number of local schedulers to account for NIL_LOCAL_SCHEDULER_ID. # This is the local scheduler ID associated with the driver task, since it # is not scheduled by a particular local scheduler. - self.assertEqual(len(set(local_scheduler_ids)), + self.assertEqual(len(local_scheduler_ids), self.num_local_schedulers + 1) # Clean up the Ray cluster.