Shard Redis. (#539)

* Implement sharding in the Ray core

* Single node Python modifications to do sharding

* Do the sharding in redis.cc

* Pipe num_redis_shards through start_ray.py and worker.py.

* Use multiple redis shards in multinode tests.

* first steps for sharding ray.global_state

* Fix problem in multinode docker test.

* fix runtest.py

* fix some tests

* fix redis shard startup

* fix redis sharding

* fix

* fix bug introduced by the map-iterator being consumed

* fix sharding bug

* shard event table

* update number of Redis clients to be 64K

* Fix object table tests by flushing shards in between unit tests

* Fix local scheduler tests

* Documentation

* Register shard locations in the primary shard

* Add plasma unit tests back to build

* lint

* lint and fix build

* Fix

* Address Robert's comments

* Refactor start_ray_processes to start Redis shard

* lint

* Fix global scheduler python tests

* Fix redis module test

* Fix plasma test

* Fix component failure test

* Fix local scheduler test

* Fix runtest.py

* Fix global scheduler test for python3

* Fix task_table_test_and_update bug, from actor task table submission race

* Fix jenkins tests.

* Retry Redis shard connections

* Fix test cases

* Convert database clients to DBClient struct

* Fix race condition when subscribing to db client table

* Remove unused lines, add APITest for sharded Ray

* Fix

* Fix memory leak

* Suppress ReconstructionTests output

* Suppress output for APITestSharded

* Reissue task table add/update commands if initial command does not publish to any subscribers.

* fix

* Fix linting.

* fix tests

* fix linting

* fix python test

* fix linting
This commit is contained in:
Stephanie Wang
2017-05-18 17:40:41 -07:00
committed by Philipp Moritz
parent 0a4304725f
commit ee08c8274b
39 changed files with 1336 additions and 651 deletions
+85 -26
View File
@@ -5,6 +5,7 @@ from __future__ import print_function
import pickle
import redis
import ray
from ray.utils import (decode, binary_to_object_id, binary_to_hex,
hex_to_binary)
@@ -25,14 +26,21 @@ OBJECT_CHANNEL_PREFIX = "OC:"
# This mapping from integer to task state string must be kept up-to-date with
# the scheduling_state enum in task.h.
task_state_mapping = {
1: "WAITING",
2: "SCHEDULED",
4: "QUEUED",
8: "RUNNING",
16: "DONE",
32: "LOST",
64: "RECONSTRUCTING"
TASK_STATUS_WAITING = 1
TASK_STATUS_SCHEDULED = 2
TASK_STATUS_QUEUED = 4
TASK_STATUS_RUNNING = 8
TASK_STATUS_DONE = 16
TASK_STATUS_LOST = 32
TASK_STATUS_RECONSTRUCTING = 64
TASK_STATUS_MAPPING = {
TASK_STATUS_WAITING: "WAITING",
TASK_STATUS_SCHEDULED: "SCHEDULED",
TASK_STATUS_QUEUED: "QUEUED",
TASK_STATUS_RUNNING: "RUNNING",
TASK_STATUS_DONE: "DONE",
TASK_STATUS_LOST: "LOST",
TASK_STATUS_RECONSTRUCTING: "RECONSTRUCTING",
}
@@ -66,8 +74,54 @@ class GlobalState(object):
"""
self.redis_client = redis.StrictRedis(host=redis_ip_address,
port=redis_port)
self.redis_clients = []
num_redis_shards = self.redis_client.get("NumRedisShards")
if num_redis_shards is None:
raise Exception("No entry found for NumRedisShards")
num_redis_shards = int(num_redis_shards)
if (num_redis_shards < 1):
raise Exception("Expected at least one Redis shard, found "
"{}.".format(num_redis_shards))
def _object_table(self, object_id_binary):
ip_address_ports = self.redis_client.lrange("RedisShards", start=0, end=-1)
if len(ip_address_ports) != num_redis_shards:
raise Exception("Expected {} Redis shard addresses, found "
"{}".format(num_redis_shards, len(ip_address_ports)))
for ip_address_port in ip_address_ports:
shard_address, shard_port = ip_address_port.split(b":")
self.redis_clients.append(redis.StrictRedis(host=shard_address,
port=shard_port))
def _execute_command(self, key, *args):
"""Execute a Redis command on the appropriate Redis shard based on key.
Args:
key: The object ID or the task ID that the query is about.
args: The command to run.
Returns:
The value returned by the Redis command.
"""
client = self.redis_clients[key.redis_shard_hash() %
len(self.redis_clients)]
return client.execute_command(*args)
def _keys(self, pattern):
"""Execute the KEYS command on all Redis shards.
Args:
pattern: The KEYS pattern to query.
Returns:
The concatenated list of results from all shards.
"""
result = []
for client in self.redis_clients:
result.extend(client.keys(pattern))
return result
def _object_table(self, object_id):
"""Fetch and parse the object table information for a single object ID.
Args:
@@ -78,16 +132,18 @@ class GlobalState(object):
A dictionary with information about the object ID in question.
"""
# Return information about a single object ID.
object_locations = self.redis_client.execute_command(
"RAY.OBJECT_TABLE_LOOKUP", object_id_binary)
object_locations = self._execute_command(object_id,
"RAY.OBJECT_TABLE_LOOKUP",
object_id.id())
if object_locations is not None:
manager_ids = [binary_to_hex(manager_id)
for manager_id in object_locations]
else:
manager_ids = None
result_table_response = self.redis_client.execute_command(
"RAY.RESULT_TABLE_LOOKUP", object_id_binary)
result_table_response = self._execute_command(object_id,
"RAY.RESULT_TABLE_LOOKUP",
object_id.id())
result_table_message = ResultTableReply.GetRootAsResultTableReply(
result_table_response, 0)
@@ -111,22 +167,21 @@ class GlobalState(object):
self._check_connected()
if object_id is not None:
# Return information about a single object ID.
return self._object_table(object_id.id())
return self._object_table(object_id)
else:
# Return the entire object table.
object_info_keys = self.redis_client.keys(OBJECT_INFO_PREFIX + "*")
object_location_keys = self.redis_client.keys(
OBJECT_LOCATION_PREFIX + "*")
object_info_keys = self._keys(OBJECT_INFO_PREFIX + "*")
object_location_keys = self._keys(OBJECT_LOCATION_PREFIX + "*")
object_ids_binary = set(
[key[len(OBJECT_INFO_PREFIX):] for key in object_info_keys] +
[key[len(OBJECT_LOCATION_PREFIX):] for key in object_location_keys])
results = {}
for object_id_binary in object_ids_binary:
results[binary_to_object_id(object_id_binary)] = self._object_table(
object_id_binary)
binary_to_object_id(object_id_binary))
return results
def _task_table(self, task_id_binary):
def _task_table(self, task_id):
"""Fetch and parse the task table information for a single object task ID.
Args:
@@ -135,12 +190,15 @@ class GlobalState(object):
Returns:
A dictionary with information about the task ID in question.
TASK_STATUS_MAPPING should be used to parse the "State" field into a
human-readable string.
"""
task_table_response = self.redis_client.execute_command(
"RAY.TASK_TABLE_GET", task_id_binary)
task_table_response = self._execute_command(task_id,
"RAY.TASK_TABLE_GET",
task_id.id())
if task_table_response is None:
raise Exception("There is no entry for task ID {} in the task table."
.format(binary_to_hex(task_id_binary)))
.format(binary_to_hex(task_id.id())))
task_table_message = TaskReply.GetRootAsTaskReply(task_table_response, 0)
task_spec = task_table_message.TaskSpec()
task_spec_message = TaskInfo.GetRootAsTaskInfo(task_spec, 0)
@@ -167,7 +225,7 @@ class GlobalState(object):
for i in range(task_spec_message.ReturnsLength())],
"RequiredResources": required_resources}
return {"State": task_state_mapping[task_table_message.State()],
return {"State": task_table_message.State(),
"LocalSchedulerID": binary_to_hex(
task_table_message.LocalSchedulerId()),
"TaskSpec": task_spec_info}
@@ -185,14 +243,15 @@ class GlobalState(object):
"""
self._check_connected()
if task_id is not None:
return self._task_table(hex_to_binary(task_id))
task_id = ray.local_scheduler.ObjectID(hex_to_binary(task_id))
return self._task_table(task_id)
else:
task_table_keys = self.redis_client.keys(TASK_PREFIX + "*")
task_table_keys = self._keys(TASK_PREFIX + "*")
results = {}
for key in task_table_keys:
task_id_binary = key[len(TASK_PREFIX):]
results[binary_to_hex(task_id_binary)] = self._task_table(
task_id_binary)
ray.local_scheduler.ObjectID(task_id_binary))
return results
def function_table(self, function_id=None):