Shard Redis. (#539)

* Implement sharding in the Ray core

* Single node Python modifications to do sharding

* Do the sharding in redis.cc

* Pipe num_redis_shards through start_ray.py and worker.py.

* Use multiple redis shards in multinode tests.

* first steps for sharding ray.global_state

* Fix problem in multinode docker test.

* fix runtest.py

* fix some tests

* fix redis shard startup

* fix redis sharding

* fix

* fix bug introduced by the map-iterator being consumed

* fix sharding bug

* shard event table

* update number of Redis clients to be 64K

* Fix object table tests by flushing shards in between unit tests

* Fix local scheduler tests

* Documentation

* Register shard locations in the primary shard

* Add plasma unit tests back to build

* lint

* lint and fix build

* Fix

* Address Robert's comments

* Refactor start_ray_processes to start Redis shard

* lint

* Fix global scheduler python tests

* Fix redis module test

* Fix plasma test

* Fix component failure test

* Fix local scheduler test

* Fix runtest.py

* Fix global scheduler test for python3

* Fix task_table_test_and_update bug, from actor task table submission race

* Fix jenkins tests.

* Retry Redis shard connections

* Fix test cases

* Convert database clients to DBClient struct

* Fix race condition when subscribing to db client table

* Remove unused lines, add APITest for sharded Ray

* Fix

* Fix memory leak

* Suppress ReconstructionTests output

* Suppress output for APITestSharded

* Reissue task table add/update commands if initial command does not publish to any subscribers.

* fix

* Fix linting.

* fix tests

* fix linting

* fix python test

* fix linting
This commit is contained in:
Stephanie Wang
2017-05-18 17:40:41 -07:00
committed by Philipp Moritz
parent 0a4304725f
commit ee08c8274b
39 changed files with 1336 additions and 651 deletions
+41 -17
View File
@@ -52,7 +52,7 @@ def get_next_message(pubsub_client, timeout_seconds=10):
class TestGlobalStateStore(unittest.TestCase):
def setUp(self):
redis_port, _ = ray.services.start_redis()
redis_port, _ = ray.services.start_redis_instance()
self.redis = redis.StrictRedis(host="localhost", port=redis_port, db=0)
def tearDown(self):
@@ -308,6 +308,10 @@ class TestGlobalStateStore(unittest.TestCase):
TASK_STATUS_SCHEDULED = 2
TASK_STATUS_QUEUED = 4
# make sure somebody will get a notification (checked in the redis module)
p = self.redis.pubsub()
p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
def check_task_reply(message, task_args, updated=False):
task_status, local_scheduler_id, task_spec = task_args
task_reply_object = TaskReply.GetRootAsTaskReply(message, 0)
@@ -388,33 +392,53 @@ class TestGlobalStateStore(unittest.TestCase):
self.assertNotEqual(get_response, old_response)
check_task_reply(get_response, task_args[1:])
def check_task_subscription(self, p, scheduling_state, local_scheduler_id):
task_args = [b"task_id", scheduling_state,
local_scheduler_id.encode("ascii"), b"task_spec"]
self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args)
# Receive the data.
message = get_next_message(p)["data"]
# Check that the notification object is correct.
notification_object = TaskReply.GetRootAsTaskReply(message, 0)
self.assertEqual(notification_object.TaskId(), b"task_id")
self.assertEqual(notification_object.State(), scheduling_state)
self.assertEqual(notification_object.LocalSchedulerId(),
local_scheduler_id.encode("ascii"))
self.assertEqual(notification_object.TaskSpec(), b"task_spec")
def testTaskTableSubscribe(self):
scheduling_state = 1
local_scheduler_id = "local_scheduler_id"
# Subscribe to the task table.
p = self.redis.pubsub()
p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 1)
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
# unsubscribe to make sure there is only one subscriber at a given time
p.punsubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 0)
p.psubscribe("{prefix}*:{state}".format(
prefix=TASK_PREFIX, state=scheduling_state))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 1)
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
p.punsubscribe("{prefix}*:{state}".format(
prefix=TASK_PREFIX, state=scheduling_state))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 0)
p.psubscribe("{prefix}{local_scheduler_id}:*".format(
prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
task_args = [b"task_id", scheduling_state,
local_scheduler_id.encode("ascii"), b"task_spec"]
self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args)
# Receive the acknowledgement message.
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 1)
self.assertEqual(get_next_message(p)["data"], 2)
self.assertEqual(get_next_message(p)["data"], 3)
# Receive the actual data.
for i in range(3):
message = get_next_message(p)["data"]
# Check that the notification object is correct.
notification_object = TaskReply.GetRootAsTaskReply(message, 0)
self.assertEqual(notification_object.TaskId(), b"task_id")
self.assertEqual(notification_object.State(), scheduling_state)
self.assertEqual(notification_object.LocalSchedulerId(),
local_scheduler_id.encode("ascii"))
self.assertEqual(notification_object.TaskSpec(), b"task_spec")
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
p.punsubscribe("{prefix}{local_scheduler_id}:*".format(
prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 0)
if __name__ == "__main__":
+85 -26
View File
@@ -5,6 +5,7 @@ from __future__ import print_function
import pickle
import redis
import ray
from ray.utils import (decode, binary_to_object_id, binary_to_hex,
hex_to_binary)
@@ -25,14 +26,21 @@ OBJECT_CHANNEL_PREFIX = "OC:"
# This mapping from integer to task state string must be kept up-to-date with
# the scheduling_state enum in task.h.
task_state_mapping = {
1: "WAITING",
2: "SCHEDULED",
4: "QUEUED",
8: "RUNNING",
16: "DONE",
32: "LOST",
64: "RECONSTRUCTING"
TASK_STATUS_WAITING = 1
TASK_STATUS_SCHEDULED = 2
TASK_STATUS_QUEUED = 4
TASK_STATUS_RUNNING = 8
TASK_STATUS_DONE = 16
TASK_STATUS_LOST = 32
TASK_STATUS_RECONSTRUCTING = 64
TASK_STATUS_MAPPING = {
TASK_STATUS_WAITING: "WAITING",
TASK_STATUS_SCHEDULED: "SCHEDULED",
TASK_STATUS_QUEUED: "QUEUED",
TASK_STATUS_RUNNING: "RUNNING",
TASK_STATUS_DONE: "DONE",
TASK_STATUS_LOST: "LOST",
TASK_STATUS_RECONSTRUCTING: "RECONSTRUCTING",
}
@@ -66,8 +74,54 @@ class GlobalState(object):
"""
self.redis_client = redis.StrictRedis(host=redis_ip_address,
port=redis_port)
self.redis_clients = []
num_redis_shards = self.redis_client.get("NumRedisShards")
if num_redis_shards is None:
raise Exception("No entry found for NumRedisShards")
num_redis_shards = int(num_redis_shards)
if (num_redis_shards < 1):
raise Exception("Expected at least one Redis shard, found "
"{}.".format(num_redis_shards))
def _object_table(self, object_id_binary):
ip_address_ports = self.redis_client.lrange("RedisShards", start=0, end=-1)
if len(ip_address_ports) != num_redis_shards:
raise Exception("Expected {} Redis shard addresses, found "
"{}".format(num_redis_shards, len(ip_address_ports)))
for ip_address_port in ip_address_ports:
shard_address, shard_port = ip_address_port.split(b":")
self.redis_clients.append(redis.StrictRedis(host=shard_address,
port=shard_port))
def _execute_command(self, key, *args):
"""Execute a Redis command on the appropriate Redis shard based on key.
Args:
key: The object ID or the task ID that the query is about.
args: The command to run.
Returns:
The value returned by the Redis command.
"""
client = self.redis_clients[key.redis_shard_hash() %
len(self.redis_clients)]
return client.execute_command(*args)
def _keys(self, pattern):
"""Execute the KEYS command on all Redis shards.
Args:
pattern: The KEYS pattern to query.
Returns:
The concatenated list of results from all shards.
"""
result = []
for client in self.redis_clients:
result.extend(client.keys(pattern))
return result
def _object_table(self, object_id):
"""Fetch and parse the object table information for a single object ID.
Args:
@@ -78,16 +132,18 @@ class GlobalState(object):
A dictionary with information about the object ID in question.
"""
# Return information about a single object ID.
object_locations = self.redis_client.execute_command(
"RAY.OBJECT_TABLE_LOOKUP", object_id_binary)
object_locations = self._execute_command(object_id,
"RAY.OBJECT_TABLE_LOOKUP",
object_id.id())
if object_locations is not None:
manager_ids = [binary_to_hex(manager_id)
for manager_id in object_locations]
else:
manager_ids = None
result_table_response = self.redis_client.execute_command(
"RAY.RESULT_TABLE_LOOKUP", object_id_binary)
result_table_response = self._execute_command(object_id,
"RAY.RESULT_TABLE_LOOKUP",
object_id.id())
result_table_message = ResultTableReply.GetRootAsResultTableReply(
result_table_response, 0)
@@ -111,22 +167,21 @@ class GlobalState(object):
self._check_connected()
if object_id is not None:
# Return information about a single object ID.
return self._object_table(object_id.id())
return self._object_table(object_id)
else:
# Return the entire object table.
object_info_keys = self.redis_client.keys(OBJECT_INFO_PREFIX + "*")
object_location_keys = self.redis_client.keys(
OBJECT_LOCATION_PREFIX + "*")
object_info_keys = self._keys(OBJECT_INFO_PREFIX + "*")
object_location_keys = self._keys(OBJECT_LOCATION_PREFIX + "*")
object_ids_binary = set(
[key[len(OBJECT_INFO_PREFIX):] for key in object_info_keys] +
[key[len(OBJECT_LOCATION_PREFIX):] for key in object_location_keys])
results = {}
for object_id_binary in object_ids_binary:
results[binary_to_object_id(object_id_binary)] = self._object_table(
object_id_binary)
binary_to_object_id(object_id_binary))
return results
def _task_table(self, task_id_binary):
def _task_table(self, task_id):
"""Fetch and parse the task table information for a single object task ID.
Args:
@@ -135,12 +190,15 @@ class GlobalState(object):
Returns:
A dictionary with information about the task ID in question.
TASK_STATUS_MAPPING should be used to parse the "State" field into a
human-readable string.
"""
task_table_response = self.redis_client.execute_command(
"RAY.TASK_TABLE_GET", task_id_binary)
task_table_response = self._execute_command(task_id,
"RAY.TASK_TABLE_GET",
task_id.id())
if task_table_response is None:
raise Exception("There is no entry for task ID {} in the task table."
.format(binary_to_hex(task_id_binary)))
.format(binary_to_hex(task_id.id())))
task_table_message = TaskReply.GetRootAsTaskReply(task_table_response, 0)
task_spec = task_table_message.TaskSpec()
task_spec_message = TaskInfo.GetRootAsTaskInfo(task_spec, 0)
@@ -167,7 +225,7 @@ class GlobalState(object):
for i in range(task_spec_message.ReturnsLength())],
"RequiredResources": required_resources}
return {"State": task_state_mapping[task_table_message.State()],
return {"State": task_table_message.State(),
"LocalSchedulerID": binary_to_hex(
task_table_message.LocalSchedulerId()),
"TaskSpec": task_spec_info}
@@ -185,14 +243,15 @@ class GlobalState(object):
"""
self._check_connected()
if task_id is not None:
return self._task_table(hex_to_binary(task_id))
task_id = ray.local_scheduler.ObjectID(hex_to_binary(task_id))
return self._task_table(task_id)
else:
task_table_keys = self.redis_client.keys(TASK_PREFIX + "*")
task_table_keys = self._keys(TASK_PREFIX + "*")
results = {}
for key in task_table_keys:
task_id_binary = key[len(TASK_PREFIX):]
results[binary_to_hex(task_id_binary)] = self._task_table(
task_id_binary)
ray.local_scheduler.ObjectID(task_id_binary))
return results
def function_table(self, function_id=None):
@@ -7,9 +7,9 @@ import subprocess
import time
def start_global_scheduler(redis_address, node_ip_address, use_valgrind=False,
use_profiler=False, stdout_file=None,
stderr_file=None):
def start_global_scheduler(redis_address, node_ip_address,
use_valgrind=False, use_profiler=False,
stdout_file=None, stderr_file=None):
"""Start a global scheduler process.
Args:
+46 -51
View File
@@ -5,7 +5,6 @@ from __future__ import print_function
import numpy as np
import os
import random
import redis
import signal
import sys
import time
@@ -17,6 +16,7 @@ import ray.plasma as plasma
from ray.plasma.utils import create_object
from ray import services
from ray.experimental import state
USE_VALGRIND = False
PLASMA_STORE_MEMORY = 1000000000
@@ -26,13 +26,6 @@ NUM_CLUSTER_NODES = 2
NIL_WORKER_ID = 20 * b"\xff"
NIL_ACTOR_ID = 20 * b"\xff"
# These constants must match the scheduling state enum in task.h.
TASK_STATUS_WAITING = 1
TASK_STATUS_SCHEDULED = 2
TASK_STATUS_QUEUED = 4
TASK_STATUS_RUNNING = 8
TASK_STATUS_DONE = 16
# These constants are an implementation detail of ray_redis_module.cc, so this
# must be kept in sync with that file.
DB_CLIENT_PREFIX = "CL:"
@@ -63,15 +56,17 @@ class TestGlobalScheduler(unittest.TestCase):
def setUp(self):
# Start one Redis server and N pairs of (plasma, local_scheduler)
node_ip_address = "127.0.0.1"
redis_port, self.redis_process = services.start_redis(cleanup=False)
redis_address = services.address(node_ip_address, redis_port)
# Create a Redis client.
self.redis_client = redis.StrictRedis(host=node_ip_address,
port=redis_port)
self.node_ip_address = "127.0.0.1"
redis_address, redis_shards = services.start_redis(self.node_ip_address)
redis_port = services.get_port(redis_address)
time.sleep(0.1)
# Create a client for the global state store.
self.state = state.GlobalState()
self.state._initialize_global_state(self.node_ip_address, redis_port)
# Start one global scheduler.
self.p1 = global_scheduler.start_global_scheduler(
redis_address, node_ip_address, use_valgrind=USE_VALGRIND)
redis_address, self.node_ip_address, use_valgrind=USE_VALGRIND)
self.plasma_store_pids = []
self.plasma_manager_pids = []
self.local_scheduler_pids = []
@@ -89,7 +84,8 @@ class TestGlobalScheduler(unittest.TestCase):
redis_address)
plasma_manager_name, p3, plasma_manager_port = manager_info
self.plasma_manager_pids.append(p3)
plasma_address = "{}:{}".format(node_ip_address, plasma_manager_port)
plasma_address = "{}:{}".format(self.node_ip_address,
plasma_manager_port)
plasma_client = plasma.PlasmaClient(plasma_store_name,
plasma_manager_name)
self.plasma_clients.append(plasma_client)
@@ -116,7 +112,10 @@ class TestGlobalScheduler(unittest.TestCase):
for p4 in self.local_scheduler_pids:
self.assertEqual(p4.poll(), None)
self.assertEqual(self.redis_process.poll(), None)
redis_processes = services.all_processes[
services.PROCESS_TYPE_REDIS_SERVER]
for redis_process in redis_processes:
self.assertEqual(redis_process.poll(), None)
# Kill the global scheduler.
if USE_VALGRIND:
@@ -135,7 +134,9 @@ class TestGlobalScheduler(unittest.TestCase):
p4.kill()
# Kill Redis. In the event that we are using valgrind, this needs to happen
# after we kill the global scheduler.
self.redis_process.kill()
while redis_processes:
redis_process = redis_processes.pop()
redis_process.kill()
def get_plasma_manager_id(self):
"""Get the db_client_id with client_type equal to plasma_manager.
@@ -150,11 +151,10 @@ class TestGlobalScheduler(unittest.TestCase):
"""
db_client_id = None
client_list = self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))
for client_id in client_list:
response = self.redis_client.hget(client_id, b"client_type")
if response == b"plasma_manager":
db_client_id = client_id
client_list = self.state.client_table()[self.node_ip_address]
for client in client_list:
if client["ClientType"] == "plasma_manager":
db_client_id = client["DBClientID"]
break
return db_client_id
@@ -178,18 +178,16 @@ class TestGlobalScheduler(unittest.TestCase):
# There should be 2n+1 db clients: the global scheduler + one local
# scheduler and one plasma per node.
self.assertEqual(
len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
len(self.state.client_table()[self.node_ip_address]),
2 * NUM_CLUSTER_NODES + 1)
db_client_id = self.get_plasma_manager_id()
assert(db_client_id is not None)
assert(db_client_id.startswith(b"CL:"))
db_client_id = db_client_id[len(b"CL:"):] # Remove the CL: prefix.
def test_integration_single_task(self):
# There should be three db clients, the global scheduler, the local
# scheduler, and the plasma manager.
self.assertEqual(
len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
len(self.state.client_table()[self.node_ip_address]),
2 * NUM_CLUSTER_NODES + 1)
num_return_vals = [0, 1, 2, 3, 5, 10]
@@ -212,15 +210,15 @@ class TestGlobalScheduler(unittest.TestCase):
# local scheduler
num_retries = 10
while num_retries > 0:
task_entries = self.redis_client.keys("{}*".format(TASK_PREFIX))
task_entries = self.state.task_table()
self.assertLessEqual(len(task_entries), 1)
if len(task_entries) == 1:
task_contents = self.redis_client.hgetall(task_entries[0])
task_status = int(task_contents[b"state"])
self.assertTrue(task_status in [TASK_STATUS_WAITING,
TASK_STATUS_SCHEDULED,
TASK_STATUS_QUEUED])
if task_status == TASK_STATUS_QUEUED:
task_id, task = task_entries.popitem()
task_status = task["State"]
self.assertTrue(task_status in [state.TASK_STATUS_WAITING,
state.TASK_STATUS_SCHEDULED,
state.TASK_STATUS_QUEUED])
if task_status == state.TASK_STATUS_QUEUED:
break
else:
print(task_status)
@@ -228,7 +226,7 @@ class TestGlobalScheduler(unittest.TestCase):
num_retries -= 1
time.sleep(1)
if num_retries <= 0 and task_status != TASK_STATUS_QUEUED:
if num_retries <= 0 and task_status != state.TASK_STATUS_QUEUED:
# Failed to submit and schedule a single task -- bail.
self.tearDown()
sys.exit(1)
@@ -237,7 +235,7 @@ class TestGlobalScheduler(unittest.TestCase):
# There should be three db clients, the global scheduler, the local
# scheduler, and the plasma manager.
self.assertEqual(
len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
len(self.state.client_table()[self.node_ip_address]),
2 * NUM_CLUSTER_NODES + 1)
num_return_vals = [0, 1, 2, 3, 5, 10]
@@ -264,34 +262,31 @@ class TestGlobalScheduler(unittest.TestCase):
num_retries = 10
num_tasks_done = 0
while num_retries > 0:
task_entries = self.redis_client.keys("{}*".format(TASK_PREFIX))
task_entries = self.state.task_table()
self.assertLessEqual(len(task_entries), num_tasks)
# First, check if all tasks made it to Redis.
if len(task_entries) == num_tasks:
task_contents = [self.redis_client.hgetall(task_entries[i])
for i in range(len(task_entries))]
task_statuses = [int(contents[b"state"]) for contents in task_contents]
self.assertTrue(all([status in [TASK_STATUS_WAITING,
TASK_STATUS_SCHEDULED,
TASK_STATUS_QUEUED]
task_statuses = [task_entry["State"] for task_entry in
task_entries.values()]
self.assertTrue(all([status in [state.TASK_STATUS_WAITING,
state.TASK_STATUS_SCHEDULED,
state.TASK_STATUS_QUEUED]
for status in task_statuses]))
num_tasks_done = task_statuses.count(TASK_STATUS_QUEUED)
num_tasks_scheduled = task_statuses.count(TASK_STATUS_SCHEDULED)
num_tasks_waiting = task_statuses.count(TASK_STATUS_WAITING)
num_tasks_done = task_statuses.count(state.TASK_STATUS_QUEUED)
num_tasks_scheduled = task_statuses.count(state.TASK_STATUS_SCHEDULED)
num_tasks_waiting = task_statuses.count(state.TASK_STATUS_WAITING)
print("tasks in Redis = {}, tasks waiting = {}, tasks scheduled = {}, "
"tasks queued = {}, retries left = {}"
.format(len(task_entries), num_tasks_waiting,
num_tasks_scheduled, num_tasks_done, num_retries))
if all([status == TASK_STATUS_QUEUED for status in task_statuses]):
if all([status == state.TASK_STATUS_QUEUED for status in
task_statuses]):
# We're done, so pass.
break
num_retries -= 1
time.sleep(0.1)
if num_tasks_done != num_tasks:
# At least one of the tasks failed to schedule.
self.tearDown()
sys.exit(2)
self.assertEqual(num_tasks_done, num_tasks)
def test_integration_many_tasks_handler_sync(self):
self.integration_many_tasks_helper(timesync=True)
+38 -40
View File
@@ -12,11 +12,13 @@ import time
import ray
from ray.services import get_ip_address
from ray.services import get_port
from ray.utils import binary_to_object_id
from ray.utils import binary_to_hex
from ray.utils import hex_to_binary
# Import flatbuffer bindings.
from ray.core.generated.SubscribeToDBClientTableReply \
import SubscribeToDBClientTableReply
from ray.core.generated.TaskReply import TaskReply
from ray.core.generated.DriverTableMessage import DriverTableMessage
# These variables must be kept in sync with the C codebase.
@@ -31,7 +33,6 @@ TASK_STATUS_LOST = 32
PLASMA_MANAGER_HEARTBEAT_CHANNEL = b"plasma_managers"
DRIVER_DEATH_CHANNEL = b"driver_deaths"
# common/redis_module/ray_redis_module.cc
TASK_PREFIX = "TT:"
OBJECT_PREFIX = "OL:"
DB_CLIENT_PREFIX = "CL:"
DB_CLIENT_TABLE_NAME = b"db_clients"
@@ -43,7 +44,7 @@ PLASMA_MANAGER_CLIENT_TYPE = b"plasma_manager"
# Set up logging.
logging.basicConfig()
log = logging.getLogger()
log.setLevel(logging.WARN)
log.setLevel(logging.INFO)
class Monitor(object):
@@ -70,7 +71,11 @@ class Monitor(object):
"""
def __init__(self, redis_address, redis_port):
# Initialize the Redis clients.
self.state = ray.experimental.state.GlobalState()
self.state._initialize_global_state(redis_address, redis_port)
self.redis = redis.StrictRedis(host=redis_address, port=redis_port, db=0)
# TODO(swang): Update pubsub client to use ray.experimental.state once
# subscriptions are implemented there.
self.subscribe_client = self.redis.pubsub()
self.subscribed = {}
# Initialize data structures to keep track of the active database clients.
@@ -97,23 +102,17 @@ class Monitor(object):
TASK_STATUS_LOST. A local scheduler is deemed dead if it is in
self.dead_local_schedulers.
"""
task_ids = self.redis.scan_iter(
match="{prefix}*".format(prefix=TASK_PREFIX))
tasks = self.state.task_table()
num_tasks_updated = 0
for task_id in task_ids:
task_id = task_id[len(TASK_PREFIX):]
response = self.redis.execute_command("RAY.TASK_TABLE_GET", task_id)
# Parse the serialized task object.
task_object = TaskReply.GetRootAsTaskReply(response, 0)
local_scheduler_id = task_object.LocalSchedulerId()
for task_id, task in tasks.items():
# See if the corresponding local scheduler is alive.
if local_scheduler_id in self.dead_local_schedulers:
if task["LocalSchedulerID"] in self.dead_local_schedulers:
# If the task is scheduled on a dead local scheduler, mark the task as
# lost.
ok = self.redis.execute_command("RAY.TASK_TABLE_UPDATE",
task_id,
TASK_STATUS_LOST,
NIL_ID)
key = binary_to_object_id(hex_to_binary(task_id))
ok = self.state._execute_command(
key, "RAY.TASK_TABLE_UPDATE", hex_to_binary(task_id),
ray.experimental.state.TASK_STATUS_LOST, NIL_ID)
if ok != b"OK":
log.warn("Failed to update lost task for dead scheduler.")
num_tasks_updated += 1
@@ -129,19 +128,20 @@ class Monitor(object):
"""
# TODO(swang): Also kill the associated plasma store, since it's no longer
# reachable without a plasma manager.
object_ids = self.redis.scan_iter(
match="{prefix}*".format(prefix=OBJECT_PREFIX))
objects = self.state.object_table()
num_objects_removed = 0
for object_id in object_ids:
object_id = object_id[len(OBJECT_PREFIX):]
managers = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
object_id)
for manager in managers:
for object_id, obj in objects.items():
manager_ids = obj["ManagerIDs"]
if manager_ids is None:
continue
for manager in manager_ids:
if manager in self.dead_plasma_managers:
# If the object was on a dead plasma manager, remove that location
# entry.
ok = self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", object_id,
manager)
ok = self.state._execute_command(object_id,
"RAY.OBJECT_TABLE_REMOVE",
object_id.id(),
hex_to_binary(manager))
if ok != b"OK":
log.warn("Failed to remove object location for dead plasma "
"manager.")
@@ -157,18 +157,16 @@ class Monitor(object):
not miss any notifications for deleted clients that occurred before we
subscribed.
"""
db_client_keys = self.redis.keys(
"{prefix}*".format(prefix=DB_CLIENT_PREFIX))
for db_client_key in db_client_keys:
db_client_id = db_client_key[len(DB_CLIENT_PREFIX):]
client_type, deleted = self.redis.hmget(db_client_key,
[b"client_type", b"deleted"])
deleted = bool(int(deleted))
if deleted:
if client_type == LOCAL_SCHEDULER_CLIENT_TYPE:
self.dead_local_schedulers.add(db_client_id)
elif client_type == PLASMA_MANAGER_CLIENT_TYPE:
self.dead_plasma_managers.add(db_client_id)
clients = self.state.client_table()
for node_ip_address, node_clients in clients.items():
for client in node_clients:
db_client_id = client["DBClientID"]
client_type = client["ClientType"]
if client["Deleted"]:
if client_type == LOCAL_SCHEDULER_CLIENT_TYPE:
self.dead_local_schedulers.add(db_client_id)
elif client_type == PLASMA_MANAGER_CLIENT_TYPE:
self.dead_plasma_managers.add(db_client_id)
def subscribe_handler(self, channel, data):
"""Handle a subscription success message from Redis.
@@ -186,7 +184,7 @@ class Monitor(object):
"""
notification_object = (SubscribeToDBClientTableReply
.GetRootAsSubscribeToDBClientTableReply(data, 0))
db_client_id = notification_object.DbClientId()
db_client_id = binary_to_hex(notification_object.DbClientId())
client_type = notification_object.ClientType()
is_insertion = notification_object.IsInsertion()
@@ -196,7 +194,7 @@ class Monitor(object):
# If the update was a deletion, add them to our accounting for dead
# local schedulers and plasma managers.
log.warn("Removed {}".format(client_type))
log.warn("Removed {}, client ID {}".format(client_type, db_client_id))
if client_type == LOCAL_SCHEDULER_CLIENT_TYPE:
if db_client_id not in self.dead_local_schedulers:
self.dead_local_schedulers.add(db_client_id)
@@ -256,7 +254,7 @@ class Monitor(object):
result = pipe.hget(local_scheduler_id, "gpus_in_use")
gpus_in_use = dict() if result is None else json.loads(result)
driver_id_hex = ray.utils.binary_to_hex(driver_id)
driver_id_hex = binary_to_hex(driver_id)
if driver_id_hex in gpus_in_use:
num_gpus_returned = gpus_in_use.pop(driver_id_hex)
+2 -1
View File
@@ -405,7 +405,8 @@ def start_plasma_manager(store_name, redis_address,
"-m", plasma_manager_name,
"-h", node_ip_address,
"-p", str(plasma_manager_port),
"-r", redis_address]
"-r", redis_address,
]
if use_valgrind:
process = subprocess.Popen(["valgrind",
"--track-origins=yes",
+2 -3
View File
@@ -480,7 +480,7 @@ class TestPlasmaManager(unittest.TestCase):
store_name1, self.p2 = plasma.start_plasma_store(use_valgrind=USE_VALGRIND)
store_name2, self.p3 = plasma.start_plasma_store(use_valgrind=USE_VALGRIND)
# Start a Redis server.
redis_address = services.address("127.0.0.1", services.start_redis()[0])
redis_address, _ = services.start_redis("127.0.0.1")
# Start two PlasmaManagers.
manager_name1, self.p4, self.port1 = plasma.start_plasma_manager(
store_name1, redis_address, use_valgrind=USE_VALGRIND)
@@ -778,8 +778,7 @@ class TestPlasmaManagerRecovery(unittest.TestCase):
self.store_name, self.p2 = plasma.start_plasma_store(
use_valgrind=USE_VALGRIND)
# Start a Redis server.
self.redis_address = services.address("127.0.0.1",
services.start_redis()[0])
self.redis_address, _ = services.start_redis("127.0.0.1")
# Start a PlasmaManagers.
manager_name, self.p3, self.port1 = plasma.start_plasma_manager(
self.store_name,
+110 -45
View File
@@ -240,15 +240,75 @@ def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5):
"configured properly.")
def start_redis(node_ip_address="127.0.0.1", port=None, num_retries=20,
stdout_file=None, stderr_file=None, cleanup=True):
"""Start a Redis server.
def start_redis(node_ip_address,
port=None,
num_redis_shards=1,
redirect_output=False,
cleanup=True):
"""Start the Redis global state store.
Args:
node_ip_address: The IP address of the current node. This is only used for
recording the log filenames in Redis.
port (int): If provided, the primary Redis shard will be started on this
port.
num_redis_shards (int): If provided, the number of Redis shards to start,
in addition to the primary one. The default value is one shard.
cleanup (bool): True if using Ray in local mode. If cleanup is true, then
all Redis processes started by this method will be killed by
serices.cleanup() when the Python process that imported services exits.
Returns:
A tuple of the address for the primary Redis shard and a list of addresses
for the remaining shards.
"""
redis_stdout_file, redis_stderr_file = new_log_files(
"redis", redirect_output)
assigned_port, _ = start_redis_instance(
node_ip_address=node_ip_address, port=port,
stdout_file=redis_stdout_file, stderr_file=redis_stderr_file,
cleanup=cleanup)
if port is not None:
assert assigned_port == port
port = assigned_port
redis_address = address(node_ip_address, port)
# Register the number of Redis shards in the primary shard, so that clients
# know how many redis shards to expect under RedisShards.
redis_client = redis.StrictRedis(host=node_ip_address, port=port)
redis_client.set("NumRedisShards", str(num_redis_shards))
# Start other Redis shards listening on random ports. Each Redis shard logs
# to a separate file, prefixed by "redis-<shard number>".
redis_shards = []
for i in range(num_redis_shards):
redis_stdout_file, redis_stderr_file = new_log_files(
"redis-{}".format(i), redirect_output)
redis_shard_port, _ = start_redis_instance(
node_ip_address=node_ip_address, stdout_file=redis_stdout_file,
stderr_file=redis_stderr_file, cleanup=cleanup)
shard_address = address(node_ip_address, redis_shard_port)
redis_shards.append(shard_address)
# Store redis shard information in the primary redis shard.
redis_client.rpush("RedisShards", shard_address)
return redis_address, redis_shards
def start_redis_instance(node_ip_address="127.0.0.1",
port=None,
num_retries=20,
stdout_file=None,
stderr_file=None,
cleanup=True):
"""Start a single Redis server.
Args:
node_ip_address (str): The IP address of the current node. This is only
used for recording the log filenames in Redis.
port (int): If provided, start a Redis server with this port.
num_retries (int): The number of times to attempt to start Redis.
num_retries (int): The number of times to attempt to start Redis. If a port
is provided, this defaults to 1.
stdout_file: A file handle opened for writing to redirect stdout to. If no
redirection should happen, then this should be None.
stderr_file: A file handle opened for writing to redirect stderr to. If no
@@ -275,8 +335,8 @@ def start_redis(node_ip_address="127.0.0.1", port=None, num_retries=20,
assert os.path.isfile(redis_module)
counter = 0
if port is not None:
if num_retries != 1:
raise Exception("num_retries must be 1 if port is specified.")
# If a port is specified, then try only once to connect.
num_retries = 1
else:
port = new_port()
while counter < num_retries:
@@ -356,8 +416,8 @@ def start_log_monitor(redis_address, node_ip_address, stdout_file=None,
[stdout_file, stderr_file])
def start_global_scheduler(redis_address, node_ip_address, stdout_file=None,
stderr_file=None, cleanup=True):
def start_global_scheduler(redis_address, node_ip_address,
stdout_file=None, stderr_file=None, cleanup=True):
"""Start a global scheduler process.
Args:
@@ -372,7 +432,8 @@ def start_global_scheduler(redis_address, node_ip_address, stdout_file=None,
this process will be killed by services.cleanup() when the Python process
that imported services exits.
"""
p = global_scheduler.start_global_scheduler(redis_address, node_ip_address,
p = global_scheduler.start_global_scheduler(redis_address,
node_ip_address,
stdout_file=stdout_file,
stderr_file=stderr_file)
if cleanup:
@@ -545,10 +606,11 @@ def start_local_scheduler(redis_address,
return local_scheduler_name
def start_objstore(node_ip_address, redis_address, object_manager_port=None,
store_stdout_file=None, store_stderr_file=None,
manager_stdout_file=None, manager_stderr_file=None,
cleanup=True, objstore_memory=None):
def start_objstore(node_ip_address, redis_address,
object_manager_port=None, store_stdout_file=None,
store_stderr_file=None, manager_stdout_file=None,
manager_stderr_file=None, cleanup=True,
objstore_memory=None):
"""This method starts an object store process.
Args:
@@ -704,13 +766,14 @@ def start_monitor(redis_address, node_ip_address, stdout_file=None,
def start_ray_processes(address_info=None,
node_ip_address="127.0.0.1",
redis_port=None,
num_workers=None,
num_local_schedulers=1,
num_redis_shards=1,
worker_path=None,
cleanup=True,
redirect_output=False,
include_global_scheduler=False,
include_redis=False,
include_log_monitor=False,
include_webui=False,
start_workers_from_local_scheduler=True,
@@ -723,12 +786,17 @@ def start_ray_processes(address_info=None,
that have already been started. If provided, address_info will be
modified to include processes that are newly started.
node_ip_address (str): The IP address of this node.
redis_port (int): The port that the primary Redis shard should listen to.
If None, then a random port will be chosen. If the key "redis_address" is
in address_info, then this argument will be ignored.
num_workers (int): The number of workers to start.
num_local_schedulers (int): The total number of local schedulers required.
This is also the total number of object stores required. This method will
start new instances of local schedulers and object stores until there are
num_local_schedulers existing instances of each, including ones already
registered with the given address_info.
num_redis_shards: The number of Redis shards to start in addition to the
primary Redis shard.
worker_path (str): The path of the source code that will be run by the
worker.
cleanup (bool): If cleanup is true, then the processes started here will be
@@ -738,8 +806,6 @@ def start_ray_processes(address_info=None,
file.
include_global_scheduler (bool): If include_global_scheduler is True, then
start a global scheduler process.
include_redis (bool): If include_redis is True, then start a Redis server
process.
include_log_monitor (bool): If True, then start a log monitor to monitor
the log files for all processes on this node and push their contents to
Redis.
@@ -785,29 +851,14 @@ def start_ray_processes(address_info=None,
# warning messages when it starts up. Instead of suppressing the output, we
# should address the warnings.
redis_address = address_info.get("redis_address")
if include_redis:
redis_stdout_file, redis_stderr_file = new_log_files("redis",
redirect_output)
if redis_address is None:
# Start a Redis server. The start_redis method will choose a random port.
redis_port, _ = start_redis(node_ip_address,
stdout_file=redis_stdout_file,
stderr_file=redis_stderr_file,
cleanup=cleanup)
redis_address = address(node_ip_address, redis_port)
address_info["redis_address"] = redis_address
time.sleep(0.1)
else:
# A Redis address was provided, so start a Redis server with the given
# port. TODO(rkn): We should check that the IP address corresponds to the
# machine that this method is running on.
redis_port = get_port(redis_address)
new_redis_port, _ = start_redis(port=int(redis_port),
num_retries=1,
stdout_file=redis_stdout_file,
stderr_file=redis_stderr_file,
cleanup=cleanup)
assert redis_port == new_redis_port
redis_shards = address_info.get("redis_shards", [])
if redis_address is None:
redis_address, redis_shards = start_redis(
node_ip_address, port=redis_port, num_redis_shards=num_redis_shards,
redirect_output=redirect_output, cleanup=cleanup)
address_info["redis_address"] = redis_address
time.sleep(0.1)
# Start monitoring the processes.
monitor_stdout_file, monitor_stderr_file = new_log_files("monitor",
redirect_output)
@@ -815,9 +866,14 @@ def start_ray_processes(address_info=None,
node_ip_address,
stdout_file=monitor_stdout_file,
stderr_file=monitor_stderr_file)
else:
if redis_address is None:
raise Exception("Redis address expected")
if redis_shards == []:
# Get redis shards from primary redis instance.
redis_ip_address, redis_port = redis_address.split(":")
redis_client = redis.StrictRedis(host=redis_ip_address, port=redis_port)
redis_shards = redis_client.lrange("RedisShards", start=0, end=-1)
redis_shards = [shard.decode("ascii") for shard in redis_shards]
address_info["redis_shards"] = redis_shards
# Start the log monitor, if necessary.
if include_log_monitor:
@@ -1005,6 +1061,7 @@ def start_ray_node(node_ip_address,
def start_ray_head(address_info=None,
node_ip_address="127.0.0.1",
redis_port=None,
num_workers=0,
num_local_schedulers=1,
worker_path=None,
@@ -1012,7 +1069,8 @@ def start_ray_head(address_info=None,
redirect_output=False,
start_workers_from_local_scheduler=True,
num_cpus=None,
num_gpus=None):
num_gpus=None,
num_redis_shards=None):
"""Start Ray in local mode.
Args:
@@ -1020,6 +1078,9 @@ def start_ray_head(address_info=None,
that have already been started. If provided, address_info will be
modified to include processes that are newly started.
node_ip_address (str): The IP address of this node.
redis_port (int): The port that the primary Redis shard should listen to.
If None, then a random port will be chosen. If the key "redis_address" is
in address_info, then this argument will be ignored.
num_workers (int): The number of workers to start.
num_local_schedulers (int): The total number of local schedulers required.
This is also the total number of object stores required. This method will
@@ -1038,14 +1099,18 @@ def start_ray_head(address_info=None,
Python.
num_cpus (int): number of cpus to configure the local scheduler with.
num_gpus (int): number of gpus to configure the local scheduler with.
num_redis_shards: The number of Redis shards to start in addition to the
primary Redis shard.
Returns:
A dictionary of the address information for the processes that were
started.
"""
num_redis_shards = 1 if num_redis_shards is None else num_redis_shards
return start_ray_processes(
address_info=address_info,
node_ip_address=node_ip_address,
redis_port=redis_port,
num_workers=num_workers,
num_local_schedulers=num_local_schedulers,
worker_path=worker_path,
@@ -1053,11 +1118,11 @@ def start_ray_head(address_info=None,
redirect_output=redirect_output,
include_global_scheduler=True,
include_log_monitor=True,
include_redis=True,
include_webui=False,
start_workers_from_local_scheduler=start_workers_from_local_scheduler,
num_cpus=num_cpus,
num_gpus=num_gpus)
num_gpus=num_gpus,
num_redis_shards=num_redis_shards)
def new_log_files(name, redirect_output):
+28 -15
View File
@@ -992,7 +992,8 @@ def _init(address_info=None,
redirect_output=False,
start_workers_from_local_scheduler=True,
num_cpus=None,
num_gpus=None):
num_gpus=None,
num_redis_shards=None):
"""Helper method to connect to an existing Ray cluster or start a new one.
This method handles two cases. Either a Ray cluster already exists and we
@@ -1002,8 +1003,9 @@ def _init(address_info=None,
Args:
address_info (dict): A dictionary with address information for processes in
a partially-started Ray cluster. If start_ray_local=True, any processes
not in this dictionary will be started. If provided, address_info will be
modified to include processes that are newly started.
not in this dictionary will be started. If provided, an updated
address_info dictionary will be returned to include processes that are
newly started.
start_ray_local (bool): If True then this will start any processes not
already in address_info, including Redis, a global scheduler, local
scheduler(s), object store(s), and worker(s). It will also kill these
@@ -1028,6 +1030,8 @@ def _init(address_info=None,
be configured with.
num_gpus: A list containing the number of GPUs the local schedulers should
be configured with.
num_redis_shards: The number of Redis shards to start in addition to the
primary Redis shard.
Returns:
Address information about the started processes.
@@ -1069,6 +1073,8 @@ def _init(address_info=None,
num_local_schedulers = len(local_schedulers)
else:
num_local_schedulers = 1
# Use 1 additional redis shard if num_redis_shards is not provided.
num_redis_shards = 1 if num_redis_shards is None else num_redis_shards
# Start the scheduler, object store, and some workers. These will be killed
# by the call to cleanup(), which happens when the Python script exits.
address_info = services.start_ray_head(
@@ -1079,20 +1085,24 @@ def _init(address_info=None,
redirect_output=redirect_output,
start_workers_from_local_scheduler=start_workers_from_local_scheduler,
num_cpus=num_cpus,
num_gpus=num_gpus)
num_gpus=num_gpus,
num_redis_shards=num_redis_shards)
else:
if redis_address is None:
raise Exception("If start_ray_local=False, then redis_address must be "
"provided.")
raise Exception("When connecting to an existing cluster, redis_address "
"must be provided.")
if num_workers is not None:
raise Exception("If start_ray_local=False, then num_workers must not be "
"provided.")
raise Exception("When connecting to an existing cluster, num_workers "
"must not be provided.")
if num_local_schedulers is not None:
raise Exception("If start_ray_local=False, then num_local_schedulers "
"must not be provided.")
raise Exception("When connecting to an existing cluster, "
"num_local_schedulers must not be provided.")
if num_cpus is not None or num_gpus is not None:
raise Exception("If start_ray_local=False, then num_cpus and num_gpus "
"must not be provided.")
raise Exception("When connecting to an existing cluster, num_cpus and "
"num_gpus must not be provided.")
if num_redis_shards is not None:
raise Exception("When connecting to an existing cluster, "
"num_redis_shards must not be provided.")
# Get the node IP address if one is not provided.
if node_ip_address is None:
node_ip_address = services.get_node_ip_address(redis_address)
@@ -1121,7 +1131,7 @@ def _init(address_info=None,
def init(redis_address=None, node_ip_address=None, object_id_seed=None,
num_workers=None, driver_mode=SCRIPT_MODE, redirect_output=False,
num_cpus=None, num_gpus=None):
num_cpus=None, num_gpus=None, num_redis_shards=None):
"""Either connect to an existing Ray cluster or start one and connect to it.
This method handles two cases. Either a Ray cluster already exists and we
@@ -1148,6 +1158,8 @@ def init(redis_address=None, node_ip_address=None, object_id_seed=None,
configured with.
num_gpus (int): Number of gpus the user wishes all local schedulers to be
configured with.
num_redis_shards: The number of Redis shards to start in addition to the
primary Redis shard.
Returns:
Address information about the started processes.
@@ -1161,7 +1173,7 @@ def init(redis_address=None, node_ip_address=None, object_id_seed=None,
return _init(address_info=info, start_ray_local=(redis_address is None),
num_workers=num_workers, driver_mode=driver_mode,
redirect_output=redirect_output, num_cpus=num_cpus,
num_gpus=num_gpus)
num_gpus=num_gpus, num_redis_shards=num_redis_shards)
def cleanup(worker=global_worker):
@@ -1577,7 +1589,8 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker,
ray.local_scheduler.ObjectID(NIL_ACTOR_ID),
worker.actor_counters[actor_id],
[0, 0])
worker.redis_client.execute_command(
global_state._execute_command(
driver_task.task_id(),
"RAY.TASK_TABLE_ADD",
driver_task.task_id().id(),
TASK_STATUS_RUNNING,
+16 -11
View File
@@ -15,6 +15,9 @@ parser.add_argument("--redis-address", required=False, type=str,
help="the address to use for connecting to Redis")
parser.add_argument("--redis-port", required=False, type=str,
help="the port to use for starting Redis")
parser.add_argument("--num-redis-shards", required=False, type=int,
help=("the number of additional Redis shards to use in "
"addition to the primary Redis shard"))
parser.add_argument("--object-manager-port", required=False, type=int,
help="the port to use for starting the object manager")
parser.add_argument("--num-workers", required=False, type=int,
@@ -75,23 +78,22 @@ if __name__ == "__main__":
print("Using IP address {} for this node.".format(node_ip_address))
address_info = {}
# Use the provided Redis port if there is one.
if args.redis_port is not None:
address_info["redis_address"] = "{}:{}".format(node_ip_address,
args.redis_port)
# Use the provided object manager port if there is one.
if args.object_manager_port is not None:
address_info["object_manager_ports"] = [args.object_manager_port]
if address_info == {}:
address_info = None
address_info = services.start_ray_head(address_info=address_info,
node_ip_address=node_ip_address,
num_workers=args.num_workers,
cleanup=False,
redirect_output=True,
num_cpus=args.num_cpus,
num_gpus=args.num_gpus)
address_info = services.start_ray_head(
address_info=address_info,
node_ip_address=node_ip_address,
redis_port=args.redis_port,
num_workers=args.num_workers,
cleanup=False,
redirect_output=True,
num_cpus=args.num_cpus,
num_gpus=args.num_gpus,
num_redis_shards=args.num_redis_shards)
print(address_info)
print("\nStarted Ray on this node. You can add additional nodes to the "
"cluster by calling\n\n"
@@ -113,6 +115,9 @@ if __name__ == "__main__":
if args.redis_address is None:
raise Exception("If --head is not passed in, --redis-address must be "
"provided.")
if args.num_redis_shards is not None:
raise Exception("If --head is not passed in, --num-redis-shards must "
"not be provided.")
redis_ip_address, redis_port = args.redis_address.split(":")
# Wait for the Redis server to be started. And throw an exception if we
# can't connect to it.
+2
View File
@@ -5,6 +5,8 @@
#include "common.h"
#define DB_CLIENT_PREFIX "CL:"
/**
* Convert an object ID to a flatbuffer string.
*
+10
View File
@@ -172,6 +172,14 @@ static PyObject *PyObjectID_richcompare(PyObjectID *self,
return result;
}
static PyObject *PyObjectID_redis_shard_hash(PyObjectID *self) {
/* NOTE: The hash function used here must match the one in get_redis_context
* in src/common/state/redis.cc. Changes to the hash function should only be
* made through UniqueIDHasher in src/common/common.h */
UniqueIDHasher hash;
return PyLong_FromSize_t(hash(self->object_id));
}
static long PyObjectID_hash(PyObjectID *self) {
PyObject *tuple = PyTuple_New(UNIQUE_ID_SIZE);
for (int i = 0; i < UNIQUE_ID_SIZE; ++i) {
@@ -201,6 +209,8 @@ static PyObject *PyObjectID___reduce__(PyObjectID *self) {
static PyMethodDef PyObjectID_methods[] = {
{"id", (PyCFunction) PyObjectID_id, METH_NOARGS,
"Return the hash associated with this ObjectID"},
{"redis_shard_hash", (PyCFunction) PyObjectID_redis_shard_hash, METH_NOARGS,
"Return the redis shard that this ObjectID is associated with"},
{"hex", (PyCFunction) PyObjectID_hex, METH_NOARGS,
"Return the object ID as a string in hex."},
{"__reduce__", (PyCFunction) PyObjectID___reduce__, METH_NOARGS,
+8 -5
View File
@@ -67,11 +67,14 @@ void RayLogger_log(RayLogger *logger,
if (logger->is_direct) {
DBHandle *db = (DBHandle *) logger->conn;
/* Fill in the client ID and send the message to Redis. */
int status = redisAsyncCommand(
db->context, NULL, NULL, utstring_body(formatted_message),
(char *) db->client.id, sizeof(db->client.id));
if ((status == REDIS_ERR) || db->context->err) {
LOG_REDIS_DEBUG(db->context, "error while logging message to log table");
redisAsyncContext *context = get_redis_context(db, db->client);
int status =
redisAsyncCommand(context, NULL, NULL, utstring_body(formatted_message),
(char *) db->client.id, sizeof(db->client.id));
if ((status == REDIS_ERR) || context->err) {
LOG_REDIS_DEBUG(context, "error while logging message to log table");
}
} else {
/* If we don't own a Redis connection, we leave our client
+11
View File
@@ -1,5 +1,9 @@
#include "net.h"
#include <arpa/inet.h>
#include <sstream>
#include "common.h"
int parse_ip_addr_port(const char *ip_addr_port, char *ip_addr, int *port) {
@@ -11,3 +15,10 @@ int parse_ip_addr_port(const char *ip_addr_port, char *ip_addr, int *port) {
*port = atoi(port_str);
return 0;
}
/* Return true if the ip address is valid. */
bool valid_ip_address(const std::string &ip_address) {
struct sockaddr_in sa;
int result = inet_pton(AF_INET, ip_address.c_str(), &sa.sin_addr);
return result == 1;
}
+16 -1
View File
@@ -27,7 +27,6 @@
* TODO(pcm): Fill this out.
*/
#define DB_CLIENT_PREFIX "CL:"
#define OBJECT_INFO_PREFIX "OI:"
#define OBJECT_LOCATION_PREFIX "OL:"
#define OBJECT_NOTIFICATION_PREFIX "ON:"
@@ -929,6 +928,10 @@ int TaskTableWrite(RedisModuleCtx *ctx,
RedisModuleCallReply *reply =
RedisModule_Call(ctx, "PUBLISH", "ss", publish_topic, publish_message);
/* See how many clients received this publish. */
long long num_clients = RedisModule_CallReplyInteger(reply);
CHECKM(num_clients <= 1, "Published to %lld clients.", num_clients);
RedisModule_FreeString(ctx, publish_message);
RedisModule_FreeString(ctx, publish_topic);
if (existing_task_spec != NULL) {
@@ -938,6 +941,18 @@ int TaskTableWrite(RedisModuleCtx *ctx,
if (reply == NULL) {
return RedisModule_ReplyWithError(ctx, "PUBLISH unsuccessful");
}
if (num_clients == 0) {
LOG_WARN(
"No subscribers received this publish. This most likely means that "
"either the intended recipient has not subscribed yet or that the "
"pubsub connection to the intended recipient has been broken.");
/* This reply will be received by redis_task_table_update_callback or
* redis_task_table_add_task_callback in redis.cc, which will then reissue
* the command. */
return RedisModule_ReplyWithError(ctx,
"No subscribers received message.");
}
}
RedisModule_ReplyWithSimpleString(ctx, "OK");
+5 -2
View File
@@ -11,6 +11,9 @@ typedef struct DBHandle DBHandle;
*
* @param db_address The hostname to use to connect to the database.
* @param db_port The port to use to connect to the database.
* @param db_shards_addresses The list of database shard IP addresses.
* @param db_shards_ports The list of database shard ports, in the same order
* as db_shards_addresses.
* @param client_type The type of this client.
* @param node_ip_address The hostname of the client that is connecting.
* @param num_args The number of extra arguments that should be supplied. This
@@ -21,8 +24,8 @@ typedef struct DBHandle DBHandle;
* @return This returns a handle to the database, which must be freed with
* db_disconnect after use.
*/
DBHandle *db_connect(const char *db_address,
int db_port,
DBHandle *db_connect(const std::string &db_primary_address,
int db_primary_port,
const char *client_type,
const char *node_ip_address,
int num_args,
+15 -4
View File
@@ -29,11 +29,22 @@ void db_client_table_remove(DBHandle *db_handle,
* ==== Subscribing to the db client table ====
*/
/* An entry in the db client table. */
typedef struct {
/** The database client ID. */
DBClientID id;
/** The database client type. */
const char *client_type;
/** An optional auxiliary address for an associated database client on the
* same node. */
const char *aux_address;
/** Whether or not the database client exists. If this is false for an entry,
* then it will never again be true. */
bool is_insertion;
} DBClient;
/* Callback for subscribing to the db client table. */
typedef void (*db_client_table_subscribe_callback)(DBClientID db_client_id,
const char *client_type,
const char *aux_address,
bool is_insertion,
typedef void (*db_client_table_subscribe_callback)(DBClient *db_client,
void *user_context);
/**
+475 -162
View File
@@ -4,6 +4,7 @@
#include <stdbool.h>
#include <stdlib.h>
#include <unistd.h>
#include <vector>
extern "C" {
/* Including hiredis here is necessary on Windows for typedefs used in ae.h. */
@@ -26,6 +27,7 @@ extern "C" {
#include "event_loop.h"
#include "redis.h"
#include "io.h"
#include "net.h"
#include "format/common_generated.h"
@@ -77,52 +79,128 @@ extern int usleep(useconds_t usec);
do { \
} while (0)
DBHandle *db_connect(const char *db_address,
int db_port,
const char *client_type,
const char *node_ip_address,
int num_args,
const char **args) {
/* Check that the number of args is even. These args will be passed to the
* RAY.CONNECT Redis command, which takes arguments in pairs. */
if (num_args % 2 != 0) {
LOG_FATAL("The number of extra args must be divisible by two.");
}
redisAsyncContext *get_redis_context(DBHandle *db, UniqueID id) {
/* NOTE: The hash function used here must match the one in
* PyObjectID_redis_shard_hash in src/common/lib/python/common_extension.cc.
* Changes to the hash function should only be made through
* UniqueIDHasher in src/common/common.h */
UniqueIDHasher index;
return db->contexts[index(id) % db->contexts.size()];
}
DBHandle *db = (DBHandle *) malloc(sizeof(DBHandle));
/* Sync connection for initial handshake */
redisAsyncContext *get_redis_subscribe_context(DBHandle *db, UniqueID id) {
UniqueIDHasher index;
return db->subscribe_contexts[index(id) % db->subscribe_contexts.size()];
}
void get_redis_shards(redisContext *context,
std::vector<std::string> &db_shards_addresses,
std::vector<int> &db_shards_ports) {
/* Get the total number of Redis shards in the system. */
int num_attempts = 0;
redisReply *reply = NULL;
while (num_attempts < REDIS_DB_CONNECT_RETRIES) {
/* Try to read the number of Redis shards from the primary shard. If the
* entry is present, exit. */
reply = (redisReply *) redisCommand(context, "GET NumRedisShards");
if (reply->type != REDIS_REPLY_NIL) {
break;
}
/* Sleep for a little, and try again if the entry isn't there yet. */
freeReplyObject(reply);
usleep(REDIS_DB_CONNECT_WAIT_MS * 1000);
num_attempts++;
continue;
}
CHECKM(num_attempts < REDIS_DB_CONNECT_RETRIES,
"No entry found for NumRedisShards");
CHECKM(reply->type == REDIS_REPLY_STRING,
"Expected string, found Redis type %d for NumRedisShards",
reply->type);
int num_redis_shards = atoi(reply->str);
CHECKM(num_redis_shards >= 1, "Expected at least one Redis shard, found %d.",
num_redis_shards);
freeReplyObject(reply);
/* Get the addresses of all of the Redis shards. */
num_attempts = 0;
while (num_attempts < REDIS_DB_CONNECT_RETRIES) {
/* Try to read the Redis shard locations from the primary shard. If we find
* that all of them are present, exit. */
reply = (redisReply *) redisCommand(context, "LRANGE RedisShards 0 -1");
if (reply->elements == num_redis_shards) {
break;
}
/* Sleep for a little, and try again if not all Redis shard addresses have
* been added yet. */
freeReplyObject(reply);
usleep(REDIS_DB_CONNECT_WAIT_MS * 1000);
num_attempts++;
continue;
}
CHECKM(num_attempts < REDIS_DB_CONNECT_RETRIES,
"Expected %d Redis shard addresses, found %d", num_redis_shards,
(int) reply->elements);
/* Parse the Redis shard addresses. */
char db_shard_address[16];
int db_shard_port;
for (int i = 0; i < reply->elements; ++i) {
/* Parse the shard addresses and ports. */
CHECK(reply->element[i]->type == REDIS_REPLY_STRING);
CHECK(parse_ip_addr_port(reply->element[i]->str, db_shard_address,
&db_shard_port) == 0);
db_shards_addresses.push_back(std::string(db_shard_address));
db_shards_ports.push_back(db_shard_port);
}
freeReplyObject(reply);
}
void db_connect_shard(const std::string &db_address,
int db_port,
DBClientID client,
const char *client_type,
const char *node_ip_address,
int num_args,
const char **args,
DBHandle *db,
redisAsyncContext **context_out,
redisAsyncContext **subscribe_context_out,
redisContext **sync_context_out) {
/* Synchronous connection for initial handshake */
redisReply *reply;
int connection_attempts = 0;
redisContext *context = redisConnect(db_address, db_port);
while (context == NULL || context->err) {
redisContext *sync_context = redisConnect(db_address.c_str(), db_port);
while (sync_context == NULL || sync_context->err) {
if (connection_attempts >= REDIS_DB_CONNECT_RETRIES) {
break;
}
LOG_WARN("Failed to connect to Redis, retrying.");
/* Sleep for a little. */
usleep(REDIS_DB_CONNECT_WAIT_MS * 1000);
context = redisConnect(db_address, db_port);
sync_context = redisConnect(db_address.c_str(), db_port);
connection_attempts += 1;
}
CHECK_REDIS_CONNECT(redisContext, context,
CHECK_REDIS_CONNECT(redisContext, sync_context,
"could not establish synchronous connection to redis "
"%s:%d",
db_address, db_port);
db_address.c_str(), db_port);
/* Configure Redis to generate keyspace notifications for list events. This
* should only need to be done once (by whoever started Redis), but since
* Redis may be started in multiple places (e.g., for testing or when starting
* processes by hand), it is easier to do it multiple times. */
reply = (redisReply *) redisCommand(context,
reply = (redisReply *) redisCommand(sync_context,
"CONFIG SET notify-keyspace-events Kl");
CHECKM(reply != NULL, "db_connect failed on CONFIG SET");
freeReplyObject(reply);
/* Also configure Redis to not run in protected mode, so clients on other
* hosts can connect to it. */
reply = (redisReply *) redisCommand(context, "CONFIG SET protected-mode no");
reply =
(redisReply *) redisCommand(sync_context, "CONFIG SET protected-mode no");
CHECKM(reply != NULL, "db_connect failed on CONFIG SET");
freeReplyObject(reply);
/* Create a client ID for this client. */
DBClientID client = globally_unique_id();
/* Construct the argument arrays for RAY.CONNECT. */
int argc = num_args + 4;
@@ -133,7 +211,7 @@ DBHandle *db_connect(const char *db_address,
argvlen[0] = strlen(argv[0]);
/* Set the client ID argument. */
argv[1] = (char *) client.id;
argvlen[1] = sizeof(db->client.id);
argvlen[1] = sizeof(client.id);
/* Set the node IP address argument. */
argv[2] = node_ip_address;
argvlen[2] = strlen(node_ip_address);
@@ -152,7 +230,7 @@ DBHandle *db_connect(const char *db_address,
/* Register this client with Redis. RAY.CONNECT is a custom Redis command that
* we've defined. */
reply = (redisReply *) redisCommandArgv(context, argc, argv, argvlen);
reply = (redisReply *) redisCommandArgv(sync_context, argc, argv, argvlen);
CHECKM(reply != NULL, "db_connect failed on RAY.CONNECT");
CHECK(reply->type != REDIS_REPLY_ERROR);
CHECK(strcmp(reply->str, "OK") == 0);
@@ -160,25 +238,75 @@ DBHandle *db_connect(const char *db_address,
free(argv);
free(argvlen);
*sync_context_out = sync_context;
/* Establish connection for control data. */
redisAsyncContext *context = redisAsyncConnect(db_address.c_str(), db_port);
CHECK_REDIS_CONNECT(redisAsyncContext, context,
"could not establish asynchronous connection to redis "
"%s:%d",
db_address.c_str(), db_port);
context->data = (void *) db;
*context_out = context;
/* Establish async connection for subscription. */
redisAsyncContext *subscribe_context =
redisAsyncConnect(db_address.c_str(), db_port);
CHECK_REDIS_CONNECT(redisAsyncContext, subscribe_context,
"could not establish asynchronous subscription "
"connection to redis %s:%d",
db_address.c_str(), db_port);
subscribe_context->data = (void *) db;
*subscribe_context_out = subscribe_context;
}
DBHandle *db_connect(const std::string &db_primary_address,
int db_primary_port,
const char *client_type,
const char *node_ip_address,
int num_args,
const char **args) {
/* Check that the number of args is even. These args will be passed to the
* RAY.CONNECT Redis command, which takes arguments in pairs. */
if (num_args % 2 != 0) {
LOG_FATAL("The number of extra args must be divisible by two.");
}
/* Create a client ID for this client. */
DBClientID client = globally_unique_id();
DBHandle *db = new DBHandle();
db->client_type = strdup(client_type);
db->client = client;
db->db_client_cache = NULL;
db->sync_context = context;
/* Establish async connection */
db->context = redisAsyncConnect(db_address, db_port);
CHECK_REDIS_CONNECT(redisAsyncContext, db->context,
"could not establish asynchronous connection to redis "
"%s:%d",
db_address, db_port);
db->context->data = (void *) db;
/* Establish async connection for subscription */
db->sub_context = redisAsyncConnect(db_address, db_port);
CHECK_REDIS_CONNECT(redisAsyncContext, db->sub_context,
"could not establish asynchronous subscription "
"connection to redis %s:%d",
db_address, db_port);
db->sub_context->data = (void *) db;
redisAsyncContext *context;
redisAsyncContext *subscribe_context;
redisContext *sync_context;
/* Connect to the primary redis instance. */
db_connect_shard(db_primary_address, db_primary_port, client, client_type,
node_ip_address, num_args, args, db, &context,
&subscribe_context, &sync_context);
db->context = context;
db->subscribe_context = subscribe_context;
db->sync_context = sync_context;
/* Get the shard locations. */
std::vector<std::string> db_shards_addresses;
std::vector<int> db_shards_ports;
get_redis_shards(db->sync_context, db_shards_addresses, db_shards_ports);
CHECKM(db_shards_addresses.size() > 0, "No Redis shards found");
/* Connect to the shards. */
for (int i = 0; i < db_shards_addresses.size(); ++i) {
db_connect_shard(db_shards_addresses[i], db_shards_ports[i], client,
client_type, node_ip_address, num_args, args, db, &context,
&subscribe_context, &sync_context);
db->contexts.push_back(context);
db->subscribe_contexts.push_back(subscribe_context);
redisFree(sync_context);
}
return db;
}
@@ -193,10 +321,17 @@ void db_disconnect(DBHandle *db) {
CHECK(strcmp(reply->str, "OK") == 0);
freeReplyObject(reply);
/* Clean up the Redis connection state. */
/* Clean up the primary Redis connection state. */
redisFree(db->sync_context);
redisAsyncFree(db->context);
redisAsyncFree(db->sub_context);
redisAsyncFree(db->subscribe_context);
/* Clean up the Redis shards. */
CHECK(db->contexts.size() == db->subscribe_contexts.size());
for (int i = 0; i < db->contexts.size(); ++i) {
redisAsyncFree(db->contexts[i]);
redisAsyncFree(db->subscribe_contexts[i]);
}
/* Clean up memory. */
DBClientCacheEntry *e, *tmp;
@@ -206,21 +341,36 @@ void db_disconnect(DBHandle *db) {
free(e);
}
free(db->client_type);
free(db);
delete db;
}
void db_attach(DBHandle *db, event_loop *loop, bool reattach) {
db->loop = loop;
/* Attach primary redis instance to the event loop. */
int err = redisAeAttach(loop, db->context);
/* If the database is reattached in the tests, redis normally gives
* an error which we can safely ignore. */
if (!reattach) {
CHECKM(err == REDIS_OK, "failed to attach the event loop");
}
err = redisAeAttach(loop, db->sub_context);
err = redisAeAttach(loop, db->subscribe_context);
if (!reattach) {
CHECKM(err == REDIS_OK, "failed to attach the event loop");
}
/* Attach other redis shards to the event loop. */
CHECK(db->contexts.size() == db->subscribe_contexts.size());
for (int i = 0; i < db->contexts.size(); ++i) {
int err = redisAeAttach(loop, db->contexts[i]);
/* If the database is reattached in the tests, redis normally gives
* an error which we can safely ignore. */
if (!reattach) {
CHECKM(err == REDIS_OK, "failed to attach the event loop");
}
err = redisAeAttach(loop, db->subscribe_contexts[i]);
if (!reattach) {
CHECKM(err == REDIS_OK, "failed to attach the event loop");
}
}
}
/*
@@ -264,14 +414,16 @@ void redis_object_table_add(TableCallbackData *callback_data) {
int64_t object_size = info->object_size;
unsigned char *digest = info->digest;
redisAsyncContext *context = get_redis_context(db, obj_id);
int status = redisAsyncCommand(
db->context, redis_object_table_add_callback,
context, redis_object_table_add_callback,
(void *) callback_data->timer_id, "RAY.OBJECT_TABLE_ADD %b %ld %b %b",
obj_id.id, sizeof(obj_id.id), object_size, digest, (size_t) DIGEST_SIZE,
db->client.id, sizeof(db->client.id));
if ((status == REDIS_ERR) || db->context->err) {
LOG_REDIS_DEBUG(db->context, "error in redis_object_table_add");
if ((status == REDIS_ERR) || context->err) {
LOG_REDIS_DEBUG(context, "error in redis_object_table_add");
}
}
@@ -309,13 +461,16 @@ void redis_object_table_remove(TableCallbackData *callback_data) {
if (client_id == NULL) {
client_id = &db->client;
}
redisAsyncContext *context = get_redis_context(db, obj_id);
int status = redisAsyncCommand(
db->context, redis_object_table_remove_callback,
context, redis_object_table_remove_callback,
(void *) callback_data->timer_id, "RAY.OBJECT_TABLE_REMOVE %b %b",
obj_id.id, sizeof(obj_id.id), client_id->id, sizeof(client_id->id));
if ((status == REDIS_ERR) || db->context->err) {
LOG_REDIS_DEBUG(db->context, "error in redis_object_table_remove");
if ((status == REDIS_ERR) || context->err) {
LOG_REDIS_DEBUG(context, "error in redis_object_table_remove");
}
}
@@ -324,12 +479,15 @@ void redis_object_table_lookup(TableCallbackData *callback_data) {
DBHandle *db = callback_data->db_handle;
ObjectID obj_id = callback_data->id;
int status = redisAsyncCommand(
db->context, redis_object_table_lookup_callback,
(void *) callback_data->timer_id, "RAY.OBJECT_TABLE_LOOKUP %b", obj_id.id,
sizeof(obj_id.id));
if ((status == REDIS_ERR) || db->context->err) {
LOG_REDIS_DEBUG(db->context, "error in object_table lookup");
redisAsyncContext *context = get_redis_context(db, obj_id);
int status = redisAsyncCommand(context, redis_object_table_lookup_callback,
(void *) callback_data->timer_id,
"RAY.OBJECT_TABLE_LOOKUP %b", obj_id.id,
sizeof(obj_id.id));
if ((status == REDIS_ERR) || context->err) {
LOG_REDIS_DEBUG(context, "error in object_table lookup");
}
}
@@ -358,13 +516,15 @@ void redis_result_table_add(TableCallbackData *callback_data) {
ResultTableAddInfo *info = (ResultTableAddInfo *) callback_data->data;
int is_put = info->is_put ? 1 : 0;
redisAsyncContext *context = get_redis_context(db, id);
/* Add the result entry to the result table. */
int status = redisAsyncCommand(
db->context, redis_result_table_add_callback,
context, redis_result_table_add_callback,
(void *) callback_data->timer_id, "RAY.RESULT_TABLE_ADD %b %b %d", id.id,
sizeof(id.id), info->task_id.id, sizeof(info->task_id.id), is_put);
if ((status == REDIS_ERR) || db->context->err) {
LOG_REDIS_DEBUG(db->context, "Error in result table add");
if ((status == REDIS_ERR) || context->err) {
LOG_REDIS_DEBUG(context, "Error in result table add");
}
}
@@ -423,12 +583,13 @@ void redis_result_table_lookup(TableCallbackData *callback_data) {
CHECK(callback_data);
DBHandle *db = callback_data->db_handle;
ObjectID id = callback_data->id;
redisAsyncContext *context = get_redis_context(db, id);
int status =
redisAsyncCommand(db->context, redis_result_table_lookup_callback,
redisAsyncCommand(context, redis_result_table_lookup_callback,
(void *) callback_data->timer_id,
"RAY.RESULT_TABLE_LOOKUP %b", id.id, sizeof(id.id));
if ((status == REDIS_ERR) || db->context->err) {
LOG_REDIS_DEBUG(db->context, "Error in result table lookup");
if ((status == REDIS_ERR) || context->err) {
LOG_REDIS_DEBUG(context, "Error in result table lookup");
}
}
@@ -586,28 +747,33 @@ void redis_object_table_subscribe_to_notifications(
* src/common/redismodule/ray_redis_module.cc. */
const char *object_channel_prefix = "OC:";
const char *object_channel_bcast = "BCAST";
int status = REDIS_OK;
/* Subscribe to notifications from the object table. This uses the client ID
* as the channel name so this channel is specific to this client. TODO(rkn):
* The channel name should probably be the client ID with some prefix. */
CHECKM(callback_data->data != NULL,
"Object table subscribe data passed as NULL.");
if (((ObjectTableSubscribeData *) (callback_data->data))->subscribe_all) {
/* Subscribe to the object broadcast channel. */
status = redisAsyncCommand(
db->sub_context, object_table_redis_subscribe_to_notifications_callback,
(void *) callback_data->timer_id, "SUBSCRIBE %s%s",
object_channel_prefix, object_channel_bcast);
} else {
status = redisAsyncCommand(
db->sub_context, object_table_redis_subscribe_to_notifications_callback,
(void *) callback_data->timer_id, "SUBSCRIBE %s%b",
object_channel_prefix, db->client.id, sizeof(db->client.id));
}
for (int i = 0; i < db->subscribe_contexts.size(); ++i) {
int status = REDIS_OK;
/* Subscribe to notifications from the object table. This uses the client ID
* as the channel name so this channel is specific to this client.
* TODO(rkn):
* The channel name should probably be the client ID with some prefix. */
CHECKM(callback_data->data != NULL,
"Object table subscribe data passed as NULL.");
if (((ObjectTableSubscribeData *) (callback_data->data))->subscribe_all) {
/* Subscribe to the object broadcast channel. */
status = redisAsyncCommand(
db->subscribe_contexts[i],
object_table_redis_subscribe_to_notifications_callback,
(void *) callback_data->timer_id, "SUBSCRIBE %s%s",
object_channel_prefix, object_channel_bcast);
} else {
status = redisAsyncCommand(
db->subscribe_contexts[i],
object_table_redis_subscribe_to_notifications_callback,
(void *) callback_data->timer_id, "SUBSCRIBE %s%b",
object_channel_prefix, db->client.id, sizeof(db->client.id));
}
if ((status == REDIS_ERR) || db->sub_context->err) {
LOG_REDIS_DEBUG(db->sub_context,
"error in redis_object_table_subscribe_to_notifications");
if ((status == REDIS_ERR) || db->subscribe_contexts[i]->err) {
LOG_REDIS_DEBUG(db->subscribe_contexts[i],
"error in redis_object_table_subscribe_to_notifications");
}
}
}
@@ -633,31 +799,33 @@ void redis_object_table_request_notifications(
int num_object_ids = request_data->num_object_ids;
ObjectID *object_ids = request_data->object_ids;
/* Create the arguments for the Redis command. */
int num_args = 1 + 1 + num_object_ids;
const char **argv = (const char **) malloc(sizeof(char *) * num_args);
size_t *argvlen = (size_t *) malloc(sizeof(size_t) * num_args);
/* Set the command name argument. */
argv[0] = "RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS";
argvlen[0] = strlen(argv[0]);
/* Set the client ID argument. */
argv[1] = (char *) db->client.id;
argvlen[1] = sizeof(db->client.id);
/* Set the object ID arguments. */
for (int i = 0; i < num_object_ids; ++i) {
argv[2 + i] = (char *) object_ids[i].id;
argvlen[2 + i] = sizeof(object_ids[i].id);
}
redisAsyncContext *context = get_redis_context(db, object_ids[i]);
int status = redisAsyncCommandArgv(
db->context, redis_object_table_request_notifications_callback,
(void *) callback_data->timer_id, num_args, argv, argvlen);
free(argv);
free(argvlen);
/* Create the arguments for the Redis command. */
int num_args = 1 + 1 + 1;
const char **argv = (const char **) malloc(sizeof(char *) * num_args);
size_t *argvlen = (size_t *) malloc(sizeof(size_t) * num_args);
/* Set the command name argument. */
argv[0] = "RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS";
argvlen[0] = strlen(argv[0]);
/* Set the client ID argument. */
argv[1] = (char *) db->client.id;
argvlen[1] = sizeof(db->client.id);
/* Set the object ID arguments. */
argv[2] = (char *) object_ids[i].id;
argvlen[2] = sizeof(object_ids[i].id);
if ((status == REDIS_ERR) || db->context->err) {
LOG_REDIS_DEBUG(db->context,
"error in redis_object_table_subscribe_to_notifications");
int status = redisAsyncCommandArgv(
context, redis_object_table_request_notifications_callback,
(void *) callback_data->timer_id, num_args, argv, argvlen);
free(argv);
free(argvlen);
if ((status == REDIS_ERR) || context->err) {
LOG_REDIS_DEBUG(context,
"error in redis_object_table_subscribe_to_notifications");
}
}
}
@@ -690,12 +858,14 @@ void redis_task_table_get_task(TableCallbackData *callback_data) {
CHECK(callback_data->data == NULL);
TaskID task_id = callback_data->id;
int status = redisAsyncCommand(
db->context, redis_task_table_get_task_callback,
(void *) callback_data->timer_id, "RAY.TASK_TABLE_GET %b", task_id.id,
sizeof(task_id.id));
if ((status == REDIS_ERR) || db->context->err) {
LOG_REDIS_DEBUG(db->context, "error in redis_task_table_get_task");
redisAsyncContext *context = get_redis_context(db, task_id);
int status = redisAsyncCommand(context, redis_task_table_get_task_callback,
(void *) callback_data->timer_id,
"RAY.TASK_TABLE_GET %b", task_id.id,
sizeof(task_id.id));
if ((status == REDIS_ERR) || context->err) {
LOG_REDIS_DEBUG(context, "error in redis_task_table_get_task");
}
}
@@ -706,6 +876,36 @@ void redis_task_table_add_task_callback(redisAsyncContext *c,
/* Do some minimal checking. */
redisReply *reply = (redisReply *) r;
/* If the publish which happens inside of the call to RAY.TASK_TABLE_ADD was
* not received by any subscribers, then reissue the command. TODO(rkn): This
* entire if block should be temporary. Once we address the problem where in
* which a global scheduler may publish a task to a local scheduler before the
* local scheduler has subscribed to the relevant channel, we shouldn't need
* this block any more. */
if (reply->type == REDIS_REPLY_ERROR &&
strcmp(reply->str, "No subscribers received message.") == 0) {
Task *task = (Task *) callback_data->data;
TaskID task_id = Task_task_id(task);
DBClientID local_scheduler_id = Task_local_scheduler(task);
redisAsyncContext *context = get_redis_context(db, task_id);
int state = Task_state(task);
TaskSpec *spec = Task_task_spec(task);
/* Reissue the command. */
CHECKM(task != NULL, "NULL task passed to redis_task_table_add_task.");
int status = redisAsyncCommand(
context, redis_task_table_add_task_callback,
(void *) callback_data->timer_id, "RAY.TASK_TABLE_ADD %b %d %b %b",
task_id.id, sizeof(task_id.id), state, local_scheduler_id.id,
sizeof(local_scheduler_id.id), spec, Task_task_spec_size(task));
if ((status == REDIS_ERR) || context->err) {
LOG_REDIS_DEBUG(context, "error in redis_task_table_add_task");
}
/* Since we are reissuing the same command with the same callback data,
* return early to avoid freeing the callback data. */
return;
}
CHECKM(strcmp(reply->str, "OK") == 0, "reply->str is %s", reply->str);
/* Call the done callback if there is one. */
if (callback_data->done_callback != NULL) {
@@ -722,17 +922,18 @@ void redis_task_table_add_task(TableCallbackData *callback_data) {
Task *task = (Task *) callback_data->data;
TaskID task_id = Task_task_id(task);
DBClientID local_scheduler_id = Task_local_scheduler(task);
redisAsyncContext *context = get_redis_context(db, task_id);
int state = Task_state(task);
TaskSpec *spec = Task_task_spec(task);
CHECKM(task != NULL, "NULL task passed to redis_task_table_add_task.");
int status = redisAsyncCommand(
db->context, redis_task_table_add_task_callback,
context, redis_task_table_add_task_callback,
(void *) callback_data->timer_id, "RAY.TASK_TABLE_ADD %b %d %b %b",
task_id.id, sizeof(task_id.id), state, local_scheduler_id.id,
sizeof(local_scheduler_id.id), spec, Task_task_spec_size(task));
if ((status == REDIS_ERR) || db->context->err) {
LOG_REDIS_DEBUG(db->context, "error in redis_task_table_add_task");
if ((status == REDIS_ERR) || context->err) {
LOG_REDIS_DEBUG(context, "error in redis_task_table_add_task");
}
}
@@ -743,6 +944,35 @@ void redis_task_table_update_callback(redisAsyncContext *c,
/* Do some minimal checking. */
redisReply *reply = (redisReply *) r;
/* If the publish which happens inside of the call to RAY.TASK_TABLE_UPDATE
* was not received by any subscribers, then reissue the command. TODO(rkn):
* This entire if block should be temporary. Once we address the problem where
* in which a global scheduler may publish a task to a local scheduler before
* the local scheduler has subscribed to the relevant channel, we shouldn't
* need this block any more. */
if (reply->type == REDIS_REPLY_ERROR &&
strcmp(reply->str, "No subscribers received message.") == 0) {
Task *task = (Task *) callback_data->data;
TaskID task_id = Task_task_id(task);
redisAsyncContext *context = get_redis_context(db, task_id);
DBClientID local_scheduler_id = Task_local_scheduler(task);
int state = Task_state(task);
/* Reissue the command. */
CHECKM(task != NULL, "NULL task passed to redis_task_table_update.");
int status = redisAsyncCommand(
context, redis_task_table_update_callback,
(void *) callback_data->timer_id, "RAY.TASK_TABLE_UPDATE %b %d %b",
task_id.id, sizeof(task_id.id), state, local_scheduler_id.id,
sizeof(local_scheduler_id.id));
if ((status == REDIS_ERR) || context->err) {
LOG_REDIS_DEBUG(context, "error in redis_task_table_update");
}
/* Since we are reissuing the same command with the same callback data,
* return early to avoid freeing the callback data. */
return;
}
CHECKM(strcmp(reply->str, "OK") == 0, "reply->str is %s", reply->str);
/* Call the done callback if there is one. */
if (callback_data->done_callback != NULL) {
@@ -758,17 +988,18 @@ void redis_task_table_update(TableCallbackData *callback_data) {
DBHandle *db = callback_data->db_handle;
Task *task = (Task *) callback_data->data;
TaskID task_id = Task_task_id(task);
redisAsyncContext *context = get_redis_context(db, task_id);
DBClientID local_scheduler_id = Task_local_scheduler(task);
int state = Task_state(task);
CHECKM(task != NULL, "NULL task passed to redis_task_table_update.");
int status = redisAsyncCommand(
db->context, redis_task_table_update_callback,
context, redis_task_table_update_callback,
(void *) callback_data->timer_id, "RAY.TASK_TABLE_UPDATE %b %d %b",
task_id.id, sizeof(task_id.id), state, local_scheduler_id.id,
sizeof(local_scheduler_id.id));
if ((status == REDIS_ERR) || db->context->err) {
LOG_REDIS_DEBUG(db->context, "error in redis_task_table_update");
if ((status == REDIS_ERR) || context->err) {
LOG_REDIS_DEBUG(context, "error in redis_task_table_update");
}
}
@@ -779,6 +1010,15 @@ void redis_task_table_test_and_update_callback(redisAsyncContext *c,
redisReply *reply = (redisReply *) r;
/* Parse the task from the reply. */
Task *task = parse_and_construct_task_from_redis_reply(reply);
if (task == NULL) {
/* A NULL task means that the task was not in the task table. NOTE(swang):
* For normal tasks, this is not expected behavior, but actor tasks may be
* delayed when added to the task table if they are submitted to a local
* scheduler before it receives the notification that maps the actor to a
* local scheduler. */
LOG_ERROR("No task found during task_table_test_and_update");
return;
}
/* Determine whether the update happened. */
auto message = flatbuffers::GetRoot<TaskReply>(reply->str);
bool updated = message->updated();
@@ -800,18 +1040,19 @@ void redis_task_table_test_and_update_callback(redisAsyncContext *c,
void redis_task_table_test_and_update(TableCallbackData *callback_data) {
DBHandle *db = callback_data->db_handle;
TaskID task_id = callback_data->id;
redisAsyncContext *context = get_redis_context(db, task_id);
TaskTableTestAndUpdateData *update_data =
(TaskTableTestAndUpdateData *) callback_data->data;
int status = redisAsyncCommand(
db->context, redis_task_table_test_and_update_callback,
context, redis_task_table_test_and_update_callback,
(void *) callback_data->timer_id,
"RAY.TASK_TABLE_TEST_AND_UPDATE %b %d %d %b", task_id.id,
sizeof(task_id.id), update_data->test_state_bitmask,
update_data->update_state, update_data->local_scheduler_id.id,
sizeof(update_data->local_scheduler_id.id));
if ((status == REDIS_ERR) || db->context->err) {
LOG_REDIS_DEBUG(db->context, "error in redis_task_table_test_and_update");
if ((status == REDIS_ERR) || context->err) {
LOG_REDIS_DEBUG(context, "error in redis_task_table_test_and_update");
}
}
@@ -879,24 +1120,26 @@ void redis_task_table_subscribe(TableCallbackData *callback_data) {
/* TASK_CHANNEL_PREFIX is defined in ray_redis_module.cc and must be kept in
* sync with that file. */
const char *TASK_CHANNEL_PREFIX = "TT:";
int status;
if (IS_NIL_ID(data->local_scheduler_id)) {
/* TODO(swang): Implement the state_filter by translating the bitmask into
* a Redis key-matching pattern. */
status =
redisAsyncCommand(db->sub_context, redis_task_table_subscribe_callback,
(void *) callback_data->timer_id, "PSUBSCRIBE %s*:%d",
TASK_CHANNEL_PREFIX, data->state_filter);
} else {
DBClientID local_scheduler_id = data->local_scheduler_id;
status =
redisAsyncCommand(db->sub_context, redis_task_table_subscribe_callback,
(void *) callback_data->timer_id, "SUBSCRIBE %s%b:%d",
TASK_CHANNEL_PREFIX, (char *) local_scheduler_id.id,
sizeof(local_scheduler_id.id), data->state_filter);
}
if ((status == REDIS_ERR) || db->sub_context->err) {
LOG_REDIS_DEBUG(db->sub_context, "error in redis_task_table_subscribe");
for (auto subscribe_context : db->subscribe_contexts) {
int status;
if (IS_NIL_ID(data->local_scheduler_id)) {
/* TODO(swang): Implement the state_filter by translating the bitmask into
* a Redis key-matching pattern. */
status = redisAsyncCommand(
subscribe_context, redis_task_table_subscribe_callback,
(void *) callback_data->timer_id, "PSUBSCRIBE %s*:%d",
TASK_CHANNEL_PREFIX, data->state_filter);
} else {
DBClientID local_scheduler_id = data->local_scheduler_id;
status = redisAsyncCommand(
subscribe_context, redis_task_table_subscribe_callback,
(void *) callback_data->timer_id, "SUBSCRIBE %s%b:%d",
TASK_CHANNEL_PREFIX, (char *) local_scheduler_id.id,
sizeof(local_scheduler_id.id), data->state_filter);
}
if ((status == REDIS_ERR) || subscribe_context->err) {
LOG_REDIS_DEBUG(subscribe_context, "error in redis_task_table_subscribe");
}
}
}
@@ -934,6 +1177,55 @@ void redis_db_client_table_remove(TableCallbackData *callback_data) {
}
}
void redis_db_client_table_scan(DBHandle *db,
std::vector<DBClient> &db_clients) {
/* TODO(swang): Integrate this functionality with the Ray Redis module. To do
* this, we need the KEYS or SCAN command in Redis modules. */
/* Get all the database client keys. */
redisReply *reply = (redisReply *) redisCommand(db->sync_context, "KEYS %s*",
DB_CLIENT_PREFIX);
if (reply->type == REDIS_REPLY_NIL) {
return;
}
/* Get all the database client information. */
CHECK(reply->type == REDIS_REPLY_ARRAY);
for (int i = 0; i < reply->elements; ++i) {
redisReply *client_reply = (redisReply *) redisCommand(
db->sync_context, "HGETALL %b", reply->element[i]->str,
reply->element[i]->len);
CHECK(reply->type == REDIS_REPLY_ARRAY);
CHECK(reply->elements > 0);
DBClient db_client;
memset(&db_client, 0, sizeof(db_client));
int num_fields = 0;
/* Parse the fields into a DBClient. */
for (int j = 0; j < client_reply->elements; j = j + 2) {
const char *key = client_reply->element[j]->str;
const char *value = client_reply->element[j + 1]->str;
if (strcmp(key, "ray_client_id") == 0) {
memcpy(db_client.id.id, value, sizeof(db_client.id));
num_fields++;
} else if (strcmp(key, "client_type") == 0) {
db_client.client_type = strdup(value);
num_fields++;
} else if (strcmp(key, "aux_address") == 0) {
db_client.aux_address = strdup(value);
num_fields++;
} else if (strcmp(key, "deleted") == 0) {
bool is_deleted = atoi(value);
db_client.is_insertion = !is_deleted;
num_fields++;
}
}
freeReplyObject(client_reply);
/* The client ID, type, and whether it is deleted are all mandatory fields.
* Auxiliary address is optional. */
CHECK(num_fields >= 3);
db_clients.push_back(db_client);
}
freeReplyObject(reply);
}
void redis_db_client_table_subscribe_callback(redisAsyncContext *c,
void *r,
void *privdata) {
@@ -956,35 +1248,54 @@ void redis_db_client_table_subscribe_callback(redisAsyncContext *c,
/* Note that we do not destroy the callback data yet because the
* subscription callback needs this data. */
event_loop_remove_timer(db->loop, callback_data->timer_id);
/* Get the current db client table entries, in case we missed notifications
* before the initial subscription. This must be done before we process any
* notifications from the subscription channel, so that we don't readd an
* entry that has already been deleted. */
std::vector<DBClient> db_clients;
redis_db_client_table_scan(db, db_clients);
/* Call the subscription callback for all entries that we missed. */
DBClientTableSubscribeData *data =
(DBClientTableSubscribeData *) callback_data->data;
for (auto db_client : db_clients) {
data->subscribe_callback(&db_client, data->subscribe_context);
if (db_client.client_type != NULL) {
free((void *) db_client.client_type);
}
if (db_client.aux_address != NULL) {
free((void *) db_client.aux_address);
}
}
return;
}
/* Otherwise, parse the payload and call the callback. */
auto message =
flatbuffers::GetRoot<SubscribeToDBClientTableReply>(payload->str);
DBClientID client = from_flatbuf(message->db_client_id());
/* Parse the client type and auxiliary address from the response. If there is
* only client type, then the update was a delete. */
char *client_type = (char *) message->client_type()->data();
char *aux_address = (char *) message->aux_address()->data();
bool is_insertion = message->is_insertion();
DBClient db_client;
db_client.id = from_flatbuf(message->db_client_id());
db_client.client_type = (char *) message->client_type()->data();
db_client.aux_address = message->aux_address()->data();
db_client.is_insertion = message->is_insertion();
/* Call the subscription callback. */
DBClientTableSubscribeData *data =
(DBClientTableSubscribeData *) callback_data->data;
if (data->subscribe_callback) {
data->subscribe_callback(client, client_type, aux_address, is_insertion,
data->subscribe_context);
data->subscribe_callback(&db_client, data->subscribe_context);
}
}
void redis_db_client_table_subscribe(TableCallbackData *callback_data) {
DBHandle *db = callback_data->db_handle;
int status = redisAsyncCommand(
db->sub_context, redis_db_client_table_subscribe_callback,
db->subscribe_context, redis_db_client_table_subscribe_callback,
(void *) callback_data->timer_id, "SUBSCRIBE db_clients");
if ((status == REDIS_ERR) || db->sub_context->err) {
LOG_REDIS_DEBUG(db->sub_context,
if ((status == REDIS_ERR) || db->subscribe_context->err) {
LOG_REDIS_DEBUG(db->subscribe_context,
"error in db_client_table_register_callback");
}
}
@@ -1042,10 +1353,10 @@ void redis_local_scheduler_table_subscribe_callback(redisAsyncContext *c,
void redis_local_scheduler_table_subscribe(TableCallbackData *callback_data) {
DBHandle *db = callback_data->db_handle;
int status = redisAsyncCommand(
db->sub_context, redis_local_scheduler_table_subscribe_callback,
db->subscribe_context, redis_local_scheduler_table_subscribe_callback,
(void *) callback_data->timer_id, "SUBSCRIBE local_schedulers");
if ((status == REDIS_ERR) || db->sub_context->err) {
LOG_REDIS_DEBUG(db->sub_context,
if ((status == REDIS_ERR) || db->subscribe_context->err) {
LOG_REDIS_DEBUG(db->subscribe_context,
"error in redis_local_scheduler_table_subscribe");
}
}
@@ -1130,10 +1441,11 @@ void redis_driver_table_subscribe_callback(redisAsyncContext *c,
void redis_driver_table_subscribe(TableCallbackData *callback_data) {
DBHandle *db = callback_data->db_handle;
int status = redisAsyncCommand(
db->sub_context, redis_driver_table_subscribe_callback,
db->subscribe_context, redis_driver_table_subscribe_callback,
(void *) callback_data->timer_id, "SUBSCRIBE driver_deaths");
if ((status == REDIS_ERR) || db->sub_context->err) {
LOG_REDIS_DEBUG(db->sub_context, "error in redis_driver_table_subscribe");
if ((status == REDIS_ERR) || db->subscribe_context->err) {
LOG_REDIS_DEBUG(db->subscribe_context,
"error in redis_driver_table_subscribe");
}
}
@@ -1240,10 +1552,10 @@ void redis_actor_notification_table_subscribe(
TableCallbackData *callback_data) {
DBHandle *db = callback_data->db_handle;
int status = redisAsyncCommand(
db->sub_context, redis_actor_notification_table_subscribe_callback,
db->subscribe_context, redis_actor_notification_table_subscribe_callback,
(void *) callback_data->timer_id, "SUBSCRIBE actor_notifications");
if ((status == REDIS_ERR) || db->sub_context->err) {
LOG_REDIS_DEBUG(db->sub_context,
if ((status == REDIS_ERR) || db->subscribe_context->err) {
LOG_REDIS_DEBUG(db->subscribe_context,
"error in redis_actor_notification_table_subscribe");
}
}
@@ -1290,10 +1602,11 @@ void redis_object_info_subscribe_callback(redisAsyncContext *c,
void redis_object_info_subscribe(TableCallbackData *callback_data) {
DBHandle *db = callback_data->db_handle;
int status = redisAsyncCommand(
db->sub_context, redis_object_info_subscribe_callback,
db->subscribe_context, redis_object_info_subscribe_callback,
(void *) callback_data->timer_id, "PSUBSCRIBE obj:info");
if ((status == REDIS_ERR) || db->sub_context->err) {
LOG_REDIS_DEBUG(db->sub_context, "error in object_info_register_callback");
if ((status == REDIS_ERR) || db->subscribe_context->err) {
LOG_REDIS_DEBUG(db->subscribe_context,
"error in object_info_register_callback");
}
}
@@ -1324,8 +1637,8 @@ void redis_push_error_hmset_callback(redisAsyncContext *c,
"RPUSH ErrorKeys Error:%b:%b",
info->driver_id.id, sizeof(info->driver_id.id),
info->error_key, sizeof(info->error_key));
if ((status == REDIS_ERR) || db->sub_context->err) {
LOG_REDIS_DEBUG(db->sub_context, "error in redis_push_error rpush");
if ((status == REDIS_ERR) || db->subscribe_context->err) {
LOG_REDIS_DEBUG(db->subscribe_context, "error in redis_push_error rpush");
}
}
@@ -1344,8 +1657,8 @@ void redis_push_error(TableCallbackData *callback_data) {
"HMSET Error:%b:%b type %s message %s data %b", info->driver_id.id,
sizeof(info->driver_id.id), info->error_key, sizeof(info->error_key),
error_type, error_message, info->data, info->data_length);
if ((status == REDIS_ERR) || db->sub_context->err) {
LOG_REDIS_DEBUG(db->sub_context, "error in redis_push_error hmset");
if ((status == REDIS_ERR) || db->subscribe_context->err) {
LOG_REDIS_DEBUG(db->subscribe_context, "error in redis_push_error hmset");
}
}
+51 -4
View File
@@ -34,11 +34,24 @@ struct DBHandle {
char *client_type;
/** Unique ID for this client. */
DBClientID client;
/** Redis context for all non-subscribe connections. */
/** Primary redis context for all non-subscribe connections. This is used for
* the database client table, heartbeats, and errors that should be pushed to
* the driver. */
redisAsyncContext *context;
/** Redis context for "subscribe" communication. Yes, we need a separate one
* for that, see https://github.com/redis/hiredis/issues/55. */
redisAsyncContext *sub_context;
/** Primary redis context for "subscribe" communication. A separate context
* is needed for this communication (see
* https://github.com/redis/hiredis/issues/55). This is used for the
* database client table, heartbeats, and errors that should be pushed to
* the driver. */
redisAsyncContext *subscribe_context;
/** Redis contexts for shards for all non-subscribe connections. All requests
* to the object table, task table, and event table should be directed here.
* The correct shard can be retrieved using get_redis_context below. */
std::vector<redisAsyncContext *> contexts;
/** Redis contexts for shards for "subscribe" communication. All requests
* to the object table, task table, and event table should be directed here.
* The correct shard can be retrieved using get_redis_context below. */
std::vector<redisAsyncContext *> subscribe_contexts;
/** The event loop this global state store connection is part of. */
event_loop *loop;
/** Index of the database connection in the event loop */
@@ -51,6 +64,40 @@ struct DBHandle {
redisContext *sync_context;
};
/**
* Get the Redis asynchronous context responsible for non-subscription
* communication for the given UniqueID.
*
* @param db The database handle.
* @param id The ID whose location we are querying for.
* @return The redisAsyncContext responsible for the given ID.
*/
redisAsyncContext *get_redis_context(DBHandle *db, UniqueID id);
/**
* Get the Redis asynchronous context responsible for subscription
* communication for the given UniqueID.
*
* @param db The database handle.
* @param id The ID whose location we are querying for.
* @return The redisAsyncContext responsible for the given ID.
*/
redisAsyncContext *get_redis_subscribe_context(DBHandle *db, UniqueID id);
/**
* Get a list of Redis shard IP addresses from the primary shard.
*
* @param context A Redis context connected to the primary shard.
* @param db_shards_addresses The IP addresses for the shards registered
* with the primary shard will be added to this vector.
* @param db_shards_ports The IP ports for the shards registered with the
* primary shard will be added to this vector, in the same order as
* db_shards_addresses.
*/
void get_redis_shards(redisContext *context,
std::vector<std::string> &db_shards_addresses,
std::vector<int> &db_shards_ports);
void redis_object_table_get_entry(redisAsyncContext *c,
void *r,
void *privdata);
+4 -3
View File
@@ -151,9 +151,10 @@ typedef void (*task_table_subscribe_callback)(Task *task, void *user_context);
* @param local_scheduler_id The db_client_id of the local scheduler whose
* events we want to listen to. If you want to subscribe to updates from
* all local schedulers, pass in NIL_ID.
* @param state_filter Flags for events we want to listen to. If you want
* to listen to all events, use state_filter = TASK_WAITING |
* TASK_SCHEDULED | TASK_RUNNING | TASK_DONE.
* @param state_filter Events we want to listen to. Can have values from the
* enum "scheduling_state" in task.h.
* TODO(pcm): Make it possible to combine these using flags like
* TASK_STATUS_WAITING | TASK_STATUS_SCHEDULED.
* @param retry Information about retrying the request to the database.
* @param done_callback Function to be called when database returns result.
* @param user_context Data that will be passed to done_callback and
+10 -9
View File
@@ -72,12 +72,12 @@ TEST object_table_lookup_test(void) {
event_loop *loop = event_loop_create();
/* This uses manager_port1. */
const char *db_connect_args1[] = {"address", "127.0.0.1:12345"};
DBHandle *db1 = db_connect("127.0.0.1", 6379, "plasma_manager", manager_addr,
2, db_connect_args1);
DBHandle *db1 = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
manager_addr, 2, db_connect_args1);
/* This uses manager_port2. */
const char *db_connect_args2[] = {"address", "127.0.0.1:12346"};
DBHandle *db2 = db_connect("127.0.0.1", 6379, "plasma_manager", manager_addr,
2, db_connect_args2);
DBHandle *db2 = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
manager_addr, 2, db_connect_args2);
db_attach(db1, loop, false);
db_attach(db2, loop, false);
UniqueID id = globally_unique_id();
@@ -148,8 +148,8 @@ void task_table_test_callback(Task *callback_task, void *user_data) {
TEST task_table_test(void) {
task_table_test_callback_called = 0;
event_loop *loop = event_loop_create();
DBHandle *db =
db_connect("127.0.0.1", 6379, "local_scheduler", "127.0.0.1", 0, NULL);
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "local_scheduler",
"127.0.0.1", 0, NULL);
db_attach(db, loop, false);
DBClientID local_scheduler_id = globally_unique_id();
int64_t task_spec_size;
@@ -184,8 +184,8 @@ void task_table_all_test_callback(Task *task, void *user_data) {
TEST task_table_all_test(void) {
event_loop *loop = event_loop_create();
DBHandle *db =
db_connect("127.0.0.1", 6379, "local_scheduler", "127.0.0.1", 0, NULL);
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "local_scheduler",
"127.0.0.1", 0, NULL);
db_attach(db, loop, false);
int64_t task_spec_size;
TaskSpec *spec = example_task_spec(1, 1, &task_spec_size);
@@ -222,7 +222,8 @@ TEST unique_client_id_test(void) {
DBClientID ids[num_conns];
DBHandle *db;
for (int i = 0; i < num_conns; ++i) {
db = db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
"127.0.0.1", 0, NULL);
ids[i] = get_db_client_id(db);
db_disconnect(db);
}
+51 -57
View File
@@ -65,6 +65,15 @@ void new_object_task_callback(TaskID task_id, void *user_context) {
new_object_lookup_callback, (void *) db);
}
void task_table_subscribe_done(TaskID task_id, void *user_context) {
RetryInfo retry = {
.num_retries = 5, .timeout = 100, .fail_callback = NULL,
};
DBHandle *db = (DBHandle *) user_context;
task_table_add_task(db, Task_copy(new_object_task), &retry,
new_object_task_callback, db);
}
TEST new_object_test(void) {
new_object_failed = 0;
new_object_succeeded = 0;
@@ -73,16 +82,16 @@ TEST new_object_test(void) {
new_object_task_spec = Task_task_spec(new_object_task);
new_object_task_id = TaskSpec_task_id(new_object_task_spec);
g_loop = event_loop_create();
DBHandle *db =
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
"127.0.0.1", 0, NULL);
db_attach(db, g_loop, false);
RetryInfo retry = {
.num_retries = 5,
.timeout = 100,
.fail_callback = new_object_fail_callback,
};
task_table_add_task(db, Task_copy(new_object_task), &retry,
new_object_task_callback, db);
task_table_subscribe(db, NIL_ID, TASK_STATUS_WAITING, NULL, NULL, &retry,
task_table_subscribe_done, db);
event_loop_run(g_loop);
db_disconnect(db);
destroy_outstanding_callbacks(g_loop);
@@ -109,8 +118,8 @@ TEST new_object_no_task_test(void) {
new_object_id = globally_unique_id();
new_object_task_id = globally_unique_id();
g_loop = event_loop_create();
DBHandle *db =
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
"127.0.0.1", 0, NULL);
db_attach(db, g_loop, false);
RetryInfo retry = {
.num_retries = 5,
@@ -151,8 +160,8 @@ void lookup_fail_callback(UniqueID id, void *user_context, void *user_data) {
TEST lookup_timeout_test(void) {
g_loop = event_loop_create();
DBHandle *db =
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
"127.0.0.1", 0, NULL);
db_attach(db, g_loop, false);
RetryInfo retry = {
.num_retries = 5, .timeout = 100, .fail_callback = lookup_fail_callback,
@@ -161,6 +170,9 @@ TEST lookup_timeout_test(void) {
(void *) lookup_timeout_context);
/* Disconnect the database to see if the lookup times out. */
close(db->context->c.fd);
for (auto context : db->contexts) {
close(context->c.fd);
}
event_loop_run(g_loop);
db_disconnect(db);
destroy_outstanding_callbacks(g_loop);
@@ -187,8 +199,8 @@ void add_fail_callback(UniqueID id, void *user_context, void *user_data) {
TEST add_timeout_test(void) {
g_loop = event_loop_create();
DBHandle *db =
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
"127.0.0.1", 0, NULL);
db_attach(db, g_loop, false);
RetryInfo retry = {
.num_retries = 5, .timeout = 100, .fail_callback = add_fail_callback,
@@ -197,6 +209,9 @@ TEST add_timeout_test(void) {
add_done_callback, (void *) add_timeout_context);
/* Disconnect the database to see if the lookup times out. */
close(db->context->c.fd);
for (auto context : db->contexts) {
close(context->c.fd);
}
event_loop_run(g_loop);
db_disconnect(db);
destroy_outstanding_callbacks(g_loop);
@@ -225,8 +240,8 @@ void subscribe_fail_callback(UniqueID id, void *user_context, void *user_data) {
TEST subscribe_timeout_test(void) {
g_loop = event_loop_create();
DBHandle *db =
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
"127.0.0.1", 0, NULL);
db_attach(db, g_loop, false);
RetryInfo retry = {
.num_retries = 5,
@@ -236,7 +251,10 @@ TEST subscribe_timeout_test(void) {
object_table_subscribe_to_notifications(db, false, subscribe_done_callback,
NULL, &retry, NULL, NULL);
/* Disconnect the database to see if the lookup times out. */
close(db->sub_context->c.fd);
close(db->subscribe_context->c.fd);
for (auto subscribe_context : db->subscribe_contexts) {
close(subscribe_context->c.fd);
}
event_loop_run(g_loop);
db_disconnect(db);
destroy_outstanding_callbacks(g_loop);
@@ -324,8 +342,8 @@ TEST add_lookup_test(void) {
lookup_retry_succeeded = 0;
/* Construct the arguments to db_connect. */
const char *db_connect_args[] = {"address", "127.0.0.1:11235"};
DBHandle *db = db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 2,
db_connect_args);
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
"127.0.0.1", 2, db_connect_args);
db_attach(db, g_loop, true);
RetryInfo retry = {
.num_retries = 5,
@@ -385,8 +403,8 @@ void add_remove_callback(ObjectID object_id, bool success, void *user_context) {
TEST add_remove_lookup_test(void) {
g_loop = event_loop_create();
lookup_retry_succeeded = 0;
DBHandle *db =
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
"127.0.0.1", 0, NULL);
db_attach(db, g_loop, true);
RetryInfo retry = {
.num_retries = 5,
@@ -407,29 +425,6 @@ TEST add_remove_lookup_test(void) {
PASS();
}
/* === Test subscribe retry === */
const char *subscribe_retry_context = "subscribe_retry";
int subscribe_retry_succeeded = 0;
int64_t reconnect_sub_context_callback(event_loop *loop,
int64_t timer_id,
void *context) {
DBHandle *db = (DBHandle *) context;
/* Reconnect to redis. This is not reconnecting the pub/sub channel. */
redisAsyncFree(db->sub_context);
redisAsyncFree(db->context);
redisFree(db->sync_context);
db->sub_context = redisAsyncConnect("127.0.0.1", 6379);
db->sub_context->data = (void *) db;
db->context = redisAsyncConnect("127.0.0.1", 6379);
db->context->data = (void *) db;
db->sync_context = redisConnect("127.0.0.1", 6379);
/* Re-attach the database to the event loop (the file descriptor changed). */
db_attach(db, loop, true);
return EVENT_LOOP_TIMER_DONE;
}
/* ==== Test if late succeed is working correctly ==== */
/* === Test lookup late succeed === */
@@ -454,8 +449,8 @@ void lookup_late_done_callback(ObjectID object_id,
TEST lookup_late_test(void) {
g_loop = event_loop_create();
DBHandle *db =
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
"127.0.0.1", 0, NULL);
db_attach(db, g_loop, false);
RetryInfo retry = {
.num_retries = 0,
@@ -498,8 +493,8 @@ void add_late_done_callback(ObjectID object_id,
TEST add_late_test(void) {
g_loop = event_loop_create();
DBHandle *db =
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
"127.0.0.1", 0, NULL);
db_attach(db, g_loop, false);
RetryInfo retry = {
.num_retries = 0, .timeout = 0, .fail_callback = add_late_fail_callback,
@@ -543,8 +538,8 @@ void subscribe_late_done_callback(ObjectID object_id,
TEST subscribe_late_test(void) {
g_loop = event_loop_create();
DBHandle *db =
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
"127.0.0.1", 0, NULL);
db_attach(db, g_loop, false);
RetryInfo retry = {
.num_retries = 0,
@@ -611,8 +606,8 @@ TEST subscribe_success_test(void) {
/* Construct the arguments to db_connect. */
const char *db_connect_args[] = {"address", "127.0.0.1:11236"};
DBHandle *db = db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 2,
db_connect_args);
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
"127.0.0.1", 2, db_connect_args);
db_attach(db, g_loop, false);
subscribe_id = globally_unique_id();
@@ -680,8 +675,8 @@ TEST subscribe_object_present_test(void) {
g_loop = event_loop_create();
/* Construct the arguments to db_connect. */
const char *db_connect_args[] = {"address", "127.0.0.1:11236"};
DBHandle *db = db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 2,
db_connect_args);
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
"127.0.0.1", 2, db_connect_args);
db_attach(db, g_loop, false);
UniqueID id = globally_unique_id();
RetryInfo retry = {
@@ -732,8 +727,8 @@ void subscribe_object_not_present_object_available_callback(
TEST subscribe_object_not_present_test(void) {
g_loop = event_loop_create();
DBHandle *db =
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
"127.0.0.1", 0, NULL);
db_attach(db, g_loop, false);
UniqueID id = globally_unique_id();
RetryInfo retry = {
@@ -796,8 +791,8 @@ TEST subscribe_object_available_later_test(void) {
g_loop = event_loop_create();
/* Construct the arguments to db_connect. */
const char *db_connect_args[] = {"address", "127.0.0.1:11236"};
DBHandle *db = db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 2,
db_connect_args);
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
"127.0.0.1", 2, db_connect_args);
db_attach(db, g_loop, false);
UniqueID id = globally_unique_id();
RetryInfo retry = {
@@ -849,8 +844,8 @@ TEST subscribe_object_available_subscribe_all(void) {
g_loop = event_loop_create();
/* Construct the arguments to db_connect. */
const char *db_connect_args[] = {"address", "127.0.0.1:11236"};
DBHandle *db = db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 2,
db_connect_args);
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
"127.0.0.1", 2, db_connect_args);
db_attach(db, g_loop, false);
UniqueID id = globally_unique_id();
RetryInfo retry = {
@@ -904,7 +899,6 @@ SUITE(object_table_tests) {
RUN_REDIS_TEST(add_late_test);
RUN_REDIS_TEST(subscribe_late_test);
RUN_REDIS_TEST(subscribe_success_test);
RUN_REDIS_TEST(subscribe_object_present_test);
RUN_REDIS_TEST(subscribe_object_not_present_test);
RUN_REDIS_TEST(subscribe_object_available_later_test);
RUN_REDIS_TEST(subscribe_object_available_subscribe_all);
+4 -4
View File
@@ -102,8 +102,8 @@ TEST async_redis_socket_test(void) {
utarray_push_back(connections, &socket_fd);
/* Start connection to Redis. */
DBHandle *db =
db_connect("127.0.0.1", 6379, "test_process", "127.0.0.1", 0, NULL);
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "test_process",
"127.0.0.1", 0, NULL);
db_attach(db, loop, false);
/* Send a command to the Redis process. */
@@ -177,8 +177,8 @@ TEST logging_test(void) {
utarray_push_back(connections, &socket_fd);
/* Start connection to Redis. */
DBHandle *conn =
db_connect("127.0.0.1", 6379, "test_process", "127.0.0.1", 0, NULL);
DBHandle *conn = db_connect(std::string("127.0.0.1"), 6379, "test_process",
"127.0.0.1", 0, NULL);
db_attach(conn, loop, false);
/* Send a command to the Redis process. */
+9 -2
View File
@@ -5,8 +5,14 @@
# Cause the script to exit if a single command fails.
set -e
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so &
# Start the Redis shards.
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6379 &
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6380 &
sleep 1s
# Register the shard location with the primary shard.
./src/common/thirdparty/redis/src/redis-cli set NumRedisShards 1
./src/common/thirdparty/redis/src/redis-cli rpush RedisShards 127.0.0.1:6380
./src/common/common_tests
./src/common/db_tests
./src/common/io_tests
@@ -14,4 +20,5 @@ sleep 1s
./src/common/redis_tests
./src/common/task_table_tests
./src/common/object_table_tests
./src/common/thirdparty/redis/src/redis-cli shutdown
./src/common/thirdparty/redis/src/redis-cli -p 6379 shutdown
./src/common/thirdparty/redis/src/redis-cli -p 6380 shutdown
+9 -1
View File
@@ -5,8 +5,14 @@
# Cause the script to exit if a single command fails.
set -e
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so &
# Start the Redis shards.
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6379 &
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6380 &
sleep 1s
# Register the shard location with the primary shard.
./src/common/thirdparty/redis/src/redis-cli set NumRedisShards 1
./src/common/thirdparty/redis/src/redis-cli rpush RedisShards 127.0.0.1:6380
valgrind --leak-check=full --error-exitcode=1 ./src/common/common_tests
valgrind --leak-check=full --error-exitcode=1 ./src/common/db_tests
valgrind --leak-check=full --error-exitcode=1 ./src/common/io_tests
@@ -14,4 +20,6 @@ valgrind --leak-check=full --error-exitcode=1 ./src/common/task_tests
valgrind --leak-check=full --error-exitcode=1 ./src/common/redis_tests
valgrind --leak-check=full --error-exitcode=1 ./src/common/task_table_tests
valgrind --leak-check=full --error-exitcode=1 ./src/common/object_table_tests
./src/common/thirdparty/redis/src/redis-cli shutdown
./src/common/thirdparty/redis/src/redis-cli -p 6380 shutdown
+47 -22
View File
@@ -40,8 +40,8 @@ void lookup_nil_success_callback(Task *task, void *context) {
TEST lookup_nil_test(void) {
lookup_nil_id = globally_unique_id();
g_loop = event_loop_create();
DBHandle *db =
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
"127.0.0.1", 0, NULL);
db_attach(db, g_loop, false);
RetryInfo retry = {
.num_retries = 5,
@@ -96,14 +96,16 @@ void add_success_callback(TaskID task_id, void *context) {
TEST add_lookup_test(void) {
add_lookup_task = example_task(1, 1, TASK_STATUS_WAITING);
g_loop = event_loop_create();
DBHandle *db =
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
"127.0.0.1", 0, NULL);
db_attach(db, g_loop, false);
RetryInfo retry = {
.num_retries = 5,
.timeout = 1000,
.fail_callback = add_lookup_fail_callback,
};
task_table_subscribe(db, NIL_ID, TASK_STATUS_WAITING, NULL, NULL, &retry,
NULL, NULL);
task_table_add_task(db, Task_copy(add_lookup_task), &retry,
add_success_callback, (void *) db);
/* Disconnect the database to see if the lookup times out. */
@@ -136,8 +138,8 @@ void subscribe_fail_callback(UniqueID id, void *user_context, void *user_data) {
TEST subscribe_timeout_test(void) {
g_loop = event_loop_create();
DBHandle *db =
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
"127.0.0.1", 0, NULL);
db_attach(db, g_loop, false);
RetryInfo retry = {
.num_retries = 5,
@@ -148,7 +150,10 @@ TEST subscribe_timeout_test(void) {
subscribe_done_callback,
(void *) subscribe_timeout_context);
/* Disconnect the database to see if the subscribe times out. */
close(db->sub_context->c.fd);
close(db->subscribe_context->c.fd);
for (int i = 0; i < db->subscribe_contexts.size(); ++i) {
close(db->subscribe_contexts[i]->c.fd);
}
aeProcessEvents(g_loop, AE_TIME_EVENTS);
event_loop_run(g_loop);
db_disconnect(db);
@@ -177,17 +182,22 @@ void publish_fail_callback(UniqueID id, void *user_context, void *user_data) {
TEST publish_timeout_test(void) {
g_loop = event_loop_create();
DBHandle *db =
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
"127.0.0.1", 0, NULL);
db_attach(db, g_loop, false);
Task *task = example_task(1, 1, TASK_STATUS_WAITING);
RetryInfo retry = {
.num_retries = 5, .timeout = 100, .fail_callback = publish_fail_callback,
};
task_table_subscribe(db, NIL_ID, TASK_STATUS_WAITING, NULL, NULL, &retry,
NULL, NULL);
task_table_add_task(db, task, &retry, publish_done_callback,
(void *) publish_timeout_context);
/* Disconnect the database to see if the publish times out. */
close(db->context->c.fd);
for (int i = 0; i < db->contexts.size(); ++i) {
close(db->contexts[i]->c.fd);
}
aeProcessEvents(g_loop, AE_TIME_EVENTS);
event_loop_run(g_loop);
db_disconnect(db);
@@ -204,9 +214,14 @@ int64_t reconnect_db_callback(event_loop *loop,
void *context) {
DBHandle *db = (DBHandle *) context;
/* Reconnect to redis. */
redisAsyncFree(db->sub_context);
db->sub_context = redisAsyncConnect("127.0.0.1", 6379);
db->sub_context->data = (void *) db;
redisAsyncFree(db->subscribe_context);
db->subscribe_context = redisAsyncConnect("127.0.0.1", 6379);
db->subscribe_context->data = (void *) db;
for (int i = 0; i < db->subscribe_contexts.size(); ++i) {
redisAsyncFree(db->subscribe_contexts[i]);
db->subscribe_contexts[i] = redisAsyncConnect("127.0.0.1", 6380 + i);
db->subscribe_contexts[i]->data = (void *) db;
}
/* Re-attach the database to the event loop (the file descriptor changed). */
db_attach(db, loop, true);
return EVENT_LOOP_TIMER_DONE;
@@ -239,8 +254,8 @@ void subscribe_retry_fail_callback(UniqueID id,
TEST subscribe_retry_test(void) {
g_loop = event_loop_create();
DBHandle *db =
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
"127.0.0.1", 0, NULL);
db_attach(db, g_loop, false);
RetryInfo retry = {
.num_retries = 5,
@@ -251,7 +266,10 @@ TEST subscribe_retry_test(void) {
subscribe_retry_done_callback,
(void *) subscribe_retry_context);
/* Disconnect the database to see if the subscribe times out. */
close(db->sub_context->c.fd);
close(db->subscribe_context->c.fd);
for (int i = 0; i < db->subscribe_contexts.size(); ++i) {
close(db->subscribe_contexts[i]->c.fd);
}
/* Install handler for reconnecting the database. */
event_loop_add_timer(g_loop, 150,
(event_loop_timer_handler) reconnect_db_callback, db);
@@ -286,8 +304,8 @@ void publish_retry_fail_callback(UniqueID id,
TEST publish_retry_test(void) {
g_loop = event_loop_create();
DBHandle *db =
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
"127.0.0.1", 0, NULL);
db_attach(db, g_loop, false);
Task *task = example_task(1, 1, TASK_STATUS_WAITING);
RetryInfo retry = {
@@ -295,10 +313,15 @@ TEST publish_retry_test(void) {
.timeout = 100,
.fail_callback = publish_retry_fail_callback,
};
task_table_subscribe(db, NIL_ID, TASK_STATUS_WAITING, NULL, NULL, &retry,
NULL, NULL);
task_table_add_task(db, task, &retry, publish_retry_done_callback,
(void *) publish_retry_context);
/* Disconnect the database to see if the publish times out. */
close(db->sub_context->c.fd);
close(db->subscribe_context->c.fd);
for (int i = 0; i < db->subscribe_contexts.size(); ++i) {
close(db->subscribe_contexts[i]->c.fd);
}
/* Install handler for reconnecting the database. */
event_loop_add_timer(g_loop, 150,
(event_loop_timer_handler) reconnect_db_callback, db);
@@ -335,8 +358,8 @@ void subscribe_late_done_callback(TaskID task_id, void *user_context) {
TEST subscribe_late_test(void) {
g_loop = event_loop_create();
DBHandle *db =
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
"127.0.0.1", 0, NULL);
db_attach(db, g_loop, false);
RetryInfo retry = {
.num_retries = 0,
@@ -380,8 +403,8 @@ void publish_late_done_callback(TaskID task_id, void *user_context) {
TEST publish_late_test(void) {
g_loop = event_loop_create();
DBHandle *db =
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
"127.0.0.1", 0, NULL);
db_attach(db, g_loop, false);
Task *task = example_task(1, 1, TASK_STATUS_WAITING);
RetryInfo retry = {
@@ -389,6 +412,8 @@ TEST publish_late_test(void) {
.timeout = 0,
.fail_callback = publish_late_fail_callback,
};
task_table_subscribe(db, NIL_ID, TASK_STATUS_WAITING, NULL, NULL, NULL, NULL,
NULL);
task_table_add_task(db, task, &retry, publish_late_done_callback,
(void *) publish_late_context);
/* Install handler for terminating the event loop. */
+22 -1
View File
@@ -2,11 +2,13 @@
#define TEST_COMMON_H
#include <unistd.h>
#include <vector>
#include "common.h"
#include "io.h"
#include "hiredis/hiredis.h"
#include "utstring.h"
#include "state/redis.h"
#ifndef _WIN32
/* This function is actually not declared in standard POSIX, so declare it. */
@@ -48,10 +50,29 @@ static inline int bind_inet_sock_retry(int *fd) {
}
/* Flush redis. */
static inline void flushall_redis() {
static inline void flushall_redis(void) {
/* Flush the primary shard. */
redisContext *context = redisConnect("127.0.0.1", 6379);
std::vector<std::string> db_shards_addresses;
std::vector<int> db_shards_ports;
get_redis_shards(context, db_shards_addresses, db_shards_ports);
freeReplyObject(redisCommand(context, "FLUSHALL"));
/* Readd the shard locations. */
freeReplyObject(redisCommand(context, "SET NumRedisShards %d",
db_shards_addresses.size()));
for (int i = 0; i < db_shards_addresses.size(); ++i) {
freeReplyObject(redisCommand(context, "RPUSH RedisShards %s:%d",
db_shards_addresses[i].c_str(),
db_shards_ports[i]));
}
redisFree(context);
/* Flush the remaining shards. */
for (int i = 0; i < db_shards_addresses.size(); ++i) {
context = redisConnect(db_shards_addresses[i].c_str(), db_shards_ports[i]);
freeReplyObject(redisCommand(context, "FLUSHALL"));
redisFree(context);
}
}
/* Cleanup method for running tests with the greatest library.
+40 -30
View File
@@ -62,14 +62,14 @@ void assign_task_to_local_scheduler(GlobalSchedulerState *state,
GlobalSchedulerState *GlobalSchedulerState_init(event_loop *loop,
const char *node_ip_address,
const char *redis_addr,
int redis_port) {
const char *redis_primary_addr,
int redis_primary_port) {
GlobalSchedulerState *state =
(GlobalSchedulerState *) malloc(sizeof(GlobalSchedulerState));
/* Must initialize state to 0. Sets hashmap head(s) to NULL. */
memset(state, 0, sizeof(GlobalSchedulerState));
state->db = db_connect(redis_addr, redis_port, "global_scheduler",
node_ip_address, 0, NULL);
state->db = db_connect(std::string(redis_primary_addr), redis_primary_port,
"global_scheduler", node_ip_address, 0, NULL);
db_attach(state->db, loop, false);
utarray_new(state->local_schedulers, &local_scheduler_icd);
state->policy_state = GlobalSchedulerPolicyState_init();
@@ -253,26 +253,34 @@ void remove_local_scheduler(GlobalSchedulerState *state, int index) {
* @param aux_address: an ip:port pair for the plasma manager associated with
* this db client.
*/
void process_new_db_client(DBClientID db_client_id,
const char *client_type,
const char *aux_address,
bool is_insertion,
void *user_context) {
void process_new_db_client(DBClient *db_client, void *user_context) {
GlobalSchedulerState *state = (GlobalSchedulerState *) user_context;
char id_string[ID_STRING_SIZE];
LOG_DEBUG("db client table callback for db client = %s",
ObjectID_to_string(db_client_id, id_string, ID_STRING_SIZE));
ObjectID_to_string(db_client->id, id_string, ID_STRING_SIZE));
UNUSED(id_string);
if (strncmp(client_type, "local_scheduler", strlen("local_scheduler")) == 0) {
if (is_insertion) {
/* This is a notification for an insert. */
add_local_scheduler(state, db_client_id, aux_address);
if (strncmp(db_client->client_type, "local_scheduler",
strlen("local_scheduler")) == 0) {
if (db_client->is_insertion) {
/* This is a notification for an insert. We may receive duplicate
* notifications since we read the entire table before processing
* notifications. Filter out local schedulers that we already added. */
for (LocalScheduler *scheduler =
(LocalScheduler *) utarray_front(state->local_schedulers);
scheduler != NULL; scheduler = (LocalScheduler *) utarray_next(
state->local_schedulers, scheduler)) {
if (UNIQUE_ID_EQ(scheduler->id, db_client->id)) {
return;
}
}
add_local_scheduler(state, db_client->id, db_client->aux_address);
} else {
int i = 0;
for (; i < utarray_len(state->local_schedulers); ++i) {
LocalScheduler *active_worker =
(LocalScheduler *) utarray_eltptr(state->local_schedulers, i);
if (DBClientID_equal(active_worker->id, db_client_id)) {
if (DBClientID_equal(active_worker->id, db_client->id)) {
break;
}
}
@@ -418,11 +426,11 @@ int heartbeat_timeout_handler(event_loop *loop, timer_id id, void *context) {
}
void start_server(const char *node_ip_address,
const char *redis_addr,
int redis_port) {
const char *redis_primary_addr,
int redis_primary_port) {
event_loop *loop = event_loop_create();
g_state =
GlobalSchedulerState_init(loop, node_ip_address, redis_addr, redis_port);
g_state = GlobalSchedulerState_init(loop, node_ip_address, redis_primary_addr,
redis_primary_port);
/* TODO(rkn): subscribe to notifications from the object table. */
/* Subscribe to notifications about new local schedulers. TODO(rkn): this
* needs to also get all of the clients that registered with the database
@@ -458,15 +466,15 @@ void start_server(const char *node_ip_address,
int main(int argc, char *argv[]) {
signal(SIGTERM, signal_handler);
/* IP address and port of redis. */
char *redis_addr_port = NULL;
/* IP address and port of the primary redis instance. */
char *redis_primary_addr_port = NULL;
/* The IP address of the node that this global scheduler is running on. */
char *node_ip_address = NULL;
int c;
while ((c = getopt(argc, argv, "h:r:")) != -1) {
switch (c) {
case 'r':
redis_addr_port = optarg;
redis_primary_addr_port = optarg;
break;
case 'h':
node_ip_address = optarg;
@@ -476,16 +484,18 @@ int main(int argc, char *argv[]) {
exit(-1);
}
}
char redis_addr[16];
int redis_port;
if (!redis_addr_port ||
parse_ip_addr_port(redis_addr_port, redis_addr, &redis_port) == -1) {
LOG_ERROR(
"specify the redis address like 127.0.0.1:6379 with the -r switch");
exit(-1);
char redis_primary_addr[16];
int redis_primary_port;
if (!redis_primary_addr_port ||
parse_ip_addr_port(redis_primary_addr_port, redis_primary_addr,
&redis_primary_port) == -1) {
LOG_FATAL(
"specify the primary redis address like 127.0.0.1:6379 with the -r "
"switch");
}
if (!node_ip_address) {
LOG_FATAL("specify the node IP address with the -h switch");
}
start_server(node_ip_address, redis_addr, redis_port);
start_server(node_ip_address, redis_primary_addr, redis_primary_port);
}
+22 -28
View File
@@ -16,6 +16,7 @@
#include "local_scheduler_shared.h"
#include "local_scheduler.h"
#include "local_scheduler_algorithm.h"
#include "net.h"
#include "state/actor_notification_table.h"
#include "state/db.h"
#include "state/driver_table.h"
@@ -296,8 +297,8 @@ const char **parse_command(const char *command) {
LocalSchedulerState *LocalSchedulerState_init(
const char *node_ip_address,
event_loop *loop,
const char *redis_addr,
int redis_port,
const char *redis_primary_addr,
int redis_primary_port,
const char *local_scheduler_socket_name,
const char *plasma_store_socket_name,
const char *plasma_manager_socket_name,
@@ -323,7 +324,7 @@ LocalSchedulerState *LocalSchedulerState_init(
state->loop = loop;
/* Connect to Redis if a Redis address is provided. */
if (redis_addr != NULL) {
if (redis_primary_addr != NULL) {
int num_args;
const char **db_connect_args = NULL;
/* Use UT_string to convert the resource value into a string. */
@@ -354,8 +355,9 @@ LocalSchedulerState *LocalSchedulerState_init(
db_connect_args[4] = "num_gpus";
db_connect_args[5] = utstring_body(num_gpus);
}
state->db = db_connect(redis_addr, redis_port, "local_scheduler",
node_ip_address, num_args, db_connect_args);
state->db = db_connect(std::string(redis_primary_addr), redis_primary_port,
"local_scheduler", node_ip_address, num_args,
db_connect_args);
utstring_free(num_cpus);
utstring_free(num_gpus);
free(db_connect_args);
@@ -548,8 +550,6 @@ void process_plasma_notification(event_loop *loop,
void reconstruct_task_update_callback(Task *task,
void *user_context,
bool updated) {
/* The task ID should be in the task table. */
CHECK(task != NULL);
if (!updated) {
/* The test-and-set of the task's scheduling state failed, so the task was
* either not finished yet, or it was already being reconstructed.
@@ -578,7 +578,6 @@ void reconstruct_task_update_callback(Task *task,
void reconstruct_put_task_update_callback(Task *task,
void *user_context,
bool updated) {
CHECK(task != NULL);
if (updated) {
/* The update to TASK_STATUS_RECONSTRUCTING succeeded, so continue with
* reconstruction as usual. */
@@ -1111,8 +1110,8 @@ int heartbeat_handler(event_loop *loop, timer_id id, void *context) {
void start_server(const char *node_ip_address,
const char *socket_name,
const char *redis_addr,
int redis_port,
const char *redis_primary_addr,
int redis_primary_port,
const char *plasma_store_socket_name,
const char *plasma_manager_socket_name,
const char *plasma_manager_address,
@@ -1126,8 +1125,8 @@ void start_server(const char *node_ip_address,
int fd = bind_ipc_sock(socket_name, true);
event_loop *loop = event_loop_create();
g_state = LocalSchedulerState_init(
node_ip_address, loop, redis_addr, redis_port, socket_name,
plasma_store_socket_name, plasma_manager_socket_name,
node_ip_address, loop, redis_primary_addr, redis_primary_port,
socket_name, plasma_store_socket_name, plasma_manager_socket_name,
plasma_manager_address, global_scheduler_exists, static_resource_conf,
start_worker_command, num_workers);
/* Register a callback for registering new clients. */
@@ -1173,8 +1172,8 @@ int main(int argc, char *argv[]) {
signal(SIGTERM, signal_handler);
/* Path of the listening socket of the local scheduler. */
char *scheduler_socket_name = NULL;
/* IP address and port of redis. */
char *redis_addr_port = NULL;
/* IP address and port of the primary redis instance. */
char *redis_primary_addr_port = NULL;
/* Socket name for the local Plasma store. */
char *plasma_store_socket_name = NULL;
/* Socket name for the local Plasma manager. */
@@ -1199,7 +1198,7 @@ int main(int argc, char *argv[]) {
scheduler_socket_name = optarg;
break;
case 'r':
redis_addr_port = optarg;
redis_primary_addr_port = optarg;
break;
case 'p':
plasma_store_socket_name = optarg;
@@ -1266,7 +1265,7 @@ int main(int argc, char *argv[]) {
char *redis_addr = NULL;
int redis_port = -1;
if (!redis_addr_port) {
if (!redis_primary_addr_port) {
/* Start the local scheduler without connecting to Redis. In this case, all
* submitted tasks will be queued and scheduled locally. */
if (plasma_manager_socket_name) {
@@ -1275,27 +1274,22 @@ int main(int argc, char *argv[]) {
"then a redis address must be provided with the -r switch");
}
} else {
char redis_addr_buffer[16] = {0};
char redis_port_str[6] = {0};
/* Parse the Redis address into an IP address and a port. */
int num_assigned = sscanf(redis_addr_port, "%15[0-9.]:%5[0-9]",
redis_addr_buffer, redis_port_str);
if (num_assigned != 2) {
char redis_primary_addr[16];
int redis_primary_port;
/* Parse the primary Redis address into an IP address and a port. */
if (parse_ip_addr_port(redis_primary_addr_port, redis_primary_addr,
&redis_primary_port) == -1) {
LOG_FATAL(
"if a redis address is provided with the -r switch, it should be "
"formatted like 127.0.0.1:6379");
}
redis_addr = redis_addr_buffer;
redis_port = strtol(redis_port_str, NULL, 10);
if (redis_port == 0) {
LOG_FATAL("Unable to parse port number from redis address %s",
redis_addr_port);
}
if (!plasma_manager_socket_name) {
LOG_FATAL(
"please specify socket for connecting to Plasma manager with -m "
"switch");
}
redis_addr = redis_primary_addr;
redis_port = redis_primary_port;
}
start_server(node_ip_address, scheduler_socket_name, redis_addr, redis_port,
@@ -839,7 +839,8 @@ void handle_actor_task_submitted(LocalSchedulerState *state,
/* Add this task to a queue of tasks that have been submitted but the local
* scheduler doesn't know which actor is responsible for them. These tasks
* will be resubmitted (internally by the local scheduler) whenever a new
* actor notification arrives. */
* actor notification arrives. NOTE(swang): These tasks have not yet been
* added to the task table. */
utarray_push_back(algorithm_state->cached_submitted_actor_tasks, &spec);
utarray_push_back(algorithm_state->cached_submitted_actor_task_sizes,
&task_spec_size);
@@ -18,6 +18,7 @@
#include "task.h"
#include "state/object_table.h"
#include "state/task_table.h"
#include "state/redis.h"
#include "local_scheduler_shared.h"
#include "local_scheduler.h"
@@ -182,7 +183,16 @@ TEST object_reconstruction_test(void) {
/* Add an empty object table entry for the object we want to reconstruct, to
* simulate it having been created and evicted. */
const char *client_id = "clientid";
/* Lookup the shard locations for the object table. */
std::vector<std::string> db_shards_addresses;
std::vector<int> db_shards_ports;
redisContext *context = redisConnect("127.0.0.1", 6379);
get_redis_shards(context, db_shards_addresses, db_shards_ports);
redisFree(context);
/* There should only be one shard, so we can safely add the empty object
* table entry to the first one. */
ASSERT(db_shards_addresses.size() == 1);
context = redisConnect(db_shards_addresses[0].c_str(), db_shards_ports[0]);
redisReply *reply = (redisReply *) redisCommand(
context, "RAY.OBJECT_TABLE_ADD %b %ld %b %s", return_id.id,
sizeof(return_id.id), 1, NIL_DIGEST, (size_t) DIGEST_SIZE, client_id);
@@ -273,7 +283,16 @@ TEST object_reconstruction_recursive_test(void) {
/* Add an empty object table entry for each object we want to reconstruct, to
* simulate their having been created and evicted. */
const char *client_id = "clientid";
/* Lookup the shard locations for the object table. */
std::vector<std::string> db_shards_addresses;
std::vector<int> db_shards_ports;
redisContext *context = redisConnect("127.0.0.1", 6379);
get_redis_shards(context, db_shards_addresses, db_shards_ports);
redisFree(context);
/* There should only be one shard, so we can safely add the empty object
* table entry to the first one. */
ASSERT(db_shards_addresses.size() == 1);
context = redisConnect(db_shards_addresses[0].c_str(), db_shards_ports[0]);
for (int i = 0; i < NUM_TASKS; ++i) {
ObjectID return_id = TaskSpec_return(specs[i], 0);
redisReply *reply = (redisReply *) redisCommand(
@@ -406,8 +425,8 @@ TEST object_reconstruction_suppression_test(void) {
} else {
/* Connect a plasma manager client so we can call object_table_add. */
const char *db_connect_args[] = {"address", "127.0.0.1:12346"};
DBHandle *db = db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1",
2, db_connect_args);
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
"127.0.0.1", 2, db_connect_args);
db_attach(db, local_scheduler->loop, false);
/* Add the object to the object table. */
object_table_add(db, return_id, 1, (unsigned char *) NIL_DIGEST, NULL,
+8 -1
View File
@@ -5,10 +5,17 @@
# Cause the script to exit if a single command fails.
set -e
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so &
# Start the Redis shards.
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6379 &
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6380 &
sleep 1s
# Register the shard location with the primary shard.
./src/common/thirdparty/redis/src/redis-cli set NumRedisShards 1
./src/common/thirdparty/redis/src/redis-cli rpush RedisShards 127.0.0.1:6380
./src/plasma/plasma_store -s /tmp/plasma_store_socket_1 -m 100000000 &
sleep 0.5s
./src/local_scheduler/local_scheduler_tests
./src/common/thirdparty/redis/src/redis-cli shutdown
./src/common/thirdparty/redis/src/redis-cli -p 6380 shutdown
killall plasma_store
+8 -1
View File
@@ -5,10 +5,17 @@
# Cause the script to exit if a single command fails.
set -e
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so &
# Start the Redis shards.
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6379 &
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6380 &
sleep 1s
# Register the shard location with the primary shard.
./src/common/thirdparty/redis/src/redis-cli set NumRedisShards 1
./src/common/thirdparty/redis/src/redis-cli rpush RedisShards 127.0.0.1:6380
./src/plasma/plasma_store -s /tmp/plasma_store_socket_1 -m 100000000 &
sleep 0.5s
valgrind --leak-check=full --show-leak-kinds=all --error-exitcode=1 ./src/local_scheduler/local_scheduler_tests
./src/common/thirdparty/redis/src/redis-cli shutdown
./src/common/thirdparty/redis/src/redis-cli -p 6380 shutdown
killall plasma_store
+24 -22
View File
@@ -493,8 +493,8 @@ PlasmaManagerState *PlasmaManagerState_init(const char *store_socket_name,
const char *manager_socket_name,
const char *manager_addr,
int manager_port,
const char *db_addr,
int db_port) {
const char *redis_primary_addr,
int redis_primary_port) {
PlasmaManagerState *state =
(PlasmaManagerState *) malloc(sizeof(PlasmaManagerState));
state->loop = event_loop_create();
@@ -504,7 +504,7 @@ PlasmaManagerState *PlasmaManagerState_init(const char *store_socket_name,
state->fetch_requests = NULL;
state->object_wait_requests_local = NULL;
state->object_wait_requests_remote = NULL;
if (db_addr) {
if (redis_primary_addr) {
/* Get the manager port as a string. */
UT_string *manager_address_str;
utstring_new(manager_address_str);
@@ -519,8 +519,9 @@ PlasmaManagerState *PlasmaManagerState_init(const char *store_socket_name,
db_connect_args[3] = manager_socket_name;
db_connect_args[4] = "address";
db_connect_args[5] = utstring_body(manager_address_str);
state->db = db_connect(db_addr, db_port, "plasma_manager", manager_addr,
num_args, db_connect_args);
state->db =
db_connect(std::string(redis_primary_addr), redis_primary_port,
"plasma_manager", manager_addr, num_args, db_connect_args);
utstring_free(manager_address_str);
free(db_connect_args);
db_attach(state->db, state->loop, false);
@@ -1594,8 +1595,8 @@ void start_server(const char *store_socket_name,
const char *manager_socket_name,
const char *master_addr,
int port,
const char *db_addr,
int db_port) {
const char *redis_primary_addr,
int redis_primary_port) {
/* Ignore SIGPIPE signals. If we don't do this, then when we attempt to write
* to a client that has already died, the manager could die. */
signal(SIGPIPE, SIG_IGN);
@@ -1610,9 +1611,9 @@ void start_server(const char *store_socket_name,
int local_sock = bind_ipc_sock(manager_socket_name, false);
CHECKM(local_sock >= 0, "Unable to bind local manager socket");
g_manager_state =
PlasmaManagerState_init(store_socket_name, manager_socket_name,
master_addr, port, db_addr, db_port);
g_manager_state = PlasmaManagerState_init(
store_socket_name, manager_socket_name, master_addr, port,
redis_primary_addr, redis_primary_port);
CHECK(g_manager_state);
CHECK(listen(remote_sock, 5) != -1);
@@ -1664,8 +1665,8 @@ int main(int argc, char *argv[]) {
char *master_addr = NULL;
/* Port number the manager should use. */
int port = -1;
/* IP address and port of state database. */
char *db_host = NULL;
/* IP address and port of the primary redis instance. */
char *redis_primary_addr_port = NULL;
int c;
while ((c = getopt(argc, argv, "s:m:h:p:r:")) != -1) {
switch (c) {
@@ -1682,7 +1683,7 @@ int main(int argc, char *argv[]) {
port = atoi(optarg);
break;
case 'r':
db_host = optarg;
redis_primary_addr_port = optarg;
break;
default:
LOG_FATAL("unknown option %c", c);
@@ -1708,15 +1709,16 @@ int main(int argc, char *argv[]) {
"please specify port the plasma manager shall listen to in the"
"format 12345 with -p switch");
}
char db_addr[16];
int db_port;
if (db_host) {
parse_ip_addr_port(db_host, db_addr, &db_port);
start_server(store_socket_name, manager_socket_name, master_addr, port,
db_addr, db_port);
} else {
start_server(store_socket_name, manager_socket_name, master_addr, port,
NULL, 0);
char redis_primary_addr[16];
int redis_primary_port;
if (!redis_primary_addr_port ||
parse_ip_addr_port(redis_primary_addr_port, redis_primary_addr,
&redis_primary_port) == -1) {
LOG_FATAL(
"specify the primary redis address like 127.0.0.1:6379 with the -r "
"switch");
}
start_server(store_socket_name, manager_socket_name, master_addr, port,
redis_primary_addr, redis_primary_port);
}
#endif
+15 -6
View File
@@ -9,11 +9,18 @@ sleep 1
killall plasma_store
./src/plasma/serialization_tests
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so &
redis_pid=$!
# Start the Redis shards.
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6379 &
redis_pid1=$!
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6380 &
redis_pid2=$!
sleep 1
# flush the redis server
./src/common/thirdparty/redis/src/redis-cli flushall &
# Flush the redis server
./src/common/thirdparty/redis/src/redis-cli flushall
# Register the shard location with the primary shard.
./src/common/thirdparty/redis/src/redis-cli set NumRedisShards 1
./src/common/thirdparty/redis/src/redis-cli rpush RedisShards 127.0.0.1:6380
sleep 1
./src/plasma/plasma_store -s /tmp/store1 -m 1000000000 &
plasma1_pid=$!
@@ -31,5 +38,7 @@ kill $plasma4_pid
kill $plasma3_pid
kill $plasma2_pid
kill $plasma1_pid
kill $redis_pid
wait $redis_pid
kill $redis_pid1
wait $redis_pid1
kill $redis_pid2
wait $redis_pid2
+13 -8
View File
@@ -86,8 +86,8 @@ class DockerRunner(object):
else:
return m.group(1)
def _start_head_node(self, docker_image, mem_size, shm_size, num_cpus,
num_gpus, development_mode):
def _start_head_node(self, docker_image, mem_size, shm_size,
num_redis_shards, num_cpus, num_gpus, development_mode):
"""Start the Ray head node inside a docker container."""
mem_arg = ["--memory=" + mem_size] if mem_size else []
shm_arg = ["--shm-size=" + shm_size] if shm_size else []
@@ -99,6 +99,7 @@ class DockerRunner(object):
command = (["docker", "run", "-d"] + mem_arg + shm_arg + volume_arg +
[docker_image, "/ray/scripts/start_ray.sh", "--head",
"--redis-port=6379",
"--num-redis-shards={}".format(num_redis_shards),
"--num-cpus={}".format(num_cpus),
"--num-gpus={}".format(num_gpus)])
print("Starting head node with command:{}".format(command))
@@ -137,8 +138,8 @@ class DockerRunner(object):
self.worker_container_ids.append(container_id)
def start_ray(self, docker_image=None, mem_size=None, shm_size=None,
num_nodes=None, num_cpus=None, num_gpus=None,
development_mode=None):
num_nodes=None, num_redis_shards=1, num_cpus=None,
num_gpus=None, development_mode=None):
"""Start a Ray cluster within docker.
This starts one docker container running the head node and num_nodes - 1
@@ -153,6 +154,7 @@ class DockerRunner(object):
with. This will be passed into `docker run` as the `--shm-size` flag.
num_nodes: The number of nodes to use in the cluster (this counts the
head node as well).
num_redis_shards: The number of Redis shards to use on the head node.
num_cpus: A list of the number of CPUs to start each node with.
num_gpus: A list of the number of GPUs to start each node with.
development_mode: True if you want to mount the local copy of
@@ -163,8 +165,8 @@ class DockerRunner(object):
assert len(num_gpus) == num_nodes
# Launch the head node.
self._start_head_node(docker_image, mem_size, shm_size, num_cpus[0],
num_gpus[0], development_mode)
self._start_head_node(docker_image, mem_size, shm_size, num_redis_shards,
num_cpus[0], num_gpus[0], development_mode)
# Start the worker nodes.
for i in range(num_nodes - 1):
self._start_worker_node(docker_image, mem_size, shm_size,
@@ -252,6 +254,9 @@ if __name__ == "__main__":
parser.add_argument("--shm-size", default="1G", help="shared memory size")
parser.add_argument("--num-nodes", default=1, type=int,
help="number of nodes to use in the cluster")
parser.add_argument("--num-redis-shards", default=1, type=int,
help=("the number of Redis shards to start on the head "
"node"))
parser.add_argument("--num-cpus", type=str,
help=("a comma separated list of values representing "
"the number of CPUs to start each node with"))
@@ -282,8 +287,8 @@ if __name__ == "__main__":
d = DockerRunner()
d.start_ray(docker_image=args.docker_image, mem_size=args.mem_size,
shm_size=args.shm_size, num_nodes=num_nodes,
num_cpus=num_cpus, num_gpus=num_gpus,
development_mode=args.development_mode)
num_redis_shards=args.num_redis_shards, num_cpus=num_cpus,
num_gpus=num_gpus, development_mode=args.development_mode)
try:
run_results = d.run_test(args.test_script, args.num_drivers,
driver_locations=driver_locations)
@@ -14,11 +14,13 @@ echo "Using Docker image" $DOCKER_SHA
python $ROOT_DIR/multi_node_docker_test.py \
--docker-image=$DOCKER_SHA \
--num-nodes=5 \
--num-redis-shards=10 \
--test-script=/ray/test/jenkins_tests/multi_node_tests/test_0.py
python $ROOT_DIR/multi_node_docker_test.py \
--docker-image=$DOCKER_SHA \
--num-nodes=5 \
--num-redis-shards=5 \
--num-gpus=0,1,2,3,4 \
--num-drivers=7 \
--driver-locations=0,1,0,1,2,3,4 \
@@ -27,6 +29,7 @@ python $ROOT_DIR/multi_node_docker_test.py \
python $ROOT_DIR/multi_node_docker_test.py \
--docker-image=$DOCKER_SHA \
--num-nodes=5 \
--num-redis-shards=2 \
--num-gpus=0,0,5,6,50 \
--num-drivers=100 \
--test-script=/ray/test/jenkins_tests/multi_node_tests/many_drivers_test.py
+34 -42
View File
@@ -293,8 +293,16 @@ class WorkerTest(unittest.TestCase):
class APITest(unittest.TestCase):
def init_ray(self, kwargs=None):
if kwargs is None:
kwargs = {}
ray.init(**kwargs)
def tearDown(self):
ray.worker.cleanup()
def testRegisterClass(self):
ray.init(num_workers=2)
self.init_ray({"num_workers": 2})
# Check that putting an object of a class that has not been registered
# throws an exception.
@@ -417,11 +425,9 @@ class APITest(unittest.TestCase):
self.assertFalse(hasattr(c2, "method0"))
self.assertFalse(hasattr(c2, "method1"))
ray.worker.cleanup()
def testKeywordArgs(self):
reload(test_functions)
ray.init(num_workers=1)
self.init_ray()
x = test_functions.keyword_fct1.remote(1)
self.assertEqual(ray.get(x), "1 hello")
@@ -483,11 +489,9 @@ class APITest(unittest.TestCase):
self.assertEqual(ray.get(f3.remote(4)), 4)
ray.worker.cleanup()
def testVariableNumberOfArgs(self):
reload(test_functions)
ray.init(num_workers=1)
self.init_ray()
x = test_functions.varargs_fct1.remote(0, 1, 2)
self.assertEqual(ray.get(x), "0 1 2")
@@ -516,18 +520,14 @@ class APITest(unittest.TestCase):
self.assertEqual(ray.get(f2.remote(1, 2, 3)), (1, 2, (3,)))
self.assertEqual(ray.get(f2.remote(1, 2, 3, 4)), (1, 2, (3, 4)))
ray.worker.cleanup()
def testNoArgs(self):
reload(test_functions)
ray.init(num_workers=1)
self.init_ray()
ray.get(test_functions.no_op.remote())
ray.worker.cleanup()
def testDefiningRemoteFunctions(self):
ray.init(num_workers=3, num_cpus=3)
self.init_ray({"num_cpus": 3})
# Test that we can define a remote function in the shell.
@ray.remote
@@ -584,10 +584,8 @@ class APITest(unittest.TestCase):
self.assertEqual(ray.get(l.remote(1)), 2)
self.assertEqual(ray.get(m.remote(1)), 2)
ray.worker.cleanup()
def testGetMultiple(self):
ray.init(num_workers=0)
self.init_ray()
object_ids = [ray.put(i) for i in range(10)]
self.assertEqual(ray.get(object_ids), list(range(10)))
@@ -597,10 +595,8 @@ class APITest(unittest.TestCase):
results = ray.get([object_ids[i] for i in indices])
self.assertEqual(results, indices)
ray.worker.cleanup()
def testWait(self):
ray.init(num_workers=1, num_cpus=1)
self.init_ray({"num_cpus": 1})
@ray.remote
def f(delay):
@@ -633,12 +629,10 @@ class APITest(unittest.TestCase):
x = ray.put(1)
self.assertRaises(Exception, lambda: ray.wait([x, x]))
ray.worker.cleanup()
def testMultipleWaitsAndGets(self):
# It is important to use three workers here, so that the three tasks
# launched in this experiment can run at the same time.
ray.init(num_workers=3)
self.init_ray()
@ray.remote
def f(delay):
@@ -665,8 +659,6 @@ class APITest(unittest.TestCase):
x = f.remote(1)
ray.get([h.remote([x]), h.remote([x])])
ray.worker.cleanup()
def testCachingEnvironmentVariables(self):
# Test that we can define environment variables before the driver is
# connected.
@@ -690,15 +682,13 @@ class APITest(unittest.TestCase):
ray.env.bar.append(1)
return ray.env.bar
ray.init(num_workers=2)
self.init_ray()
self.assertEqual(ray.get(use_foo.remote()), 1)
self.assertEqual(ray.get(use_foo.remote()), 1)
self.assertEqual(ray.get(use_bar.remote()), [1])
self.assertEqual(ray.get(use_bar.remote()), [1])
ray.worker.cleanup()
def testCachingFunctionsToRun(self):
# Test that we export functions to run on all workers before the driver is
# connected.
@@ -718,7 +708,7 @@ class APITest(unittest.TestCase):
sys.path.append(4)
ray.worker.global_worker.run_function_on_all_workers(f)
ray.init(num_workers=2)
self.init_ray()
@ray.remote
def get_state():
@@ -738,10 +728,8 @@ class APITest(unittest.TestCase):
sys.path.pop()
ray.worker.global_worker.run_function_on_all_workers(f)
ray.worker.cleanup()
def testRunningFunctionOnAllWorkers(self):
ray.init(num_workers=1)
self.init_ray()
def f(worker_info):
sys.path.append("fake_directory")
@@ -764,10 +752,8 @@ class APITest(unittest.TestCase):
return sys.path
self.assertTrue("fake_directory" not in ray.get(get_path2.remote()))
ray.worker.cleanup()
def testLoggingAPI(self):
ray.init(num_workers=1, driver_mode=ray.SILENT_MODE)
self.init_ray({"driver_mode": ray.SILENT_MODE})
def events():
# This is a hack for getting the event log. It is not part of the API.
@@ -815,12 +801,10 @@ class APITest(unittest.TestCase):
wait_for_num_events(3)
self.assertEqual(len(events()), 3)
ray.worker.cleanup()
def testIdenticalFunctionNames(self):
# Define a bunch of remote functions and make sure that we don't
# accidentally call an older version.
ray.init(num_workers=2)
self.init_ray()
num_calls = 200
@@ -878,10 +862,8 @@ class APITest(unittest.TestCase):
result_values = ray.get([g.remote() for _ in range(num_calls)])
self.assertEqual(result_values, num_calls * [5])
ray.worker.cleanup()
def testIllegalAPICalls(self):
ray.init(num_workers=0)
self.init_ray()
# Verify that we cannot call put on an ObjectID.
x = ray.put(1)
@@ -891,7 +873,16 @@ class APITest(unittest.TestCase):
with self.assertRaises(Exception):
ray.get(3)
ray.worker.cleanup()
class APITestSharded(APITest):
def init_ray(self, kwargs=None):
if kwargs is None:
kwargs = {}
kwargs["start_ray_local"] = True
kwargs["num_redis_shards"] = 20
kwargs["redirect_output"] = True
ray.worker._init(**kwargs)
class PythonModeTest(unittest.TestCase):
@@ -1619,7 +1610,8 @@ class GlobalStateAPI(unittest.TestCase):
task_table = ray.global_state.task_table()
self.assertEqual(len(task_table), 1)
self.assertEqual(driver_task_id, list(task_table.keys())[0])
self.assertEqual(task_table[driver_task_id]["State"], "RUNNING")
self.assertEqual(task_table[driver_task_id]["State"],
ray.experimental.state.TASK_STATUS_RUNNING)
self.assertEqual(task_table[driver_task_id]["TaskSpec"]["TaskID"],
driver_task_id)
self.assertEqual(task_table[driver_task_id]["TaskSpec"]["ActorID"],
+26 -21
View File
@@ -6,10 +6,6 @@ import unittest
import ray
import numpy as np
import time
import redis
# Import flatbuffer bindings.
from ray.core.generated.TaskReply import TaskReply
class TaskTests(unittest.TestCase):
@@ -137,26 +133,38 @@ class ReconstructionTests(unittest.TestCase):
num_local_schedulers = 1
def setUp(self):
# Start a Redis instance and Plasma store instances with a total of 1GB
# memory.
# Start the Redis global state store.
node_ip_address = "127.0.0.1"
self.redis_port = ray.services.new_port()
print(self.redis_port)
redis_address = ray.services.address(node_ip_address, self.redis_port)
redis_address, redis_shards = ray.services.start_redis(node_ip_address)
self.redis_ip_address = ray.services.get_ip_address(redis_address)
self.redis_port = ray.services.get_port(redis_address)
time.sleep(0.1)
# Start the Plasma store instances with a total of 1GB memory.
self.plasma_store_memory = 10 ** 9
plasma_addresses = []
objstore_memory = (self.plasma_store_memory // self.num_local_schedulers)
for i in range(self.num_local_schedulers):
store_stdout_file, store_stderr_file = ray.services.new_log_files(
"plasma_store_{}".format(i), True)
manager_stdout_file, manager_stderr_file = ray.services.new_log_files(
"plasma_manager_{}".format(i), True)
plasma_addresses.append(ray.services.start_objstore(
node_ip_address, redis_address, objstore_memory=objstore_memory))
address_info = {"redis_address": redis_address,
"object_store_addresses": plasma_addresses}
node_ip_address, redis_address, objstore_memory=objstore_memory,
store_stdout_file=store_stdout_file,
store_stderr_file=store_stderr_file,
manager_stdout_file=manager_stdout_file,
manager_stderr_file=manager_stderr_file))
# Start the rest of the services in the Ray cluster.
address_info = {"redis_address": redis_address,
"redis_shards": redis_shards,
"object_store_addresses": plasma_addresses}
ray.worker._init(address_info=address_info, start_ray_local=True,
num_workers=1,
num_local_schedulers=self.num_local_schedulers,
num_cpus=[1] * self.num_local_schedulers,
redirect_output=True,
driver_mode=ray.SILENT_MODE)
def tearDown(self):
@@ -164,14 +172,11 @@ class ReconstructionTests(unittest.TestCase):
# Determine the IDs of all local schedulers that had a task scheduled or
# submitted.
r = redis.StrictRedis(port=self.redis_port)
task_ids = r.keys("TT:*")
task_ids = [task_id[3:] for task_id in task_ids]
local_scheduler_ids = []
for task_id in task_ids:
message = r.execute_command("ray.task_table_get", task_id)
task_reply_object = TaskReply.GetRootAsTaskReply(message, 0)
local_scheduler_ids.append(task_reply_object.LocalSchedulerId())
state = ray.experimental.state.GlobalState()
state._initialize_global_state(self.redis_ip_address, self.redis_port)
tasks = state.task_table()
local_scheduler_ids = set(task["LocalSchedulerID"] for task in
tasks.values())
# Make sure that all nodes in the cluster were used by checking that the
# set of local scheduler IDs that had a task scheduled or submitted is
@@ -179,7 +184,7 @@ class ReconstructionTests(unittest.TestCase):
# total number of local schedulers to account for NIL_LOCAL_SCHEDULER_ID.
# This is the local scheduler ID associated with the driver task, since it
# is not scheduled by a particular local scheduler.
self.assertEqual(len(set(local_scheduler_ids)),
self.assertEqual(len(local_scheduler_ids),
self.num_local_schedulers + 1)
# Clean up the Ray cluster.