mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:00:36 +08:00
Shard Redis. (#539)
* Implement sharding in the Ray core * Single node Python modifications to do sharding * Do the sharding in redis.cc * Pipe num_redis_shards through start_ray.py and worker.py. * Use multiple redis shards in multinode tests. * first steps for sharding ray.global_state * Fix problem in multinode docker test. * fix runtest.py * fix some tests * fix redis shard startup * fix redis sharding * fix * fix bug introduced by the map-iterator being consumed * fix sharding bug * shard event table * update number of Redis clients to be 64K * Fix object table tests by flushing shards in between unit tests * Fix local scheduler tests * Documentation * Register shard locations in the primary shard * Add plasma unit tests back to build * lint * lint and fix build * Fix * Address Robert's comments * Refactor start_ray_processes to start Redis shard * lint * Fix global scheduler python tests * Fix redis module test * Fix plasma test * Fix component failure test * Fix local scheduler test * Fix runtest.py * Fix global scheduler test for python3 * Fix task_table_test_and_update bug, from actor task table submission race * Fix jenkins tests. * Retry Redis shard connections * Fix test cases * Convert database clients to DBClient struct * Fix race condition when subscribing to db client table * Remove unused lines, add APITest for sharded Ray * Fix * Fix memory leak * Suppress ReconstructionTests output * Suppress output for APITestSharded * Reissue task table add/update commands if initial command does not publish to any subscribers. * fix * Fix linting. * fix tests * fix linting * fix python test * fix linting
This commit is contained in:
committed by
Philipp Moritz
parent
0a4304725f
commit
ee08c8274b
@@ -52,7 +52,7 @@ def get_next_message(pubsub_client, timeout_seconds=10):
|
||||
class TestGlobalStateStore(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
redis_port, _ = ray.services.start_redis()
|
||||
redis_port, _ = ray.services.start_redis_instance()
|
||||
self.redis = redis.StrictRedis(host="localhost", port=redis_port, db=0)
|
||||
|
||||
def tearDown(self):
|
||||
@@ -308,6 +308,10 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
TASK_STATUS_SCHEDULED = 2
|
||||
TASK_STATUS_QUEUED = 4
|
||||
|
||||
# make sure somebody will get a notification (checked in the redis module)
|
||||
p = self.redis.pubsub()
|
||||
p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
|
||||
|
||||
def check_task_reply(message, task_args, updated=False):
|
||||
task_status, local_scheduler_id, task_spec = task_args
|
||||
task_reply_object = TaskReply.GetRootAsTaskReply(message, 0)
|
||||
@@ -388,33 +392,53 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
self.assertNotEqual(get_response, old_response)
|
||||
check_task_reply(get_response, task_args[1:])
|
||||
|
||||
def check_task_subscription(self, p, scheduling_state, local_scheduler_id):
|
||||
task_args = [b"task_id", scheduling_state,
|
||||
local_scheduler_id.encode("ascii"), b"task_spec"]
|
||||
self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args)
|
||||
# Receive the data.
|
||||
message = get_next_message(p)["data"]
|
||||
# Check that the notification object is correct.
|
||||
notification_object = TaskReply.GetRootAsTaskReply(message, 0)
|
||||
self.assertEqual(notification_object.TaskId(), b"task_id")
|
||||
self.assertEqual(notification_object.State(), scheduling_state)
|
||||
self.assertEqual(notification_object.LocalSchedulerId(),
|
||||
local_scheduler_id.encode("ascii"))
|
||||
self.assertEqual(notification_object.TaskSpec(), b"task_spec")
|
||||
|
||||
def testTaskTableSubscribe(self):
|
||||
scheduling_state = 1
|
||||
local_scheduler_id = "local_scheduler_id"
|
||||
# Subscribe to the task table.
|
||||
p = self.redis.pubsub()
|
||||
p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
|
||||
# Receive acknowledgment.
|
||||
self.assertEqual(get_next_message(p)["data"], 1)
|
||||
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
|
||||
# unsubscribe to make sure there is only one subscriber at a given time
|
||||
p.punsubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
|
||||
# Receive acknowledgment.
|
||||
self.assertEqual(get_next_message(p)["data"], 0)
|
||||
|
||||
p.psubscribe("{prefix}*:{state}".format(
|
||||
prefix=TASK_PREFIX, state=scheduling_state))
|
||||
# Receive acknowledgment.
|
||||
self.assertEqual(get_next_message(p)["data"], 1)
|
||||
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
|
||||
p.punsubscribe("{prefix}*:{state}".format(
|
||||
prefix=TASK_PREFIX, state=scheduling_state))
|
||||
# Receive acknowledgment.
|
||||
self.assertEqual(get_next_message(p)["data"], 0)
|
||||
|
||||
p.psubscribe("{prefix}{local_scheduler_id}:*".format(
|
||||
prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
|
||||
task_args = [b"task_id", scheduling_state,
|
||||
local_scheduler_id.encode("ascii"), b"task_spec"]
|
||||
self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args)
|
||||
# Receive the acknowledgement message.
|
||||
# Receive acknowledgment.
|
||||
self.assertEqual(get_next_message(p)["data"], 1)
|
||||
self.assertEqual(get_next_message(p)["data"], 2)
|
||||
self.assertEqual(get_next_message(p)["data"], 3)
|
||||
# Receive the actual data.
|
||||
for i in range(3):
|
||||
message = get_next_message(p)["data"]
|
||||
# Check that the notification object is correct.
|
||||
notification_object = TaskReply.GetRootAsTaskReply(message, 0)
|
||||
self.assertEqual(notification_object.TaskId(), b"task_id")
|
||||
self.assertEqual(notification_object.State(), scheduling_state)
|
||||
self.assertEqual(notification_object.LocalSchedulerId(),
|
||||
local_scheduler_id.encode("ascii"))
|
||||
self.assertEqual(notification_object.TaskSpec(), b"task_spec")
|
||||
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
|
||||
p.punsubscribe("{prefix}{local_scheduler_id}:*".format(
|
||||
prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
|
||||
# Receive acknowledgment.
|
||||
self.assertEqual(get_next_message(p)["data"], 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import print_function
|
||||
import pickle
|
||||
import redis
|
||||
|
||||
import ray
|
||||
from ray.utils import (decode, binary_to_object_id, binary_to_hex,
|
||||
hex_to_binary)
|
||||
|
||||
@@ -25,14 +26,21 @@ OBJECT_CHANNEL_PREFIX = "OC:"
|
||||
|
||||
# This mapping from integer to task state string must be kept up-to-date with
|
||||
# the scheduling_state enum in task.h.
|
||||
task_state_mapping = {
|
||||
1: "WAITING",
|
||||
2: "SCHEDULED",
|
||||
4: "QUEUED",
|
||||
8: "RUNNING",
|
||||
16: "DONE",
|
||||
32: "LOST",
|
||||
64: "RECONSTRUCTING"
|
||||
TASK_STATUS_WAITING = 1
|
||||
TASK_STATUS_SCHEDULED = 2
|
||||
TASK_STATUS_QUEUED = 4
|
||||
TASK_STATUS_RUNNING = 8
|
||||
TASK_STATUS_DONE = 16
|
||||
TASK_STATUS_LOST = 32
|
||||
TASK_STATUS_RECONSTRUCTING = 64
|
||||
TASK_STATUS_MAPPING = {
|
||||
TASK_STATUS_WAITING: "WAITING",
|
||||
TASK_STATUS_SCHEDULED: "SCHEDULED",
|
||||
TASK_STATUS_QUEUED: "QUEUED",
|
||||
TASK_STATUS_RUNNING: "RUNNING",
|
||||
TASK_STATUS_DONE: "DONE",
|
||||
TASK_STATUS_LOST: "LOST",
|
||||
TASK_STATUS_RECONSTRUCTING: "RECONSTRUCTING",
|
||||
}
|
||||
|
||||
|
||||
@@ -66,8 +74,54 @@ class GlobalState(object):
|
||||
"""
|
||||
self.redis_client = redis.StrictRedis(host=redis_ip_address,
|
||||
port=redis_port)
|
||||
self.redis_clients = []
|
||||
num_redis_shards = self.redis_client.get("NumRedisShards")
|
||||
if num_redis_shards is None:
|
||||
raise Exception("No entry found for NumRedisShards")
|
||||
num_redis_shards = int(num_redis_shards)
|
||||
if (num_redis_shards < 1):
|
||||
raise Exception("Expected at least one Redis shard, found "
|
||||
"{}.".format(num_redis_shards))
|
||||
|
||||
def _object_table(self, object_id_binary):
|
||||
ip_address_ports = self.redis_client.lrange("RedisShards", start=0, end=-1)
|
||||
if len(ip_address_ports) != num_redis_shards:
|
||||
raise Exception("Expected {} Redis shard addresses, found "
|
||||
"{}".format(num_redis_shards, len(ip_address_ports)))
|
||||
|
||||
for ip_address_port in ip_address_ports:
|
||||
shard_address, shard_port = ip_address_port.split(b":")
|
||||
self.redis_clients.append(redis.StrictRedis(host=shard_address,
|
||||
port=shard_port))
|
||||
|
||||
def _execute_command(self, key, *args):
|
||||
"""Execute a Redis command on the appropriate Redis shard based on key.
|
||||
|
||||
Args:
|
||||
key: The object ID or the task ID that the query is about.
|
||||
args: The command to run.
|
||||
|
||||
Returns:
|
||||
The value returned by the Redis command.
|
||||
"""
|
||||
client = self.redis_clients[key.redis_shard_hash() %
|
||||
len(self.redis_clients)]
|
||||
return client.execute_command(*args)
|
||||
|
||||
def _keys(self, pattern):
|
||||
"""Execute the KEYS command on all Redis shards.
|
||||
|
||||
Args:
|
||||
pattern: The KEYS pattern to query.
|
||||
|
||||
Returns:
|
||||
The concatenated list of results from all shards.
|
||||
"""
|
||||
result = []
|
||||
for client in self.redis_clients:
|
||||
result.extend(client.keys(pattern))
|
||||
return result
|
||||
|
||||
def _object_table(self, object_id):
|
||||
"""Fetch and parse the object table information for a single object ID.
|
||||
|
||||
Args:
|
||||
@@ -78,16 +132,18 @@ class GlobalState(object):
|
||||
A dictionary with information about the object ID in question.
|
||||
"""
|
||||
# Return information about a single object ID.
|
||||
object_locations = self.redis_client.execute_command(
|
||||
"RAY.OBJECT_TABLE_LOOKUP", object_id_binary)
|
||||
object_locations = self._execute_command(object_id,
|
||||
"RAY.OBJECT_TABLE_LOOKUP",
|
||||
object_id.id())
|
||||
if object_locations is not None:
|
||||
manager_ids = [binary_to_hex(manager_id)
|
||||
for manager_id in object_locations]
|
||||
else:
|
||||
manager_ids = None
|
||||
|
||||
result_table_response = self.redis_client.execute_command(
|
||||
"RAY.RESULT_TABLE_LOOKUP", object_id_binary)
|
||||
result_table_response = self._execute_command(object_id,
|
||||
"RAY.RESULT_TABLE_LOOKUP",
|
||||
object_id.id())
|
||||
result_table_message = ResultTableReply.GetRootAsResultTableReply(
|
||||
result_table_response, 0)
|
||||
|
||||
@@ -111,22 +167,21 @@ class GlobalState(object):
|
||||
self._check_connected()
|
||||
if object_id is not None:
|
||||
# Return information about a single object ID.
|
||||
return self._object_table(object_id.id())
|
||||
return self._object_table(object_id)
|
||||
else:
|
||||
# Return the entire object table.
|
||||
object_info_keys = self.redis_client.keys(OBJECT_INFO_PREFIX + "*")
|
||||
object_location_keys = self.redis_client.keys(
|
||||
OBJECT_LOCATION_PREFIX + "*")
|
||||
object_info_keys = self._keys(OBJECT_INFO_PREFIX + "*")
|
||||
object_location_keys = self._keys(OBJECT_LOCATION_PREFIX + "*")
|
||||
object_ids_binary = set(
|
||||
[key[len(OBJECT_INFO_PREFIX):] for key in object_info_keys] +
|
||||
[key[len(OBJECT_LOCATION_PREFIX):] for key in object_location_keys])
|
||||
results = {}
|
||||
for object_id_binary in object_ids_binary:
|
||||
results[binary_to_object_id(object_id_binary)] = self._object_table(
|
||||
object_id_binary)
|
||||
binary_to_object_id(object_id_binary))
|
||||
return results
|
||||
|
||||
def _task_table(self, task_id_binary):
|
||||
def _task_table(self, task_id):
|
||||
"""Fetch and parse the task table information for a single object task ID.
|
||||
|
||||
Args:
|
||||
@@ -135,12 +190,15 @@ class GlobalState(object):
|
||||
|
||||
Returns:
|
||||
A dictionary with information about the task ID in question.
|
||||
TASK_STATUS_MAPPING should be used to parse the "State" field into a
|
||||
human-readable string.
|
||||
"""
|
||||
task_table_response = self.redis_client.execute_command(
|
||||
"RAY.TASK_TABLE_GET", task_id_binary)
|
||||
task_table_response = self._execute_command(task_id,
|
||||
"RAY.TASK_TABLE_GET",
|
||||
task_id.id())
|
||||
if task_table_response is None:
|
||||
raise Exception("There is no entry for task ID {} in the task table."
|
||||
.format(binary_to_hex(task_id_binary)))
|
||||
.format(binary_to_hex(task_id.id())))
|
||||
task_table_message = TaskReply.GetRootAsTaskReply(task_table_response, 0)
|
||||
task_spec = task_table_message.TaskSpec()
|
||||
task_spec_message = TaskInfo.GetRootAsTaskInfo(task_spec, 0)
|
||||
@@ -167,7 +225,7 @@ class GlobalState(object):
|
||||
for i in range(task_spec_message.ReturnsLength())],
|
||||
"RequiredResources": required_resources}
|
||||
|
||||
return {"State": task_state_mapping[task_table_message.State()],
|
||||
return {"State": task_table_message.State(),
|
||||
"LocalSchedulerID": binary_to_hex(
|
||||
task_table_message.LocalSchedulerId()),
|
||||
"TaskSpec": task_spec_info}
|
||||
@@ -185,14 +243,15 @@ class GlobalState(object):
|
||||
"""
|
||||
self._check_connected()
|
||||
if task_id is not None:
|
||||
return self._task_table(hex_to_binary(task_id))
|
||||
task_id = ray.local_scheduler.ObjectID(hex_to_binary(task_id))
|
||||
return self._task_table(task_id)
|
||||
else:
|
||||
task_table_keys = self.redis_client.keys(TASK_PREFIX + "*")
|
||||
task_table_keys = self._keys(TASK_PREFIX + "*")
|
||||
results = {}
|
||||
for key in task_table_keys:
|
||||
task_id_binary = key[len(TASK_PREFIX):]
|
||||
results[binary_to_hex(task_id_binary)] = self._task_table(
|
||||
task_id_binary)
|
||||
ray.local_scheduler.ObjectID(task_id_binary))
|
||||
return results
|
||||
|
||||
def function_table(self, function_id=None):
|
||||
|
||||
@@ -7,9 +7,9 @@ import subprocess
|
||||
import time
|
||||
|
||||
|
||||
def start_global_scheduler(redis_address, node_ip_address, use_valgrind=False,
|
||||
use_profiler=False, stdout_file=None,
|
||||
stderr_file=None):
|
||||
def start_global_scheduler(redis_address, node_ip_address,
|
||||
use_valgrind=False, use_profiler=False,
|
||||
stdout_file=None, stderr_file=None):
|
||||
"""Start a global scheduler process.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -5,7 +5,6 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
import redis
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
@@ -17,6 +16,7 @@ import ray.plasma as plasma
|
||||
from ray.plasma.utils import create_object
|
||||
|
||||
from ray import services
|
||||
from ray.experimental import state
|
||||
|
||||
USE_VALGRIND = False
|
||||
PLASMA_STORE_MEMORY = 1000000000
|
||||
@@ -26,13 +26,6 @@ NUM_CLUSTER_NODES = 2
|
||||
NIL_WORKER_ID = 20 * b"\xff"
|
||||
NIL_ACTOR_ID = 20 * b"\xff"
|
||||
|
||||
# These constants must match the scheduling state enum in task.h.
|
||||
TASK_STATUS_WAITING = 1
|
||||
TASK_STATUS_SCHEDULED = 2
|
||||
TASK_STATUS_QUEUED = 4
|
||||
TASK_STATUS_RUNNING = 8
|
||||
TASK_STATUS_DONE = 16
|
||||
|
||||
# These constants are an implementation detail of ray_redis_module.cc, so this
|
||||
# must be kept in sync with that file.
|
||||
DB_CLIENT_PREFIX = "CL:"
|
||||
@@ -63,15 +56,17 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Start one Redis server and N pairs of (plasma, local_scheduler)
|
||||
node_ip_address = "127.0.0.1"
|
||||
redis_port, self.redis_process = services.start_redis(cleanup=False)
|
||||
redis_address = services.address(node_ip_address, redis_port)
|
||||
# Create a Redis client.
|
||||
self.redis_client = redis.StrictRedis(host=node_ip_address,
|
||||
port=redis_port)
|
||||
self.node_ip_address = "127.0.0.1"
|
||||
redis_address, redis_shards = services.start_redis(self.node_ip_address)
|
||||
redis_port = services.get_port(redis_address)
|
||||
time.sleep(0.1)
|
||||
# Create a client for the global state store.
|
||||
self.state = state.GlobalState()
|
||||
self.state._initialize_global_state(self.node_ip_address, redis_port)
|
||||
|
||||
# Start one global scheduler.
|
||||
self.p1 = global_scheduler.start_global_scheduler(
|
||||
redis_address, node_ip_address, use_valgrind=USE_VALGRIND)
|
||||
redis_address, self.node_ip_address, use_valgrind=USE_VALGRIND)
|
||||
self.plasma_store_pids = []
|
||||
self.plasma_manager_pids = []
|
||||
self.local_scheduler_pids = []
|
||||
@@ -89,7 +84,8 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
redis_address)
|
||||
plasma_manager_name, p3, plasma_manager_port = manager_info
|
||||
self.plasma_manager_pids.append(p3)
|
||||
plasma_address = "{}:{}".format(node_ip_address, plasma_manager_port)
|
||||
plasma_address = "{}:{}".format(self.node_ip_address,
|
||||
plasma_manager_port)
|
||||
plasma_client = plasma.PlasmaClient(plasma_store_name,
|
||||
plasma_manager_name)
|
||||
self.plasma_clients.append(plasma_client)
|
||||
@@ -116,7 +112,10 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
for p4 in self.local_scheduler_pids:
|
||||
self.assertEqual(p4.poll(), None)
|
||||
|
||||
self.assertEqual(self.redis_process.poll(), None)
|
||||
redis_processes = services.all_processes[
|
||||
services.PROCESS_TYPE_REDIS_SERVER]
|
||||
for redis_process in redis_processes:
|
||||
self.assertEqual(redis_process.poll(), None)
|
||||
|
||||
# Kill the global scheduler.
|
||||
if USE_VALGRIND:
|
||||
@@ -135,7 +134,9 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
p4.kill()
|
||||
# Kill Redis. In the event that we are using valgrind, this needs to happen
|
||||
# after we kill the global scheduler.
|
||||
self.redis_process.kill()
|
||||
while redis_processes:
|
||||
redis_process = redis_processes.pop()
|
||||
redis_process.kill()
|
||||
|
||||
def get_plasma_manager_id(self):
|
||||
"""Get the db_client_id with client_type equal to plasma_manager.
|
||||
@@ -150,11 +151,10 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
"""
|
||||
db_client_id = None
|
||||
|
||||
client_list = self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))
|
||||
for client_id in client_list:
|
||||
response = self.redis_client.hget(client_id, b"client_type")
|
||||
if response == b"plasma_manager":
|
||||
db_client_id = client_id
|
||||
client_list = self.state.client_table()[self.node_ip_address]
|
||||
for client in client_list:
|
||||
if client["ClientType"] == "plasma_manager":
|
||||
db_client_id = client["DBClientID"]
|
||||
break
|
||||
|
||||
return db_client_id
|
||||
@@ -178,18 +178,16 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
# There should be 2n+1 db clients: the global scheduler + one local
|
||||
# scheduler and one plasma per node.
|
||||
self.assertEqual(
|
||||
len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
|
||||
len(self.state.client_table()[self.node_ip_address]),
|
||||
2 * NUM_CLUSTER_NODES + 1)
|
||||
db_client_id = self.get_plasma_manager_id()
|
||||
assert(db_client_id is not None)
|
||||
assert(db_client_id.startswith(b"CL:"))
|
||||
db_client_id = db_client_id[len(b"CL:"):] # Remove the CL: prefix.
|
||||
|
||||
def test_integration_single_task(self):
|
||||
# There should be three db clients, the global scheduler, the local
|
||||
# scheduler, and the plasma manager.
|
||||
self.assertEqual(
|
||||
len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
|
||||
len(self.state.client_table()[self.node_ip_address]),
|
||||
2 * NUM_CLUSTER_NODES + 1)
|
||||
|
||||
num_return_vals = [0, 1, 2, 3, 5, 10]
|
||||
@@ -212,15 +210,15 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
# local scheduler
|
||||
num_retries = 10
|
||||
while num_retries > 0:
|
||||
task_entries = self.redis_client.keys("{}*".format(TASK_PREFIX))
|
||||
task_entries = self.state.task_table()
|
||||
self.assertLessEqual(len(task_entries), 1)
|
||||
if len(task_entries) == 1:
|
||||
task_contents = self.redis_client.hgetall(task_entries[0])
|
||||
task_status = int(task_contents[b"state"])
|
||||
self.assertTrue(task_status in [TASK_STATUS_WAITING,
|
||||
TASK_STATUS_SCHEDULED,
|
||||
TASK_STATUS_QUEUED])
|
||||
if task_status == TASK_STATUS_QUEUED:
|
||||
task_id, task = task_entries.popitem()
|
||||
task_status = task["State"]
|
||||
self.assertTrue(task_status in [state.TASK_STATUS_WAITING,
|
||||
state.TASK_STATUS_SCHEDULED,
|
||||
state.TASK_STATUS_QUEUED])
|
||||
if task_status == state.TASK_STATUS_QUEUED:
|
||||
break
|
||||
else:
|
||||
print(task_status)
|
||||
@@ -228,7 +226,7 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
num_retries -= 1
|
||||
time.sleep(1)
|
||||
|
||||
if num_retries <= 0 and task_status != TASK_STATUS_QUEUED:
|
||||
if num_retries <= 0 and task_status != state.TASK_STATUS_QUEUED:
|
||||
# Failed to submit and schedule a single task -- bail.
|
||||
self.tearDown()
|
||||
sys.exit(1)
|
||||
@@ -237,7 +235,7 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
# There should be three db clients, the global scheduler, the local
|
||||
# scheduler, and the plasma manager.
|
||||
self.assertEqual(
|
||||
len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
|
||||
len(self.state.client_table()[self.node_ip_address]),
|
||||
2 * NUM_CLUSTER_NODES + 1)
|
||||
num_return_vals = [0, 1, 2, 3, 5, 10]
|
||||
|
||||
@@ -264,34 +262,31 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
num_retries = 10
|
||||
num_tasks_done = 0
|
||||
while num_retries > 0:
|
||||
task_entries = self.redis_client.keys("{}*".format(TASK_PREFIX))
|
||||
task_entries = self.state.task_table()
|
||||
self.assertLessEqual(len(task_entries), num_tasks)
|
||||
# First, check if all tasks made it to Redis.
|
||||
if len(task_entries) == num_tasks:
|
||||
task_contents = [self.redis_client.hgetall(task_entries[i])
|
||||
for i in range(len(task_entries))]
|
||||
task_statuses = [int(contents[b"state"]) for contents in task_contents]
|
||||
self.assertTrue(all([status in [TASK_STATUS_WAITING,
|
||||
TASK_STATUS_SCHEDULED,
|
||||
TASK_STATUS_QUEUED]
|
||||
task_statuses = [task_entry["State"] for task_entry in
|
||||
task_entries.values()]
|
||||
self.assertTrue(all([status in [state.TASK_STATUS_WAITING,
|
||||
state.TASK_STATUS_SCHEDULED,
|
||||
state.TASK_STATUS_QUEUED]
|
||||
for status in task_statuses]))
|
||||
num_tasks_done = task_statuses.count(TASK_STATUS_QUEUED)
|
||||
num_tasks_scheduled = task_statuses.count(TASK_STATUS_SCHEDULED)
|
||||
num_tasks_waiting = task_statuses.count(TASK_STATUS_WAITING)
|
||||
num_tasks_done = task_statuses.count(state.TASK_STATUS_QUEUED)
|
||||
num_tasks_scheduled = task_statuses.count(state.TASK_STATUS_SCHEDULED)
|
||||
num_tasks_waiting = task_statuses.count(state.TASK_STATUS_WAITING)
|
||||
print("tasks in Redis = {}, tasks waiting = {}, tasks scheduled = {}, "
|
||||
"tasks queued = {}, retries left = {}"
|
||||
.format(len(task_entries), num_tasks_waiting,
|
||||
num_tasks_scheduled, num_tasks_done, num_retries))
|
||||
if all([status == TASK_STATUS_QUEUED for status in task_statuses]):
|
||||
if all([status == state.TASK_STATUS_QUEUED for status in
|
||||
task_statuses]):
|
||||
# We're done, so pass.
|
||||
break
|
||||
num_retries -= 1
|
||||
time.sleep(0.1)
|
||||
|
||||
if num_tasks_done != num_tasks:
|
||||
# At least one of the tasks failed to schedule.
|
||||
self.tearDown()
|
||||
sys.exit(2)
|
||||
self.assertEqual(num_tasks_done, num_tasks)
|
||||
|
||||
def test_integration_many_tasks_handler_sync(self):
|
||||
self.integration_many_tasks_helper(timesync=True)
|
||||
|
||||
+38
-40
@@ -12,11 +12,13 @@ import time
|
||||
import ray
|
||||
from ray.services import get_ip_address
|
||||
from ray.services import get_port
|
||||
from ray.utils import binary_to_object_id
|
||||
from ray.utils import binary_to_hex
|
||||
from ray.utils import hex_to_binary
|
||||
|
||||
# Import flatbuffer bindings.
|
||||
from ray.core.generated.SubscribeToDBClientTableReply \
|
||||
import SubscribeToDBClientTableReply
|
||||
from ray.core.generated.TaskReply import TaskReply
|
||||
from ray.core.generated.DriverTableMessage import DriverTableMessage
|
||||
|
||||
# These variables must be kept in sync with the C codebase.
|
||||
@@ -31,7 +33,6 @@ TASK_STATUS_LOST = 32
|
||||
PLASMA_MANAGER_HEARTBEAT_CHANNEL = b"plasma_managers"
|
||||
DRIVER_DEATH_CHANNEL = b"driver_deaths"
|
||||
# common/redis_module/ray_redis_module.cc
|
||||
TASK_PREFIX = "TT:"
|
||||
OBJECT_PREFIX = "OL:"
|
||||
DB_CLIENT_PREFIX = "CL:"
|
||||
DB_CLIENT_TABLE_NAME = b"db_clients"
|
||||
@@ -43,7 +44,7 @@ PLASMA_MANAGER_CLIENT_TYPE = b"plasma_manager"
|
||||
# Set up logging.
|
||||
logging.basicConfig()
|
||||
log = logging.getLogger()
|
||||
log.setLevel(logging.WARN)
|
||||
log.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class Monitor(object):
|
||||
@@ -70,7 +71,11 @@ class Monitor(object):
|
||||
"""
|
||||
def __init__(self, redis_address, redis_port):
|
||||
# Initialize the Redis clients.
|
||||
self.state = ray.experimental.state.GlobalState()
|
||||
self.state._initialize_global_state(redis_address, redis_port)
|
||||
self.redis = redis.StrictRedis(host=redis_address, port=redis_port, db=0)
|
||||
# TODO(swang): Update pubsub client to use ray.experimental.state once
|
||||
# subscriptions are implemented there.
|
||||
self.subscribe_client = self.redis.pubsub()
|
||||
self.subscribed = {}
|
||||
# Initialize data structures to keep track of the active database clients.
|
||||
@@ -97,23 +102,17 @@ class Monitor(object):
|
||||
TASK_STATUS_LOST. A local scheduler is deemed dead if it is in
|
||||
self.dead_local_schedulers.
|
||||
"""
|
||||
task_ids = self.redis.scan_iter(
|
||||
match="{prefix}*".format(prefix=TASK_PREFIX))
|
||||
tasks = self.state.task_table()
|
||||
num_tasks_updated = 0
|
||||
for task_id in task_ids:
|
||||
task_id = task_id[len(TASK_PREFIX):]
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_GET", task_id)
|
||||
# Parse the serialized task object.
|
||||
task_object = TaskReply.GetRootAsTaskReply(response, 0)
|
||||
local_scheduler_id = task_object.LocalSchedulerId()
|
||||
for task_id, task in tasks.items():
|
||||
# See if the corresponding local scheduler is alive.
|
||||
if local_scheduler_id in self.dead_local_schedulers:
|
||||
if task["LocalSchedulerID"] in self.dead_local_schedulers:
|
||||
# If the task is scheduled on a dead local scheduler, mark the task as
|
||||
# lost.
|
||||
ok = self.redis.execute_command("RAY.TASK_TABLE_UPDATE",
|
||||
task_id,
|
||||
TASK_STATUS_LOST,
|
||||
NIL_ID)
|
||||
key = binary_to_object_id(hex_to_binary(task_id))
|
||||
ok = self.state._execute_command(
|
||||
key, "RAY.TASK_TABLE_UPDATE", hex_to_binary(task_id),
|
||||
ray.experimental.state.TASK_STATUS_LOST, NIL_ID)
|
||||
if ok != b"OK":
|
||||
log.warn("Failed to update lost task for dead scheduler.")
|
||||
num_tasks_updated += 1
|
||||
@@ -129,19 +128,20 @@ class Monitor(object):
|
||||
"""
|
||||
# TODO(swang): Also kill the associated plasma store, since it's no longer
|
||||
# reachable without a plasma manager.
|
||||
object_ids = self.redis.scan_iter(
|
||||
match="{prefix}*".format(prefix=OBJECT_PREFIX))
|
||||
objects = self.state.object_table()
|
||||
num_objects_removed = 0
|
||||
for object_id in object_ids:
|
||||
object_id = object_id[len(OBJECT_PREFIX):]
|
||||
managers = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
object_id)
|
||||
for manager in managers:
|
||||
for object_id, obj in objects.items():
|
||||
manager_ids = obj["ManagerIDs"]
|
||||
if manager_ids is None:
|
||||
continue
|
||||
for manager in manager_ids:
|
||||
if manager in self.dead_plasma_managers:
|
||||
# If the object was on a dead plasma manager, remove that location
|
||||
# entry.
|
||||
ok = self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", object_id,
|
||||
manager)
|
||||
ok = self.state._execute_command(object_id,
|
||||
"RAY.OBJECT_TABLE_REMOVE",
|
||||
object_id.id(),
|
||||
hex_to_binary(manager))
|
||||
if ok != b"OK":
|
||||
log.warn("Failed to remove object location for dead plasma "
|
||||
"manager.")
|
||||
@@ -157,18 +157,16 @@ class Monitor(object):
|
||||
not miss any notifications for deleted clients that occurred before we
|
||||
subscribed.
|
||||
"""
|
||||
db_client_keys = self.redis.keys(
|
||||
"{prefix}*".format(prefix=DB_CLIENT_PREFIX))
|
||||
for db_client_key in db_client_keys:
|
||||
db_client_id = db_client_key[len(DB_CLIENT_PREFIX):]
|
||||
client_type, deleted = self.redis.hmget(db_client_key,
|
||||
[b"client_type", b"deleted"])
|
||||
deleted = bool(int(deleted))
|
||||
if deleted:
|
||||
if client_type == LOCAL_SCHEDULER_CLIENT_TYPE:
|
||||
self.dead_local_schedulers.add(db_client_id)
|
||||
elif client_type == PLASMA_MANAGER_CLIENT_TYPE:
|
||||
self.dead_plasma_managers.add(db_client_id)
|
||||
clients = self.state.client_table()
|
||||
for node_ip_address, node_clients in clients.items():
|
||||
for client in node_clients:
|
||||
db_client_id = client["DBClientID"]
|
||||
client_type = client["ClientType"]
|
||||
if client["Deleted"]:
|
||||
if client_type == LOCAL_SCHEDULER_CLIENT_TYPE:
|
||||
self.dead_local_schedulers.add(db_client_id)
|
||||
elif client_type == PLASMA_MANAGER_CLIENT_TYPE:
|
||||
self.dead_plasma_managers.add(db_client_id)
|
||||
|
||||
def subscribe_handler(self, channel, data):
|
||||
"""Handle a subscription success message from Redis.
|
||||
@@ -186,7 +184,7 @@ class Monitor(object):
|
||||
"""
|
||||
notification_object = (SubscribeToDBClientTableReply
|
||||
.GetRootAsSubscribeToDBClientTableReply(data, 0))
|
||||
db_client_id = notification_object.DbClientId()
|
||||
db_client_id = binary_to_hex(notification_object.DbClientId())
|
||||
client_type = notification_object.ClientType()
|
||||
is_insertion = notification_object.IsInsertion()
|
||||
|
||||
@@ -196,7 +194,7 @@ class Monitor(object):
|
||||
|
||||
# If the update was a deletion, add them to our accounting for dead
|
||||
# local schedulers and plasma managers.
|
||||
log.warn("Removed {}".format(client_type))
|
||||
log.warn("Removed {}, client ID {}".format(client_type, db_client_id))
|
||||
if client_type == LOCAL_SCHEDULER_CLIENT_TYPE:
|
||||
if db_client_id not in self.dead_local_schedulers:
|
||||
self.dead_local_schedulers.add(db_client_id)
|
||||
@@ -256,7 +254,7 @@ class Monitor(object):
|
||||
result = pipe.hget(local_scheduler_id, "gpus_in_use")
|
||||
gpus_in_use = dict() if result is None else json.loads(result)
|
||||
|
||||
driver_id_hex = ray.utils.binary_to_hex(driver_id)
|
||||
driver_id_hex = binary_to_hex(driver_id)
|
||||
if driver_id_hex in gpus_in_use:
|
||||
num_gpus_returned = gpus_in_use.pop(driver_id_hex)
|
||||
|
||||
|
||||
@@ -405,7 +405,8 @@ def start_plasma_manager(store_name, redis_address,
|
||||
"-m", plasma_manager_name,
|
||||
"-h", node_ip_address,
|
||||
"-p", str(plasma_manager_port),
|
||||
"-r", redis_address]
|
||||
"-r", redis_address,
|
||||
]
|
||||
if use_valgrind:
|
||||
process = subprocess.Popen(["valgrind",
|
||||
"--track-origins=yes",
|
||||
|
||||
@@ -480,7 +480,7 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
store_name1, self.p2 = plasma.start_plasma_store(use_valgrind=USE_VALGRIND)
|
||||
store_name2, self.p3 = plasma.start_plasma_store(use_valgrind=USE_VALGRIND)
|
||||
# Start a Redis server.
|
||||
redis_address = services.address("127.0.0.1", services.start_redis()[0])
|
||||
redis_address, _ = services.start_redis("127.0.0.1")
|
||||
# Start two PlasmaManagers.
|
||||
manager_name1, self.p4, self.port1 = plasma.start_plasma_manager(
|
||||
store_name1, redis_address, use_valgrind=USE_VALGRIND)
|
||||
@@ -778,8 +778,7 @@ class TestPlasmaManagerRecovery(unittest.TestCase):
|
||||
self.store_name, self.p2 = plasma.start_plasma_store(
|
||||
use_valgrind=USE_VALGRIND)
|
||||
# Start a Redis server.
|
||||
self.redis_address = services.address("127.0.0.1",
|
||||
services.start_redis()[0])
|
||||
self.redis_address, _ = services.start_redis("127.0.0.1")
|
||||
# Start a PlasmaManagers.
|
||||
manager_name, self.p3, self.port1 = plasma.start_plasma_manager(
|
||||
self.store_name,
|
||||
|
||||
+110
-45
@@ -240,15 +240,75 @@ def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5):
|
||||
"configured properly.")
|
||||
|
||||
|
||||
def start_redis(node_ip_address="127.0.0.1", port=None, num_retries=20,
|
||||
stdout_file=None, stderr_file=None, cleanup=True):
|
||||
"""Start a Redis server.
|
||||
def start_redis(node_ip_address,
|
||||
port=None,
|
||||
num_redis_shards=1,
|
||||
redirect_output=False,
|
||||
cleanup=True):
|
||||
"""Start the Redis global state store.
|
||||
|
||||
Args:
|
||||
node_ip_address: The IP address of the current node. This is only used for
|
||||
recording the log filenames in Redis.
|
||||
port (int): If provided, the primary Redis shard will be started on this
|
||||
port.
|
||||
num_redis_shards (int): If provided, the number of Redis shards to start,
|
||||
in addition to the primary one. The default value is one shard.
|
||||
cleanup (bool): True if using Ray in local mode. If cleanup is true, then
|
||||
all Redis processes started by this method will be killed by
|
||||
serices.cleanup() when the Python process that imported services exits.
|
||||
|
||||
Returns:
|
||||
A tuple of the address for the primary Redis shard and a list of addresses
|
||||
for the remaining shards.
|
||||
"""
|
||||
redis_stdout_file, redis_stderr_file = new_log_files(
|
||||
"redis", redirect_output)
|
||||
assigned_port, _ = start_redis_instance(
|
||||
node_ip_address=node_ip_address, port=port,
|
||||
stdout_file=redis_stdout_file, stderr_file=redis_stderr_file,
|
||||
cleanup=cleanup)
|
||||
if port is not None:
|
||||
assert assigned_port == port
|
||||
port = assigned_port
|
||||
redis_address = address(node_ip_address, port)
|
||||
|
||||
# Register the number of Redis shards in the primary shard, so that clients
|
||||
# know how many redis shards to expect under RedisShards.
|
||||
redis_client = redis.StrictRedis(host=node_ip_address, port=port)
|
||||
redis_client.set("NumRedisShards", str(num_redis_shards))
|
||||
|
||||
# Start other Redis shards listening on random ports. Each Redis shard logs
|
||||
# to a separate file, prefixed by "redis-<shard number>".
|
||||
redis_shards = []
|
||||
for i in range(num_redis_shards):
|
||||
redis_stdout_file, redis_stderr_file = new_log_files(
|
||||
"redis-{}".format(i), redirect_output)
|
||||
redis_shard_port, _ = start_redis_instance(
|
||||
node_ip_address=node_ip_address, stdout_file=redis_stdout_file,
|
||||
stderr_file=redis_stderr_file, cleanup=cleanup)
|
||||
shard_address = address(node_ip_address, redis_shard_port)
|
||||
redis_shards.append(shard_address)
|
||||
# Store redis shard information in the primary redis shard.
|
||||
redis_client.rpush("RedisShards", shard_address)
|
||||
|
||||
return redis_address, redis_shards
|
||||
|
||||
|
||||
def start_redis_instance(node_ip_address="127.0.0.1",
|
||||
port=None,
|
||||
num_retries=20,
|
||||
stdout_file=None,
|
||||
stderr_file=None,
|
||||
cleanup=True):
|
||||
"""Start a single Redis server.
|
||||
|
||||
Args:
|
||||
node_ip_address (str): The IP address of the current node. This is only
|
||||
used for recording the log filenames in Redis.
|
||||
port (int): If provided, start a Redis server with this port.
|
||||
num_retries (int): The number of times to attempt to start Redis.
|
||||
num_retries (int): The number of times to attempt to start Redis. If a port
|
||||
is provided, this defaults to 1.
|
||||
stdout_file: A file handle opened for writing to redirect stdout to. If no
|
||||
redirection should happen, then this should be None.
|
||||
stderr_file: A file handle opened for writing to redirect stderr to. If no
|
||||
@@ -275,8 +335,8 @@ def start_redis(node_ip_address="127.0.0.1", port=None, num_retries=20,
|
||||
assert os.path.isfile(redis_module)
|
||||
counter = 0
|
||||
if port is not None:
|
||||
if num_retries != 1:
|
||||
raise Exception("num_retries must be 1 if port is specified.")
|
||||
# If a port is specified, then try only once to connect.
|
||||
num_retries = 1
|
||||
else:
|
||||
port = new_port()
|
||||
while counter < num_retries:
|
||||
@@ -356,8 +416,8 @@ def start_log_monitor(redis_address, node_ip_address, stdout_file=None,
|
||||
[stdout_file, stderr_file])
|
||||
|
||||
|
||||
def start_global_scheduler(redis_address, node_ip_address, stdout_file=None,
|
||||
stderr_file=None, cleanup=True):
|
||||
def start_global_scheduler(redis_address, node_ip_address,
|
||||
stdout_file=None, stderr_file=None, cleanup=True):
|
||||
"""Start a global scheduler process.
|
||||
|
||||
Args:
|
||||
@@ -372,7 +432,8 @@ def start_global_scheduler(redis_address, node_ip_address, stdout_file=None,
|
||||
this process will be killed by services.cleanup() when the Python process
|
||||
that imported services exits.
|
||||
"""
|
||||
p = global_scheduler.start_global_scheduler(redis_address, node_ip_address,
|
||||
p = global_scheduler.start_global_scheduler(redis_address,
|
||||
node_ip_address,
|
||||
stdout_file=stdout_file,
|
||||
stderr_file=stderr_file)
|
||||
if cleanup:
|
||||
@@ -545,10 +606,11 @@ def start_local_scheduler(redis_address,
|
||||
return local_scheduler_name
|
||||
|
||||
|
||||
def start_objstore(node_ip_address, redis_address, object_manager_port=None,
|
||||
store_stdout_file=None, store_stderr_file=None,
|
||||
manager_stdout_file=None, manager_stderr_file=None,
|
||||
cleanup=True, objstore_memory=None):
|
||||
def start_objstore(node_ip_address, redis_address,
|
||||
object_manager_port=None, store_stdout_file=None,
|
||||
store_stderr_file=None, manager_stdout_file=None,
|
||||
manager_stderr_file=None, cleanup=True,
|
||||
objstore_memory=None):
|
||||
"""This method starts an object store process.
|
||||
|
||||
Args:
|
||||
@@ -704,13 +766,14 @@ def start_monitor(redis_address, node_ip_address, stdout_file=None,
|
||||
|
||||
def start_ray_processes(address_info=None,
|
||||
node_ip_address="127.0.0.1",
|
||||
redis_port=None,
|
||||
num_workers=None,
|
||||
num_local_schedulers=1,
|
||||
num_redis_shards=1,
|
||||
worker_path=None,
|
||||
cleanup=True,
|
||||
redirect_output=False,
|
||||
include_global_scheduler=False,
|
||||
include_redis=False,
|
||||
include_log_monitor=False,
|
||||
include_webui=False,
|
||||
start_workers_from_local_scheduler=True,
|
||||
@@ -723,12 +786,17 @@ def start_ray_processes(address_info=None,
|
||||
that have already been started. If provided, address_info will be
|
||||
modified to include processes that are newly started.
|
||||
node_ip_address (str): The IP address of this node.
|
||||
redis_port (int): The port that the primary Redis shard should listen to.
|
||||
If None, then a random port will be chosen. If the key "redis_address" is
|
||||
in address_info, then this argument will be ignored.
|
||||
num_workers (int): The number of workers to start.
|
||||
num_local_schedulers (int): The total number of local schedulers required.
|
||||
This is also the total number of object stores required. This method will
|
||||
start new instances of local schedulers and object stores until there are
|
||||
num_local_schedulers existing instances of each, including ones already
|
||||
registered with the given address_info.
|
||||
num_redis_shards: The number of Redis shards to start in addition to the
|
||||
primary Redis shard.
|
||||
worker_path (str): The path of the source code that will be run by the
|
||||
worker.
|
||||
cleanup (bool): If cleanup is true, then the processes started here will be
|
||||
@@ -738,8 +806,6 @@ def start_ray_processes(address_info=None,
|
||||
file.
|
||||
include_global_scheduler (bool): If include_global_scheduler is True, then
|
||||
start a global scheduler process.
|
||||
include_redis (bool): If include_redis is True, then start a Redis server
|
||||
process.
|
||||
include_log_monitor (bool): If True, then start a log monitor to monitor
|
||||
the log files for all processes on this node and push their contents to
|
||||
Redis.
|
||||
@@ -785,29 +851,14 @@ def start_ray_processes(address_info=None,
|
||||
# warning messages when it starts up. Instead of suppressing the output, we
|
||||
# should address the warnings.
|
||||
redis_address = address_info.get("redis_address")
|
||||
if include_redis:
|
||||
redis_stdout_file, redis_stderr_file = new_log_files("redis",
|
||||
redirect_output)
|
||||
if redis_address is None:
|
||||
# Start a Redis server. The start_redis method will choose a random port.
|
||||
redis_port, _ = start_redis(node_ip_address,
|
||||
stdout_file=redis_stdout_file,
|
||||
stderr_file=redis_stderr_file,
|
||||
cleanup=cleanup)
|
||||
redis_address = address(node_ip_address, redis_port)
|
||||
address_info["redis_address"] = redis_address
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
# A Redis address was provided, so start a Redis server with the given
|
||||
# port. TODO(rkn): We should check that the IP address corresponds to the
|
||||
# machine that this method is running on.
|
||||
redis_port = get_port(redis_address)
|
||||
new_redis_port, _ = start_redis(port=int(redis_port),
|
||||
num_retries=1,
|
||||
stdout_file=redis_stdout_file,
|
||||
stderr_file=redis_stderr_file,
|
||||
cleanup=cleanup)
|
||||
assert redis_port == new_redis_port
|
||||
redis_shards = address_info.get("redis_shards", [])
|
||||
if redis_address is None:
|
||||
redis_address, redis_shards = start_redis(
|
||||
node_ip_address, port=redis_port, num_redis_shards=num_redis_shards,
|
||||
redirect_output=redirect_output, cleanup=cleanup)
|
||||
address_info["redis_address"] = redis_address
|
||||
time.sleep(0.1)
|
||||
|
||||
# Start monitoring the processes.
|
||||
monitor_stdout_file, monitor_stderr_file = new_log_files("monitor",
|
||||
redirect_output)
|
||||
@@ -815,9 +866,14 @@ def start_ray_processes(address_info=None,
|
||||
node_ip_address,
|
||||
stdout_file=monitor_stdout_file,
|
||||
stderr_file=monitor_stderr_file)
|
||||
else:
|
||||
if redis_address is None:
|
||||
raise Exception("Redis address expected")
|
||||
|
||||
if redis_shards == []:
|
||||
# Get redis shards from primary redis instance.
|
||||
redis_ip_address, redis_port = redis_address.split(":")
|
||||
redis_client = redis.StrictRedis(host=redis_ip_address, port=redis_port)
|
||||
redis_shards = redis_client.lrange("RedisShards", start=0, end=-1)
|
||||
redis_shards = [shard.decode("ascii") for shard in redis_shards]
|
||||
address_info["redis_shards"] = redis_shards
|
||||
|
||||
# Start the log monitor, if necessary.
|
||||
if include_log_monitor:
|
||||
@@ -1005,6 +1061,7 @@ def start_ray_node(node_ip_address,
|
||||
|
||||
def start_ray_head(address_info=None,
|
||||
node_ip_address="127.0.0.1",
|
||||
redis_port=None,
|
||||
num_workers=0,
|
||||
num_local_schedulers=1,
|
||||
worker_path=None,
|
||||
@@ -1012,7 +1069,8 @@ def start_ray_head(address_info=None,
|
||||
redirect_output=False,
|
||||
start_workers_from_local_scheduler=True,
|
||||
num_cpus=None,
|
||||
num_gpus=None):
|
||||
num_gpus=None,
|
||||
num_redis_shards=None):
|
||||
"""Start Ray in local mode.
|
||||
|
||||
Args:
|
||||
@@ -1020,6 +1078,9 @@ def start_ray_head(address_info=None,
|
||||
that have already been started. If provided, address_info will be
|
||||
modified to include processes that are newly started.
|
||||
node_ip_address (str): The IP address of this node.
|
||||
redis_port (int): The port that the primary Redis shard should listen to.
|
||||
If None, then a random port will be chosen. If the key "redis_address" is
|
||||
in address_info, then this argument will be ignored.
|
||||
num_workers (int): The number of workers to start.
|
||||
num_local_schedulers (int): The total number of local schedulers required.
|
||||
This is also the total number of object stores required. This method will
|
||||
@@ -1038,14 +1099,18 @@ def start_ray_head(address_info=None,
|
||||
Python.
|
||||
num_cpus (int): number of cpus to configure the local scheduler with.
|
||||
num_gpus (int): number of gpus to configure the local scheduler with.
|
||||
num_redis_shards: The number of Redis shards to start in addition to the
|
||||
primary Redis shard.
|
||||
|
||||
Returns:
|
||||
A dictionary of the address information for the processes that were
|
||||
started.
|
||||
"""
|
||||
num_redis_shards = 1 if num_redis_shards is None else num_redis_shards
|
||||
return start_ray_processes(
|
||||
address_info=address_info,
|
||||
node_ip_address=node_ip_address,
|
||||
redis_port=redis_port,
|
||||
num_workers=num_workers,
|
||||
num_local_schedulers=num_local_schedulers,
|
||||
worker_path=worker_path,
|
||||
@@ -1053,11 +1118,11 @@ def start_ray_head(address_info=None,
|
||||
redirect_output=redirect_output,
|
||||
include_global_scheduler=True,
|
||||
include_log_monitor=True,
|
||||
include_redis=True,
|
||||
include_webui=False,
|
||||
start_workers_from_local_scheduler=start_workers_from_local_scheduler,
|
||||
num_cpus=num_cpus,
|
||||
num_gpus=num_gpus)
|
||||
num_gpus=num_gpus,
|
||||
num_redis_shards=num_redis_shards)
|
||||
|
||||
|
||||
def new_log_files(name, redirect_output):
|
||||
|
||||
+28
-15
@@ -992,7 +992,8 @@ def _init(address_info=None,
|
||||
redirect_output=False,
|
||||
start_workers_from_local_scheduler=True,
|
||||
num_cpus=None,
|
||||
num_gpus=None):
|
||||
num_gpus=None,
|
||||
num_redis_shards=None):
|
||||
"""Helper method to connect to an existing Ray cluster or start a new one.
|
||||
|
||||
This method handles two cases. Either a Ray cluster already exists and we
|
||||
@@ -1002,8 +1003,9 @@ def _init(address_info=None,
|
||||
Args:
|
||||
address_info (dict): A dictionary with address information for processes in
|
||||
a partially-started Ray cluster. If start_ray_local=True, any processes
|
||||
not in this dictionary will be started. If provided, address_info will be
|
||||
modified to include processes that are newly started.
|
||||
not in this dictionary will be started. If provided, an updated
|
||||
address_info dictionary will be returned to include processes that are
|
||||
newly started.
|
||||
start_ray_local (bool): If True then this will start any processes not
|
||||
already in address_info, including Redis, a global scheduler, local
|
||||
scheduler(s), object store(s), and worker(s). It will also kill these
|
||||
@@ -1028,6 +1030,8 @@ def _init(address_info=None,
|
||||
be configured with.
|
||||
num_gpus: A list containing the number of GPUs the local schedulers should
|
||||
be configured with.
|
||||
num_redis_shards: The number of Redis shards to start in addition to the
|
||||
primary Redis shard.
|
||||
|
||||
Returns:
|
||||
Address information about the started processes.
|
||||
@@ -1069,6 +1073,8 @@ def _init(address_info=None,
|
||||
num_local_schedulers = len(local_schedulers)
|
||||
else:
|
||||
num_local_schedulers = 1
|
||||
# Use 1 additional redis shard if num_redis_shards is not provided.
|
||||
num_redis_shards = 1 if num_redis_shards is None else num_redis_shards
|
||||
# Start the scheduler, object store, and some workers. These will be killed
|
||||
# by the call to cleanup(), which happens when the Python script exits.
|
||||
address_info = services.start_ray_head(
|
||||
@@ -1079,20 +1085,24 @@ def _init(address_info=None,
|
||||
redirect_output=redirect_output,
|
||||
start_workers_from_local_scheduler=start_workers_from_local_scheduler,
|
||||
num_cpus=num_cpus,
|
||||
num_gpus=num_gpus)
|
||||
num_gpus=num_gpus,
|
||||
num_redis_shards=num_redis_shards)
|
||||
else:
|
||||
if redis_address is None:
|
||||
raise Exception("If start_ray_local=False, then redis_address must be "
|
||||
"provided.")
|
||||
raise Exception("When connecting to an existing cluster, redis_address "
|
||||
"must be provided.")
|
||||
if num_workers is not None:
|
||||
raise Exception("If start_ray_local=False, then num_workers must not be "
|
||||
"provided.")
|
||||
raise Exception("When connecting to an existing cluster, num_workers "
|
||||
"must not be provided.")
|
||||
if num_local_schedulers is not None:
|
||||
raise Exception("If start_ray_local=False, then num_local_schedulers "
|
||||
"must not be provided.")
|
||||
raise Exception("When connecting to an existing cluster, "
|
||||
"num_local_schedulers must not be provided.")
|
||||
if num_cpus is not None or num_gpus is not None:
|
||||
raise Exception("If start_ray_local=False, then num_cpus and num_gpus "
|
||||
"must not be provided.")
|
||||
raise Exception("When connecting to an existing cluster, num_cpus and "
|
||||
"num_gpus must not be provided.")
|
||||
if num_redis_shards is not None:
|
||||
raise Exception("When connecting to an existing cluster, "
|
||||
"num_redis_shards must not be provided.")
|
||||
# Get the node IP address if one is not provided.
|
||||
if node_ip_address is None:
|
||||
node_ip_address = services.get_node_ip_address(redis_address)
|
||||
@@ -1121,7 +1131,7 @@ def _init(address_info=None,
|
||||
|
||||
def init(redis_address=None, node_ip_address=None, object_id_seed=None,
|
||||
num_workers=None, driver_mode=SCRIPT_MODE, redirect_output=False,
|
||||
num_cpus=None, num_gpus=None):
|
||||
num_cpus=None, num_gpus=None, num_redis_shards=None):
|
||||
"""Either connect to an existing Ray cluster or start one and connect to it.
|
||||
|
||||
This method handles two cases. Either a Ray cluster already exists and we
|
||||
@@ -1148,6 +1158,8 @@ def init(redis_address=None, node_ip_address=None, object_id_seed=None,
|
||||
configured with.
|
||||
num_gpus (int): Number of gpus the user wishes all local schedulers to be
|
||||
configured with.
|
||||
num_redis_shards: The number of Redis shards to start in addition to the
|
||||
primary Redis shard.
|
||||
|
||||
Returns:
|
||||
Address information about the started processes.
|
||||
@@ -1161,7 +1173,7 @@ def init(redis_address=None, node_ip_address=None, object_id_seed=None,
|
||||
return _init(address_info=info, start_ray_local=(redis_address is None),
|
||||
num_workers=num_workers, driver_mode=driver_mode,
|
||||
redirect_output=redirect_output, num_cpus=num_cpus,
|
||||
num_gpus=num_gpus)
|
||||
num_gpus=num_gpus, num_redis_shards=num_redis_shards)
|
||||
|
||||
|
||||
def cleanup(worker=global_worker):
|
||||
@@ -1577,7 +1589,8 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker,
|
||||
ray.local_scheduler.ObjectID(NIL_ACTOR_ID),
|
||||
worker.actor_counters[actor_id],
|
||||
[0, 0])
|
||||
worker.redis_client.execute_command(
|
||||
global_state._execute_command(
|
||||
driver_task.task_id(),
|
||||
"RAY.TASK_TABLE_ADD",
|
||||
driver_task.task_id().id(),
|
||||
TASK_STATUS_RUNNING,
|
||||
|
||||
+16
-11
@@ -15,6 +15,9 @@ parser.add_argument("--redis-address", required=False, type=str,
|
||||
help="the address to use for connecting to Redis")
|
||||
parser.add_argument("--redis-port", required=False, type=str,
|
||||
help="the port to use for starting Redis")
|
||||
parser.add_argument("--num-redis-shards", required=False, type=int,
|
||||
help=("the number of additional Redis shards to use in "
|
||||
"addition to the primary Redis shard"))
|
||||
parser.add_argument("--object-manager-port", required=False, type=int,
|
||||
help="the port to use for starting the object manager")
|
||||
parser.add_argument("--num-workers", required=False, type=int,
|
||||
@@ -75,23 +78,22 @@ if __name__ == "__main__":
|
||||
print("Using IP address {} for this node.".format(node_ip_address))
|
||||
|
||||
address_info = {}
|
||||
# Use the provided Redis port if there is one.
|
||||
if args.redis_port is not None:
|
||||
address_info["redis_address"] = "{}:{}".format(node_ip_address,
|
||||
args.redis_port)
|
||||
# Use the provided object manager port if there is one.
|
||||
if args.object_manager_port is not None:
|
||||
address_info["object_manager_ports"] = [args.object_manager_port]
|
||||
if address_info == {}:
|
||||
address_info = None
|
||||
|
||||
address_info = services.start_ray_head(address_info=address_info,
|
||||
node_ip_address=node_ip_address,
|
||||
num_workers=args.num_workers,
|
||||
cleanup=False,
|
||||
redirect_output=True,
|
||||
num_cpus=args.num_cpus,
|
||||
num_gpus=args.num_gpus)
|
||||
address_info = services.start_ray_head(
|
||||
address_info=address_info,
|
||||
node_ip_address=node_ip_address,
|
||||
redis_port=args.redis_port,
|
||||
num_workers=args.num_workers,
|
||||
cleanup=False,
|
||||
redirect_output=True,
|
||||
num_cpus=args.num_cpus,
|
||||
num_gpus=args.num_gpus,
|
||||
num_redis_shards=args.num_redis_shards)
|
||||
print(address_info)
|
||||
print("\nStarted Ray on this node. You can add additional nodes to the "
|
||||
"cluster by calling\n\n"
|
||||
@@ -113,6 +115,9 @@ if __name__ == "__main__":
|
||||
if args.redis_address is None:
|
||||
raise Exception("If --head is not passed in, --redis-address must be "
|
||||
"provided.")
|
||||
if args.num_redis_shards is not None:
|
||||
raise Exception("If --head is not passed in, --num-redis-shards must "
|
||||
"not be provided.")
|
||||
redis_ip_address, redis_port = args.redis_address.split(":")
|
||||
# Wait for the Redis server to be started. And throw an exception if we
|
||||
# can't connect to it.
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
|
||||
#include "common.h"
|
||||
|
||||
#define DB_CLIENT_PREFIX "CL:"
|
||||
|
||||
/**
|
||||
* Convert an object ID to a flatbuffer string.
|
||||
*
|
||||
|
||||
@@ -172,6 +172,14 @@ static PyObject *PyObjectID_richcompare(PyObjectID *self,
|
||||
return result;
|
||||
}
|
||||
|
||||
static PyObject *PyObjectID_redis_shard_hash(PyObjectID *self) {
|
||||
/* NOTE: The hash function used here must match the one in get_redis_context
|
||||
* in src/common/state/redis.cc. Changes to the hash function should only be
|
||||
* made through UniqueIDHasher in src/common/common.h */
|
||||
UniqueIDHasher hash;
|
||||
return PyLong_FromSize_t(hash(self->object_id));
|
||||
}
|
||||
|
||||
static long PyObjectID_hash(PyObjectID *self) {
|
||||
PyObject *tuple = PyTuple_New(UNIQUE_ID_SIZE);
|
||||
for (int i = 0; i < UNIQUE_ID_SIZE; ++i) {
|
||||
@@ -201,6 +209,8 @@ static PyObject *PyObjectID___reduce__(PyObjectID *self) {
|
||||
static PyMethodDef PyObjectID_methods[] = {
|
||||
{"id", (PyCFunction) PyObjectID_id, METH_NOARGS,
|
||||
"Return the hash associated with this ObjectID"},
|
||||
{"redis_shard_hash", (PyCFunction) PyObjectID_redis_shard_hash, METH_NOARGS,
|
||||
"Return the redis shard that this ObjectID is associated with"},
|
||||
{"hex", (PyCFunction) PyObjectID_hex, METH_NOARGS,
|
||||
"Return the object ID as a string in hex."},
|
||||
{"__reduce__", (PyCFunction) PyObjectID___reduce__, METH_NOARGS,
|
||||
|
||||
@@ -67,11 +67,14 @@ void RayLogger_log(RayLogger *logger,
|
||||
if (logger->is_direct) {
|
||||
DBHandle *db = (DBHandle *) logger->conn;
|
||||
/* Fill in the client ID and send the message to Redis. */
|
||||
int status = redisAsyncCommand(
|
||||
db->context, NULL, NULL, utstring_body(formatted_message),
|
||||
(char *) db->client.id, sizeof(db->client.id));
|
||||
if ((status == REDIS_ERR) || db->context->err) {
|
||||
LOG_REDIS_DEBUG(db->context, "error while logging message to log table");
|
||||
|
||||
redisAsyncContext *context = get_redis_context(db, db->client);
|
||||
|
||||
int status =
|
||||
redisAsyncCommand(context, NULL, NULL, utstring_body(formatted_message),
|
||||
(char *) db->client.id, sizeof(db->client.id));
|
||||
if ((status == REDIS_ERR) || context->err) {
|
||||
LOG_REDIS_DEBUG(context, "error while logging message to log table");
|
||||
}
|
||||
} else {
|
||||
/* If we don't own a Redis connection, we leave our client
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
#include "net.h"
|
||||
|
||||
#include <arpa/inet.h>
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "common.h"
|
||||
|
||||
int parse_ip_addr_port(const char *ip_addr_port, char *ip_addr, int *port) {
|
||||
@@ -11,3 +15,10 @@ int parse_ip_addr_port(const char *ip_addr_port, char *ip_addr, int *port) {
|
||||
*port = atoi(port_str);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* Return true if the ip address is valid. */
|
||||
bool valid_ip_address(const std::string &ip_address) {
|
||||
struct sockaddr_in sa;
|
||||
int result = inet_pton(AF_INET, ip_address.c_str(), &sa.sin_addr);
|
||||
return result == 1;
|
||||
}
|
||||
|
||||
@@ -27,7 +27,6 @@
|
||||
* TODO(pcm): Fill this out.
|
||||
*/
|
||||
|
||||
#define DB_CLIENT_PREFIX "CL:"
|
||||
#define OBJECT_INFO_PREFIX "OI:"
|
||||
#define OBJECT_LOCATION_PREFIX "OL:"
|
||||
#define OBJECT_NOTIFICATION_PREFIX "ON:"
|
||||
@@ -929,6 +928,10 @@ int TaskTableWrite(RedisModuleCtx *ctx,
|
||||
RedisModuleCallReply *reply =
|
||||
RedisModule_Call(ctx, "PUBLISH", "ss", publish_topic, publish_message);
|
||||
|
||||
/* See how many clients received this publish. */
|
||||
long long num_clients = RedisModule_CallReplyInteger(reply);
|
||||
CHECKM(num_clients <= 1, "Published to %lld clients.", num_clients);
|
||||
|
||||
RedisModule_FreeString(ctx, publish_message);
|
||||
RedisModule_FreeString(ctx, publish_topic);
|
||||
if (existing_task_spec != NULL) {
|
||||
@@ -938,6 +941,18 @@ int TaskTableWrite(RedisModuleCtx *ctx,
|
||||
if (reply == NULL) {
|
||||
return RedisModule_ReplyWithError(ctx, "PUBLISH unsuccessful");
|
||||
}
|
||||
|
||||
if (num_clients == 0) {
|
||||
LOG_WARN(
|
||||
"No subscribers received this publish. This most likely means that "
|
||||
"either the intended recipient has not subscribed yet or that the "
|
||||
"pubsub connection to the intended recipient has been broken.");
|
||||
/* This reply will be received by redis_task_table_update_callback or
|
||||
* redis_task_table_add_task_callback in redis.cc, which will then reissue
|
||||
* the command. */
|
||||
return RedisModule_ReplyWithError(ctx,
|
||||
"No subscribers received message.");
|
||||
}
|
||||
}
|
||||
|
||||
RedisModule_ReplyWithSimpleString(ctx, "OK");
|
||||
|
||||
@@ -11,6 +11,9 @@ typedef struct DBHandle DBHandle;
|
||||
*
|
||||
* @param db_address The hostname to use to connect to the database.
|
||||
* @param db_port The port to use to connect to the database.
|
||||
* @param db_shards_addresses The list of database shard IP addresses.
|
||||
* @param db_shards_ports The list of database shard ports, in the same order
|
||||
* as db_shards_addresses.
|
||||
* @param client_type The type of this client.
|
||||
* @param node_ip_address The hostname of the client that is connecting.
|
||||
* @param num_args The number of extra arguments that should be supplied. This
|
||||
@@ -21,8 +24,8 @@ typedef struct DBHandle DBHandle;
|
||||
* @return This returns a handle to the database, which must be freed with
|
||||
* db_disconnect after use.
|
||||
*/
|
||||
DBHandle *db_connect(const char *db_address,
|
||||
int db_port,
|
||||
DBHandle *db_connect(const std::string &db_primary_address,
|
||||
int db_primary_port,
|
||||
const char *client_type,
|
||||
const char *node_ip_address,
|
||||
int num_args,
|
||||
|
||||
@@ -29,11 +29,22 @@ void db_client_table_remove(DBHandle *db_handle,
|
||||
* ==== Subscribing to the db client table ====
|
||||
*/
|
||||
|
||||
/* An entry in the db client table. */
|
||||
typedef struct {
|
||||
/** The database client ID. */
|
||||
DBClientID id;
|
||||
/** The database client type. */
|
||||
const char *client_type;
|
||||
/** An optional auxiliary address for an associated database client on the
|
||||
* same node. */
|
||||
const char *aux_address;
|
||||
/** Whether or not the database client exists. If this is false for an entry,
|
||||
* then it will never again be true. */
|
||||
bool is_insertion;
|
||||
} DBClient;
|
||||
|
||||
/* Callback for subscribing to the db client table. */
|
||||
typedef void (*db_client_table_subscribe_callback)(DBClientID db_client_id,
|
||||
const char *client_type,
|
||||
const char *aux_address,
|
||||
bool is_insertion,
|
||||
typedef void (*db_client_table_subscribe_callback)(DBClient *db_client,
|
||||
void *user_context);
|
||||
|
||||
/**
|
||||
|
||||
+475
-162
@@ -4,6 +4,7 @@
|
||||
#include <stdbool.h>
|
||||
#include <stdlib.h>
|
||||
#include <unistd.h>
|
||||
#include <vector>
|
||||
|
||||
extern "C" {
|
||||
/* Including hiredis here is necessary on Windows for typedefs used in ae.h. */
|
||||
@@ -26,6 +27,7 @@ extern "C" {
|
||||
#include "event_loop.h"
|
||||
#include "redis.h"
|
||||
#include "io.h"
|
||||
#include "net.h"
|
||||
|
||||
#include "format/common_generated.h"
|
||||
|
||||
@@ -77,52 +79,128 @@ extern int usleep(useconds_t usec);
|
||||
do { \
|
||||
} while (0)
|
||||
|
||||
DBHandle *db_connect(const char *db_address,
|
||||
int db_port,
|
||||
const char *client_type,
|
||||
const char *node_ip_address,
|
||||
int num_args,
|
||||
const char **args) {
|
||||
/* Check that the number of args is even. These args will be passed to the
|
||||
* RAY.CONNECT Redis command, which takes arguments in pairs. */
|
||||
if (num_args % 2 != 0) {
|
||||
LOG_FATAL("The number of extra args must be divisible by two.");
|
||||
}
|
||||
redisAsyncContext *get_redis_context(DBHandle *db, UniqueID id) {
|
||||
/* NOTE: The hash function used here must match the one in
|
||||
* PyObjectID_redis_shard_hash in src/common/lib/python/common_extension.cc.
|
||||
* Changes to the hash function should only be made through
|
||||
* UniqueIDHasher in src/common/common.h */
|
||||
UniqueIDHasher index;
|
||||
return db->contexts[index(id) % db->contexts.size()];
|
||||
}
|
||||
|
||||
DBHandle *db = (DBHandle *) malloc(sizeof(DBHandle));
|
||||
/* Sync connection for initial handshake */
|
||||
redisAsyncContext *get_redis_subscribe_context(DBHandle *db, UniqueID id) {
|
||||
UniqueIDHasher index;
|
||||
return db->subscribe_contexts[index(id) % db->subscribe_contexts.size()];
|
||||
}
|
||||
|
||||
void get_redis_shards(redisContext *context,
|
||||
std::vector<std::string> &db_shards_addresses,
|
||||
std::vector<int> &db_shards_ports) {
|
||||
/* Get the total number of Redis shards in the system. */
|
||||
int num_attempts = 0;
|
||||
redisReply *reply = NULL;
|
||||
while (num_attempts < REDIS_DB_CONNECT_RETRIES) {
|
||||
/* Try to read the number of Redis shards from the primary shard. If the
|
||||
* entry is present, exit. */
|
||||
reply = (redisReply *) redisCommand(context, "GET NumRedisShards");
|
||||
if (reply->type != REDIS_REPLY_NIL) {
|
||||
break;
|
||||
}
|
||||
|
||||
/* Sleep for a little, and try again if the entry isn't there yet. */
|
||||
freeReplyObject(reply);
|
||||
usleep(REDIS_DB_CONNECT_WAIT_MS * 1000);
|
||||
num_attempts++;
|
||||
continue;
|
||||
}
|
||||
CHECKM(num_attempts < REDIS_DB_CONNECT_RETRIES,
|
||||
"No entry found for NumRedisShards");
|
||||
CHECKM(reply->type == REDIS_REPLY_STRING,
|
||||
"Expected string, found Redis type %d for NumRedisShards",
|
||||
reply->type);
|
||||
int num_redis_shards = atoi(reply->str);
|
||||
CHECKM(num_redis_shards >= 1, "Expected at least one Redis shard, found %d.",
|
||||
num_redis_shards);
|
||||
freeReplyObject(reply);
|
||||
|
||||
/* Get the addresses of all of the Redis shards. */
|
||||
num_attempts = 0;
|
||||
while (num_attempts < REDIS_DB_CONNECT_RETRIES) {
|
||||
/* Try to read the Redis shard locations from the primary shard. If we find
|
||||
* that all of them are present, exit. */
|
||||
reply = (redisReply *) redisCommand(context, "LRANGE RedisShards 0 -1");
|
||||
if (reply->elements == num_redis_shards) {
|
||||
break;
|
||||
}
|
||||
|
||||
/* Sleep for a little, and try again if not all Redis shard addresses have
|
||||
* been added yet. */
|
||||
freeReplyObject(reply);
|
||||
usleep(REDIS_DB_CONNECT_WAIT_MS * 1000);
|
||||
num_attempts++;
|
||||
continue;
|
||||
}
|
||||
CHECKM(num_attempts < REDIS_DB_CONNECT_RETRIES,
|
||||
"Expected %d Redis shard addresses, found %d", num_redis_shards,
|
||||
(int) reply->elements);
|
||||
|
||||
/* Parse the Redis shard addresses. */
|
||||
char db_shard_address[16];
|
||||
int db_shard_port;
|
||||
for (int i = 0; i < reply->elements; ++i) {
|
||||
/* Parse the shard addresses and ports. */
|
||||
CHECK(reply->element[i]->type == REDIS_REPLY_STRING);
|
||||
CHECK(parse_ip_addr_port(reply->element[i]->str, db_shard_address,
|
||||
&db_shard_port) == 0);
|
||||
db_shards_addresses.push_back(std::string(db_shard_address));
|
||||
db_shards_ports.push_back(db_shard_port);
|
||||
}
|
||||
freeReplyObject(reply);
|
||||
}
|
||||
|
||||
void db_connect_shard(const std::string &db_address,
|
||||
int db_port,
|
||||
DBClientID client,
|
||||
const char *client_type,
|
||||
const char *node_ip_address,
|
||||
int num_args,
|
||||
const char **args,
|
||||
DBHandle *db,
|
||||
redisAsyncContext **context_out,
|
||||
redisAsyncContext **subscribe_context_out,
|
||||
redisContext **sync_context_out) {
|
||||
/* Synchronous connection for initial handshake */
|
||||
redisReply *reply;
|
||||
int connection_attempts = 0;
|
||||
redisContext *context = redisConnect(db_address, db_port);
|
||||
while (context == NULL || context->err) {
|
||||
redisContext *sync_context = redisConnect(db_address.c_str(), db_port);
|
||||
while (sync_context == NULL || sync_context->err) {
|
||||
if (connection_attempts >= REDIS_DB_CONNECT_RETRIES) {
|
||||
break;
|
||||
}
|
||||
LOG_WARN("Failed to connect to Redis, retrying.");
|
||||
/* Sleep for a little. */
|
||||
usleep(REDIS_DB_CONNECT_WAIT_MS * 1000);
|
||||
context = redisConnect(db_address, db_port);
|
||||
sync_context = redisConnect(db_address.c_str(), db_port);
|
||||
connection_attempts += 1;
|
||||
}
|
||||
CHECK_REDIS_CONNECT(redisContext, context,
|
||||
CHECK_REDIS_CONNECT(redisContext, sync_context,
|
||||
"could not establish synchronous connection to redis "
|
||||
"%s:%d",
|
||||
db_address, db_port);
|
||||
db_address.c_str(), db_port);
|
||||
/* Configure Redis to generate keyspace notifications for list events. This
|
||||
* should only need to be done once (by whoever started Redis), but since
|
||||
* Redis may be started in multiple places (e.g., for testing or when starting
|
||||
* processes by hand), it is easier to do it multiple times. */
|
||||
reply = (redisReply *) redisCommand(context,
|
||||
reply = (redisReply *) redisCommand(sync_context,
|
||||
"CONFIG SET notify-keyspace-events Kl");
|
||||
CHECKM(reply != NULL, "db_connect failed on CONFIG SET");
|
||||
freeReplyObject(reply);
|
||||
/* Also configure Redis to not run in protected mode, so clients on other
|
||||
* hosts can connect to it. */
|
||||
reply = (redisReply *) redisCommand(context, "CONFIG SET protected-mode no");
|
||||
reply =
|
||||
(redisReply *) redisCommand(sync_context, "CONFIG SET protected-mode no");
|
||||
CHECKM(reply != NULL, "db_connect failed on CONFIG SET");
|
||||
freeReplyObject(reply);
|
||||
/* Create a client ID for this client. */
|
||||
DBClientID client = globally_unique_id();
|
||||
|
||||
/* Construct the argument arrays for RAY.CONNECT. */
|
||||
int argc = num_args + 4;
|
||||
@@ -133,7 +211,7 @@ DBHandle *db_connect(const char *db_address,
|
||||
argvlen[0] = strlen(argv[0]);
|
||||
/* Set the client ID argument. */
|
||||
argv[1] = (char *) client.id;
|
||||
argvlen[1] = sizeof(db->client.id);
|
||||
argvlen[1] = sizeof(client.id);
|
||||
/* Set the node IP address argument. */
|
||||
argv[2] = node_ip_address;
|
||||
argvlen[2] = strlen(node_ip_address);
|
||||
@@ -152,7 +230,7 @@ DBHandle *db_connect(const char *db_address,
|
||||
|
||||
/* Register this client with Redis. RAY.CONNECT is a custom Redis command that
|
||||
* we've defined. */
|
||||
reply = (redisReply *) redisCommandArgv(context, argc, argv, argvlen);
|
||||
reply = (redisReply *) redisCommandArgv(sync_context, argc, argv, argvlen);
|
||||
CHECKM(reply != NULL, "db_connect failed on RAY.CONNECT");
|
||||
CHECK(reply->type != REDIS_REPLY_ERROR);
|
||||
CHECK(strcmp(reply->str, "OK") == 0);
|
||||
@@ -160,25 +238,75 @@ DBHandle *db_connect(const char *db_address,
|
||||
free(argv);
|
||||
free(argvlen);
|
||||
|
||||
*sync_context_out = sync_context;
|
||||
|
||||
/* Establish connection for control data. */
|
||||
redisAsyncContext *context = redisAsyncConnect(db_address.c_str(), db_port);
|
||||
CHECK_REDIS_CONNECT(redisAsyncContext, context,
|
||||
"could not establish asynchronous connection to redis "
|
||||
"%s:%d",
|
||||
db_address.c_str(), db_port);
|
||||
context->data = (void *) db;
|
||||
*context_out = context;
|
||||
|
||||
/* Establish async connection for subscription. */
|
||||
redisAsyncContext *subscribe_context =
|
||||
redisAsyncConnect(db_address.c_str(), db_port);
|
||||
CHECK_REDIS_CONNECT(redisAsyncContext, subscribe_context,
|
||||
"could not establish asynchronous subscription "
|
||||
"connection to redis %s:%d",
|
||||
db_address.c_str(), db_port);
|
||||
subscribe_context->data = (void *) db;
|
||||
*subscribe_context_out = subscribe_context;
|
||||
}
|
||||
|
||||
DBHandle *db_connect(const std::string &db_primary_address,
|
||||
int db_primary_port,
|
||||
const char *client_type,
|
||||
const char *node_ip_address,
|
||||
int num_args,
|
||||
const char **args) {
|
||||
/* Check that the number of args is even. These args will be passed to the
|
||||
* RAY.CONNECT Redis command, which takes arguments in pairs. */
|
||||
if (num_args % 2 != 0) {
|
||||
LOG_FATAL("The number of extra args must be divisible by two.");
|
||||
}
|
||||
|
||||
/* Create a client ID for this client. */
|
||||
DBClientID client = globally_unique_id();
|
||||
|
||||
DBHandle *db = new DBHandle();
|
||||
|
||||
db->client_type = strdup(client_type);
|
||||
db->client = client;
|
||||
db->db_client_cache = NULL;
|
||||
db->sync_context = context;
|
||||
|
||||
/* Establish async connection */
|
||||
db->context = redisAsyncConnect(db_address, db_port);
|
||||
CHECK_REDIS_CONNECT(redisAsyncContext, db->context,
|
||||
"could not establish asynchronous connection to redis "
|
||||
"%s:%d",
|
||||
db_address, db_port);
|
||||
db->context->data = (void *) db;
|
||||
/* Establish async connection for subscription */
|
||||
db->sub_context = redisAsyncConnect(db_address, db_port);
|
||||
CHECK_REDIS_CONNECT(redisAsyncContext, db->sub_context,
|
||||
"could not establish asynchronous subscription "
|
||||
"connection to redis %s:%d",
|
||||
db_address, db_port);
|
||||
db->sub_context->data = (void *) db;
|
||||
redisAsyncContext *context;
|
||||
redisAsyncContext *subscribe_context;
|
||||
redisContext *sync_context;
|
||||
|
||||
/* Connect to the primary redis instance. */
|
||||
db_connect_shard(db_primary_address, db_primary_port, client, client_type,
|
||||
node_ip_address, num_args, args, db, &context,
|
||||
&subscribe_context, &sync_context);
|
||||
db->context = context;
|
||||
db->subscribe_context = subscribe_context;
|
||||
db->sync_context = sync_context;
|
||||
|
||||
/* Get the shard locations. */
|
||||
std::vector<std::string> db_shards_addresses;
|
||||
std::vector<int> db_shards_ports;
|
||||
get_redis_shards(db->sync_context, db_shards_addresses, db_shards_ports);
|
||||
CHECKM(db_shards_addresses.size() > 0, "No Redis shards found");
|
||||
/* Connect to the shards. */
|
||||
for (int i = 0; i < db_shards_addresses.size(); ++i) {
|
||||
db_connect_shard(db_shards_addresses[i], db_shards_ports[i], client,
|
||||
client_type, node_ip_address, num_args, args, db, &context,
|
||||
&subscribe_context, &sync_context);
|
||||
db->contexts.push_back(context);
|
||||
db->subscribe_contexts.push_back(subscribe_context);
|
||||
redisFree(sync_context);
|
||||
}
|
||||
|
||||
return db;
|
||||
}
|
||||
@@ -193,10 +321,17 @@ void db_disconnect(DBHandle *db) {
|
||||
CHECK(strcmp(reply->str, "OK") == 0);
|
||||
freeReplyObject(reply);
|
||||
|
||||
/* Clean up the Redis connection state. */
|
||||
/* Clean up the primary Redis connection state. */
|
||||
redisFree(db->sync_context);
|
||||
redisAsyncFree(db->context);
|
||||
redisAsyncFree(db->sub_context);
|
||||
redisAsyncFree(db->subscribe_context);
|
||||
|
||||
/* Clean up the Redis shards. */
|
||||
CHECK(db->contexts.size() == db->subscribe_contexts.size());
|
||||
for (int i = 0; i < db->contexts.size(); ++i) {
|
||||
redisAsyncFree(db->contexts[i]);
|
||||
redisAsyncFree(db->subscribe_contexts[i]);
|
||||
}
|
||||
|
||||
/* Clean up memory. */
|
||||
DBClientCacheEntry *e, *tmp;
|
||||
@@ -206,21 +341,36 @@ void db_disconnect(DBHandle *db) {
|
||||
free(e);
|
||||
}
|
||||
free(db->client_type);
|
||||
free(db);
|
||||
delete db;
|
||||
}
|
||||
|
||||
void db_attach(DBHandle *db, event_loop *loop, bool reattach) {
|
||||
db->loop = loop;
|
||||
/* Attach primary redis instance to the event loop. */
|
||||
int err = redisAeAttach(loop, db->context);
|
||||
/* If the database is reattached in the tests, redis normally gives
|
||||
* an error which we can safely ignore. */
|
||||
if (!reattach) {
|
||||
CHECKM(err == REDIS_OK, "failed to attach the event loop");
|
||||
}
|
||||
err = redisAeAttach(loop, db->sub_context);
|
||||
err = redisAeAttach(loop, db->subscribe_context);
|
||||
if (!reattach) {
|
||||
CHECKM(err == REDIS_OK, "failed to attach the event loop");
|
||||
}
|
||||
/* Attach other redis shards to the event loop. */
|
||||
CHECK(db->contexts.size() == db->subscribe_contexts.size());
|
||||
for (int i = 0; i < db->contexts.size(); ++i) {
|
||||
int err = redisAeAttach(loop, db->contexts[i]);
|
||||
/* If the database is reattached in the tests, redis normally gives
|
||||
* an error which we can safely ignore. */
|
||||
if (!reattach) {
|
||||
CHECKM(err == REDIS_OK, "failed to attach the event loop");
|
||||
}
|
||||
err = redisAeAttach(loop, db->subscribe_contexts[i]);
|
||||
if (!reattach) {
|
||||
CHECKM(err == REDIS_OK, "failed to attach the event loop");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -264,14 +414,16 @@ void redis_object_table_add(TableCallbackData *callback_data) {
|
||||
int64_t object_size = info->object_size;
|
||||
unsigned char *digest = info->digest;
|
||||
|
||||
redisAsyncContext *context = get_redis_context(db, obj_id);
|
||||
|
||||
int status = redisAsyncCommand(
|
||||
db->context, redis_object_table_add_callback,
|
||||
context, redis_object_table_add_callback,
|
||||
(void *) callback_data->timer_id, "RAY.OBJECT_TABLE_ADD %b %ld %b %b",
|
||||
obj_id.id, sizeof(obj_id.id), object_size, digest, (size_t) DIGEST_SIZE,
|
||||
db->client.id, sizeof(db->client.id));
|
||||
|
||||
if ((status == REDIS_ERR) || db->context->err) {
|
||||
LOG_REDIS_DEBUG(db->context, "error in redis_object_table_add");
|
||||
if ((status == REDIS_ERR) || context->err) {
|
||||
LOG_REDIS_DEBUG(context, "error in redis_object_table_add");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -309,13 +461,16 @@ void redis_object_table_remove(TableCallbackData *callback_data) {
|
||||
if (client_id == NULL) {
|
||||
client_id = &db->client;
|
||||
}
|
||||
|
||||
redisAsyncContext *context = get_redis_context(db, obj_id);
|
||||
|
||||
int status = redisAsyncCommand(
|
||||
db->context, redis_object_table_remove_callback,
|
||||
context, redis_object_table_remove_callback,
|
||||
(void *) callback_data->timer_id, "RAY.OBJECT_TABLE_REMOVE %b %b",
|
||||
obj_id.id, sizeof(obj_id.id), client_id->id, sizeof(client_id->id));
|
||||
|
||||
if ((status == REDIS_ERR) || db->context->err) {
|
||||
LOG_REDIS_DEBUG(db->context, "error in redis_object_table_remove");
|
||||
if ((status == REDIS_ERR) || context->err) {
|
||||
LOG_REDIS_DEBUG(context, "error in redis_object_table_remove");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -324,12 +479,15 @@ void redis_object_table_lookup(TableCallbackData *callback_data) {
|
||||
DBHandle *db = callback_data->db_handle;
|
||||
|
||||
ObjectID obj_id = callback_data->id;
|
||||
int status = redisAsyncCommand(
|
||||
db->context, redis_object_table_lookup_callback,
|
||||
(void *) callback_data->timer_id, "RAY.OBJECT_TABLE_LOOKUP %b", obj_id.id,
|
||||
sizeof(obj_id.id));
|
||||
if ((status == REDIS_ERR) || db->context->err) {
|
||||
LOG_REDIS_DEBUG(db->context, "error in object_table lookup");
|
||||
|
||||
redisAsyncContext *context = get_redis_context(db, obj_id);
|
||||
|
||||
int status = redisAsyncCommand(context, redis_object_table_lookup_callback,
|
||||
(void *) callback_data->timer_id,
|
||||
"RAY.OBJECT_TABLE_LOOKUP %b", obj_id.id,
|
||||
sizeof(obj_id.id));
|
||||
if ((status == REDIS_ERR) || context->err) {
|
||||
LOG_REDIS_DEBUG(context, "error in object_table lookup");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -358,13 +516,15 @@ void redis_result_table_add(TableCallbackData *callback_data) {
|
||||
ResultTableAddInfo *info = (ResultTableAddInfo *) callback_data->data;
|
||||
int is_put = info->is_put ? 1 : 0;
|
||||
|
||||
redisAsyncContext *context = get_redis_context(db, id);
|
||||
|
||||
/* Add the result entry to the result table. */
|
||||
int status = redisAsyncCommand(
|
||||
db->context, redis_result_table_add_callback,
|
||||
context, redis_result_table_add_callback,
|
||||
(void *) callback_data->timer_id, "RAY.RESULT_TABLE_ADD %b %b %d", id.id,
|
||||
sizeof(id.id), info->task_id.id, sizeof(info->task_id.id), is_put);
|
||||
if ((status == REDIS_ERR) || db->context->err) {
|
||||
LOG_REDIS_DEBUG(db->context, "Error in result table add");
|
||||
if ((status == REDIS_ERR) || context->err) {
|
||||
LOG_REDIS_DEBUG(context, "Error in result table add");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -423,12 +583,13 @@ void redis_result_table_lookup(TableCallbackData *callback_data) {
|
||||
CHECK(callback_data);
|
||||
DBHandle *db = callback_data->db_handle;
|
||||
ObjectID id = callback_data->id;
|
||||
redisAsyncContext *context = get_redis_context(db, id);
|
||||
int status =
|
||||
redisAsyncCommand(db->context, redis_result_table_lookup_callback,
|
||||
redisAsyncCommand(context, redis_result_table_lookup_callback,
|
||||
(void *) callback_data->timer_id,
|
||||
"RAY.RESULT_TABLE_LOOKUP %b", id.id, sizeof(id.id));
|
||||
if ((status == REDIS_ERR) || db->context->err) {
|
||||
LOG_REDIS_DEBUG(db->context, "Error in result table lookup");
|
||||
if ((status == REDIS_ERR) || context->err) {
|
||||
LOG_REDIS_DEBUG(context, "Error in result table lookup");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -586,28 +747,33 @@ void redis_object_table_subscribe_to_notifications(
|
||||
* src/common/redismodule/ray_redis_module.cc. */
|
||||
const char *object_channel_prefix = "OC:";
|
||||
const char *object_channel_bcast = "BCAST";
|
||||
int status = REDIS_OK;
|
||||
/* Subscribe to notifications from the object table. This uses the client ID
|
||||
* as the channel name so this channel is specific to this client. TODO(rkn):
|
||||
* The channel name should probably be the client ID with some prefix. */
|
||||
CHECKM(callback_data->data != NULL,
|
||||
"Object table subscribe data passed as NULL.");
|
||||
if (((ObjectTableSubscribeData *) (callback_data->data))->subscribe_all) {
|
||||
/* Subscribe to the object broadcast channel. */
|
||||
status = redisAsyncCommand(
|
||||
db->sub_context, object_table_redis_subscribe_to_notifications_callback,
|
||||
(void *) callback_data->timer_id, "SUBSCRIBE %s%s",
|
||||
object_channel_prefix, object_channel_bcast);
|
||||
} else {
|
||||
status = redisAsyncCommand(
|
||||
db->sub_context, object_table_redis_subscribe_to_notifications_callback,
|
||||
(void *) callback_data->timer_id, "SUBSCRIBE %s%b",
|
||||
object_channel_prefix, db->client.id, sizeof(db->client.id));
|
||||
}
|
||||
for (int i = 0; i < db->subscribe_contexts.size(); ++i) {
|
||||
int status = REDIS_OK;
|
||||
/* Subscribe to notifications from the object table. This uses the client ID
|
||||
* as the channel name so this channel is specific to this client.
|
||||
* TODO(rkn):
|
||||
* The channel name should probably be the client ID with some prefix. */
|
||||
CHECKM(callback_data->data != NULL,
|
||||
"Object table subscribe data passed as NULL.");
|
||||
if (((ObjectTableSubscribeData *) (callback_data->data))->subscribe_all) {
|
||||
/* Subscribe to the object broadcast channel. */
|
||||
status = redisAsyncCommand(
|
||||
db->subscribe_contexts[i],
|
||||
object_table_redis_subscribe_to_notifications_callback,
|
||||
(void *) callback_data->timer_id, "SUBSCRIBE %s%s",
|
||||
object_channel_prefix, object_channel_bcast);
|
||||
} else {
|
||||
status = redisAsyncCommand(
|
||||
db->subscribe_contexts[i],
|
||||
object_table_redis_subscribe_to_notifications_callback,
|
||||
(void *) callback_data->timer_id, "SUBSCRIBE %s%b",
|
||||
object_channel_prefix, db->client.id, sizeof(db->client.id));
|
||||
}
|
||||
|
||||
if ((status == REDIS_ERR) || db->sub_context->err) {
|
||||
LOG_REDIS_DEBUG(db->sub_context,
|
||||
"error in redis_object_table_subscribe_to_notifications");
|
||||
if ((status == REDIS_ERR) || db->subscribe_contexts[i]->err) {
|
||||
LOG_REDIS_DEBUG(db->subscribe_contexts[i],
|
||||
"error in redis_object_table_subscribe_to_notifications");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -633,31 +799,33 @@ void redis_object_table_request_notifications(
|
||||
int num_object_ids = request_data->num_object_ids;
|
||||
ObjectID *object_ids = request_data->object_ids;
|
||||
|
||||
/* Create the arguments for the Redis command. */
|
||||
int num_args = 1 + 1 + num_object_ids;
|
||||
const char **argv = (const char **) malloc(sizeof(char *) * num_args);
|
||||
size_t *argvlen = (size_t *) malloc(sizeof(size_t) * num_args);
|
||||
/* Set the command name argument. */
|
||||
argv[0] = "RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS";
|
||||
argvlen[0] = strlen(argv[0]);
|
||||
/* Set the client ID argument. */
|
||||
argv[1] = (char *) db->client.id;
|
||||
argvlen[1] = sizeof(db->client.id);
|
||||
/* Set the object ID arguments. */
|
||||
for (int i = 0; i < num_object_ids; ++i) {
|
||||
argv[2 + i] = (char *) object_ids[i].id;
|
||||
argvlen[2 + i] = sizeof(object_ids[i].id);
|
||||
}
|
||||
redisAsyncContext *context = get_redis_context(db, object_ids[i]);
|
||||
|
||||
int status = redisAsyncCommandArgv(
|
||||
db->context, redis_object_table_request_notifications_callback,
|
||||
(void *) callback_data->timer_id, num_args, argv, argvlen);
|
||||
free(argv);
|
||||
free(argvlen);
|
||||
/* Create the arguments for the Redis command. */
|
||||
int num_args = 1 + 1 + 1;
|
||||
const char **argv = (const char **) malloc(sizeof(char *) * num_args);
|
||||
size_t *argvlen = (size_t *) malloc(sizeof(size_t) * num_args);
|
||||
/* Set the command name argument. */
|
||||
argv[0] = "RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS";
|
||||
argvlen[0] = strlen(argv[0]);
|
||||
/* Set the client ID argument. */
|
||||
argv[1] = (char *) db->client.id;
|
||||
argvlen[1] = sizeof(db->client.id);
|
||||
/* Set the object ID arguments. */
|
||||
argv[2] = (char *) object_ids[i].id;
|
||||
argvlen[2] = sizeof(object_ids[i].id);
|
||||
|
||||
if ((status == REDIS_ERR) || db->context->err) {
|
||||
LOG_REDIS_DEBUG(db->context,
|
||||
"error in redis_object_table_subscribe_to_notifications");
|
||||
int status = redisAsyncCommandArgv(
|
||||
context, redis_object_table_request_notifications_callback,
|
||||
(void *) callback_data->timer_id, num_args, argv, argvlen);
|
||||
free(argv);
|
||||
free(argvlen);
|
||||
|
||||
if ((status == REDIS_ERR) || context->err) {
|
||||
LOG_REDIS_DEBUG(context,
|
||||
"error in redis_object_table_subscribe_to_notifications");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -690,12 +858,14 @@ void redis_task_table_get_task(TableCallbackData *callback_data) {
|
||||
CHECK(callback_data->data == NULL);
|
||||
TaskID task_id = callback_data->id;
|
||||
|
||||
int status = redisAsyncCommand(
|
||||
db->context, redis_task_table_get_task_callback,
|
||||
(void *) callback_data->timer_id, "RAY.TASK_TABLE_GET %b", task_id.id,
|
||||
sizeof(task_id.id));
|
||||
if ((status == REDIS_ERR) || db->context->err) {
|
||||
LOG_REDIS_DEBUG(db->context, "error in redis_task_table_get_task");
|
||||
redisAsyncContext *context = get_redis_context(db, task_id);
|
||||
|
||||
int status = redisAsyncCommand(context, redis_task_table_get_task_callback,
|
||||
(void *) callback_data->timer_id,
|
||||
"RAY.TASK_TABLE_GET %b", task_id.id,
|
||||
sizeof(task_id.id));
|
||||
if ((status == REDIS_ERR) || context->err) {
|
||||
LOG_REDIS_DEBUG(context, "error in redis_task_table_get_task");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -706,6 +876,36 @@ void redis_task_table_add_task_callback(redisAsyncContext *c,
|
||||
|
||||
/* Do some minimal checking. */
|
||||
redisReply *reply = (redisReply *) r;
|
||||
|
||||
/* If the publish which happens inside of the call to RAY.TASK_TABLE_ADD was
|
||||
* not received by any subscribers, then reissue the command. TODO(rkn): This
|
||||
* entire if block should be temporary. Once we address the problem where in
|
||||
* which a global scheduler may publish a task to a local scheduler before the
|
||||
* local scheduler has subscribed to the relevant channel, we shouldn't need
|
||||
* this block any more. */
|
||||
if (reply->type == REDIS_REPLY_ERROR &&
|
||||
strcmp(reply->str, "No subscribers received message.") == 0) {
|
||||
Task *task = (Task *) callback_data->data;
|
||||
TaskID task_id = Task_task_id(task);
|
||||
DBClientID local_scheduler_id = Task_local_scheduler(task);
|
||||
redisAsyncContext *context = get_redis_context(db, task_id);
|
||||
int state = Task_state(task);
|
||||
TaskSpec *spec = Task_task_spec(task);
|
||||
/* Reissue the command. */
|
||||
CHECKM(task != NULL, "NULL task passed to redis_task_table_add_task.");
|
||||
int status = redisAsyncCommand(
|
||||
context, redis_task_table_add_task_callback,
|
||||
(void *) callback_data->timer_id, "RAY.TASK_TABLE_ADD %b %d %b %b",
|
||||
task_id.id, sizeof(task_id.id), state, local_scheduler_id.id,
|
||||
sizeof(local_scheduler_id.id), spec, Task_task_spec_size(task));
|
||||
if ((status == REDIS_ERR) || context->err) {
|
||||
LOG_REDIS_DEBUG(context, "error in redis_task_table_add_task");
|
||||
}
|
||||
/* Since we are reissuing the same command with the same callback data,
|
||||
* return early to avoid freeing the callback data. */
|
||||
return;
|
||||
}
|
||||
|
||||
CHECKM(strcmp(reply->str, "OK") == 0, "reply->str is %s", reply->str);
|
||||
/* Call the done callback if there is one. */
|
||||
if (callback_data->done_callback != NULL) {
|
||||
@@ -722,17 +922,18 @@ void redis_task_table_add_task(TableCallbackData *callback_data) {
|
||||
Task *task = (Task *) callback_data->data;
|
||||
TaskID task_id = Task_task_id(task);
|
||||
DBClientID local_scheduler_id = Task_local_scheduler(task);
|
||||
redisAsyncContext *context = get_redis_context(db, task_id);
|
||||
int state = Task_state(task);
|
||||
TaskSpec *spec = Task_task_spec(task);
|
||||
|
||||
CHECKM(task != NULL, "NULL task passed to redis_task_table_add_task.");
|
||||
int status = redisAsyncCommand(
|
||||
db->context, redis_task_table_add_task_callback,
|
||||
context, redis_task_table_add_task_callback,
|
||||
(void *) callback_data->timer_id, "RAY.TASK_TABLE_ADD %b %d %b %b",
|
||||
task_id.id, sizeof(task_id.id), state, local_scheduler_id.id,
|
||||
sizeof(local_scheduler_id.id), spec, Task_task_spec_size(task));
|
||||
if ((status == REDIS_ERR) || db->context->err) {
|
||||
LOG_REDIS_DEBUG(db->context, "error in redis_task_table_add_task");
|
||||
if ((status == REDIS_ERR) || context->err) {
|
||||
LOG_REDIS_DEBUG(context, "error in redis_task_table_add_task");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -743,6 +944,35 @@ void redis_task_table_update_callback(redisAsyncContext *c,
|
||||
|
||||
/* Do some minimal checking. */
|
||||
redisReply *reply = (redisReply *) r;
|
||||
|
||||
/* If the publish which happens inside of the call to RAY.TASK_TABLE_UPDATE
|
||||
* was not received by any subscribers, then reissue the command. TODO(rkn):
|
||||
* This entire if block should be temporary. Once we address the problem where
|
||||
* in which a global scheduler may publish a task to a local scheduler before
|
||||
* the local scheduler has subscribed to the relevant channel, we shouldn't
|
||||
* need this block any more. */
|
||||
if (reply->type == REDIS_REPLY_ERROR &&
|
||||
strcmp(reply->str, "No subscribers received message.") == 0) {
|
||||
Task *task = (Task *) callback_data->data;
|
||||
TaskID task_id = Task_task_id(task);
|
||||
redisAsyncContext *context = get_redis_context(db, task_id);
|
||||
DBClientID local_scheduler_id = Task_local_scheduler(task);
|
||||
int state = Task_state(task);
|
||||
/* Reissue the command. */
|
||||
CHECKM(task != NULL, "NULL task passed to redis_task_table_update.");
|
||||
int status = redisAsyncCommand(
|
||||
context, redis_task_table_update_callback,
|
||||
(void *) callback_data->timer_id, "RAY.TASK_TABLE_UPDATE %b %d %b",
|
||||
task_id.id, sizeof(task_id.id), state, local_scheduler_id.id,
|
||||
sizeof(local_scheduler_id.id));
|
||||
if ((status == REDIS_ERR) || context->err) {
|
||||
LOG_REDIS_DEBUG(context, "error in redis_task_table_update");
|
||||
}
|
||||
/* Since we are reissuing the same command with the same callback data,
|
||||
* return early to avoid freeing the callback data. */
|
||||
return;
|
||||
}
|
||||
|
||||
CHECKM(strcmp(reply->str, "OK") == 0, "reply->str is %s", reply->str);
|
||||
/* Call the done callback if there is one. */
|
||||
if (callback_data->done_callback != NULL) {
|
||||
@@ -758,17 +988,18 @@ void redis_task_table_update(TableCallbackData *callback_data) {
|
||||
DBHandle *db = callback_data->db_handle;
|
||||
Task *task = (Task *) callback_data->data;
|
||||
TaskID task_id = Task_task_id(task);
|
||||
redisAsyncContext *context = get_redis_context(db, task_id);
|
||||
DBClientID local_scheduler_id = Task_local_scheduler(task);
|
||||
int state = Task_state(task);
|
||||
|
||||
CHECKM(task != NULL, "NULL task passed to redis_task_table_update.");
|
||||
int status = redisAsyncCommand(
|
||||
db->context, redis_task_table_update_callback,
|
||||
context, redis_task_table_update_callback,
|
||||
(void *) callback_data->timer_id, "RAY.TASK_TABLE_UPDATE %b %d %b",
|
||||
task_id.id, sizeof(task_id.id), state, local_scheduler_id.id,
|
||||
sizeof(local_scheduler_id.id));
|
||||
if ((status == REDIS_ERR) || db->context->err) {
|
||||
LOG_REDIS_DEBUG(db->context, "error in redis_task_table_update");
|
||||
if ((status == REDIS_ERR) || context->err) {
|
||||
LOG_REDIS_DEBUG(context, "error in redis_task_table_update");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -779,6 +1010,15 @@ void redis_task_table_test_and_update_callback(redisAsyncContext *c,
|
||||
redisReply *reply = (redisReply *) r;
|
||||
/* Parse the task from the reply. */
|
||||
Task *task = parse_and_construct_task_from_redis_reply(reply);
|
||||
if (task == NULL) {
|
||||
/* A NULL task means that the task was not in the task table. NOTE(swang):
|
||||
* For normal tasks, this is not expected behavior, but actor tasks may be
|
||||
* delayed when added to the task table if they are submitted to a local
|
||||
* scheduler before it receives the notification that maps the actor to a
|
||||
* local scheduler. */
|
||||
LOG_ERROR("No task found during task_table_test_and_update");
|
||||
return;
|
||||
}
|
||||
/* Determine whether the update happened. */
|
||||
auto message = flatbuffers::GetRoot<TaskReply>(reply->str);
|
||||
bool updated = message->updated();
|
||||
@@ -800,18 +1040,19 @@ void redis_task_table_test_and_update_callback(redisAsyncContext *c,
|
||||
void redis_task_table_test_and_update(TableCallbackData *callback_data) {
|
||||
DBHandle *db = callback_data->db_handle;
|
||||
TaskID task_id = callback_data->id;
|
||||
redisAsyncContext *context = get_redis_context(db, task_id);
|
||||
TaskTableTestAndUpdateData *update_data =
|
||||
(TaskTableTestAndUpdateData *) callback_data->data;
|
||||
|
||||
int status = redisAsyncCommand(
|
||||
db->context, redis_task_table_test_and_update_callback,
|
||||
context, redis_task_table_test_and_update_callback,
|
||||
(void *) callback_data->timer_id,
|
||||
"RAY.TASK_TABLE_TEST_AND_UPDATE %b %d %d %b", task_id.id,
|
||||
sizeof(task_id.id), update_data->test_state_bitmask,
|
||||
update_data->update_state, update_data->local_scheduler_id.id,
|
||||
sizeof(update_data->local_scheduler_id.id));
|
||||
if ((status == REDIS_ERR) || db->context->err) {
|
||||
LOG_REDIS_DEBUG(db->context, "error in redis_task_table_test_and_update");
|
||||
if ((status == REDIS_ERR) || context->err) {
|
||||
LOG_REDIS_DEBUG(context, "error in redis_task_table_test_and_update");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -879,24 +1120,26 @@ void redis_task_table_subscribe(TableCallbackData *callback_data) {
|
||||
/* TASK_CHANNEL_PREFIX is defined in ray_redis_module.cc and must be kept in
|
||||
* sync with that file. */
|
||||
const char *TASK_CHANNEL_PREFIX = "TT:";
|
||||
int status;
|
||||
if (IS_NIL_ID(data->local_scheduler_id)) {
|
||||
/* TODO(swang): Implement the state_filter by translating the bitmask into
|
||||
* a Redis key-matching pattern. */
|
||||
status =
|
||||
redisAsyncCommand(db->sub_context, redis_task_table_subscribe_callback,
|
||||
(void *) callback_data->timer_id, "PSUBSCRIBE %s*:%d",
|
||||
TASK_CHANNEL_PREFIX, data->state_filter);
|
||||
} else {
|
||||
DBClientID local_scheduler_id = data->local_scheduler_id;
|
||||
status =
|
||||
redisAsyncCommand(db->sub_context, redis_task_table_subscribe_callback,
|
||||
(void *) callback_data->timer_id, "SUBSCRIBE %s%b:%d",
|
||||
TASK_CHANNEL_PREFIX, (char *) local_scheduler_id.id,
|
||||
sizeof(local_scheduler_id.id), data->state_filter);
|
||||
}
|
||||
if ((status == REDIS_ERR) || db->sub_context->err) {
|
||||
LOG_REDIS_DEBUG(db->sub_context, "error in redis_task_table_subscribe");
|
||||
for (auto subscribe_context : db->subscribe_contexts) {
|
||||
int status;
|
||||
if (IS_NIL_ID(data->local_scheduler_id)) {
|
||||
/* TODO(swang): Implement the state_filter by translating the bitmask into
|
||||
* a Redis key-matching pattern. */
|
||||
status = redisAsyncCommand(
|
||||
subscribe_context, redis_task_table_subscribe_callback,
|
||||
(void *) callback_data->timer_id, "PSUBSCRIBE %s*:%d",
|
||||
TASK_CHANNEL_PREFIX, data->state_filter);
|
||||
} else {
|
||||
DBClientID local_scheduler_id = data->local_scheduler_id;
|
||||
status = redisAsyncCommand(
|
||||
subscribe_context, redis_task_table_subscribe_callback,
|
||||
(void *) callback_data->timer_id, "SUBSCRIBE %s%b:%d",
|
||||
TASK_CHANNEL_PREFIX, (char *) local_scheduler_id.id,
|
||||
sizeof(local_scheduler_id.id), data->state_filter);
|
||||
}
|
||||
if ((status == REDIS_ERR) || subscribe_context->err) {
|
||||
LOG_REDIS_DEBUG(subscribe_context, "error in redis_task_table_subscribe");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -934,6 +1177,55 @@ void redis_db_client_table_remove(TableCallbackData *callback_data) {
|
||||
}
|
||||
}
|
||||
|
||||
void redis_db_client_table_scan(DBHandle *db,
|
||||
std::vector<DBClient> &db_clients) {
|
||||
/* TODO(swang): Integrate this functionality with the Ray Redis module. To do
|
||||
* this, we need the KEYS or SCAN command in Redis modules. */
|
||||
/* Get all the database client keys. */
|
||||
redisReply *reply = (redisReply *) redisCommand(db->sync_context, "KEYS %s*",
|
||||
DB_CLIENT_PREFIX);
|
||||
if (reply->type == REDIS_REPLY_NIL) {
|
||||
return;
|
||||
}
|
||||
/* Get all the database client information. */
|
||||
CHECK(reply->type == REDIS_REPLY_ARRAY);
|
||||
for (int i = 0; i < reply->elements; ++i) {
|
||||
redisReply *client_reply = (redisReply *) redisCommand(
|
||||
db->sync_context, "HGETALL %b", reply->element[i]->str,
|
||||
reply->element[i]->len);
|
||||
CHECK(reply->type == REDIS_REPLY_ARRAY);
|
||||
CHECK(reply->elements > 0);
|
||||
DBClient db_client;
|
||||
memset(&db_client, 0, sizeof(db_client));
|
||||
int num_fields = 0;
|
||||
/* Parse the fields into a DBClient. */
|
||||
for (int j = 0; j < client_reply->elements; j = j + 2) {
|
||||
const char *key = client_reply->element[j]->str;
|
||||
const char *value = client_reply->element[j + 1]->str;
|
||||
if (strcmp(key, "ray_client_id") == 0) {
|
||||
memcpy(db_client.id.id, value, sizeof(db_client.id));
|
||||
num_fields++;
|
||||
} else if (strcmp(key, "client_type") == 0) {
|
||||
db_client.client_type = strdup(value);
|
||||
num_fields++;
|
||||
} else if (strcmp(key, "aux_address") == 0) {
|
||||
db_client.aux_address = strdup(value);
|
||||
num_fields++;
|
||||
} else if (strcmp(key, "deleted") == 0) {
|
||||
bool is_deleted = atoi(value);
|
||||
db_client.is_insertion = !is_deleted;
|
||||
num_fields++;
|
||||
}
|
||||
}
|
||||
freeReplyObject(client_reply);
|
||||
/* The client ID, type, and whether it is deleted are all mandatory fields.
|
||||
* Auxiliary address is optional. */
|
||||
CHECK(num_fields >= 3);
|
||||
db_clients.push_back(db_client);
|
||||
}
|
||||
freeReplyObject(reply);
|
||||
}
|
||||
|
||||
void redis_db_client_table_subscribe_callback(redisAsyncContext *c,
|
||||
void *r,
|
||||
void *privdata) {
|
||||
@@ -956,35 +1248,54 @@ void redis_db_client_table_subscribe_callback(redisAsyncContext *c,
|
||||
/* Note that we do not destroy the callback data yet because the
|
||||
* subscription callback needs this data. */
|
||||
event_loop_remove_timer(db->loop, callback_data->timer_id);
|
||||
|
||||
/* Get the current db client table entries, in case we missed notifications
|
||||
* before the initial subscription. This must be done before we process any
|
||||
* notifications from the subscription channel, so that we don't readd an
|
||||
* entry that has already been deleted. */
|
||||
std::vector<DBClient> db_clients;
|
||||
redis_db_client_table_scan(db, db_clients);
|
||||
/* Call the subscription callback for all entries that we missed. */
|
||||
DBClientTableSubscribeData *data =
|
||||
(DBClientTableSubscribeData *) callback_data->data;
|
||||
for (auto db_client : db_clients) {
|
||||
data->subscribe_callback(&db_client, data->subscribe_context);
|
||||
if (db_client.client_type != NULL) {
|
||||
free((void *) db_client.client_type);
|
||||
}
|
||||
if (db_client.aux_address != NULL) {
|
||||
free((void *) db_client.aux_address);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
/* Otherwise, parse the payload and call the callback. */
|
||||
auto message =
|
||||
flatbuffers::GetRoot<SubscribeToDBClientTableReply>(payload->str);
|
||||
DBClientID client = from_flatbuf(message->db_client_id());
|
||||
|
||||
/* Parse the client type and auxiliary address from the response. If there is
|
||||
* only client type, then the update was a delete. */
|
||||
char *client_type = (char *) message->client_type()->data();
|
||||
char *aux_address = (char *) message->aux_address()->data();
|
||||
bool is_insertion = message->is_insertion();
|
||||
DBClient db_client;
|
||||
db_client.id = from_flatbuf(message->db_client_id());
|
||||
db_client.client_type = (char *) message->client_type()->data();
|
||||
db_client.aux_address = message->aux_address()->data();
|
||||
db_client.is_insertion = message->is_insertion();
|
||||
|
||||
/* Call the subscription callback. */
|
||||
DBClientTableSubscribeData *data =
|
||||
(DBClientTableSubscribeData *) callback_data->data;
|
||||
if (data->subscribe_callback) {
|
||||
data->subscribe_callback(client, client_type, aux_address, is_insertion,
|
||||
data->subscribe_context);
|
||||
data->subscribe_callback(&db_client, data->subscribe_context);
|
||||
}
|
||||
}
|
||||
|
||||
void redis_db_client_table_subscribe(TableCallbackData *callback_data) {
|
||||
DBHandle *db = callback_data->db_handle;
|
||||
int status = redisAsyncCommand(
|
||||
db->sub_context, redis_db_client_table_subscribe_callback,
|
||||
db->subscribe_context, redis_db_client_table_subscribe_callback,
|
||||
(void *) callback_data->timer_id, "SUBSCRIBE db_clients");
|
||||
if ((status == REDIS_ERR) || db->sub_context->err) {
|
||||
LOG_REDIS_DEBUG(db->sub_context,
|
||||
if ((status == REDIS_ERR) || db->subscribe_context->err) {
|
||||
LOG_REDIS_DEBUG(db->subscribe_context,
|
||||
"error in db_client_table_register_callback");
|
||||
}
|
||||
}
|
||||
@@ -1042,10 +1353,10 @@ void redis_local_scheduler_table_subscribe_callback(redisAsyncContext *c,
|
||||
void redis_local_scheduler_table_subscribe(TableCallbackData *callback_data) {
|
||||
DBHandle *db = callback_data->db_handle;
|
||||
int status = redisAsyncCommand(
|
||||
db->sub_context, redis_local_scheduler_table_subscribe_callback,
|
||||
db->subscribe_context, redis_local_scheduler_table_subscribe_callback,
|
||||
(void *) callback_data->timer_id, "SUBSCRIBE local_schedulers");
|
||||
if ((status == REDIS_ERR) || db->sub_context->err) {
|
||||
LOG_REDIS_DEBUG(db->sub_context,
|
||||
if ((status == REDIS_ERR) || db->subscribe_context->err) {
|
||||
LOG_REDIS_DEBUG(db->subscribe_context,
|
||||
"error in redis_local_scheduler_table_subscribe");
|
||||
}
|
||||
}
|
||||
@@ -1130,10 +1441,11 @@ void redis_driver_table_subscribe_callback(redisAsyncContext *c,
|
||||
void redis_driver_table_subscribe(TableCallbackData *callback_data) {
|
||||
DBHandle *db = callback_data->db_handle;
|
||||
int status = redisAsyncCommand(
|
||||
db->sub_context, redis_driver_table_subscribe_callback,
|
||||
db->subscribe_context, redis_driver_table_subscribe_callback,
|
||||
(void *) callback_data->timer_id, "SUBSCRIBE driver_deaths");
|
||||
if ((status == REDIS_ERR) || db->sub_context->err) {
|
||||
LOG_REDIS_DEBUG(db->sub_context, "error in redis_driver_table_subscribe");
|
||||
if ((status == REDIS_ERR) || db->subscribe_context->err) {
|
||||
LOG_REDIS_DEBUG(db->subscribe_context,
|
||||
"error in redis_driver_table_subscribe");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1240,10 +1552,10 @@ void redis_actor_notification_table_subscribe(
|
||||
TableCallbackData *callback_data) {
|
||||
DBHandle *db = callback_data->db_handle;
|
||||
int status = redisAsyncCommand(
|
||||
db->sub_context, redis_actor_notification_table_subscribe_callback,
|
||||
db->subscribe_context, redis_actor_notification_table_subscribe_callback,
|
||||
(void *) callback_data->timer_id, "SUBSCRIBE actor_notifications");
|
||||
if ((status == REDIS_ERR) || db->sub_context->err) {
|
||||
LOG_REDIS_DEBUG(db->sub_context,
|
||||
if ((status == REDIS_ERR) || db->subscribe_context->err) {
|
||||
LOG_REDIS_DEBUG(db->subscribe_context,
|
||||
"error in redis_actor_notification_table_subscribe");
|
||||
}
|
||||
}
|
||||
@@ -1290,10 +1602,11 @@ void redis_object_info_subscribe_callback(redisAsyncContext *c,
|
||||
void redis_object_info_subscribe(TableCallbackData *callback_data) {
|
||||
DBHandle *db = callback_data->db_handle;
|
||||
int status = redisAsyncCommand(
|
||||
db->sub_context, redis_object_info_subscribe_callback,
|
||||
db->subscribe_context, redis_object_info_subscribe_callback,
|
||||
(void *) callback_data->timer_id, "PSUBSCRIBE obj:info");
|
||||
if ((status == REDIS_ERR) || db->sub_context->err) {
|
||||
LOG_REDIS_DEBUG(db->sub_context, "error in object_info_register_callback");
|
||||
if ((status == REDIS_ERR) || db->subscribe_context->err) {
|
||||
LOG_REDIS_DEBUG(db->subscribe_context,
|
||||
"error in object_info_register_callback");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1324,8 +1637,8 @@ void redis_push_error_hmset_callback(redisAsyncContext *c,
|
||||
"RPUSH ErrorKeys Error:%b:%b",
|
||||
info->driver_id.id, sizeof(info->driver_id.id),
|
||||
info->error_key, sizeof(info->error_key));
|
||||
if ((status == REDIS_ERR) || db->sub_context->err) {
|
||||
LOG_REDIS_DEBUG(db->sub_context, "error in redis_push_error rpush");
|
||||
if ((status == REDIS_ERR) || db->subscribe_context->err) {
|
||||
LOG_REDIS_DEBUG(db->subscribe_context, "error in redis_push_error rpush");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1344,8 +1657,8 @@ void redis_push_error(TableCallbackData *callback_data) {
|
||||
"HMSET Error:%b:%b type %s message %s data %b", info->driver_id.id,
|
||||
sizeof(info->driver_id.id), info->error_key, sizeof(info->error_key),
|
||||
error_type, error_message, info->data, info->data_length);
|
||||
if ((status == REDIS_ERR) || db->sub_context->err) {
|
||||
LOG_REDIS_DEBUG(db->sub_context, "error in redis_push_error hmset");
|
||||
if ((status == REDIS_ERR) || db->subscribe_context->err) {
|
||||
LOG_REDIS_DEBUG(db->subscribe_context, "error in redis_push_error hmset");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -34,11 +34,24 @@ struct DBHandle {
|
||||
char *client_type;
|
||||
/** Unique ID for this client. */
|
||||
DBClientID client;
|
||||
/** Redis context for all non-subscribe connections. */
|
||||
/** Primary redis context for all non-subscribe connections. This is used for
|
||||
* the database client table, heartbeats, and errors that should be pushed to
|
||||
* the driver. */
|
||||
redisAsyncContext *context;
|
||||
/** Redis context for "subscribe" communication. Yes, we need a separate one
|
||||
* for that, see https://github.com/redis/hiredis/issues/55. */
|
||||
redisAsyncContext *sub_context;
|
||||
/** Primary redis context for "subscribe" communication. A separate context
|
||||
* is needed for this communication (see
|
||||
* https://github.com/redis/hiredis/issues/55). This is used for the
|
||||
* database client table, heartbeats, and errors that should be pushed to
|
||||
* the driver. */
|
||||
redisAsyncContext *subscribe_context;
|
||||
/** Redis contexts for shards for all non-subscribe connections. All requests
|
||||
* to the object table, task table, and event table should be directed here.
|
||||
* The correct shard can be retrieved using get_redis_context below. */
|
||||
std::vector<redisAsyncContext *> contexts;
|
||||
/** Redis contexts for shards for "subscribe" communication. All requests
|
||||
* to the object table, task table, and event table should be directed here.
|
||||
* The correct shard can be retrieved using get_redis_context below. */
|
||||
std::vector<redisAsyncContext *> subscribe_contexts;
|
||||
/** The event loop this global state store connection is part of. */
|
||||
event_loop *loop;
|
||||
/** Index of the database connection in the event loop */
|
||||
@@ -51,6 +64,40 @@ struct DBHandle {
|
||||
redisContext *sync_context;
|
||||
};
|
||||
|
||||
/**
|
||||
* Get the Redis asynchronous context responsible for non-subscription
|
||||
* communication for the given UniqueID.
|
||||
*
|
||||
* @param db The database handle.
|
||||
* @param id The ID whose location we are querying for.
|
||||
* @return The redisAsyncContext responsible for the given ID.
|
||||
*/
|
||||
redisAsyncContext *get_redis_context(DBHandle *db, UniqueID id);
|
||||
|
||||
/**
|
||||
* Get the Redis asynchronous context responsible for subscription
|
||||
* communication for the given UniqueID.
|
||||
*
|
||||
* @param db The database handle.
|
||||
* @param id The ID whose location we are querying for.
|
||||
* @return The redisAsyncContext responsible for the given ID.
|
||||
*/
|
||||
redisAsyncContext *get_redis_subscribe_context(DBHandle *db, UniqueID id);
|
||||
|
||||
/**
|
||||
* Get a list of Redis shard IP addresses from the primary shard.
|
||||
*
|
||||
* @param context A Redis context connected to the primary shard.
|
||||
* @param db_shards_addresses The IP addresses for the shards registered
|
||||
* with the primary shard will be added to this vector.
|
||||
* @param db_shards_ports The IP ports for the shards registered with the
|
||||
* primary shard will be added to this vector, in the same order as
|
||||
* db_shards_addresses.
|
||||
*/
|
||||
void get_redis_shards(redisContext *context,
|
||||
std::vector<std::string> &db_shards_addresses,
|
||||
std::vector<int> &db_shards_ports);
|
||||
|
||||
void redis_object_table_get_entry(redisAsyncContext *c,
|
||||
void *r,
|
||||
void *privdata);
|
||||
|
||||
@@ -151,9 +151,10 @@ typedef void (*task_table_subscribe_callback)(Task *task, void *user_context);
|
||||
* @param local_scheduler_id The db_client_id of the local scheduler whose
|
||||
* events we want to listen to. If you want to subscribe to updates from
|
||||
* all local schedulers, pass in NIL_ID.
|
||||
* @param state_filter Flags for events we want to listen to. If you want
|
||||
* to listen to all events, use state_filter = TASK_WAITING |
|
||||
* TASK_SCHEDULED | TASK_RUNNING | TASK_DONE.
|
||||
* @param state_filter Events we want to listen to. Can have values from the
|
||||
* enum "scheduling_state" in task.h.
|
||||
* TODO(pcm): Make it possible to combine these using flags like
|
||||
* TASK_STATUS_WAITING | TASK_STATUS_SCHEDULED.
|
||||
* @param retry Information about retrying the request to the database.
|
||||
* @param done_callback Function to be called when database returns result.
|
||||
* @param user_context Data that will be passed to done_callback and
|
||||
|
||||
@@ -72,12 +72,12 @@ TEST object_table_lookup_test(void) {
|
||||
event_loop *loop = event_loop_create();
|
||||
/* This uses manager_port1. */
|
||||
const char *db_connect_args1[] = {"address", "127.0.0.1:12345"};
|
||||
DBHandle *db1 = db_connect("127.0.0.1", 6379, "plasma_manager", manager_addr,
|
||||
2, db_connect_args1);
|
||||
DBHandle *db1 = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
|
||||
manager_addr, 2, db_connect_args1);
|
||||
/* This uses manager_port2. */
|
||||
const char *db_connect_args2[] = {"address", "127.0.0.1:12346"};
|
||||
DBHandle *db2 = db_connect("127.0.0.1", 6379, "plasma_manager", manager_addr,
|
||||
2, db_connect_args2);
|
||||
DBHandle *db2 = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
|
||||
manager_addr, 2, db_connect_args2);
|
||||
db_attach(db1, loop, false);
|
||||
db_attach(db2, loop, false);
|
||||
UniqueID id = globally_unique_id();
|
||||
@@ -148,8 +148,8 @@ void task_table_test_callback(Task *callback_task, void *user_data) {
|
||||
TEST task_table_test(void) {
|
||||
task_table_test_callback_called = 0;
|
||||
event_loop *loop = event_loop_create();
|
||||
DBHandle *db =
|
||||
db_connect("127.0.0.1", 6379, "local_scheduler", "127.0.0.1", 0, NULL);
|
||||
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "local_scheduler",
|
||||
"127.0.0.1", 0, NULL);
|
||||
db_attach(db, loop, false);
|
||||
DBClientID local_scheduler_id = globally_unique_id();
|
||||
int64_t task_spec_size;
|
||||
@@ -184,8 +184,8 @@ void task_table_all_test_callback(Task *task, void *user_data) {
|
||||
|
||||
TEST task_table_all_test(void) {
|
||||
event_loop *loop = event_loop_create();
|
||||
DBHandle *db =
|
||||
db_connect("127.0.0.1", 6379, "local_scheduler", "127.0.0.1", 0, NULL);
|
||||
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "local_scheduler",
|
||||
"127.0.0.1", 0, NULL);
|
||||
db_attach(db, loop, false);
|
||||
int64_t task_spec_size;
|
||||
TaskSpec *spec = example_task_spec(1, 1, &task_spec_size);
|
||||
@@ -222,7 +222,8 @@ TEST unique_client_id_test(void) {
|
||||
DBClientID ids[num_conns];
|
||||
DBHandle *db;
|
||||
for (int i = 0; i < num_conns; ++i) {
|
||||
db = db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
|
||||
db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
|
||||
"127.0.0.1", 0, NULL);
|
||||
ids[i] = get_db_client_id(db);
|
||||
db_disconnect(db);
|
||||
}
|
||||
|
||||
@@ -65,6 +65,15 @@ void new_object_task_callback(TaskID task_id, void *user_context) {
|
||||
new_object_lookup_callback, (void *) db);
|
||||
}
|
||||
|
||||
void task_table_subscribe_done(TaskID task_id, void *user_context) {
|
||||
RetryInfo retry = {
|
||||
.num_retries = 5, .timeout = 100, .fail_callback = NULL,
|
||||
};
|
||||
DBHandle *db = (DBHandle *) user_context;
|
||||
task_table_add_task(db, Task_copy(new_object_task), &retry,
|
||||
new_object_task_callback, db);
|
||||
}
|
||||
|
||||
TEST new_object_test(void) {
|
||||
new_object_failed = 0;
|
||||
new_object_succeeded = 0;
|
||||
@@ -73,16 +82,16 @@ TEST new_object_test(void) {
|
||||
new_object_task_spec = Task_task_spec(new_object_task);
|
||||
new_object_task_id = TaskSpec_task_id(new_object_task_spec);
|
||||
g_loop = event_loop_create();
|
||||
DBHandle *db =
|
||||
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
|
||||
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
|
||||
"127.0.0.1", 0, NULL);
|
||||
db_attach(db, g_loop, false);
|
||||
RetryInfo retry = {
|
||||
.num_retries = 5,
|
||||
.timeout = 100,
|
||||
.fail_callback = new_object_fail_callback,
|
||||
};
|
||||
task_table_add_task(db, Task_copy(new_object_task), &retry,
|
||||
new_object_task_callback, db);
|
||||
task_table_subscribe(db, NIL_ID, TASK_STATUS_WAITING, NULL, NULL, &retry,
|
||||
task_table_subscribe_done, db);
|
||||
event_loop_run(g_loop);
|
||||
db_disconnect(db);
|
||||
destroy_outstanding_callbacks(g_loop);
|
||||
@@ -109,8 +118,8 @@ TEST new_object_no_task_test(void) {
|
||||
new_object_id = globally_unique_id();
|
||||
new_object_task_id = globally_unique_id();
|
||||
g_loop = event_loop_create();
|
||||
DBHandle *db =
|
||||
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
|
||||
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
|
||||
"127.0.0.1", 0, NULL);
|
||||
db_attach(db, g_loop, false);
|
||||
RetryInfo retry = {
|
||||
.num_retries = 5,
|
||||
@@ -151,8 +160,8 @@ void lookup_fail_callback(UniqueID id, void *user_context, void *user_data) {
|
||||
|
||||
TEST lookup_timeout_test(void) {
|
||||
g_loop = event_loop_create();
|
||||
DBHandle *db =
|
||||
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
|
||||
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
|
||||
"127.0.0.1", 0, NULL);
|
||||
db_attach(db, g_loop, false);
|
||||
RetryInfo retry = {
|
||||
.num_retries = 5, .timeout = 100, .fail_callback = lookup_fail_callback,
|
||||
@@ -161,6 +170,9 @@ TEST lookup_timeout_test(void) {
|
||||
(void *) lookup_timeout_context);
|
||||
/* Disconnect the database to see if the lookup times out. */
|
||||
close(db->context->c.fd);
|
||||
for (auto context : db->contexts) {
|
||||
close(context->c.fd);
|
||||
}
|
||||
event_loop_run(g_loop);
|
||||
db_disconnect(db);
|
||||
destroy_outstanding_callbacks(g_loop);
|
||||
@@ -187,8 +199,8 @@ void add_fail_callback(UniqueID id, void *user_context, void *user_data) {
|
||||
|
||||
TEST add_timeout_test(void) {
|
||||
g_loop = event_loop_create();
|
||||
DBHandle *db =
|
||||
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
|
||||
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
|
||||
"127.0.0.1", 0, NULL);
|
||||
db_attach(db, g_loop, false);
|
||||
RetryInfo retry = {
|
||||
.num_retries = 5, .timeout = 100, .fail_callback = add_fail_callback,
|
||||
@@ -197,6 +209,9 @@ TEST add_timeout_test(void) {
|
||||
add_done_callback, (void *) add_timeout_context);
|
||||
/* Disconnect the database to see if the lookup times out. */
|
||||
close(db->context->c.fd);
|
||||
for (auto context : db->contexts) {
|
||||
close(context->c.fd);
|
||||
}
|
||||
event_loop_run(g_loop);
|
||||
db_disconnect(db);
|
||||
destroy_outstanding_callbacks(g_loop);
|
||||
@@ -225,8 +240,8 @@ void subscribe_fail_callback(UniqueID id, void *user_context, void *user_data) {
|
||||
|
||||
TEST subscribe_timeout_test(void) {
|
||||
g_loop = event_loop_create();
|
||||
DBHandle *db =
|
||||
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
|
||||
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
|
||||
"127.0.0.1", 0, NULL);
|
||||
db_attach(db, g_loop, false);
|
||||
RetryInfo retry = {
|
||||
.num_retries = 5,
|
||||
@@ -236,7 +251,10 @@ TEST subscribe_timeout_test(void) {
|
||||
object_table_subscribe_to_notifications(db, false, subscribe_done_callback,
|
||||
NULL, &retry, NULL, NULL);
|
||||
/* Disconnect the database to see if the lookup times out. */
|
||||
close(db->sub_context->c.fd);
|
||||
close(db->subscribe_context->c.fd);
|
||||
for (auto subscribe_context : db->subscribe_contexts) {
|
||||
close(subscribe_context->c.fd);
|
||||
}
|
||||
event_loop_run(g_loop);
|
||||
db_disconnect(db);
|
||||
destroy_outstanding_callbacks(g_loop);
|
||||
@@ -324,8 +342,8 @@ TEST add_lookup_test(void) {
|
||||
lookup_retry_succeeded = 0;
|
||||
/* Construct the arguments to db_connect. */
|
||||
const char *db_connect_args[] = {"address", "127.0.0.1:11235"};
|
||||
DBHandle *db = db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 2,
|
||||
db_connect_args);
|
||||
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
|
||||
"127.0.0.1", 2, db_connect_args);
|
||||
db_attach(db, g_loop, true);
|
||||
RetryInfo retry = {
|
||||
.num_retries = 5,
|
||||
@@ -385,8 +403,8 @@ void add_remove_callback(ObjectID object_id, bool success, void *user_context) {
|
||||
TEST add_remove_lookup_test(void) {
|
||||
g_loop = event_loop_create();
|
||||
lookup_retry_succeeded = 0;
|
||||
DBHandle *db =
|
||||
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
|
||||
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
|
||||
"127.0.0.1", 0, NULL);
|
||||
db_attach(db, g_loop, true);
|
||||
RetryInfo retry = {
|
||||
.num_retries = 5,
|
||||
@@ -407,29 +425,6 @@ TEST add_remove_lookup_test(void) {
|
||||
PASS();
|
||||
}
|
||||
|
||||
/* === Test subscribe retry === */
|
||||
|
||||
const char *subscribe_retry_context = "subscribe_retry";
|
||||
int subscribe_retry_succeeded = 0;
|
||||
|
||||
int64_t reconnect_sub_context_callback(event_loop *loop,
|
||||
int64_t timer_id,
|
||||
void *context) {
|
||||
DBHandle *db = (DBHandle *) context;
|
||||
/* Reconnect to redis. This is not reconnecting the pub/sub channel. */
|
||||
redisAsyncFree(db->sub_context);
|
||||
redisAsyncFree(db->context);
|
||||
redisFree(db->sync_context);
|
||||
db->sub_context = redisAsyncConnect("127.0.0.1", 6379);
|
||||
db->sub_context->data = (void *) db;
|
||||
db->context = redisAsyncConnect("127.0.0.1", 6379);
|
||||
db->context->data = (void *) db;
|
||||
db->sync_context = redisConnect("127.0.0.1", 6379);
|
||||
/* Re-attach the database to the event loop (the file descriptor changed). */
|
||||
db_attach(db, loop, true);
|
||||
return EVENT_LOOP_TIMER_DONE;
|
||||
}
|
||||
|
||||
/* ==== Test if late succeed is working correctly ==== */
|
||||
|
||||
/* === Test lookup late succeed === */
|
||||
@@ -454,8 +449,8 @@ void lookup_late_done_callback(ObjectID object_id,
|
||||
|
||||
TEST lookup_late_test(void) {
|
||||
g_loop = event_loop_create();
|
||||
DBHandle *db =
|
||||
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
|
||||
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
|
||||
"127.0.0.1", 0, NULL);
|
||||
db_attach(db, g_loop, false);
|
||||
RetryInfo retry = {
|
||||
.num_retries = 0,
|
||||
@@ -498,8 +493,8 @@ void add_late_done_callback(ObjectID object_id,
|
||||
|
||||
TEST add_late_test(void) {
|
||||
g_loop = event_loop_create();
|
||||
DBHandle *db =
|
||||
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
|
||||
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
|
||||
"127.0.0.1", 0, NULL);
|
||||
db_attach(db, g_loop, false);
|
||||
RetryInfo retry = {
|
||||
.num_retries = 0, .timeout = 0, .fail_callback = add_late_fail_callback,
|
||||
@@ -543,8 +538,8 @@ void subscribe_late_done_callback(ObjectID object_id,
|
||||
|
||||
TEST subscribe_late_test(void) {
|
||||
g_loop = event_loop_create();
|
||||
DBHandle *db =
|
||||
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
|
||||
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
|
||||
"127.0.0.1", 0, NULL);
|
||||
db_attach(db, g_loop, false);
|
||||
RetryInfo retry = {
|
||||
.num_retries = 0,
|
||||
@@ -611,8 +606,8 @@ TEST subscribe_success_test(void) {
|
||||
|
||||
/* Construct the arguments to db_connect. */
|
||||
const char *db_connect_args[] = {"address", "127.0.0.1:11236"};
|
||||
DBHandle *db = db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 2,
|
||||
db_connect_args);
|
||||
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
|
||||
"127.0.0.1", 2, db_connect_args);
|
||||
db_attach(db, g_loop, false);
|
||||
subscribe_id = globally_unique_id();
|
||||
|
||||
@@ -680,8 +675,8 @@ TEST subscribe_object_present_test(void) {
|
||||
g_loop = event_loop_create();
|
||||
/* Construct the arguments to db_connect. */
|
||||
const char *db_connect_args[] = {"address", "127.0.0.1:11236"};
|
||||
DBHandle *db = db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 2,
|
||||
db_connect_args);
|
||||
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
|
||||
"127.0.0.1", 2, db_connect_args);
|
||||
db_attach(db, g_loop, false);
|
||||
UniqueID id = globally_unique_id();
|
||||
RetryInfo retry = {
|
||||
@@ -732,8 +727,8 @@ void subscribe_object_not_present_object_available_callback(
|
||||
|
||||
TEST subscribe_object_not_present_test(void) {
|
||||
g_loop = event_loop_create();
|
||||
DBHandle *db =
|
||||
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
|
||||
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
|
||||
"127.0.0.1", 0, NULL);
|
||||
db_attach(db, g_loop, false);
|
||||
UniqueID id = globally_unique_id();
|
||||
RetryInfo retry = {
|
||||
@@ -796,8 +791,8 @@ TEST subscribe_object_available_later_test(void) {
|
||||
g_loop = event_loop_create();
|
||||
/* Construct the arguments to db_connect. */
|
||||
const char *db_connect_args[] = {"address", "127.0.0.1:11236"};
|
||||
DBHandle *db = db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 2,
|
||||
db_connect_args);
|
||||
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
|
||||
"127.0.0.1", 2, db_connect_args);
|
||||
db_attach(db, g_loop, false);
|
||||
UniqueID id = globally_unique_id();
|
||||
RetryInfo retry = {
|
||||
@@ -849,8 +844,8 @@ TEST subscribe_object_available_subscribe_all(void) {
|
||||
g_loop = event_loop_create();
|
||||
/* Construct the arguments to db_connect. */
|
||||
const char *db_connect_args[] = {"address", "127.0.0.1:11236"};
|
||||
DBHandle *db = db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 2,
|
||||
db_connect_args);
|
||||
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
|
||||
"127.0.0.1", 2, db_connect_args);
|
||||
db_attach(db, g_loop, false);
|
||||
UniqueID id = globally_unique_id();
|
||||
RetryInfo retry = {
|
||||
@@ -904,7 +899,6 @@ SUITE(object_table_tests) {
|
||||
RUN_REDIS_TEST(add_late_test);
|
||||
RUN_REDIS_TEST(subscribe_late_test);
|
||||
RUN_REDIS_TEST(subscribe_success_test);
|
||||
RUN_REDIS_TEST(subscribe_object_present_test);
|
||||
RUN_REDIS_TEST(subscribe_object_not_present_test);
|
||||
RUN_REDIS_TEST(subscribe_object_available_later_test);
|
||||
RUN_REDIS_TEST(subscribe_object_available_subscribe_all);
|
||||
|
||||
@@ -102,8 +102,8 @@ TEST async_redis_socket_test(void) {
|
||||
utarray_push_back(connections, &socket_fd);
|
||||
|
||||
/* Start connection to Redis. */
|
||||
DBHandle *db =
|
||||
db_connect("127.0.0.1", 6379, "test_process", "127.0.0.1", 0, NULL);
|
||||
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "test_process",
|
||||
"127.0.0.1", 0, NULL);
|
||||
db_attach(db, loop, false);
|
||||
|
||||
/* Send a command to the Redis process. */
|
||||
@@ -177,8 +177,8 @@ TEST logging_test(void) {
|
||||
utarray_push_back(connections, &socket_fd);
|
||||
|
||||
/* Start connection to Redis. */
|
||||
DBHandle *conn =
|
||||
db_connect("127.0.0.1", 6379, "test_process", "127.0.0.1", 0, NULL);
|
||||
DBHandle *conn = db_connect(std::string("127.0.0.1"), 6379, "test_process",
|
||||
"127.0.0.1", 0, NULL);
|
||||
db_attach(conn, loop, false);
|
||||
|
||||
/* Send a command to the Redis process. */
|
||||
|
||||
@@ -5,8 +5,14 @@
|
||||
# Cause the script to exit if a single command fails.
|
||||
set -e
|
||||
|
||||
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so &
|
||||
# Start the Redis shards.
|
||||
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6379 &
|
||||
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6380 &
|
||||
sleep 1s
|
||||
# Register the shard location with the primary shard.
|
||||
./src/common/thirdparty/redis/src/redis-cli set NumRedisShards 1
|
||||
./src/common/thirdparty/redis/src/redis-cli rpush RedisShards 127.0.0.1:6380
|
||||
|
||||
./src/common/common_tests
|
||||
./src/common/db_tests
|
||||
./src/common/io_tests
|
||||
@@ -14,4 +20,5 @@ sleep 1s
|
||||
./src/common/redis_tests
|
||||
./src/common/task_table_tests
|
||||
./src/common/object_table_tests
|
||||
./src/common/thirdparty/redis/src/redis-cli shutdown
|
||||
./src/common/thirdparty/redis/src/redis-cli -p 6379 shutdown
|
||||
./src/common/thirdparty/redis/src/redis-cli -p 6380 shutdown
|
||||
|
||||
@@ -5,8 +5,14 @@
|
||||
# Cause the script to exit if a single command fails.
|
||||
set -e
|
||||
|
||||
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so &
|
||||
# Start the Redis shards.
|
||||
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6379 &
|
||||
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6380 &
|
||||
sleep 1s
|
||||
# Register the shard location with the primary shard.
|
||||
./src/common/thirdparty/redis/src/redis-cli set NumRedisShards 1
|
||||
./src/common/thirdparty/redis/src/redis-cli rpush RedisShards 127.0.0.1:6380
|
||||
|
||||
valgrind --leak-check=full --error-exitcode=1 ./src/common/common_tests
|
||||
valgrind --leak-check=full --error-exitcode=1 ./src/common/db_tests
|
||||
valgrind --leak-check=full --error-exitcode=1 ./src/common/io_tests
|
||||
@@ -14,4 +20,6 @@ valgrind --leak-check=full --error-exitcode=1 ./src/common/task_tests
|
||||
valgrind --leak-check=full --error-exitcode=1 ./src/common/redis_tests
|
||||
valgrind --leak-check=full --error-exitcode=1 ./src/common/task_table_tests
|
||||
valgrind --leak-check=full --error-exitcode=1 ./src/common/object_table_tests
|
||||
|
||||
./src/common/thirdparty/redis/src/redis-cli shutdown
|
||||
./src/common/thirdparty/redis/src/redis-cli -p 6380 shutdown
|
||||
|
||||
@@ -40,8 +40,8 @@ void lookup_nil_success_callback(Task *task, void *context) {
|
||||
TEST lookup_nil_test(void) {
|
||||
lookup_nil_id = globally_unique_id();
|
||||
g_loop = event_loop_create();
|
||||
DBHandle *db =
|
||||
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
|
||||
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
|
||||
"127.0.0.1", 0, NULL);
|
||||
db_attach(db, g_loop, false);
|
||||
RetryInfo retry = {
|
||||
.num_retries = 5,
|
||||
@@ -96,14 +96,16 @@ void add_success_callback(TaskID task_id, void *context) {
|
||||
TEST add_lookup_test(void) {
|
||||
add_lookup_task = example_task(1, 1, TASK_STATUS_WAITING);
|
||||
g_loop = event_loop_create();
|
||||
DBHandle *db =
|
||||
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
|
||||
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
|
||||
"127.0.0.1", 0, NULL);
|
||||
db_attach(db, g_loop, false);
|
||||
RetryInfo retry = {
|
||||
.num_retries = 5,
|
||||
.timeout = 1000,
|
||||
.fail_callback = add_lookup_fail_callback,
|
||||
};
|
||||
task_table_subscribe(db, NIL_ID, TASK_STATUS_WAITING, NULL, NULL, &retry,
|
||||
NULL, NULL);
|
||||
task_table_add_task(db, Task_copy(add_lookup_task), &retry,
|
||||
add_success_callback, (void *) db);
|
||||
/* Disconnect the database to see if the lookup times out. */
|
||||
@@ -136,8 +138,8 @@ void subscribe_fail_callback(UniqueID id, void *user_context, void *user_data) {
|
||||
|
||||
TEST subscribe_timeout_test(void) {
|
||||
g_loop = event_loop_create();
|
||||
DBHandle *db =
|
||||
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
|
||||
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
|
||||
"127.0.0.1", 0, NULL);
|
||||
db_attach(db, g_loop, false);
|
||||
RetryInfo retry = {
|
||||
.num_retries = 5,
|
||||
@@ -148,7 +150,10 @@ TEST subscribe_timeout_test(void) {
|
||||
subscribe_done_callback,
|
||||
(void *) subscribe_timeout_context);
|
||||
/* Disconnect the database to see if the subscribe times out. */
|
||||
close(db->sub_context->c.fd);
|
||||
close(db->subscribe_context->c.fd);
|
||||
for (int i = 0; i < db->subscribe_contexts.size(); ++i) {
|
||||
close(db->subscribe_contexts[i]->c.fd);
|
||||
}
|
||||
aeProcessEvents(g_loop, AE_TIME_EVENTS);
|
||||
event_loop_run(g_loop);
|
||||
db_disconnect(db);
|
||||
@@ -177,17 +182,22 @@ void publish_fail_callback(UniqueID id, void *user_context, void *user_data) {
|
||||
|
||||
TEST publish_timeout_test(void) {
|
||||
g_loop = event_loop_create();
|
||||
DBHandle *db =
|
||||
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
|
||||
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
|
||||
"127.0.0.1", 0, NULL);
|
||||
db_attach(db, g_loop, false);
|
||||
Task *task = example_task(1, 1, TASK_STATUS_WAITING);
|
||||
RetryInfo retry = {
|
||||
.num_retries = 5, .timeout = 100, .fail_callback = publish_fail_callback,
|
||||
};
|
||||
task_table_subscribe(db, NIL_ID, TASK_STATUS_WAITING, NULL, NULL, &retry,
|
||||
NULL, NULL);
|
||||
task_table_add_task(db, task, &retry, publish_done_callback,
|
||||
(void *) publish_timeout_context);
|
||||
/* Disconnect the database to see if the publish times out. */
|
||||
close(db->context->c.fd);
|
||||
for (int i = 0; i < db->contexts.size(); ++i) {
|
||||
close(db->contexts[i]->c.fd);
|
||||
}
|
||||
aeProcessEvents(g_loop, AE_TIME_EVENTS);
|
||||
event_loop_run(g_loop);
|
||||
db_disconnect(db);
|
||||
@@ -204,9 +214,14 @@ int64_t reconnect_db_callback(event_loop *loop,
|
||||
void *context) {
|
||||
DBHandle *db = (DBHandle *) context;
|
||||
/* Reconnect to redis. */
|
||||
redisAsyncFree(db->sub_context);
|
||||
db->sub_context = redisAsyncConnect("127.0.0.1", 6379);
|
||||
db->sub_context->data = (void *) db;
|
||||
redisAsyncFree(db->subscribe_context);
|
||||
db->subscribe_context = redisAsyncConnect("127.0.0.1", 6379);
|
||||
db->subscribe_context->data = (void *) db;
|
||||
for (int i = 0; i < db->subscribe_contexts.size(); ++i) {
|
||||
redisAsyncFree(db->subscribe_contexts[i]);
|
||||
db->subscribe_contexts[i] = redisAsyncConnect("127.0.0.1", 6380 + i);
|
||||
db->subscribe_contexts[i]->data = (void *) db;
|
||||
}
|
||||
/* Re-attach the database to the event loop (the file descriptor changed). */
|
||||
db_attach(db, loop, true);
|
||||
return EVENT_LOOP_TIMER_DONE;
|
||||
@@ -239,8 +254,8 @@ void subscribe_retry_fail_callback(UniqueID id,
|
||||
|
||||
TEST subscribe_retry_test(void) {
|
||||
g_loop = event_loop_create();
|
||||
DBHandle *db =
|
||||
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
|
||||
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
|
||||
"127.0.0.1", 0, NULL);
|
||||
db_attach(db, g_loop, false);
|
||||
RetryInfo retry = {
|
||||
.num_retries = 5,
|
||||
@@ -251,7 +266,10 @@ TEST subscribe_retry_test(void) {
|
||||
subscribe_retry_done_callback,
|
||||
(void *) subscribe_retry_context);
|
||||
/* Disconnect the database to see if the subscribe times out. */
|
||||
close(db->sub_context->c.fd);
|
||||
close(db->subscribe_context->c.fd);
|
||||
for (int i = 0; i < db->subscribe_contexts.size(); ++i) {
|
||||
close(db->subscribe_contexts[i]->c.fd);
|
||||
}
|
||||
/* Install handler for reconnecting the database. */
|
||||
event_loop_add_timer(g_loop, 150,
|
||||
(event_loop_timer_handler) reconnect_db_callback, db);
|
||||
@@ -286,8 +304,8 @@ void publish_retry_fail_callback(UniqueID id,
|
||||
|
||||
TEST publish_retry_test(void) {
|
||||
g_loop = event_loop_create();
|
||||
DBHandle *db =
|
||||
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
|
||||
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
|
||||
"127.0.0.1", 0, NULL);
|
||||
db_attach(db, g_loop, false);
|
||||
Task *task = example_task(1, 1, TASK_STATUS_WAITING);
|
||||
RetryInfo retry = {
|
||||
@@ -295,10 +313,15 @@ TEST publish_retry_test(void) {
|
||||
.timeout = 100,
|
||||
.fail_callback = publish_retry_fail_callback,
|
||||
};
|
||||
task_table_subscribe(db, NIL_ID, TASK_STATUS_WAITING, NULL, NULL, &retry,
|
||||
NULL, NULL);
|
||||
task_table_add_task(db, task, &retry, publish_retry_done_callback,
|
||||
(void *) publish_retry_context);
|
||||
/* Disconnect the database to see if the publish times out. */
|
||||
close(db->sub_context->c.fd);
|
||||
close(db->subscribe_context->c.fd);
|
||||
for (int i = 0; i < db->subscribe_contexts.size(); ++i) {
|
||||
close(db->subscribe_contexts[i]->c.fd);
|
||||
}
|
||||
/* Install handler for reconnecting the database. */
|
||||
event_loop_add_timer(g_loop, 150,
|
||||
(event_loop_timer_handler) reconnect_db_callback, db);
|
||||
@@ -335,8 +358,8 @@ void subscribe_late_done_callback(TaskID task_id, void *user_context) {
|
||||
|
||||
TEST subscribe_late_test(void) {
|
||||
g_loop = event_loop_create();
|
||||
DBHandle *db =
|
||||
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
|
||||
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
|
||||
"127.0.0.1", 0, NULL);
|
||||
db_attach(db, g_loop, false);
|
||||
RetryInfo retry = {
|
||||
.num_retries = 0,
|
||||
@@ -380,8 +403,8 @@ void publish_late_done_callback(TaskID task_id, void *user_context) {
|
||||
|
||||
TEST publish_late_test(void) {
|
||||
g_loop = event_loop_create();
|
||||
DBHandle *db =
|
||||
db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1", 0, NULL);
|
||||
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
|
||||
"127.0.0.1", 0, NULL);
|
||||
db_attach(db, g_loop, false);
|
||||
Task *task = example_task(1, 1, TASK_STATUS_WAITING);
|
||||
RetryInfo retry = {
|
||||
@@ -389,6 +412,8 @@ TEST publish_late_test(void) {
|
||||
.timeout = 0,
|
||||
.fail_callback = publish_late_fail_callback,
|
||||
};
|
||||
task_table_subscribe(db, NIL_ID, TASK_STATUS_WAITING, NULL, NULL, NULL, NULL,
|
||||
NULL);
|
||||
task_table_add_task(db, task, &retry, publish_late_done_callback,
|
||||
(void *) publish_late_context);
|
||||
/* Install handler for terminating the event loop. */
|
||||
|
||||
@@ -2,11 +2,13 @@
|
||||
#define TEST_COMMON_H
|
||||
|
||||
#include <unistd.h>
|
||||
#include <vector>
|
||||
|
||||
#include "common.h"
|
||||
#include "io.h"
|
||||
#include "hiredis/hiredis.h"
|
||||
#include "utstring.h"
|
||||
#include "state/redis.h"
|
||||
|
||||
#ifndef _WIN32
|
||||
/* This function is actually not declared in standard POSIX, so declare it. */
|
||||
@@ -48,10 +50,29 @@ static inline int bind_inet_sock_retry(int *fd) {
|
||||
}
|
||||
|
||||
/* Flush redis. */
|
||||
static inline void flushall_redis() {
|
||||
static inline void flushall_redis(void) {
|
||||
/* Flush the primary shard. */
|
||||
redisContext *context = redisConnect("127.0.0.1", 6379);
|
||||
std::vector<std::string> db_shards_addresses;
|
||||
std::vector<int> db_shards_ports;
|
||||
get_redis_shards(context, db_shards_addresses, db_shards_ports);
|
||||
freeReplyObject(redisCommand(context, "FLUSHALL"));
|
||||
/* Readd the shard locations. */
|
||||
freeReplyObject(redisCommand(context, "SET NumRedisShards %d",
|
||||
db_shards_addresses.size()));
|
||||
for (int i = 0; i < db_shards_addresses.size(); ++i) {
|
||||
freeReplyObject(redisCommand(context, "RPUSH RedisShards %s:%d",
|
||||
db_shards_addresses[i].c_str(),
|
||||
db_shards_ports[i]));
|
||||
}
|
||||
redisFree(context);
|
||||
|
||||
/* Flush the remaining shards. */
|
||||
for (int i = 0; i < db_shards_addresses.size(); ++i) {
|
||||
context = redisConnect(db_shards_addresses[i].c_str(), db_shards_ports[i]);
|
||||
freeReplyObject(redisCommand(context, "FLUSHALL"));
|
||||
redisFree(context);
|
||||
}
|
||||
}
|
||||
|
||||
/* Cleanup method for running tests with the greatest library.
|
||||
|
||||
@@ -62,14 +62,14 @@ void assign_task_to_local_scheduler(GlobalSchedulerState *state,
|
||||
|
||||
GlobalSchedulerState *GlobalSchedulerState_init(event_loop *loop,
|
||||
const char *node_ip_address,
|
||||
const char *redis_addr,
|
||||
int redis_port) {
|
||||
const char *redis_primary_addr,
|
||||
int redis_primary_port) {
|
||||
GlobalSchedulerState *state =
|
||||
(GlobalSchedulerState *) malloc(sizeof(GlobalSchedulerState));
|
||||
/* Must initialize state to 0. Sets hashmap head(s) to NULL. */
|
||||
memset(state, 0, sizeof(GlobalSchedulerState));
|
||||
state->db = db_connect(redis_addr, redis_port, "global_scheduler",
|
||||
node_ip_address, 0, NULL);
|
||||
state->db = db_connect(std::string(redis_primary_addr), redis_primary_port,
|
||||
"global_scheduler", node_ip_address, 0, NULL);
|
||||
db_attach(state->db, loop, false);
|
||||
utarray_new(state->local_schedulers, &local_scheduler_icd);
|
||||
state->policy_state = GlobalSchedulerPolicyState_init();
|
||||
@@ -253,26 +253,34 @@ void remove_local_scheduler(GlobalSchedulerState *state, int index) {
|
||||
* @param aux_address: an ip:port pair for the plasma manager associated with
|
||||
* this db client.
|
||||
*/
|
||||
void process_new_db_client(DBClientID db_client_id,
|
||||
const char *client_type,
|
||||
const char *aux_address,
|
||||
bool is_insertion,
|
||||
void *user_context) {
|
||||
void process_new_db_client(DBClient *db_client, void *user_context) {
|
||||
GlobalSchedulerState *state = (GlobalSchedulerState *) user_context;
|
||||
char id_string[ID_STRING_SIZE];
|
||||
LOG_DEBUG("db client table callback for db client = %s",
|
||||
ObjectID_to_string(db_client_id, id_string, ID_STRING_SIZE));
|
||||
ObjectID_to_string(db_client->id, id_string, ID_STRING_SIZE));
|
||||
UNUSED(id_string);
|
||||
if (strncmp(client_type, "local_scheduler", strlen("local_scheduler")) == 0) {
|
||||
if (is_insertion) {
|
||||
/* This is a notification for an insert. */
|
||||
add_local_scheduler(state, db_client_id, aux_address);
|
||||
if (strncmp(db_client->client_type, "local_scheduler",
|
||||
strlen("local_scheduler")) == 0) {
|
||||
if (db_client->is_insertion) {
|
||||
/* This is a notification for an insert. We may receive duplicate
|
||||
* notifications since we read the entire table before processing
|
||||
* notifications. Filter out local schedulers that we already added. */
|
||||
for (LocalScheduler *scheduler =
|
||||
(LocalScheduler *) utarray_front(state->local_schedulers);
|
||||
scheduler != NULL; scheduler = (LocalScheduler *) utarray_next(
|
||||
state->local_schedulers, scheduler)) {
|
||||
if (UNIQUE_ID_EQ(scheduler->id, db_client->id)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
add_local_scheduler(state, db_client->id, db_client->aux_address);
|
||||
} else {
|
||||
int i = 0;
|
||||
for (; i < utarray_len(state->local_schedulers); ++i) {
|
||||
LocalScheduler *active_worker =
|
||||
(LocalScheduler *) utarray_eltptr(state->local_schedulers, i);
|
||||
if (DBClientID_equal(active_worker->id, db_client_id)) {
|
||||
if (DBClientID_equal(active_worker->id, db_client->id)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -418,11 +426,11 @@ int heartbeat_timeout_handler(event_loop *loop, timer_id id, void *context) {
|
||||
}
|
||||
|
||||
void start_server(const char *node_ip_address,
|
||||
const char *redis_addr,
|
||||
int redis_port) {
|
||||
const char *redis_primary_addr,
|
||||
int redis_primary_port) {
|
||||
event_loop *loop = event_loop_create();
|
||||
g_state =
|
||||
GlobalSchedulerState_init(loop, node_ip_address, redis_addr, redis_port);
|
||||
g_state = GlobalSchedulerState_init(loop, node_ip_address, redis_primary_addr,
|
||||
redis_primary_port);
|
||||
/* TODO(rkn): subscribe to notifications from the object table. */
|
||||
/* Subscribe to notifications about new local schedulers. TODO(rkn): this
|
||||
* needs to also get all of the clients that registered with the database
|
||||
@@ -458,15 +466,15 @@ void start_server(const char *node_ip_address,
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
signal(SIGTERM, signal_handler);
|
||||
/* IP address and port of redis. */
|
||||
char *redis_addr_port = NULL;
|
||||
/* IP address and port of the primary redis instance. */
|
||||
char *redis_primary_addr_port = NULL;
|
||||
/* The IP address of the node that this global scheduler is running on. */
|
||||
char *node_ip_address = NULL;
|
||||
int c;
|
||||
while ((c = getopt(argc, argv, "h:r:")) != -1) {
|
||||
switch (c) {
|
||||
case 'r':
|
||||
redis_addr_port = optarg;
|
||||
redis_primary_addr_port = optarg;
|
||||
break;
|
||||
case 'h':
|
||||
node_ip_address = optarg;
|
||||
@@ -476,16 +484,18 @@ int main(int argc, char *argv[]) {
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
char redis_addr[16];
|
||||
int redis_port;
|
||||
if (!redis_addr_port ||
|
||||
parse_ip_addr_port(redis_addr_port, redis_addr, &redis_port) == -1) {
|
||||
LOG_ERROR(
|
||||
"specify the redis address like 127.0.0.1:6379 with the -r switch");
|
||||
exit(-1);
|
||||
|
||||
char redis_primary_addr[16];
|
||||
int redis_primary_port;
|
||||
if (!redis_primary_addr_port ||
|
||||
parse_ip_addr_port(redis_primary_addr_port, redis_primary_addr,
|
||||
&redis_primary_port) == -1) {
|
||||
LOG_FATAL(
|
||||
"specify the primary redis address like 127.0.0.1:6379 with the -r "
|
||||
"switch");
|
||||
}
|
||||
if (!node_ip_address) {
|
||||
LOG_FATAL("specify the node IP address with the -h switch");
|
||||
}
|
||||
start_server(node_ip_address, redis_addr, redis_port);
|
||||
start_server(node_ip_address, redis_primary_addr, redis_primary_port);
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "local_scheduler_shared.h"
|
||||
#include "local_scheduler.h"
|
||||
#include "local_scheduler_algorithm.h"
|
||||
#include "net.h"
|
||||
#include "state/actor_notification_table.h"
|
||||
#include "state/db.h"
|
||||
#include "state/driver_table.h"
|
||||
@@ -296,8 +297,8 @@ const char **parse_command(const char *command) {
|
||||
LocalSchedulerState *LocalSchedulerState_init(
|
||||
const char *node_ip_address,
|
||||
event_loop *loop,
|
||||
const char *redis_addr,
|
||||
int redis_port,
|
||||
const char *redis_primary_addr,
|
||||
int redis_primary_port,
|
||||
const char *local_scheduler_socket_name,
|
||||
const char *plasma_store_socket_name,
|
||||
const char *plasma_manager_socket_name,
|
||||
@@ -323,7 +324,7 @@ LocalSchedulerState *LocalSchedulerState_init(
|
||||
state->loop = loop;
|
||||
|
||||
/* Connect to Redis if a Redis address is provided. */
|
||||
if (redis_addr != NULL) {
|
||||
if (redis_primary_addr != NULL) {
|
||||
int num_args;
|
||||
const char **db_connect_args = NULL;
|
||||
/* Use UT_string to convert the resource value into a string. */
|
||||
@@ -354,8 +355,9 @@ LocalSchedulerState *LocalSchedulerState_init(
|
||||
db_connect_args[4] = "num_gpus";
|
||||
db_connect_args[5] = utstring_body(num_gpus);
|
||||
}
|
||||
state->db = db_connect(redis_addr, redis_port, "local_scheduler",
|
||||
node_ip_address, num_args, db_connect_args);
|
||||
state->db = db_connect(std::string(redis_primary_addr), redis_primary_port,
|
||||
"local_scheduler", node_ip_address, num_args,
|
||||
db_connect_args);
|
||||
utstring_free(num_cpus);
|
||||
utstring_free(num_gpus);
|
||||
free(db_connect_args);
|
||||
@@ -548,8 +550,6 @@ void process_plasma_notification(event_loop *loop,
|
||||
void reconstruct_task_update_callback(Task *task,
|
||||
void *user_context,
|
||||
bool updated) {
|
||||
/* The task ID should be in the task table. */
|
||||
CHECK(task != NULL);
|
||||
if (!updated) {
|
||||
/* The test-and-set of the task's scheduling state failed, so the task was
|
||||
* either not finished yet, or it was already being reconstructed.
|
||||
@@ -578,7 +578,6 @@ void reconstruct_task_update_callback(Task *task,
|
||||
void reconstruct_put_task_update_callback(Task *task,
|
||||
void *user_context,
|
||||
bool updated) {
|
||||
CHECK(task != NULL);
|
||||
if (updated) {
|
||||
/* The update to TASK_STATUS_RECONSTRUCTING succeeded, so continue with
|
||||
* reconstruction as usual. */
|
||||
@@ -1111,8 +1110,8 @@ int heartbeat_handler(event_loop *loop, timer_id id, void *context) {
|
||||
|
||||
void start_server(const char *node_ip_address,
|
||||
const char *socket_name,
|
||||
const char *redis_addr,
|
||||
int redis_port,
|
||||
const char *redis_primary_addr,
|
||||
int redis_primary_port,
|
||||
const char *plasma_store_socket_name,
|
||||
const char *plasma_manager_socket_name,
|
||||
const char *plasma_manager_address,
|
||||
@@ -1126,8 +1125,8 @@ void start_server(const char *node_ip_address,
|
||||
int fd = bind_ipc_sock(socket_name, true);
|
||||
event_loop *loop = event_loop_create();
|
||||
g_state = LocalSchedulerState_init(
|
||||
node_ip_address, loop, redis_addr, redis_port, socket_name,
|
||||
plasma_store_socket_name, plasma_manager_socket_name,
|
||||
node_ip_address, loop, redis_primary_addr, redis_primary_port,
|
||||
socket_name, plasma_store_socket_name, plasma_manager_socket_name,
|
||||
plasma_manager_address, global_scheduler_exists, static_resource_conf,
|
||||
start_worker_command, num_workers);
|
||||
/* Register a callback for registering new clients. */
|
||||
@@ -1173,8 +1172,8 @@ int main(int argc, char *argv[]) {
|
||||
signal(SIGTERM, signal_handler);
|
||||
/* Path of the listening socket of the local scheduler. */
|
||||
char *scheduler_socket_name = NULL;
|
||||
/* IP address and port of redis. */
|
||||
char *redis_addr_port = NULL;
|
||||
/* IP address and port of the primary redis instance. */
|
||||
char *redis_primary_addr_port = NULL;
|
||||
/* Socket name for the local Plasma store. */
|
||||
char *plasma_store_socket_name = NULL;
|
||||
/* Socket name for the local Plasma manager. */
|
||||
@@ -1199,7 +1198,7 @@ int main(int argc, char *argv[]) {
|
||||
scheduler_socket_name = optarg;
|
||||
break;
|
||||
case 'r':
|
||||
redis_addr_port = optarg;
|
||||
redis_primary_addr_port = optarg;
|
||||
break;
|
||||
case 'p':
|
||||
plasma_store_socket_name = optarg;
|
||||
@@ -1266,7 +1265,7 @@ int main(int argc, char *argv[]) {
|
||||
|
||||
char *redis_addr = NULL;
|
||||
int redis_port = -1;
|
||||
if (!redis_addr_port) {
|
||||
if (!redis_primary_addr_port) {
|
||||
/* Start the local scheduler without connecting to Redis. In this case, all
|
||||
* submitted tasks will be queued and scheduled locally. */
|
||||
if (plasma_manager_socket_name) {
|
||||
@@ -1275,27 +1274,22 @@ int main(int argc, char *argv[]) {
|
||||
"then a redis address must be provided with the -r switch");
|
||||
}
|
||||
} else {
|
||||
char redis_addr_buffer[16] = {0};
|
||||
char redis_port_str[6] = {0};
|
||||
/* Parse the Redis address into an IP address and a port. */
|
||||
int num_assigned = sscanf(redis_addr_port, "%15[0-9.]:%5[0-9]",
|
||||
redis_addr_buffer, redis_port_str);
|
||||
if (num_assigned != 2) {
|
||||
char redis_primary_addr[16];
|
||||
int redis_primary_port;
|
||||
/* Parse the primary Redis address into an IP address and a port. */
|
||||
if (parse_ip_addr_port(redis_primary_addr_port, redis_primary_addr,
|
||||
&redis_primary_port) == -1) {
|
||||
LOG_FATAL(
|
||||
"if a redis address is provided with the -r switch, it should be "
|
||||
"formatted like 127.0.0.1:6379");
|
||||
}
|
||||
redis_addr = redis_addr_buffer;
|
||||
redis_port = strtol(redis_port_str, NULL, 10);
|
||||
if (redis_port == 0) {
|
||||
LOG_FATAL("Unable to parse port number from redis address %s",
|
||||
redis_addr_port);
|
||||
}
|
||||
if (!plasma_manager_socket_name) {
|
||||
LOG_FATAL(
|
||||
"please specify socket for connecting to Plasma manager with -m "
|
||||
"switch");
|
||||
}
|
||||
redis_addr = redis_primary_addr;
|
||||
redis_port = redis_primary_port;
|
||||
}
|
||||
|
||||
start_server(node_ip_address, scheduler_socket_name, redis_addr, redis_port,
|
||||
|
||||
@@ -839,7 +839,8 @@ void handle_actor_task_submitted(LocalSchedulerState *state,
|
||||
/* Add this task to a queue of tasks that have been submitted but the local
|
||||
* scheduler doesn't know which actor is responsible for them. These tasks
|
||||
* will be resubmitted (internally by the local scheduler) whenever a new
|
||||
* actor notification arrives. */
|
||||
* actor notification arrives. NOTE(swang): These tasks have not yet been
|
||||
* added to the task table. */
|
||||
utarray_push_back(algorithm_state->cached_submitted_actor_tasks, &spec);
|
||||
utarray_push_back(algorithm_state->cached_submitted_actor_task_sizes,
|
||||
&task_spec_size);
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "task.h"
|
||||
#include "state/object_table.h"
|
||||
#include "state/task_table.h"
|
||||
#include "state/redis.h"
|
||||
|
||||
#include "local_scheduler_shared.h"
|
||||
#include "local_scheduler.h"
|
||||
@@ -182,7 +183,16 @@ TEST object_reconstruction_test(void) {
|
||||
/* Add an empty object table entry for the object we want to reconstruct, to
|
||||
* simulate it having been created and evicted. */
|
||||
const char *client_id = "clientid";
|
||||
/* Lookup the shard locations for the object table. */
|
||||
std::vector<std::string> db_shards_addresses;
|
||||
std::vector<int> db_shards_ports;
|
||||
redisContext *context = redisConnect("127.0.0.1", 6379);
|
||||
get_redis_shards(context, db_shards_addresses, db_shards_ports);
|
||||
redisFree(context);
|
||||
/* There should only be one shard, so we can safely add the empty object
|
||||
* table entry to the first one. */
|
||||
ASSERT(db_shards_addresses.size() == 1);
|
||||
context = redisConnect(db_shards_addresses[0].c_str(), db_shards_ports[0]);
|
||||
redisReply *reply = (redisReply *) redisCommand(
|
||||
context, "RAY.OBJECT_TABLE_ADD %b %ld %b %s", return_id.id,
|
||||
sizeof(return_id.id), 1, NIL_DIGEST, (size_t) DIGEST_SIZE, client_id);
|
||||
@@ -273,7 +283,16 @@ TEST object_reconstruction_recursive_test(void) {
|
||||
/* Add an empty object table entry for each object we want to reconstruct, to
|
||||
* simulate their having been created and evicted. */
|
||||
const char *client_id = "clientid";
|
||||
/* Lookup the shard locations for the object table. */
|
||||
std::vector<std::string> db_shards_addresses;
|
||||
std::vector<int> db_shards_ports;
|
||||
redisContext *context = redisConnect("127.0.0.1", 6379);
|
||||
get_redis_shards(context, db_shards_addresses, db_shards_ports);
|
||||
redisFree(context);
|
||||
/* There should only be one shard, so we can safely add the empty object
|
||||
* table entry to the first one. */
|
||||
ASSERT(db_shards_addresses.size() == 1);
|
||||
context = redisConnect(db_shards_addresses[0].c_str(), db_shards_ports[0]);
|
||||
for (int i = 0; i < NUM_TASKS; ++i) {
|
||||
ObjectID return_id = TaskSpec_return(specs[i], 0);
|
||||
redisReply *reply = (redisReply *) redisCommand(
|
||||
@@ -406,8 +425,8 @@ TEST object_reconstruction_suppression_test(void) {
|
||||
} else {
|
||||
/* Connect a plasma manager client so we can call object_table_add. */
|
||||
const char *db_connect_args[] = {"address", "127.0.0.1:12346"};
|
||||
DBHandle *db = db_connect("127.0.0.1", 6379, "plasma_manager", "127.0.0.1",
|
||||
2, db_connect_args);
|
||||
DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager",
|
||||
"127.0.0.1", 2, db_connect_args);
|
||||
db_attach(db, local_scheduler->loop, false);
|
||||
/* Add the object to the object table. */
|
||||
object_table_add(db, return_id, 1, (unsigned char *) NIL_DIGEST, NULL,
|
||||
|
||||
@@ -5,10 +5,17 @@
|
||||
# Cause the script to exit if a single command fails.
|
||||
set -e
|
||||
|
||||
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so &
|
||||
# Start the Redis shards.
|
||||
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6379 &
|
||||
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6380 &
|
||||
sleep 1s
|
||||
# Register the shard location with the primary shard.
|
||||
./src/common/thirdparty/redis/src/redis-cli set NumRedisShards 1
|
||||
./src/common/thirdparty/redis/src/redis-cli rpush RedisShards 127.0.0.1:6380
|
||||
|
||||
./src/plasma/plasma_store -s /tmp/plasma_store_socket_1 -m 100000000 &
|
||||
sleep 0.5s
|
||||
./src/local_scheduler/local_scheduler_tests
|
||||
./src/common/thirdparty/redis/src/redis-cli shutdown
|
||||
./src/common/thirdparty/redis/src/redis-cli -p 6380 shutdown
|
||||
killall plasma_store
|
||||
|
||||
@@ -5,10 +5,17 @@
|
||||
# Cause the script to exit if a single command fails.
|
||||
set -e
|
||||
|
||||
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so &
|
||||
# Start the Redis shards.
|
||||
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6379 &
|
||||
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6380 &
|
||||
sleep 1s
|
||||
# Register the shard location with the primary shard.
|
||||
./src/common/thirdparty/redis/src/redis-cli set NumRedisShards 1
|
||||
./src/common/thirdparty/redis/src/redis-cli rpush RedisShards 127.0.0.1:6380
|
||||
|
||||
./src/plasma/plasma_store -s /tmp/plasma_store_socket_1 -m 100000000 &
|
||||
sleep 0.5s
|
||||
valgrind --leak-check=full --show-leak-kinds=all --error-exitcode=1 ./src/local_scheduler/local_scheduler_tests
|
||||
./src/common/thirdparty/redis/src/redis-cli shutdown
|
||||
./src/common/thirdparty/redis/src/redis-cli -p 6380 shutdown
|
||||
killall plasma_store
|
||||
|
||||
@@ -493,8 +493,8 @@ PlasmaManagerState *PlasmaManagerState_init(const char *store_socket_name,
|
||||
const char *manager_socket_name,
|
||||
const char *manager_addr,
|
||||
int manager_port,
|
||||
const char *db_addr,
|
||||
int db_port) {
|
||||
const char *redis_primary_addr,
|
||||
int redis_primary_port) {
|
||||
PlasmaManagerState *state =
|
||||
(PlasmaManagerState *) malloc(sizeof(PlasmaManagerState));
|
||||
state->loop = event_loop_create();
|
||||
@@ -504,7 +504,7 @@ PlasmaManagerState *PlasmaManagerState_init(const char *store_socket_name,
|
||||
state->fetch_requests = NULL;
|
||||
state->object_wait_requests_local = NULL;
|
||||
state->object_wait_requests_remote = NULL;
|
||||
if (db_addr) {
|
||||
if (redis_primary_addr) {
|
||||
/* Get the manager port as a string. */
|
||||
UT_string *manager_address_str;
|
||||
utstring_new(manager_address_str);
|
||||
@@ -519,8 +519,9 @@ PlasmaManagerState *PlasmaManagerState_init(const char *store_socket_name,
|
||||
db_connect_args[3] = manager_socket_name;
|
||||
db_connect_args[4] = "address";
|
||||
db_connect_args[5] = utstring_body(manager_address_str);
|
||||
state->db = db_connect(db_addr, db_port, "plasma_manager", manager_addr,
|
||||
num_args, db_connect_args);
|
||||
state->db =
|
||||
db_connect(std::string(redis_primary_addr), redis_primary_port,
|
||||
"plasma_manager", manager_addr, num_args, db_connect_args);
|
||||
utstring_free(manager_address_str);
|
||||
free(db_connect_args);
|
||||
db_attach(state->db, state->loop, false);
|
||||
@@ -1594,8 +1595,8 @@ void start_server(const char *store_socket_name,
|
||||
const char *manager_socket_name,
|
||||
const char *master_addr,
|
||||
int port,
|
||||
const char *db_addr,
|
||||
int db_port) {
|
||||
const char *redis_primary_addr,
|
||||
int redis_primary_port) {
|
||||
/* Ignore SIGPIPE signals. If we don't do this, then when we attempt to write
|
||||
* to a client that has already died, the manager could die. */
|
||||
signal(SIGPIPE, SIG_IGN);
|
||||
@@ -1610,9 +1611,9 @@ void start_server(const char *store_socket_name,
|
||||
int local_sock = bind_ipc_sock(manager_socket_name, false);
|
||||
CHECKM(local_sock >= 0, "Unable to bind local manager socket");
|
||||
|
||||
g_manager_state =
|
||||
PlasmaManagerState_init(store_socket_name, manager_socket_name,
|
||||
master_addr, port, db_addr, db_port);
|
||||
g_manager_state = PlasmaManagerState_init(
|
||||
store_socket_name, manager_socket_name, master_addr, port,
|
||||
redis_primary_addr, redis_primary_port);
|
||||
CHECK(g_manager_state);
|
||||
|
||||
CHECK(listen(remote_sock, 5) != -1);
|
||||
@@ -1664,8 +1665,8 @@ int main(int argc, char *argv[]) {
|
||||
char *master_addr = NULL;
|
||||
/* Port number the manager should use. */
|
||||
int port = -1;
|
||||
/* IP address and port of state database. */
|
||||
char *db_host = NULL;
|
||||
/* IP address and port of the primary redis instance. */
|
||||
char *redis_primary_addr_port = NULL;
|
||||
int c;
|
||||
while ((c = getopt(argc, argv, "s:m:h:p:r:")) != -1) {
|
||||
switch (c) {
|
||||
@@ -1682,7 +1683,7 @@ int main(int argc, char *argv[]) {
|
||||
port = atoi(optarg);
|
||||
break;
|
||||
case 'r':
|
||||
db_host = optarg;
|
||||
redis_primary_addr_port = optarg;
|
||||
break;
|
||||
default:
|
||||
LOG_FATAL("unknown option %c", c);
|
||||
@@ -1708,15 +1709,16 @@ int main(int argc, char *argv[]) {
|
||||
"please specify port the plasma manager shall listen to in the"
|
||||
"format 12345 with -p switch");
|
||||
}
|
||||
char db_addr[16];
|
||||
int db_port;
|
||||
if (db_host) {
|
||||
parse_ip_addr_port(db_host, db_addr, &db_port);
|
||||
start_server(store_socket_name, manager_socket_name, master_addr, port,
|
||||
db_addr, db_port);
|
||||
} else {
|
||||
start_server(store_socket_name, manager_socket_name, master_addr, port,
|
||||
NULL, 0);
|
||||
char redis_primary_addr[16];
|
||||
int redis_primary_port;
|
||||
if (!redis_primary_addr_port ||
|
||||
parse_ip_addr_port(redis_primary_addr_port, redis_primary_addr,
|
||||
&redis_primary_port) == -1) {
|
||||
LOG_FATAL(
|
||||
"specify the primary redis address like 127.0.0.1:6379 with the -r "
|
||||
"switch");
|
||||
}
|
||||
start_server(store_socket_name, manager_socket_name, master_addr, port,
|
||||
redis_primary_addr, redis_primary_port);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -9,11 +9,18 @@ sleep 1
|
||||
killall plasma_store
|
||||
./src/plasma/serialization_tests
|
||||
|
||||
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so &
|
||||
redis_pid=$!
|
||||
# Start the Redis shards.
|
||||
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6379 &
|
||||
redis_pid1=$!
|
||||
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6380 &
|
||||
redis_pid2=$!
|
||||
sleep 1
|
||||
# flush the redis server
|
||||
./src/common/thirdparty/redis/src/redis-cli flushall &
|
||||
|
||||
# Flush the redis server
|
||||
./src/common/thirdparty/redis/src/redis-cli flushall
|
||||
# Register the shard location with the primary shard.
|
||||
./src/common/thirdparty/redis/src/redis-cli set NumRedisShards 1
|
||||
./src/common/thirdparty/redis/src/redis-cli rpush RedisShards 127.0.0.1:6380
|
||||
sleep 1
|
||||
./src/plasma/plasma_store -s /tmp/store1 -m 1000000000 &
|
||||
plasma1_pid=$!
|
||||
@@ -31,5 +38,7 @@ kill $plasma4_pid
|
||||
kill $plasma3_pid
|
||||
kill $plasma2_pid
|
||||
kill $plasma1_pid
|
||||
kill $redis_pid
|
||||
wait $redis_pid
|
||||
kill $redis_pid1
|
||||
wait $redis_pid1
|
||||
kill $redis_pid2
|
||||
wait $redis_pid2
|
||||
|
||||
@@ -86,8 +86,8 @@ class DockerRunner(object):
|
||||
else:
|
||||
return m.group(1)
|
||||
|
||||
def _start_head_node(self, docker_image, mem_size, shm_size, num_cpus,
|
||||
num_gpus, development_mode):
|
||||
def _start_head_node(self, docker_image, mem_size, shm_size,
|
||||
num_redis_shards, num_cpus, num_gpus, development_mode):
|
||||
"""Start the Ray head node inside a docker container."""
|
||||
mem_arg = ["--memory=" + mem_size] if mem_size else []
|
||||
shm_arg = ["--shm-size=" + shm_size] if shm_size else []
|
||||
@@ -99,6 +99,7 @@ class DockerRunner(object):
|
||||
command = (["docker", "run", "-d"] + mem_arg + shm_arg + volume_arg +
|
||||
[docker_image, "/ray/scripts/start_ray.sh", "--head",
|
||||
"--redis-port=6379",
|
||||
"--num-redis-shards={}".format(num_redis_shards),
|
||||
"--num-cpus={}".format(num_cpus),
|
||||
"--num-gpus={}".format(num_gpus)])
|
||||
print("Starting head node with command:{}".format(command))
|
||||
@@ -137,8 +138,8 @@ class DockerRunner(object):
|
||||
self.worker_container_ids.append(container_id)
|
||||
|
||||
def start_ray(self, docker_image=None, mem_size=None, shm_size=None,
|
||||
num_nodes=None, num_cpus=None, num_gpus=None,
|
||||
development_mode=None):
|
||||
num_nodes=None, num_redis_shards=1, num_cpus=None,
|
||||
num_gpus=None, development_mode=None):
|
||||
"""Start a Ray cluster within docker.
|
||||
|
||||
This starts one docker container running the head node and num_nodes - 1
|
||||
@@ -153,6 +154,7 @@ class DockerRunner(object):
|
||||
with. This will be passed into `docker run` as the `--shm-size` flag.
|
||||
num_nodes: The number of nodes to use in the cluster (this counts the
|
||||
head node as well).
|
||||
num_redis_shards: The number of Redis shards to use on the head node.
|
||||
num_cpus: A list of the number of CPUs to start each node with.
|
||||
num_gpus: A list of the number of GPUs to start each node with.
|
||||
development_mode: True if you want to mount the local copy of
|
||||
@@ -163,8 +165,8 @@ class DockerRunner(object):
|
||||
assert len(num_gpus) == num_nodes
|
||||
|
||||
# Launch the head node.
|
||||
self._start_head_node(docker_image, mem_size, shm_size, num_cpus[0],
|
||||
num_gpus[0], development_mode)
|
||||
self._start_head_node(docker_image, mem_size, shm_size, num_redis_shards,
|
||||
num_cpus[0], num_gpus[0], development_mode)
|
||||
# Start the worker nodes.
|
||||
for i in range(num_nodes - 1):
|
||||
self._start_worker_node(docker_image, mem_size, shm_size,
|
||||
@@ -252,6 +254,9 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--shm-size", default="1G", help="shared memory size")
|
||||
parser.add_argument("--num-nodes", default=1, type=int,
|
||||
help="number of nodes to use in the cluster")
|
||||
parser.add_argument("--num-redis-shards", default=1, type=int,
|
||||
help=("the number of Redis shards to start on the head "
|
||||
"node"))
|
||||
parser.add_argument("--num-cpus", type=str,
|
||||
help=("a comma separated list of values representing "
|
||||
"the number of CPUs to start each node with"))
|
||||
@@ -282,8 +287,8 @@ if __name__ == "__main__":
|
||||
d = DockerRunner()
|
||||
d.start_ray(docker_image=args.docker_image, mem_size=args.mem_size,
|
||||
shm_size=args.shm_size, num_nodes=num_nodes,
|
||||
num_cpus=num_cpus, num_gpus=num_gpus,
|
||||
development_mode=args.development_mode)
|
||||
num_redis_shards=args.num_redis_shards, num_cpus=num_cpus,
|
||||
num_gpus=num_gpus, development_mode=args.development_mode)
|
||||
try:
|
||||
run_results = d.run_test(args.test_script, args.num_drivers,
|
||||
driver_locations=driver_locations)
|
||||
|
||||
@@ -14,11 +14,13 @@ echo "Using Docker image" $DOCKER_SHA
|
||||
python $ROOT_DIR/multi_node_docker_test.py \
|
||||
--docker-image=$DOCKER_SHA \
|
||||
--num-nodes=5 \
|
||||
--num-redis-shards=10 \
|
||||
--test-script=/ray/test/jenkins_tests/multi_node_tests/test_0.py
|
||||
|
||||
python $ROOT_DIR/multi_node_docker_test.py \
|
||||
--docker-image=$DOCKER_SHA \
|
||||
--num-nodes=5 \
|
||||
--num-redis-shards=5 \
|
||||
--num-gpus=0,1,2,3,4 \
|
||||
--num-drivers=7 \
|
||||
--driver-locations=0,1,0,1,2,3,4 \
|
||||
@@ -27,6 +29,7 @@ python $ROOT_DIR/multi_node_docker_test.py \
|
||||
python $ROOT_DIR/multi_node_docker_test.py \
|
||||
--docker-image=$DOCKER_SHA \
|
||||
--num-nodes=5 \
|
||||
--num-redis-shards=2 \
|
||||
--num-gpus=0,0,5,6,50 \
|
||||
--num-drivers=100 \
|
||||
--test-script=/ray/test/jenkins_tests/multi_node_tests/many_drivers_test.py
|
||||
|
||||
+34
-42
@@ -293,8 +293,16 @@ class WorkerTest(unittest.TestCase):
|
||||
|
||||
class APITest(unittest.TestCase):
|
||||
|
||||
def init_ray(self, kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
ray.init(**kwargs)
|
||||
|
||||
def tearDown(self):
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testRegisterClass(self):
|
||||
ray.init(num_workers=2)
|
||||
self.init_ray({"num_workers": 2})
|
||||
|
||||
# Check that putting an object of a class that has not been registered
|
||||
# throws an exception.
|
||||
@@ -417,11 +425,9 @@ class APITest(unittest.TestCase):
|
||||
self.assertFalse(hasattr(c2, "method0"))
|
||||
self.assertFalse(hasattr(c2, "method1"))
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testKeywordArgs(self):
|
||||
reload(test_functions)
|
||||
ray.init(num_workers=1)
|
||||
self.init_ray()
|
||||
|
||||
x = test_functions.keyword_fct1.remote(1)
|
||||
self.assertEqual(ray.get(x), "1 hello")
|
||||
@@ -483,11 +489,9 @@ class APITest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(ray.get(f3.remote(4)), 4)
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testVariableNumberOfArgs(self):
|
||||
reload(test_functions)
|
||||
ray.init(num_workers=1)
|
||||
self.init_ray()
|
||||
|
||||
x = test_functions.varargs_fct1.remote(0, 1, 2)
|
||||
self.assertEqual(ray.get(x), "0 1 2")
|
||||
@@ -516,18 +520,14 @@ class APITest(unittest.TestCase):
|
||||
self.assertEqual(ray.get(f2.remote(1, 2, 3)), (1, 2, (3,)))
|
||||
self.assertEqual(ray.get(f2.remote(1, 2, 3, 4)), (1, 2, (3, 4)))
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testNoArgs(self):
|
||||
reload(test_functions)
|
||||
ray.init(num_workers=1)
|
||||
self.init_ray()
|
||||
|
||||
ray.get(test_functions.no_op.remote())
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testDefiningRemoteFunctions(self):
|
||||
ray.init(num_workers=3, num_cpus=3)
|
||||
self.init_ray({"num_cpus": 3})
|
||||
|
||||
# Test that we can define a remote function in the shell.
|
||||
@ray.remote
|
||||
@@ -584,10 +584,8 @@ class APITest(unittest.TestCase):
|
||||
self.assertEqual(ray.get(l.remote(1)), 2)
|
||||
self.assertEqual(ray.get(m.remote(1)), 2)
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testGetMultiple(self):
|
||||
ray.init(num_workers=0)
|
||||
self.init_ray()
|
||||
object_ids = [ray.put(i) for i in range(10)]
|
||||
self.assertEqual(ray.get(object_ids), list(range(10)))
|
||||
|
||||
@@ -597,10 +595,8 @@ class APITest(unittest.TestCase):
|
||||
results = ray.get([object_ids[i] for i in indices])
|
||||
self.assertEqual(results, indices)
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testWait(self):
|
||||
ray.init(num_workers=1, num_cpus=1)
|
||||
self.init_ray({"num_cpus": 1})
|
||||
|
||||
@ray.remote
|
||||
def f(delay):
|
||||
@@ -633,12 +629,10 @@ class APITest(unittest.TestCase):
|
||||
x = ray.put(1)
|
||||
self.assertRaises(Exception, lambda: ray.wait([x, x]))
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testMultipleWaitsAndGets(self):
|
||||
# It is important to use three workers here, so that the three tasks
|
||||
# launched in this experiment can run at the same time.
|
||||
ray.init(num_workers=3)
|
||||
self.init_ray()
|
||||
|
||||
@ray.remote
|
||||
def f(delay):
|
||||
@@ -665,8 +659,6 @@ class APITest(unittest.TestCase):
|
||||
x = f.remote(1)
|
||||
ray.get([h.remote([x]), h.remote([x])])
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testCachingEnvironmentVariables(self):
|
||||
# Test that we can define environment variables before the driver is
|
||||
# connected.
|
||||
@@ -690,15 +682,13 @@ class APITest(unittest.TestCase):
|
||||
ray.env.bar.append(1)
|
||||
return ray.env.bar
|
||||
|
||||
ray.init(num_workers=2)
|
||||
self.init_ray()
|
||||
|
||||
self.assertEqual(ray.get(use_foo.remote()), 1)
|
||||
self.assertEqual(ray.get(use_foo.remote()), 1)
|
||||
self.assertEqual(ray.get(use_bar.remote()), [1])
|
||||
self.assertEqual(ray.get(use_bar.remote()), [1])
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testCachingFunctionsToRun(self):
|
||||
# Test that we export functions to run on all workers before the driver is
|
||||
# connected.
|
||||
@@ -718,7 +708,7 @@ class APITest(unittest.TestCase):
|
||||
sys.path.append(4)
|
||||
ray.worker.global_worker.run_function_on_all_workers(f)
|
||||
|
||||
ray.init(num_workers=2)
|
||||
self.init_ray()
|
||||
|
||||
@ray.remote
|
||||
def get_state():
|
||||
@@ -738,10 +728,8 @@ class APITest(unittest.TestCase):
|
||||
sys.path.pop()
|
||||
ray.worker.global_worker.run_function_on_all_workers(f)
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testRunningFunctionOnAllWorkers(self):
|
||||
ray.init(num_workers=1)
|
||||
self.init_ray()
|
||||
|
||||
def f(worker_info):
|
||||
sys.path.append("fake_directory")
|
||||
@@ -764,10 +752,8 @@ class APITest(unittest.TestCase):
|
||||
return sys.path
|
||||
self.assertTrue("fake_directory" not in ray.get(get_path2.remote()))
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testLoggingAPI(self):
|
||||
ray.init(num_workers=1, driver_mode=ray.SILENT_MODE)
|
||||
self.init_ray({"driver_mode": ray.SILENT_MODE})
|
||||
|
||||
def events():
|
||||
# This is a hack for getting the event log. It is not part of the API.
|
||||
@@ -815,12 +801,10 @@ class APITest(unittest.TestCase):
|
||||
wait_for_num_events(3)
|
||||
self.assertEqual(len(events()), 3)
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testIdenticalFunctionNames(self):
|
||||
# Define a bunch of remote functions and make sure that we don't
|
||||
# accidentally call an older version.
|
||||
ray.init(num_workers=2)
|
||||
self.init_ray()
|
||||
|
||||
num_calls = 200
|
||||
|
||||
@@ -878,10 +862,8 @@ class APITest(unittest.TestCase):
|
||||
result_values = ray.get([g.remote() for _ in range(num_calls)])
|
||||
self.assertEqual(result_values, num_calls * [5])
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testIllegalAPICalls(self):
|
||||
ray.init(num_workers=0)
|
||||
self.init_ray()
|
||||
|
||||
# Verify that we cannot call put on an ObjectID.
|
||||
x = ray.put(1)
|
||||
@@ -891,7 +873,16 @@ class APITest(unittest.TestCase):
|
||||
with self.assertRaises(Exception):
|
||||
ray.get(3)
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
class APITestSharded(APITest):
|
||||
|
||||
def init_ray(self, kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
kwargs["start_ray_local"] = True
|
||||
kwargs["num_redis_shards"] = 20
|
||||
kwargs["redirect_output"] = True
|
||||
ray.worker._init(**kwargs)
|
||||
|
||||
|
||||
class PythonModeTest(unittest.TestCase):
|
||||
@@ -1619,7 +1610,8 @@ class GlobalStateAPI(unittest.TestCase):
|
||||
task_table = ray.global_state.task_table()
|
||||
self.assertEqual(len(task_table), 1)
|
||||
self.assertEqual(driver_task_id, list(task_table.keys())[0])
|
||||
self.assertEqual(task_table[driver_task_id]["State"], "RUNNING")
|
||||
self.assertEqual(task_table[driver_task_id]["State"],
|
||||
ray.experimental.state.TASK_STATUS_RUNNING)
|
||||
self.assertEqual(task_table[driver_task_id]["TaskSpec"]["TaskID"],
|
||||
driver_task_id)
|
||||
self.assertEqual(task_table[driver_task_id]["TaskSpec"]["ActorID"],
|
||||
|
||||
+26
-21
@@ -6,10 +6,6 @@ import unittest
|
||||
import ray
|
||||
import numpy as np
|
||||
import time
|
||||
import redis
|
||||
|
||||
# Import flatbuffer bindings.
|
||||
from ray.core.generated.TaskReply import TaskReply
|
||||
|
||||
|
||||
class TaskTests(unittest.TestCase):
|
||||
@@ -137,26 +133,38 @@ class ReconstructionTests(unittest.TestCase):
|
||||
num_local_schedulers = 1
|
||||
|
||||
def setUp(self):
|
||||
# Start a Redis instance and Plasma store instances with a total of 1GB
|
||||
# memory.
|
||||
# Start the Redis global state store.
|
||||
node_ip_address = "127.0.0.1"
|
||||
self.redis_port = ray.services.new_port()
|
||||
print(self.redis_port)
|
||||
redis_address = ray.services.address(node_ip_address, self.redis_port)
|
||||
redis_address, redis_shards = ray.services.start_redis(node_ip_address)
|
||||
self.redis_ip_address = ray.services.get_ip_address(redis_address)
|
||||
self.redis_port = ray.services.get_port(redis_address)
|
||||
time.sleep(0.1)
|
||||
|
||||
# Start the Plasma store instances with a total of 1GB memory.
|
||||
self.plasma_store_memory = 10 ** 9
|
||||
plasma_addresses = []
|
||||
objstore_memory = (self.plasma_store_memory // self.num_local_schedulers)
|
||||
for i in range(self.num_local_schedulers):
|
||||
store_stdout_file, store_stderr_file = ray.services.new_log_files(
|
||||
"plasma_store_{}".format(i), True)
|
||||
manager_stdout_file, manager_stderr_file = ray.services.new_log_files(
|
||||
"plasma_manager_{}".format(i), True)
|
||||
plasma_addresses.append(ray.services.start_objstore(
|
||||
node_ip_address, redis_address, objstore_memory=objstore_memory))
|
||||
address_info = {"redis_address": redis_address,
|
||||
"object_store_addresses": plasma_addresses}
|
||||
node_ip_address, redis_address, objstore_memory=objstore_memory,
|
||||
store_stdout_file=store_stdout_file,
|
||||
store_stderr_file=store_stderr_file,
|
||||
manager_stdout_file=manager_stdout_file,
|
||||
manager_stderr_file=manager_stderr_file))
|
||||
|
||||
# Start the rest of the services in the Ray cluster.
|
||||
address_info = {"redis_address": redis_address,
|
||||
"redis_shards": redis_shards,
|
||||
"object_store_addresses": plasma_addresses}
|
||||
ray.worker._init(address_info=address_info, start_ray_local=True,
|
||||
num_workers=1,
|
||||
num_local_schedulers=self.num_local_schedulers,
|
||||
num_cpus=[1] * self.num_local_schedulers,
|
||||
redirect_output=True,
|
||||
driver_mode=ray.SILENT_MODE)
|
||||
|
||||
def tearDown(self):
|
||||
@@ -164,14 +172,11 @@ class ReconstructionTests(unittest.TestCase):
|
||||
|
||||
# Determine the IDs of all local schedulers that had a task scheduled or
|
||||
# submitted.
|
||||
r = redis.StrictRedis(port=self.redis_port)
|
||||
task_ids = r.keys("TT:*")
|
||||
task_ids = [task_id[3:] for task_id in task_ids]
|
||||
local_scheduler_ids = []
|
||||
for task_id in task_ids:
|
||||
message = r.execute_command("ray.task_table_get", task_id)
|
||||
task_reply_object = TaskReply.GetRootAsTaskReply(message, 0)
|
||||
local_scheduler_ids.append(task_reply_object.LocalSchedulerId())
|
||||
state = ray.experimental.state.GlobalState()
|
||||
state._initialize_global_state(self.redis_ip_address, self.redis_port)
|
||||
tasks = state.task_table()
|
||||
local_scheduler_ids = set(task["LocalSchedulerID"] for task in
|
||||
tasks.values())
|
||||
|
||||
# Make sure that all nodes in the cluster were used by checking that the
|
||||
# set of local scheduler IDs that had a task scheduled or submitted is
|
||||
@@ -179,7 +184,7 @@ class ReconstructionTests(unittest.TestCase):
|
||||
# total number of local schedulers to account for NIL_LOCAL_SCHEDULER_ID.
|
||||
# This is the local scheduler ID associated with the driver task, since it
|
||||
# is not scheduled by a particular local scheduler.
|
||||
self.assertEqual(len(set(local_scheduler_ids)),
|
||||
self.assertEqual(len(local_scheduler_ids),
|
||||
self.num_local_schedulers + 1)
|
||||
|
||||
# Clean up the Ray cluster.
|
||||
|
||||
Reference in New Issue
Block a user