Shard Redis. (#539)

* Implement sharding in the Ray core

* Single node Python modifications to do sharding

* Do the sharding in redis.cc

* Pipe num_redis_shards through start_ray.py and worker.py.

* Use multiple redis shards in multinode tests.

* first steps for sharding ray.global_state

* Fix problem in multinode docker test.

* fix runtest.py

* fix some tests

* fix redis shard startup

* fix redis sharding

* fix

* fix bug introduced by the map-iterator being consumed

* fix sharding bug

* shard event table

* update number of Redis clients to be 64K

* Fix object table tests by flushing shards in between unit tests

* Fix local scheduler tests

* Documentation

* Register shard locations in the primary shard

* Add plasma unit tests back to build

* lint

* lint and fix build

* Fix

* Address Robert's comments

* Refactor start_ray_processes to start Redis shard

* lint

* Fix global scheduler python tests

* Fix redis module test

* Fix plasma test

* Fix component failure test

* Fix local scheduler test

* Fix runtest.py

* Fix global scheduler test for python3

* Fix task_table_test_and_update bug, from actor task table submission race

* Fix jenkins tests.

* Retry Redis shard connections

* Fix test cases

* Convert database clients to DBClient struct

* Fix race condition when subscribing to db client table

* Remove unused lines, add APITest for sharded Ray

* Fix

* Fix memory leak

* Suppress ReconstructionTests output

* Suppress output for APITestSharded

* Reissue task table add/update commands if initial command does not publish to any subscribers.

* fix

* Fix linting.

* fix tests

* fix linting

* fix python test

* fix linting
This commit is contained in:
Stephanie Wang
2017-05-18 17:40:41 -07:00
committed by Philipp Moritz
parent 0a4304725f
commit ee08c8274b
39 changed files with 1336 additions and 651 deletions
+41 -17
View File
@@ -52,7 +52,7 @@ def get_next_message(pubsub_client, timeout_seconds=10):
class TestGlobalStateStore(unittest.TestCase):
def setUp(self):
redis_port, _ = ray.services.start_redis()
redis_port, _ = ray.services.start_redis_instance()
self.redis = redis.StrictRedis(host="localhost", port=redis_port, db=0)
def tearDown(self):
@@ -308,6 +308,10 @@ class TestGlobalStateStore(unittest.TestCase):
TASK_STATUS_SCHEDULED = 2
TASK_STATUS_QUEUED = 4
# make sure somebody will get a notification (checked in the redis module)
p = self.redis.pubsub()
p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
def check_task_reply(message, task_args, updated=False):
task_status, local_scheduler_id, task_spec = task_args
task_reply_object = TaskReply.GetRootAsTaskReply(message, 0)
@@ -388,33 +392,53 @@ class TestGlobalStateStore(unittest.TestCase):
self.assertNotEqual(get_response, old_response)
check_task_reply(get_response, task_args[1:])
def check_task_subscription(self, p, scheduling_state, local_scheduler_id):
task_args = [b"task_id", scheduling_state,
local_scheduler_id.encode("ascii"), b"task_spec"]
self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args)
# Receive the data.
message = get_next_message(p)["data"]
# Check that the notification object is correct.
notification_object = TaskReply.GetRootAsTaskReply(message, 0)
self.assertEqual(notification_object.TaskId(), b"task_id")
self.assertEqual(notification_object.State(), scheduling_state)
self.assertEqual(notification_object.LocalSchedulerId(),
local_scheduler_id.encode("ascii"))
self.assertEqual(notification_object.TaskSpec(), b"task_spec")
def testTaskTableSubscribe(self):
scheduling_state = 1
local_scheduler_id = "local_scheduler_id"
# Subscribe to the task table.
p = self.redis.pubsub()
p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 1)
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
# unsubscribe to make sure there is only one subscriber at a given time
p.punsubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 0)
p.psubscribe("{prefix}*:{state}".format(
prefix=TASK_PREFIX, state=scheduling_state))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 1)
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
p.punsubscribe("{prefix}*:{state}".format(
prefix=TASK_PREFIX, state=scheduling_state))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 0)
p.psubscribe("{prefix}{local_scheduler_id}:*".format(
prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
task_args = [b"task_id", scheduling_state,
local_scheduler_id.encode("ascii"), b"task_spec"]
self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args)
# Receive the acknowledgement message.
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 1)
self.assertEqual(get_next_message(p)["data"], 2)
self.assertEqual(get_next_message(p)["data"], 3)
# Receive the actual data.
for i in range(3):
message = get_next_message(p)["data"]
# Check that the notification object is correct.
notification_object = TaskReply.GetRootAsTaskReply(message, 0)
self.assertEqual(notification_object.TaskId(), b"task_id")
self.assertEqual(notification_object.State(), scheduling_state)
self.assertEqual(notification_object.LocalSchedulerId(),
local_scheduler_id.encode("ascii"))
self.assertEqual(notification_object.TaskSpec(), b"task_spec")
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
p.punsubscribe("{prefix}{local_scheduler_id}:*".format(
prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 0)
if __name__ == "__main__":
+85 -26
View File
@@ -5,6 +5,7 @@ from __future__ import print_function
import pickle
import redis
import ray
from ray.utils import (decode, binary_to_object_id, binary_to_hex,
hex_to_binary)
@@ -25,14 +26,21 @@ OBJECT_CHANNEL_PREFIX = "OC:"
# This mapping from integer to task state string must be kept up-to-date with
# the scheduling_state enum in task.h.
task_state_mapping = {
1: "WAITING",
2: "SCHEDULED",
4: "QUEUED",
8: "RUNNING",
16: "DONE",
32: "LOST",
64: "RECONSTRUCTING"
TASK_STATUS_WAITING = 1
TASK_STATUS_SCHEDULED = 2
TASK_STATUS_QUEUED = 4
TASK_STATUS_RUNNING = 8
TASK_STATUS_DONE = 16
TASK_STATUS_LOST = 32
TASK_STATUS_RECONSTRUCTING = 64
TASK_STATUS_MAPPING = {
TASK_STATUS_WAITING: "WAITING",
TASK_STATUS_SCHEDULED: "SCHEDULED",
TASK_STATUS_QUEUED: "QUEUED",
TASK_STATUS_RUNNING: "RUNNING",
TASK_STATUS_DONE: "DONE",
TASK_STATUS_LOST: "LOST",
TASK_STATUS_RECONSTRUCTING: "RECONSTRUCTING",
}
@@ -66,8 +74,54 @@ class GlobalState(object):
"""
self.redis_client = redis.StrictRedis(host=redis_ip_address,
port=redis_port)
self.redis_clients = []
num_redis_shards = self.redis_client.get("NumRedisShards")
if num_redis_shards is None:
raise Exception("No entry found for NumRedisShards")
num_redis_shards = int(num_redis_shards)
if (num_redis_shards < 1):
raise Exception("Expected at least one Redis shard, found "
"{}.".format(num_redis_shards))
def _object_table(self, object_id_binary):
ip_address_ports = self.redis_client.lrange("RedisShards", start=0, end=-1)
if len(ip_address_ports) != num_redis_shards:
raise Exception("Expected {} Redis shard addresses, found "
"{}".format(num_redis_shards, len(ip_address_ports)))
for ip_address_port in ip_address_ports:
shard_address, shard_port = ip_address_port.split(b":")
self.redis_clients.append(redis.StrictRedis(host=shard_address,
port=shard_port))
def _execute_command(self, key, *args):
"""Execute a Redis command on the appropriate Redis shard based on key.
Args:
key: The object ID or the task ID that the query is about.
args: The command to run.
Returns:
The value returned by the Redis command.
"""
client = self.redis_clients[key.redis_shard_hash() %
len(self.redis_clients)]
return client.execute_command(*args)
def _keys(self, pattern):
"""Execute the KEYS command on all Redis shards.
Args:
pattern: The KEYS pattern to query.
Returns:
The concatenated list of results from all shards.
"""
result = []
for client in self.redis_clients:
result.extend(client.keys(pattern))
return result
def _object_table(self, object_id):
"""Fetch and parse the object table information for a single object ID.
Args:
@@ -78,16 +132,18 @@ class GlobalState(object):
A dictionary with information about the object ID in question.
"""
# Return information about a single object ID.
object_locations = self.redis_client.execute_command(
"RAY.OBJECT_TABLE_LOOKUP", object_id_binary)
object_locations = self._execute_command(object_id,
"RAY.OBJECT_TABLE_LOOKUP",
object_id.id())
if object_locations is not None:
manager_ids = [binary_to_hex(manager_id)
for manager_id in object_locations]
else:
manager_ids = None
result_table_response = self.redis_client.execute_command(
"RAY.RESULT_TABLE_LOOKUP", object_id_binary)
result_table_response = self._execute_command(object_id,
"RAY.RESULT_TABLE_LOOKUP",
object_id.id())
result_table_message = ResultTableReply.GetRootAsResultTableReply(
result_table_response, 0)
@@ -111,22 +167,21 @@ class GlobalState(object):
self._check_connected()
if object_id is not None:
# Return information about a single object ID.
return self._object_table(object_id.id())
return self._object_table(object_id)
else:
# Return the entire object table.
object_info_keys = self.redis_client.keys(OBJECT_INFO_PREFIX + "*")
object_location_keys = self.redis_client.keys(
OBJECT_LOCATION_PREFIX + "*")
object_info_keys = self._keys(OBJECT_INFO_PREFIX + "*")
object_location_keys = self._keys(OBJECT_LOCATION_PREFIX + "*")
object_ids_binary = set(
[key[len(OBJECT_INFO_PREFIX):] for key in object_info_keys] +
[key[len(OBJECT_LOCATION_PREFIX):] for key in object_location_keys])
results = {}
for object_id_binary in object_ids_binary:
results[binary_to_object_id(object_id_binary)] = self._object_table(
object_id_binary)
binary_to_object_id(object_id_binary))
return results
def _task_table(self, task_id_binary):
def _task_table(self, task_id):
"""Fetch and parse the task table information for a single object task ID.
Args:
@@ -135,12 +190,15 @@ class GlobalState(object):
Returns:
A dictionary with information about the task ID in question.
TASK_STATUS_MAPPING should be used to parse the "State" field into a
human-readable string.
"""
task_table_response = self.redis_client.execute_command(
"RAY.TASK_TABLE_GET", task_id_binary)
task_table_response = self._execute_command(task_id,
"RAY.TASK_TABLE_GET",
task_id.id())
if task_table_response is None:
raise Exception("There is no entry for task ID {} in the task table."
.format(binary_to_hex(task_id_binary)))
.format(binary_to_hex(task_id.id())))
task_table_message = TaskReply.GetRootAsTaskReply(task_table_response, 0)
task_spec = task_table_message.TaskSpec()
task_spec_message = TaskInfo.GetRootAsTaskInfo(task_spec, 0)
@@ -167,7 +225,7 @@ class GlobalState(object):
for i in range(task_spec_message.ReturnsLength())],
"RequiredResources": required_resources}
return {"State": task_state_mapping[task_table_message.State()],
return {"State": task_table_message.State(),
"LocalSchedulerID": binary_to_hex(
task_table_message.LocalSchedulerId()),
"TaskSpec": task_spec_info}
@@ -185,14 +243,15 @@ class GlobalState(object):
"""
self._check_connected()
if task_id is not None:
return self._task_table(hex_to_binary(task_id))
task_id = ray.local_scheduler.ObjectID(hex_to_binary(task_id))
return self._task_table(task_id)
else:
task_table_keys = self.redis_client.keys(TASK_PREFIX + "*")
task_table_keys = self._keys(TASK_PREFIX + "*")
results = {}
for key in task_table_keys:
task_id_binary = key[len(TASK_PREFIX):]
results[binary_to_hex(task_id_binary)] = self._task_table(
task_id_binary)
ray.local_scheduler.ObjectID(task_id_binary))
return results
def function_table(self, function_id=None):
@@ -7,9 +7,9 @@ import subprocess
import time
def start_global_scheduler(redis_address, node_ip_address, use_valgrind=False,
use_profiler=False, stdout_file=None,
stderr_file=None):
def start_global_scheduler(redis_address, node_ip_address,
use_valgrind=False, use_profiler=False,
stdout_file=None, stderr_file=None):
"""Start a global scheduler process.
Args:
+46 -51
View File
@@ -5,7 +5,6 @@ from __future__ import print_function
import numpy as np
import os
import random
import redis
import signal
import sys
import time
@@ -17,6 +16,7 @@ import ray.plasma as plasma
from ray.plasma.utils import create_object
from ray import services
from ray.experimental import state
USE_VALGRIND = False
PLASMA_STORE_MEMORY = 1000000000
@@ -26,13 +26,6 @@ NUM_CLUSTER_NODES = 2
NIL_WORKER_ID = 20 * b"\xff"
NIL_ACTOR_ID = 20 * b"\xff"
# These constants must match the scheduling state enum in task.h.
TASK_STATUS_WAITING = 1
TASK_STATUS_SCHEDULED = 2
TASK_STATUS_QUEUED = 4
TASK_STATUS_RUNNING = 8
TASK_STATUS_DONE = 16
# These constants are an implementation detail of ray_redis_module.cc, so this
# must be kept in sync with that file.
DB_CLIENT_PREFIX = "CL:"
@@ -63,15 +56,17 @@ class TestGlobalScheduler(unittest.TestCase):
def setUp(self):
# Start one Redis server and N pairs of (plasma, local_scheduler)
node_ip_address = "127.0.0.1"
redis_port, self.redis_process = services.start_redis(cleanup=False)
redis_address = services.address(node_ip_address, redis_port)
# Create a Redis client.
self.redis_client = redis.StrictRedis(host=node_ip_address,
port=redis_port)
self.node_ip_address = "127.0.0.1"
redis_address, redis_shards = services.start_redis(self.node_ip_address)
redis_port = services.get_port(redis_address)
time.sleep(0.1)
# Create a client for the global state store.
self.state = state.GlobalState()
self.state._initialize_global_state(self.node_ip_address, redis_port)
# Start one global scheduler.
self.p1 = global_scheduler.start_global_scheduler(
redis_address, node_ip_address, use_valgrind=USE_VALGRIND)
redis_address, self.node_ip_address, use_valgrind=USE_VALGRIND)
self.plasma_store_pids = []
self.plasma_manager_pids = []
self.local_scheduler_pids = []
@@ -89,7 +84,8 @@ class TestGlobalScheduler(unittest.TestCase):
redis_address)
plasma_manager_name, p3, plasma_manager_port = manager_info
self.plasma_manager_pids.append(p3)
plasma_address = "{}:{}".format(node_ip_address, plasma_manager_port)
plasma_address = "{}:{}".format(self.node_ip_address,
plasma_manager_port)
plasma_client = plasma.PlasmaClient(plasma_store_name,
plasma_manager_name)
self.plasma_clients.append(plasma_client)
@@ -116,7 +112,10 @@ class TestGlobalScheduler(unittest.TestCase):
for p4 in self.local_scheduler_pids:
self.assertEqual(p4.poll(), None)
self.assertEqual(self.redis_process.poll(), None)
redis_processes = services.all_processes[
services.PROCESS_TYPE_REDIS_SERVER]
for redis_process in redis_processes:
self.assertEqual(redis_process.poll(), None)
# Kill the global scheduler.
if USE_VALGRIND:
@@ -135,7 +134,9 @@ class TestGlobalScheduler(unittest.TestCase):
p4.kill()
# Kill Redis. In the event that we are using valgrind, this needs to happen
# after we kill the global scheduler.
self.redis_process.kill()
while redis_processes:
redis_process = redis_processes.pop()
redis_process.kill()
def get_plasma_manager_id(self):
"""Get the db_client_id with client_type equal to plasma_manager.
@@ -150,11 +151,10 @@ class TestGlobalScheduler(unittest.TestCase):
"""
db_client_id = None
client_list = self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))
for client_id in client_list:
response = self.redis_client.hget(client_id, b"client_type")
if response == b"plasma_manager":
db_client_id = client_id
client_list = self.state.client_table()[self.node_ip_address]
for client in client_list:
if client["ClientType"] == "plasma_manager":
db_client_id = client["DBClientID"]
break
return db_client_id
@@ -178,18 +178,16 @@ class TestGlobalScheduler(unittest.TestCase):
# There should be 2n+1 db clients: the global scheduler + one local
# scheduler and one plasma per node.
self.assertEqual(
len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
len(self.state.client_table()[self.node_ip_address]),
2 * NUM_CLUSTER_NODES + 1)
db_client_id = self.get_plasma_manager_id()
assert(db_client_id is not None)
assert(db_client_id.startswith(b"CL:"))
db_client_id = db_client_id[len(b"CL:"):] # Remove the CL: prefix.
def test_integration_single_task(self):
# There should be three db clients, the global scheduler, the local
# scheduler, and the plasma manager.
self.assertEqual(
len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
len(self.state.client_table()[self.node_ip_address]),
2 * NUM_CLUSTER_NODES + 1)
num_return_vals = [0, 1, 2, 3, 5, 10]
@@ -212,15 +210,15 @@ class TestGlobalScheduler(unittest.TestCase):
# local scheduler
num_retries = 10
while num_retries > 0:
task_entries = self.redis_client.keys("{}*".format(TASK_PREFIX))
task_entries = self.state.task_table()
self.assertLessEqual(len(task_entries), 1)
if len(task_entries) == 1:
task_contents = self.redis_client.hgetall(task_entries[0])
task_status = int(task_contents[b"state"])
self.assertTrue(task_status in [TASK_STATUS_WAITING,
TASK_STATUS_SCHEDULED,
TASK_STATUS_QUEUED])
if task_status == TASK_STATUS_QUEUED:
task_id, task = task_entries.popitem()
task_status = task["State"]
self.assertTrue(task_status in [state.TASK_STATUS_WAITING,
state.TASK_STATUS_SCHEDULED,
state.TASK_STATUS_QUEUED])
if task_status == state.TASK_STATUS_QUEUED:
break
else:
print(task_status)
@@ -228,7 +226,7 @@ class TestGlobalScheduler(unittest.TestCase):
num_retries -= 1
time.sleep(1)
if num_retries <= 0 and task_status != TASK_STATUS_QUEUED:
if num_retries <= 0 and task_status != state.TASK_STATUS_QUEUED:
# Failed to submit and schedule a single task -- bail.
self.tearDown()
sys.exit(1)
@@ -237,7 +235,7 @@ class TestGlobalScheduler(unittest.TestCase):
# There should be three db clients, the global scheduler, the local
# scheduler, and the plasma manager.
self.assertEqual(
len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
len(self.state.client_table()[self.node_ip_address]),
2 * NUM_CLUSTER_NODES + 1)
num_return_vals = [0, 1, 2, 3, 5, 10]
@@ -264,34 +262,31 @@ class TestGlobalScheduler(unittest.TestCase):
num_retries = 10
num_tasks_done = 0
while num_retries > 0:
task_entries = self.redis_client.keys("{}*".format(TASK_PREFIX))
task_entries = self.state.task_table()
self.assertLessEqual(len(task_entries), num_tasks)
# First, check if all tasks made it to Redis.
if len(task_entries) == num_tasks:
task_contents = [self.redis_client.hgetall(task_entries[i])
for i in range(len(task_entries))]
task_statuses = [int(contents[b"state"]) for contents in task_contents]
self.assertTrue(all([status in [TASK_STATUS_WAITING,
TASK_STATUS_SCHEDULED,
TASK_STATUS_QUEUED]
task_statuses = [task_entry["State"] for task_entry in
task_entries.values()]
self.assertTrue(all([status in [state.TASK_STATUS_WAITING,
state.TASK_STATUS_SCHEDULED,
state.TASK_STATUS_QUEUED]
for status in task_statuses]))
num_tasks_done = task_statuses.count(TASK_STATUS_QUEUED)
num_tasks_scheduled = task_statuses.count(TASK_STATUS_SCHEDULED)
num_tasks_waiting = task_statuses.count(TASK_STATUS_WAITING)
num_tasks_done = task_statuses.count(state.TASK_STATUS_QUEUED)
num_tasks_scheduled = task_statuses.count(state.TASK_STATUS_SCHEDULED)
num_tasks_waiting = task_statuses.count(state.TASK_STATUS_WAITING)
print("tasks in Redis = {}, tasks waiting = {}, tasks scheduled = {}, "
"tasks queued = {}, retries left = {}"
.format(len(task_entries), num_tasks_waiting,
num_tasks_scheduled, num_tasks_done, num_retries))
if all([status == TASK_STATUS_QUEUED for status in task_statuses]):
if all([status == state.TASK_STATUS_QUEUED for status in
task_statuses]):
# We're done, so pass.
break
num_retries -= 1
time.sleep(0.1)
if num_tasks_done != num_tasks:
# At least one of the tasks failed to schedule.
self.tearDown()
sys.exit(2)
self.assertEqual(num_tasks_done, num_tasks)
def test_integration_many_tasks_handler_sync(self):
self.integration_many_tasks_helper(timesync=True)
+38 -40
View File
@@ -12,11 +12,13 @@ import time
import ray
from ray.services import get_ip_address
from ray.services import get_port
from ray.utils import binary_to_object_id
from ray.utils import binary_to_hex
from ray.utils import hex_to_binary
# Import flatbuffer bindings.
from ray.core.generated.SubscribeToDBClientTableReply \
import SubscribeToDBClientTableReply
from ray.core.generated.TaskReply import TaskReply
from ray.core.generated.DriverTableMessage import DriverTableMessage
# These variables must be kept in sync with the C codebase.
@@ -31,7 +33,6 @@ TASK_STATUS_LOST = 32
PLASMA_MANAGER_HEARTBEAT_CHANNEL = b"plasma_managers"
DRIVER_DEATH_CHANNEL = b"driver_deaths"
# common/redis_module/ray_redis_module.cc
TASK_PREFIX = "TT:"
OBJECT_PREFIX = "OL:"
DB_CLIENT_PREFIX = "CL:"
DB_CLIENT_TABLE_NAME = b"db_clients"
@@ -43,7 +44,7 @@ PLASMA_MANAGER_CLIENT_TYPE = b"plasma_manager"
# Set up logging.
logging.basicConfig()
log = logging.getLogger()
log.setLevel(logging.WARN)
log.setLevel(logging.INFO)
class Monitor(object):
@@ -70,7 +71,11 @@ class Monitor(object):
"""
def __init__(self, redis_address, redis_port):
# Initialize the Redis clients.
self.state = ray.experimental.state.GlobalState()
self.state._initialize_global_state(redis_address, redis_port)
self.redis = redis.StrictRedis(host=redis_address, port=redis_port, db=0)
# TODO(swang): Update pubsub client to use ray.experimental.state once
# subscriptions are implemented there.
self.subscribe_client = self.redis.pubsub()
self.subscribed = {}
# Initialize data structures to keep track of the active database clients.
@@ -97,23 +102,17 @@ class Monitor(object):
TASK_STATUS_LOST. A local scheduler is deemed dead if it is in
self.dead_local_schedulers.
"""
task_ids = self.redis.scan_iter(
match="{prefix}*".format(prefix=TASK_PREFIX))
tasks = self.state.task_table()
num_tasks_updated = 0
for task_id in task_ids:
task_id = task_id[len(TASK_PREFIX):]
response = self.redis.execute_command("RAY.TASK_TABLE_GET", task_id)
# Parse the serialized task object.
task_object = TaskReply.GetRootAsTaskReply(response, 0)
local_scheduler_id = task_object.LocalSchedulerId()
for task_id, task in tasks.items():
# See if the corresponding local scheduler is alive.
if local_scheduler_id in self.dead_local_schedulers:
if task["LocalSchedulerID"] in self.dead_local_schedulers:
# If the task is scheduled on a dead local scheduler, mark the task as
# lost.
ok = self.redis.execute_command("RAY.TASK_TABLE_UPDATE",
task_id,
TASK_STATUS_LOST,
NIL_ID)
key = binary_to_object_id(hex_to_binary(task_id))
ok = self.state._execute_command(
key, "RAY.TASK_TABLE_UPDATE", hex_to_binary(task_id),
ray.experimental.state.TASK_STATUS_LOST, NIL_ID)
if ok != b"OK":
log.warn("Failed to update lost task for dead scheduler.")
num_tasks_updated += 1
@@ -129,19 +128,20 @@ class Monitor(object):
"""
# TODO(swang): Also kill the associated plasma store, since it's no longer
# reachable without a plasma manager.
object_ids = self.redis.scan_iter(
match="{prefix}*".format(prefix=OBJECT_PREFIX))
objects = self.state.object_table()
num_objects_removed = 0
for object_id in object_ids:
object_id = object_id[len(OBJECT_PREFIX):]
managers = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
object_id)
for manager in managers:
for object_id, obj in objects.items():
manager_ids = obj["ManagerIDs"]
if manager_ids is None:
continue
for manager in manager_ids:
if manager in self.dead_plasma_managers:
# If the object was on a dead plasma manager, remove that location
# entry.
ok = self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", object_id,
manager)
ok = self.state._execute_command(object_id,
"RAY.OBJECT_TABLE_REMOVE",
object_id.id(),
hex_to_binary(manager))
if ok != b"OK":
log.warn("Failed to remove object location for dead plasma "
"manager.")
@@ -157,18 +157,16 @@ class Monitor(object):
not miss any notifications for deleted clients that occurred before we
subscribed.
"""
db_client_keys = self.redis.keys(
"{prefix}*".format(prefix=DB_CLIENT_PREFIX))
for db_client_key in db_client_keys:
db_client_id = db_client_key[len(DB_CLIENT_PREFIX):]
client_type, deleted = self.redis.hmget(db_client_key,
[b"client_type", b"deleted"])
deleted = bool(int(deleted))
if deleted:
if client_type == LOCAL_SCHEDULER_CLIENT_TYPE:
self.dead_local_schedulers.add(db_client_id)
elif client_type == PLASMA_MANAGER_CLIENT_TYPE:
self.dead_plasma_managers.add(db_client_id)
clients = self.state.client_table()
for node_ip_address, node_clients in clients.items():
for client in node_clients:
db_client_id = client["DBClientID"]
client_type = client["ClientType"]
if client["Deleted"]:
if client_type == LOCAL_SCHEDULER_CLIENT_TYPE:
self.dead_local_schedulers.add(db_client_id)
elif client_type == PLASMA_MANAGER_CLIENT_TYPE:
self.dead_plasma_managers.add(db_client_id)
def subscribe_handler(self, channel, data):
"""Handle a subscription success message from Redis.
@@ -186,7 +184,7 @@ class Monitor(object):
"""
notification_object = (SubscribeToDBClientTableReply
.GetRootAsSubscribeToDBClientTableReply(data, 0))
db_client_id = notification_object.DbClientId()
db_client_id = binary_to_hex(notification_object.DbClientId())
client_type = notification_object.ClientType()
is_insertion = notification_object.IsInsertion()
@@ -196,7 +194,7 @@ class Monitor(object):
# If the update was a deletion, add them to our accounting for dead
# local schedulers and plasma managers.
log.warn("Removed {}".format(client_type))
log.warn("Removed {}, client ID {}".format(client_type, db_client_id))
if client_type == LOCAL_SCHEDULER_CLIENT_TYPE:
if db_client_id not in self.dead_local_schedulers:
self.dead_local_schedulers.add(db_client_id)
@@ -256,7 +254,7 @@ class Monitor(object):
result = pipe.hget(local_scheduler_id, "gpus_in_use")
gpus_in_use = dict() if result is None else json.loads(result)
driver_id_hex = ray.utils.binary_to_hex(driver_id)
driver_id_hex = binary_to_hex(driver_id)
if driver_id_hex in gpus_in_use:
num_gpus_returned = gpus_in_use.pop(driver_id_hex)
+2 -1
View File
@@ -405,7 +405,8 @@ def start_plasma_manager(store_name, redis_address,
"-m", plasma_manager_name,
"-h", node_ip_address,
"-p", str(plasma_manager_port),
"-r", redis_address]
"-r", redis_address,
]
if use_valgrind:
process = subprocess.Popen(["valgrind",
"--track-origins=yes",
+2 -3
View File
@@ -480,7 +480,7 @@ class TestPlasmaManager(unittest.TestCase):
store_name1, self.p2 = plasma.start_plasma_store(use_valgrind=USE_VALGRIND)
store_name2, self.p3 = plasma.start_plasma_store(use_valgrind=USE_VALGRIND)
# Start a Redis server.
redis_address = services.address("127.0.0.1", services.start_redis()[0])
redis_address, _ = services.start_redis("127.0.0.1")
# Start two PlasmaManagers.
manager_name1, self.p4, self.port1 = plasma.start_plasma_manager(
store_name1, redis_address, use_valgrind=USE_VALGRIND)
@@ -778,8 +778,7 @@ class TestPlasmaManagerRecovery(unittest.TestCase):
self.store_name, self.p2 = plasma.start_plasma_store(
use_valgrind=USE_VALGRIND)
# Start a Redis server.
self.redis_address = services.address("127.0.0.1",
services.start_redis()[0])
self.redis_address, _ = services.start_redis("127.0.0.1")
# Start a PlasmaManagers.
manager_name, self.p3, self.port1 = plasma.start_plasma_manager(
self.store_name,
+110 -45
View File
@@ -240,15 +240,75 @@ def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5):
"configured properly.")
def start_redis(node_ip_address="127.0.0.1", port=None, num_retries=20,
stdout_file=None, stderr_file=None, cleanup=True):
"""Start a Redis server.
def start_redis(node_ip_address,
port=None,
num_redis_shards=1,
redirect_output=False,
cleanup=True):
"""Start the Redis global state store.
Args:
node_ip_address: The IP address of the current node. This is only used for
recording the log filenames in Redis.
port (int): If provided, the primary Redis shard will be started on this
port.
num_redis_shards (int): If provided, the number of Redis shards to start,
in addition to the primary one. The default value is one shard.
cleanup (bool): True if using Ray in local mode. If cleanup is true, then
all Redis processes started by this method will be killed by
serices.cleanup() when the Python process that imported services exits.
Returns:
A tuple of the address for the primary Redis shard and a list of addresses
for the remaining shards.
"""
redis_stdout_file, redis_stderr_file = new_log_files(
"redis", redirect_output)
assigned_port, _ = start_redis_instance(
node_ip_address=node_ip_address, port=port,
stdout_file=redis_stdout_file, stderr_file=redis_stderr_file,
cleanup=cleanup)
if port is not None:
assert assigned_port == port
port = assigned_port
redis_address = address(node_ip_address, port)
# Register the number of Redis shards in the primary shard, so that clients
# know how many redis shards to expect under RedisShards.
redis_client = redis.StrictRedis(host=node_ip_address, port=port)
redis_client.set("NumRedisShards", str(num_redis_shards))
# Start other Redis shards listening on random ports. Each Redis shard logs
# to a separate file, prefixed by "redis-<shard number>".
redis_shards = []
for i in range(num_redis_shards):
redis_stdout_file, redis_stderr_file = new_log_files(
"redis-{}".format(i), redirect_output)
redis_shard_port, _ = start_redis_instance(
node_ip_address=node_ip_address, stdout_file=redis_stdout_file,
stderr_file=redis_stderr_file, cleanup=cleanup)
shard_address = address(node_ip_address, redis_shard_port)
redis_shards.append(shard_address)
# Store redis shard information in the primary redis shard.
redis_client.rpush("RedisShards", shard_address)
return redis_address, redis_shards
def start_redis_instance(node_ip_address="127.0.0.1",
port=None,
num_retries=20,
stdout_file=None,
stderr_file=None,
cleanup=True):
"""Start a single Redis server.
Args:
node_ip_address (str): The IP address of the current node. This is only
used for recording the log filenames in Redis.
port (int): If provided, start a Redis server with this port.
num_retries (int): The number of times to attempt to start Redis.
num_retries (int): The number of times to attempt to start Redis. If a port
is provided, this defaults to 1.
stdout_file: A file handle opened for writing to redirect stdout to. If no
redirection should happen, then this should be None.
stderr_file: A file handle opened for writing to redirect stderr to. If no
@@ -275,8 +335,8 @@ def start_redis(node_ip_address="127.0.0.1", port=None, num_retries=20,
assert os.path.isfile(redis_module)
counter = 0
if port is not None:
if num_retries != 1:
raise Exception("num_retries must be 1 if port is specified.")
# If a port is specified, then try only once to connect.
num_retries = 1
else:
port = new_port()
while counter < num_retries:
@@ -356,8 +416,8 @@ def start_log_monitor(redis_address, node_ip_address, stdout_file=None,
[stdout_file, stderr_file])
def start_global_scheduler(redis_address, node_ip_address, stdout_file=None,
stderr_file=None, cleanup=True):
def start_global_scheduler(redis_address, node_ip_address,
stdout_file=None, stderr_file=None, cleanup=True):
"""Start a global scheduler process.
Args:
@@ -372,7 +432,8 @@ def start_global_scheduler(redis_address, node_ip_address, stdout_file=None,
this process will be killed by services.cleanup() when the Python process
that imported services exits.
"""
p = global_scheduler.start_global_scheduler(redis_address, node_ip_address,
p = global_scheduler.start_global_scheduler(redis_address,
node_ip_address,
stdout_file=stdout_file,
stderr_file=stderr_file)
if cleanup:
@@ -545,10 +606,11 @@ def start_local_scheduler(redis_address,
return local_scheduler_name
def start_objstore(node_ip_address, redis_address, object_manager_port=None,
store_stdout_file=None, store_stderr_file=None,
manager_stdout_file=None, manager_stderr_file=None,
cleanup=True, objstore_memory=None):
def start_objstore(node_ip_address, redis_address,
object_manager_port=None, store_stdout_file=None,
store_stderr_file=None, manager_stdout_file=None,
manager_stderr_file=None, cleanup=True,
objstore_memory=None):
"""This method starts an object store process.
Args:
@@ -704,13 +766,14 @@ def start_monitor(redis_address, node_ip_address, stdout_file=None,
def start_ray_processes(address_info=None,
node_ip_address="127.0.0.1",
redis_port=None,
num_workers=None,
num_local_schedulers=1,
num_redis_shards=1,
worker_path=None,
cleanup=True,
redirect_output=False,
include_global_scheduler=False,
include_redis=False,
include_log_monitor=False,
include_webui=False,
start_workers_from_local_scheduler=True,
@@ -723,12 +786,17 @@ def start_ray_processes(address_info=None,
that have already been started. If provided, address_info will be
modified to include processes that are newly started.
node_ip_address (str): The IP address of this node.
redis_port (int): The port that the primary Redis shard should listen to.
If None, then a random port will be chosen. If the key "redis_address" is
in address_info, then this argument will be ignored.
num_workers (int): The number of workers to start.
num_local_schedulers (int): The total number of local schedulers required.
This is also the total number of object stores required. This method will
start new instances of local schedulers and object stores until there are
num_local_schedulers existing instances of each, including ones already
registered with the given address_info.
num_redis_shards: The number of Redis shards to start in addition to the
primary Redis shard.
worker_path (str): The path of the source code that will be run by the
worker.
cleanup (bool): If cleanup is true, then the processes started here will be
@@ -738,8 +806,6 @@ def start_ray_processes(address_info=None,
file.
include_global_scheduler (bool): If include_global_scheduler is True, then
start a global scheduler process.
include_redis (bool): If include_redis is True, then start a Redis server
process.
include_log_monitor (bool): If True, then start a log monitor to monitor
the log files for all processes on this node and push their contents to
Redis.
@@ -785,29 +851,14 @@ def start_ray_processes(address_info=None,
# warning messages when it starts up. Instead of suppressing the output, we
# should address the warnings.
redis_address = address_info.get("redis_address")
if include_redis:
redis_stdout_file, redis_stderr_file = new_log_files("redis",
redirect_output)
if redis_address is None:
# Start a Redis server. The start_redis method will choose a random port.
redis_port, _ = start_redis(node_ip_address,
stdout_file=redis_stdout_file,
stderr_file=redis_stderr_file,
cleanup=cleanup)
redis_address = address(node_ip_address, redis_port)
address_info["redis_address"] = redis_address
time.sleep(0.1)
else:
# A Redis address was provided, so start a Redis server with the given
# port. TODO(rkn): We should check that the IP address corresponds to the
# machine that this method is running on.
redis_port = get_port(redis_address)
new_redis_port, _ = start_redis(port=int(redis_port),
num_retries=1,
stdout_file=redis_stdout_file,
stderr_file=redis_stderr_file,
cleanup=cleanup)
assert redis_port == new_redis_port
redis_shards = address_info.get("redis_shards", [])
if redis_address is None:
redis_address, redis_shards = start_redis(
node_ip_address, port=redis_port, num_redis_shards=num_redis_shards,
redirect_output=redirect_output, cleanup=cleanup)
address_info["redis_address"] = redis_address
time.sleep(0.1)
# Start monitoring the processes.
monitor_stdout_file, monitor_stderr_file = new_log_files("monitor",
redirect_output)
@@ -815,9 +866,14 @@ def start_ray_processes(address_info=None,
node_ip_address,
stdout_file=monitor_stdout_file,
stderr_file=monitor_stderr_file)
else:
if redis_address is None:
raise Exception("Redis address expected")
if redis_shards == []:
# Get redis shards from primary redis instance.
redis_ip_address, redis_port = redis_address.split(":")
redis_client = redis.StrictRedis(host=redis_ip_address, port=redis_port)
redis_shards = redis_client.lrange("RedisShards", start=0, end=-1)
redis_shards = [shard.decode("ascii") for shard in redis_shards]
address_info["redis_shards"] = redis_shards
# Start the log monitor, if necessary.
if include_log_monitor:
@@ -1005,6 +1061,7 @@ def start_ray_node(node_ip_address,
def start_ray_head(address_info=None,
node_ip_address="127.0.0.1",
redis_port=None,
num_workers=0,
num_local_schedulers=1,
worker_path=None,
@@ -1012,7 +1069,8 @@ def start_ray_head(address_info=None,
redirect_output=False,
start_workers_from_local_scheduler=True,
num_cpus=None,
num_gpus=None):
num_gpus=None,
num_redis_shards=None):
"""Start Ray in local mode.
Args:
@@ -1020,6 +1078,9 @@ def start_ray_head(address_info=None,
that have already been started. If provided, address_info will be
modified to include processes that are newly started.
node_ip_address (str): The IP address of this node.
redis_port (int): The port that the primary Redis shard should listen to.
If None, then a random port will be chosen. If the key "redis_address" is
in address_info, then this argument will be ignored.
num_workers (int): The number of workers to start.
num_local_schedulers (int): The total number of local schedulers required.
This is also the total number of object stores required. This method will
@@ -1038,14 +1099,18 @@ def start_ray_head(address_info=None,
Python.
num_cpus (int): number of cpus to configure the local scheduler with.
num_gpus (int): number of gpus to configure the local scheduler with.
num_redis_shards: The number of Redis shards to start in addition to the
primary Redis shard.
Returns:
A dictionary of the address information for the processes that were
started.
"""
num_redis_shards = 1 if num_redis_shards is None else num_redis_shards
return start_ray_processes(
address_info=address_info,
node_ip_address=node_ip_address,
redis_port=redis_port,
num_workers=num_workers,
num_local_schedulers=num_local_schedulers,
worker_path=worker_path,
@@ -1053,11 +1118,11 @@ def start_ray_head(address_info=None,
redirect_output=redirect_output,
include_global_scheduler=True,
include_log_monitor=True,
include_redis=True,
include_webui=False,
start_workers_from_local_scheduler=start_workers_from_local_scheduler,
num_cpus=num_cpus,
num_gpus=num_gpus)
num_gpus=num_gpus,
num_redis_shards=num_redis_shards)
def new_log_files(name, redirect_output):
+28 -15
View File
@@ -992,7 +992,8 @@ def _init(address_info=None,
redirect_output=False,
start_workers_from_local_scheduler=True,
num_cpus=None,
num_gpus=None):
num_gpus=None,
num_redis_shards=None):
"""Helper method to connect to an existing Ray cluster or start a new one.
This method handles two cases. Either a Ray cluster already exists and we
@@ -1002,8 +1003,9 @@ def _init(address_info=None,
Args:
address_info (dict): A dictionary with address information for processes in
a partially-started Ray cluster. If start_ray_local=True, any processes
not in this dictionary will be started. If provided, address_info will be
modified to include processes that are newly started.
not in this dictionary will be started. If provided, an updated
address_info dictionary will be returned to include processes that are
newly started.
start_ray_local (bool): If True then this will start any processes not
already in address_info, including Redis, a global scheduler, local
scheduler(s), object store(s), and worker(s). It will also kill these
@@ -1028,6 +1030,8 @@ def _init(address_info=None,
be configured with.
num_gpus: A list containing the number of GPUs the local schedulers should
be configured with.
num_redis_shards: The number of Redis shards to start in addition to the
primary Redis shard.
Returns:
Address information about the started processes.
@@ -1069,6 +1073,8 @@ def _init(address_info=None,
num_local_schedulers = len(local_schedulers)
else:
num_local_schedulers = 1
# Use 1 additional redis shard if num_redis_shards is not provided.
num_redis_shards = 1 if num_redis_shards is None else num_redis_shards
# Start the scheduler, object store, and some workers. These will be killed
# by the call to cleanup(), which happens when the Python script exits.
address_info = services.start_ray_head(
@@ -1079,20 +1085,24 @@ def _init(address_info=None,
redirect_output=redirect_output,
start_workers_from_local_scheduler=start_workers_from_local_scheduler,
num_cpus=num_cpus,
num_gpus=num_gpus)
num_gpus=num_gpus,
num_redis_shards=num_redis_shards)
else:
if redis_address is None:
raise Exception("If start_ray_local=False, then redis_address must be "
"provided.")
raise Exception("When connecting to an existing cluster, redis_address "
"must be provided.")
if num_workers is not None:
raise Exception("If start_ray_local=False, then num_workers must not be "
"provided.")
raise Exception("When connecting to an existing cluster, num_workers "
"must not be provided.")
if num_local_schedulers is not None:
raise Exception("If start_ray_local=False, then num_local_schedulers "
"must not be provided.")
raise Exception("When connecting to an existing cluster, "
"num_local_schedulers must not be provided.")
if num_cpus is not None or num_gpus is not None:
raise Exception("If start_ray_local=False, then num_cpus and num_gpus "
"must not be provided.")
raise Exception("When connecting to an existing cluster, num_cpus and "
"num_gpus must not be provided.")
if num_redis_shards is not None:
raise Exception("When connecting to an existing cluster, "
"num_redis_shards must not be provided.")
# Get the node IP address if one is not provided.
if node_ip_address is None:
node_ip_address = services.get_node_ip_address(redis_address)
@@ -1121,7 +1131,7 @@ def _init(address_info=None,
def init(redis_address=None, node_ip_address=None, object_id_seed=None,
num_workers=None, driver_mode=SCRIPT_MODE, redirect_output=False,
num_cpus=None, num_gpus=None):
num_cpus=None, num_gpus=None, num_redis_shards=None):
"""Either connect to an existing Ray cluster or start one and connect to it.
This method handles two cases. Either a Ray cluster already exists and we
@@ -1148,6 +1158,8 @@ def init(redis_address=None, node_ip_address=None, object_id_seed=None,
configured with.
num_gpus (int): Number of gpus the user wishes all local schedulers to be
configured with.
num_redis_shards: The number of Redis shards to start in addition to the
primary Redis shard.
Returns:
Address information about the started processes.
@@ -1161,7 +1173,7 @@ def init(redis_address=None, node_ip_address=None, object_id_seed=None,
return _init(address_info=info, start_ray_local=(redis_address is None),
num_workers=num_workers, driver_mode=driver_mode,
redirect_output=redirect_output, num_cpus=num_cpus,
num_gpus=num_gpus)
num_gpus=num_gpus, num_redis_shards=num_redis_shards)
def cleanup(worker=global_worker):
@@ -1577,7 +1589,8 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker,
ray.local_scheduler.ObjectID(NIL_ACTOR_ID),
worker.actor_counters[actor_id],
[0, 0])
worker.redis_client.execute_command(
global_state._execute_command(
driver_task.task_id(),
"RAY.TASK_TABLE_ADD",
driver_task.task_id().id(),
TASK_STATUS_RUNNING,