Switch build system to use CMake completely. (#200)

* switch to CMake completely

...

* cleanup

* Run C tests, update installation instructions.
This commit is contained in:
Philipp Moritz
2017-01-17 16:56:40 -08:00
committed by Robert Nishihara
parent ba8933e10f
commit a708e36225
106 changed files with 467 additions and 870 deletions
View File
View File
+252
View File
@@ -0,0 +1,252 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import random
import subprocess
import sys
import time
import unittest
import redis
# Check if the redis-server binary is present.
redis_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../core/src/common/thirdparty/redis/src/redis-server")
if not os.path.exists(redis_path):
raise Exception("You do not have the redis-server binary. Run `make test` in the plasma directory to get it.")
# Absolute path of the ray redis module.
module_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../core/src/common/redis_module/libray_redis_module.so")
print("path to the redis module is {}".format(module_path))
OBJECT_INFO_PREFIX = "OI:"
OBJECT_LOCATION_PREFIX = "OL:"
OBJECT_SUBSCRIBE_PREFIX = "OS:"
TASK_PREFIX = "TT:"
OBJECT_CHANNEL_PREFIX = "OC:"
def integerToAsciiHex(num, numbytes):
retstr = b""
# Support 32 and 64 bit architecture.
assert(numbytes == 4 or numbytes == 8)
for i in range(numbytes):
curbyte = num & 0xff
if sys.version_info >= (3, 0):
retstr += bytes([curbyte])
else:
retstr += chr(curbyte)
num = num >> 8
return retstr
class TestGlobalStateStore(unittest.TestCase):
def setUp(self):
redis_port = random.randint(2000, 50000)
self.redis_process = subprocess.Popen([redis_path,
"--port", str(redis_port),
"--loglevel", "warning",
"--loadmodule", module_path])
time.sleep(1.5)
self.redis = redis.StrictRedis(host="localhost", port=redis_port, db=0)
def tearDown(self):
self.redis_process.kill()
def testInvalidObjectTableAdd(self):
# Check that Redis returns an error when RAY.OBJECT_TABLE_ADD is called with
# the wrong arguments.
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "hello")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", "one", "hash2", "manager_id1")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", 1, "hash2", "manager_id1", "extra argument")
# Check that Redis returns an error when RAY.OBJECT_TABLE_ADD adds an object
# ID that is already present with a different hash.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash2", "manager_id1")
# Check that it is fine if we add the same object ID multiple times with the
# same hash.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id2")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 2, "hash1", "manager_id2")
def testObjectTableAddAndLookup(self):
# Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not been
# added yet.
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
self.assertEqual(set(response), set([]))
# Add some managers and try again.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
# Add a manager that already exists again and try again.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
# Check that we properly handle NULL characters. In the past, NULL
# characters were handled improperly causing a "hash mismatch" error if two
# object IDs that agreed up to the NULL character were inserted with
# different hashes.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id3", 1, "hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id4", 1, "hash2", "manager_id1")
# Check that NULL characters in the hash are handled properly.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1, "\x00hash1", "manager_id1")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1, "\x00hash2", "manager_id1")
def testObjectTableAddAndRemove(self):
# Try removing a manager from an object ID that has not been added yet.
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id1")
# Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not been
# added yet.
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
self.assertEqual(set(response), set([]))
# Add some managers and try again.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
# Remove a manager that doesn't exist, and make sure we still have the same set.
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id3")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
# Remove a manager that does exist. Make sure it gets removed the first
# time and does nothing the second time.
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id1")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
self.assertEqual(set(response), {b"manager_id2"})
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id1")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
self.assertEqual(set(response), {b"manager_id2"})
# Remove the last manager, and make sure we have an empty set.
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
self.assertEqual(set(response), set())
# Remove a manager from an empty set, and make sure we still have an empty set.
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id3")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
self.assertEqual(set(response), set())
def testObjectTableSubscribeToNotifications(self):
data_size = 0xf1f0
p = self.redis.pubsub()
# Subscribe to an object ID.
p.psubscribe("{}manager_id1".format(OBJECT_CHANNEL_PREFIX))
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", data_size, "hash1", "manager_id2")
# Receive the acknowledgement message.
self.assertEqual(p.get_message()["data"], 1)
# Request a notification and receive the data.
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", "manager_id1", "object_id1")
self.assertEqual(p.get_message()["data"], b"object_id1 %s MANAGERS manager_id2"\
%integerToAsciiHex(data_size, 8))
# Request a notification for an object that isn't there. Then add the object
# and receive the data. Only the first call to RAY.OBJECT_TABLE_ADD should
# trigger notifications.
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", "manager_id1", "object_id2", "object_id3")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, "hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, "hash1", "manager_id2")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, "hash1", "manager_id3")
self.assertEqual(p.get_message()["data"], b"object_id3 %s MANAGERS manager_id1"\
%integerToAsciiHex(data_size, 8))
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", data_size, "hash1", "manager_id3")
self.assertEqual(p.get_message()["data"], b"object_id2 %s MANAGERS manager_id3"\
%integerToAsciiHex(data_size, 8))
# Request notifications for object_id3 again.
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", "manager_id1", "object_id3")
self.assertEqual(p.get_message()["data"], b"object_id3 %s MANAGERS manager_id1 manager_id2 manager_id3"\
%integerToAsciiHex(data_size, 8))
def testResultTableAddAndLookup(self):
# Try looking up something in the result table before anything is added.
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1")
self.assertIsNone(response)
# Adding the object to the object table should have no effect.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1")
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1")
self.assertIsNone(response)
# Add the result to the result table. This is necessary, but not sufficient
# because the task is still not in the task table.
self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id1", "task_id1")
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1")
self.assertIsNone(response)
# Add the task to the task table so that the result table lookup can
# succeed.
self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id1", 1, "local_scheduler_id1", "task_spec1")
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1")
self.assertEqual(response, [1, b"local_scheduler_id1", b"task_spec1"])
# Doing it again should still work.
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1")
self.assertEqual(response, [1, b"local_scheduler_id1", b"task_spec1"])
# Try another result table lookup. This should succeed.
self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id2", 2, "local_scheduler_id2", "task_spec2")
self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id2", "task_id2")
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id2")
self.assertEqual(response, [2, b"local_scheduler_id2", b"task_spec2"])
def testInvalidTaskTableAdd(self):
# Check that Redis returns an error when RAY.TASK_TABLE_ADD is called with
# the wrong arguments.
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.TASK_TABLE_ADD")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.TASK_TABLE_ADD", "hello")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", 3, "node_id")
with self.assertRaises(redis.ResponseError):
# Non-integer scheduling states should not be added.
self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id",
"invalid_state", "node_id", "task_spec")
with self.assertRaises(redis.ResponseError):
# Scheduling states with invalid width should not be added.
self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", 101,
"node_id", "task_spec")
with self.assertRaises(redis.ResponseError):
# Should not be able to update a non-existent task.
self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", 10,
"node_id")
def testTaskTableAddAndLookup(self):
# Check that task table adds, updates, and lookups work correctly.
task_args = [1, b"node_id", b"task_spec"]
response = self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id",
*task_args)
response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
self.assertEqual(response, task_args)
task_args[0] = 2
self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", *task_args[:2])
response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
self.assertEqual(response, task_args)
def testTaskTableSubscribe(self):
scheduling_state = 1
node_id = "node_id"
# Subscribe to the task table.
p = self.redis.pubsub()
p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
p.psubscribe("{prefix}*:{state: >2}".format(prefix=TASK_PREFIX, state=scheduling_state))
p.psubscribe("{prefix}{node}:*".format(prefix=TASK_PREFIX, node=node_id))
task_args = [b"task_id", scheduling_state, node_id.encode("ascii"), b"task_spec"]
self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args)
# Receive the acknowledgement message.
self.assertEqual(p.get_message()["data"], 1)
self.assertEqual(p.get_message()["data"], 2)
self.assertEqual(p.get_message()["data"], 3)
# Receive the actual data.
for i in range(3):
message = p.get_message()["data"]
message = message.split()
message[1] = int(message[1])
self.assertEqual(message, task_args)
if __name__ == "__main__":
unittest.main(verbosity=2)
+166
View File
@@ -0,0 +1,166 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import pickle
import sys
import unittest
import photon
ID_SIZE = 20
def random_object_id():
return photon.ObjectID(np.random.bytes(ID_SIZE))
def random_function_id():
return photon.ObjectID(np.random.bytes(ID_SIZE))
def random_task_id():
return photon.ObjectID(np.random.bytes(ID_SIZE))
BASE_SIMPLE_OBJECTS = [
0, 1, 100000, 0.0, 0.5, 0.9, 100000.1, (), [], {},
"", 990 * "h", u"", 990 * u"h"
]
if sys.version_info < (3, 0):
BASE_SIMPLE_OBJECTS += [long(0), long(1), long(100000), long(1 << 100)]
LIST_SIMPLE_OBJECTS = [[obj] for obj in BASE_SIMPLE_OBJECTS]
TUPLE_SIMPLE_OBJECTS = [(obj,) for obj in BASE_SIMPLE_OBJECTS]
DICT_SIMPLE_OBJECTS = [{(): obj} for obj in BASE_SIMPLE_OBJECTS]
SIMPLE_OBJECTS = (BASE_SIMPLE_OBJECTS +
LIST_SIMPLE_OBJECTS +
TUPLE_SIMPLE_OBJECTS +
DICT_SIMPLE_OBJECTS)
# Create some complex objects that cannot be serialized by value in tasks.
l = []
l.append(l)
class Foo(object):
def __init__(self):
pass
BASE_COMPLEX_OBJECTS = [999 * "h", 999 * u"h", l, Foo(), 10 * [10 * [10 * [1]]]]
LIST_COMPLEX_OBJECTS = [[obj] for obj in BASE_COMPLEX_OBJECTS]
TUPLE_COMPLEX_OBJECTS = [(obj,) for obj in BASE_COMPLEX_OBJECTS]
DICT_COMPLEX_OBJECTS = [{(): obj} for obj in BASE_COMPLEX_OBJECTS]
COMPLEX_OBJECTS = (BASE_COMPLEX_OBJECTS +
LIST_COMPLEX_OBJECTS +
TUPLE_COMPLEX_OBJECTS +
DICT_COMPLEX_OBJECTS)
class TestSerialization(unittest.TestCase):
def test_serialize_by_value(self):
for val in SIMPLE_OBJECTS:
self.assertTrue(photon.check_simple_value(val))
for val in COMPLEX_OBJECTS:
self.assertFalse(photon.check_simple_value(val))
class TestObjectID(unittest.TestCase):
def test_create_object_id(self):
object_id = random_object_id()
def test_cannot_pickle_object_ids(self):
object_ids = [random_object_id() for _ in range(256)]
def f():
return object_ids
def g(val=object_ids):
return 1
def h():
x = object_ids[0]
return 1
# Make sure that object IDs cannot be pickled (including functions that
# close over object IDs).
self.assertRaises(Exception, lambda : pickling.dumps(object_ids[0]))
self.assertRaises(Exception, lambda : pickling.dumps(object_ids))
self.assertRaises(Exception, lambda : pickling.dumps(f))
self.assertRaises(Exception, lambda : pickling.dumps(g))
self.assertRaises(Exception, lambda : pickling.dumps(h))
def test_equality_comparisons(self):
x1 = photon.ObjectID(ID_SIZE * b"a")
x2 = photon.ObjectID(ID_SIZE * b"a")
y1 = photon.ObjectID(ID_SIZE * b"b")
y2 = photon.ObjectID(ID_SIZE * b"b")
self.assertEqual(x1, x2)
self.assertEqual(y1, y2)
self.assertNotEqual(x1, y1)
random_strings = [np.random.bytes(ID_SIZE) for _ in range(256)]
object_ids1 = [photon.ObjectID(random_strings[i]) for i in range(256)]
object_ids2 = [photon.ObjectID(random_strings[i]) for i in range(256)]
self.assertEqual(len(set(object_ids1)), 256)
self.assertEqual(len(set(object_ids1 + object_ids2)), 256)
self.assertEqual(set(object_ids1), set(object_ids2))
def test_hashability(self):
x = random_object_id()
y = random_object_id()
{x: y}
set([x, y])
class TestTask(unittest.TestCase):
def check_task(self, task, function_id, num_return_vals, args):
self.assertEqual(function_id.id(), task.function_id().id())
retrieved_args = task.arguments()
self.assertEqual(num_return_vals, len(task.returns()))
self.assertEqual(len(args), len(retrieved_args))
for i in range(len(retrieved_args)):
if isinstance(retrieved_args[i], photon.ObjectID):
self.assertEqual(retrieved_args[i].id(), args[i].id())
else:
self.assertEqual(retrieved_args[i], args[i])
def test_create_and_serialize_task(self):
# TODO(rkn): The function ID should be a FunctionID object, not an ObjectID.
parent_id = random_task_id()
function_id = random_function_id()
object_ids = [random_object_id() for _ in range(256)]
args_list = [
[],
1 * [1],
10 * [1],
100 * [1],
1000 * [1],
1 * ["a"],
10 * ["a"],
100 * ["a"],
1000 * ["a"],
[1, 1.3, 2, 1 << 100, "hi", u"hi", [1, 2]],
object_ids[:1],
object_ids[:2],
object_ids[:3],
object_ids[:4],
object_ids[:5],
object_ids[:10],
object_ids[:100],
object_ids[:256],
[1, object_ids[0]],
[object_ids[0], "a"],
[1, object_ids[0], "a"],
[object_ids[0], 1, object_ids[1], "a"],
object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
object_ids + 100 * ["a"] + object_ids
]
for args in args_list:
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
task = photon.Task(function_id, args, num_return_vals, parent_id, 0)
self.check_task(task, function_id, num_return_vals, args)
data = photon.task_to_string(task)
task2 = photon.task_from_string(data)
self.check_task(task2, function_id, num_return_vals, args)
if __name__ == "__main__":
unittest.main(verbosity=2)
View File
View File
View File
View File
View File
View File
+5
View File
@@ -0,0 +1,5 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from .global_scheduler_services import *
@@ -0,0 +1,40 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import subprocess
import time
def start_global_scheduler(redis_address, use_valgrind=False, use_profiler=False, redirect_output=False):
"""Start a global scheduler process.
Args:
redis_address (str): The address of the Redis instance.
use_valgrind (bool): True if the global scheduler should be started inside
of valgrind. If this is True, use_profiler must be False.
use_profiler (bool): True if the global scheduler should be started inside a
profiler. If this is True, use_valgrind must be False.
redirect_output (bool): True if stdout and stderr should be redirected to
/dev/null.
Return:
The process ID of the global scheduler process.
"""
if use_valgrind and use_profiler:
raise Exception("Cannot use valgrind and profiler at the same time.")
global_scheduler_executable = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../core/src/global_scheduler/global_scheduler")
command = [global_scheduler_executable, "-r", redis_address]
with open(os.devnull, "w") as FNULL:
stdout = FNULL if redirect_output else None
stderr = FNULL if redirect_output else None
if use_valgrind:
pid = subprocess.Popen(["valgrind", "--track-origins=yes", "--leak-check=full", "--show-leak-kinds=all", "--error-exitcode=1"] + command, stdout=stdout, stderr=stderr)
time.sleep(1.0)
elif use_profiler:
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command, stdout=stdout, stderr=stderr)
time.sleep(1.0)
else:
pid = subprocess.Popen(command, stdout=stdout, stderr=stderr)
time.sleep(0.1)
return pid
+238
View File
@@ -0,0 +1,238 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import random
import redis
import signal
import subprocess
import sys
import threading
import time
import unittest
import global_scheduler
import photon
import plasma
from plasma.utils import random_object_id, generate_metadata, write_to_data_buffer, create_object_with_id, create_object
USE_VALGRIND = False
PLASMA_STORE_MEMORY = 1000000000
ID_SIZE = 20
# These constants must match the scheduling state enum in task.h.
TASK_STATUS_WAITING = 1
TASK_STATUS_SCHEDULED = 2
TASK_STATUS_RUNNING = 4
TASK_STATUS_DONE = 8
# These constants are an implementation detail of ray_redis_module.c, so this
# must be kept in sync with that file.
DB_CLIENT_PREFIX = "CL:"
TASK_PREFIX = "TT:"
def random_task_id():
return photon.ObjectID(np.random.bytes(ID_SIZE))
def random_function_id():
return photon.ObjectID(np.random.bytes(ID_SIZE))
def new_port():
return random.randint(10000, 65535)
class TestGlobalScheduler(unittest.TestCase):
def setUp(self):
# Start a Redis server.
redis_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../core/src/common/thirdparty/redis/src/redis-server")
redis_module = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../core/src/common/redis_module/libray_redis_module.so")
assert os.path.isfile(redis_path)
assert os.path.isfile(redis_module)
node_ip_address = "127.0.0.1"
redis_port = new_port()
redis_address = "{}:{}".format(node_ip_address, redis_port)
self.redis_process = subprocess.Popen([redis_path, "--port", str(redis_port), "--loglevel", "warning", "--loadmodule", redis_module])
time.sleep(0.1)
# Create a Redis client.
self.redis_client = redis.StrictRedis(host=node_ip_address, port=redis_port)
# Start the global scheduler.
self.p1 = global_scheduler.start_global_scheduler(redis_address, use_valgrind=USE_VALGRIND)
# Start the Plasma store.
plasma_store_name, self.p2 = plasma.start_plasma_store()
# Start the Plasma manager.
plasma_manager_name, self.p3, plasma_manager_port = plasma.start_plasma_manager(plasma_store_name, redis_address)
self.plasma_address = "{}:{}".format(node_ip_address, plasma_manager_port)
self.plasma_client = plasma.PlasmaClient(plasma_store_name, plasma_manager_name)
# Start the local scheduler.
local_scheduler_name, self.p4 = photon.start_local_scheduler(
plasma_store_name,
plasma_manager_name=plasma_manager_name,
plasma_address=self.plasma_address,
redis_address=redis_address)
# Connect to the scheduler.
self.photon_client = photon.PhotonClient(local_scheduler_name)
def tearDown(self):
# Check that the processes are still alive.
self.assertEqual(self.p1.poll(), None)
self.assertEqual(self.p2.poll(), None)
self.assertEqual(self.p3.poll(), None)
self.assertEqual(self.p4.poll(), None)
self.assertEqual(self.redis_process.poll(), None)
# Kill the global scheduler.
if USE_VALGRIND:
self.p1.send_signal(signal.SIGTERM)
self.p1.wait()
if self.p1.returncode != 0:
os._exit(-1)
else:
self.p1.kill()
self.p2.kill()
self.p3.kill()
self.p4.kill()
# Kill Redis. In the event that we are using valgrind, this needs to happen
# after we kill the global scheduler.
self.redis_process.kill()
def get_plasma_manager_id(self):
"""Get the db_client_id with client_type equal to plasma_manager.
Iterates over all the client table keys, gets the db_client_id for the
client with client_type matching plasma_manager. Strips the client table
prefix. TODO(atumanov): write a separate function to get all plasma manager
client IDs.
Returns:
The db_client_id if one is found and otherwise None.
"""
db_client_id = None
client_list = self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))
for client_id in client_list:
response = self.redis_client.hget(client_id, b"client_type")
if response == b"plasma_manager":
db_client_id = client_id
break
return db_client_id
def test_redis_only_single_task(self):
"""
Tests global scheduler functionality by interacting with Redis and checking
task state transitions in Redis only. TODO(atumanov): implement.
"""
# Check precondition for this test:
# There should be three db clients, the global scheduler, the local
# scheduler, and the plasma manager.
self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))), 3)
db_client_id = self.get_plasma_manager_id()
assert(db_client_id != None)
assert(db_client_id.startswith(b"CL:"))
db_client_id = db_client_id[len(b"CL:"):] # Remove the CL: prefix.
def test_integration_single_task(self):
# There should be three db clients, the global scheduler, the local
# scheduler, and the plasma manager.
self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))), 3)
num_return_vals = [0, 1, 2, 3, 5, 10]
# There should not be anything else in Redis yet.
self.assertEqual(len(self.redis_client.keys("*")), 3)
# Insert the object into Redis.
data_size = 0xf1f0
metadata_size = 0x40
object_dep, memory_buffer, metadata = create_object(self.plasma_client, data_size, metadata_size, seal=True)
# Sleep before submitting task to photon.
time.sleep(0.1)
# Submit a task to Redis.
task = photon.Task(random_function_id(), [photon.ObjectID(object_dep)], num_return_vals[0], random_task_id(), 0)
self.photon_client.submit(task)
time.sleep(0.1)
# There should now be a task in Redis, and it should get assigned to the
# local scheduler
num_retries = 10
while num_retries > 0:
task_entries = self.redis_client.keys("{}*".format(TASK_PREFIX))
self.assertLessEqual(len(task_entries), 1)
if len(task_entries) == 1:
task_contents = self.redis_client.hgetall(task_entries[0])
task_status = int(task_contents[b"state"])
self.assertTrue(task_status in [TASK_STATUS_WAITING, TASK_STATUS_SCHEDULED])
if task_status == TASK_STATUS_SCHEDULED:
break
else:
print(task_status)
print("The task has not been scheduled yet, trying again.")
num_retries -= 1
time.sleep(1)
if num_retries <= 0 and task_status != TASK_STATUS_SCHEDULED:
# Failed to submit and schedule a single task -- bail.
self.tearDown()
sys.exit(1)
def integration_many_tasks_helper(self, timesync=True):
# There should be three db clients, the global scheduler, the local
# scheduler, and the plasma manager.
self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))), 3)
num_return_vals = [0, 1, 2, 3, 5, 10]
# Submit a bunch of tasks to Redis.
num_tasks = 1000
for _ in range(num_tasks):
# Create a new object for each task.
data_size = np.random.randint(1 << 20)
metadata_size = np.random.randint(1 << 10)
object_dep, memory_buffer, metadata = create_object(self.plasma_client, data_size, metadata_size, seal=True)
if timesync:
# Give 10ms for object info handler to fire (long enough to yield CPU).
time.sleep(0.010)
task = photon.Task(random_function_id(), [photon.ObjectID(object_dep)], num_return_vals[0], random_task_id(), 0)
self.photon_client.submit(task)
# Check that there are the correct number of tasks in Redis and that they
# all get assigned to the local scheduler.
num_retries = 10
num_tasks_done = 0
while num_retries > 0:
task_entries = self.redis_client.keys("{}*".format(TASK_PREFIX))
self.assertLessEqual(len(task_entries), num_tasks)
# First, check if all tasks made it to Redis.
if len(task_entries) == num_tasks:
task_contents = [self.redis_client.hgetall(task_entries[i]) for i in range(len(task_entries))]
task_statuses = [int(contents[b"state"]) for contents in task_contents]
self.assertTrue(all([status in [TASK_STATUS_WAITING, TASK_STATUS_SCHEDULED] for status in task_statuses]))
num_tasks_done = task_statuses.count(TASK_STATUS_SCHEDULED)
num_tasks_waiting = task_statuses.count(TASK_STATUS_WAITING)
print("tasks in Redis = {}, tasks waiting = {}, tasks scheduled = {}, retries left = {}"
.format(len(task_entries), num_tasks_waiting, num_tasks_done, num_retries))
if all([status == TASK_STATUS_SCHEDULED for status in task_statuses]):
# We're done, so pass.
break
num_retries -= 1
time.sleep(0.1)
if num_tasks_done != num_tasks:
# At least one of the tasks failed to schedule.
self.tearDown()
sys.exit(2)
def test_integration_many_tasks_handler_sync(self):
self.integration_many_tasks_helper(timesync=True)
def test_integration_many_tasks(self):
# More realistic case: should handle out of order object and task
# notifications.
self.integration_many_tasks_helper(timesync=False)
if __name__ == "__main__":
if len(sys.argv) > 1:
# Pop the argument so we don't mess with unittest's own argument parser.
arg = sys.argv.pop()
if arg == "valgrind":
USE_VALGRIND = True
print("Using valgrind for tests")
unittest.main(verbosity=2)
+29
View File
@@ -0,0 +1,29 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# See https://github.com/ray-project/ray/issues/131.
helpful_message = """
If you are using Anaconda, try fixing this problem by running:
conda install libgcc
"""
try:
from core.src.numbuf.libnumbuf import *
except ImportError as e:
if hasattr(e, "msg") and isinstance(e.msg, str) and "GLIBCXX" in e.msg:
# This code path should be taken with Python 3.
e.msg += helpful_message
elif hasattr(e, "message") and isinstance(e.message, str) and "GLIBCXX" in e.message:
# This code path should be taken with Python 2.
if hasattr(e, "args") and isinstance(e.args, tuple) and len(e.args) == 1 and isinstance(e.args[0], str):
e.args = (e.args[0] + helpful_message,)
else:
if not hasattr(e, "args"):
e.args = ()
elif not isinstance(e.args, tuple):
e.args = (e.args,)
e.args += (helpful_message,)
raise
+6
View File
@@ -0,0 +1,6 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from core.src.photon.libphoton import *
from .photon_services import *
View File
+64
View File
@@ -0,0 +1,64 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import random
import subprocess
import time
def random_name():
return str(random.randint(0, 99999999))
def start_local_scheduler(plasma_store_name, plasma_manager_name=None, plasma_address=None, node_ip_address="127.0.0.1", redis_address=None, use_valgrind=False, use_profiler=False, redirect_output=False):
"""Start a local scheduler process.
Args:
plasma_store_name (str): The name of the plasma store socket to connect to.
plasma_manager_name (str): The name of the plasma manager to connect to.
This does not need to be provided, but if it is, then the Redis address
must be provided as well.
plasma_address (str): The address of the plasma manager to connect to. This
is only used by the global scheduler to figure out which plasma managers
are connected to which local schedulers.
node_ip_address (str): The address of the node that this local scheduler is
running on.
redis_address (str): The address of the Redis instance to connect to. If
this is not provided, then the local scheduler will not connect to Redis.
use_valgrind (bool): True if the local scheduler should be started inside of
valgrind. If this is True, use_profiler must be False.
use_profiler (bool): True if the local scheduler should be started inside a
profiler. If this is True, use_valgrind must be False.
redirect_output (bool): True if stdout and stderr should be redirected to
/dev/null.
Return:
A tuple of the name of the local scheduler socket and the process ID of the
local scheduler process.
"""
if (plasma_manager_name == None) != (redis_address == None):
raise Exception("If one of the plasma_manager_name and the redis_address is provided, then both must be provided.")
if use_valgrind and use_profiler:
raise Exception("Cannot use valgrind and profiler at the same time.")
local_scheduler_executable = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../core/src/photon/photon_scheduler")
local_scheduler_name = "/tmp/scheduler{}".format(random_name())
command = [local_scheduler_executable, "-s", local_scheduler_name, "-p", plasma_store_name, "-h", node_ip_address]
if plasma_manager_name is not None:
command += ["-m", plasma_manager_name]
if redis_address is not None:
command += ["-r", redis_address]
if plasma_address is not None:
command += ["-a", plasma_address]
with open(os.devnull, "w") as FNULL:
stdout = FNULL if redirect_output else None
stderr = FNULL if redirect_output else None
if use_valgrind:
pid = subprocess.Popen(["valgrind", "--track-origins=yes", "--leak-check=full", "--show-leak-kinds=all", "--error-exitcode=1"] + command, stdout=stdout, stderr=stderr)
time.sleep(1.0)
elif use_profiler:
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command, stdout=stdout, stderr=stderr)
time.sleep(1.0)
else:
pid = subprocess.Popen(command, stdout=stdout, stderr=stderr)
time.sleep(0.1)
return local_scheduler_name, pid
+149
View File
@@ -0,0 +1,149 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import random
import signal
import subprocess
import sys
import threading
import time
import unittest
import photon
import plasma
USE_VALGRIND = False
ID_SIZE = 20
def random_object_id():
return photon.ObjectID(np.random.bytes(ID_SIZE))
def random_task_id():
return photon.ObjectID(np.random.bytes(ID_SIZE))
def random_function_id():
return photon.ObjectID(np.random.bytes(ID_SIZE))
class TestPhotonClient(unittest.TestCase):
def setUp(self):
# Start Plasma store.
plasma_store_name, self.p1 = plasma.start_plasma_store()
self.plasma_client = plasma.PlasmaClient(plasma_store_name)
# Start a local scheduler.
scheduler_name, self.p2 = photon.start_local_scheduler(plasma_store_name, use_valgrind=USE_VALGRIND)
# Connect to the scheduler.
self.photon_client = photon.PhotonClient(scheduler_name)
def tearDown(self):
# Check that the processes are still alive.
self.assertEqual(self.p1.poll(), None)
self.assertEqual(self.p2.poll(), None)
# Kill Plasma.
self.p1.kill()
# Kill the local scheduler.
if USE_VALGRIND:
self.p2.send_signal(signal.SIGTERM)
self.p2.wait()
if self.p2.returncode != 0:
os._exit(-1)
else:
self.p2.kill()
def test_submit_and_get_task(self):
function_id = random_function_id()
object_ids = [random_object_id() for i in range(256)]
# Create and seal the objects in the object store so that we can schedule
# all of the subsequent tasks.
for object_id in object_ids:
self.plasma_client.create(object_id.id(), 0)
self.plasma_client.seal(object_id.id())
# Define some arguments to use for the tasks.
args_list = [
[],
#{},
#(),
1 * [1],
10 * [1],
100 * [1],
1000 * [1],
1 * ["a"],
10 * ["a"],
100 * ["a"],
1000 * ["a"],
[1, 1.3, 1 << 100, "hi", u"hi", [1, 2]],
object_ids[:1],
object_ids[:2],
object_ids[:3],
object_ids[:4],
object_ids[:5],
object_ids[:10],
object_ids[:100],
object_ids[:256],
[1, object_ids[0]],
[object_ids[0], "a"],
[1, object_ids[0], "a"],
[object_ids[0], 1, object_ids[1], "a"],
object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
object_ids + 100 * ["a"] + object_ids
]
for args in args_list:
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
task = photon.Task(function_id, args, num_return_vals, random_task_id(), 0)
# Submit a task.
self.photon_client.submit(task)
# Get the task.
new_task = self.photon_client.get_task()
self.assertEqual(task.function_id().id(), new_task.function_id().id())
retrieved_args = new_task.arguments()
returns = new_task.returns()
self.assertEqual(len(args), len(retrieved_args))
self.assertEqual(num_return_vals, len(returns))
for i in range(len(retrieved_args)):
if isinstance(args[i], photon.ObjectID):
self.assertEqual(args[i].id(), retrieved_args[i].id())
else:
self.assertEqual(args[i], retrieved_args[i])
# Submit all of the tasks.
for args in args_list:
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
task = photon.Task(function_id, args, num_return_vals, random_task_id(), 0)
self.photon_client.submit(task)
# Get all of the tasks.
for args in args_list:
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
new_task = self.photon_client.get_task()
def test_scheduling_when_objects_ready(self):
# Create a task and submit it.
object_id = random_object_id()
task = photon.Task(random_function_id(), [object_id], 0, random_task_id(), 0)
self.photon_client.submit(task)
# Launch a thread to get the task.
def get_task():
self.photon_client.get_task()
t = threading.Thread(target=get_task)
t.start()
# Sleep to give the thread time to call get_task.
time.sleep(0.1)
# Create and seal the object ID in the object store. This should trigger a
# scheduling event.
self.plasma_client.create(object_id.id(), 0)
self.plasma_client.seal(object_id.id())
# Wait until the thread finishes so that we know the task was scheduled.
t.join()
if __name__ == "__main__":
if len(sys.argv) > 1:
# pop the argument so we don't mess with unittest's own argument parser
arg = sys.argv.pop()
if arg == "valgrind":
USE_VALGRIND = True
print("Using valgrind for tests")
unittest.main(verbosity=2)
View File
+5
View File
@@ -0,0 +1,5 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from plasma.plasma import *
+350
View File
@@ -0,0 +1,350 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import random
import subprocess
import sys
import time
import core.src.plasma.libplasma as libplasma
from core.src.plasma.libplasma import plasma_object_exists_error
from core.src.plasma.libplasma import plasma_out_of_memory_error
PLASMA_ID_SIZE = 20
PLASMA_WAIT_TIMEOUT = 2 ** 30
class PlasmaBuffer(object):
"""This is the type of objects returned by calls to get with a PlasmaClient.
We define our own class instead of directly returning a buffer object so that
we can add a custom destructor which notifies Plasma that the object is no
longer being used, so the memory in the Plasma store backing the object can
potentially be freed.
Attributes:
buffer (buffer): A buffer containing an object in the Plasma store.
plasma_id (PlasmaID): The ID of the object in the buffer.
plasma_client (PlasmaClient): The PlasmaClient that we use to communicate
with the store and manager.
"""
def __init__(self, buff, plasma_id, plasma_client):
"""Initialize a PlasmaBuffer."""
self.buffer = buff
self.plasma_id = plasma_id
self.plasma_client = plasma_client
def __del__(self):
"""Notify Plasma that the object is no longer needed.
If the plasma client has been shut down, then don't do anything.
"""
if self.plasma_client.alive:
libplasma.release(self.plasma_client.conn, self.plasma_id)
def __getitem__(self, index):
"""Read from the PlasmaBuffer as if it were just a regular buffer."""
# We currently don't allow slicing plasma buffers. We should handle this
# better, but it requires some care because the slice may be backed by the
# same memory in the object store, but the original plasma buffer may go out
# of scope causing the memory to no longer be accessible.
assert not isinstance(index, slice)
value = self.buffer[index]
if sys.version_info >= (3, 0) and not isinstance(index, slice):
value = chr(value)
return value
def __setitem__(self, index, value):
"""Write to the PlasmaBuffer as if it were just a regular buffer.
This should fail because the buffer should be read only.
"""
# We currently don't allow slicing plasma buffers. We should handle this
# better, but it requires some care because the slice may be backed by the
# same memory in the object store, but the original plasma buffer may go out
# of scope causing the memory to no longer be accessible.
assert not isinstance(index, slice)
if sys.version_info >= (3, 0) and not isinstance(index, slice):
value = ord(value)
self.buffer[index] = value
def __len__(self):
"""Return the length of the buffer."""
return len(self.buffer)
def buffers_equal(buff1, buff2):
"""Compare two buffers. These buffers may be PlasmaBuffer objects.
This method should only be used in the tests. We implement a special helper
method for doing this because doing comparisons by slicing is much faster, but
we don't want to expose slicing of PlasmaBuffer objects because it currently
is not safe.
"""
buff1_to_compare = buff1.buffer if isinstance(buff1, PlasmaBuffer) else buff1
buff2_to_compare = buff2.buffer if isinstance(buff2, PlasmaBuffer) else buff2
return buff1_to_compare[:] == buff2_to_compare[:]
class PlasmaClient(object):
"""The PlasmaClient is used to interface with a plasma store and a plasma manager.
The PlasmaClient can ask the PlasmaStore to allocate a new buffer, seal a
buffer, and get a buffer. Buffers are referred to by object IDs, which are
strings.
"""
def __init__(self, store_socket_name, manager_socket_name=None, release_delay=64):
"""Initialize the PlasmaClient.
Args:
store_socket_name (str): Name of the socket the plasma store is listening at.
manager_socket_name (str): Name of the socket the plasma manager is listening at.
"""
self.alive = True
if manager_socket_name is not None:
self.conn = libplasma.connect(store_socket_name, manager_socket_name, release_delay)
else:
self.conn = libplasma.connect(store_socket_name, "", release_delay)
def shutdown(self):
"""Shutdown the client so that it does not send messages.
If we kill the Plasma store and Plasma manager that this client is connected
to, then we can use this method to prevent the client from trying to send
messages to the killed processes.
"""
if self.alive:
libplasma.disconnect(self.conn)
self.alive = False
def create(self, object_id, size, metadata=None):
"""Create a new buffer in the PlasmaStore for a particular object ID.
The returned buffer is mutable until seal is called.
Args:
object_id (str): A string used to identify an object.
size (int): The size in bytes of the created buffer.
metadata (buffer): An optional buffer encoding whatever metadata the user
wishes to encode.
Raises:
plasma_object_exists_error: This exception is raised if the object could
not be created because there already is an object with the same ID in
the plasma store.
plasma_out_of_memory_error: This exception is raised if the object could
not be created because the plasma store is unable to evict enough
objects to create room for it.
"""
# Turn the metadata into the right type.
metadata = bytearray(b"") if metadata is None else metadata
buff = libplasma.create(self.conn, object_id, size, metadata)
return PlasmaBuffer(buff, object_id, self)
def get(self, object_id):
"""Create a buffer from the PlasmaStore based on object ID.
If the object has not been sealed yet, this call will block. The retrieved
buffer is immutable.
Args:
object_id (str): A string used to identify an object.
"""
buff = libplasma.get(self.conn, object_id)[0]
return PlasmaBuffer(buff, object_id, self)
def get_metadata(self, object_id):
"""Create a buffer from the PlasmaStore based on object ID.
If the object has not been sealed yet, this call will block until the object
has been sealed. The retrieved buffer is immutable.
Args:
object_id (str): A string used to identify an object.
"""
buff = libplasma.get(self.conn, object_id)[1]
return PlasmaBuffer(buff, object_id, self)
def contains(self, object_id):
"""Check if the object is present and has been sealed in the PlasmaStore.
Args:
object_id (str): A string used to identify an object.
"""
return libplasma.contains(self.conn, object_id)
def hash(self, object_id):
"""Compute the hash of an object in the object store.
Args:
object_id (str): A string used to identify an object.
Returns:
A digest string object's SHA256 hash. If the object isn't in the object
store, the string will have length zero.
"""
return libplasma.hash(self.conn, object_id)
def seal(self, object_id):
"""Seal the buffer in the PlasmaStore for a particular object ID.
Once a buffer has been sealed, the buffer is immutable and can only be
accessed through get.
Args:
object_id (str): A string used to identify an object.
"""
libplasma.seal(self.conn, object_id)
def delete(self, object_id):
"""Delete the buffer in the PlasmaStore for a particular object ID.
Once a buffer has been deleted, the buffer is no longer accessible.
Args:
object_id (str): A string used to identify an object.
"""
libplasma.delete(self.conn, object_id)
def evict(self, num_bytes):
"""Evict some objects until to recover some bytes.
Recover at least num_bytes bytes if possible.
Args:
num_bytes (int): The number of bytes to attempt to recover.
"""
return libplasma.evict(self.conn, num_bytes)
def transfer(self, addr, port, object_id):
"""Transfer local object with id object_id to another plasma instance
Args:
addr (str): IPv4 address of the plasma instance the object is sent to.
port (int): Port number of the plasma instance the object is sent to.
object_id (str): A string used to identify an object.
"""
return libplasma.transfer(self.conn, object_id, addr, port)
def fetch(self, object_ids):
"""Fetch the objects with the given IDs from other plasma manager instances.
Args:
object_ids (List[str]): A list of strings used to identify the objects.
"""
return libplasma.fetch(self.conn, object_ids)
def wait(self, object_ids, timeout=PLASMA_WAIT_TIMEOUT, num_returns=1):
"""Wait until num_returns objects in object_ids are ready.
Args:
object_ids (List[str]): List of object IDs to wait for.
timeout (int): Return to the caller after timeout milliseconds.
num_returns (int): We are waiting for this number of objects to be ready.
Returns:
ready_ids, waiting_ids (List[str], List[str]): List of object IDs that
are ready and list of object IDs we might still wait on respectively.
"""
ready_ids, waiting_ids = libplasma.wait(self.conn, object_ids, timeout, num_returns)
return ready_ids, list(waiting_ids)
def subscribe(self):
"""Subscribe to notifications about sealed objects."""
self.notification_fd = libplasma.subscribe(self.conn)
def get_next_notification(self):
"""Get the next notification from the notification socket."""
return libplasma.receive_notification(self.notification_fd)
DEFAULT_PLASMA_STORE_MEMORY = 10 ** 9
def random_name():
return str(random.randint(0, 99999999))
def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY, use_valgrind=False, use_profiler=False, redirect_output=False):
"""Start a plasma store process.
Args:
use_valgrind (bool): True if the plasma store should be started inside of
valgrind. If this is True, use_profiler must be False.
use_profiler (bool): True if the plasma store should be started inside a
profiler. If this is True, use_valgrind must be False.
redirect_output (bool): True if stdout and stderr should be redirected to
/dev/null.
Return:
A tuple of the name of the plasma store socket and the process ID of the
plasma store process.
"""
if use_valgrind and use_profiler:
raise Exception("Cannot use valgrind and profiler at the same time.")
plasma_store_executable = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../core/src/plasma/plasma_store")
plasma_store_name = "/tmp/plasma_store{}".format(random_name())
command = [plasma_store_executable, "-s", plasma_store_name, "-m", str(plasma_store_memory)]
with open(os.devnull, "w") as FNULL:
stdout = FNULL if redirect_output else None
stderr = FNULL if redirect_output else None
if use_valgrind:
pid = subprocess.Popen(["valgrind", "--track-origins=yes", "--leak-check=full", "--show-leak-kinds=all", "--error-exitcode=1"] + command, stdout=stdout, stderr=stderr)
time.sleep(1.0)
elif use_profiler:
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command, stdout=stdout, stderr=stderr)
time.sleep(1.0)
else:
pid = subprocess.Popen(command, stdout=stdout, stderr=stderr)
time.sleep(0.1)
return plasma_store_name, pid
def start_plasma_manager(store_name, redis_address, node_ip_address="127.0.0.1", num_retries=20, use_valgrind=False, run_profiler=False, redirect_output=False):
"""Start a plasma manager and return the ports it listens on.
Args:
store_name (str): The name of the plasma store socket.
redis_address (str): The address of the Redis server.
node_ip_address (str): The IP address of the node.
use_valgrind (bool): True if the Plasma manager should be started inside of
valgrind and False otherwise.
redirect_output (bool): True if stdout and stderr should be redirected to
/dev/null.
Returns:
A tuple of the Plasma manager socket name, the process ID of the Plasma
manager process, and the port that the manager is listening on.
Raises:
Exception: An exception is raised if the manager could not be started.
"""
plasma_manager_executable = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../core/src/plasma/plasma_manager")
plasma_manager_name = "/tmp/plasma_manager{}".format(random_name())
port = None
process = None
counter = 0
while counter < num_retries:
if counter > 0:
print("Plasma manager failed to start, retrying now.")
port = random.randint(10000, 65535)
command = [plasma_manager_executable,
"-s", store_name,
"-m", plasma_manager_name,
"-h", node_ip_address,
"-p", str(port),
"-r", redis_address]
with open(os.devnull, "w") as FNULL:
stdout = FNULL if redirect_output else None
stderr = FNULL if redirect_output else None
if use_valgrind:
process = subprocess.Popen(["valgrind", "--track-origins=yes", "--leak-check=full", "--show-leak-kinds=all", "--error-exitcode=1"] + command, stdout=stdout, stderr=stderr)
elif run_profiler:
process = subprocess.Popen(["valgrind", "--tool=callgrind"] + command, stdout=stdout, stderr=stderr)
else:
process = subprocess.Popen(command, stdout=stdout, stderr=stderr)
# This sleep is critical. If the plasma_manager fails to start because the
# port is already in use, then we need it to fail within 0.1 seconds.
time.sleep(0.1)
# See if the process has terminated
if process.poll() == None:
return plasma_manager_name, process, port
counter += 1
raise Exception("Couldn't start plasma manager.")
+778
View File
@@ -0,0 +1,778 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import random
import signal
import socket
import struct
import subprocess
import sys
import tempfile
import threading
import time
import unittest
import plasma
from plasma.utils import random_object_id, generate_metadata, write_to_data_buffer, create_object_with_id, create_object
USE_VALGRIND = False
PLASMA_STORE_MEMORY = 1000000000
def assert_get_object_equal(unit_test, client1, client2, object_id, memory_buffer=None, metadata=None):
client1_buff = client1.get(object_id)
client2_buff = client2.get(object_id)
client1_metadata = client1.get_metadata(object_id)
client2_metadata = client2.get_metadata(object_id)
unit_test.assertEqual(len(client1_buff), len(client2_buff))
unit_test.assertEqual(len(client1_metadata), len(client2_metadata))
# Check that the buffers from the two clients are the same.
unit_test.assertTrue(plasma.buffers_equal(client1_buff, client2_buff))
# Check that the metadata buffers from the two clients are the same.
unit_test.assertTrue(plasma.buffers_equal(client1_metadata, client2_metadata))
# If a reference buffer was provided, check that it is the same as well.
if memory_buffer is not None:
unit_test.assertTrue(plasma.buffers_equal(memory_buffer, client1_buff))
# If reference metadata was provided, check that it is the same as well.
if metadata is not None:
unit_test.assertTrue(plasma.buffers_equal(metadata, client1_metadata))
class TestPlasmaClient(unittest.TestCase):
def setUp(self):
# Start Plasma store.
plasma_store_name, self.p = plasma.start_plasma_store(use_valgrind=USE_VALGRIND)
# Connect to Plasma.
self.plasma_client = plasma.PlasmaClient(plasma_store_name, None, 64)
# For the eviction test
self.plasma_client2 = plasma.PlasmaClient(plasma_store_name, None, 0)
def tearDown(self):
# Check that the Plasma store is still alive.
self.assertEqual(self.p.poll(), None)
# Kill the plasma store process.
if USE_VALGRIND:
self.p.send_signal(signal.SIGTERM)
self.p.wait()
if self.p.returncode != 0:
os._exit(-1)
else:
self.p.kill()
def test_create(self):
# Create an object id string.
object_id = random_object_id()
# Create a new buffer and write to it.
length = 50
memory_buffer = self.plasma_client.create(object_id, length)
for i in range(length):
memory_buffer[i] = chr(i % 256)
# Seal the object.
self.plasma_client.seal(object_id)
# Get the object.
memory_buffer = self.plasma_client.get(object_id)
for i in range(length):
self.assertEqual(memory_buffer[i], chr(i % 256))
def test_create_with_metadata(self):
for length in range(1000):
# Create an object id string.
object_id = random_object_id()
# Create a random metadata string.
metadata = generate_metadata(length)
# Create a new buffer and write to it.
memory_buffer = self.plasma_client.create(object_id, length, metadata)
for i in range(length):
memory_buffer[i] = chr(i % 256)
# Seal the object.
self.plasma_client.seal(object_id)
# Get the object.
memory_buffer = self.plasma_client.get(object_id)
for i in range(length):
self.assertEqual(memory_buffer[i], chr(i % 256))
# Get the metadata.
metadata_buffer = self.plasma_client.get_metadata(object_id)
self.assertEqual(len(metadata), len(metadata_buffer))
for i in range(len(metadata)):
self.assertEqual(chr(metadata[i]), metadata_buffer[i])
def test_create_existing(self):
# This test is partially used to test the code path in which we create an
# object with an ID that already exists
length = 100
for _ in range(1000):
object_id = random_object_id()
self.plasma_client.create(object_id, length, generate_metadata(length))
try:
val = self.plasma_client.create(object_id, length, generate_metadata(length))
except plasma.plasma_object_exists_error as e:
pass
else:
self.assertTrue(False)
def test_store_full(self):
# The store is started with 1GB, so make sure that create throws an
# exception when it is full.
def assert_create_raises_plasma_full(unit_test, size):
partial_size = np.random.randint(size)
try:
_, memory_buffer, _ = create_object(unit_test.plasma_client, partial_size, size - partial_size)
except plasma.plasma_out_of_memory_error as e:
pass
else:
# For some reason the above didn't throw an exception, so fail.
unit_test.assertTrue(False)
# Create a list to keep some of the buffers in scope.
memory_buffers = []
_, memory_buffer, _ = create_object(self.plasma_client, 9 * 10 ** 8, 0)
memory_buffers.append(memory_buffer)
# Remaining space is 10 ** 8. Make sure that we can't create an object of
# size 10 ** 8 + 1, but we can create one of size 10 ** 8.
assert_create_raises_plasma_full(self, 10 ** 8 + 1)
_, memory_buffer, _ = create_object(self.plasma_client, 10 ** 8, 0)
del memory_buffer
_, memory_buffer, _ = create_object(self.plasma_client, 10 ** 8, 0)
del memory_buffer
assert_create_raises_plasma_full(self, 10 ** 8 + 1)
_, memory_buffer, _ = create_object(self.plasma_client, 9 * 10 ** 7, 0)
memory_buffers.append(memory_buffer)
# Remaining space is 10 ** 7.
assert_create_raises_plasma_full(self, 10 ** 7 + 1)
_, memory_buffer, _ = create_object(self.plasma_client, 9 * 10 ** 6, 0)
memory_buffers.append(memory_buffer)
# Remaining space is 10 ** 6.
assert_create_raises_plasma_full(self, 10 ** 6 + 1)
_, memory_buffer, _ = create_object(self.plasma_client, 9 * 10 ** 5, 0)
memory_buffers.append(memory_buffer)
# Remaining space is 10 ** 5.
assert_create_raises_plasma_full(self, 10 ** 5 + 1)
_, memory_buffer, _ = create_object(self.plasma_client, 9 * 10 ** 4, 0)
memory_buffers.append(memory_buffer)
# Remaining space is 10 ** 4.
assert_create_raises_plasma_full(self, 10 ** 4 + 1)
_, memory_buffer, _ = create_object(self.plasma_client, 9 * 10 ** 3, 0)
memory_buffers.append(memory_buffer)
# Remaining space is 10 ** 3.
assert_create_raises_plasma_full(self, 10 ** 3 + 1)
_, memory_buffer, _ = create_object(self.plasma_client, 9 * 10 ** 2, 0)
memory_buffers.append(memory_buffer)
# Remaining space is 10 ** 2.
assert_create_raises_plasma_full(self, 10 ** 2 + 1)
_, memory_buffer, _ = create_object(self.plasma_client, 9 * 10 ** 1, 0)
memory_buffers.append(memory_buffer)
# Remaining space is 10 ** 1.
assert_create_raises_plasma_full(self, 10 ** 1 + 1)
_, memory_buffer, _ = create_object(self.plasma_client, 9 * 10 ** 0, 0)
memory_buffers.append(memory_buffer)
# Remaining space is 10 ** 0.
assert_create_raises_plasma_full(self, 10 ** 0 + 1)
_, memory_buffer, _ = create_object(self.plasma_client, 1, 0)
def test_contains(self):
fake_object_ids = [random_object_id() for _ in range(100)]
real_object_ids = [random_object_id() for _ in range(100)]
for object_id in real_object_ids:
self.assertFalse(self.plasma_client.contains(object_id))
memory_buffer = self.plasma_client.create(object_id, 100)
self.plasma_client.seal(object_id)
self.assertTrue(self.plasma_client.contains(object_id))
for object_id in fake_object_ids:
self.assertFalse(self.plasma_client.contains(object_id))
for object_id in real_object_ids:
self.assertTrue(self.plasma_client.contains(object_id))
def test_hash(self):
# Check the hash of an object that doesn't exist.
object_id1 = random_object_id()
h = self.plasma_client.hash(object_id1)
length = 1000
# Create a random object, and check that the hash function always returns
# the same value.
metadata = generate_metadata(length)
memory_buffer = self.plasma_client.create(object_id1, length, metadata)
for i in range(length):
memory_buffer[i] = chr(i % 256)
self.plasma_client.seal(object_id1)
self.assertEqual(self.plasma_client.hash(object_id1),
self.plasma_client.hash(object_id1))
# Create a second object with the same value as the first, and check that
# their hashes are equal.
object_id2 = random_object_id()
memory_buffer = self.plasma_client.create(object_id2, length, metadata)
for i in range(length):
memory_buffer[i] = chr(i % 256)
self.plasma_client.seal(object_id2)
self.assertEqual(self.plasma_client.hash(object_id1),
self.plasma_client.hash(object_id2))
# Create a third object with a different value from the first two, and
# check that its hash is different.
object_id3 = random_object_id()
metadata = generate_metadata(length)
memory_buffer = self.plasma_client.create(object_id3, length, metadata)
for i in range(length):
memory_buffer[i] = chr((i + 1) % 256)
self.plasma_client.seal(object_id3)
self.assertNotEqual(self.plasma_client.hash(object_id1),
self.plasma_client.hash(object_id3))
# Create a fourth object with the same value as the third, but different
# metadata. Check that its hash is different from any of the previous
# three.
object_id4 = random_object_id()
metadata4 = generate_metadata(length)
memory_buffer = self.plasma_client.create(object_id4, length, metadata4)
for i in range(length):
memory_buffer[i] = chr((i + 1) % 256)
self.plasma_client.seal(object_id4)
self.assertNotEqual(self.plasma_client.hash(object_id1),
self.plasma_client.hash(object_id4))
self.assertNotEqual(self.plasma_client.hash(object_id3),
self.plasma_client.hash(object_id4))
def test_many_hashes(self):
hashes = []
length = 2 ** 10
for i in range(256):
object_id = random_object_id()
memory_buffer = self.plasma_client.create(object_id, length)
for j in range(length):
memory_buffer[j] = chr(i)
self.plasma_client.seal(object_id)
hashes.append(self.plasma_client.hash(object_id))
# Create objects of varying length. Each pair has two bits different.
for i in range(length):
object_id = random_object_id()
memory_buffer = self.plasma_client.create(object_id, length)
for j in range(length):
memory_buffer[j] = chr(0)
memory_buffer[i] = chr(1)
self.plasma_client.seal(object_id)
hashes.append(self.plasma_client.hash(object_id))
# Create objects of varying length, all with value 0.
for i in range(length):
object_id = random_object_id()
memory_buffer = self.plasma_client.create(object_id, i)
for j in range(i):
memory_buffer[j] = chr(0)
self.plasma_client.seal(object_id)
hashes.append(self.plasma_client.hash(object_id))
# Check that all hashes were unique.
self.assertEqual(len(set(hashes)), 256 + length + length)
# def test_individual_delete(self):
# length = 100
# # Create an object id string.
# object_id = random_object_id()
# # Create a random metadata string.
# metadata = generate_metadata(100)
# # Create a new buffer and write to it.
# memory_buffer = self.plasma_client.create(object_id, length, metadata)
# for i in range(length):
# memory_buffer[i] = chr(i % 256)
# # Seal the object.
# self.plasma_client.seal(object_id)
# # Check that the object is present.
# self.assertTrue(self.plasma_client.contains(object_id))
# # Delete the object.
# self.plasma_client.delete(object_id)
# # Make sure the object is no longer present.
# self.assertFalse(self.plasma_client.contains(object_id))
#
# def test_delete(self):
# # Create some objects.
# object_ids = [random_object_id() for _ in range(100)]
# for object_id in object_ids:
# length = 100
# # Create a random metadata string.
# metadata = generate_metadata(100)
# # Create a new buffer and write to it.
# memory_buffer = self.plasma_client.create(object_id, length, metadata)
# for i in range(length):
# memory_buffer[i] = chr(i % 256)
# # Seal the object.
# self.plasma_client.seal(object_id)
# # Check that the object is present.
# self.assertTrue(self.plasma_client.contains(object_id))
#
# # Delete the objects and make sure they are no longer present.
# for object_id in object_ids:
# # Delete the object.
# self.plasma_client.delete(object_id)
# # Make sure the object is no longer present.
# self.assertFalse(self.plasma_client.contains(object_id))
def test_illegal_functionality(self):
# Create an object id string.
object_id = random_object_id()
# Create a new buffer and write to it.
length = 1000
memory_buffer = self.plasma_client.create(object_id, length)
# Make sure we cannot access memory out of bounds.
self.assertRaises(Exception, lambda : memory_buffer[length])
# Seal the object.
self.plasma_client.seal(object_id)
# This test is commented out because it currently fails.
# # Make sure the object is ready only now.
# def illegal_assignment():
# memory_buffer[0] = chr(0)
# self.assertRaises(Exception, illegal_assignment)
# Get the object.
memory_buffer = self.plasma_client.get(object_id)
# Make sure the object is read only.
def illegal_assignment():
memory_buffer[0] = chr(0)
self.assertRaises(Exception, illegal_assignment)
def test_evict(self):
client = self.plasma_client2
object_id1 = random_object_id()
b1 = client.create(object_id1, 1000)
client.seal(object_id1)
del b1
self.assertEqual(client.evict(1), 1000)
object_id2 = random_object_id()
object_id3 = random_object_id()
b2 = client.create(object_id2, 999)
b3 = client.create(object_id3, 998)
client.seal(object_id3)
del b3
self.assertEqual(client.evict(1000), 998)
object_id4 = random_object_id()
b4 = client.create(object_id4, 997)
client.seal(object_id4)
del b4
client.seal(object_id2)
del b2
self.assertEqual(client.evict(1), 997)
self.assertEqual(client.evict(1), 999)
object_id5 = random_object_id()
object_id6 = random_object_id()
object_id7 = random_object_id()
b5 = client.create(object_id5, 996)
b6 = client.create(object_id6, 995)
b7 = client.create(object_id7, 994)
client.seal(object_id5)
client.seal(object_id6)
client.seal(object_id7)
del b5
del b6
del b7
self.assertEqual(client.evict(2000), 996 + 995 + 994)
def test_subscribe(self):
# Subscribe to notifications from the Plasma Store.
sock = self.plasma_client.subscribe()
for i in [1, 10, 100, 1000, 10000, 100000]:
object_ids = [random_object_id() for _ in range(i)]
metadata_sizes = [np.random.randint(1000) for _ in range(i)]
data_sizes = [np.random.randint(1000) for _ in range(i)]
for j in range(i):
self.plasma_client.create(object_ids[j], size=data_sizes[j],
metadata=bytearray(np.random.bytes(metadata_sizes[j])))
self.plasma_client.seal(object_ids[j])
# Check that we received notifications for all of the objects.
for j in range(i):
recv_objid, recv_dsize, recv_msize = self.plasma_client.get_next_notification()
self.assertEqual(object_ids[j], recv_objid)
self.assertEqual(data_sizes[j], recv_dsize)
self.assertEqual(metadata_sizes[j], recv_msize)
def test_subscribe_deletions(self):
# Subscribe to notifications from the Plasma Store. We use plasma_client2
# to make sure that all used objects will get evicted properly.
sock = self.plasma_client2.subscribe()
for i in [1, 10, 100, 1000, 10000, 100000]:
object_ids = [random_object_id() for _ in range(i)]
# Add 1 to the sizes to make sure we have nonzero object sizes.
metadata_sizes = [np.random.randint(1000) + 1 for _ in range(i)]
data_sizes = [np.random.randint(1000) + 1 for _ in range(i)]
for j in range(i):
x = self.plasma_client2.create(object_ids[j], size=data_sizes[j],
metadata=bytearray(np.random.bytes(metadata_sizes[j])))
self.plasma_client2.seal(object_ids[j])
del x
# Check that we received notifications for creating all of the objects.
for j in range(i):
recv_objid, recv_dsize, recv_msize = self.plasma_client2.get_next_notification()
self.assertEqual(object_ids[j], recv_objid)
self.assertEqual(data_sizes[j], recv_dsize)
self.assertEqual(metadata_sizes[j], recv_msize)
# Check that we receive notifications for deleting all objects, as we
# evict them.
for j in range(i):
self.assertEqual(self.plasma_client2.evict(1), data_sizes[j] + metadata_sizes[j])
recv_objid, recv_dsize, recv_msize = self.plasma_client2.get_next_notification()
self.assertEqual(object_ids[j], recv_objid)
self.assertEqual(-1, recv_dsize)
self.assertEqual(-1, recv_msize)
# Test multiple deletion notifications. The first 9 object IDs have size 0,
# and the last has a nonzero size. When Plasma evicts 1 byte, it will evict
# all objects, so we should receive deletion notifications for each.
num_object_ids = 10
object_ids = [random_object_id() for _ in range(num_object_ids)]
metadata_sizes = [0] * (num_object_ids - 1)
data_sizes = [0] * (num_object_ids - 1)
metadata_sizes.append(np.random.randint(1000))
data_sizes.append(np.random.randint(1000))
for i in range(num_object_ids):
x = self.plasma_client2.create(object_ids[i], size=data_sizes[i],
metadata=bytearray(np.random.bytes(metadata_sizes[i])))
self.plasma_client2.seal(object_ids[i])
del x
for i in range(num_object_ids):
recv_objid, recv_dsize, recv_msize = self.plasma_client2.get_next_notification()
self.assertEqual(object_ids[i], recv_objid)
self.assertEqual(data_sizes[i], recv_dsize)
self.assertEqual(metadata_sizes[i], recv_msize)
self.assertEqual(self.plasma_client2.evict(1), data_sizes[-1] + metadata_sizes[-1])
for i in range(num_object_ids):
recv_objid, recv_dsize, recv_msize = self.plasma_client2.get_next_notification()
self.assertEqual(object_ids[i], recv_objid)
self.assertEqual(-1, recv_dsize)
self.assertEqual(-1, recv_msize)
class TestPlasmaManager(unittest.TestCase):
def setUp(self):
# Start two PlasmaStores.
store_name1, self.p2 = plasma.start_plasma_store(use_valgrind=USE_VALGRIND)
store_name2, self.p3 = plasma.start_plasma_store(use_valgrind=USE_VALGRIND)
# Start a Redis server.
redis_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../core/src/common/thirdparty/redis/src/redis-server")
redis_module = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../core/src/common/redis_module/libray_redis_module.so")
assert os.path.isfile(redis_path)
assert os.path.isfile(redis_module)
redis_port = 6379
with open(os.devnull, "w") as FNULL:
self.redis_process = subprocess.Popen([redis_path,
"--port", str(redis_port),
"--loadmodule", redis_module],
stdout=FNULL)
time.sleep(0.1)
# Start two PlasmaManagers.
redis_address = "{}:{}".format("127.0.0.1", redis_port)
manager_name1, self.p4, self.port1 = plasma.start_plasma_manager(store_name1, redis_address, use_valgrind=USE_VALGRIND)
manager_name2, self.p5, self.port2 = plasma.start_plasma_manager(store_name2, redis_address, use_valgrind=USE_VALGRIND)
# Connect two PlasmaClients.
self.client1 = plasma.PlasmaClient(store_name1, manager_name1)
self.client2 = plasma.PlasmaClient(store_name2, manager_name2)
# Store the processes that will be explicitly killed during tearDown so
# that a test case can remove ones that will be killed during the test.
# NOTE: If this specific order is changed, valgrind will fail.
self.processes_to_kill = [self.p4, self.p5, self.p2, self.p3]
def tearDown(self):
# Check that the processes are still alive.
for process in self.processes_to_kill:
self.assertEqual(process.poll(), None)
# Kill the Plasma store and Plasma manager processes.
if USE_VALGRIND:
time.sleep(1) # give processes opportunity to finish work
for process in self.processes_to_kill:
process.send_signal(signal.SIGTERM)
process.wait()
if process.returncode != 0:
print("aborting due to valgrind error")
os._exit(-1)
else:
for process in self.processes_to_kill:
process.kill()
self.redis_process.kill()
def test_fetch(self):
if self.redis_process is None:
print("Cannot test fetch without a running redis instance.")
self.assertTrue(False)
for _ in range(10):
# Create an object.
object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000, 2000)
self.client1.fetch([object_id1])
self.assertEqual(self.client1.contains(object_id1), True)
self.assertEqual(self.client2.contains(object_id1), False)
# Fetch the object from the other plasma manager.
# TODO(rkn): Right now we must wait for the object table to be updated.
while not self.client2.contains(object_id1):
self.client2.fetch([object_id1])
# Compare the two buffers.
assert_get_object_equal(self, self.client1, self.client2, object_id1,
memory_buffer=memory_buffer1, metadata=metadata1)
# Test that we can call fetch on object IDs that don't exist yet.
object_id2 = random_object_id()
self.client1.fetch([object_id2])
self.assertEqual(self.client1.contains(object_id2), False)
memory_buffer2, metadata2 = create_object_with_id(self.client2, object_id2, 2000, 2000)
# # Check that the object has been fetched.
# self.assertEqual(self.client1.contains(object_id2), True)
# Compare the two buffers.
# assert_get_object_equal(self, self.client1, self.client2, object_id2,
# memory_buffer=memory_buffer2, metadata=metadata2)
# Test calling the same fetch request a bunch of times.
object_id3 = random_object_id()
self.assertEqual(self.client1.contains(object_id3), False)
self.assertEqual(self.client2.contains(object_id3), False)
for _ in range(10):
self.client1.fetch([object_id3])
self.client2.fetch([object_id3])
memory_buffer3, metadata3 = create_object_with_id(self.client1, object_id3, 2000, 2000)
for _ in range(10):
self.client1.fetch([object_id3])
self.client2.fetch([object_id3])
#TODO(rkn): Right now we must wait for the object table to be updated.
while not self.client2.contains(object_id3):
self.client2.fetch([object_id3])
assert_get_object_equal(self, self.client1, self.client2, object_id3,
memory_buffer=memory_buffer3, metadata=metadata3)
def test_fetch_multiple(self):
if self.redis_process is None:
print("Cannot test fetch without a running redis instance.")
self.assertTrue(False)
for _ in range(20):
# Create two objects and a third fake one that doesn't exist.
object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000, 2000)
missing_object_id = random_object_id()
object_id2, memory_buffer2, metadata2 = create_object(self.client1, 2000, 2000)
object_ids = [object_id1, missing_object_id, object_id2]
# Fetch the objects from the other plasma store. The second object ID
# should timeout since it does not exist.
# TODO(rkn): Right now we must wait for the object table to be updated.
while (not self.client2.contains(object_id1)) or (not self.client2.contains(object_id2)):
self.client2.fetch(object_ids)
# Compare the buffers of the objects that do exist.
assert_get_object_equal(self, self.client1, self.client2, object_id1,
memory_buffer=memory_buffer1, metadata=metadata1)
assert_get_object_equal(self, self.client1, self.client2, object_id2,
memory_buffer=memory_buffer2, metadata=metadata2)
# Fetch in the other direction. The fake object still does not exist.
self.client1.fetch(object_ids)
assert_get_object_equal(self, self.client2, self.client1, object_id1,
memory_buffer=memory_buffer1, metadata=metadata1)
assert_get_object_equal(self, self.client2, self.client1, object_id2,
memory_buffer=memory_buffer2, metadata=metadata2)
# Check that we can call fetch with duplicated object IDs.
object_id3 = random_object_id()
self.client1.fetch([object_id3, object_id3])
object_id4, memory_buffer4, metadata4 = create_object(self.client1, 2000, 2000)
time.sleep(0.1)
# TODO(rkn): Right now we must wait for the object table to be updated.
while not self.client2.contains(object_id4):
self.client2.fetch([object_id3, object_id3, object_id4, object_id4])
assert_get_object_equal(self, self.client2, self.client1, object_id4,
memory_buffer=memory_buffer4, metadata=metadata4)
def test_wait(self):
# Test timeout.
obj_id0 = random_object_id()
self.client1.wait([obj_id0], timeout=100, num_returns=1)
# If we get here, the test worked.
# Test wait if local objects available.
obj_id1 = random_object_id()
self.client1.create(obj_id1, 1000)
self.client1.seal(obj_id1)
ready, waiting = self.client1.wait([obj_id1], timeout=100, num_returns=1)
self.assertEqual(set(ready), set([obj_id1]))
self.assertEqual(waiting, [])
# Test wait if only one object available and only one object waited for.
obj_id2 = random_object_id()
self.client1.create(obj_id2, 1000)
# Don't seal.
ready, waiting = self.client1.wait([obj_id2, obj_id1], timeout=100, num_returns=1)
self.assertEqual(set(ready), set([obj_id1]))
self.assertEqual(set(waiting), set([obj_id2]))
# Test wait if object is sealed later.
obj_id3 = random_object_id()
def finish():
self.client2.create(obj_id3, 1000)
self.client2.seal(obj_id3)
t = threading.Timer(0.1, finish)
t.start()
ready, waiting = self.client1.wait([obj_id3, obj_id2, obj_id1], timeout=1000, num_returns=2)
self.assertEqual(set(ready), set([obj_id1, obj_id3]))
self.assertEqual(set(waiting), set([obj_id2]))
# Test if the appropriate number of objects is shown if some objects are not ready
ready, waiting = self.client1.wait([obj_id3, obj_id2, obj_id1], 100, 3)
self.assertEqual(set(ready), set([obj_id1, obj_id3]))
self.assertEqual(set(waiting), set([obj_id2]))
# Don't forget to seal obj_id2.
self.client1.seal(obj_id2)
# Test calling wait a bunch of times.
object_ids = []
# TODO(rkn): Increasing n to 100 (or larger) will cause failures. The
# problem appears to be that the number of timers added to the manager event
# loop slow down the manager so much that some of the asynchronous Redis
# commands timeout triggering fatal failure callbacks.
n = 40
for i in range(n * (n + 1) // 2):
if i % 2 == 0:
object_id, _, _ = create_object(self.client1, 200, 200)
else:
object_id, _, _ = create_object(self.client2, 200, 200)
object_ids.append(object_id)
# Try waiting for all of the object IDs on the first client.
waiting = object_ids
retrieved = []
for i in range(1, n + 1):
ready, waiting = self.client1.wait(waiting, timeout=1000, num_returns=i)
self.assertEqual(len(ready), i)
retrieved += ready
self.assertEqual(set(retrieved), set(object_ids))
ready, waiting = self.client1.wait(object_ids, timeout=1000, num_returns=len(object_ids))
self.assertEqual(set(ready), set(object_ids))
self.assertEqual(waiting, [])
# Try waiting for all of the object IDs on the second client.
waiting = object_ids
retrieved = []
for i in range(1, n + 1):
ready, waiting = self.client2.wait(waiting, timeout=1000, num_returns=i)
self.assertEqual(len(ready), i)
retrieved += ready
self.assertEqual(set(retrieved), set(object_ids))
ready, waiting = self.client2.wait(object_ids, timeout=1000, num_returns=len(object_ids))
self.assertEqual(set(ready), set(object_ids))
self.assertEqual(waiting, [])
# Make sure that wait returns when the requested number of object IDs are
# available and does not wait for all object IDs to be available.
object_ids = [random_object_id() for _ in range(9)] + [20 * b'\x00']
object_ids_perm = object_ids[:]
random.shuffle(object_ids_perm)
for i in range(10):
if i % 2 == 0:
create_object_with_id(self.client1, object_ids_perm[i], 2000, 2000)
else:
create_object_with_id(self.client2, object_ids_perm[i], 2000, 2000)
ready, waiting = self.client1.wait(object_ids, num_returns=(i + 1))
self.assertEqual(set(ready), set(object_ids_perm[:(i + 1)]))
self.assertEqual(set(waiting), set(object_ids_perm[(i + 1):]))
def test_transfer(self):
for _ in range(100):
# Create an object.
object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000, 2000)
# Transfer the buffer to the the other PlasmaStore.
self.client1.transfer("127.0.0.1", self.port2, object_id1)
# Compare the two buffers.
assert_get_object_equal(self, self.client1, self.client2, object_id1,
memory_buffer=memory_buffer1, metadata=metadata1)
# # Transfer the buffer again.
# self.client1.transfer("127.0.0.1", self.port2, object_id1)
# # Compare the two buffers.
# assert_get_object_equal(self, self.client1, self.client2, object_id1,
# memory_buffer=memory_buffer1, metadata=metadata1)
# Create an object.
object_id2, memory_buffer2, metadata2 = create_object(self.client2, 20000, 20000)
# Transfer the buffer to the the other PlasmaStore.
self.client2.transfer("127.0.0.1", self.port1, object_id2)
# Compare the two buffers.
assert_get_object_equal(self, self.client1, self.client2, object_id2,
memory_buffer=memory_buffer2, metadata=metadata2)
def test_illegal_put(self):
"""
Test doing a put at the same object ID, but with different object data. The
first put should succeed. The second put should cause the plasma manager to
exit with a fatal error.
"""
if USE_VALGRIND:
# Don't run this test when we are using valgrind because when processes
# die without freeing up their state, valgrind complains.
return
# Create and seal the first object.
length = 1000
object_id = random_object_id()
memory_buffer1 = self.client1.create(object_id, length)
for i in range(length):
memory_buffer1[i] = chr(i % 256)
self.client1.seal(object_id)
# Create and seal the second object. It has all the same data as the first
# object, with one bit flipped.
memory_buffer2 = self.client2.create(object_id, length)
for i in range(length):
j = i
if j == 0:
j += 1
memory_buffer2[i] = chr(j % 256)
self.client2.seal(object_id)
# Give the second manager some time to complete the seal, then make sure it
# exited.
time_left = 100
while time_left > 0:
if self.p5.poll() != None:
self.processes_to_kill.remove(self.p5)
break
time_left -= 0.1
time.sleep(0.1)
print("Time waiting for plasma manager to fail = {:.2}".format(100 - time_left))
self.assertNotEqual(self.p5.poll(), None)
def test_illegal_functionality(self):
# Create an object id string.
object_id = random_object_id()
# Create a new buffer.
# memory_buffer = self.client1.create(object_id, 20000)
# This test is commented out because it currently fails.
# # Transferring the buffer before sealing it should fail.
# self.assertRaises(Exception, lambda : self.manager1.transfer(1, object_id))
def test_stresstest(self):
a = time.time()
object_ids = []
for i in range(10000): # TODO(pcm): increase this to 100000
object_id = random_object_id()
object_ids.append(object_id)
self.client1.create(object_id, 1)
self.client1.seal(object_id)
for object_id in object_ids:
self.client1.transfer("127.0.0.1", self.port2, object_id)
b = time.time() - a
print("it took", b, "seconds to put and transfer the objects")
if __name__ == "__main__":
if len(sys.argv) > 1:
# pop the argument so we don't mess with unittest's own argument parser
if sys.argv[-1] == "valgrind":
arg = sys.argv.pop()
USE_VALGRIND = True
print("Using valgrind for tests")
unittest.main(verbosity=2)
+38
View File
@@ -0,0 +1,38 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import random
def random_object_id():
return np.random.bytes(20)
def generate_metadata(length):
metadata_buffer = bytearray(length)
if length > 0:
metadata_buffer[0] = random.randint(0, 255)
metadata_buffer[-1] = random.randint(0, 255)
for _ in range(100):
metadata_buffer[random.randint(0, length - 1)] = random.randint(0, 255)
return metadata_buffer
def write_to_data_buffer(buff, length):
if length > 0:
buff[0] = chr(random.randint(0, 255))
buff[-1] = chr(random.randint(0, 255))
for _ in range(100):
buff[random.randint(0, length - 1)] = chr(random.randint(0, 255))
def create_object_with_id(client, object_id, data_size, metadata_size, seal=True):
metadata = generate_metadata(metadata_size)
memory_buffer = client.create(object_id, data_size, metadata)
write_to_data_buffer(memory_buffer, data_size)
if seal:
client.seal(object_id)
return memory_buffer, metadata
def create_object(client, data_size, metadata_size, seal=True):
object_id = random_object_id()
memory_buffer, metadata = create_object_with_id(client, object_id, data_size, metadata_size, seal=seal)
return object_id, memory_buffer, metadata
+20
View File
@@ -0,0 +1,20 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Ray version string
__version__ = "0.01"
import ctypes
# Windows only
if hasattr(ctypes, "windll"):
# Makes sure that all child processes die when we die
# Also makes sure that fatal crashes result in process termination rather than an error dialog (the latter is annoying since we have a lot of processes)
# This is done by associating all child processes with a "job" object that imposes this behavior.
(lambda kernel32: (lambda job: (lambda n: kernel32.SetInformationJobObject(job, 9, "\0" * 17 + chr(0x8 | 0x4 | 0x20) + "\0" * (n - 18), n))(0x90 if ctypes.sizeof(ctypes.c_void_p) > ctypes.sizeof(ctypes.c_int) else 0x70) and kernel32.AssignProcessToJobObject(job, ctypes.c_void_p(kernel32.GetCurrentProcess())))(ctypes.c_void_p(kernel32.CreateJobObjectW(None, None))) if kernel32 is not None else None)(ctypes.windll.kernel32)
import ray.experimental
import ray.serialization
from ray.worker import register_class, error_info, init, connect, disconnect, get, put, wait, remote, log_event, log_span, flush_log
from ray.worker import EnvironmentVariable, env
from ray.worker import SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, SILENT_MODE
+6
View File
@@ -0,0 +1,6 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from .utils import copy_directory
from .tfutils import TensorFlowVariables
@@ -0,0 +1,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from . import random
from . import linalg
from .core import *
@@ -0,0 +1,227 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import ray.experimental.array.remote as ra
import ray
__all__ = ["BLOCK_SIZE", "DistArray", "assemble", "zeros", "ones", "copy",
"eye", "triu", "tril", "blockwise_dot", "dot", "transpose", "add", "subtract", "numpy_to_dist", "subblocks"]
BLOCK_SIZE = 10
class DistArray(object):
def __init__(self, shape, objectids=None):
self.shape = shape
self.ndim = len(shape)
self.num_blocks = [int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in self.shape]
self.objectids = objectids if objectids is not None else np.empty(self.num_blocks, dtype=object)
if self.num_blocks != list(self.objectids.shape):
raise Exception("The fields `num_blocks` and `objectids` are inconsistent, `num_blocks` is {} and `objectids` has shape {}".format(self.num_blocks, list(self.objectids.shape)))
@staticmethod
def compute_block_lower(index, shape):
if len(index) != len(shape):
raise Exception("The fields `index` and `shape` must have the same length, but `index` is {} and `shape` is {}.".format(index, shape))
return [elem * BLOCK_SIZE for elem in index]
@staticmethod
def compute_block_upper(index, shape):
if len(index) != len(shape):
raise Exception("The fields `index` and `shape` must have the same length, but `index` is {} and `shape` is {}.".format(index, shape))
upper = []
for i in range(len(shape)):
upper.append(min((index[i] + 1) * BLOCK_SIZE, shape[i]))
return upper
@staticmethod
def compute_block_shape(index, shape):
lower = DistArray.compute_block_lower(index, shape)
upper = DistArray.compute_block_upper(index, shape)
return [u - l for (l, u) in zip(lower, upper)]
@staticmethod
def compute_num_blocks(shape):
return [int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in shape]
def assemble(self):
"""Assemble an array on this node from a distributed array of object IDs."""
first_block = ray.get(self.objectids[(0,) * self.ndim])
dtype = first_block.dtype
result = np.zeros(self.shape, dtype=dtype)
for index in np.ndindex(*self.num_blocks):
lower = DistArray.compute_block_lower(index, self.shape)
upper = DistArray.compute_block_upper(index, self.shape)
result[[slice(l, u) for (l, u) in zip(lower, upper)]] = ray.get(self.objectids[index])
return result
def __getitem__(self, sliced):
# TODO(rkn): fix this, this is just a placeholder that should work but is inefficient
a = self.assemble()
return a[sliced]
# Register the DistArray class with Ray so that it knows how to serialize it.
ray.register_class(DistArray)
@ray.remote
def assemble(a):
return a.assemble()
# TODO(rkn): what should we call this method
@ray.remote
def numpy_to_dist(a):
result = DistArray(a.shape)
for index in np.ndindex(*result.num_blocks):
lower = DistArray.compute_block_lower(index, a.shape)
upper = DistArray.compute_block_upper(index, a.shape)
result.objectids[index] = ray.put(a[[slice(l, u) for (l, u) in zip(lower, upper)]])
return result
@ray.remote
def zeros(shape, dtype_name="float"):
result = DistArray(shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = ra.zeros.remote(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
return result
@ray.remote
def ones(shape, dtype_name="float"):
result = DistArray(shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = ra.ones.remote(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
return result
@ray.remote
def copy(a):
result = DistArray(a.shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = a.objectids[index] # We don't need to actually copy the objects because cluster-level objects are assumed to be immutable.
return result
@ray.remote
def eye(dim1, dim2=-1, dtype_name="float"):
dim2 = dim1 if dim2 == -1 else dim2
shape = [dim1, dim2]
result = DistArray(shape)
for (i, j) in np.ndindex(*result.num_blocks):
block_shape = DistArray.compute_block_shape([i, j], shape)
if i == j:
result.objectids[i, j] = ra.eye.remote(block_shape[0], block_shape[1], dtype_name=dtype_name)
else:
result.objectids[i, j] = ra.zeros.remote(block_shape, dtype_name=dtype_name)
return result
@ray.remote
def triu(a):
if a.ndim != 2:
raise Exception("Input must have 2 dimensions, but a.ndim is " + str(a.ndim))
result = DistArray(a.shape)
for (i, j) in np.ndindex(*result.num_blocks):
if i < j:
result.objectids[i, j] = ra.copy.remote(a.objectids[i, j])
elif i == j:
result.objectids[i, j] = ra.triu.remote(a.objectids[i, j])
else:
result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j])
return result
@ray.remote
def tril(a):
if a.ndim != 2:
raise Exception("Input must have 2 dimensions, but a.ndim is " + str(a.ndim))
result = DistArray(a.shape)
for (i, j) in np.ndindex(*result.num_blocks):
if i > j:
result.objectids[i, j] = ra.copy.remote(a.objectids[i, j])
elif i == j:
result.objectids[i, j] = ra.tril.remote(a.objectids[i, j])
else:
result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j])
return result
@ray.remote
def blockwise_dot(*matrices):
n = len(matrices)
if n % 2 != 0:
raise Exception("blockwise_dot expects an even number of arguments, but len(matrices) is {}.".format(n))
shape = (matrices[0].shape[0], matrices[n // 2].shape[1])
result = np.zeros(shape)
for i in range(n // 2):
result += np.dot(matrices[i], matrices[n // 2 + i])
return result
@ray.remote
def dot(a, b):
if a.ndim != 2:
raise Exception("dot expects its arguments to be 2-dimensional, but a.ndim = {}.".format(a.ndim))
if b.ndim != 2:
raise Exception("dot expects its arguments to be 2-dimensional, but b.ndim = {}.".format(b.ndim))
if a.shape[1] != b.shape[0]:
raise Exception("dot expects a.shape[1] to equal b.shape[0], but a.shape = {} and b.shape = {}.".format(a.shape, b.shape))
shape = [a.shape[0], b.shape[1]]
result = DistArray(shape)
for (i, j) in np.ndindex(*result.num_blocks):
args = list(a.objectids[i, :]) + list(b.objectids[:, j])
result.objectids[i, j] = blockwise_dot.remote(*args)
return result
@ray.remote
def subblocks(a, *ranges):
"""
This function produces a distributed array from a subset of the blocks in the `a`. The result and `a` will have the same number of dimensions.For example,
subblocks(a, [0, 1], [2, 4])
will produce a DistArray whose objectids are
[[a.objectids[0, 2], a.objectids[0, 4]],
[a.objectids[1, 2], a.objectids[1, 4]]]
We allow the user to pass in an empty list [] to indicate the full range.
"""
ranges = list(ranges)
if len(ranges) != a.ndim:
raise Exception("sub_blocks expects to receive a number of ranges equal to a.ndim, but it received {} ranges and a.ndim = {}.".format(len(ranges), a.ndim))
for i in range(len(ranges)):
if ranges[i] == []: # We allow the user to pass in an empty list to indicate the full range
ranges[i] = range(a.num_blocks[i])
if not np.alltrue(ranges[i] == np.sort(ranges[i])):
raise Exception("Ranges passed to sub_blocks must be sorted, but the {}th range is {}.".format(i, ranges[i]))
if ranges[i][0] < 0:
raise Exception("Values in the ranges passed to sub_blocks must be at least 0, but the {}th range is {}.".format(i, ranges[i]))
if ranges[i][-1] >= a.num_blocks[i]:
raise Exception("Values in the ranges passed to sub_blocks must be less than the relevant number of blocks, but the {}th range is {}, and a.num_blocks = {}.".format(i, ranges[i], a.num_blocks))
last_index = [r[-1] for r in ranges]
last_block_shape = DistArray.compute_block_shape(last_index, a.shape)
shape = [(len(ranges[i]) - 1) * BLOCK_SIZE + last_block_shape[i] for i in range(a.ndim)]
result = DistArray(shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = a.objectids[tuple([ranges[i][index[i]] for i in range(a.ndim)])]
return result
@ray.remote
def transpose(a):
if a.ndim != 2:
raise Exception("transpose expects its argument to be 2-dimensional, but a.ndim = {}, a.shape = {}.".format(a.ndim, a.shape))
result = DistArray([a.shape[1], a.shape[0]])
for i in range(result.num_blocks[0]):
for j in range(result.num_blocks[1]):
result.objectids[i, j] = ra.transpose.remote(a.objectids[j, i])
return result
# TODO(rkn): support broadcasting?
@ray.remote
def add(x1, x2):
if x1.shape != x2.shape:
raise Exception("add expects arguments `x1` and `x2` to have the same shape, but x1.shape = {}, and x2.shape = {}.".format(x1.shape, x2.shape))
result = DistArray(x1.shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = ra.add.remote(x1.objectids[index], x2.objectids[index])
return result
# TODO(rkn): support broadcasting?
@ray.remote
def subtract(x1, x2):
if x1.shape != x2.shape:
raise Exception("subtract expects arguments `x1` and `x2` to have the same shape, but x1.shape = {}, and x2.shape = {}.".format(x1.shape, x2.shape))
result = DistArray(x1.shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = ra.subtract.remote(x1.objectids[index], x2.objectids[index])
return result
@@ -0,0 +1,190 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import ray.experimental.array.remote as ra
import ray
from .core import *
__all__ = ["tsqr", "modified_lu", "tsqr_hr", "qr"]
@ray.remote(num_return_vals=2)
def tsqr(a):
"""
arguments:
a: a distributed matrix
Suppose that
a.shape == (M, N)
K == min(M, N)
return values:
q: DistArray, if q_full = ray.get(DistArray, q).assemble(), then
q_full.shape == (M, K)
np.allclose(np.dot(q_full.T, q_full), np.eye(K)) == True
r: np.ndarray, if r_val = ray.get(np.ndarray, r), then
r_val.shape == (K, N)
np.allclose(r, np.triu(r)) == True
"""
if len(a.shape) != 2:
raise Exception("tsqr requires len(a.shape) == 2, but a.shape is {}".format(a.shape))
if a.num_blocks[1] != 1:
raise Exception("tsqr requires a.num_blocks[1] == 1, but a.num_blocks is {}".format(a.num_blocks))
num_blocks = a.num_blocks[0]
K = int(np.ceil(np.log2(num_blocks))) + 1
q_tree = np.empty((num_blocks, K), dtype=object)
current_rs = []
for i in range(num_blocks):
block = a.objectids[i, 0]
q, r = ra.linalg.qr.remote(block)
q_tree[i, 0] = q
current_rs.append(r)
for j in range(1, K):
new_rs = []
for i in range(int(np.ceil(1.0 * len(current_rs) / 2))):
stacked_rs = ra.vstack.remote(*current_rs[(2 * i):(2 * i + 2)])
q, r = ra.linalg.qr.remote(stacked_rs)
q_tree[i, j] = q
new_rs.append(r)
current_rs = new_rs
assert len(current_rs) == 1, "len(current_rs) = " + str(len(current_rs))
# handle the special case in which the whole DistArray "a" fits in one block
# and has fewer rows than columns, this is a bit ugly so think about how to
# remove it
if a.shape[0] >= a.shape[1]:
q_shape = a.shape
else:
q_shape = [a.shape[0], a.shape[0]]
q_num_blocks = DistArray.compute_num_blocks(q_shape)
q_objectids = np.empty(q_num_blocks, dtype=object)
q_result = DistArray(q_shape, q_objectids)
# reconstruct output
for i in range(num_blocks):
q_block_current = q_tree[i, 0]
ith_index = i
for j in range(1, K):
if np.mod(ith_index, 2) == 0:
lower = [0, 0]
upper = [a.shape[1], BLOCK_SIZE]
else:
lower = [a.shape[1], 0]
upper = [2 * a.shape[1], BLOCK_SIZE]
ith_index //= 2
q_block_current = ra.dot.remote(q_block_current, ra.subarray.remote(q_tree[ith_index, j], lower, upper))
q_result.objectids[i] = q_block_current
r = current_rs[0]
return q_result, ray.get(r)
# TODO(rkn): This is unoptimized, we really want a block version of this.
@ray.remote(num_return_vals=3)
def modified_lu(q):
"""
Algorithm 5 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf
takes a matrix q with orthonormal columns, returns l, u, s such that q - s = l * u
arguments:
q: a two dimensional orthonormal q
return values:
l: lower triangular
u: upper triangular
s: a diagonal matrix represented by its diagonal
"""
q = q.assemble()
m, b = q.shape[0], q.shape[1]
S = np.zeros(b)
q_work = np.copy(q)
for i in range(b):
S[i] = -1 * np.sign(q_work[i, i])
q_work[i, i] -= S[i]
q_work[(i + 1):m, i] /= q_work[i, i] # scale ith column of L by diagonal element
q_work[(i + 1):m, (i + 1):b] -= np.outer(q_work[(i + 1):m, i], q_work[i, (i + 1):b]) # perform Schur complement update
L = np.tril(q_work)
for i in range(b):
L[i, i] = 1
U = np.triu(q_work)[:b, :]
return ray.get(numpy_to_dist.remote(ray.put(L))), U, S # TODO(rkn): get rid of put
@ray.remote(num_return_vals=2)
def tsqr_hr_helper1(u, s, y_top_block, b):
y_top = y_top_block[:b, :b]
s_full = np.diag(s)
t = -1 * np.dot(u, np.dot(s_full, np.linalg.inv(y_top).T))
return t, y_top
@ray.remote
def tsqr_hr_helper2(s, r_temp):
s_full = np.diag(s)
return np.dot(s_full, r_temp)
@ray.remote(num_return_vals=4)
def tsqr_hr(a):
"""Algorithm 6 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf"""
q, r_temp = tsqr.remote(a)
y, u, s = modified_lu.remote(q)
y_blocked = ray.get(y)
t, y_top = tsqr_hr_helper1.remote(u, s, y_blocked.objectids[0, 0], a.shape[1])
r = tsqr_hr_helper2.remote(s, r_temp)
return ray.get(y), ray.get(t), ray.get(y_top), ray.get(r)
@ray.remote
def qr_helper1(a_rc, y_ri, t, W_c):
return a_rc - np.dot(y_ri, np.dot(t.T, W_c))
@ray.remote
def qr_helper2(y_ri, a_rc):
return np.dot(y_ri.T, a_rc)
@ray.remote(num_return_vals=2)
def qr(a):
"""Algorithm 7 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf"""
m, n = a.shape[0], a.shape[1]
k = min(m, n)
# we will store our scratch work in a_work
a_work = DistArray(a.shape, np.copy(a.objectids))
result_dtype = np.linalg.qr(ray.get(a.objectids[0, 0]))[0].dtype.name
r_res = ray.get(zeros.remote([k, n], result_dtype)) # TODO(rkn): It would be preferable not to get this right after creating it.
y_res = ray.get(zeros.remote([m, k], result_dtype)) # TODO(rkn): It would be preferable not to get this right after creating it.
Ts = []
for i in range(min(a.num_blocks[0], a.num_blocks[1])): # this differs from the paper, which says "for i in range(a.num_blocks[1])", but that doesn't seem to make any sense when a.num_blocks[1] > a.num_blocks[0]
sub_dist_array = subblocks.remote(a_work, list(range(i, a_work.num_blocks[0])), [i])
y, t, _, R = tsqr_hr.remote(sub_dist_array)
y_val = ray.get(y)
for j in range(i, a.num_blocks[0]):
y_res.objectids[j, i] = y_val.objectids[j - i, 0]
if a.shape[0] > a.shape[1]:
# in this case, R needs to be square
R_shape = ray.get(ra.shape.remote(R))
eye_temp = ra.eye.remote(R_shape[1], R_shape[0], dtype_name=result_dtype)
r_res.objectids[i, i] = ra.dot.remote(eye_temp, R)
else:
r_res.objectids[i, i] = R
Ts.append(numpy_to_dist.remote(t))
for c in range(i + 1, a.num_blocks[1]):
W_rcs = []
for r in range(i, a.num_blocks[0]):
y_ri = y_val.objectids[r - i, 0]
W_rcs.append(qr_helper2.remote(y_ri, a_work.objectids[r, c]))
W_c = ra.sum_list.remote(*W_rcs)
for r in range(i, a.num_blocks[0]):
y_ri = y_val.objectids[r - i, 0]
A_rc = qr_helper1.remote(a_work.objectids[r, c], y_ri, t, W_c)
a_work.objectids[r, c] = A_rc
r_res.objectids[i, c] = a_work.objectids[i, c]
# construct q_res from Ys and Ts
q = eye.remote(m, k, dtype_name=result_dtype)
for i in range(len(Ts))[::-1]:
y_col_block = subblocks.remote(y_res, [], [i])
q = subtract.remote(q, dot.remote(y_col_block, dot.remote(Ts[i], dot.remote(transpose.remote(y_col_block), q))))
return ray.get(q), r_res
@@ -0,0 +1,18 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import ray.experimental.array.remote as ra
import ray
from .core import *
@ray.remote
def normal(shape):
num_blocks = DistArray.compute_num_blocks(shape)
objectids = np.empty(num_blocks, dtype=object)
for index in np.ndindex(*num_blocks):
objectids[index] = ra.random.normal.remote(DistArray.compute_block_shape(index, shape))
result = DistArray(shape, objectids)
return result
@@ -0,0 +1,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from . import random
from . import linalg
from .core import *
@@ -0,0 +1,86 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import ray
__all__ = ["zeros", "zeros_like", "ones", "eye", "dot", "vstack", "hstack", "subarray", "copy", "tril", "triu", "diag", "transpose", "add", "subtract", "sum", "shape", "sum_list"]
@ray.remote
def zeros(shape, dtype_name="float", order="C"):
return np.zeros(shape, dtype=np.dtype(dtype_name), order=order)
@ray.remote
def zeros_like(a, dtype_name="None", order="K", subok=True):
dtype_val = None if dtype_name == "None" else np.dtype(dtype_name)
return np.zeros_like(a, dtype=dtype_val, order=order, subok=subok)
@ray.remote
def ones(shape, dtype_name="float", order="C"):
return np.ones(shape, dtype=np.dtype(dtype_name), order=order)
@ray.remote
def eye(N, M=-1, k=0, dtype_name="float"):
M = N if M == -1 else M
return np.eye(N, M=M, k=k, dtype=np.dtype(dtype_name))
@ray.remote
def dot(a, b):
return np.dot(a, b)
@ray.remote
def vstack(*xs):
return np.vstack(xs)
@ray.remote
def hstack(*xs):
return np.hstack(xs)
# TODO(rkn): instead of this, consider implementing slicing
@ray.remote
def subarray(a, lower_indices, upper_indices): # TODO(rkn): be consistent about using "index" versus "indices"
return a[[slice(l, u) for (l, u) in zip(lower_indices, upper_indices)]]
@ray.remote
def copy(a, order="K"):
return np.copy(a, order=order)
@ray.remote
def tril(m, k=0):
return np.tril(m, k=k)
@ray.remote
def triu(m, k=0):
return np.triu(m, k=k)
@ray.remote
def diag(v, k=0):
return np.diag(v, k=k)
@ray.remote
def transpose(a, axes=[]):
axes = None if axes == [] else axes
return np.transpose(a, axes=axes)
@ray.remote
def add(x1, x2):
return np.add(x1, x2)
@ray.remote
def subtract(x1, x2):
return np.subtract(x1, x2)
@ray.remote
def sum(x, axis=-1):
return np.sum(x, axis=axis if axis != -1 else None)
@ray.remote
def shape(a):
return np.shape(a)
# We use Any to allow different numerical types as well as numpy arrays.
# TODO(rkn):this isn't in the numpy API, so be careful about exposing this.
@ray.remote
def sum_list(*xs):
return np.sum(xs, axis=0)
@@ -0,0 +1,91 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import ray
__all__ = ["matrix_power", "solve", "tensorsolve", "tensorinv", "inv",
"cholesky", "eigvals", "eigvalsh", "pinv", "slogdet", "det",
"svd", "eig", "eigh", "lstsq", "norm", "qr", "cond", "matrix_rank",
"LinAlgError", "multi_dot"]
@ray.remote
def matrix_power(M, n):
return np.linalg.matrix_power(M, n)
@ray.remote
def solve(a, b):
return np.linalg.solve(a, b)
@ray.remote(num_return_vals=2)
def tensorsolve(a):
raise NotImplementedError
@ray.remote(num_return_vals=2)
def tensorinv(a):
raise NotImplementedError
@ray.remote
def inv(a):
return np.linalg.inv(a)
@ray.remote
def cholesky(a):
return np.linalg.cholesky(a)
@ray.remote
def eigvals(a):
return np.linalg.eigvals(a)
@ray.remote
def eigvalsh(a):
raise NotImplementedError
@ray.remote
def pinv(a):
return np.linalg.pinv(a)
@ray.remote
def slogdet(a):
raise NotImplementedError
@ray.remote
def det(a):
return np.linalg.det(a)
@ray.remote(num_return_vals=3)
def svd(a):
return np.linalg.svd(a)
@ray.remote(num_return_vals=2)
def eig(a):
return np.linalg.eig(a)
@ray.remote(num_return_vals=2)
def eigh(a):
return np.linalg.eigh(a)
@ray.remote(num_return_vals=4)
def lstsq(a, b):
return np.linalg.lstsq(a)
@ray.remote
def norm(x):
return np.linalg.norm(x)
@ray.remote(num_return_vals=2)
def qr(a):
return np.linalg.qr(a)
@ray.remote
def cond(x):
return np.linalg.cond(x)
@ray.remote
def matrix_rank(M):
return np.linalg.matrix_rank(M)
@ray.remote
def multi_dot(*a):
raise NotImplementedError
@@ -0,0 +1,10 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import ray
@ray.remote
def normal(shape):
return np.random.normal(size=shape)
+78
View File
@@ -0,0 +1,78 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
def unflatten(vector, shapes):
i = 0
arrays = []
for shape in shapes:
size = np.prod(shape)
array = vector[i:(i + size)].reshape(shape)
arrays.append(array)
i += size
assert len(vector) == i, "Passed weight does not have the correct shape."
return arrays
class TensorFlowVariables(object):
"""An object used to extract variables from a loss function.
This object also provides methods for getting and setting the weights of the
relevant variables.
Attributes:
sess (tf.Session): The tensorflow session used to run assignment.
loss: The loss function passed in by the user.
variables (List[tf.Variable]): Extracted variables from the loss.
assignment_placeholders (List[tf.placeholders]): The nodes that weights get
passed to.
assignment_nodes (List[tf.Tensor]): The nodes that assign the weights.
"""
def __init__(self, loss, sess=None):
"""Creates a TensorFlowVariables instance."""
import tensorflow as tf
self.sess = sess
self.loss = loss
variable_names = [op.node_def.name for op in loss.graph.get_operations() if op.node_def.op == "Variable"]
self.variables = [v for v in tf.trainable_variables() if v.op.node_def.name in variable_names]
self.assignment_placeholders = dict()
self.assignment_nodes = []
# Create new placeholders to put in custom weights.
for var in self.variables:
self.assignment_placeholders[var.op.node_def.name] = tf.placeholder(var.value().dtype, var.get_shape().as_list())
self.assignment_nodes.append(var.assign(self.assignment_placeholders[var.op.node_def.name]))
def set_session(self, sess):
"""Modifies the current session used by the class."""
self.sess = sess
def get_flat_size(self):
return sum([np.prod(v.get_shape().as_list()) for v in self.variables])
def _check_sess(self):
"""Checks if the session is set, and if not throw an error message."""
assert self.sess is not None, "The session is not set. Set the session either by passing it into the TensorFlowVariables constructor or by calling set_session(sess)."
def get_flat(self):
"""Gets the weights and returns them as a flat array."""
self._check_sess()
return np.concatenate([v.eval(session=self.sess).flatten() for v in self.variables])
def set_flat(self, new_weights):
"""Sets the weights to new_weights, converting from a flat array."""
self._check_sess()
shapes = [v.get_shape().as_list() for v in self.variables]
arrays = unflatten(new_weights, shapes)
placeholders = [self.assignment_placeholders[v.op.node_def.name] for v in self.variables]
self.sess.run(self.assignment_nodes, feed_dict=dict(zip(placeholders,arrays)))
def get_weights(self):
"""Returns the weights of the variables of the loss function in a list."""
self._check_sess()
return {v.op.node_def.name: v.eval(session=self.sess) for v in self.variables}
def set_weights(self, new_weights):
"""Sets the weights to new_weights."""
self._check_sess()
self.sess.run(self.assignment_nodes, feed_dict={self.assignment_placeholders[name]: value for (name, value) in new_weights.items()})
+74
View File
@@ -0,0 +1,74 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import io
import os
import tarfile
import sys
import ray
def tarred_directory_as_bytes(source_dir):
"""Tar a directory and return it as a byte string.
Args:
source_dir (str): The name of the directory to tar.
Returns:
A byte string representing the tarred file.
"""
# Get a BytesIO object.
string_file = io.BytesIO()
# Create an in-memory tarfile of the source directory.
with tarfile.open(mode="w:gz", fileobj=string_file) as tar:
tar.add(source_dir, arcname=os.path.basename(source_dir))
string_file.seek(0)
return string_file.read()
def tarred_bytes_to_directory(tarred_bytes, target_dir):
"""Take a byte string and untar it.
Args:
tarred_bytes (str): A byte string representing the tarred file. This should
be the output of tarred_directory_as_bytes.
target_dir (str): The directory to create the untarred files in.
"""
string_file = io.BytesIO(tarred_bytes)
with tarfile.open(fileobj=string_file) as tar:
tar.extractall(path=target_dir)
def copy_directory(source_dir, target_dir=None):
"""Copy a local directory to each machine in the Ray cluster.
Note that both source_dir and target_dir must have the same basename). For
example, source_dir can be /a/b/c and target_dir can be /d/e/c. In this case,
the directory /d/e will be added to the Python path of each worker.
Note that this method is not completely safe to use. For example, workers that
do not do the copying and only set their paths (only one worker per node does
the copying) may try to execute functions that use the files in the directory
being copied before the directory being copied has finished untarring.
Args:
source_dir (str): The directory to copy.
target_dir (str): The location to copy it to on the other machines. If this
is not provided, the source_dir will be used. If it is provided and is
different from source_dir, the source_dir also be copied to the target_dir
location on this machine.
"""
target_dir = source_dir if target_dir is None else target_dir
source_dir = os.path.abspath(source_dir)
target_dir = os.path.abspath(target_dir)
source_basename = os.path.basename(source_dir)
target_basename = os.path.basename(target_dir)
if source_basename != target_basename:
raise Exception("The source_dir and target_dir must have the same base name, {} != {}".format(source_basename, target_basename))
tarred_bytes = tarred_directory_as_bytes(source_dir)
def f(worker_info):
if worker_info["counter"] == 0:
tarred_bytes_to_directory(tarred_bytes, os.path.dirname(target_dir))
sys.path.append(os.path.dirname(target_dir))
# Run this function on all workers to copy the directory to all nodes and to
# add the directory to the Python path of each worker.
ray.worker.global_worker.run_function_on_all_workers(f)
+67
View File
@@ -0,0 +1,67 @@
# Note that a little bit of code here is taken and slightly modified from the pickler because it was not possible to change its behavior otherwise.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
from ctypes import c_void_p
from cloudpickle import pickle, cloudpickle, CloudPickler, load, loads
try:
from ctypes import pythonapi
pythonapi.PyCell_Set # Make sure this exists
except:
pythonapi = None
def dump(obj, file, protocol=2):
return BetterPickler(file, protocol).dump(obj)
def dumps(obj):
stringio = cloudpickle.StringIO()
dump(obj, stringio)
return stringio.getvalue()
def _make_skel_func(code, closure, base_globals = None):
""" Creates a skeleton function object that contains just the provided
code and the correct number of cells in func_closure. All other
func attributes (e.g. func_globals) are empty.
"""
if base_globals is None: base_globals = {}
base_globals['__builtins__'] = __builtins__
return _make_skel_func.__class__(code, base_globals, None, None, tuple(closure))
def _fill_function(func, globals, defaults, closure, dict):
""" Fills in the rest of function data into the skeleton function object
that were created via _make_skel_func(), including closures.
"""
result = cloudpickle._fill_function(func, globals, defaults, dict)
if pythonapi is not None:
for i, v in enumerate(closure):
pythonapi.PyCell_Set(c_void_p(id(result.__closure__[i])), c_void_p(id(v)))
return result
class BetterPickler(CloudPickler):
def save_function_tuple(self, func):
code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func)
self.save(_fill_function)
self.write(pickle.MARK)
self.save(_make_skel_func if pythonapi else cloudpickle._make_skel_func)
self.save((code, map(lambda _: cloudpickle._make_cell(None), closure) if closure and pythonapi is not None else closure, base_globals))
self.write(pickle.REDUCE)
self.memoize(func)
self.save(f_globals)
self.save(defaults)
self.save(closure)
self.save(dct)
self.write(pickle.TUPLE)
self.write(pickle.REDUCE)
def save_cell(self, obj):
self.save(cloudpickle._make_cell)
self.save((obj.cell_contents,))
self.write(pickle.REDUCE)
dispatch = CloudPickler.dispatch.copy()
dispatch[(lambda _: lambda: _)(0).__closure__[0].__class__] = save_cell
+143
View File
@@ -0,0 +1,143 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import numbuf
import ray.pickling as pickling
def check_serializable(cls):
"""Throws an exception if Ray cannot serialize this class efficiently.
Args:
cls (type): The class to be serialized.
Raises:
Exception: An exception is raised if Ray cannot serialize this class
efficiently.
"""
if is_named_tuple(cls):
# This case works.
return
if not hasattr(cls, "__new__"):
raise Exception("The class {} does not have a '__new__' attribute, and is probably an old-style class. We do not support this. Please either make it a new-style class by inheriting from 'object', or use 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls))
try:
obj = cls.__new__(cls)
except:
raise Exception("The class {} has overridden '__new__', so Ray may not be able to serialize it efficiently. Try using 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls))
if not hasattr(obj, "__dict__"):
raise Exception("Objects of the class {} do not have a `__dict__` attribute, so Ray cannot serialize it efficiently. Try using 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls))
if hasattr(obj, "__slots__"):
raise Exception("The class {} uses '__slots__', so Ray may not be able to serialize it efficiently. Try using 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls))
# This field keeps track of a whitelisted set of classes that Ray will
# serialize.
whitelisted_classes = {}
classes_to_pickle = set()
custom_serializers = {}
custom_deserializers = {}
def class_identifier(typ):
"""Return a string that identifies this type."""
return "{}.{}".format(typ.__module__, typ.__name__)
def is_named_tuple(cls):
"""Return True if cls is a namedtuple and False otherwise."""
b = cls.__bases__
if len(b) != 1 or b[0] != tuple:
return False
f = getattr(cls, "_fields", None)
if not isinstance(f, tuple):
return False
return all(type(n) == str for n in f)
def add_class_to_whitelist(cls, pickle=False, custom_serializer=None, custom_deserializer=None):
"""Add cls to the list of classes that we can serialize.
Args:
cls (type): The class that we can serialize.
pickle (bool): True if the serialization should be done with pickle. False
if it should be done efficiently with Ray.
custom_serializer: This argument is optional, but can be provided to
serialize objects of the class in a particular way.
custom_deserializer: This argument is optional, but can be provided to
deserialize objects of the class in a particular way.
"""
class_id = class_identifier(cls)
whitelisted_classes[class_id] = cls
if pickle:
classes_to_pickle.add(class_id)
if custom_serializer is not None:
custom_serializers[class_id] = custom_serializer
custom_deserializers[class_id] = custom_deserializer
# Here we define a custom serializer and deserializer for handling numpy
# arrays that contain objects.
def array_custom_serializer(obj):
return obj.tolist(), obj.dtype.str
def array_custom_deserializer(serialized_obj):
return np.array(serialized_obj[0], dtype=np.dtype(serialized_obj[1]))
add_class_to_whitelist(np.ndarray, pickle=False, custom_serializer=array_custom_serializer, custom_deserializer=array_custom_deserializer)
def serialize(obj):
"""This is the callback that will be used by numbuf.
If numbuf does not know how to serialize an object, it will call this method.
Args:
obj (object): A Python object.
Returns:
A dictionary that has the key "_pyttype_" to identify the class, and
contains all information needed to reconstruct the object.
"""
class_id = class_identifier(type(obj))
if class_id not in whitelisted_classes:
raise Exception("Ray does not know how to serialize objects of type {}. To fix this, call 'ray.register_class' with this class.".format(type(obj)))
if class_id in classes_to_pickle:
serialized_obj = {"data": pickling.dumps(obj)}
elif class_id in custom_serializers.keys():
serialized_obj = {"data": custom_serializers[class_id](obj)}
else:
# Handle the namedtuple case.
if is_named_tuple(type(obj)):
serialized_obj = {}
serialized_obj["_ray_getnewargs_"] = obj.__getnewargs__()
elif hasattr(obj, "__dict__"):
serialized_obj = obj.__dict__
else:
raise Exception("We do not know how to serialize the object '{}'".format(obj))
result = dict(serialized_obj, **{"_pytype_": class_id})
return result
def deserialize(serialized_obj):
"""This is the callback that will be used by numbuf.
If numbuf encounters a dictionary that contains the key "_pytype_" during
deserialization, it will ask this callback to deserialize the object.
Args:
serialized_obj (object): A dictionary that contains the key "_pytype_".
Returns:
A Python object.
"""
class_id = serialized_obj["_pytype_"]
cls = whitelisted_classes[class_id]
if class_id in classes_to_pickle:
obj = pickling.loads(serialized_obj["data"])
elif class_id in custom_deserializers.keys():
obj = custom_deserializers[class_id](serialized_obj["data"])
else:
# In this case, serialized_obj should just be the __dict__ field.
if "_ray_getnewargs_" in serialized_obj:
obj = cls.__new__(cls, *serialized_obj["_ray_getnewargs_"])
else:
obj = cls.__new__(cls)
serialized_obj.pop("_pytype_")
obj.__dict__.update(serialized_obj)
return obj
# Register the callbacks with numbuf.
numbuf.register_callbacks(serialize, deserialize)
+529
View File
@@ -0,0 +1,529 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import psutil
import os
import random
import redis
import signal
import socket
import string
import subprocess
import sys
import time
from collections import namedtuple
# Ray modules
import photon
import plasma
import global_scheduler
# all_processes is a list of the scheduler, object store, and worker processes
# that have been started by this services module if Ray is being used in local
# mode.
all_processes = []
# True if processes are run in the valgrind profiler.
RUN_PHOTON_PROFILER = False
RUN_PLASMA_MANAGER_PROFILER = False
RUN_PLASMA_STORE_PROFILER = False
# ObjectStoreAddress tuples contain all information necessary to connect to an
# object store. The fields are:
# - name: The socket name for the object store
# - manager_name: The socket name for the object store manager
# - manager_port: The Internet port that the object store manager listens on
ObjectStoreAddress = namedtuple("ObjectStoreAddress", ["name",
"manager_name",
"manager_port"])
def address(host, port):
return host + ":" + str(port)
def get_port(address):
try:
port = int(address.split(":")[1])
except:
raise Exception("Unable to parse port from address {}".format(address))
return port
def new_port():
return random.randint(10000, 65535)
def random_name():
return str(random.randint(0, 99999999))
def cleanup():
"""When running in local mode, shutdown the Ray processes.
This method is used to shutdown processes that were started with
services.start_ray_local(). It kills all scheduler, object store, and worker
processes that were started by this services module. Driver processes are
started and disconnected by worker.py.
"""
global all_processes
successfully_shut_down = True
# Terminate the processes in reverse order.
for p in all_processes[::-1]:
if p.poll() is not None: # process has already terminated
continue
if RUN_PHOTON_PROFILER or RUN_PLASMA_MANAGER_PROFILER or RUN_PLASMA_STORE_PROFILER:
os.kill(p.pid, signal.SIGINT) # Give process signal to write profiler data.
time.sleep(0.1) # Wait for profiling data to be written.
p.kill()
time.sleep(0.05) # is this necessary?
if p.poll() is not None:
continue
p.terminate()
time.sleep(0.05) # is this necessary?
if p.poll is not None:
continue
successfully_shut_down = False
if successfully_shut_down:
if len(all_processes) > 0:
print("Successfully shut down Ray.")
else:
print("Ray did not shut down properly.")
all_processes = []
def all_processes_alive():
return all([p.poll() is None for p in all_processes])
def get_node_ip_address(address="8.8.8.8:53"):
"""Determine the IP address of the local node.
Args:
address (str): The IP address and port of any known live service on the
network you care about.
Returns:
The IP address of the current node.
"""
host, port = address.split(":")
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect((host, int(port)))
return s.getsockname()[0]
def wait_for_redis_to_start(redis_host, redis_port, num_retries=5):
"""Wait for a Redis server to be available.
This is accomplished by creating a Redis client and sending a random command
to the server until the command gets through.
Args:
redis_host (str): The IP address of the redis server.
redis_port (int): The port of the redis server.
num_retries (int): The number of times to try connecting with redis. The
client will sleep for one second between attempts.
Raises:
Exception: An exception is raised if we could not connect with Redis.
"""
redis_client = redis.StrictRedis(host=redis_host, port=redis_port)
# Wait for the Redis server to start.
counter = 0
while counter < num_retries:
try:
# Run some random command and see if it worked.
redis_client.client_list()
except redis.ConnectionError as e:
# Wait a little bit.
time.sleep(1)
print("Failed to connect to the redis server, retrying.")
counter += 1
else:
break
if counter == num_retries:
raise Exception("Unable to connect to Redis. If the Redis instance is on a different machine, check that your firewall is configured properly.")
def start_redis(node_ip_address, num_retries=20, cleanup=True, redirect_output=False):
"""Start a Redis server.
Args:
num_retries (int): The number of times to attempt to start Redis.
cleanup (bool): True if using Ray in local mode. If cleanup is true, then
this process will be killed by serices.cleanup() when the Python process
that imported services exits.
redirect_output (bool): True if stdout and stderr should be redirected to
/dev/null.
Returns:
The port used by Redis.
Raises:
Exception: An exception is raised if Redis could not be started.
"""
redis_filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../core/src/common/thirdparty/redis/src/redis-server")
redis_module = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../core/src/common/redis_module/libray_redis_module.so")
assert os.path.isfile(redis_filepath)
assert os.path.isfile(redis_module)
counter = 0
while counter < num_retries:
if counter > 0:
print("Redis failed to start, retrying now.")
port = new_port()
with open(os.devnull, "w") as FNULL:
stdout = FNULL if redirect_output else None
stderr = FNULL if redirect_output else None
p = subprocess.Popen([redis_filepath, "--port", str(port), "--loglevel", "warning", "--loadmodule", redis_module], stdout=stdout, stderr=stderr)
time.sleep(0.1)
# Check if Redis successfully started (or at least if it the executable did
# not exit within 0.1 seconds).
if p.poll() is None:
if cleanup:
all_processes.append(p)
break
counter += 1
if counter == num_retries:
raise Exception("Couldn't start Redis.")
# Create a Redis client just for configuring Redis.
redis_client = redis.StrictRedis(host="127.0.0.1", port=port)
# Wait for the Redis server to start.
wait_for_redis_to_start("127.0.0.1", port)
# Configure Redis to generate keyspace notifications. TODO(rkn): Change this
# to only generate notifications for the export keys.
redis_client.config_set("notify-keyspace-events", "Kl")
# Configure Redis to not run in protected mode so that processes on other
# hosts can connect to it. TODO(rkn): Do this in a more secure way.
redis_client.config_set("protected-mode", "no")
redis_address = address(node_ip_address, port)
return redis_address
def start_global_scheduler(redis_address, cleanup=True, redirect_output=False):
"""Start a global scheduler process.
Args:
redis_address (str): The address of the Redis instance.
cleanup (bool): True if using Ray in local mode. If cleanup is true, then
this process will be killed by serices.cleanup() when the Python process
that imported services exits.
redirect_output (bool): True if stdout and stderr should be redirected to
/dev/null.
"""
p = global_scheduler.start_global_scheduler(redis_address, redirect_output=redirect_output)
if cleanup:
all_processes.append(p)
def start_local_scheduler(redis_address, node_ip_address, plasma_store_name, plasma_manager_name, plasma_address=None, cleanup=True, redirect_output=False):
"""Start a local scheduler process.
Args:
redis_address (str): The address of the Redis instance.
node_ip_address (str): The IP address of the node that this local scheduler
is running on.
plasma_store_name (str): The name of the plasma store socket to connect to.
plasma_manager_name (str): The name of the plasma manager socket to connect
to.
cleanup (bool): True if using Ray in local mode. If cleanup is true, then
this process will be killed by serices.cleanup() when the Python process
that imported services exits.
redirect_output (bool): True if stdout and stderr should be redirected to
/dev/null.
Return:
The name of the local scheduler socket.
"""
local_scheduler_name, p = photon.start_local_scheduler(plasma_store_name, plasma_manager_name, node_ip_address=node_ip_address, redis_address=redis_address, plasma_address=plasma_address, use_profiler=RUN_PHOTON_PROFILER, redirect_output=redirect_output)
if cleanup:
all_processes.append(p)
return local_scheduler_name
def start_objstore(node_ip_address, redis_address, cleanup=True, redirect_output=False, objstore_memory=None):
"""This method starts an object store process.
Args:
node_ip_address (str): The IP address of the node running the object store.
redis_address (str): The address of the Redis instance to connect to.
cleanup (bool): True if using Ray in local mode. If cleanup is true, then
this process will be killed by serices.cleanup() when the Python process
that imported services exits.
redirect_output (bool): True if stdout and stderr should be redirected to
/dev/null.
Return:
A tuple of the Plasma store socket name, the Plasma manager socket name, and
the plasma manager port.
"""
if objstore_memory is None:
# Compute a fraction of the system memory for the Plasma store to use.
system_memory = psutil.virtual_memory().total
if sys.platform == "linux" or sys.platform == "linux2":
# On linux we use /dev/shm, its size is half the size of the physical
# memory. To not overflow it, we set the plasma memory limit to 0.4 times
# the size of the physical memory.
objstore_memory = int(system_memory * 0.4)
# Compare the requested memory size to the memory available in /dev/shm.
shm_fd = os.open("/dev/shm", os.O_RDONLY)
try:
shm_fs_stats = os.fstatvfs(shm_fd)
# The value shm_fs_stats.f_bsize is the block size and the value
# shm_fs_stats.f_bavail is the number of available blocks.
shm_avail = shm_fs_stats.f_bsize * shm_fs_stats.f_bavail
if objstore_memory > shm_avail:
print("Warning: Reducing object store memory because /dev/shm has only {} bytes available. You may be able to free up space by deleting files in /dev/shm. If you are inside a Docker container, you may need to pass an argument with the flag '--shm-size' to 'docker run'.".format(shm_avail))
objstore_memory = int(shm_avail * 0.8)
finally:
os.close(shm_fd)
else:
objstore_memory = int(system_memory * 0.8)
# Start the Plasma store.
plasma_store_name, p1 = plasma.start_plasma_store(plasma_store_memory=objstore_memory, use_profiler=RUN_PLASMA_STORE_PROFILER, redirect_output=redirect_output)
# Start the plasma manager.
plasma_manager_name, p2, plasma_manager_port = plasma.start_plasma_manager(plasma_store_name, redis_address, node_ip_address=node_ip_address, run_profiler=RUN_PLASMA_MANAGER_PROFILER, redirect_output=redirect_output)
if cleanup:
all_processes.append(p1)
all_processes.append(p2)
return ObjectStoreAddress(plasma_store_name, plasma_manager_name,
plasma_manager_port)
def start_worker(node_ip_address, object_store_name, object_store_manager_name, local_scheduler_name, redis_address, worker_path, cleanup=True, redirect_output=False):
"""This method starts a worker process.
Args:
node_ip_address (str): The IP address of the node that this worker is
running on.
object_store_name (str): The name of the object store.
object_store_manager_name (str): The name of the object store manager.
local_scheduler_name (str): The name of the local scheduler.
redis_address (int): The address that the Redis server is listening on.
worker_path (str): The path of the source code which the worker process will
run.
cleanup (bool): True if using Ray in local mode. If cleanup is true, then
this process will be killed by services.cleanup() when the Python process
that imported services exits. This is True by default.
redirect_output (bool): True if stdout and stderr should be redirected to
/dev/null.
"""
command = ["python",
worker_path,
"--node-ip-address=" + node_ip_address,
"--object-store-name=" + object_store_name,
"--object-store-manager-name=" + object_store_manager_name,
"--local-scheduler-name=" + local_scheduler_name,
"--redis-address=" + str(redis_address)]
with open(os.devnull, "w") as FNULL:
stdout = FNULL if redirect_output else None
stderr = FNULL if redirect_output else None
p = subprocess.Popen(command, stdout=stdout, stderr=stderr)
if cleanup:
all_processes.append(p)
def start_webui(redis_port, cleanup=True, redirect_output=False):
"""This method starts the web interface.
Args:
redis_port (int): The redis server's port
cleanup (bool): True if using Ray in local mode. If cleanup is true, then
this process will be killed by services.cleanup() when the Python process
that imported services exits. This is True by default.
redirect_output (bool): True if stdout and stderr should be redirected to
/dev/null.
"""
executable = "nodejs" if sys.platform == "linux" or sys.platform == "linux2" else "node"
command = [executable, os.path.join(os.path.abspath(os.path.dirname(__file__)), "../webui/index.js"), str(redis_port)]
with open("/tmp/webui_out.txt", "wb") as out:
with open(os.devnull, "w") as FNULL:
stdout = FNULL if redirect_output else out
stderr = FNULL if redirect_output else None
p = subprocess.Popen(command, stdout=stdout, stderr=stderr)
if cleanup:
all_processes.append(p)
def start_ray_processes(address_info=None,
node_ip_address="127.0.0.1",
num_workers=0,
num_local_schedulers=1,
worker_path=None,
cleanup=True,
redirect_output=False,
include_global_scheduler=False):
"""Helper method to start Ray processes.
Args:
address_info (dict): A dictionary with address information for processes
that have already been started. If provided, address_info will be
modified to include processes that are newly started.
node_ip_address (str): The IP address of this node.
num_workers (int): The number of workers to start.
num_local_schedulers (int): The total number of local schedulers required.
This is also the total number of object stores required. This method will
start new instances of local schedulers and object stores until there are
num_local_schedulers existing instances of each, including ones already
registered with the given address_info.
worker_path (str): The path of the source code that will be run by the
worker.
cleanup (bool): If cleanup is true, then the processes started here will be
killed by services.cleanup() when the Python process that called this
method exits.
redirect_output (bool): True if stdout and stderr should be redirected to
/dev/null.
include_global_scheduler (bool): If include_global_scheduler is True, then
start a global scheduler process.
Returns:
A dictionary of the address information for the processes that were
started.
"""
if address_info is None:
address_info = {}
address_info["node_ip_address"] = node_ip_address
if worker_path is None:
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "workers/default_worker.py")
# Start Redis if there isn't already an instance running. TODO(rkn): We are
# suppressing the output of Redis because on Linux it prints a bunch of
# warning messages when it starts up. Instead of suppressing the output, we
# should address the warnings.
redis_address = address_info.get("redis_address")
if redis_address is None:
redis_address = start_redis(node_ip_address, cleanup=cleanup,
redirect_output=redirect_output)
address_info["redis_address"] = redis_address
time.sleep(0.1)
redis_port = get_port(redis_address)
# Start the global scheduler, if necessary.
if include_global_scheduler:
start_global_scheduler(redis_address, cleanup=cleanup,
redirect_output=redirect_output)
# Initialize with existing services.
if "object_store_addresses" not in address_info:
address_info["object_store_addresses"] = []
object_store_addresses = address_info["object_store_addresses"]
if "local_scheduler_socket_names" not in address_info:
address_info["local_scheduler_socket_names"] = []
local_scheduler_socket_names = address_info["local_scheduler_socket_names"]
# Start any object stores that do not yet exist.
for _ in range(num_local_schedulers - len(object_store_addresses)):
# Start Plasma.
object_store_address = start_objstore(node_ip_address, redis_address,
cleanup=cleanup,
redirect_output=redirect_output)
object_store_addresses.append(object_store_address)
time.sleep(0.1)
# Start any local schedulers that do not yet exist.
for i in range(len(local_scheduler_socket_names), num_local_schedulers):
# Connect the local scheduler to the object store at the same index.
object_store_address = object_store_addresses[i]
plasma_address = "{}:{}".format(node_ip_address,
object_store_address.manager_port)
# Start the local scheduler.
local_scheduler_name = start_local_scheduler(redis_address,
node_ip_address,
object_store_address.name,
object_store_address.manager_name,
plasma_address=plasma_address,
cleanup=cleanup,
redirect_output=redirect_output)
local_scheduler_socket_names.append(local_scheduler_name)
time.sleep(0.1)
# Make sure that we have exactly num_local_schedulers instances of object
# stores and local schedulers.
assert len(object_store_addresses) == num_local_schedulers
assert len(local_scheduler_socket_names) == num_local_schedulers
# Start the workers.
for i in range(num_workers):
object_store_address = object_store_addresses[i % num_local_schedulers]
local_scheduler_name = local_scheduler_socket_names[i % num_local_schedulers]
start_worker(node_ip_address,
object_store_address.name,
object_store_address.manager_name,
local_scheduler_name,
redis_address,
worker_path,
cleanup=cleanup,
redirect_output=redirect_output)
# Return the addresses of the relevant processes.
return address_info
def start_ray_node(node_ip_address,
redis_address,
num_workers=0,
num_local_schedulers=1,
worker_path=None,
cleanup=True,
redirect_output=False):
"""Start the Ray processes for a single node.
This assumes that the Ray processes on some master node have already been
started.
Args:
node_ip_address (str): The IP address of this node.
redis_address (str): The address of the Redis server.
num_workers (int): The number of workers to start.
num_local_schedulers (int): The number of local schedulers to start. This is
also the number of plasma stores and plasma managers to start.
worker_path (str): The path of the source code that will be run by the
worker.
cleanup (bool): If cleanup is true, then the processes started here will be
killed by services.cleanup() when the Python process that called this
method exits.
redirect_output (bool): True if stdout and stderr should be redirected to
/dev/null.
Returns:
A dictionary of the address information for the processes that were
started.
"""
address_info = {
"redis_address": redis_address,
}
return start_ray_processes(address_info=address_info,
node_ip_address=node_ip_address,
num_workers=num_workers,
num_local_schedulers=num_local_schedulers,
worker_path=worker_path,
cleanup=cleanup,
redirect_output=redirect_output)
def start_ray_local(address_info=None,
node_ip_address="127.0.0.1",
num_workers=0,
num_local_schedulers=1,
worker_path=None,
cleanup=True,
redirect_output=False):
"""Start Ray in local mode.
Args:
address_info (dict): A dictionary with address information for processes
that have already been started. If provided, address_info will be
modified to include processes that are newly started.
node_ip_address (str): The IP address of this node.
num_workers (int): The number of workers to start.
num_local_schedulers (int): The total number of local schedulers required.
This is also the total number of object stores required. This method will
start new instances of local schedulers and object stores until there are
at least num_local_schedulers existing instances of each, including ones
already registered with the given address_info.
worker_path (str): The path of the source code that will be run by the
worker.
cleanup (bool): If cleanup is true, then the processes started here will be
killed by services.cleanup() when the Python process that called this
method exits.
redirect_output (bool): True if stdout and stderr should be redirected to
/dev/null.
Returns:
A dictionary of the address information for the processes that were
started.
"""
return start_ray_processes(address_info=address_info,
node_ip_address=node_ip_address,
num_workers=num_workers,
num_local_schedulers=num_local_schedulers,
worker_path=worker_path,
cleanup=cleanup,
redirect_output=redirect_output,
include_global_scheduler=True)
View File
+102
View File
@@ -0,0 +1,102 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ray
import numpy as np
# Test simple functionality
@ray.remote(num_return_vals=2)
def handle_int(a, b):
return a + 1, b + 1
# Test timing
@ray.remote
def empty_function():
pass
@ray.remote
def trivial_function():
return 1
# Test keyword arguments
@ray.remote
def keyword_fct1(a, b="hello"):
return "{} {}".format(a, b)
@ray.remote
def keyword_fct2(a="hello", b="world"):
return "{} {}".format(a, b)
@ray.remote
def keyword_fct3(a, b, c="hello", d="world"):
return "{} {} {} {}".format(a, b, c, d)
# Test variable numbers of arguments
@ray.remote
def varargs_fct1(*a):
return " ".join(map(str, a))
@ray.remote
def varargs_fct2(a, *b):
return " ".join(map(str, b))
try:
@ray.remote
def kwargs_throw_exception(**c):
return ()
kwargs_exception_thrown = False
except:
kwargs_exception_thrown = True
try:
@ray.remote
def varargs_and_kwargs_throw_exception(a, b="hi", *c):
return "{} {} {}".format(a, b, c)
varargs_and_kwargs_exception_thrown = False
except:
varargs_and_kwargs_exception_thrown = True
# test throwing an exception
@ray.remote
def throw_exception_fct1():
raise Exception("Test function 1 intentionally failed.")
@ray.remote
def throw_exception_fct2():
raise Exception("Test function 2 intentionally failed.")
@ray.remote(num_return_vals=3)
def throw_exception_fct3(x):
raise Exception("Test function 3 intentionally failed.")
# test Python mode
@ray.remote
def python_mode_f():
return np.array([0, 0])
@ray.remote
def python_mode_g(x):
x[0] = 1
return x
# test no return values
@ray.remote
def no_op():
pass
class TestClass(object):
def __init__(self):
self.a = 5
@ray.remote
def test_unknown_type():
return TestClass()
+1743
View File
File diff suppressed because it is too large Load Diff
View File
+61
View File
@@ -0,0 +1,61 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import numpy as np
import redis
import traceback
import ray
parser = argparse.ArgumentParser(description="Parse addresses for the worker to connect to.")
parser.add_argument("--node-ip-address", required=True, type=str, help="the ip address of the worker's node")
parser.add_argument("--redis-address", required=True, type=str, help="the address to use for Redis")
parser.add_argument("--object-store-name", required=True, type=str, help="the object store's name")
parser.add_argument("--object-store-manager-name", required=True, type=str, help="the object store manager's name")
parser.add_argument("--local-scheduler-name", required=True, type=str, help="the local scheduler's name")
def random_string():
return np.random.bytes(20)
if __name__ == "__main__":
args = parser.parse_args()
info = {"node_ip_address": args.node_ip_address,
"redis_address": args.redis_address,
"store_socket_name": args.object_store_name,
"manager_socket_name": args.object_store_manager_name,
"local_scheduler_socket_name": args.local_scheduler_name}
ray.worker.connect(info, ray.WORKER_MODE)
error_explanation = """
This error is unexpected and should not have happened. Somehow a worker crashed
in an unanticipated way causing the main_loop to throw an exception, which is
being caught in "lib/python/ray/workers/default_worker.py".
"""
while True:
try:
# This call to main_loop should never return if things are working. Most
# exceptions that are thrown (e.g., inside the execution of a task) should
# be caught and handled inside of the call to main_loop. If an exception
# is thrown here, then that means that there is some error that we didn't
# anticipate.
ray.worker.main_loop()
except Exception as e:
traceback_str = traceback.format_exc() + error_explanation
error_key = "WorkerError:{}".format(random_string())
redis_host, redis_port = args.redis_address.split(":")
# For this command to work, some other client (on the same machine as
# Redis) must have run "CONFIG SET protected-mode no".
redis_client = redis.StrictRedis(host=redis_host, port=int(redis_port))
redis_client.hmset(error_key, {"message": traceback_str,
"note": "This error is unexpected and should not have happened."})
redis_client.rpush("ErrorKeys", error_key)
# TODO(rkn): Note that if the worker was in the middle of executing a
# task, the any worker or driver that is blocking in a get call and
# waiting for the output of that task will hang. We need to address this.
# After putting the error message in Redis, this worker will attempt to
# reenter the main loop. TODO(rkn): We should probably reset it's state and
# call connect again.
+41
View File
@@ -0,0 +1,41 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import subprocess
from setuptools import setup, find_packages
import setuptools.command.install as _install
class install(_install.install):
def run(self):
subprocess.check_call(["../build.sh"])
# Calling _install.install.run(self) does not fetch required packages and
# instead performs an old-style install. See command/install.py in
# setuptools. So, calling do_egg_install() manually here.
self.do_egg_install()
setup(name="ray",
version="0.0.1",
packages=find_packages(),
package_data={"core": ["src/common/thirdparty/redis/src/redis-server",
"src/common/redis_module/libray_redis_module.so",
"src/plasma/plasma_store",
"src/plasma/plasma_manager",
"src/plasma/libplasma.so",
"src/photon/photon_scheduler",
"src/photon/libphoton.so",
"src/numbuf/libarrow.so",
"src/numbuf/libnumbuf.so",
"src/global_scheduler/global_scheduler"]},
cmdclass={"install": install},
install_requires=["numpy",
"funcsigs",
"colorama",
"psutil",
"redis",
"cloudpickle >= 0.2.2"],
include_package_data=True,
zip_safe=False,
license="Apache 2.0")
View File