mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:38:18 +08:00
Shard Redis. (#539)
* Implement sharding in the Ray core * Single node Python modifications to do sharding * Do the sharding in redis.cc * Pipe num_redis_shards through start_ray.py and worker.py. * Use multiple redis shards in multinode tests. * first steps for sharding ray.global_state * Fix problem in multinode docker test. * fix runtest.py * fix some tests * fix redis shard startup * fix redis sharding * fix * fix bug introduced by the map-iterator being consumed * fix sharding bug * shard event table * update number of Redis clients to be 64K * Fix object table tests by flushing shards in between unit tests * Fix local scheduler tests * Documentation * Register shard locations in the primary shard * Add plasma unit tests back to build * lint * lint and fix build * Fix * Address Robert's comments * Refactor start_ray_processes to start Redis shard * lint * Fix global scheduler python tests * Fix redis module test * Fix plasma test * Fix component failure test * Fix local scheduler test * Fix runtest.py * Fix global scheduler test for python3 * Fix task_table_test_and_update bug, from actor task table submission race * Fix jenkins tests. * Retry Redis shard connections * Fix test cases * Convert database clients to DBClient struct * Fix race condition when subscribing to db client table * Remove unused lines, add APITest for sharded Ray * Fix * Fix memory leak * Suppress ReconstructionTests output * Suppress output for APITestSharded * Reissue task table add/update commands if initial command does not publish to any subscribers. * fix * Fix linting. * fix tests * fix linting * fix python test * fix linting
This commit is contained in:
committed by
Philipp Moritz
parent
0a4304725f
commit
ee08c8274b
@@ -52,7 +52,7 @@ def get_next_message(pubsub_client, timeout_seconds=10):
|
||||
class TestGlobalStateStore(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
redis_port, _ = ray.services.start_redis()
|
||||
redis_port, _ = ray.services.start_redis_instance()
|
||||
self.redis = redis.StrictRedis(host="localhost", port=redis_port, db=0)
|
||||
|
||||
def tearDown(self):
|
||||
@@ -308,6 +308,10 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
TASK_STATUS_SCHEDULED = 2
|
||||
TASK_STATUS_QUEUED = 4
|
||||
|
||||
# make sure somebody will get a notification (checked in the redis module)
|
||||
p = self.redis.pubsub()
|
||||
p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
|
||||
|
||||
def check_task_reply(message, task_args, updated=False):
|
||||
task_status, local_scheduler_id, task_spec = task_args
|
||||
task_reply_object = TaskReply.GetRootAsTaskReply(message, 0)
|
||||
@@ -388,33 +392,53 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
self.assertNotEqual(get_response, old_response)
|
||||
check_task_reply(get_response, task_args[1:])
|
||||
|
||||
def check_task_subscription(self, p, scheduling_state, local_scheduler_id):
|
||||
task_args = [b"task_id", scheduling_state,
|
||||
local_scheduler_id.encode("ascii"), b"task_spec"]
|
||||
self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args)
|
||||
# Receive the data.
|
||||
message = get_next_message(p)["data"]
|
||||
# Check that the notification object is correct.
|
||||
notification_object = TaskReply.GetRootAsTaskReply(message, 0)
|
||||
self.assertEqual(notification_object.TaskId(), b"task_id")
|
||||
self.assertEqual(notification_object.State(), scheduling_state)
|
||||
self.assertEqual(notification_object.LocalSchedulerId(),
|
||||
local_scheduler_id.encode("ascii"))
|
||||
self.assertEqual(notification_object.TaskSpec(), b"task_spec")
|
||||
|
||||
def testTaskTableSubscribe(self):
|
||||
scheduling_state = 1
|
||||
local_scheduler_id = "local_scheduler_id"
|
||||
# Subscribe to the task table.
|
||||
p = self.redis.pubsub()
|
||||
p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
|
||||
# Receive acknowledgment.
|
||||
self.assertEqual(get_next_message(p)["data"], 1)
|
||||
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
|
||||
# unsubscribe to make sure there is only one subscriber at a given time
|
||||
p.punsubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
|
||||
# Receive acknowledgment.
|
||||
self.assertEqual(get_next_message(p)["data"], 0)
|
||||
|
||||
p.psubscribe("{prefix}*:{state}".format(
|
||||
prefix=TASK_PREFIX, state=scheduling_state))
|
||||
# Receive acknowledgment.
|
||||
self.assertEqual(get_next_message(p)["data"], 1)
|
||||
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
|
||||
p.punsubscribe("{prefix}*:{state}".format(
|
||||
prefix=TASK_PREFIX, state=scheduling_state))
|
||||
# Receive acknowledgment.
|
||||
self.assertEqual(get_next_message(p)["data"], 0)
|
||||
|
||||
p.psubscribe("{prefix}{local_scheduler_id}:*".format(
|
||||
prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
|
||||
task_args = [b"task_id", scheduling_state,
|
||||
local_scheduler_id.encode("ascii"), b"task_spec"]
|
||||
self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args)
|
||||
# Receive the acknowledgement message.
|
||||
# Receive acknowledgment.
|
||||
self.assertEqual(get_next_message(p)["data"], 1)
|
||||
self.assertEqual(get_next_message(p)["data"], 2)
|
||||
self.assertEqual(get_next_message(p)["data"], 3)
|
||||
# Receive the actual data.
|
||||
for i in range(3):
|
||||
message = get_next_message(p)["data"]
|
||||
# Check that the notification object is correct.
|
||||
notification_object = TaskReply.GetRootAsTaskReply(message, 0)
|
||||
self.assertEqual(notification_object.TaskId(), b"task_id")
|
||||
self.assertEqual(notification_object.State(), scheduling_state)
|
||||
self.assertEqual(notification_object.LocalSchedulerId(),
|
||||
local_scheduler_id.encode("ascii"))
|
||||
self.assertEqual(notification_object.TaskSpec(), b"task_spec")
|
||||
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
|
||||
p.punsubscribe("{prefix}{local_scheduler_id}:*".format(
|
||||
prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
|
||||
# Receive acknowledgment.
|
||||
self.assertEqual(get_next_message(p)["data"], 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import print_function
|
||||
import pickle
|
||||
import redis
|
||||
|
||||
import ray
|
||||
from ray.utils import (decode, binary_to_object_id, binary_to_hex,
|
||||
hex_to_binary)
|
||||
|
||||
@@ -25,14 +26,21 @@ OBJECT_CHANNEL_PREFIX = "OC:"
|
||||
|
||||
# This mapping from integer to task state string must be kept up-to-date with
|
||||
# the scheduling_state enum in task.h.
|
||||
task_state_mapping = {
|
||||
1: "WAITING",
|
||||
2: "SCHEDULED",
|
||||
4: "QUEUED",
|
||||
8: "RUNNING",
|
||||
16: "DONE",
|
||||
32: "LOST",
|
||||
64: "RECONSTRUCTING"
|
||||
TASK_STATUS_WAITING = 1
|
||||
TASK_STATUS_SCHEDULED = 2
|
||||
TASK_STATUS_QUEUED = 4
|
||||
TASK_STATUS_RUNNING = 8
|
||||
TASK_STATUS_DONE = 16
|
||||
TASK_STATUS_LOST = 32
|
||||
TASK_STATUS_RECONSTRUCTING = 64
|
||||
TASK_STATUS_MAPPING = {
|
||||
TASK_STATUS_WAITING: "WAITING",
|
||||
TASK_STATUS_SCHEDULED: "SCHEDULED",
|
||||
TASK_STATUS_QUEUED: "QUEUED",
|
||||
TASK_STATUS_RUNNING: "RUNNING",
|
||||
TASK_STATUS_DONE: "DONE",
|
||||
TASK_STATUS_LOST: "LOST",
|
||||
TASK_STATUS_RECONSTRUCTING: "RECONSTRUCTING",
|
||||
}
|
||||
|
||||
|
||||
@@ -66,8 +74,54 @@ class GlobalState(object):
|
||||
"""
|
||||
self.redis_client = redis.StrictRedis(host=redis_ip_address,
|
||||
port=redis_port)
|
||||
self.redis_clients = []
|
||||
num_redis_shards = self.redis_client.get("NumRedisShards")
|
||||
if num_redis_shards is None:
|
||||
raise Exception("No entry found for NumRedisShards")
|
||||
num_redis_shards = int(num_redis_shards)
|
||||
if (num_redis_shards < 1):
|
||||
raise Exception("Expected at least one Redis shard, found "
|
||||
"{}.".format(num_redis_shards))
|
||||
|
||||
def _object_table(self, object_id_binary):
|
||||
ip_address_ports = self.redis_client.lrange("RedisShards", start=0, end=-1)
|
||||
if len(ip_address_ports) != num_redis_shards:
|
||||
raise Exception("Expected {} Redis shard addresses, found "
|
||||
"{}".format(num_redis_shards, len(ip_address_ports)))
|
||||
|
||||
for ip_address_port in ip_address_ports:
|
||||
shard_address, shard_port = ip_address_port.split(b":")
|
||||
self.redis_clients.append(redis.StrictRedis(host=shard_address,
|
||||
port=shard_port))
|
||||
|
||||
def _execute_command(self, key, *args):
|
||||
"""Execute a Redis command on the appropriate Redis shard based on key.
|
||||
|
||||
Args:
|
||||
key: The object ID or the task ID that the query is about.
|
||||
args: The command to run.
|
||||
|
||||
Returns:
|
||||
The value returned by the Redis command.
|
||||
"""
|
||||
client = self.redis_clients[key.redis_shard_hash() %
|
||||
len(self.redis_clients)]
|
||||
return client.execute_command(*args)
|
||||
|
||||
def _keys(self, pattern):
|
||||
"""Execute the KEYS command on all Redis shards.
|
||||
|
||||
Args:
|
||||
pattern: The KEYS pattern to query.
|
||||
|
||||
Returns:
|
||||
The concatenated list of results from all shards.
|
||||
"""
|
||||
result = []
|
||||
for client in self.redis_clients:
|
||||
result.extend(client.keys(pattern))
|
||||
return result
|
||||
|
||||
def _object_table(self, object_id):
|
||||
"""Fetch and parse the object table information for a single object ID.
|
||||
|
||||
Args:
|
||||
@@ -78,16 +132,18 @@ class GlobalState(object):
|
||||
A dictionary with information about the object ID in question.
|
||||
"""
|
||||
# Return information about a single object ID.
|
||||
object_locations = self.redis_client.execute_command(
|
||||
"RAY.OBJECT_TABLE_LOOKUP", object_id_binary)
|
||||
object_locations = self._execute_command(object_id,
|
||||
"RAY.OBJECT_TABLE_LOOKUP",
|
||||
object_id.id())
|
||||
if object_locations is not None:
|
||||
manager_ids = [binary_to_hex(manager_id)
|
||||
for manager_id in object_locations]
|
||||
else:
|
||||
manager_ids = None
|
||||
|
||||
result_table_response = self.redis_client.execute_command(
|
||||
"RAY.RESULT_TABLE_LOOKUP", object_id_binary)
|
||||
result_table_response = self._execute_command(object_id,
|
||||
"RAY.RESULT_TABLE_LOOKUP",
|
||||
object_id.id())
|
||||
result_table_message = ResultTableReply.GetRootAsResultTableReply(
|
||||
result_table_response, 0)
|
||||
|
||||
@@ -111,22 +167,21 @@ class GlobalState(object):
|
||||
self._check_connected()
|
||||
if object_id is not None:
|
||||
# Return information about a single object ID.
|
||||
return self._object_table(object_id.id())
|
||||
return self._object_table(object_id)
|
||||
else:
|
||||
# Return the entire object table.
|
||||
object_info_keys = self.redis_client.keys(OBJECT_INFO_PREFIX + "*")
|
||||
object_location_keys = self.redis_client.keys(
|
||||
OBJECT_LOCATION_PREFIX + "*")
|
||||
object_info_keys = self._keys(OBJECT_INFO_PREFIX + "*")
|
||||
object_location_keys = self._keys(OBJECT_LOCATION_PREFIX + "*")
|
||||
object_ids_binary = set(
|
||||
[key[len(OBJECT_INFO_PREFIX):] for key in object_info_keys] +
|
||||
[key[len(OBJECT_LOCATION_PREFIX):] for key in object_location_keys])
|
||||
results = {}
|
||||
for object_id_binary in object_ids_binary:
|
||||
results[binary_to_object_id(object_id_binary)] = self._object_table(
|
||||
object_id_binary)
|
||||
binary_to_object_id(object_id_binary))
|
||||
return results
|
||||
|
||||
def _task_table(self, task_id_binary):
|
||||
def _task_table(self, task_id):
|
||||
"""Fetch and parse the task table information for a single object task ID.
|
||||
|
||||
Args:
|
||||
@@ -135,12 +190,15 @@ class GlobalState(object):
|
||||
|
||||
Returns:
|
||||
A dictionary with information about the task ID in question.
|
||||
TASK_STATUS_MAPPING should be used to parse the "State" field into a
|
||||
human-readable string.
|
||||
"""
|
||||
task_table_response = self.redis_client.execute_command(
|
||||
"RAY.TASK_TABLE_GET", task_id_binary)
|
||||
task_table_response = self._execute_command(task_id,
|
||||
"RAY.TASK_TABLE_GET",
|
||||
task_id.id())
|
||||
if task_table_response is None:
|
||||
raise Exception("There is no entry for task ID {} in the task table."
|
||||
.format(binary_to_hex(task_id_binary)))
|
||||
.format(binary_to_hex(task_id.id())))
|
||||
task_table_message = TaskReply.GetRootAsTaskReply(task_table_response, 0)
|
||||
task_spec = task_table_message.TaskSpec()
|
||||
task_spec_message = TaskInfo.GetRootAsTaskInfo(task_spec, 0)
|
||||
@@ -167,7 +225,7 @@ class GlobalState(object):
|
||||
for i in range(task_spec_message.ReturnsLength())],
|
||||
"RequiredResources": required_resources}
|
||||
|
||||
return {"State": task_state_mapping[task_table_message.State()],
|
||||
return {"State": task_table_message.State(),
|
||||
"LocalSchedulerID": binary_to_hex(
|
||||
task_table_message.LocalSchedulerId()),
|
||||
"TaskSpec": task_spec_info}
|
||||
@@ -185,14 +243,15 @@ class GlobalState(object):
|
||||
"""
|
||||
self._check_connected()
|
||||
if task_id is not None:
|
||||
return self._task_table(hex_to_binary(task_id))
|
||||
task_id = ray.local_scheduler.ObjectID(hex_to_binary(task_id))
|
||||
return self._task_table(task_id)
|
||||
else:
|
||||
task_table_keys = self.redis_client.keys(TASK_PREFIX + "*")
|
||||
task_table_keys = self._keys(TASK_PREFIX + "*")
|
||||
results = {}
|
||||
for key in task_table_keys:
|
||||
task_id_binary = key[len(TASK_PREFIX):]
|
||||
results[binary_to_hex(task_id_binary)] = self._task_table(
|
||||
task_id_binary)
|
||||
ray.local_scheduler.ObjectID(task_id_binary))
|
||||
return results
|
||||
|
||||
def function_table(self, function_id=None):
|
||||
|
||||
@@ -7,9 +7,9 @@ import subprocess
|
||||
import time
|
||||
|
||||
|
||||
def start_global_scheduler(redis_address, node_ip_address, use_valgrind=False,
|
||||
use_profiler=False, stdout_file=None,
|
||||
stderr_file=None):
|
||||
def start_global_scheduler(redis_address, node_ip_address,
|
||||
use_valgrind=False, use_profiler=False,
|
||||
stdout_file=None, stderr_file=None):
|
||||
"""Start a global scheduler process.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -5,7 +5,6 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
import redis
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
@@ -17,6 +16,7 @@ import ray.plasma as plasma
|
||||
from ray.plasma.utils import create_object
|
||||
|
||||
from ray import services
|
||||
from ray.experimental import state
|
||||
|
||||
USE_VALGRIND = False
|
||||
PLASMA_STORE_MEMORY = 1000000000
|
||||
@@ -26,13 +26,6 @@ NUM_CLUSTER_NODES = 2
|
||||
NIL_WORKER_ID = 20 * b"\xff"
|
||||
NIL_ACTOR_ID = 20 * b"\xff"
|
||||
|
||||
# These constants must match the scheduling state enum in task.h.
|
||||
TASK_STATUS_WAITING = 1
|
||||
TASK_STATUS_SCHEDULED = 2
|
||||
TASK_STATUS_QUEUED = 4
|
||||
TASK_STATUS_RUNNING = 8
|
||||
TASK_STATUS_DONE = 16
|
||||
|
||||
# These constants are an implementation detail of ray_redis_module.cc, so this
|
||||
# must be kept in sync with that file.
|
||||
DB_CLIENT_PREFIX = "CL:"
|
||||
@@ -63,15 +56,17 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Start one Redis server and N pairs of (plasma, local_scheduler)
|
||||
node_ip_address = "127.0.0.1"
|
||||
redis_port, self.redis_process = services.start_redis(cleanup=False)
|
||||
redis_address = services.address(node_ip_address, redis_port)
|
||||
# Create a Redis client.
|
||||
self.redis_client = redis.StrictRedis(host=node_ip_address,
|
||||
port=redis_port)
|
||||
self.node_ip_address = "127.0.0.1"
|
||||
redis_address, redis_shards = services.start_redis(self.node_ip_address)
|
||||
redis_port = services.get_port(redis_address)
|
||||
time.sleep(0.1)
|
||||
# Create a client for the global state store.
|
||||
self.state = state.GlobalState()
|
||||
self.state._initialize_global_state(self.node_ip_address, redis_port)
|
||||
|
||||
# Start one global scheduler.
|
||||
self.p1 = global_scheduler.start_global_scheduler(
|
||||
redis_address, node_ip_address, use_valgrind=USE_VALGRIND)
|
||||
redis_address, self.node_ip_address, use_valgrind=USE_VALGRIND)
|
||||
self.plasma_store_pids = []
|
||||
self.plasma_manager_pids = []
|
||||
self.local_scheduler_pids = []
|
||||
@@ -89,7 +84,8 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
redis_address)
|
||||
plasma_manager_name, p3, plasma_manager_port = manager_info
|
||||
self.plasma_manager_pids.append(p3)
|
||||
plasma_address = "{}:{}".format(node_ip_address, plasma_manager_port)
|
||||
plasma_address = "{}:{}".format(self.node_ip_address,
|
||||
plasma_manager_port)
|
||||
plasma_client = plasma.PlasmaClient(plasma_store_name,
|
||||
plasma_manager_name)
|
||||
self.plasma_clients.append(plasma_client)
|
||||
@@ -116,7 +112,10 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
for p4 in self.local_scheduler_pids:
|
||||
self.assertEqual(p4.poll(), None)
|
||||
|
||||
self.assertEqual(self.redis_process.poll(), None)
|
||||
redis_processes = services.all_processes[
|
||||
services.PROCESS_TYPE_REDIS_SERVER]
|
||||
for redis_process in redis_processes:
|
||||
self.assertEqual(redis_process.poll(), None)
|
||||
|
||||
# Kill the global scheduler.
|
||||
if USE_VALGRIND:
|
||||
@@ -135,7 +134,9 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
p4.kill()
|
||||
# Kill Redis. In the event that we are using valgrind, this needs to happen
|
||||
# after we kill the global scheduler.
|
||||
self.redis_process.kill()
|
||||
while redis_processes:
|
||||
redis_process = redis_processes.pop()
|
||||
redis_process.kill()
|
||||
|
||||
def get_plasma_manager_id(self):
|
||||
"""Get the db_client_id with client_type equal to plasma_manager.
|
||||
@@ -150,11 +151,10 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
"""
|
||||
db_client_id = None
|
||||
|
||||
client_list = self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))
|
||||
for client_id in client_list:
|
||||
response = self.redis_client.hget(client_id, b"client_type")
|
||||
if response == b"plasma_manager":
|
||||
db_client_id = client_id
|
||||
client_list = self.state.client_table()[self.node_ip_address]
|
||||
for client in client_list:
|
||||
if client["ClientType"] == "plasma_manager":
|
||||
db_client_id = client["DBClientID"]
|
||||
break
|
||||
|
||||
return db_client_id
|
||||
@@ -178,18 +178,16 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
# There should be 2n+1 db clients: the global scheduler + one local
|
||||
# scheduler and one plasma per node.
|
||||
self.assertEqual(
|
||||
len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
|
||||
len(self.state.client_table()[self.node_ip_address]),
|
||||
2 * NUM_CLUSTER_NODES + 1)
|
||||
db_client_id = self.get_plasma_manager_id()
|
||||
assert(db_client_id is not None)
|
||||
assert(db_client_id.startswith(b"CL:"))
|
||||
db_client_id = db_client_id[len(b"CL:"):] # Remove the CL: prefix.
|
||||
|
||||
def test_integration_single_task(self):
|
||||
# There should be three db clients, the global scheduler, the local
|
||||
# scheduler, and the plasma manager.
|
||||
self.assertEqual(
|
||||
len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
|
||||
len(self.state.client_table()[self.node_ip_address]),
|
||||
2 * NUM_CLUSTER_NODES + 1)
|
||||
|
||||
num_return_vals = [0, 1, 2, 3, 5, 10]
|
||||
@@ -212,15 +210,15 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
# local scheduler
|
||||
num_retries = 10
|
||||
while num_retries > 0:
|
||||
task_entries = self.redis_client.keys("{}*".format(TASK_PREFIX))
|
||||
task_entries = self.state.task_table()
|
||||
self.assertLessEqual(len(task_entries), 1)
|
||||
if len(task_entries) == 1:
|
||||
task_contents = self.redis_client.hgetall(task_entries[0])
|
||||
task_status = int(task_contents[b"state"])
|
||||
self.assertTrue(task_status in [TASK_STATUS_WAITING,
|
||||
TASK_STATUS_SCHEDULED,
|
||||
TASK_STATUS_QUEUED])
|
||||
if task_status == TASK_STATUS_QUEUED:
|
||||
task_id, task = task_entries.popitem()
|
||||
task_status = task["State"]
|
||||
self.assertTrue(task_status in [state.TASK_STATUS_WAITING,
|
||||
state.TASK_STATUS_SCHEDULED,
|
||||
state.TASK_STATUS_QUEUED])
|
||||
if task_status == state.TASK_STATUS_QUEUED:
|
||||
break
|
||||
else:
|
||||
print(task_status)
|
||||
@@ -228,7 +226,7 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
num_retries -= 1
|
||||
time.sleep(1)
|
||||
|
||||
if num_retries <= 0 and task_status != TASK_STATUS_QUEUED:
|
||||
if num_retries <= 0 and task_status != state.TASK_STATUS_QUEUED:
|
||||
# Failed to submit and schedule a single task -- bail.
|
||||
self.tearDown()
|
||||
sys.exit(1)
|
||||
@@ -237,7 +235,7 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
# There should be three db clients, the global scheduler, the local
|
||||
# scheduler, and the plasma manager.
|
||||
self.assertEqual(
|
||||
len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
|
||||
len(self.state.client_table()[self.node_ip_address]),
|
||||
2 * NUM_CLUSTER_NODES + 1)
|
||||
num_return_vals = [0, 1, 2, 3, 5, 10]
|
||||
|
||||
@@ -264,34 +262,31 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
num_retries = 10
|
||||
num_tasks_done = 0
|
||||
while num_retries > 0:
|
||||
task_entries = self.redis_client.keys("{}*".format(TASK_PREFIX))
|
||||
task_entries = self.state.task_table()
|
||||
self.assertLessEqual(len(task_entries), num_tasks)
|
||||
# First, check if all tasks made it to Redis.
|
||||
if len(task_entries) == num_tasks:
|
||||
task_contents = [self.redis_client.hgetall(task_entries[i])
|
||||
for i in range(len(task_entries))]
|
||||
task_statuses = [int(contents[b"state"]) for contents in task_contents]
|
||||
self.assertTrue(all([status in [TASK_STATUS_WAITING,
|
||||
TASK_STATUS_SCHEDULED,
|
||||
TASK_STATUS_QUEUED]
|
||||
task_statuses = [task_entry["State"] for task_entry in
|
||||
task_entries.values()]
|
||||
self.assertTrue(all([status in [state.TASK_STATUS_WAITING,
|
||||
state.TASK_STATUS_SCHEDULED,
|
||||
state.TASK_STATUS_QUEUED]
|
||||
for status in task_statuses]))
|
||||
num_tasks_done = task_statuses.count(TASK_STATUS_QUEUED)
|
||||
num_tasks_scheduled = task_statuses.count(TASK_STATUS_SCHEDULED)
|
||||
num_tasks_waiting = task_statuses.count(TASK_STATUS_WAITING)
|
||||
num_tasks_done = task_statuses.count(state.TASK_STATUS_QUEUED)
|
||||
num_tasks_scheduled = task_statuses.count(state.TASK_STATUS_SCHEDULED)
|
||||
num_tasks_waiting = task_statuses.count(state.TASK_STATUS_WAITING)
|
||||
print("tasks in Redis = {}, tasks waiting = {}, tasks scheduled = {}, "
|
||||
"tasks queued = {}, retries left = {}"
|
||||
.format(len(task_entries), num_tasks_waiting,
|
||||
num_tasks_scheduled, num_tasks_done, num_retries))
|
||||
if all([status == TASK_STATUS_QUEUED for status in task_statuses]):
|
||||
if all([status == state.TASK_STATUS_QUEUED for status in
|
||||
task_statuses]):
|
||||
# We're done, so pass.
|
||||
break
|
||||
num_retries -= 1
|
||||
time.sleep(0.1)
|
||||
|
||||
if num_tasks_done != num_tasks:
|
||||
# At least one of the tasks failed to schedule.
|
||||
self.tearDown()
|
||||
sys.exit(2)
|
||||
self.assertEqual(num_tasks_done, num_tasks)
|
||||
|
||||
def test_integration_many_tasks_handler_sync(self):
|
||||
self.integration_many_tasks_helper(timesync=True)
|
||||
|
||||
+38
-40
@@ -12,11 +12,13 @@ import time
|
||||
import ray
|
||||
from ray.services import get_ip_address
|
||||
from ray.services import get_port
|
||||
from ray.utils import binary_to_object_id
|
||||
from ray.utils import binary_to_hex
|
||||
from ray.utils import hex_to_binary
|
||||
|
||||
# Import flatbuffer bindings.
|
||||
from ray.core.generated.SubscribeToDBClientTableReply \
|
||||
import SubscribeToDBClientTableReply
|
||||
from ray.core.generated.TaskReply import TaskReply
|
||||
from ray.core.generated.DriverTableMessage import DriverTableMessage
|
||||
|
||||
# These variables must be kept in sync with the C codebase.
|
||||
@@ -31,7 +33,6 @@ TASK_STATUS_LOST = 32
|
||||
PLASMA_MANAGER_HEARTBEAT_CHANNEL = b"plasma_managers"
|
||||
DRIVER_DEATH_CHANNEL = b"driver_deaths"
|
||||
# common/redis_module/ray_redis_module.cc
|
||||
TASK_PREFIX = "TT:"
|
||||
OBJECT_PREFIX = "OL:"
|
||||
DB_CLIENT_PREFIX = "CL:"
|
||||
DB_CLIENT_TABLE_NAME = b"db_clients"
|
||||
@@ -43,7 +44,7 @@ PLASMA_MANAGER_CLIENT_TYPE = b"plasma_manager"
|
||||
# Set up logging.
|
||||
logging.basicConfig()
|
||||
log = logging.getLogger()
|
||||
log.setLevel(logging.WARN)
|
||||
log.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class Monitor(object):
|
||||
@@ -70,7 +71,11 @@ class Monitor(object):
|
||||
"""
|
||||
def __init__(self, redis_address, redis_port):
|
||||
# Initialize the Redis clients.
|
||||
self.state = ray.experimental.state.GlobalState()
|
||||
self.state._initialize_global_state(redis_address, redis_port)
|
||||
self.redis = redis.StrictRedis(host=redis_address, port=redis_port, db=0)
|
||||
# TODO(swang): Update pubsub client to use ray.experimental.state once
|
||||
# subscriptions are implemented there.
|
||||
self.subscribe_client = self.redis.pubsub()
|
||||
self.subscribed = {}
|
||||
# Initialize data structures to keep track of the active database clients.
|
||||
@@ -97,23 +102,17 @@ class Monitor(object):
|
||||
TASK_STATUS_LOST. A local scheduler is deemed dead if it is in
|
||||
self.dead_local_schedulers.
|
||||
"""
|
||||
task_ids = self.redis.scan_iter(
|
||||
match="{prefix}*".format(prefix=TASK_PREFIX))
|
||||
tasks = self.state.task_table()
|
||||
num_tasks_updated = 0
|
||||
for task_id in task_ids:
|
||||
task_id = task_id[len(TASK_PREFIX):]
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_GET", task_id)
|
||||
# Parse the serialized task object.
|
||||
task_object = TaskReply.GetRootAsTaskReply(response, 0)
|
||||
local_scheduler_id = task_object.LocalSchedulerId()
|
||||
for task_id, task in tasks.items():
|
||||
# See if the corresponding local scheduler is alive.
|
||||
if local_scheduler_id in self.dead_local_schedulers:
|
||||
if task["LocalSchedulerID"] in self.dead_local_schedulers:
|
||||
# If the task is scheduled on a dead local scheduler, mark the task as
|
||||
# lost.
|
||||
ok = self.redis.execute_command("RAY.TASK_TABLE_UPDATE",
|
||||
task_id,
|
||||
TASK_STATUS_LOST,
|
||||
NIL_ID)
|
||||
key = binary_to_object_id(hex_to_binary(task_id))
|
||||
ok = self.state._execute_command(
|
||||
key, "RAY.TASK_TABLE_UPDATE", hex_to_binary(task_id),
|
||||
ray.experimental.state.TASK_STATUS_LOST, NIL_ID)
|
||||
if ok != b"OK":
|
||||
log.warn("Failed to update lost task for dead scheduler.")
|
||||
num_tasks_updated += 1
|
||||
@@ -129,19 +128,20 @@ class Monitor(object):
|
||||
"""
|
||||
# TODO(swang): Also kill the associated plasma store, since it's no longer
|
||||
# reachable without a plasma manager.
|
||||
object_ids = self.redis.scan_iter(
|
||||
match="{prefix}*".format(prefix=OBJECT_PREFIX))
|
||||
objects = self.state.object_table()
|
||||
num_objects_removed = 0
|
||||
for object_id in object_ids:
|
||||
object_id = object_id[len(OBJECT_PREFIX):]
|
||||
managers = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
object_id)
|
||||
for manager in managers:
|
||||
for object_id, obj in objects.items():
|
||||
manager_ids = obj["ManagerIDs"]
|
||||
if manager_ids is None:
|
||||
continue
|
||||
for manager in manager_ids:
|
||||
if manager in self.dead_plasma_managers:
|
||||
# If the object was on a dead plasma manager, remove that location
|
||||
# entry.
|
||||
ok = self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", object_id,
|
||||
manager)
|
||||
ok = self.state._execute_command(object_id,
|
||||
"RAY.OBJECT_TABLE_REMOVE",
|
||||
object_id.id(),
|
||||
hex_to_binary(manager))
|
||||
if ok != b"OK":
|
||||
log.warn("Failed to remove object location for dead plasma "
|
||||
"manager.")
|
||||
@@ -157,18 +157,16 @@ class Monitor(object):
|
||||
not miss any notifications for deleted clients that occurred before we
|
||||
subscribed.
|
||||
"""
|
||||
db_client_keys = self.redis.keys(
|
||||
"{prefix}*".format(prefix=DB_CLIENT_PREFIX))
|
||||
for db_client_key in db_client_keys:
|
||||
db_client_id = db_client_key[len(DB_CLIENT_PREFIX):]
|
||||
client_type, deleted = self.redis.hmget(db_client_key,
|
||||
[b"client_type", b"deleted"])
|
||||
deleted = bool(int(deleted))
|
||||
if deleted:
|
||||
if client_type == LOCAL_SCHEDULER_CLIENT_TYPE:
|
||||
self.dead_local_schedulers.add(db_client_id)
|
||||
elif client_type == PLASMA_MANAGER_CLIENT_TYPE:
|
||||
self.dead_plasma_managers.add(db_client_id)
|
||||
clients = self.state.client_table()
|
||||
for node_ip_address, node_clients in clients.items():
|
||||
for client in node_clients:
|
||||
db_client_id = client["DBClientID"]
|
||||
client_type = client["ClientType"]
|
||||
if client["Deleted"]:
|
||||
if client_type == LOCAL_SCHEDULER_CLIENT_TYPE:
|
||||
self.dead_local_schedulers.add(db_client_id)
|
||||
elif client_type == PLASMA_MANAGER_CLIENT_TYPE:
|
||||
self.dead_plasma_managers.add(db_client_id)
|
||||
|
||||
def subscribe_handler(self, channel, data):
|
||||
"""Handle a subscription success message from Redis.
|
||||
@@ -186,7 +184,7 @@ class Monitor(object):
|
||||
"""
|
||||
notification_object = (SubscribeToDBClientTableReply
|
||||
.GetRootAsSubscribeToDBClientTableReply(data, 0))
|
||||
db_client_id = notification_object.DbClientId()
|
||||
db_client_id = binary_to_hex(notification_object.DbClientId())
|
||||
client_type = notification_object.ClientType()
|
||||
is_insertion = notification_object.IsInsertion()
|
||||
|
||||
@@ -196,7 +194,7 @@ class Monitor(object):
|
||||
|
||||
# If the update was a deletion, add them to our accounting for dead
|
||||
# local schedulers and plasma managers.
|
||||
log.warn("Removed {}".format(client_type))
|
||||
log.warn("Removed {}, client ID {}".format(client_type, db_client_id))
|
||||
if client_type == LOCAL_SCHEDULER_CLIENT_TYPE:
|
||||
if db_client_id not in self.dead_local_schedulers:
|
||||
self.dead_local_schedulers.add(db_client_id)
|
||||
@@ -256,7 +254,7 @@ class Monitor(object):
|
||||
result = pipe.hget(local_scheduler_id, "gpus_in_use")
|
||||
gpus_in_use = dict() if result is None else json.loads(result)
|
||||
|
||||
driver_id_hex = ray.utils.binary_to_hex(driver_id)
|
||||
driver_id_hex = binary_to_hex(driver_id)
|
||||
if driver_id_hex in gpus_in_use:
|
||||
num_gpus_returned = gpus_in_use.pop(driver_id_hex)
|
||||
|
||||
|
||||
@@ -405,7 +405,8 @@ def start_plasma_manager(store_name, redis_address,
|
||||
"-m", plasma_manager_name,
|
||||
"-h", node_ip_address,
|
||||
"-p", str(plasma_manager_port),
|
||||
"-r", redis_address]
|
||||
"-r", redis_address,
|
||||
]
|
||||
if use_valgrind:
|
||||
process = subprocess.Popen(["valgrind",
|
||||
"--track-origins=yes",
|
||||
|
||||
@@ -480,7 +480,7 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
store_name1, self.p2 = plasma.start_plasma_store(use_valgrind=USE_VALGRIND)
|
||||
store_name2, self.p3 = plasma.start_plasma_store(use_valgrind=USE_VALGRIND)
|
||||
# Start a Redis server.
|
||||
redis_address = services.address("127.0.0.1", services.start_redis()[0])
|
||||
redis_address, _ = services.start_redis("127.0.0.1")
|
||||
# Start two PlasmaManagers.
|
||||
manager_name1, self.p4, self.port1 = plasma.start_plasma_manager(
|
||||
store_name1, redis_address, use_valgrind=USE_VALGRIND)
|
||||
@@ -778,8 +778,7 @@ class TestPlasmaManagerRecovery(unittest.TestCase):
|
||||
self.store_name, self.p2 = plasma.start_plasma_store(
|
||||
use_valgrind=USE_VALGRIND)
|
||||
# Start a Redis server.
|
||||
self.redis_address = services.address("127.0.0.1",
|
||||
services.start_redis()[0])
|
||||
self.redis_address, _ = services.start_redis("127.0.0.1")
|
||||
# Start a PlasmaManagers.
|
||||
manager_name, self.p3, self.port1 = plasma.start_plasma_manager(
|
||||
self.store_name,
|
||||
|
||||
+110
-45
@@ -240,15 +240,75 @@ def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5):
|
||||
"configured properly.")
|
||||
|
||||
|
||||
def start_redis(node_ip_address="127.0.0.1", port=None, num_retries=20,
|
||||
stdout_file=None, stderr_file=None, cleanup=True):
|
||||
"""Start a Redis server.
|
||||
def start_redis(node_ip_address,
|
||||
port=None,
|
||||
num_redis_shards=1,
|
||||
redirect_output=False,
|
||||
cleanup=True):
|
||||
"""Start the Redis global state store.
|
||||
|
||||
Args:
|
||||
node_ip_address: The IP address of the current node. This is only used for
|
||||
recording the log filenames in Redis.
|
||||
port (int): If provided, the primary Redis shard will be started on this
|
||||
port.
|
||||
num_redis_shards (int): If provided, the number of Redis shards to start,
|
||||
in addition to the primary one. The default value is one shard.
|
||||
cleanup (bool): True if using Ray in local mode. If cleanup is true, then
|
||||
all Redis processes started by this method will be killed by
|
||||
serices.cleanup() when the Python process that imported services exits.
|
||||
|
||||
Returns:
|
||||
A tuple of the address for the primary Redis shard and a list of addresses
|
||||
for the remaining shards.
|
||||
"""
|
||||
redis_stdout_file, redis_stderr_file = new_log_files(
|
||||
"redis", redirect_output)
|
||||
assigned_port, _ = start_redis_instance(
|
||||
node_ip_address=node_ip_address, port=port,
|
||||
stdout_file=redis_stdout_file, stderr_file=redis_stderr_file,
|
||||
cleanup=cleanup)
|
||||
if port is not None:
|
||||
assert assigned_port == port
|
||||
port = assigned_port
|
||||
redis_address = address(node_ip_address, port)
|
||||
|
||||
# Register the number of Redis shards in the primary shard, so that clients
|
||||
# know how many redis shards to expect under RedisShards.
|
||||
redis_client = redis.StrictRedis(host=node_ip_address, port=port)
|
||||
redis_client.set("NumRedisShards", str(num_redis_shards))
|
||||
|
||||
# Start other Redis shards listening on random ports. Each Redis shard logs
|
||||
# to a separate file, prefixed by "redis-<shard number>".
|
||||
redis_shards = []
|
||||
for i in range(num_redis_shards):
|
||||
redis_stdout_file, redis_stderr_file = new_log_files(
|
||||
"redis-{}".format(i), redirect_output)
|
||||
redis_shard_port, _ = start_redis_instance(
|
||||
node_ip_address=node_ip_address, stdout_file=redis_stdout_file,
|
||||
stderr_file=redis_stderr_file, cleanup=cleanup)
|
||||
shard_address = address(node_ip_address, redis_shard_port)
|
||||
redis_shards.append(shard_address)
|
||||
# Store redis shard information in the primary redis shard.
|
||||
redis_client.rpush("RedisShards", shard_address)
|
||||
|
||||
return redis_address, redis_shards
|
||||
|
||||
|
||||
def start_redis_instance(node_ip_address="127.0.0.1",
|
||||
port=None,
|
||||
num_retries=20,
|
||||
stdout_file=None,
|
||||
stderr_file=None,
|
||||
cleanup=True):
|
||||
"""Start a single Redis server.
|
||||
|
||||
Args:
|
||||
node_ip_address (str): The IP address of the current node. This is only
|
||||
used for recording the log filenames in Redis.
|
||||
port (int): If provided, start a Redis server with this port.
|
||||
num_retries (int): The number of times to attempt to start Redis.
|
||||
num_retries (int): The number of times to attempt to start Redis. If a port
|
||||
is provided, this defaults to 1.
|
||||
stdout_file: A file handle opened for writing to redirect stdout to. If no
|
||||
redirection should happen, then this should be None.
|
||||
stderr_file: A file handle opened for writing to redirect stderr to. If no
|
||||
@@ -275,8 +335,8 @@ def start_redis(node_ip_address="127.0.0.1", port=None, num_retries=20,
|
||||
assert os.path.isfile(redis_module)
|
||||
counter = 0
|
||||
if port is not None:
|
||||
if num_retries != 1:
|
||||
raise Exception("num_retries must be 1 if port is specified.")
|
||||
# If a port is specified, then try only once to connect.
|
||||
num_retries = 1
|
||||
else:
|
||||
port = new_port()
|
||||
while counter < num_retries:
|
||||
@@ -356,8 +416,8 @@ def start_log_monitor(redis_address, node_ip_address, stdout_file=None,
|
||||
[stdout_file, stderr_file])
|
||||
|
||||
|
||||
def start_global_scheduler(redis_address, node_ip_address, stdout_file=None,
|
||||
stderr_file=None, cleanup=True):
|
||||
def start_global_scheduler(redis_address, node_ip_address,
|
||||
stdout_file=None, stderr_file=None, cleanup=True):
|
||||
"""Start a global scheduler process.
|
||||
|
||||
Args:
|
||||
@@ -372,7 +432,8 @@ def start_global_scheduler(redis_address, node_ip_address, stdout_file=None,
|
||||
this process will be killed by services.cleanup() when the Python process
|
||||
that imported services exits.
|
||||
"""
|
||||
p = global_scheduler.start_global_scheduler(redis_address, node_ip_address,
|
||||
p = global_scheduler.start_global_scheduler(redis_address,
|
||||
node_ip_address,
|
||||
stdout_file=stdout_file,
|
||||
stderr_file=stderr_file)
|
||||
if cleanup:
|
||||
@@ -545,10 +606,11 @@ def start_local_scheduler(redis_address,
|
||||
return local_scheduler_name
|
||||
|
||||
|
||||
def start_objstore(node_ip_address, redis_address, object_manager_port=None,
|
||||
store_stdout_file=None, store_stderr_file=None,
|
||||
manager_stdout_file=None, manager_stderr_file=None,
|
||||
cleanup=True, objstore_memory=None):
|
||||
def start_objstore(node_ip_address, redis_address,
|
||||
object_manager_port=None, store_stdout_file=None,
|
||||
store_stderr_file=None, manager_stdout_file=None,
|
||||
manager_stderr_file=None, cleanup=True,
|
||||
objstore_memory=None):
|
||||
"""This method starts an object store process.
|
||||
|
||||
Args:
|
||||
@@ -704,13 +766,14 @@ def start_monitor(redis_address, node_ip_address, stdout_file=None,
|
||||
|
||||
def start_ray_processes(address_info=None,
|
||||
node_ip_address="127.0.0.1",
|
||||
redis_port=None,
|
||||
num_workers=None,
|
||||
num_local_schedulers=1,
|
||||
num_redis_shards=1,
|
||||
worker_path=None,
|
||||
cleanup=True,
|
||||
redirect_output=False,
|
||||
include_global_scheduler=False,
|
||||
include_redis=False,
|
||||
include_log_monitor=False,
|
||||
include_webui=False,
|
||||
start_workers_from_local_scheduler=True,
|
||||
@@ -723,12 +786,17 @@ def start_ray_processes(address_info=None,
|
||||
that have already been started. If provided, address_info will be
|
||||
modified to include processes that are newly started.
|
||||
node_ip_address (str): The IP address of this node.
|
||||
redis_port (int): The port that the primary Redis shard should listen to.
|
||||
If None, then a random port will be chosen. If the key "redis_address" is
|
||||
in address_info, then this argument will be ignored.
|
||||
num_workers (int): The number of workers to start.
|
||||
num_local_schedulers (int): The total number of local schedulers required.
|
||||
This is also the total number of object stores required. This method will
|
||||
start new instances of local schedulers and object stores until there are
|
||||
num_local_schedulers existing instances of each, including ones already
|
||||
registered with the given address_info.
|
||||
num_redis_shards: The number of Redis shards to start in addition to the
|
||||
primary Redis shard.
|
||||
worker_path (str): The path of the source code that will be run by the
|
||||
worker.
|
||||
cleanup (bool): If cleanup is true, then the processes started here will be
|
||||
@@ -738,8 +806,6 @@ def start_ray_processes(address_info=None,
|
||||
file.
|
||||
include_global_scheduler (bool): If include_global_scheduler is True, then
|
||||
start a global scheduler process.
|
||||
include_redis (bool): If include_redis is True, then start a Redis server
|
||||
process.
|
||||
include_log_monitor (bool): If True, then start a log monitor to monitor
|
||||
the log files for all processes on this node and push their contents to
|
||||
Redis.
|
||||
@@ -785,29 +851,14 @@ def start_ray_processes(address_info=None,
|
||||
# warning messages when it starts up. Instead of suppressing the output, we
|
||||
# should address the warnings.
|
||||
redis_address = address_info.get("redis_address")
|
||||
if include_redis:
|
||||
redis_stdout_file, redis_stderr_file = new_log_files("redis",
|
||||
redirect_output)
|
||||
if redis_address is None:
|
||||
# Start a Redis server. The start_redis method will choose a random port.
|
||||
redis_port, _ = start_redis(node_ip_address,
|
||||
stdout_file=redis_stdout_file,
|
||||
stderr_file=redis_stderr_file,
|
||||
cleanup=cleanup)
|
||||
redis_address = address(node_ip_address, redis_port)
|
||||
address_info["redis_address"] = redis_address
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
# A Redis address was provided, so start a Redis server with the given
|
||||
# port. TODO(rkn): We should check that the IP address corresponds to the
|
||||
# machine that this method is running on.
|
||||
redis_port = get_port(redis_address)
|
||||
new_redis_port, _ = start_redis(port=int(redis_port),
|
||||
num_retries=1,
|
||||
stdout_file=redis_stdout_file,
|
||||
stderr_file=redis_stderr_file,
|
||||
cleanup=cleanup)
|
||||
assert redis_port == new_redis_port
|
||||
redis_shards = address_info.get("redis_shards", [])
|
||||
if redis_address is None:
|
||||
redis_address, redis_shards = start_redis(
|
||||
node_ip_address, port=redis_port, num_redis_shards=num_redis_shards,
|
||||
redirect_output=redirect_output, cleanup=cleanup)
|
||||
address_info["redis_address"] = redis_address
|
||||
time.sleep(0.1)
|
||||
|
||||
# Start monitoring the processes.
|
||||
monitor_stdout_file, monitor_stderr_file = new_log_files("monitor",
|
||||
redirect_output)
|
||||
@@ -815,9 +866,14 @@ def start_ray_processes(address_info=None,
|
||||
node_ip_address,
|
||||
stdout_file=monitor_stdout_file,
|
||||
stderr_file=monitor_stderr_file)
|
||||
else:
|
||||
if redis_address is None:
|
||||
raise Exception("Redis address expected")
|
||||
|
||||
if redis_shards == []:
|
||||
# Get redis shards from primary redis instance.
|
||||
redis_ip_address, redis_port = redis_address.split(":")
|
||||
redis_client = redis.StrictRedis(host=redis_ip_address, port=redis_port)
|
||||
redis_shards = redis_client.lrange("RedisShards", start=0, end=-1)
|
||||
redis_shards = [shard.decode("ascii") for shard in redis_shards]
|
||||
address_info["redis_shards"] = redis_shards
|
||||
|
||||
# Start the log monitor, if necessary.
|
||||
if include_log_monitor:
|
||||
@@ -1005,6 +1061,7 @@ def start_ray_node(node_ip_address,
|
||||
|
||||
def start_ray_head(address_info=None,
|
||||
node_ip_address="127.0.0.1",
|
||||
redis_port=None,
|
||||
num_workers=0,
|
||||
num_local_schedulers=1,
|
||||
worker_path=None,
|
||||
@@ -1012,7 +1069,8 @@ def start_ray_head(address_info=None,
|
||||
redirect_output=False,
|
||||
start_workers_from_local_scheduler=True,
|
||||
num_cpus=None,
|
||||
num_gpus=None):
|
||||
num_gpus=None,
|
||||
num_redis_shards=None):
|
||||
"""Start Ray in local mode.
|
||||
|
||||
Args:
|
||||
@@ -1020,6 +1078,9 @@ def start_ray_head(address_info=None,
|
||||
that have already been started. If provided, address_info will be
|
||||
modified to include processes that are newly started.
|
||||
node_ip_address (str): The IP address of this node.
|
||||
redis_port (int): The port that the primary Redis shard should listen to.
|
||||
If None, then a random port will be chosen. If the key "redis_address" is
|
||||
in address_info, then this argument will be ignored.
|
||||
num_workers (int): The number of workers to start.
|
||||
num_local_schedulers (int): The total number of local schedulers required.
|
||||
This is also the total number of object stores required. This method will
|
||||
@@ -1038,14 +1099,18 @@ def start_ray_head(address_info=None,
|
||||
Python.
|
||||
num_cpus (int): number of cpus to configure the local scheduler with.
|
||||
num_gpus (int): number of gpus to configure the local scheduler with.
|
||||
num_redis_shards: The number of Redis shards to start in addition to the
|
||||
primary Redis shard.
|
||||
|
||||
Returns:
|
||||
A dictionary of the address information for the processes that were
|
||||
started.
|
||||
"""
|
||||
num_redis_shards = 1 if num_redis_shards is None else num_redis_shards
|
||||
return start_ray_processes(
|
||||
address_info=address_info,
|
||||
node_ip_address=node_ip_address,
|
||||
redis_port=redis_port,
|
||||
num_workers=num_workers,
|
||||
num_local_schedulers=num_local_schedulers,
|
||||
worker_path=worker_path,
|
||||
@@ -1053,11 +1118,11 @@ def start_ray_head(address_info=None,
|
||||
redirect_output=redirect_output,
|
||||
include_global_scheduler=True,
|
||||
include_log_monitor=True,
|
||||
include_redis=True,
|
||||
include_webui=False,
|
||||
start_workers_from_local_scheduler=start_workers_from_local_scheduler,
|
||||
num_cpus=num_cpus,
|
||||
num_gpus=num_gpus)
|
||||
num_gpus=num_gpus,
|
||||
num_redis_shards=num_redis_shards)
|
||||
|
||||
|
||||
def new_log_files(name, redirect_output):
|
||||
|
||||
+28
-15
@@ -992,7 +992,8 @@ def _init(address_info=None,
|
||||
redirect_output=False,
|
||||
start_workers_from_local_scheduler=True,
|
||||
num_cpus=None,
|
||||
num_gpus=None):
|
||||
num_gpus=None,
|
||||
num_redis_shards=None):
|
||||
"""Helper method to connect to an existing Ray cluster or start a new one.
|
||||
|
||||
This method handles two cases. Either a Ray cluster already exists and we
|
||||
@@ -1002,8 +1003,9 @@ def _init(address_info=None,
|
||||
Args:
|
||||
address_info (dict): A dictionary with address information for processes in
|
||||
a partially-started Ray cluster. If start_ray_local=True, any processes
|
||||
not in this dictionary will be started. If provided, address_info will be
|
||||
modified to include processes that are newly started.
|
||||
not in this dictionary will be started. If provided, an updated
|
||||
address_info dictionary will be returned to include processes that are
|
||||
newly started.
|
||||
start_ray_local (bool): If True then this will start any processes not
|
||||
already in address_info, including Redis, a global scheduler, local
|
||||
scheduler(s), object store(s), and worker(s). It will also kill these
|
||||
@@ -1028,6 +1030,8 @@ def _init(address_info=None,
|
||||
be configured with.
|
||||
num_gpus: A list containing the number of GPUs the local schedulers should
|
||||
be configured with.
|
||||
num_redis_shards: The number of Redis shards to start in addition to the
|
||||
primary Redis shard.
|
||||
|
||||
Returns:
|
||||
Address information about the started processes.
|
||||
@@ -1069,6 +1073,8 @@ def _init(address_info=None,
|
||||
num_local_schedulers = len(local_schedulers)
|
||||
else:
|
||||
num_local_schedulers = 1
|
||||
# Use 1 additional redis shard if num_redis_shards is not provided.
|
||||
num_redis_shards = 1 if num_redis_shards is None else num_redis_shards
|
||||
# Start the scheduler, object store, and some workers. These will be killed
|
||||
# by the call to cleanup(), which happens when the Python script exits.
|
||||
address_info = services.start_ray_head(
|
||||
@@ -1079,20 +1085,24 @@ def _init(address_info=None,
|
||||
redirect_output=redirect_output,
|
||||
start_workers_from_local_scheduler=start_workers_from_local_scheduler,
|
||||
num_cpus=num_cpus,
|
||||
num_gpus=num_gpus)
|
||||
num_gpus=num_gpus,
|
||||
num_redis_shards=num_redis_shards)
|
||||
else:
|
||||
if redis_address is None:
|
||||
raise Exception("If start_ray_local=False, then redis_address must be "
|
||||
"provided.")
|
||||
raise Exception("When connecting to an existing cluster, redis_address "
|
||||
"must be provided.")
|
||||
if num_workers is not None:
|
||||
raise Exception("If start_ray_local=False, then num_workers must not be "
|
||||
"provided.")
|
||||
raise Exception("When connecting to an existing cluster, num_workers "
|
||||
"must not be provided.")
|
||||
if num_local_schedulers is not None:
|
||||
raise Exception("If start_ray_local=False, then num_local_schedulers "
|
||||
"must not be provided.")
|
||||
raise Exception("When connecting to an existing cluster, "
|
||||
"num_local_schedulers must not be provided.")
|
||||
if num_cpus is not None or num_gpus is not None:
|
||||
raise Exception("If start_ray_local=False, then num_cpus and num_gpus "
|
||||
"must not be provided.")
|
||||
raise Exception("When connecting to an existing cluster, num_cpus and "
|
||||
"num_gpus must not be provided.")
|
||||
if num_redis_shards is not None:
|
||||
raise Exception("When connecting to an existing cluster, "
|
||||
"num_redis_shards must not be provided.")
|
||||
# Get the node IP address if one is not provided.
|
||||
if node_ip_address is None:
|
||||
node_ip_address = services.get_node_ip_address(redis_address)
|
||||
@@ -1121,7 +1131,7 @@ def _init(address_info=None,
|
||||
|
||||
def init(redis_address=None, node_ip_address=None, object_id_seed=None,
|
||||
num_workers=None, driver_mode=SCRIPT_MODE, redirect_output=False,
|
||||
num_cpus=None, num_gpus=None):
|
||||
num_cpus=None, num_gpus=None, num_redis_shards=None):
|
||||
"""Either connect to an existing Ray cluster or start one and connect to it.
|
||||
|
||||
This method handles two cases. Either a Ray cluster already exists and we
|
||||
@@ -1148,6 +1158,8 @@ def init(redis_address=None, node_ip_address=None, object_id_seed=None,
|
||||
configured with.
|
||||
num_gpus (int): Number of gpus the user wishes all local schedulers to be
|
||||
configured with.
|
||||
num_redis_shards: The number of Redis shards to start in addition to the
|
||||
primary Redis shard.
|
||||
|
||||
Returns:
|
||||
Address information about the started processes.
|
||||
@@ -1161,7 +1173,7 @@ def init(redis_address=None, node_ip_address=None, object_id_seed=None,
|
||||
return _init(address_info=info, start_ray_local=(redis_address is None),
|
||||
num_workers=num_workers, driver_mode=driver_mode,
|
||||
redirect_output=redirect_output, num_cpus=num_cpus,
|
||||
num_gpus=num_gpus)
|
||||
num_gpus=num_gpus, num_redis_shards=num_redis_shards)
|
||||
|
||||
|
||||
def cleanup(worker=global_worker):
|
||||
@@ -1577,7 +1589,8 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker,
|
||||
ray.local_scheduler.ObjectID(NIL_ACTOR_ID),
|
||||
worker.actor_counters[actor_id],
|
||||
[0, 0])
|
||||
worker.redis_client.execute_command(
|
||||
global_state._execute_command(
|
||||
driver_task.task_id(),
|
||||
"RAY.TASK_TABLE_ADD",
|
||||
driver_task.task_id().id(),
|
||||
TASK_STATUS_RUNNING,
|
||||
|
||||
Reference in New Issue
Block a user