Run flake8 in Travis and make code PEP8 compliant. (#387)

This commit is contained in:
Robert Nishihara
2017-03-21 12:57:54 -07:00
committed by Philipp Moritz
parent 083e7a28ad
commit ba02fc0eb0
54 changed files with 2391 additions and 1313 deletions
+22 -12
View File
@@ -2,19 +2,29 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Ray version string
__version__ = "0.01"
import ctypes
# Windows only
if hasattr(ctypes, "windll"):
# Makes sure that all child processes die when we die
# Also makes sure that fatal crashes result in process termination rather than an error dialog (the latter is annoying since we have a lot of processes)
# This is done by associating all child processes with a "job" object that imposes this behavior.
(lambda kernel32: (lambda job: (lambda n: kernel32.SetInformationJobObject(job, 9, "\0" * 17 + chr(0x8 | 0x4 | 0x20) + "\0" * (n - 18), n))(0x90 if ctypes.sizeof(ctypes.c_void_p) > ctypes.sizeof(ctypes.c_int) else 0x70) and kernel32.AssignProcessToJobObject(job, ctypes.c_void_p(kernel32.GetCurrentProcess())))(ctypes.c_void_p(kernel32.CreateJobObjectW(None, None))) if kernel32 is not None else None)(ctypes.windll.kernel32)
from ray.worker import register_class, error_info, init, connect, disconnect, get, put, wait, remote, log_event, log_span, flush_log
from ray.worker import (register_class, error_info, init, connect, disconnect,
get, put, wait, remote, log_event, log_span,
flush_log)
from ray.actor import actor
from ray.actor import get_gpu_ids
from ray.worker import EnvironmentVariable, env
from ray.worker import SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, SILENT_MODE
# Ray version string
__version__ = "0.01"
__all__ = ["register_class", "error_info", "init", "connect", "disconnect",
"get", "put", "wait", "remote", "log_event", "log_span",
"flush_log", "actor", "get_gpu_ids", "EnvironmentVariable", "env",
"SCRIPT_MODE", "WORKER_MODE", "PYTHON_MODE", "SILENT_MODE",
"__version__"]
import ctypes
# Windows only
if hasattr(ctypes, "windll"):
# Makes sure that all child processes die when we die. Also makes sure that
# fatal crashes result in process termination rather than an error dialog
# (the latter is annoying since we have a lot of processes). This is done by
# associating all child processes with a "job" object that imposes this
# behavior.
(lambda kernel32: (lambda job: (lambda n: kernel32.SetInformationJobObject(job, 9, "\0" * 17 + chr(0x8 | 0x4 | 0x20) + "\0" * (n - 18), n))(0x90 if ctypes.sizeof(ctypes.c_void_p) > ctypes.sizeof(ctypes.c_int) else 0x70) and kernel32.AssignProcessToJobObject(job, ctypes.c_void_p(kernel32.GetCurrentProcess())))(ctypes.c_void_p(kernel32.CreateJobObjectW(None, None))) if kernel32 is not None else None)(ctypes.windll.kernel32) # noqa: E501
+60 -21
View File
@@ -18,6 +18,7 @@ import ray.experimental.state as state
# the worker is currently allowed to use.
gpu_ids = []
def get_gpu_ids():
"""Get the IDs of the GPU that are available to the worker.
@@ -26,12 +27,15 @@ def get_gpu_ids():
"""
return gpu_ids
def random_string():
return np.random.bytes(20)
def random_actor_id():
return ray.local_scheduler.ObjectID(random_string())
def get_actor_method_function_id(attr):
"""Get the function ID corresponding to an actor method.
@@ -47,10 +51,14 @@ def get_actor_method_function_id(attr):
assert len(function_id) == 20
return ray.local_scheduler.ObjectID(function_id)
def fetch_and_register_actor(key, worker):
"""Import an actor."""
driver_id, actor_id_str, actor_name, module, pickled_class, assigned_gpu_ids, actor_method_names = \
worker.redis_client.hmget(key, ["driver_id", "actor_id", "name", "module", "class", "gpu_ids", "actor_method_names"])
(driver_id, actor_id_str, actor_name,
module, pickled_class, assigned_gpu_ids,
actor_method_names) = worker.redis_client.hmget(
key, ["driver_id", "actor_id", "name", "module", "class", "gpu_ids",
"actor_method_names"])
actor_id = ray.local_scheduler.ObjectID(actor_id_str)
actor_name = actor_name.decode("ascii")
module = module.decode("ascii")
@@ -64,12 +72,14 @@ def fetch_and_register_actor(key, worker):
class TemporaryActor(object):
pass
worker.actors[actor_id_str] = TemporaryActor()
def temporary_actor_method(*xs):
raise Exception("The actor with name {} failed to be imported, and so "
"cannot execute this method".format(actor_name))
for actor_method_name in actor_method_names:
function_id = get_actor_method_function_id(actor_method_name).id()
worker.functions[driver_id][function_id] = (actor_method_name, temporary_actor_method)
worker.functions[driver_id][function_id] = (actor_method_name,
temporary_actor_method)
try:
unpickled_class = pickling.loads(pickled_class)
@@ -84,11 +94,15 @@ def fetch_and_register_actor(key, worker):
# TODO(pcm): Why is the below line necessary?
unpickled_class.__module__ = module
worker.actors[actor_id_str] = unpickled_class.__new__(unpickled_class)
for (k, v) in inspect.getmembers(unpickled_class, predicate=(lambda x: inspect.isfunction(x) or inspect.ismethod(x))):
for (k, v) in inspect.getmembers(
unpickled_class, predicate=(lambda x: (inspect.isfunction(x) or
inspect.ismethod(x)))):
function_id = get_actor_method_function_id(k).id()
worker.functions[driver_id][function_id] = (k, v)
# We do not set worker.function_properties[driver_id][function_id] because
# we currently do need the actor worker to submit new tasks for the actor.
# We do not set worker.function_properties[driver_id][function_id]
# because we currently do need the actor worker to submit new tasks for
# the actor.
def select_local_scheduler(local_schedulers, num_gpus, worker):
"""Select a local scheduler to assign this actor to.
@@ -119,15 +133,19 @@ def select_local_scheduler(local_schedulers, num_gpus, worker):
# Loop through all of the local schedulers.
for local_scheduler in local_schedulers:
# See if there are enough available GPUs on this local scheduler.
local_scheduler_total_gpus = int(float(local_scheduler[b"num_gpus"].decode("ascii")))
gpus_in_use = worker.redis_client.hget(local_scheduler[b"ray_client_id"], b"gpus_in_use")
local_scheduler_total_gpus = int(float(
local_scheduler[b"num_gpus"].decode("ascii")))
gpus_in_use = worker.redis_client.hget(local_scheduler[b"ray_client_id"],
b"gpus_in_use")
gpus_in_use = 0 if gpus_in_use is None else int(gpus_in_use)
if gpus_in_use + num_gpus <= local_scheduler_total_gpus:
# Attempt to reserve some GPUs for this actor.
new_gpus_in_use = worker.redis_client.hincrby(local_scheduler[b"ray_client_id"], b"gpus_in_use", num_gpus)
new_gpus_in_use = worker.redis_client.hincrby(
local_scheduler[b"ray_client_id"], b"gpus_in_use", num_gpus)
if new_gpus_in_use > local_scheduler_total_gpus:
# If we failed to reserve the GPUs, undo the increment.
worker.redis_client.hincrby(local_scheduler[b"ray_client_id"], b"gpus_in_use", num_gpus)
worker.redis_client.hincrby(local_scheduler[b"ray_client_id"],
b"gpus_in_use", num_gpus)
else:
# We succeeded at reserving the GPUs, so we are done.
local_scheduler_id = local_scheduler[b"ray_client_id"]
@@ -135,10 +153,13 @@ def select_local_scheduler(local_schedulers, num_gpus, worker):
break
if local_scheduler_id is None:
raise Exception("Could not find a node with enough GPUs to create this "
"actor. The local scheduler information is {}.".format(local_schedulers))
"actor. The local scheduler information is {}."
.format(local_schedulers))
return local_scheduler_id, gpu_ids
def export_actor(actor_id, Class, actor_method_names, num_cpus, num_gpus, worker):
def export_actor(actor_id, Class, actor_method_names, num_cpus, num_gpus,
worker):
"""Export an actor to redis.
Args:
@@ -158,13 +179,16 @@ def export_actor(actor_id, Class, actor_method_names, num_cpus, num_gpus, worker
driver_id = worker.task_driver_id.id()
for actor_method_name in actor_method_names:
function_id = get_actor_method_function_id(actor_method_name).id()
worker.function_properties[driver_id][function_id] = (1, num_cpus, num_gpus)
worker.function_properties[driver_id][function_id] = (1, num_cpus,
num_gpus)
# Select a local scheduler for the actor.
local_schedulers = state.get_local_schedulers(worker)
local_scheduler_id, gpu_ids = select_local_scheduler(local_schedulers, num_gpus, worker)
local_scheduler_id, gpu_ids = select_local_scheduler(local_schedulers,
num_gpus, worker)
worker.redis_client.publish("actor_notifications", actor_id.id() + local_scheduler_id)
worker.redis_client.publish("actor_notifications",
actor_id.id() + local_scheduler_id)
d = {"driver_id": driver_id,
"actor_id": actor_id.id(),
@@ -176,6 +200,7 @@ def export_actor(actor_id, Class, actor_method_names, num_cpus, num_gpus, worker
worker.redis_client.hmset(key, d)
worker.redis_client.rpush("Exports", key)
def actor(*args, **kwargs):
def make_actor_decorator(num_cpus=1, num_gpus=0):
def make_actor(Class):
@@ -189,7 +214,8 @@ def actor(*args, **kwargs):
raise Exception("Actors currently do not support **kwargs.")
function_id = get_actor_method_function_id(attr)
# TODO(pcm): Extend args with keyword args.
object_ids = ray.worker.global_worker.submit_task(function_id, "", args,
object_ids = ray.worker.global_worker.submit_task(function_id, "",
args,
actor_id=actor_id)
if len(object_ids) == 1:
return object_ids[0]
@@ -199,24 +225,34 @@ def actor(*args, **kwargs):
class NewClass(object):
def __init__(self, *args, **kwargs):
self._ray_actor_id = random_actor_id()
self._ray_actor_methods = {k: v for (k, v) in inspect.getmembers(Class, predicate=(lambda x: inspect.isfunction(x) or inspect.ismethod(x)))}
export_actor(self._ray_actor_id, Class, self._ray_actor_methods.keys(), num_cpus, num_gpus, ray.worker.global_worker)
self._ray_actor_methods = {
k: v for (k, v) in inspect.getmembers(
Class, predicate=(lambda x: (inspect.isfunction(x) or
inspect.ismethod(x))))}
export_actor(self._ray_actor_id, Class,
self._ray_actor_methods.keys(), num_cpus, num_gpus,
ray.worker.global_worker)
# Call __init__ as a remote function.
if "__init__" in self._ray_actor_methods.keys():
actor_method_call(self._ray_actor_id, "__init__", *args, **kwargs)
else:
print("WARNING: this object has no __init__ method.")
# Make tab completion work.
def __dir__(self):
return self._ray_actor_methods
def __getattribute__(self, attr):
# The following is needed so we can still access self.actor_methods.
if attr in ["_ray_actor_id", "_ray_actor_methods"]:
return super(NewClass, self).__getattribute__(attr)
if attr in self._ray_actor_methods.keys():
return lambda *args, **kwargs: actor_method_call(self._ray_actor_id, attr, *args, **kwargs)
return lambda *args, **kwargs: actor_method_call(
self._ray_actor_id, attr, *args, **kwargs)
# There is no method with this name, so raise an exception.
raise AttributeError("'{}' Actor object has no attribute '{}'".format(Class, attr))
raise AttributeError("'{}' Actor object has no attribute '{}'"
.format(Class, attr))
def __repr__(self):
return "Actor(" + self._ray_actor_id.hex() + ")"
@@ -230,7 +266,9 @@ def actor(*args, **kwargs):
return make_actor_decorator(num_cpus=1, num_gpus=0)(Class)
# In this case, the actor decorator is something like @ray.actor(num_gpus=1).
if len(args) == 0 and len(kwargs) > 0 and all([key in ["num_cpus", "num_gpus"] for key in kwargs.keys()]):
if len(args) == 0 and len(kwargs) > 0 and all([key
in ["num_cpus", "num_gpus"]
for key in kwargs.keys()]):
num_cpus = kwargs["num_cpus"] if "num_cpus" in kwargs.keys() else 1
num_gpus = kwargs["num_gpus"] if "num_gpus" in kwargs.keys() else 0
return make_actor_decorator(num_cpus=num_cpus, num_gpus=num_gpus)
@@ -240,4 +278,5 @@ def actor(*args, **kwargs):
"some of the arguments 'num_cpus' or 'num_gpus' as in "
"'ray.actor(num_gpus=1)'.")
ray.worker.global_worker.fetch_and_register["Actor"] = fetch_and_register_actor
+146 -78
View File
@@ -2,17 +2,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import random
import subprocess
import redis
import sys
import time
import unittest
import redis
import ray.services
# Import flatbuffer bindings.
from ray.core.generated.SubscribeToNotificationsReply import SubscribeToNotificationsReply
from ray.core.generated.SubscribeToNotificationsReply \
import SubscribeToNotificationsReply
from ray.core.generated.TaskReply import TaskReply
from ray.core.generated.ResultTableReply import ResultTableReply
@@ -22,6 +21,7 @@ OBJECT_SUBSCRIBE_PREFIX = "OS:"
TASK_PREFIX = "TT:"
OBJECT_CHANNEL_PREFIX = "OC:"
def integerToAsciiHex(num, numbytes):
retstr = b""
# Support 32 and 64 bit architecture.
@@ -36,6 +36,7 @@ def integerToAsciiHex(num, numbytes):
return retstr
def get_next_message(pubsub_client, timeout_seconds=10):
"""Block until the next message is available on the pubsub channel."""
start_time = time.time()
@@ -47,6 +48,7 @@ def get_next_message(pubsub_client, timeout_seconds=10):
if time.time() - start_time > timeout_seconds:
raise Exception("Timed out while waiting for next message.")
class TestGlobalStateStore(unittest.TestCase):
def setUp(self):
@@ -57,102 +59,144 @@ class TestGlobalStateStore(unittest.TestCase):
ray.services.cleanup()
def testInvalidObjectTableAdd(self):
# Check that Redis returns an error when RAY.OBJECT_TABLE_ADD is called with
# the wrong arguments.
# Check that Redis returns an error when RAY.OBJECT_TABLE_ADD is called
# with the wrong arguments.
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "hello")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", "one", "hash2", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", "one",
"hash2", "manager_id1")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", 1, "hash2", "manager_id1", "extra argument")
# Check that Redis returns an error when RAY.OBJECT_TABLE_ADD adds an object
# ID that is already present with a different hash.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", 1,
"hash2", "manager_id1", "extra argument")
# Check that Redis returns an error when RAY.OBJECT_TABLE_ADD adds an
# object ID that is already present with a different hash.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id1")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1"})
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash2", "manager_id2")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash2", "manager_id2")
# Check that the second manager was added, even though the hash was
# mismatched.
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
# Check that it is fine if we add the same object ID multiple times with the
# most recent hash.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash2", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash2", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash2", "manager_id2")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 2, "hash2", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
# Check that it is fine if we add the same object ID multiple times with
# the most recent hash.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash2", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash2", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash2", "manager_id2")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 2,
"hash2", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
def testObjectTableAddAndLookup(self):
# Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not been
# added yet.
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(response, None)
# Add some managers and try again.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
# Add a manager that already exists again and try again.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
# Check that we properly handle NULL characters. In the past, NULL
# characters were handled improperly causing a "hash mismatch" error if two
# object IDs that agreed up to the NULL character were inserted with
# different hashes.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id3", 1, "hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id4", 1, "hash2", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id3", 1,
"hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id4", 1,
"hash2", "manager_id1")
# Check that NULL characters in the hash are handled properly.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1, "\x00hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1,
"\x00hash1", "manager_id1")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1, "\x00hash2", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1,
"\x00hash2", "manager_id1")
def testObjectTableAddAndRemove(self):
# Try removing a manager from an object ID that has not been added yet.
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
"manager_id1")
# Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not been
# added yet.
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(response, None)
# Add some managers and try again.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
# Remove a manager that doesn't exist, and make sure we still have the same set.
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id3")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
# Remove a manager that doesn't exist, and make sure we still have the same
# set.
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
"manager_id3")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
# Remove a manager that does exist. Make sure it gets removed the first
# time and does nothing the second time.
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id1")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
"manager_id1")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id2"})
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id1")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
"manager_id1")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id2"})
# Remove the last manager, and make sure we have an empty set.
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
"manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), set())
# Remove a manager from an empty set, and make sure we now have an empty set.
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id3")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
# Remove a manager from an empty set, and make sure we now have an empty
# set.
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
"manager_id3")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), set())
def testObjectTableSubscribeToNotifications(self):
# Define a helper method for checking the contents of object notifications.
def check_object_notification(notification_message, object_id, object_size, manager_ids):
notification_object = SubscribeToNotificationsReply.GetRootAsSubscribeToNotificationsReply(notification_message, 0)
def check_object_notification(notification_message, object_id, object_size,
manager_ids):
notification_object = (SubscribeToNotificationsReply
.GetRootAsSubscribeToNotificationsReply(
notification_message, 0))
self.assertEqual(notification_object.ObjectId(), object_id)
self.assertEqual(notification_object.ObjectSize(), object_size)
self.assertEqual(notification_object.ManagerIdsLength(), len(manager_ids))
self.assertEqual(notification_object.ManagerIdsLength(),
len(manager_ids))
for i in range(len(manager_ids)):
self.assertEqual(notification_object.ManagerIds(i), manager_ids[i])
@@ -160,37 +204,45 @@ class TestGlobalStateStore(unittest.TestCase):
p = self.redis.pubsub()
# Subscribe to an object ID.
p.psubscribe("{}manager_id1".format(OBJECT_CHANNEL_PREFIX))
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", data_size, "hash1", "manager_id2")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", data_size,
"hash1", "manager_id2")
# Receive the acknowledgement message.
self.assertEqual(get_next_message(p)["data"], 1)
# Request a notification and receive the data.
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", "manager_id1", "object_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
"manager_id1", "object_id1")
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id1",
data_size,
[b"manager_id2"])
# Request a notification for an object that isn't there. Then add the object
# and receive the data. Only the first call to RAY.OBJECT_TABLE_ADD should
# trigger notifications.
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", "manager_id1", "object_id2", "object_id3")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, "hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, "hash1", "manager_id2")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, "hash1", "manager_id3")
# Request a notification for an object that isn't there. Then add the
# object and receive the data. Only the first call to RAY.OBJECT_TABLE_ADD
# should trigger notifications.
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
"manager_id1", "object_id2", "object_id3")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size,
"hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size,
"hash1", "manager_id2")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size,
"hash1", "manager_id3")
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id3",
data_size,
[b"manager_id1"])
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", data_size, "hash1", "manager_id3")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", data_size,
"hash1", "manager_id3")
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id2",
data_size,
[b"manager_id3"])
# Request notifications for object_id3 again.
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", "manager_id1", "object_id3")
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
"manager_id1", "object_id3")
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id3",
@@ -199,29 +251,38 @@ class TestGlobalStateStore(unittest.TestCase):
def testResultTableAddAndLookup(self):
def check_result_table_entry(message, task_id, is_put):
result_table_reply = ResultTableReply.GetRootAsResultTableReply(message, 0)
result_table_reply = ResultTableReply.GetRootAsResultTableReply(message,
0)
self.assertEqual(result_table_reply.TaskId(), task_id)
self.assertEqual(result_table_reply.IsPut(), is_put)
# Try looking up something in the result table before anything is added.
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1")
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
"object_id1")
self.assertIsNone(response)
# Adding the object to the object table should have no effect.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1")
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id1")
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
"object_id1")
self.assertIsNone(response)
# Add the result to the result table. The lookup now returns the task ID.
task_id = b"task_id1"
self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id1", task_id, 0)
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1")
self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id1", task_id,
0)
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
"object_id1")
check_result_table_entry(response, task_id, False)
# Doing it again should still work.
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1")
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
"object_id1")
check_result_table_entry(response, task_id, False)
# Try another result table lookup. This should succeed.
task_id = b"task_id2"
self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id2", task_id, 1)
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id2")
self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id2", task_id,
1)
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
"object_id2")
check_result_table_entry(response, task_id, True)
def testInvalidTaskTableAdd(self):
@@ -251,7 +312,8 @@ class TestGlobalStateStore(unittest.TestCase):
task_status, local_scheduler_id, task_spec = task_args
task_reply_object = TaskReply.GetRootAsTaskReply(message, 0)
self.assertEqual(task_reply_object.State(), task_status)
self.assertEqual(task_reply_object.LocalSchedulerId(), local_scheduler_id)
self.assertEqual(task_reply_object.LocalSchedulerId(),
local_scheduler_id)
self.assertEqual(task_reply_object.TaskSpec(), task_spec)
self.assertEqual(task_reply_object.Updated(), updated)
@@ -263,7 +325,8 @@ class TestGlobalStateStore(unittest.TestCase):
check_task_reply(response, task_args)
task_args[0] = TASK_STATUS_SCHEDULED
self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", *task_args[:2])
self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id",
*task_args[:2])
response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
check_task_reply(response, task_args)
@@ -331,9 +394,12 @@ class TestGlobalStateStore(unittest.TestCase):
# Subscribe to the task table.
p = self.redis.pubsub()
p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
p.psubscribe("{prefix}*:{state}".format(prefix=TASK_PREFIX, state=scheduling_state))
p.psubscribe("{prefix}{local_scheduler_id}:*".format(prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
task_args = [b"task_id", scheduling_state, local_scheduler_id.encode("ascii"), b"task_spec"]
p.psubscribe("{prefix}*:{state}".format(
prefix=TASK_PREFIX, state=scheduling_state))
p.psubscribe("{prefix}{local_scheduler_id}:*".format(
prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
task_args = [b"task_id", scheduling_state,
local_scheduler_id.encode("ascii"), b"task_spec"]
self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args)
# Receive the acknowledgement message.
self.assertEqual(get_next_message(p)["data"], 1)
@@ -346,8 +412,10 @@ class TestGlobalStateStore(unittest.TestCase):
notification_object = TaskReply.GetRootAsTaskReply(message, 0)
self.assertEqual(notification_object.TaskId(), b"task_id")
self.assertEqual(notification_object.State(), scheduling_state)
self.assertEqual(notification_object.LocalSchedulerId(), local_scheduler_id.encode("ascii"))
self.assertEqual(notification_object.LocalSchedulerId(),
local_scheduler_id.encode("ascii"))
self.assertEqual(notification_object.TaskSpec(), b"task_spec")
if __name__ == "__main__":
unittest.main(verbosity=2)
+58 -41
View File
@@ -11,25 +11,29 @@ import ray.local_scheduler as local_scheduler
ID_SIZE = 20
def random_object_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
def random_function_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
def random_driver_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
def random_task_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
BASE_SIMPLE_OBJECTS = [
0, 1, 100000, 0.0, 0.5, 0.9, 100000.1, (), [], {},
"", 990 * "h", u"", 990 * u"h"
]
0, 1, 100000, 0.0, 0.5, 0.9, 100000.1, (), [], {}, "", 990 * "h", u"",
990 * u"h"]
if sys.version_info < (3, 0):
BASE_SIMPLE_OBJECTS += [long(0), long(1), long(100000), long(1 << 100)]
BASE_SIMPLE_OBJECTS += [long(0), long(1), long(100000), long(1 << 100)] # noqa: E501,F821
LIST_SIMPLE_OBJECTS = [[obj] for obj in BASE_SIMPLE_OBJECTS]
TUPLE_SIMPLE_OBJECTS = [(obj,) for obj in BASE_SIMPLE_OBJECTS]
@@ -45,11 +49,14 @@ SIMPLE_OBJECTS = (BASE_SIMPLE_OBJECTS +
l = []
l.append(l)
class Foo(object):
def __init__(self):
pass
BASE_COMPLEX_OBJECTS = [999 * "h", 999 * u"h", l, Foo(), 10 * [10 * [10 * [1]]]]
BASE_COMPLEX_OBJECTS = [999 * "h", 999 * u"h", l, Foo(),
10 * [10 * [10 * [1]]]]
LIST_COMPLEX_OBJECTS = [[obj] for obj in BASE_COMPLEX_OBJECTS]
TUPLE_COMPLEX_OBJECTS = [(obj,) for obj in BASE_COMPLEX_OBJECTS]
@@ -60,6 +67,7 @@ COMPLEX_OBJECTS = (BASE_COMPLEX_OBJECTS +
TUPLE_COMPLEX_OBJECTS +
DICT_COMPLEX_OBJECTS)
class TestSerialization(unittest.TestCase):
def test_serialize_by_value(self):
@@ -69,27 +77,31 @@ class TestSerialization(unittest.TestCase):
for val in COMPLEX_OBJECTS:
self.assertFalse(local_scheduler.check_simple_value(val))
class TestObjectID(unittest.TestCase):
def test_create_object_id(self):
object_id = random_object_id()
random_object_id()
def test_cannot_pickle_object_ids(self):
object_ids = [random_object_id() for _ in range(256)]
def f():
return object_ids
def g(val=object_ids):
return 1
def h():
x = object_ids[0]
object_ids[0]
return 1
# Make sure that object IDs cannot be pickled (including functions that
# close over object IDs).
self.assertRaises(Exception, lambda : pickling.dumps(object_ids[0]))
self.assertRaises(Exception, lambda : pickling.dumps(object_ids))
self.assertRaises(Exception, lambda : pickling.dumps(f))
self.assertRaises(Exception, lambda : pickling.dumps(g))
self.assertRaises(Exception, lambda : pickling.dumps(h))
self.assertRaises(Exception, lambda: pickle.dumps(object_ids[0]))
self.assertRaises(Exception, lambda: pickle.dumps(object_ids))
self.assertRaises(Exception, lambda: pickle.dumps(f))
self.assertRaises(Exception, lambda: pickle.dumps(g))
self.assertRaises(Exception, lambda: pickle.dumps(h))
def test_equality_comparisons(self):
x1 = local_scheduler.ObjectID(ID_SIZE * b"a")
@@ -101,8 +113,10 @@ class TestObjectID(unittest.TestCase):
self.assertNotEqual(x1, y1)
random_strings = [np.random.bytes(ID_SIZE) for _ in range(256)]
object_ids1 = [local_scheduler.ObjectID(random_strings[i]) for i in range(256)]
object_ids2 = [local_scheduler.ObjectID(random_strings[i]) for i in range(256)]
object_ids1 = [local_scheduler.ObjectID(random_strings[i])
for i in range(256)]
object_ids2 = [local_scheduler.ObjectID(random_strings[i])
for i in range(256)]
self.assertEqual(len(set(object_ids1)), 256)
self.assertEqual(len(set(object_ids1 + object_ids2)), 256)
self.assertEqual(set(object_ids1), set(object_ids2))
@@ -113,6 +127,7 @@ class TestObjectID(unittest.TestCase):
{x: y}
set([x, y])
class TestTask(unittest.TestCase):
def check_task(self, task, function_id, num_return_vals, args):
@@ -127,44 +142,46 @@ class TestTask(unittest.TestCase):
self.assertEqual(retrieved_args[i], args[i])
def test_create_and_serialize_task(self):
# TODO(rkn): The function ID should be a FunctionID object, not an ObjectID.
# TODO(rkn): The function ID should be a FunctionID object, not an
# ObjectID.
driver_id = random_driver_id()
parent_id = random_task_id()
function_id = random_function_id()
object_ids = [random_object_id() for _ in range(256)]
args_list = [
[],
1 * [1],
10 * [1],
100 * [1],
1000 * [1],
1 * ["a"],
10 * ["a"],
100 * ["a"],
1000 * ["a"],
[1, 1.3, 2, 1 << 100, "hi", u"hi", [1, 2]],
object_ids[:1],
object_ids[:2],
object_ids[:3],
object_ids[:4],
object_ids[:5],
object_ids[:10],
object_ids[:100],
object_ids[:256],
[1, object_ids[0]],
[object_ids[0], "a"],
[1, object_ids[0], "a"],
[object_ids[0], 1, object_ids[1], "a"],
object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
object_ids + 100 * ["a"] + object_ids
]
[],
1 * [1],
10 * [1],
100 * [1],
1000 * [1],
1 * ["a"],
10 * ["a"],
100 * ["a"],
1000 * ["a"],
[1, 1.3, 2, 1 << 100, "hi", u"hi", [1, 2]],
object_ids[:1],
object_ids[:2],
object_ids[:3],
object_ids[:4],
object_ids[:5],
object_ids[:10],
object_ids[:100],
object_ids[:256],
[1, object_ids[0]],
[object_ids[0], "a"],
[1, object_ids[0], "a"],
[object_ids[0], 1, object_ids[1], "a"],
object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
object_ids + 100 * ["a"] + object_ids]
for args in args_list:
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
task = local_scheduler.Task(driver_id, function_id, args, num_return_vals, parent_id, 0)
task = local_scheduler.Task(driver_id, function_id, args,
num_return_vals, parent_id, 0)
self.check_task(task, function_id, num_return_vals, args)
data = local_scheduler.task_to_string(task)
task2 = local_scheduler.task_from_string(data)
self.check_task(task2, function_id, num_return_vals, args)
if __name__ == "__main__":
unittest.main(verbosity=2)
+2
View File
@@ -4,3 +4,5 @@ from __future__ import print_function
from .utils import copy_directory
from .tfutils import TensorFlowVariables
__all__ = ["copy_directory", "TensorFlowVariables"]
@@ -4,4 +4,10 @@ from __future__ import print_function
from . import random
from . import linalg
from .core import *
from .core import (BLOCK_SIZE, DistArray, assemble, zeros, ones, copy, eye,
triu, tril, blockwise_dot, dot, transpose, add, subtract,
numpy_to_dist, subblocks)
__all__ = ["random", "linalg", "BLOCK_SIZE", "DistArray", "assemble", "zeros",
"ones", "copy", "eye", "triu", "tril", "blockwise_dot", "dot",
"transpose", "add", "subtract", "numpy_to_dist", "subblocks"]
@@ -6,30 +6,38 @@ import numpy as np
import ray.experimental.array.remote as ra
import ray
__all__ = ["BLOCK_SIZE", "DistArray", "assemble", "zeros", "ones", "copy",
"eye", "triu", "tril", "blockwise_dot", "dot", "transpose", "add", "subtract", "numpy_to_dist", "subblocks"]
BLOCK_SIZE = 10
class DistArray(object):
def __init__(self, shape, objectids=None):
self.shape = shape
self.ndim = len(shape)
self.num_blocks = [int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in self.shape]
self.objectids = objectids if objectids is not None else np.empty(self.num_blocks, dtype=object)
if objectids is not None:
self.objectids = objectids
else:
self.objectids = np.empty(self.num_blocks, dtype=object)
if self.num_blocks != list(self.objectids.shape):
raise Exception("The fields `num_blocks` and `objectids` are inconsistent, `num_blocks` is {} and `objectids` has shape {}".format(self.num_blocks, list(self.objectids.shape)))
raise Exception("The fields `num_blocks` and `objectids` are "
"inconsistent, `num_blocks` is {} and `objectids` has "
"shape {}".format(self.num_blocks,
list(self.objectids.shape)))
@staticmethod
def compute_block_lower(index, shape):
if len(index) != len(shape):
raise Exception("The fields `index` and `shape` must have the same length, but `index` is {} and `shape` is {}.".format(index, shape))
raise Exception("The fields `index` and `shape` must have the same "
"length, but `index` is {} and `shape` is "
"{}.".format(index, shape))
return [elem * BLOCK_SIZE for elem in index]
@staticmethod
def compute_block_upper(index, shape):
if len(index) != len(shape):
raise Exception("The fields `index` and `shape` must have the same length, but `index` is {} and `shape` is {}.".format(index, shape))
raise Exception("The fields `index` and `shape` must have the same "
"length, but `index` is {} and `shape` is "
"{}.".format(index, shape))
upper = []
for i in range(len(shape)):
upper.append(min((index[i] + 1) * BLOCK_SIZE, shape[i]))
@@ -46,59 +54,73 @@ class DistArray(object):
return [int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in shape]
def assemble(self):
"""Assemble an array on this node from a distributed array of object IDs."""
"""Assemble an array from a distributed array of object IDs."""
first_block = ray.get(self.objectids[(0,) * self.ndim])
dtype = first_block.dtype
result = np.zeros(self.shape, dtype=dtype)
for index in np.ndindex(*self.num_blocks):
lower = DistArray.compute_block_lower(index, self.shape)
upper = DistArray.compute_block_upper(index, self.shape)
result[[slice(l, u) for (l, u) in zip(lower, upper)]] = ray.get(self.objectids[index])
result[[slice(l, u) for (l, u) in zip(lower, upper)]] = ray.get(
self.objectids[index])
return result
def __getitem__(self, sliced):
# TODO(rkn): fix this, this is just a placeholder that should work but is inefficient
# TODO(rkn): Fix this, this is just a placeholder that should work but is
# inefficient.
a = self.assemble()
return a[sliced]
# Register the DistArray class with Ray so that it knows how to serialize it.
ray.register_class(DistArray)
@ray.remote
def assemble(a):
return a.assemble()
# TODO(rkn): what should we call this method
# TODO(rkn): What should we call this method?
@ray.remote
def numpy_to_dist(a):
result = DistArray(a.shape)
for index in np.ndindex(*result.num_blocks):
lower = DistArray.compute_block_lower(index, a.shape)
upper = DistArray.compute_block_upper(index, a.shape)
result.objectids[index] = ray.put(a[[slice(l, u) for (l, u) in zip(lower, upper)]])
result.objectids[index] = ray.put(a[[slice(l, u) for (l, u)
in zip(lower, upper)]])
return result
@ray.remote
def zeros(shape, dtype_name="float"):
result = DistArray(shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = ra.zeros.remote(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
result.objectids[index] = ra.zeros.remote(
DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
return result
@ray.remote
def ones(shape, dtype_name="float"):
result = DistArray(shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = ra.ones.remote(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
result.objectids[index] = ra.ones.remote(
DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
return result
@ray.remote
def copy(a):
result = DistArray(a.shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = a.objectids[index] # We don't need to actually copy the objects because cluster-level objects are assumed to be immutable.
# We don't need to actually copy the objects because remote objects are
# immutable.
result.objectids[index] = a.objectids[index]
return result
@ray.remote
def eye(dim1, dim2=-1, dtype_name="float"):
dim2 = dim1 if dim2 == -1 else dim2
@@ -107,15 +129,19 @@ def eye(dim1, dim2=-1, dtype_name="float"):
for (i, j) in np.ndindex(*result.num_blocks):
block_shape = DistArray.compute_block_shape([i, j], shape)
if i == j:
result.objectids[i, j] = ra.eye.remote(block_shape[0], block_shape[1], dtype_name=dtype_name)
result.objectids[i, j] = ra.eye.remote(block_shape[0], block_shape[1],
dtype_name=dtype_name)
else:
result.objectids[i, j] = ra.zeros.remote(block_shape, dtype_name=dtype_name)
result.objectids[i, j] = ra.zeros.remote(block_shape,
dtype_name=dtype_name)
return result
@ray.remote
def triu(a):
if a.ndim != 2:
raise Exception("Input must have 2 dimensions, but a.ndim is " + str(a.ndim))
raise Exception("Input must have 2 dimensions, but a.ndim is "
"{}.".format(a.ndim))
result = DistArray(a.shape)
for (i, j) in np.ndindex(*result.num_blocks):
if i < j:
@@ -126,10 +152,12 @@ def triu(a):
result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j])
return result
@ray.remote
def tril(a):
if a.ndim != 2:
raise Exception("Input must have 2 dimensions, but a.ndim is " + str(a.ndim))
raise Exception("Input must have 2 dimensions, but a.ndim is "
"{}.".format(a.ndim))
result = DistArray(a.shape)
for (i, j) in np.ndindex(*result.num_blocks):
if i > j:
@@ -140,25 +168,31 @@ def tril(a):
result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j])
return result
@ray.remote
def blockwise_dot(*matrices):
n = len(matrices)
if n % 2 != 0:
raise Exception("blockwise_dot expects an even number of arguments, but len(matrices) is {}.".format(n))
raise Exception("blockwise_dot expects an even number of arguments, but "
"len(matrices) is {}.".format(n))
shape = (matrices[0].shape[0], matrices[n // 2].shape[1])
result = np.zeros(shape)
for i in range(n // 2):
result += np.dot(matrices[i], matrices[n // 2 + i])
return result
@ray.remote
def dot(a, b):
if a.ndim != 2:
raise Exception("dot expects its arguments to be 2-dimensional, but a.ndim = {}.".format(a.ndim))
raise Exception("dot expects its arguments to be 2-dimensional, but "
"a.ndim = {}.".format(a.ndim))
if b.ndim != 2:
raise Exception("dot expects its arguments to be 2-dimensional, but b.ndim = {}.".format(b.ndim))
raise Exception("dot expects its arguments to be 2-dimensional, but "
"b.ndim = {}.".format(b.ndim))
if a.shape[1] != b.shape[0]:
raise Exception("dot expects a.shape[1] to equal b.shape[0], but a.shape = {} and b.shape = {}.".format(a.shape, b.shape))
raise Exception("dot expects a.shape[1] to equal b.shape[0], but a.shape "
"= {} and b.shape = {}.".format(a.shape, b.shape))
shape = [a.shape[0], b.shape[1]]
result = DistArray(shape)
for (i, j) in np.ndindex(*result.num_blocks):
@@ -166,10 +200,12 @@ def dot(a, b):
result.objectids[i, j] = blockwise_dot.remote(*args)
return result
@ray.remote
def subblocks(a, *ranges):
"""
This function produces a distributed array from a subset of the blocks in the `a`. The result and `a` will have the same number of dimensions.For example,
This function produces a distributed array from a subset of the blocks in the
`a`. The result and `a` will have the same number of dimensions.For example,
subblocks(a, [0, 1], [2, 4])
will produce a DistArray whose objectids are
[[a.objectids[0, 2], a.objectids[0, 4]],
@@ -178,50 +214,71 @@ def subblocks(a, *ranges):
"""
ranges = list(ranges)
if len(ranges) != a.ndim:
raise Exception("sub_blocks expects to receive a number of ranges equal to a.ndim, but it received {} ranges and a.ndim = {}.".format(len(ranges), a.ndim))
raise Exception("sub_blocks expects to receive a number of ranges equal "
"to a.ndim, but it received {} ranges and a.ndim = "
"{}.".format(len(ranges), a.ndim))
for i in range(len(ranges)):
if ranges[i] == []: # We allow the user to pass in an empty list to indicate the full range
# We allow the user to pass in an empty list to indicate the full range.
if ranges[i] == []:
ranges[i] = range(a.num_blocks[i])
if not np.alltrue(ranges[i] == np.sort(ranges[i])):
raise Exception("Ranges passed to sub_blocks must be sorted, but the {}th range is {}.".format(i, ranges[i]))
raise Exception("Ranges passed to sub_blocks must be sorted, but the "
"{}th range is {}.".format(i, ranges[i]))
if ranges[i][0] < 0:
raise Exception("Values in the ranges passed to sub_blocks must be at least 0, but the {}th range is {}.".format(i, ranges[i]))
raise Exception("Values in the ranges passed to sub_blocks must be at "
"least 0, but the {}th range is {}.".format(i,
ranges[i]))
if ranges[i][-1] >= a.num_blocks[i]:
raise Exception("Values in the ranges passed to sub_blocks must be less than the relevant number of blocks, but the {}th range is {}, and a.num_blocks = {}.".format(i, ranges[i], a.num_blocks))
raise Exception("Values in the ranges passed to sub_blocks must be less "
"than the relevant number of blocks, but the {}th range "
"is {}, and a.num_blocks = {}.".format(i, ranges[i],
a.num_blocks))
last_index = [r[-1] for r in ranges]
last_block_shape = DistArray.compute_block_shape(last_index, a.shape)
shape = [(len(ranges[i]) - 1) * BLOCK_SIZE + last_block_shape[i] for i in range(a.ndim)]
shape = [(len(ranges[i]) - 1) * BLOCK_SIZE + last_block_shape[i]
for i in range(a.ndim)]
result = DistArray(shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = a.objectids[tuple([ranges[i][index[i]] for i in range(a.ndim)])]
result.objectids[index] = a.objectids[tuple([ranges[i][index[i]]
for i in range(a.ndim)])]
return result
@ray.remote
def transpose(a):
if a.ndim != 2:
raise Exception("transpose expects its argument to be 2-dimensional, but a.ndim = {}, a.shape = {}.".format(a.ndim, a.shape))
raise Exception("transpose expects its argument to be 2-dimensional, but "
"a.ndim = {}, a.shape = {}.".format(a.ndim, a.shape))
result = DistArray([a.shape[1], a.shape[0]])
for i in range(result.num_blocks[0]):
for j in range(result.num_blocks[1]):
result.objectids[i, j] = ra.transpose.remote(a.objectids[j, i])
return result
# TODO(rkn): support broadcasting?
@ray.remote
def add(x1, x2):
if x1.shape != x2.shape:
raise Exception("add expects arguments `x1` and `x2` to have the same shape, but x1.shape = {}, and x2.shape = {}.".format(x1.shape, x2.shape))
raise Exception("add expects arguments `x1` and `x2` to have the same "
"shape, but x1.shape = {}, and x2.shape = {}."
.format(x1.shape, x2.shape))
result = DistArray(x1.shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = ra.add.remote(x1.objectids[index], x2.objectids[index])
result.objectids[index] = ra.add.remote(x1.objectids[index],
x2.objectids[index])
return result
# TODO(rkn): support broadcasting?
@ray.remote
def subtract(x1, x2):
if x1.shape != x2.shape:
raise Exception("subtract expects arguments `x1` and `x2` to have the same shape, but x1.shape = {}, and x2.shape = {}.".format(x1.shape, x2.shape))
raise Exception("subtract expects arguments `x1` and `x2` to have the "
"same shape, but x1.shape = {}, and x2.shape = {}."
.format(x1.shape, x2.shape))
result = DistArray(x1.shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = ra.subtract.remote(x1.objectids[index], x2.objectids[index])
result.objectids[index] = ra.subtract.remote(x1.objectids[index],
x2.objectids[index])
return result
@@ -6,30 +6,32 @@ import numpy as np
import ray.experimental.array.remote as ra
import ray
from .core import *
from . import core
__all__ = ["tsqr", "modified_lu", "tsqr_hr", "qr"]
@ray.remote(num_return_vals=2)
def tsqr(a):
"""
arguments:
a: a distributed matrix
Suppose that
a.shape == (M, N)
K == min(M, N)
return values:
q: DistArray, if q_full = ray.get(DistArray, q).assemble(), then
q_full.shape == (M, K)
np.allclose(np.dot(q_full.T, q_full), np.eye(K)) == True
r: np.ndarray, if r_val = ray.get(np.ndarray, r), then
r_val.shape == (K, N)
np.allclose(r, np.triu(r)) == True
"""Perform a QR decomposition of a tall-skinny matrix.
Args:
a: A distributed matrix with shape MxN (suppose K = min(M, N)).
Returns:
A tuple of q (a DistArray) and r (a numpy array) satisfying the following.
- If q_full = ray.get(DistArray, q).assemble(), then
q_full.shape == (M, K).
- np.allclose(np.dot(q_full.T, q_full), np.eye(K)) == True.
- If r_val = ray.get(np.ndarray, r), then r_val.shape == (K, N).
- np.allclose(r, np.triu(r)) == True.
"""
if len(a.shape) != 2:
raise Exception("tsqr requires len(a.shape) == 2, but a.shape is {}".format(a.shape))
raise Exception("tsqr requires len(a.shape) == 2, but a.shape is "
"{}".format(a.shape))
if a.num_blocks[1] != 1:
raise Exception("tsqr requires a.num_blocks[1] == 1, but a.num_blocks is {}".format(a.num_blocks))
raise Exception("tsqr requires a.num_blocks[1] == 1, but a.num_blocks is "
"{}".format(a.num_blocks))
num_blocks = a.num_blocks[0]
K = int(np.ceil(np.log2(num_blocks))) + 1
@@ -57,9 +59,9 @@ def tsqr(a):
q_shape = a.shape
else:
q_shape = [a.shape[0], a.shape[0]]
q_num_blocks = DistArray.compute_num_blocks(q_shape)
q_num_blocks = core.DistArray.compute_num_blocks(q_shape)
q_objectids = np.empty(q_num_blocks, dtype=object)
q_result = DistArray(q_shape, q_objectids)
q_result = core.DistArray(q_shape, q_objectids)
# reconstruct output
for i in range(num_blocks):
@@ -68,28 +70,35 @@ def tsqr(a):
for j in range(1, K):
if np.mod(ith_index, 2) == 0:
lower = [0, 0]
upper = [a.shape[1], BLOCK_SIZE]
upper = [a.shape[1], core.BLOCK_SIZE]
else:
lower = [a.shape[1], 0]
upper = [2 * a.shape[1], BLOCK_SIZE]
upper = [2 * a.shape[1], core.BLOCK_SIZE]
ith_index //= 2
q_block_current = ra.dot.remote(q_block_current, ra.subarray.remote(q_tree[ith_index, j], lower, upper))
q_block_current = ra.dot.remote(q_block_current,
ra.subarray.remote(q_tree[ith_index, j],
lower, upper))
q_result.objectids[i] = q_block_current
r = current_rs[0]
return q_result, ray.get(r)
# TODO(rkn): This is unoptimized, we really want a block version of this.
# This is Algorithm 5 from
# http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf.
@ray.remote(num_return_vals=3)
def modified_lu(q):
"""
Algorithm 5 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf
takes a matrix q with orthonormal columns, returns l, u, s such that q - s = l * u
arguments:
q: a two dimensional orthonormal q
return values:
l: lower triangular
u: upper triangular
s: a diagonal matrix represented by its diagonal
"""Perform a modified LU decomposition of a matrix.
This takes a matrix q with orthonormal columns, returns l, u, s such that
q - s = l * u.
Args:
q: A two dimensional orthonormal matrix q.
Returns:
A tuple of a lower triangular matrix l, an upper triangular matrix u, and a
a vector representing a diagonal matrix s such that q - s = l * u.
"""
q = q.assemble()
m, b = q.shape[0], q.shape[1]
@@ -100,14 +109,19 @@ def modified_lu(q):
for i in range(b):
S[i] = -1 * np.sign(q_work[i, i])
q_work[i, i] -= S[i]
q_work[(i + 1):m, i] /= q_work[i, i] # scale ith column of L by diagonal element
q_work[(i + 1):m, (i + 1):b] -= np.outer(q_work[(i + 1):m, i], q_work[i, (i + 1):b]) # perform Schur complement update
# Scale ith column of L by diagonal element.
q_work[(i + 1):m, i] /= q_work[i, i]
# Perform Schur complement update.
q_work[(i + 1):m, (i + 1):b] -= np.outer(q_work[(i + 1):m, i],
q_work[i, (i + 1):b])
L = np.tril(q_work)
for i in range(b):
L[i, i] = 1
U = np.triu(q_work)[:b, :]
return ray.get(numpy_to_dist.remote(ray.put(L))), U, S # TODO(rkn): get rid of put
# TODO(rkn): Get rid of the put below.
return ray.get(core.numpy_to_dist.remote(ray.put(L))), U, S
@ray.remote(num_return_vals=2)
def tsqr_hr_helper1(u, s, y_top_block, b):
@@ -116,45 +130,60 @@ def tsqr_hr_helper1(u, s, y_top_block, b):
t = -1 * np.dot(u, np.dot(s_full, np.linalg.inv(y_top).T))
return t, y_top
@ray.remote
def tsqr_hr_helper2(s, r_temp):
s_full = np.diag(s)
return np.dot(s_full, r_temp)
# This is Algorithm 6 from
# http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf.
@ray.remote(num_return_vals=4)
def tsqr_hr(a):
"""Algorithm 6 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf"""
q, r_temp = tsqr.remote(a)
y, u, s = modified_lu.remote(q)
y_blocked = ray.get(y)
t, y_top = tsqr_hr_helper1.remote(u, s, y_blocked.objectids[0, 0], a.shape[1])
t, y_top = tsqr_hr_helper1.remote(u, s, y_blocked.objectids[0, 0],
a.shape[1])
r = tsqr_hr_helper2.remote(s, r_temp)
return ray.get(y), ray.get(t), ray.get(y_top), ray.get(r)
@ray.remote
def qr_helper1(a_rc, y_ri, t, W_c):
return a_rc - np.dot(y_ri, np.dot(t.T, W_c))
@ray.remote
def qr_helper2(y_ri, a_rc):
return np.dot(y_ri.T, a_rc)
# This is Algorithm 7 from
# http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf.
@ray.remote(num_return_vals=2)
def qr(a):
"""Algorithm 7 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf"""
m, n = a.shape[0], a.shape[1]
k = min(m, n)
# we will store our scratch work in a_work
a_work = DistArray(a.shape, np.copy(a.objectids))
a_work = core.DistArray(a.shape, np.copy(a.objectids))
result_dtype = np.linalg.qr(ray.get(a.objectids[0, 0]))[0].dtype.name
r_res = ray.get(zeros.remote([k, n], result_dtype)) # TODO(rkn): It would be preferable not to get this right after creating it.
y_res = ray.get(zeros.remote([m, k], result_dtype)) # TODO(rkn): It would be preferable not to get this right after creating it.
# TODO(rkn): It would be preferable not to get this right after creating it.
r_res = ray.get(core.zeros.remote([k, n], result_dtype))
# TODO(rkn): It would be preferable not to get this right after creating it.
y_res = ray.get(core.zeros.remote([m, k], result_dtype))
Ts = []
for i in range(min(a.num_blocks[0], a.num_blocks[1])): # this differs from the paper, which says "for i in range(a.num_blocks[1])", but that doesn't seem to make any sense when a.num_blocks[1] > a.num_blocks[0]
sub_dist_array = subblocks.remote(a_work, list(range(i, a_work.num_blocks[0])), [i])
# The for loop differs from the paper, which says
# "for i in range(a.num_blocks[1])", but that doesn't seem to make any sense
# when a.num_blocks[1] > a.num_blocks[0].
for i in range(min(a.num_blocks[0], a.num_blocks[1])):
sub_dist_array = core.subblocks.remote(
a_work, list(range(i, a_work.num_blocks[0])), [i])
y, t, _, R = tsqr_hr.remote(sub_dist_array)
y_val = ray.get(y)
@@ -167,7 +196,7 @@ def qr(a):
r_res.objectids[i, i] = ra.dot.remote(eye_temp, R)
else:
r_res.objectids[i, i] = R
Ts.append(numpy_to_dist.remote(t))
Ts.append(core.numpy_to_dist.remote(t))
for c in range(i + 1, a.num_blocks[1]):
W_rcs = []
@@ -182,9 +211,14 @@ def qr(a):
r_res.objectids[i, c] = a_work.objectids[i, c]
# construct q_res from Ys and Ts
q = eye.remote(m, k, dtype_name=result_dtype)
q = core.eye.remote(m, k, dtype_name=result_dtype)
for i in range(len(Ts))[::-1]:
y_col_block = subblocks.remote(y_res, [], [i])
q = subtract.remote(q, dot.remote(y_col_block, dot.remote(Ts[i], dot.remote(transpose.remote(y_col_block), q))))
y_col_block = core.subblocks.remote(y_res, [], [i])
q = core.subtract.remote(
q, core.dot.remote(
y_col_block,
core.dot.remote(Ts[i],
core.dot.remote(core.transpose.remote(y_col_block),
q))))
return ray.get(q), r_res
@@ -6,13 +6,15 @@ import numpy as np
import ray.experimental.array.remote as ra
import ray
from .core import *
from .core import DistArray
@ray.remote
def normal(shape):
num_blocks = DistArray.compute_num_blocks(shape)
objectids = np.empty(num_blocks, dtype=object)
for index in np.ndindex(*num_blocks):
objectids[index] = ra.random.normal.remote(DistArray.compute_block_shape(index, shape))
objectids[index] = ra.random.normal.remote(
DistArray.compute_block_shape(index, shape))
result = DistArray(shape, objectids)
return result
@@ -4,4 +4,10 @@ from __future__ import print_function
from . import random
from . import linalg
from .core import *
from .core import (zeros, zeros_like, ones, eye, dot, vstack, hstack, subarray,
copy, tril, triu, diag, transpose, add, subtract, sum,
shape, sum_list)
__all__ = ["random", "linalg", "zeros", "zeros_like", "ones", "eye", "dot",
"vstack", "hstack", "subarray", "copy", "tril", "triu", "diag",
"transpose", "add", "subtract", "sum", "shape", "sum_list"]
+20 -5
View File
@@ -5,82 +5,97 @@ from __future__ import print_function
import numpy as np
import ray
__all__ = ["zeros", "zeros_like", "ones", "eye", "dot", "vstack", "hstack", "subarray", "copy", "tril", "triu", "diag", "transpose", "add", "subtract", "sum", "shape", "sum_list"]
@ray.remote
def zeros(shape, dtype_name="float", order="C"):
return np.zeros(shape, dtype=np.dtype(dtype_name), order=order)
@ray.remote
def zeros_like(a, dtype_name="None", order="K", subok=True):
dtype_val = None if dtype_name == "None" else np.dtype(dtype_name)
return np.zeros_like(a, dtype=dtype_val, order=order, subok=subok)
@ray.remote
def ones(shape, dtype_name="float", order="C"):
return np.ones(shape, dtype=np.dtype(dtype_name), order=order)
@ray.remote
def eye(N, M=-1, k=0, dtype_name="float"):
M = N if M == -1 else M
return np.eye(N, M=M, k=k, dtype=np.dtype(dtype_name))
@ray.remote
def dot(a, b):
return np.dot(a, b)
@ray.remote
def vstack(*xs):
return np.vstack(xs)
@ray.remote
def hstack(*xs):
return np.hstack(xs)
# TODO(rkn): instead of this, consider implementing slicing
# TODO(rkn): Instead of this, consider implementing slicing.
# TODO(rkn): Be consistent about using "index" versus "indices".
@ray.remote
def subarray(a, lower_indices, upper_indices): # TODO(rkn): be consistent about using "index" versus "indices"
def subarray(a, lower_indices, upper_indices):
return a[[slice(l, u) for (l, u) in zip(lower_indices, upper_indices)]]
@ray.remote
def copy(a, order="K"):
return np.copy(a, order=order)
@ray.remote
def tril(m, k=0):
return np.tril(m, k=k)
@ray.remote
def triu(m, k=0):
return np.triu(m, k=k)
@ray.remote
def diag(v, k=0):
return np.diag(v, k=k)
@ray.remote
def transpose(a, axes=[]):
axes = None if axes == [] else axes
return np.transpose(a, axes=axes)
@ray.remote
def add(x1, x2):
return np.add(x1, x2)
@ray.remote
def subtract(x1, x2):
return np.subtract(x1, x2)
@ray.remote
def sum(x, axis=-1):
return np.sum(x, axis=axis if axis != -1 else None)
@ray.remote
def shape(a):
return np.shape(a)
# We use Any to allow different numerical types as well as numpy arrays.
# TODO(rkn):this isn't in the numpy API, so be careful about exposing this.
@ray.remote
def sum_list(*xs):
return np.sum(xs, axis=0)
+21 -1
View File
@@ -8,84 +8,104 @@ import ray
__all__ = ["matrix_power", "solve", "tensorsolve", "tensorinv", "inv",
"cholesky", "eigvals", "eigvalsh", "pinv", "slogdet", "det",
"svd", "eig", "eigh", "lstsq", "norm", "qr", "cond", "matrix_rank",
"LinAlgError", "multi_dot"]
"multi_dot"]
@ray.remote
def matrix_power(M, n):
return np.linalg.matrix_power(M, n)
@ray.remote
def solve(a, b):
return np.linalg.solve(a, b)
@ray.remote(num_return_vals=2)
def tensorsolve(a):
raise NotImplementedError
@ray.remote(num_return_vals=2)
def tensorinv(a):
raise NotImplementedError
@ray.remote
def inv(a):
return np.linalg.inv(a)
@ray.remote
def cholesky(a):
return np.linalg.cholesky(a)
@ray.remote
def eigvals(a):
return np.linalg.eigvals(a)
@ray.remote
def eigvalsh(a):
raise NotImplementedError
@ray.remote
def pinv(a):
return np.linalg.pinv(a)
@ray.remote
def slogdet(a):
raise NotImplementedError
@ray.remote
def det(a):
return np.linalg.det(a)
@ray.remote(num_return_vals=3)
def svd(a):
return np.linalg.svd(a)
@ray.remote(num_return_vals=2)
def eig(a):
return np.linalg.eig(a)
@ray.remote(num_return_vals=2)
def eigh(a):
return np.linalg.eigh(a)
@ray.remote(num_return_vals=4)
def lstsq(a, b):
return np.linalg.lstsq(a)
@ray.remote
def norm(x):
return np.linalg.norm(x)
@ray.remote(num_return_vals=2)
def qr(a):
return np.linalg.qr(a)
@ray.remote
def cond(x):
return np.linalg.cond(x)
@ray.remote
def matrix_rank(M):
return np.linalg.matrix_rank(M)
@ray.remote
def multi_dot(*a):
raise NotImplementedError
@@ -5,6 +5,7 @@ from __future__ import print_function
import numpy as np
import ray
@ray.remote
def normal(shape):
return np.random.normal(size=shape)
+1
View File
@@ -2,6 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def get_local_schedulers(worker):
local_schedulers = []
for client in worker.redis_client.keys("CL:*"):
+24 -11
View File
@@ -4,6 +4,7 @@ from __future__ import print_function
import numpy as np
from collections import deque, OrderedDict
def unflatten(vector, shapes):
i = 0
arrays = []
@@ -15,6 +16,7 @@ def unflatten(vector, shapes):
assert len(vector) == i, "Passed weight does not have the correct shape."
return arrays
class TensorFlowVariables(object):
"""An object used to extract variables from a loss function.
@@ -38,11 +40,11 @@ class TensorFlowVariables(object):
variable_names = []
explored_inputs = set([loss])
# We do a BFS on the dependency graph of the input function to find
# We do a BFS on the dependency graph of the input function to find
# the variables.
while len(queue) != 0:
tf_obj = queue.popleft()
# The object put into the queue is not necessarily an operation, so we
# want the op attribute to get the operation underlying the object.
# Only operations contain the inputs that we can explore.
@@ -61,14 +63,16 @@ class TensorFlowVariables(object):
if "Variable" in tf_obj.node_def.op:
variable_names.append(tf_obj.node_def.name)
self.variables = OrderedDict()
for v in [v for v in tf.global_variables() if v.op.node_def.name in variable_names]:
for v in [v for v in tf.global_variables()
if v.op.node_def.name in variable_names]:
self.variables[v.op.node_def.name] = v
self.placeholders = dict()
self.assignment_nodes = []
# Create new placeholders to put in custom weights.
for k, var in self.variables.items():
self.placeholders[k] = tf.placeholder(var.value().dtype, var.get_shape().as_list())
self.placeholders[k] = tf.placeholder(var.value().dtype,
var.get_shape().as_list())
self.assignment_nodes.append(var.assign(self.placeholders[k]))
def set_session(self, sess):
@@ -76,24 +80,30 @@ class TensorFlowVariables(object):
self.sess = sess
def get_flat_size(self):
return sum([np.prod(v.get_shape().as_list()) for v in self.variables.values()])
return sum([np.prod(v.get_shape().as_list())
for v in self.variables.values()])
def _check_sess(self):
"""Checks if the session is set, and if not throw an error message."""
assert self.sess is not None, "The session is not set. Set the session either by passing it into the TensorFlowVariables constructor or by calling set_session(sess)."
assert self.sess is not None, ("The session is not set. Set the session "
"either by passing it into the "
"TensorFlowVariables constructor or by "
"calling set_session(sess).")
def get_flat(self):
"""Gets the weights and returns them as a flat array."""
self._check_sess()
return np.concatenate([v.eval(session=self.sess).flatten() for v in self.variables.values()])
return np.concatenate([v.eval(session=self.sess).flatten()
for v in self.variables.values()])
def set_flat(self, new_weights):
"""Sets the weights to new_weights, converting from a flat array."""
self._check_sess()
shapes = [v.get_shape().as_list() for v in self.variables.values()]
arrays = unflatten(new_weights, shapes)
placeholders = [self.placeholders[k] for k, v in self.variables.items()]
self.sess.run(self.assignment_nodes, feed_dict=dict(zip(placeholders,arrays)))
self.sess.run(self.assignment_nodes,
feed_dict=dict(zip(placeholders, arrays)))
def get_weights(self):
"""Returns the weights of the variables of the loss function in a list."""
@@ -103,4 +113,7 @@ class TensorFlowVariables(object):
def set_weights(self, new_weights):
"""Sets the weights to new_weights."""
self._check_sess()
self.sess.run(self.assignment_nodes, feed_dict={self.placeholders[name]: value for (name, value) in new_weights.items() if name in self.placeholders})
self.sess.run(self.assignment_nodes,
feed_dict={self.placeholders[name]: value
for (name, value) in new_weights.items()
if name in self.placeholders})
+13 -7
View File
@@ -9,6 +9,7 @@ import sys
import ray
def tarred_directory_as_bytes(source_dir):
"""Tar a directory and return it as a byte string.
@@ -26,6 +27,7 @@ def tarred_directory_as_bytes(source_dir):
string_file.seek(0)
return string_file.read()
def tarred_bytes_to_directory(tarred_bytes, target_dir):
"""Take a byte string and untar it.
@@ -38,6 +40,7 @@ def tarred_bytes_to_directory(tarred_bytes, target_dir):
with tarfile.open(fileobj=string_file) as tar:
tar.extractall(path=target_dir)
def copy_directory(source_dir, target_dir=None):
"""Copy a local directory to each machine in the Ray cluster.
@@ -45,17 +48,18 @@ def copy_directory(source_dir, target_dir=None):
example, source_dir can be /a/b/c and target_dir can be /d/e/c. In this case,
the directory /d/e will be added to the Python path of each worker.
Note that this method is not completely safe to use. For example, workers that
do not do the copying and only set their paths (only one worker per node does
the copying) may try to execute functions that use the files in the directory
being copied before the directory being copied has finished untarring.
Note that this method is not completely safe to use. For example, workers
that do not do the copying and only set their paths (only one worker per node
does the copying) may try to execute functions that use the files in the
directory being copied before the directory being copied has finished
untarring.
Args:
source_dir (str): The directory to copy.
target_dir (str): The location to copy it to on the other machines. If this
is not provided, the source_dir will be used. If it is provided and is
different from source_dir, the source_dir also be copied to the target_dir
location on this machine.
different from source_dir, the source_dir also be copied to the
target_dir location on this machine.
"""
target_dir = source_dir if target_dir is None else target_dir
source_dir = os.path.abspath(source_dir)
@@ -63,8 +67,10 @@ def copy_directory(source_dir, target_dir=None):
source_basename = os.path.basename(source_dir)
target_basename = os.path.basename(target_dir)
if source_basename != target_basename:
raise Exception("The source_dir and target_dir must have the same base name, {} != {}".format(source_basename, target_basename))
raise Exception("The source_dir and target_dir must have the same base "
"name, {} != {}".format(source_basename, target_basename))
tarred_bytes = tarred_directory_as_bytes(source_dir)
def f(worker_info):
if worker_info["counter"] == 0:
tarred_bytes_to_directory(tarred_bytes, os.path.dirname(target_dir))
+3 -1
View File
@@ -2,4 +2,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from .global_scheduler_services import *
from .global_scheduler_services import start_global_scheduler
__all__ = ["start_global_scheduler"]
@@ -6,6 +6,7 @@ import os
import subprocess
import time
def start_global_scheduler(redis_address, use_valgrind=False,
use_profiler=False, stdout_file=None,
stderr_file=None):
@@ -15,8 +16,8 @@ def start_global_scheduler(redis_address, use_valgrind=False,
redis_address (str): The address of the Redis instance.
use_valgrind (bool): True if the global scheduler should be started inside
of valgrind. If this is True, use_profiler must be False.
use_profiler (bool): True if the global scheduler should be started inside a
profiler. If this is True, use_valgrind must be False.
use_profiler (bool): True if the global scheduler should be started inside
a profiler. If this is True, use_valgrind must be False.
stdout_file: A file handle opened for writing to redirect stdout to. If no
redirection should happen, then this should be None.
stderr_file: A file handle opened for writing to redirect stderr to. If no
@@ -27,7 +28,9 @@ def start_global_scheduler(redis_address, use_valgrind=False,
"""
if use_valgrind and use_profiler:
raise Exception("Cannot use valgrind and profiler at the same time.")
global_scheduler_executable = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../core/src/global_scheduler/global_scheduler")
global_scheduler_executable = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
"../core/src/global_scheduler/global_scheduler")
command = [global_scheduler_executable, "-r", redis_address]
if use_valgrind:
pid = subprocess.Popen(["valgrind",
@@ -35,7 +38,7 @@ def start_global_scheduler(redis_address, use_valgrind=False,
"--leak-check=full",
"--show-leak-kinds=all",
"--error-exitcode=1"] + command,
stdout=stdout_file, stderr=stderr_file)
stdout=stdout_file, stderr=stderr_file)
time.sleep(1.0)
elif use_profiler:
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
+57 -31
View File
@@ -7,16 +7,14 @@ import os
import random
import redis
import signal
import subprocess
import sys
import threading
import time
import unittest
import ray.global_scheduler as global_scheduler
import ray.local_scheduler as local_scheduler
import ray.plasma as plasma
from ray.plasma.utils import random_object_id, generate_metadata, write_to_data_buffer, create_object_with_id, create_object
from ray.plasma.utils import create_object
from ray import services
@@ -39,21 +37,27 @@ TASK_STATUS_DONE = 16
DB_CLIENT_PREFIX = "CL:"
TASK_PREFIX = "TT:"
def random_driver_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
def random_task_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
def random_function_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
def random_object_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
def new_port():
return random.randint(10000, 65535)
class TestGlobalScheduler(unittest.TestCase):
def setUp(self):
@@ -62,9 +66,11 @@ class TestGlobalScheduler(unittest.TestCase):
redis_port, self.redis_process = services.start_redis(cleanup=False)
redis_address = services.address(node_ip_address, redis_port)
# Create a Redis client.
self.redis_client = redis.StrictRedis(host=node_ip_address, port=redis_port)
self.redis_client = redis.StrictRedis(host=node_ip_address,
port=redis_port)
# Start one global scheduler.
self.p1 = global_scheduler.start_global_scheduler(redis_address, use_valgrind=USE_VALGRIND)
self.p1 = global_scheduler.start_global_scheduler(
redis_address, use_valgrind=USE_VALGRIND)
self.plasma_store_pids = []
self.plasma_manager_pids = []
self.local_scheduler_pids = []
@@ -76,11 +82,15 @@ class TestGlobalScheduler(unittest.TestCase):
plasma_store_name, p2 = plasma.start_plasma_store()
self.plasma_store_pids.append(p2)
# Start the Plasma manager.
# Assumption: Plasma manager name and port are randomly generated by the plasma module.
plasma_manager_name, p3, plasma_manager_port = plasma.start_plasma_manager(plasma_store_name, redis_address)
# Assumption: Plasma manager name and port are randomly generated by the
# plasma module.
manager_info = plasma.start_plasma_manager(plasma_store_name,
redis_address)
plasma_manager_name, p3, plasma_manager_port = manager_info
self.plasma_manager_pids.append(p3)
plasma_address = "{}:{}".format(node_ip_address, plasma_manager_port)
plasma_client = plasma.PlasmaClient(plasma_store_name, plasma_manager_name)
plasma_client = plasma.PlasmaClient(plasma_store_name,
plasma_manager_name)
self.plasma_clients.append(plasma_client)
# Start the local scheduler.
local_scheduler_name, p4 = local_scheduler.start_local_scheduler(
@@ -91,7 +101,7 @@ class TestGlobalScheduler(unittest.TestCase):
static_resource_list=[10, 0])
# Connect to the scheduler.
local_scheduler_client = local_scheduler.LocalSchedulerClient(
local_scheduler_name, NIL_ACTOR_ID, False)
local_scheduler_name, NIL_ACTOR_ID, False)
self.local_scheduler_clients.append(local_scheduler_client)
self.local_scheduler_pids.append(p4)
@@ -149,11 +159,13 @@ class TestGlobalScheduler(unittest.TestCase):
return db_client_id
def test_task_default_resources(self):
task1 = local_scheduler.Task(random_driver_id(), random_function_id(), [random_object_id()], 0, random_task_id(), 0)
task1 = local_scheduler.Task(random_driver_id(), random_function_id(),
[random_object_id()], 0, random_task_id(), 0)
self.assertEqual(task1.required_resources(), [1.0, 0.0])
task2 = local_scheduler.Task(random_driver_id(), random_function_id(),
[random_object_id()], 0, random_task_id(), 0,
local_scheduler.ObjectID(NIL_ACTOR_ID), 0, [1.0, 2.0])
local_scheduler.ObjectID(NIL_ACTOR_ID), 0,
[1.0, 2.0])
self.assertEqual(task2.required_resources(), [1.0, 2.0])
def test_redis_only_single_task(self):
@@ -162,31 +174,37 @@ class TestGlobalScheduler(unittest.TestCase):
task state transitions in Redis only. TODO(atumanov): implement.
"""
# Check precondition for this test:
# There should be 2n+1 db clients: the global scheduler + one local scheduler and one plasma per node.
self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
2 * NUM_CLUSTER_NODES + 1)
# There should be 2n+1 db clients: the global scheduler + one local
# scheduler and one plasma per node.
self.assertEqual(
len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
2 * NUM_CLUSTER_NODES + 1)
db_client_id = self.get_plasma_manager_id()
assert(db_client_id != None)
assert(db_client_id is not None)
assert(db_client_id.startswith(b"CL:"))
db_client_id = db_client_id[len(b"CL:"):] # Remove the CL: prefix.
db_client_id = db_client_id[len(b"CL:"):] # Remove the CL: prefix.
def test_integration_single_task(self):
# There should be three db clients, the global scheduler, the local
# scheduler, and the plasma manager.
self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
2 * NUM_CLUSTER_NODES + 1)
self.assertEqual(
len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
2 * NUM_CLUSTER_NODES + 1)
num_return_vals = [0, 1, 2, 3, 5, 10]
# Insert the object into Redis.
data_size = 0xf1f0
metadata_size = 0x40
plasma_client = self.plasma_clients[0]
object_dep, memory_buffer, metadata = create_object(plasma_client, data_size, metadata_size, seal=True)
object_dep, memory_buffer, metadata = create_object(
plasma_client, data_size, metadata_size, seal=True)
# Sleep before submitting task to local scheduler.
time.sleep(0.1)
# Submit a task to Redis.
task = local_scheduler.Task(random_driver_id(), random_function_id(), [local_scheduler.ObjectID(object_dep)], num_return_vals[0], random_task_id(), 0)
task = local_scheduler.Task(random_driver_id(), random_function_id(),
[local_scheduler.ObjectID(object_dep)],
num_return_vals[0], random_task_id(), 0)
self.local_scheduler_clients[0].submit(task)
time.sleep(0.1)
# There should now be a task in Redis, and it should get assigned to the
@@ -217,8 +235,9 @@ class TestGlobalScheduler(unittest.TestCase):
def integration_many_tasks_helper(self, timesync=True):
# There should be three db clients, the global scheduler, the local
# scheduler, and the plasma manager.
self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
2 * NUM_CLUSTER_NODES + 1)
self.assertEqual(
len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
2 * NUM_CLUSTER_NODES + 1)
num_return_vals = [0, 1, 2, 3, 5, 10]
# Submit a bunch of tasks to Redis.
@@ -228,11 +247,16 @@ class TestGlobalScheduler(unittest.TestCase):
data_size = np.random.randint(1 << 20)
metadata_size = np.random.randint(1 << 10)
plasma_client = self.plasma_clients[0]
object_dep, memory_buffer, metadata = create_object(plasma_client, data_size, metadata_size, seal=True)
object_dep, memory_buffer, metadata = create_object(plasma_client,
data_size,
metadata_size,
seal=True)
if timesync:
# Give 10ms for object info handler to fire (long enough to yield CPU).
time.sleep(0.010)
task = local_scheduler.Task(random_driver_id(), random_function_id(), [local_scheduler.ObjectID(object_dep)], num_return_vals[0], random_task_id(), 0)
task = local_scheduler.Task(random_driver_id(), random_function_id(),
[local_scheduler.ObjectID(object_dep)],
num_return_vals[0], random_task_id(), 0)
self.local_scheduler_clients[0].submit(task)
# Check that there are the correct number of tasks in Redis and that they
# all get assigned to the local scheduler.
@@ -243,17 +267,18 @@ class TestGlobalScheduler(unittest.TestCase):
self.assertLessEqual(len(task_entries), num_tasks)
# First, check if all tasks made it to Redis.
if len(task_entries) == num_tasks:
task_contents = [self.redis_client.hgetall(task_entries[i]) for i in range(len(task_entries))]
task_contents = [self.redis_client.hgetall(task_entries[i])
for i in range(len(task_entries))]
task_statuses = [int(contents[b"state"]) for contents in task_contents]
self.assertTrue(all([
status in [TASK_STATUS_WAITING,
TASK_STATUS_SCHEDULED,
TASK_STATUS_QUEUED] for status in task_statuses
]))
self.assertTrue(all([status in [TASK_STATUS_WAITING,
TASK_STATUS_SCHEDULED,
TASK_STATUS_QUEUED]
for status in task_statuses]))
num_tasks_done = task_statuses.count(TASK_STATUS_QUEUED)
num_tasks_scheduled = task_statuses.count(TASK_STATUS_SCHEDULED)
num_tasks_waiting = task_statuses.count(TASK_STATUS_WAITING)
print("tasks in Redis = {}, tasks waiting = {}, tasks scheduled = {}, tasks queued = {}, retries left = {}"
print("tasks in Redis = {}, tasks waiting = {}, tasks scheduled = {}, "
"tasks queued = {}, retries left = {}"
.format(len(task_entries), num_tasks_waiting,
num_tasks_scheduled, num_tasks_done, num_retries))
if all([status == TASK_STATUS_QUEUED for status in task_statuses]):
@@ -275,6 +300,7 @@ class TestGlobalScheduler(unittest.TestCase):
# notifications.
self.integration_many_tasks_helper(timesync=False)
if __name__ == "__main__":
if len(sys.argv) > 1:
# Pop the argument so we don't mess with unittest's own argument parser.
+7 -2
View File
@@ -2,5 +2,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.core.src.local_scheduler.liblocal_scheduler_library import *
from .local_scheduler_services import *
from ray.core.src.local_scheduler.liblocal_scheduler_library import (
Task, LocalSchedulerClient, ObjectID, check_simple_value, task_from_string,
task_to_string)
from .local_scheduler_services import start_local_scheduler
__all__ = ["Task", "LocalSchedulerClient", "ObjectID", "check_simple_value",
"task_from_string", "task_to_string", "start_local_scheduler"]
@@ -7,9 +7,11 @@ import random
import subprocess
import time
def random_name():
return str(random.randint(0, 99999999))
def start_local_scheduler(plasma_store_name,
plasma_manager_name=None,
worker_path=None,
@@ -38,8 +40,8 @@ def start_local_scheduler(plasma_store_name,
running on.
redis_address (str): The address of the Redis instance to connect to. If
this is not provided, then the local scheduler will not connect to Redis.
use_valgrind (bool): True if the local scheduler should be started inside of
valgrind. If this is True, use_profiler must be False.
use_valgrind (bool): True if the local scheduler should be started inside
of valgrind. If this is True, use_profiler must be False.
use_profiler (bool): True if the local scheduler should be started inside a
profiler. If this is True, use_valgrind must be False.
stdout_file: A file handle opened for writing to redirect stdout to. If no
@@ -56,11 +58,14 @@ def start_local_scheduler(plasma_store_name,
A tuple of the name of the local scheduler socket and the process ID of the
local scheduler process.
"""
if (plasma_manager_name == None) != (redis_address == None):
raise Exception("If one of the plasma_manager_name and the redis_address is provided, then both must be provided.")
if (plasma_manager_name is None) != (redis_address is None):
raise Exception("If one of the plasma_manager_name and the redis_address "
"is provided, then both must be provided.")
if use_valgrind and use_profiler:
raise Exception("Cannot use valgrind and profiler at the same time.")
local_scheduler_executable = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../core/src/local_scheduler/local_scheduler")
local_scheduler_executable = os.path.join(os.path.dirname(
os.path.abspath(__file__)),
"../core/src/local_scheduler/local_scheduler")
local_scheduler_name = "/tmp/scheduler{}".format(random_name())
command = [local_scheduler_executable,
"-s", local_scheduler_name,
@@ -90,8 +95,10 @@ def start_local_scheduler(plasma_store_name,
if plasma_address is not None:
command += ["-a", plasma_address]
if static_resource_list is not None:
assert all([isinstance(resource, int) or isinstance(resource, float) for resource in static_resource_list])
command += ["-c", ",".join([str(resource) for resource in static_resource_list])]
assert all([isinstance(resource, int) or isinstance(resource, float)
for resource in static_resource_list])
command += ["-c", ",".join([str(resource) for resource
in static_resource_list])]
if use_valgrind:
pid = subprocess.Popen(["valgrind",
@@ -99,7 +106,7 @@ def start_local_scheduler(plasma_store_name,
"--leak-check=full",
"--show-leak-kinds=all",
"--error-exitcode=1"] + command,
stdout=stdout_file, stderr=stderr_file)
stdout=stdout_file, stderr=stderr_file)
time.sleep(1.0)
elif use_profiler:
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
+46 -35
View File
@@ -4,9 +4,7 @@ from __future__ import print_function
import numpy as np
import os
import random
import signal
import subprocess
import sys
import threading
import time
@@ -20,18 +18,23 @@ ID_SIZE = 20
NIL_ACTOR_ID = 20 * b"\xff"
def random_object_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
def random_driver_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
def random_task_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
def random_function_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
class TestLocalSchedulerClient(unittest.TestCase):
def setUp(self):
@@ -39,10 +42,11 @@ class TestLocalSchedulerClient(unittest.TestCase):
plasma_store_name, self.p1 = plasma.start_plasma_store()
self.plasma_client = plasma.PlasmaClient(plasma_store_name)
# Start a local scheduler.
scheduler_name, self.p2 = local_scheduler.start_local_scheduler(plasma_store_name, use_valgrind=USE_VALGRIND)
scheduler_name, self.p2 = local_scheduler.start_local_scheduler(
plasma_store_name, use_valgrind=USE_VALGRIND)
# Connect to the scheduler.
self.local_scheduler_client = local_scheduler.LocalSchedulerClient(
scheduler_name, NIL_ACTOR_ID, False)
scheduler_name, NIL_ACTOR_ID, False)
def tearDown(self):
# Check that the processes are still alive.
@@ -70,37 +74,38 @@ class TestLocalSchedulerClient(unittest.TestCase):
self.plasma_client.seal(object_id.id())
# Define some arguments to use for the tasks.
args_list = [
[],
#{},
#(),
1 * [1],
10 * [1],
100 * [1],
1000 * [1],
1 * ["a"],
10 * ["a"],
100 * ["a"],
1000 * ["a"],
[1, 1.3, 1 << 100, "hi", u"hi", [1, 2]],
object_ids[:1],
object_ids[:2],
object_ids[:3],
object_ids[:4],
object_ids[:5],
object_ids[:10],
object_ids[:100],
object_ids[:256],
[1, object_ids[0]],
[object_ids[0], "a"],
[1, object_ids[0], "a"],
[object_ids[0], 1, object_ids[1], "a"],
object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
object_ids + 100 * ["a"] + object_ids
[],
[{}],
[()],
1 * [1],
10 * [1],
100 * [1],
1000 * [1],
1 * ["a"],
10 * ["a"],
100 * ["a"],
1000 * ["a"],
[1, 1.3, 1 << 100, "hi", u"hi", [1, 2]],
object_ids[:1],
object_ids[:2],
object_ids[:3],
object_ids[:4],
object_ids[:5],
object_ids[:10],
object_ids[:100],
object_ids[:256],
[1, object_ids[0]],
[object_ids[0], "a"],
[1, object_ids[0], "a"],
[object_ids[0], 1, object_ids[1], "a"],
object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
object_ids + 100 * ["a"] + object_ids
]
for args in args_list:
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
task = local_scheduler.Task(random_driver_id(), function_id, args, num_return_vals, random_task_id(), 0)
task = local_scheduler.Task(random_driver_id(), function_id, args,
num_return_vals, random_task_id(), 0)
# Submit a task.
self.local_scheduler_client.submit(task)
# Get the task.
@@ -119,7 +124,8 @@ class TestLocalSchedulerClient(unittest.TestCase):
# Submit all of the tasks.
for args in args_list:
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
task = local_scheduler.Task(random_driver_id(), function_id, args, num_return_vals, random_task_id(), 0)
task = local_scheduler.Task(random_driver_id(), function_id, args,
num_return_vals, random_task_id(), 0)
self.local_scheduler_client.submit(task)
# Get all of the tasks.
for args in args_list:
@@ -129,8 +135,10 @@ class TestLocalSchedulerClient(unittest.TestCase):
def test_scheduling_when_objects_ready(self):
# Create a task and submit it.
object_id = random_object_id()
task = local_scheduler.Task(random_driver_id(), random_function_id(), [object_id], 0, random_task_id(), 0)
task = local_scheduler.Task(random_driver_id(), random_function_id(),
[object_id], 0, random_task_id(), 0)
self.local_scheduler_client.submit(task)
# Launch a thread to get the task.
def get_task():
self.local_scheduler_client.get_task()
@@ -149,7 +157,9 @@ class TestLocalSchedulerClient(unittest.TestCase):
# Create a task with two dependencies and submit it.
object_id1 = random_object_id()
object_id2 = random_object_id()
task = local_scheduler.Task(random_driver_id(), random_function_id(), [object_id1, object_id2], 0, random_task_id(), 0)
task = local_scheduler.Task(random_driver_id(), random_function_id(),
[object_id1, object_id2], 0, random_task_id(),
0)
self.local_scheduler_client.submit(task)
# Launch a thread to get the task.
@@ -196,9 +206,10 @@ class TestLocalSchedulerClient(unittest.TestCase):
# Wait until the thread finishes so that we know the task was scheduled.
t.join()
if __name__ == "__main__":
if len(sys.argv) > 1:
# pop the argument so we don't mess with unittest's own argument parser
# Pop the argument so we don't mess with unittest's own argument parser.
if sys.argv[-1] == "valgrind":
arg = sys.argv.pop()
USE_VALGRIND = True
+9 -5
View File
@@ -3,13 +3,13 @@ from __future__ import division
from __future__ import print_function
import argparse
import os
import redis
import time
from ray.services import get_ip_address
from ray.services import get_port
class LogMonitor(object):
"""A monitor process for monitoring Ray log files.
@@ -27,15 +27,17 @@ class LogMonitor(object):
def __init__(self, redis_ip_address, redis_port, node_ip_address):
"""Initialize the log monitor object."""
self.node_ip_address = node_ip_address
self.redis_client = redis.StrictRedis(host=redis_ip_address, port=redis_port)
self.redis_client = redis.StrictRedis(host=redis_ip_address,
port=redis_port)
self.log_files = {}
self.log_file_handles = {}
def update_log_filenames(self):
"""Get the most up-to-date list of log files to monitor from Redis."""
num_current_log_files = len(self.log_files)
new_log_filenames = self.redis_client.lrange("LOG_FILENAMES:{}".format(self.node_ip_address),
num_current_log_files, -1)
new_log_filenames = self.redis_client.lrange(
"LOG_FILENAMES:{}".format(self.node_ip_address),
num_current_log_files, -1)
for log_filename in new_log_filenames:
print("Beginning to track file {}".format(log_filename))
assert log_filename not in self.log_files
@@ -50,7 +52,8 @@ class LogMonitor(object):
# If there are any new lines, cache them and also push them to Redis.
if len(new_lines) > 0:
self.log_files[log_filename] += new_lines
redis_key = "LOGFILE:{}:{}".format(self.node_ip_address, log_filename.decode("ascii"))
redis_key = "LOGFILE:{}:{}".format(self.node_ip_address,
log_filename.decode("ascii"))
self.redis_client.rpush(redis_key, *new_lines)
else:
try:
@@ -69,6 +72,7 @@ class LogMonitor(object):
self.check_log_files_and_push_updates()
time.sleep(1)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=("Parse Redis server for the "
"log monitor to connect to."))
+22 -17
View File
@@ -3,7 +3,6 @@ from __future__ import division
from __future__ import print_function
import argparse
import binascii
from collections import Counter
import logging
import redis
@@ -13,7 +12,8 @@ from ray.services import get_ip_address
from ray.services import get_port
# Import flatbuffer bindings.
from ray.core.generated.SubscribeToDBClientTableReply import SubscribeToDBClientTableReply
from ray.core.generated.SubscribeToDBClientTableReply \
import SubscribeToDBClientTableReply
from ray.core.generated.TaskReply import TaskReply
# These variables must be kept in sync with the C codebase.
@@ -41,6 +41,7 @@ logging.basicConfig()
log = logging.getLogger()
log.setLevel(logging.WARN)
class Monitor(object):
"""A monitor for Ray processes.
@@ -92,7 +93,8 @@ class Monitor(object):
TASK_STATUS_LOST. A local scheduler is deemed dead if it is in
self.dead_local_schedulers.
"""
task_ids = self.redis.scan_iter(match="{prefix}*".format(prefix=TASK_PREFIX))
task_ids = self.redis.scan_iter(
match="{prefix}*".format(prefix=TASK_PREFIX))
num_tasks_updated = 0
for task_id in task_ids:
task_id = task_id[len(TASK_PREFIX):]
@@ -123,11 +125,13 @@ class Monitor(object):
"""
# TODO(swang): Also kill the associated plasma store, since it's no longer
# reachable without a plasma manager.
object_ids = self.redis.scan_iter(match="{prefix}*".format(prefix=OBJECT_PREFIX))
object_ids = self.redis.scan_iter(
match="{prefix}*".format(prefix=OBJECT_PREFIX))
num_objects_removed = 0
for object_id in object_ids:
object_id = object_id[len(OBJECT_PREFIX):]
managers = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", object_id)
managers = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
object_id)
for manager in managers:
if manager in self.dead_plasma_managers:
# If the object was on a dead plasma manager, remove that location
@@ -149,7 +153,8 @@ class Monitor(object):
not miss any notifications for deleted clients that occurred before we
subscribed.
"""
db_client_keys = self.redis.keys("{prefix}*".format(prefix=DB_CLIENT_PREFIX))
db_client_keys = self.redis.keys(
"{prefix}*".format(prefix=DB_CLIENT_PREFIX))
for db_client_key in db_client_keys:
db_client_id = db_client_key[len(DB_CLIENT_PREFIX):]
client_type, deleted = self.redis.hmget(db_client_key,
@@ -175,10 +180,10 @@ class Monitor(object):
flatbuffer. Deletions are processed, insertions are ignored. Cleanup of the
associated state in the state tables should be handled by the caller.
"""
notification_object = SubscribeToDBClientTableReply.GetRootAsSubscribeToDBClientTableReply(data, 0)
notification_object = (SubscribeToDBClientTableReply
.GetRootAsSubscribeToDBClientTableReply(data, 0))
db_client_id = notification_object.DbClientId()
client_type = notification_object.ClientType()
auxiliary_address = notification_object.AuxAddress()
is_insertion = notification_object.IsInsertion()
# If the update was an insertion, we ignore it.
@@ -227,7 +232,6 @@ class Monitor(object):
if not self.subscribed[channel]:
# If the data was an integer, then the message was a response to an
# initial subscription request.
is_subscribe = int(data)
message_handler = self.subscribe_handler
elif channel == PLASMA_MANAGER_HEARTBEAT_CHANNEL:
assert(self.subscribed[channel])
@@ -264,12 +268,11 @@ class Monitor(object):
self.cleanup_task_table()
if len(self.dead_plasma_managers) > 0:
self.cleanup_object_table()
log.debug("{} dead local schedulers, {} plasma "
"managers total, {} dead plasma managers".format(
len(self.dead_local_schedulers),
len(self.live_plasma_managers) + len(self.dead_plasma_managers),
len(self.dead_plasma_managers)
))
log.debug("{} dead local schedulers, {} plasma managers total, {} dead "
"plasma managers".format(len(self.dead_local_schedulers),
(len(self.live_plasma_managers) +
len(self.dead_plasma_managers)),
len(self.dead_plasma_managers)))
# Handle messages from the subscription channels.
while True:
@@ -289,7 +292,8 @@ class Monitor(object):
# Handle plasma managers that timed out during this round.
plasma_manager_ids = list(self.live_plasma_managers.keys())
for plasma_manager_id in plasma_manager_ids:
if self.live_plasma_managers[plasma_manager_id] >= NUM_HEARTBEATS_TIMEOUT:
if ((self.live_plasma_managers
[plasma_manager_id]) >= NUM_HEARTBEATS_TIMEOUT):
log.warn("Timed out {}".format(PLASMA_MANAGER_CLIENT_TYPE))
# Remove the plasma manager from the managers whose heartbeats we're
# tracking.
@@ -299,7 +303,8 @@ class Monitor(object):
# receive the notification for this db_client deletion.
self.redis.execute_command("RAY.DISCONNECT", plasma_manager_id)
# Increment the number of heartbeats that we've missed from each plasma manager.
# Increment the number of heartbeats that we've missed from each plasma
# manager.
for plasma_manager_id in self.live_plasma_managers:
self.live_plasma_managers[plasma_manager_id] += 1
+18 -4
View File
@@ -10,15 +10,29 @@ If you are using Anaconda, try fixing this problem by running:
conda install libgcc
"""
__all__ = ["deserialize_list", "numbuf_error",
"numbuf_plasma_object_exists_error", "read_from_buffer",
"register_callbacks", "retrieve_list", "serialize_list",
"store_list", "write_to_buffer"]
try:
from ray.core.src.numbuf.libnumbuf import *
from ray.core.src.numbuf.libnumbuf import (deserialize_list, numbuf_error,
numbuf_plasma_object_exists_error,
read_from_buffer,
register_callbacks, retrieve_list,
serialize_list, store_list,
write_to_buffer)
except ImportError as e:
if hasattr(e, "msg") and isinstance(e.msg, str) and ("libstdc++" in e.msg or "CXX" in e.msg):
if (hasattr(e, "msg") and isinstance(e.msg, str) and ("libstdc++" in e.msg or
"CXX" in e.msg)):
# This code path should be taken with Python 3.
e.msg += helpful_message
elif hasattr(e, "message") and isinstance(e.message, str) and ("libstdc++" in e.message or "CXX" in e.message):
elif (hasattr(e, "message") and isinstance(e.message, str) and
("libstdc++" in e.message or "CXX" in e.message)):
# This code path should be taken with Python 2.
if hasattr(e, "args") and isinstance(e.args, tuple) and len(e.args) == 1 and isinstance(e.args[0], str):
condition = (hasattr(e, "args") and isinstance(e.args, tuple) and
len(e.args) == 1 and isinstance(e.args[0], str))
if condition:
e.args = (e.args[0] + helpful_message,)
else:
if not hasattr(e, "args"):
+34 -14
View File
@@ -1,55 +1,74 @@
# Note that a little bit of code here is taken and slightly modified from the pickler because it was not possible to change its behavior otherwise.
# Note that a little bit of code here is taken and slightly modified from the
# pickler because it was not possible to change its behavior otherwise.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
from ctypes import c_void_p
from cloudpickle import pickle, cloudpickle, CloudPickler, load, loads
__all__ = ["load", "loads", "dump", "dumps"]
try:
from ctypes import pythonapi
pythonapi.PyCell_Set # Make sure this exists
except:
pythonapi = None
def dump(obj, file, protocol=2):
return BetterPickler(file, protocol).dump(obj)
def dumps(obj):
stringio = cloudpickle.StringIO()
dump(obj, stringio)
return stringio.getvalue()
def _make_skel_func(code, closure, base_globals = None):
""" Creates a skeleton function object that contains just the provided
code and the correct number of cells in func_closure. All other
func attributes (e.g. func_globals) are empty.
def _make_skel_func(code, closure, base_globals=None):
"""Create a skeleton function object.
Creates a skeleton function object that contains just the provided code and
the correct number of cells in func_closure. All other func attributes
(e.g. func_globals) are empty.
"""
if base_globals is None: base_globals = {}
base_globals['__builtins__'] = __builtins__
return _make_skel_func.__class__(code, base_globals, None, None, tuple(closure))
if base_globals is None:
base_globals = {}
base_globals["__builtins__"] = __builtins__
return _make_skel_func.__class__(code, base_globals, None, None,
tuple(closure))
def _fill_function(func, globals, defaults, closure, dict):
""" Fills in the rest of function data into the skeleton function object
that were created via _make_skel_func(), including closures.
"""Fill in the resst of the function data.
This fills in the rest of function data into the skeleton function object
that were created via _make_skel_func(), including closures.
"""
result = cloudpickle._fill_function(func, globals, defaults, dict)
if pythonapi is not None:
for i, v in enumerate(closure):
pythonapi.PyCell_Set(c_void_p(id(result.__closure__[i])), c_void_p(id(v)))
pythonapi.PyCell_Set(c_void_p(id(result.__closure__[i])),
c_void_p(id(v)))
return result
class BetterPickler(CloudPickler):
def save_function_tuple(self, func):
code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func)
(code, f_globals, defaults,
closure, dct, base_globals) = self.extract_func_data(func)
self.save(_fill_function)
self.write(pickle.MARK)
self.save(_make_skel_func if pythonapi else cloudpickle._make_skel_func)
self.save((code, map(lambda _: cloudpickle._make_cell(None), closure) if closure and pythonapi is not None else closure, base_globals))
self.save((code,
(map(lambda _: cloudpickle._make_cell(None), closure)
if closure and pythonapi is not None
else closure),
base_globals))
self.write(pickle.REDUCE)
self.memoize(func)
@@ -59,6 +78,7 @@ class BetterPickler(CloudPickler):
self.save(dct)
self.write(pickle.TUPLE)
self.write(pickle.REDUCE)
def save_cell(self, obj):
self.save(cloudpickle._make_cell)
self.save((obj.cell_contents,))
+10 -1
View File
@@ -2,4 +2,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.plasma.plasma import *
from ray.plasma.plasma import (PlasmaBuffer, buffers_equal, PlasmaClient,
start_plasma_store, start_plasma_manager,
plasma_object_exists_error,
plasma_out_of_memory_error,
DEFAULT_PLASMA_STORE_MEMORY)
__all__ = ["PlasmaBuffer", "buffers_equal", "PlasmaClient",
"start_plasma_store", "start_plasma_manager",
"plasma_object_exists_error", "plasma_out_of_memory_error",
"DEFAULT_PLASMA_STORE_MEMORY"]
+55 -29
View File
@@ -12,9 +12,14 @@ import ray.core.src.plasma.libplasma as libplasma
from ray.core.src.plasma.libplasma import plasma_object_exists_error
from ray.core.src.plasma.libplasma import plasma_out_of_memory_error
PLASMA_ID_SIZE = 20
__all__ = ["PlasmaBuffer", "buffers_equal", "PlasmaClient",
"start_plasma_store", "start_plasma_manager",
"plasma_object_exists_error", "plasma_out_of_memory_error",
"DEFAULT_PLASMA_STORE_MEMORY"]
PLASMA_WAIT_TIMEOUT = 2 ** 30
class PlasmaBuffer(object):
"""This is the type of objects returned by calls to get with a PlasmaClient.
@@ -47,8 +52,8 @@ class PlasmaBuffer(object):
"""Read from the PlasmaBuffer as if it were just a regular buffer."""
# We currently don't allow slicing plasma buffers. We should handle this
# better, but it requires some care because the slice may be backed by the
# same memory in the object store, but the original plasma buffer may go out
# of scope causing the memory to no longer be accessible.
# same memory in the object store, but the original plasma buffer may go
# out of scope causing the memory to no longer be accessible.
assert not isinstance(index, slice)
value = self.buffer[index]
if sys.version_info >= (3, 0) and not isinstance(index, slice):
@@ -62,8 +67,8 @@ class PlasmaBuffer(object):
"""
# We currently don't allow slicing plasma buffers. We should handle this
# better, but it requires some care because the slice may be backed by the
# same memory in the object store, but the original plasma buffer may go out
# of scope causing the memory to no longer be accessible.
# same memory in the object store, but the original plasma buffer may go
# out of scope causing the memory to no longer be accessible.
assert not isinstance(index, slice)
if sys.version_info >= (3, 0) and not isinstance(index, slice):
value = ord(value)
@@ -73,48 +78,56 @@ class PlasmaBuffer(object):
"""Return the length of the buffer."""
return len(self.buffer)
def buffers_equal(buff1, buff2):
"""Compare two buffers. These buffers may be PlasmaBuffer objects.
This method should only be used in the tests. We implement a special helper
method for doing this because doing comparisons by slicing is much faster, but
we don't want to expose slicing of PlasmaBuffer objects because it currently
is not safe.
method for doing this because doing comparisons by slicing is much faster,
but we don't want to expose slicing of PlasmaBuffer objects because it
currently is not safe.
"""
buff1_to_compare = buff1.buffer if isinstance(buff1, PlasmaBuffer) else buff1
buff2_to_compare = buff2.buffer if isinstance(buff2, PlasmaBuffer) else buff2
return buff1_to_compare[:] == buff2_to_compare[:]
class PlasmaClient(object):
"""The PlasmaClient is used to interface with a plasma store and a plasma manager.
"""The PlasmaClient is used to interface with a plasma store and manager.
The PlasmaClient can ask the PlasmaStore to allocate a new buffer, seal a
buffer, and get a buffer. Buffers are referred to by object IDs, which are
strings.
"""
def __init__(self, store_socket_name, manager_socket_name=None, release_delay=64):
def __init__(self, store_socket_name, manager_socket_name=None,
release_delay=64):
"""Initialize the PlasmaClient.
Args:
store_socket_name (str): Name of the socket the plasma store is listening at.
manager_socket_name (str): Name of the socket the plasma manager is listening at.
store_socket_name (str): Name of the socket the plasma store is listening
at.
manager_socket_name (str): Name of the socket the plasma manager is
listening at.
release_delay (int): The maximum number of objects that the client will
keep and delay releasing (for caching reasons).
"""
self.store_socket_name = store_socket_name
self.manager_socket_name = manager_socket_name
self.alive = True
if manager_socket_name is not None:
self.conn = libplasma.connect(store_socket_name, manager_socket_name, release_delay)
self.conn = libplasma.connect(store_socket_name, manager_socket_name,
release_delay)
else:
self.conn = libplasma.connect(store_socket_name, "", release_delay)
def shutdown(self):
"""Shutdown the client so that it does not send messages.
If we kill the Plasma store and Plasma manager that this client is connected
to, then we can use this method to prevent the client from trying to send
messages to the killed processes.
If we kill the Plasma store and Plasma manager that this client is
connected to, then we can use this method to prevent the client from trying
to send messages to the killed processes.
"""
if self.alive:
libplasma.disconnect(self.conn)
@@ -169,8 +182,8 @@ class PlasmaClient(object):
def get_metadata(self, object_ids, timeout_ms=-1):
"""Create a buffer from the PlasmaStore based on object ID.
If the object has not been sealed yet, this call will block until the object
has been sealed. The retrieved buffer is immutable.
If the object has not been sealed yet, this call will block until the
object has been sealed. The retrieved buffer is immutable.
Args:
object_ids (List[str]): A list of strings used to identify some objects.
@@ -275,7 +288,8 @@ class PlasmaClient(object):
# currently crashes if given duplicate object IDs.
if len(object_ids) != len(set(object_ids)):
raise Exception("Wait requires a list of unique object IDs.")
ready_ids, waiting_ids = libplasma.wait(self.conn, object_ids, timeout, num_returns)
ready_ids, waiting_ids = libplasma.wait(self.conn, object_ids, timeout,
num_returns)
return ready_ids, list(waiting_ids)
def subscribe(self):
@@ -286,11 +300,14 @@ class PlasmaClient(object):
"""Get the next notification from the notification socket."""
return libplasma.receive_notification(self.notification_fd)
DEFAULT_PLASMA_STORE_MEMORY = 10 ** 9
def random_name():
return str(random.randint(0, 99999999))
def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
use_valgrind=False, use_profiler=False,
stdout_file=None, stderr_file=None):
@@ -312,16 +329,20 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
"""
if use_valgrind and use_profiler:
raise Exception("Cannot use valgrind and profiler at the same time.")
plasma_store_executable = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../core/src/plasma/plasma_store")
plasma_store_executable = os.path.join(os.path.abspath(
os.path.dirname(__file__)),
"../core/src/plasma/plasma_store")
plasma_store_name = "/tmp/plasma_store{}".format(random_name())
command = [plasma_store_executable, "-s", plasma_store_name, "-m", str(plasma_store_memory)]
command = [plasma_store_executable,
"-s", plasma_store_name,
"-m", str(plasma_store_memory)]
if use_valgrind:
pid = subprocess.Popen(["valgrind",
"--track-origins=yes",
"--leak-check=full",
"--show-leak-kinds=all",
"--error-exitcode=1"] + command,
stdout=stdout_file, stderr=stderr_file)
stdout=stdout_file, stderr=stderr_file)
time.sleep(1.0)
elif use_profiler:
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
@@ -332,13 +353,16 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
time.sleep(0.1)
return plasma_store_name, pid
def new_port():
return random.randint(10000, 65535)
def start_plasma_manager(store_name, redis_address, node_ip_address="127.0.0.1",
plasma_manager_port=None, num_retries=20,
use_valgrind=False, run_profiler=False,
stdout_file=None, stderr_file=None):
def start_plasma_manager(store_name, redis_address,
node_ip_address="127.0.0.1", plasma_manager_port=None,
num_retries=20, use_valgrind=False,
run_profiler=False, stdout_file=None,
stderr_file=None):
"""Start a plasma manager and return the ports it listens on.
Args:
@@ -361,7 +385,9 @@ def start_plasma_manager(store_name, redis_address, node_ip_address="127.0.0.1",
Raises:
Exception: An exception is raised if the manager could not be started.
"""
plasma_manager_executable = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../core/src/plasma/plasma_manager")
plasma_manager_executable = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
"../core/src/plasma/plasma_manager")
plasma_manager_name = "/tmp/plasma_manager{}".format(random_name())
if plasma_manager_port is not None:
if num_retries != 1:
@@ -385,7 +411,7 @@ def start_plasma_manager(store_name, redis_address, node_ip_address="127.0.0.1",
"--leak-check=full",
"--show-leak-kinds=all",
"--error-exitcode=1"] + command,
stdout=stdout_file, stderr=stderr_file)
stdout=stdout_file, stderr=stderr_file)
elif run_profiler:
process = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
stdout=stdout_file, stderr=stderr_file)
@@ -396,7 +422,7 @@ def start_plasma_manager(store_name, redis_address, node_ip_address="127.0.0.1",
# port is already in use, then we need it to fail within 0.1 seconds.
time.sleep(0.1)
# See if the process has terminated
if process.poll() == None:
if process.poll() is None:
return plasma_manager_name, process, plasma_manager_port
# Generate a new port and try again.
plasma_manager_port = new_port()
+111 -65
View File
@@ -6,23 +6,22 @@ import numpy as np
import os
import random
import signal
import socket
import struct
import subprocess
import sys
import tempfile
import threading
import time
import unittest
import ray.plasma as plasma
from ray.plasma.utils import random_object_id, generate_metadata, write_to_data_buffer, create_object_with_id, create_object
from ray.plasma.utils import (random_object_id, generate_metadata,
create_object_with_id, create_object)
from ray import services
USE_VALGRIND = False
PLASMA_STORE_MEMORY = 1000000000
def assert_get_object_equal(unit_test, client1, client2, object_id, memory_buffer=None, metadata=None):
def assert_get_object_equal(unit_test, client1, client2, object_id,
memory_buffer=None, metadata=None):
client1_buff = client1.get([object_id])[0]
client2_buff = client2.get([object_id])[0]
client1_metadata = client1.get_metadata([object_id])[0]
@@ -32,7 +31,8 @@ def assert_get_object_equal(unit_test, client1, client2, object_id, memory_buffe
# Check that the buffers from the two clients are the same.
unit_test.assertTrue(plasma.buffers_equal(client1_buff, client2_buff))
# Check that the metadata buffers from the two clients are the same.
unit_test.assertTrue(plasma.buffers_equal(client1_metadata, client2_metadata))
unit_test.assertTrue(plasma.buffers_equal(client1_metadata,
client2_metadata))
# If a reference buffer was provided, check that it is the same as well.
if memory_buffer is not None:
unit_test.assertTrue(plasma.buffers_equal(memory_buffer, client1_buff))
@@ -40,11 +40,13 @@ def assert_get_object_equal(unit_test, client1, client2, object_id, memory_buffe
if metadata is not None:
unit_test.assertTrue(plasma.buffers_equal(metadata, client1_metadata))
class TestPlasmaClient(unittest.TestCase):
def setUp(self):
# Start Plasma store.
plasma_store_name, self.p = plasma.start_plasma_store(use_valgrind=USE_VALGRIND)
plasma_store_name, self.p = plasma.start_plasma_store(
use_valgrind=USE_VALGRIND)
# Connect to Plasma.
self.plasma_client = plasma.PlasmaClient(plasma_store_name, None, 64)
# For the eviction test
@@ -107,7 +109,7 @@ class TestPlasmaClient(unittest.TestCase):
object_id = random_object_id()
self.plasma_client.create(object_id, length, generate_metadata(length))
try:
val = self.plasma_client.create(object_id, length, generate_metadata(length))
self.plasma_client.create(object_id, length, generate_metadata(length))
except plasma.plasma_object_exists_error as e:
pass
else:
@@ -125,20 +127,24 @@ class TestPlasmaClient(unittest.TestCase):
metadata_buffers = []
for i in range(num_object_ids):
if i % 2 == 0:
data_buffer, metadata_buffer = create_object_with_id(self.plasma_client, object_ids[i], 2000, 2000)
data_buffer, metadata_buffer = create_object_with_id(
self.plasma_client, object_ids[i], 2000, 2000)
data_buffers.append(data_buffer)
metadata_buffers.append(metadata_buffer)
# Test timing out from some but not all get calls with various timeouts.
for timeout in [0, 10, 100, 1000]:
data_results = self.plasma_client.get(object_ids, timeout_ms=timeout)
metadata_results = self.plasma_client.get(object_ids, timeout_ms=timeout)
# metadata_results = self.plasma_client.get_metadata(object_ids,
# timeout_ms=timeout)
for i in range(num_object_ids):
if i % 2 == 0:
self.assertTrue(plasma.buffers_equal(data_buffers[i // 2], data_results[i]))
# TODO(rkn): We should compare the metadata as well. But currently the
# types are different (e.g., memoryview versus bytearray).
# self.assertTrue(plasma.buffers_equal(metadata_buffers[i // 2], metadata_results[i]))
self.assertTrue(plasma.buffers_equal(data_buffers[i // 2],
data_results[i]))
# TODO(rkn): We should compare the metadata as well. But currently
# the types are different (e.g., memoryview versus bytearray).
# self.assertTrue(plasma.buffers_equal(metadata_buffers[i // 2],
# metadata_results[i]))
else:
self.assertIsNone(results[i])
@@ -148,7 +154,9 @@ class TestPlasmaClient(unittest.TestCase):
def assert_create_raises_plasma_full(unit_test, size):
partial_size = np.random.randint(size)
try:
_, memory_buffer, _ = create_object(unit_test.plasma_client, partial_size, size - partial_size)
_, memory_buffer, _ = create_object(unit_test.plasma_client,
partial_size,
size - partial_size)
except plasma.plasma_out_of_memory_error as e:
pass
else:
@@ -215,7 +223,7 @@ class TestPlasmaClient(unittest.TestCase):
real_object_ids = [random_object_id() for _ in range(100)]
for object_id in real_object_ids:
self.assertFalse(self.plasma_client.contains(object_id))
memory_buffer = self.plasma_client.create(object_id, 100)
self.plasma_client.create(object_id, 100)
self.plasma_client.seal(object_id)
self.assertTrue(self.plasma_client.contains(object_id))
for object_id in fake_object_ids:
@@ -226,7 +234,7 @@ class TestPlasmaClient(unittest.TestCase):
def test_hash(self):
# Check the hash of an object that doesn't exist.
object_id1 = random_object_id()
h = self.plasma_client.hash(object_id1)
self.plasma_client.hash(object_id1)
length = 1000
# Create a random object, and check that the hash function always returns
@@ -357,7 +365,7 @@ class TestPlasmaClient(unittest.TestCase):
length = 1000
memory_buffer = self.plasma_client.create(object_id, length)
# Make sure we cannot access memory out of bounds.
self.assertRaises(Exception, lambda : memory_buffer[length])
self.assertRaises(Exception, lambda: memory_buffer[length])
# Seal the object.
self.plasma_client.seal(object_id)
# This test is commented out because it currently fails.
@@ -367,6 +375,7 @@ class TestPlasmaClient(unittest.TestCase):
# self.assertRaises(Exception, illegal_assignment)
# Get the object.
memory_buffer = self.plasma_client.get([object_id])[0]
# Make sure the object is read only.
def illegal_assignment():
memory_buffer[0] = chr(0)
@@ -413,18 +422,20 @@ class TestPlasmaClient(unittest.TestCase):
def test_subscribe(self):
# Subscribe to notifications from the Plasma Store.
sock = self.plasma_client.subscribe()
self.plasma_client.subscribe()
for i in [1, 10, 100, 1000, 10000, 100000]:
object_ids = [random_object_id() for _ in range(i)]
metadata_sizes = [np.random.randint(1000) for _ in range(i)]
data_sizes = [np.random.randint(1000) for _ in range(i)]
for j in range(i):
self.plasma_client.create(object_ids[j], size=data_sizes[j],
metadata=bytearray(np.random.bytes(metadata_sizes[j])))
self.plasma_client.create(
object_ids[j], size=data_sizes[j],
metadata=bytearray(np.random.bytes(metadata_sizes[j])))
self.plasma_client.seal(object_ids[j])
# Check that we received notifications for all of the objects.
for j in range(i):
recv_objid, recv_dsize, recv_msize = self.plasma_client.get_next_notification()
notification_info = self.plasma_client.get_next_notification()
recv_objid, recv_dsize, recv_msize = notification_info
self.assertEqual(object_ids[j], recv_objid)
self.assertEqual(data_sizes[j], recv_dsize)
self.assertEqual(metadata_sizes[j], recv_msize)
@@ -432,20 +443,22 @@ class TestPlasmaClient(unittest.TestCase):
def test_subscribe_deletions(self):
# Subscribe to notifications from the Plasma Store. We use plasma_client2
# to make sure that all used objects will get evicted properly.
sock = self.plasma_client2.subscribe()
self.plasma_client2.subscribe()
for i in [1, 10, 100, 1000, 10000, 100000]:
object_ids = [random_object_id() for _ in range(i)]
# Add 1 to the sizes to make sure we have nonzero object sizes.
metadata_sizes = [np.random.randint(1000) + 1 for _ in range(i)]
data_sizes = [np.random.randint(1000) + 1 for _ in range(i)]
for j in range(i):
x = self.plasma_client2.create(object_ids[j], size=data_sizes[j],
metadata=bytearray(np.random.bytes(metadata_sizes[j])))
x = self.plasma_client2.create(
object_ids[j], size=data_sizes[j],
metadata=bytearray(np.random.bytes(metadata_sizes[j])))
self.plasma_client2.seal(object_ids[j])
del x
# Check that we received notifications for creating all of the objects.
for j in range(i):
recv_objid, recv_dsize, recv_msize = self.plasma_client2.get_next_notification()
notification_info = self.plasma_client2.get_next_notification()
recv_objid, recv_dsize, recv_msize = notification_info
self.assertEqual(object_ids[j], recv_objid)
self.assertEqual(data_sizes[j], recv_dsize)
self.assertEqual(metadata_sizes[j], recv_msize)
@@ -453,8 +466,10 @@ class TestPlasmaClient(unittest.TestCase):
# Check that we receive notifications for deleting all objects, as we
# evict them.
for j in range(i):
self.assertEqual(self.plasma_client2.evict(1), data_sizes[j] + metadata_sizes[j])
recv_objid, recv_dsize, recv_msize = self.plasma_client2.get_next_notification()
self.assertEqual(self.plasma_client2.evict(1),
data_sizes[j] + metadata_sizes[j])
notification_info = self.plasma_client2.get_next_notification()
recv_objid, recv_dsize, recv_msize = notification_info
self.assertEqual(object_ids[j], recv_objid)
self.assertEqual(-1, recv_dsize)
self.assertEqual(-1, recv_msize)
@@ -469,18 +484,22 @@ class TestPlasmaClient(unittest.TestCase):
metadata_sizes.append(np.random.randint(1000))
data_sizes.append(np.random.randint(1000))
for i in range(num_object_ids):
x = self.plasma_client2.create(object_ids[i], size=data_sizes[i],
metadata=bytearray(np.random.bytes(metadata_sizes[i])))
x = self.plasma_client2.create(
object_ids[i], size=data_sizes[i],
metadata=bytearray(np.random.bytes(metadata_sizes[i])))
self.plasma_client2.seal(object_ids[i])
del x
for i in range(num_object_ids):
recv_objid, recv_dsize, recv_msize = self.plasma_client2.get_next_notification()
notification_info = self.plasma_client2.get_next_notification()
recv_objid, recv_dsize, recv_msize = notification_info
self.assertEqual(object_ids[i], recv_objid)
self.assertEqual(data_sizes[i], recv_dsize)
self.assertEqual(metadata_sizes[i], recv_msize)
self.assertEqual(self.plasma_client2.evict(1), data_sizes[-1] + metadata_sizes[-1])
self.assertEqual(self.plasma_client2.evict(1),
data_sizes[-1] + metadata_sizes[-1])
for i in range(num_object_ids):
recv_objid, recv_dsize, recv_msize = self.plasma_client2.get_next_notification()
notification_info = self.plasma_client2.get_next_notification()
recv_objid, recv_dsize, recv_msize = notification_info
self.assertEqual(object_ids[i], recv_objid)
self.assertEqual(-1, recv_dsize)
self.assertEqual(-1, recv_msize)
@@ -495,8 +514,10 @@ class TestPlasmaManager(unittest.TestCase):
# Start a Redis server.
redis_address = services.address("127.0.0.1", services.start_redis()[0])
# Start two PlasmaManagers.
manager_name1, self.p4, self.port1 = plasma.start_plasma_manager(store_name1, redis_address, use_valgrind=USE_VALGRIND)
manager_name2, self.p5, self.port2 = plasma.start_plasma_manager(store_name2, redis_address, use_valgrind=USE_VALGRIND)
manager_name1, self.p4, self.port1 = plasma.start_plasma_manager(
store_name1, redis_address, use_valgrind=USE_VALGRIND)
manager_name2, self.p5, self.port2 = plasma.start_plasma_manager(
store_name2, redis_address, use_valgrind=USE_VALGRIND)
# Connect two PlasmaClients.
self.client1 = plasma.PlasmaClient(store_name1, manager_name1)
self.client2 = plasma.PlasmaClient(store_name2, manager_name2)
@@ -513,7 +534,8 @@ class TestPlasmaManager(unittest.TestCase):
# Kill the Plasma store and Plasma manager processes.
if USE_VALGRIND:
time.sleep(1) # give processes opportunity to finish work
# Give processes opportunity to finish work.
time.sleep(1)
for process in self.processes_to_kill:
process.send_signal(signal.SIGTERM)
process.wait()
@@ -530,7 +552,8 @@ class TestPlasmaManager(unittest.TestCase):
def test_fetch(self):
for _ in range(10):
# Create an object.
object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000, 2000)
object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000,
2000)
self.client1.fetch([object_id1])
self.assertEqual(self.client1.contains(object_id1), True)
self.assertEqual(self.client2.contains(object_id1), False)
@@ -546,7 +569,8 @@ class TestPlasmaManager(unittest.TestCase):
object_id2 = random_object_id()
self.client1.fetch([object_id2])
self.assertEqual(self.client1.contains(object_id2), False)
memory_buffer2, metadata2 = create_object_with_id(self.client2, object_id2, 2000, 2000)
memory_buffer2, metadata2 = create_object_with_id(self.client2, object_id2,
2000, 2000)
# # Check that the object has been fetched.
# self.assertEqual(self.client1.contains(object_id2), True)
# Compare the two buffers.
@@ -560,11 +584,12 @@ class TestPlasmaManager(unittest.TestCase):
for _ in range(10):
self.client1.fetch([object_id3])
self.client2.fetch([object_id3])
memory_buffer3, metadata3 = create_object_with_id(self.client1, object_id3, 2000, 2000)
memory_buffer3, metadata3 = create_object_with_id(self.client1, object_id3,
2000, 2000)
for _ in range(10):
self.client1.fetch([object_id3])
self.client2.fetch([object_id3])
#TODO(rkn): Right now we must wait for the object table to be updated.
# TODO(rkn): Right now we must wait for the object table to be updated.
while not self.client2.contains(object_id3):
self.client2.fetch([object_id3])
assert_get_object_equal(self, self.client1, self.client2, object_id3,
@@ -573,14 +598,17 @@ class TestPlasmaManager(unittest.TestCase):
def test_fetch_multiple(self):
for _ in range(20):
# Create two objects and a third fake one that doesn't exist.
object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000, 2000)
object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000,
2000)
missing_object_id = random_object_id()
object_id2, memory_buffer2, metadata2 = create_object(self.client1, 2000, 2000)
object_id2, memory_buffer2, metadata2 = create_object(self.client1, 2000,
2000)
object_ids = [object_id1, missing_object_id, object_id2]
# Fetch the objects from the other plasma store. The second object ID
# should timeout since it does not exist.
# TODO(rkn): Right now we must wait for the object table to be updated.
while (not self.client2.contains(object_id1)) or (not self.client2.contains(object_id2)):
while ((not self.client2.contains(object_id1)) or
(not self.client2.contains(object_id2))):
self.client2.fetch(object_ids)
# Compare the buffers of the objects that do exist.
assert_get_object_equal(self, self.client1, self.client2, object_id1,
@@ -597,7 +625,8 @@ class TestPlasmaManager(unittest.TestCase):
# Check that we can call fetch with duplicated object IDs.
object_id3 = random_object_id()
self.client1.fetch([object_id3, object_id3])
object_id4, memory_buffer4, metadata4 = create_object(self.client1, 2000, 2000)
object_id4, memory_buffer4, metadata4 = create_object(self.client1, 2000,
2000)
time.sleep(0.1)
# TODO(rkn): Right now we must wait for the object table to be updated.
while not self.client2.contains(object_id4):
@@ -623,7 +652,8 @@ class TestPlasmaManager(unittest.TestCase):
obj_id2 = random_object_id()
self.client1.create(obj_id2, 1000)
# Don't seal.
ready, waiting = self.client1.wait([obj_id2, obj_id1], timeout=100, num_returns=1)
ready, waiting = self.client1.wait([obj_id2, obj_id1], timeout=100,
num_returns=1)
self.assertEqual(set(ready), set([obj_id1]))
self.assertEqual(set(waiting), set([obj_id2]))
@@ -636,11 +666,13 @@ class TestPlasmaManager(unittest.TestCase):
t = threading.Timer(0.1, finish)
t.start()
ready, waiting = self.client1.wait([obj_id3, obj_id2, obj_id1], timeout=1000, num_returns=2)
ready, waiting = self.client1.wait([obj_id3, obj_id2, obj_id1],
timeout=1000, num_returns=2)
self.assertEqual(set(ready), set([obj_id1, obj_id3]))
self.assertEqual(set(waiting), set([obj_id2]))
# Test if the appropriate number of objects is shown if some objects are not ready
# Test if the appropriate number of objects is shown if some objects are
# not ready.
ready, waiting = self.client1.wait([obj_id3, obj_id2, obj_id1], 100, 3)
self.assertEqual(set(ready), set([obj_id1, obj_id3]))
self.assertEqual(set(waiting), set([obj_id2]))
@@ -651,9 +683,9 @@ class TestPlasmaManager(unittest.TestCase):
# Test calling wait a bunch of times.
object_ids = []
# TODO(rkn): Increasing n to 100 (or larger) will cause failures. The
# problem appears to be that the number of timers added to the manager event
# loop slow down the manager so much that some of the asynchronous Redis
# commands timeout triggering fatal failure callbacks.
# problem appears to be that the number of timers added to the manager
# event loop slow down the manager so much that some of the asynchronous
# Redis commands timeout triggering fatal failure callbacks.
n = 40
for i in range(n * (n + 1) // 2):
if i % 2 == 0:
@@ -669,7 +701,8 @@ class TestPlasmaManager(unittest.TestCase):
self.assertEqual(len(ready), i)
retrieved += ready
self.assertEqual(set(retrieved), set(object_ids))
ready, waiting = self.client1.wait(object_ids, timeout=1000, num_returns=len(object_ids))
ready, waiting = self.client1.wait(object_ids, timeout=1000,
num_returns=len(object_ids))
self.assertEqual(set(ready), set(object_ids))
self.assertEqual(waiting, [])
# Try waiting for all of the object IDs on the second client.
@@ -680,7 +713,8 @@ class TestPlasmaManager(unittest.TestCase):
self.assertEqual(len(ready), i)
retrieved += ready
self.assertEqual(set(retrieved), set(object_ids))
ready, waiting = self.client2.wait(object_ids, timeout=1000, num_returns=len(object_ids))
ready, waiting = self.client2.wait(object_ids, timeout=1000,
num_returns=len(object_ids))
self.assertEqual(set(ready), set(object_ids))
self.assertEqual(waiting, [])
@@ -702,7 +736,8 @@ class TestPlasmaManager(unittest.TestCase):
num_attempts = 100
for _ in range(100):
# Create an object.
object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000, 2000)
object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000,
2000)
# Transfer the buffer to the the other Plasma store. There is a race
# condition on the create and transfer of the object, so keep trying
# until the object appears on the second Plasma store.
@@ -721,10 +756,12 @@ class TestPlasmaManager(unittest.TestCase):
# self.client1.transfer("127.0.0.1", self.port2, object_id1)
# # Compare the two buffers.
# assert_get_object_equal(self, self.client1, self.client2, object_id1,
# memory_buffer=memory_buffer1, metadata=metadata1)
# memory_buffer=memory_buffer1,
# metadata=metadata1)
# Create an object.
object_id2, memory_buffer2, metadata2 = create_object(self.client2, 20000, 20000)
object_id2, memory_buffer2, metadata2 = create_object(self.client2,
20000, 20000)
# Transfer the buffer to the the other Plasma store. There is a race
# condition on the create and transfer of the object, so keep trying
# until the object appears on the second Plasma store.
@@ -742,17 +779,19 @@ class TestPlasmaManager(unittest.TestCase):
def test_illegal_functionality(self):
# Create an object id string.
object_id = random_object_id()
# object_id = random_object_id()
# Create a new buffer.
# memory_buffer = self.client1.create(object_id, 20000)
# This test is commented out because it currently fails.
# # Transferring the buffer before sealing it should fail.
# self.assertRaises(Exception, lambda : self.manager1.transfer(1, object_id))
# self.assertRaises(Exception,
# lambda : self.manager1.transfer(1, object_id))
pass
def test_stresstest(self):
a = time.time()
object_ids = []
for i in range(10000): # TODO(pcm): increase this to 100000
for i in range(10000): # TODO(pcm): increase this to 100000.
object_id = random_object_id()
object_ids.append(object_id)
self.client1.create(object_id, 1)
@@ -763,13 +802,16 @@ class TestPlasmaManager(unittest.TestCase):
print("it took", b, "seconds to put and transfer the objects")
class TestPlasmaManagerRecovery(unittest.TestCase):
def setUp(self):
# Start a Plasma store.
self.store_name, self.p2 = plasma.start_plasma_store(use_valgrind=USE_VALGRIND)
self.store_name, self.p2 = plasma.start_plasma_store(
use_valgrind=USE_VALGRIND)
# Start a Redis server.
self.redis_address = services.address("127.0.0.1", services.start_redis()[0])
self.redis_address = services.address("127.0.0.1",
services.start_redis()[0])
# Start a PlasmaManagers.
manager_name, self.p3, self.port1 = plasma.start_plasma_manager(
self.store_name,
@@ -791,7 +833,8 @@ class TestPlasmaManagerRecovery(unittest.TestCase):
# Kill the Plasma store and Plasma manager processes.
if USE_VALGRIND:
time.sleep(1) # give processes opportunity to finish work
# Give processes opportunity to finish work.
time.sleep(1)
for process in self.processes_to_kill:
process.send_signal(signal.SIGTERM)
process.wait()
@@ -818,23 +861,26 @@ class TestPlasmaManagerRecovery(unittest.TestCase):
self.assertEqual(waiting, [])
# Start a second plasma manager attached to the same store.
manager_name, self.p5, self.port2 = plasma.start_plasma_manager(self.store_name, self.redis_address, use_valgrind=USE_VALGRIND)
manager_name, self.p5, self.port2 = plasma.start_plasma_manager(
self.store_name, self.redis_address, use_valgrind=USE_VALGRIND)
self.processes_to_kill = [self.p5] + self.processes_to_kill
# Check that the second manager knows about existing objects.
client2 = plasma.PlasmaClient(self.store_name, manager_name)
ready, waiting = [], object_ids
while True:
ready, waiting = client2.wait(object_ids, num_returns=num_objects, timeout=0)
ready, waiting = client2.wait(object_ids, num_returns=num_objects,
timeout=0)
if len(ready) == len(object_ids):
break
self.assertEqual(set(ready), set(object_ids))
self.assertEqual(waiting, [])
if __name__ == "__main__":
if len(sys.argv) > 1:
# pop the argument so we don't mess with unittest's own argument parser
# Pop the argument so we don't mess with unittest's own argument parser.
if sys.argv[-1] == "valgrind":
arg = sys.argv.pop()
USE_VALGRIND = True
+9 -2
View File
@@ -5,9 +5,11 @@ from __future__ import print_function
import numpy as np
import random
def random_object_id():
return np.random.bytes(20)
def generate_metadata(length):
metadata_buffer = bytearray(length)
if length > 0:
@@ -17,6 +19,7 @@ def generate_metadata(length):
metadata_buffer[random.randint(0, length - 1)] = random.randint(0, 255)
return metadata_buffer
def write_to_data_buffer(buff, length):
if length > 0:
buff[0] = chr(random.randint(0, 255))
@@ -24,7 +27,9 @@ def write_to_data_buffer(buff, length):
for _ in range(100):
buff[random.randint(0, length - 1)] = chr(random.randint(0, 255))
def create_object_with_id(client, object_id, data_size, metadata_size, seal=True):
def create_object_with_id(client, object_id, data_size, metadata_size,
seal=True):
metadata = generate_metadata(metadata_size)
memory_buffer = client.create(object_id, data_size, metadata)
write_to_data_buffer(memory_buffer, data_size)
@@ -32,7 +37,9 @@ def create_object_with_id(client, object_id, data_size, metadata_size, seal=True
client.seal(object_id)
return memory_buffer, metadata
def create_object(client, data_size, metadata_size, seal=True):
object_id = random_object_id()
memory_buffer, metadata = create_object_with_id(client, object_id, data_size, metadata_size, seal=seal)
memory_buffer, metadata = create_object_with_id(client, object_id, data_size,
metadata_size, seal=seal)
return object_id, memory_buffer, metadata
+44 -10
View File
@@ -7,6 +7,7 @@ import numpy as np
import ray.numbuf
import ray.pickling as pickling
def check_serializable(cls):
"""Throws an exception if Ray cannot serialize this class efficiently.
@@ -21,15 +22,30 @@ def check_serializable(cls):
# This case works.
return
if not hasattr(cls, "__new__"):
raise Exception("The class {} does not have a '__new__' attribute, and is probably an old-style class. We do not support this. Please either make it a new-style class by inheriting from 'object', or use 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls))
raise Exception("The class {} does not have a '__new__' attribute, and is "
"probably an old-style class. We do not support this. "
"Please either make it a new-style class by inheriting "
"from 'object', or use "
"'ray.register_class(cls, pickle=True)'. However, note "
"that pickle is inefficient.".format(cls))
try:
obj = cls.__new__(cls)
except:
raise Exception("The class {} has overridden '__new__', so Ray may not be able to serialize it efficiently. Try using 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls))
raise Exception("The class {} has overridden '__new__', so Ray may not be "
"able to serialize it efficiently. Try using "
"'ray.register_class(cls, pickle=True)'. However, note "
"that pickle is inefficient.".format(cls))
if not hasattr(obj, "__dict__"):
raise Exception("Objects of the class {} do not have a `__dict__` attribute, so Ray cannot serialize it efficiently. Try using 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls))
raise Exception("Objects of the class {} do not have a `__dict__` "
"attribute, so Ray cannot serialize it efficiently. Try "
"using 'ray.register_class(cls, pickle=True)'. However, "
"note that pickle is inefficient.".format(cls))
if hasattr(obj, "__slots__"):
raise Exception("The class {} uses '__slots__', so Ray may not be able to serialize it efficiently. Try using 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls))
raise Exception("The class {} uses '__slots__', so Ray may not be able to "
"serialize it efficiently. Try using "
"'ray.register_class(cls, pickle=True)'. However, note "
"that pickle is inefficient.".format(cls))
# This field keeps track of a whitelisted set of classes that Ray will
# serialize.
@@ -38,10 +54,12 @@ classes_to_pickle = set()
custom_serializers = {}
custom_deserializers = {}
def class_identifier(typ):
"""Return a string that identifies this type."""
return "{}.{}".format(typ.__module__, typ.__name__)
def is_named_tuple(cls):
"""Return True if cls is a namedtuple and False otherwise."""
b = cls.__bases__
@@ -52,7 +70,9 @@ def is_named_tuple(cls):
return False
return all(type(n) == str for n in f)
def add_class_to_whitelist(cls, pickle=False, custom_serializer=None, custom_deserializer=None):
def add_class_to_whitelist(cls, pickle=False, custom_serializer=None,
custom_deserializer=None):
"""Add cls to the list of classes that we can serialize.
Args:
@@ -72,13 +92,21 @@ def add_class_to_whitelist(cls, pickle=False, custom_serializer=None, custom_des
custom_serializers[class_id] = custom_serializer
custom_deserializers[class_id] = custom_deserializer
# Here we define a custom serializer and deserializer for handling numpy
# arrays that contain objects.
def array_custom_serializer(obj):
return obj.tolist(), obj.dtype.str
def array_custom_deserializer(serialized_obj):
return np.array(serialized_obj[0], dtype=np.dtype(serialized_obj[1]))
add_class_to_whitelist(np.ndarray, pickle=False, custom_serializer=array_custom_serializer, custom_deserializer=array_custom_deserializer)
add_class_to_whitelist(np.ndarray, pickle=False,
custom_serializer=array_custom_serializer,
custom_deserializer=array_custom_deserializer)
def serialize(obj):
"""This is the callback that will be used by numbuf.
@@ -94,7 +122,9 @@ def serialize(obj):
"""
class_id = class_identifier(type(obj))
if class_id not in whitelisted_classes:
raise Exception("Ray does not know how to serialize objects of type {}. To fix this, call 'ray.register_class' with this class.".format(type(obj)))
raise Exception("Ray does not know how to serialize objects of type {}. "
"To fix this, call 'ray.register_class' with this class."
.format(type(obj)))
if class_id in classes_to_pickle:
serialized_obj = {"data": pickling.dumps(obj)}
elif class_id in custom_serializers.keys():
@@ -107,10 +137,12 @@ def serialize(obj):
elif hasattr(obj, "__dict__"):
serialized_obj = obj.__dict__
else:
raise Exception("We do not know how to serialize the object '{}'".format(obj))
raise Exception("We do not know how to serialize the object '{}'"
.format(obj))
result = dict(serialized_obj, **{"_pytype_": class_id})
return result
def deserialize(serialized_obj):
"""This is the callback that will be used by numbuf.
@@ -139,11 +171,13 @@ def deserialize(serialized_obj):
obj.__dict__.update(serialized_obj)
return obj
def set_callbacks():
"""Register the custom callbacks with numbuf.
The serialize callback is used to serialize objects that numbuf does not know
how to serialize (for example custom Python classes). The deserialize callback
is used to serialize objects that were serialized by the serialize callback.
how to serialize (for example custom Python classes). The deserialize
callback is used to serialize objects that were serialized by the serialize
callback.
"""
ray.numbuf.register_callbacks(serialize, deserialize)
+162 -103
View File
@@ -10,7 +10,6 @@ import random
import redis
import signal
import socket
import string
import subprocess
import sys
import time
@@ -60,16 +59,20 @@ ObjectStoreAddress = namedtuple("ObjectStoreAddress", ["name",
"manager_name",
"manager_port"])
def address(ip_address, port):
return ip_address + ":" + str(port)
def get_ip_address(address):
try:
ip_address = address.split(":")[0]
except:
raise Exception("Unable to parse IP address from address {}".format(address))
raise Exception("Unable to parse IP address from address "
"{}".format(address))
return ip_address
def get_port(address):
try:
port = int(address.split(":")[1])
@@ -77,12 +80,15 @@ def get_port(address):
raise Exception("Unable to parse port from address {}".format(address))
return port
def new_port():
return random.randint(10000, 65535)
def random_name():
return str(random.randint(0, 99999999))
def kill_process(p):
"""Kill a process.
@@ -92,11 +98,15 @@ def kill_process(p):
Returns:
True if the process was killed successfully and false otherwise.
"""
if p.poll() is not None: # process has already terminated
if p.poll() is not None:
# The process has already terminated.
return True
if RUN_LOCAL_SCHEDULER_PROFILER or RUN_PLASMA_MANAGER_PROFILER or RUN_PLASMA_STORE_PROFILER:
os.kill(p.pid, signal.SIGINT) # Give process signal to write profiler data.
time.sleep(0.1) # Wait for profiling data to be written.
if any([RUN_LOCAL_SCHEDULER_PROFILER, RUN_PLASMA_MANAGER_PROFILER,
RUN_PLASMA_STORE_PROFILER]):
# Give process signal to write profiler data.
os.kill(p.pid, signal.SIGINT)
# Wait for profiling data to be written.
time.sleep(0.1)
# Allow the process one second to exit gracefully.
p.terminate()
@@ -118,6 +128,7 @@ def kill_process(p):
# The process was not killed for some reason.
return False
def cleanup():
"""When running in local mode, shutdown the Ray processes.
@@ -138,6 +149,7 @@ def cleanup():
if not successfully_shut_down:
print("Ray did not shut down properly.")
def all_processes_alive(exclude=[]):
"""Check if all of the processes are still alive.
@@ -147,10 +159,13 @@ def all_processes_alive(exclude=[]):
for process_type, processes in all_processes.items():
# Note that p.poll() returns the exit code that the process exited with, so
# an exit code of None indicates that the process is still alive.
if not all([p.poll() is None for p in processes]) and process_type not in exclude:
processes_alive = [p.poll() is None for p in processes]
if (not all(processes_alive) and process_type not in exclude):
print("A process of type {} has dead.".format(process_type))
return False
return True
def get_node_ip_address(address="8.8.8.8:53"):
"""Determine the IP address of the local node.
@@ -166,6 +181,7 @@ def get_node_ip_address(address="8.8.8.8:53"):
s.connect((ip_address, int(port)))
return s.getsockname()[0]
def record_log_files_in_redis(redis_address, node_ip_address, log_files):
"""Record in Redis that a new log file has been created.
@@ -187,6 +203,7 @@ def record_log_files_in_redis(redis_address, node_ip_address, log_files):
log_file_list_key = "LOG_FILENAMES:{}".format(node_ip_address)
redis_client.rpush(log_file_list_key, log_file.name)
def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5):
"""Wait for a Redis server to be available.
@@ -208,7 +225,8 @@ def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5):
while counter < num_retries:
try:
# Run some random command and see if it worked.
print("Waiting for redis server at {}:{} to respond...".format(redis_ip_address, redis_port))
print("Waiting for redis server at {}:{} to respond..."
.format(redis_ip_address, redis_port))
redis_client.client_list()
except redis.ConnectionError as e:
# Wait a little bit.
@@ -218,7 +236,10 @@ def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5):
else:
break
if counter == num_retries:
raise Exception("Unable to connect to Redis. If the Redis instance is on a different machine, check that your firewall is configured properly.")
raise Exception("Unable to connect to Redis. If the Redis instance is on "
"a different machine, check that your firewall is "
"configured properly.")
def start_redis(node_ip_address="127.0.0.1", port=None, num_retries=20,
stdout_file=None, stderr_file=None, cleanup=True):
@@ -239,13 +260,18 @@ def start_redis(node_ip_address="127.0.0.1", port=None, num_retries=20,
Returns:
A tuple of the port used by Redis and a handle to the process that was
started. If a port is passed in, then the returned port value is the same.
started. If a port is passed in, then the returned port value is the
same.
Raises:
Exception: An exception is raised if Redis could not be started.
"""
redis_filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), "./core/src/common/thirdparty/redis/src/redis-server")
redis_module = os.path.join(os.path.dirname(os.path.abspath(__file__)), "./core/src/common/redis_module/libray_redis_module.so")
redis_filepath = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"./core/src/common/thirdparty/redis/src/redis-server")
redis_module = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"./core/src/common/redis_module/libray_redis_module.so")
assert os.path.isfile(redis_filepath)
assert os.path.isfile(redis_module)
counter = 0
@@ -261,7 +287,7 @@ def start_redis(node_ip_address="127.0.0.1", port=None, num_retries=20,
"--port", str(port),
"--loglevel", "warning",
"--loadmodule", redis_module],
stdout=stdout_file, stderr=stderr_file)
stdout=stdout_file, stderr=stderr_file)
time.sleep(0.1)
# Check if Redis successfully started (or at least if it the executable did
# not exit within 0.1 seconds).
@@ -291,6 +317,7 @@ def start_redis(node_ip_address="127.0.0.1", port=None, num_retries=20,
[stdout_file, stderr_file])
return port, p
def start_log_monitor(redis_address, node_ip_address, stdout_file=None,
stderr_file=None, cleanup=cleanup):
"""Start a log monitor process.
@@ -307,7 +334,9 @@ def start_log_monitor(redis_address, node_ip_address, stdout_file=None,
this process will be killed by services.cleanup() when the Python process
that imported services exits.
"""
log_monitor_filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log_monitor.py")
log_monitor_filepath = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"log_monitor.py")
p = subprocess.Popen(["python", log_monitor_filepath,
"--redis-address", redis_address,
"--node-ip-address", node_ip_address],
@@ -317,13 +346,15 @@ def start_log_monitor(redis_address, node_ip_address, stdout_file=None,
record_log_files_in_redis(redis_address, node_ip_address,
[stdout_file, stderr_file])
def start_global_scheduler(redis_address, node_ip_address, stdout_file=None,
stderr_file=None, cleanup=True):
"""Start a global scheduler process.
Args:
redis_address (str): The address of the Redis instance.
node_ip_address: The IP address of the node that this scheduler will run on.
node_ip_address: The IP address of the node that this scheduler will run
on.
stdout_file: A file handle opened for writing to redirect stdout to. If no
redirection should happen, then this should be None.
stderr_file: A file handle opened for writing to redirect stderr to. If no
@@ -340,6 +371,7 @@ def start_global_scheduler(redis_address, node_ip_address, stdout_file=None,
record_log_files_in_redis(redis_address, node_ip_address,
[stdout_file, stderr_file])
def start_webui(redis_address, node_ip_address, backend_stdout_file=None,
backend_stderr_file=None, polymer_stdout_file=None,
polymer_stderr_file=None, cleanup=True):
@@ -367,8 +399,11 @@ def start_webui(redis_address, node_ip_address, backend_stdout_file=None,
Return:
True if the web UI was successfully started, otherwise false.
"""
webui_backend_filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../webui/backend/ray_ui.py")
webui_directory = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../webui/")
webui_backend_filepath = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"../../webui/backend/ray_ui.py")
webui_directory = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"../../webui/")
if sys.version_info >= (3, 0):
python_executable = "python"
@@ -376,7 +411,8 @@ def start_webui(redis_address, node_ip_address, backend_stdout_file=None,
# If the user is using Python 2, it is still possible to run the webserver
# separately with Python 3, so try to find a Python 3 executable.
try:
python_executable = subprocess.check_output(["which", "python3"]).decode("ascii").strip()
python_executable = subprocess.check_output(
["which", "python3"]).decode("ascii").strip()
except Exception as e:
print("Not starting the web UI because the web UI requires Python 3.")
return False
@@ -394,8 +430,8 @@ def start_webui(redis_address, node_ip_address, backend_stdout_file=None,
return False
# Try to start polymer. If this fails, it may that port 8080 is already in
# use. It'd be nice to test for this, but doing so by calling "bind" may start
# using the port and prevent polymer from using it.
# use. It'd be nice to test for this, but doing so by calling "bind" may
# start using the port and prevent polymer from using it.
try:
polymer_process = subprocess.Popen(["polymer", "serve", "--port", "8080"],
cwd=webui_directory,
@@ -433,6 +469,7 @@ def start_webui(redis_address, node_ip_address, backend_stdout_file=None,
return True
def start_local_scheduler(redis_address,
node_ip_address,
plasma_store_name,
@@ -472,8 +509,8 @@ def start_local_scheduler(redis_address,
The name of the local scheduler socket.
"""
if num_cpus is None:
# By default, use the number of hardware execution threads for the number of
# cores.
# By default, use the number of hardware execution threads for the number
# of cores.
num_cpus = multiprocessing.cpu_count()
if num_gpus is None:
# By default, assume this node has no GPUs.
@@ -496,6 +533,7 @@ def start_local_scheduler(redis_address,
[stdout_file, stderr_file])
return local_scheduler_name
def start_objstore(node_ip_address, redis_address, object_manager_port=None,
store_stdout_file=None, store_stderr_file=None,
manager_stdout_file=None, manager_stderr_file=None,
@@ -511,10 +549,10 @@ def start_objstore(node_ip_address, redis_address, object_manager_port=None,
If no redirection should happen, then this should be None.
store_stderr_file: A file handle opened for writing to redirect stderr to.
If no redirection should happen, then this should be None.
manager_stdout_file: A file handle opened for writing to redirect stdout to.
If no redirection should happen, then this should be None.
manager_stderr_file: A file handle opened for writing to redirect stderr to.
If no redirection should happen, then this should be None.
manager_stdout_file: A file handle opened for writing to redirect stdout
to. If no redirection should happen, then this should be None.
manager_stderr_file: A file handle opened for writing to redirect stderr
to. If no redirection should happen, then this should be None.
cleanup (bool): True if using Ray in local mode. If cleanup is true, then
this process will be killed by serices.cleanup() when the Python process
that imported services exits.
@@ -522,8 +560,8 @@ def start_objstore(node_ip_address, redis_address, object_manager_port=None,
with.
Return:
A tuple of the Plasma store socket name, the Plasma manager socket name, and
the plasma manager port.
A tuple of the Plasma store socket name, the Plasma manager socket name,
and the plasma manager port.
"""
if objstore_memory is None:
# Compute a fraction of the system memory for the Plasma store to use.
@@ -559,7 +597,8 @@ def start_objstore(node_ip_address, redis_address, object_manager_port=None,
stderr_file=store_stderr_file)
# Start the plasma manager.
if object_manager_port is not None:
plasma_manager_name, p2, plasma_manager_port = ray.plasma.start_plasma_manager(
(plasma_manager_name, p2,
plasma_manager_port) = ray.plasma.start_plasma_manager(
plasma_store_name,
redis_address,
plasma_manager_port=object_manager_port,
@@ -570,7 +609,8 @@ def start_objstore(node_ip_address, redis_address, object_manager_port=None,
stderr_file=manager_stderr_file)
assert plasma_manager_port == object_manager_port
else:
plasma_manager_name, p2, plasma_manager_port = ray.plasma.start_plasma_manager(
(plasma_manager_name, p2,
plasma_manager_port) = ray.plasma.start_plasma_manager(
plasma_store_name,
redis_address,
node_ip_address=node_ip_address,
@@ -587,6 +627,7 @@ def start_objstore(node_ip_address, redis_address, object_manager_port=None,
return ObjectStoreAddress(plasma_store_name, plasma_manager_name,
plasma_manager_port)
def start_worker(node_ip_address, object_store_name, object_store_manager_name,
local_scheduler_name, redis_address, worker_path,
stdout_file=None, stderr_file=None, cleanup=True):
@@ -599,8 +640,8 @@ def start_worker(node_ip_address, object_store_name, object_store_manager_name,
object_store_manager_name (str): The name of the object store manager.
local_scheduler_name (str): The name of the local scheduler.
redis_address (str): The address that the Redis server is listening on.
worker_path (str): The path of the source code which the worker process will
run.
worker_path (str): The path of the source code which the worker process
will run.
stdout_file: A file handle opened for writing to redirect stdout to. If no
redirection should happen, then this should be None.
stderr_file: A file handle opened for writing to redirect stderr to. If no
@@ -622,6 +663,7 @@ def start_worker(node_ip_address, object_store_name, object_store_manager_name,
record_log_files_in_redis(redis_address, node_ip_address,
[stdout_file, stderr_file])
def start_monitor(redis_address, node_ip_address, stdout_file=None,
stderr_file=None, cleanup=True):
"""Run a process to monitor the other processes.
@@ -637,7 +679,8 @@ def start_monitor(redis_address, node_ip_address, stdout_file=None,
this process will be killed by services.cleanup() when the Python process
that imported services exits. This is True by default.
"""
monitor_path= os.path.join(os.path.dirname(os.path.abspath(__file__)), "monitor.py")
monitor_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"monitor.py")
command = ["python",
monitor_path,
"--redis-address=" + str(redis_address)]
@@ -647,6 +690,7 @@ def start_monitor(redis_address, node_ip_address, stdout_file=None,
record_log_files_in_redis(redis_address, node_ip_address,
[stdout_file, stderr_file])
def start_ray_processes(address_info=None,
node_ip_address="127.0.0.1",
num_workers=0,
@@ -685,8 +729,9 @@ def start_ray_processes(address_info=None,
start a global scheduler process.
include_redis (bool): If include_redis is True, then start a Redis server
process.
include_log_monitor (bool): If True, then start a log monitor to monitor the
log files for all processes on this node and push their contents to Redis.
include_log_monitor (bool): If True, then start a log monitor to monitor
the log files for all processes on this node and push their contents to
Redis.
include_webui (bool): If True, then attempt to start the web UI. Note that
this is only possible with Python 3.
start_workers_from_local_scheduler (bool): If this flag is True, then start
@@ -713,7 +758,8 @@ def start_ray_processes(address_info=None,
address_info["node_ip_address"] = node_ip_address
if worker_path is None:
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "workers/default_worker.py")
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"workers/default_worker.py")
# Start Redis if there isn't already an instance running. TODO(rkn): We are
# suppressing the output of Redis because on Linux it prints a bunch of
@@ -721,7 +767,8 @@ def start_ray_processes(address_info=None,
# should address the warnings.
redis_address = address_info.get("redis_address")
if include_redis:
redis_stdout_file, redis_stderr_file = new_log_files("redis", redirect_output)
redis_stdout_file, redis_stderr_file = new_log_files("redis",
redirect_output)
if redis_address is None:
# Start a Redis server. The start_redis method will choose a random port.
redis_port, _ = start_redis(node_ip_address,
@@ -735,7 +782,6 @@ def start_ray_processes(address_info=None,
# A Redis address was provided, so start a Redis server with the given
# port. TODO(rkn): We should check that the IP address corresponds to the
# machine that this method is running on.
redis_ip_address = get_ip_address(redis_address)
redis_port = get_port(redis_address)
new_redis_port, _ = start_redis(port=int(redis_port),
num_retries=1,
@@ -744,7 +790,8 @@ def start_ray_processes(address_info=None,
cleanup=cleanup)
assert redis_port == new_redis_port
# Start monitoring the processes.
monitor_stdout_file, monitor_stderr_file = new_log_files("monitor", redirect_output)
monitor_stdout_file, monitor_stderr_file = new_log_files("monitor",
redirect_output)
start_monitor(redis_address,
node_ip_address,
stdout_file=monitor_stdout_file,
@@ -755,8 +802,8 @@ def start_ray_processes(address_info=None,
# Start the log monitor, if necessary.
if include_log_monitor:
log_monitor_stdout_file, log_monitor_stderr_file = new_log_files("log_monitor",
redirect_output=True)
log_monitor_stdout_file, log_monitor_stderr_file = new_log_files(
"log_monitor", redirect_output=True)
start_log_monitor(redis_address,
node_ip_address,
stdout_file=log_monitor_stdout_file,
@@ -765,7 +812,8 @@ def start_ray_processes(address_info=None,
# Start the global scheduler, if necessary.
if include_global_scheduler:
global_scheduler_stdout_file, global_scheduler_stderr_file = new_log_files("global_scheduler", redirect_output)
global_scheduler_stdout_file, global_scheduler_stderr_file = new_log_files(
"global_scheduler", redirect_output)
start_global_scheduler(redis_address,
node_ip_address,
stdout_file=global_scheduler_stdout_file,
@@ -781,7 +829,8 @@ def start_ray_processes(address_info=None,
local_scheduler_socket_names = address_info["local_scheduler_socket_names"]
# Get the ports to use for the object managers if any are provided.
object_manager_ports = address_info["object_manager_ports"] if "object_manager_ports" in address_info else None
object_manager_ports = (address_info["object_manager_ports"]
if "object_manager_ports" in address_info else None)
if not isinstance(object_manager_ports, list):
object_manager_ports = num_local_schedulers * [object_manager_ports]
assert len(object_manager_ports) == num_local_schedulers
@@ -789,23 +838,26 @@ def start_ray_processes(address_info=None,
# Start any object stores that do not yet exist.
for i in range(num_local_schedulers - len(object_store_addresses)):
# Start Plasma.
plasma_store_stdout_file, plasma_store_stderr_file = new_log_files("plasma_store_{}".format(i), redirect_output)
plasma_manager_stdout_file, plasma_manager_stderr_file = new_log_files("plasma_manager_{}".format(i), redirect_output)
object_store_address = start_objstore(node_ip_address,
redis_address,
object_manager_port=object_manager_ports[i],
store_stdout_file=plasma_store_stdout_file,
store_stderr_file=plasma_store_stderr_file,
manager_stdout_file=plasma_manager_stdout_file,
manager_stderr_file=plasma_manager_stderr_file,
cleanup=cleanup)
plasma_store_stdout_file, plasma_store_stderr_file = new_log_files(
"plasma_store_{}".format(i), redirect_output)
plasma_manager_stdout_file, plasma_manager_stderr_file = new_log_files(
"plasma_manager_{}".format(i), redirect_output)
object_store_address = start_objstore(
node_ip_address,
redis_address,
object_manager_port=object_manager_ports[i],
store_stdout_file=plasma_store_stdout_file,
store_stderr_file=plasma_store_stderr_file,
manager_stdout_file=plasma_manager_stdout_file,
manager_stderr_file=plasma_manager_stderr_file,
cleanup=cleanup)
object_store_addresses.append(object_store_address)
time.sleep(0.1)
# Determine how many workers to start for each local scheduler.
num_workers_per_local_scheduler = [0] * num_local_schedulers
workers_per_local_scheduler = [0] * num_local_schedulers
for i in range(num_workers):
num_workers_per_local_scheduler[i % num_local_schedulers] += 1
workers_per_local_scheduler[i % num_local_schedulers] += 1
# Start any local schedulers that do not yet exist.
for i in range(len(local_scheduler_socket_names), num_local_schedulers):
@@ -815,26 +867,28 @@ def start_ray_processes(address_info=None,
object_store_address.manager_port)
# Determine how many workers this local scheduler should start.
if start_workers_from_local_scheduler:
num_local_scheduler_workers = num_workers_per_local_scheduler[i]
num_workers_per_local_scheduler[i] = 0
num_local_scheduler_workers = workers_per_local_scheduler[i]
workers_per_local_scheduler[i] = 0
else:
# If we're starting the workers from Python, the local scheduler should
# not start any workers.
num_local_scheduler_workers = 0
# Start the local scheduler.
local_scheduler_stdout_file, local_scheduler_stderr_file = new_log_files("local_scheduler_{}".format(i), redirect_output)
local_scheduler_name = start_local_scheduler(redis_address,
node_ip_address,
object_store_address.name,
object_store_address.manager_name,
worker_path,
plasma_address=plasma_address,
stdout_file=local_scheduler_stdout_file,
stderr_file=local_scheduler_stderr_file,
cleanup=cleanup,
num_cpus=num_cpus[i],
num_gpus=num_gpus[i],
num_workers=num_local_scheduler_workers)
local_scheduler_stdout_file, local_scheduler_stderr_file = new_log_files(
"local_scheduler_{}".format(i), redirect_output)
local_scheduler_name = start_local_scheduler(
redis_address,
node_ip_address,
object_store_address.name,
object_store_address.manager_name,
worker_path,
plasma_address=plasma_address,
stdout_file=local_scheduler_stdout_file,
stderr_file=local_scheduler_stderr_file,
cleanup=cleanup,
num_cpus=num_cpus[i],
num_gpus=num_gpus[i],
num_workers=num_local_scheduler_workers)
local_scheduler_socket_names.append(local_scheduler_name)
time.sleep(0.1)
@@ -844,11 +898,12 @@ def start_ray_processes(address_info=None,
assert len(local_scheduler_socket_names) == num_local_schedulers
# Start any workers that the local scheduler has not already started.
for i, num_local_scheduler_workers in enumerate(num_workers_per_local_scheduler):
for i, num_local_scheduler_workers in enumerate(workers_per_local_scheduler):
object_store_address = object_store_addresses[i]
local_scheduler_name = local_scheduler_socket_names[i]
for j in range(num_local_scheduler_workers):
worker_stdout_file, worker_stderr_file = new_log_files("worker_{}_{}".format(i, j), redirect_output)
worker_stdout_file, worker_stderr_file = new_log_files(
"worker_{}_{}".format(i, j), redirect_output)
start_worker(node_ip_address,
object_store_address.name,
object_store_address.manager_name,
@@ -858,17 +913,17 @@ def start_ray_processes(address_info=None,
stdout_file=worker_stdout_file,
stderr_file=worker_stderr_file,
cleanup=cleanup)
num_workers_per_local_scheduler[i] -= 1
workers_per_local_scheduler[i] -= 1
# Make sure that we've started all the workers.
assert(sum(num_workers_per_local_scheduler) == 0)
assert(sum(workers_per_local_scheduler) == 0)
# Try to start the web UI.
if include_webui:
backend_stdout_file, backend_stderr_file = new_log_files("webui_backend",
redirect_output=True)
polymer_stdout_file, polymer_stderr_file = new_log_files("webui_polymer",
redirect_output=True)
backend_stdout_file, backend_stderr_file = new_log_files(
"webui_backend", redirect_output=True)
polymer_stdout_file, polymer_stderr_file = new_log_files(
"webui_polymer", redirect_output=True)
successfully_started = start_webui(redis_address,
node_ip_address,
backend_stdout_file=backend_stdout_file,
@@ -883,6 +938,7 @@ def start_ray_processes(address_info=None,
# Return the addresses of the relevant processes.
return address_info
def start_ray_node(node_ip_address,
redis_address,
object_manager_ports=None,
@@ -905,8 +961,8 @@ def start_ray_node(node_ip_address,
managers. There should be one per object manager being started on this
node (typically just one).
num_workers (int): The number of workers to start.
num_local_schedulers (int): The number of local schedulers to start. This is
also the number of plasma stores and plasma managers to start.
num_local_schedulers (int): The number of local schedulers to start. This
is also the number of plasma stores and plasma managers to start.
worker_path (str): The path of the source code that will be run by the
worker.
cleanup (bool): If cleanup is true, then the processes started here will be
@@ -919,10 +975,8 @@ def start_ray_node(node_ip_address,
A dictionary of the address information for the processes that were
started.
"""
address_info = {
"redis_address": redis_address,
"object_manager_ports": object_manager_ports,
}
address_info = {"redis_address": redis_address,
"object_manager_ports": object_manager_ports}
return start_ray_processes(address_info=address_info,
node_ip_address=node_ip_address,
num_workers=num_workers,
@@ -934,6 +988,7 @@ def start_ray_node(node_ip_address,
num_cpus=num_cpus,
num_gpus=num_gpus)
def start_ray_head(address_info=None,
node_ip_address="127.0.0.1",
num_workers=0,
@@ -974,33 +1029,37 @@ def start_ray_head(address_info=None,
A dictionary of the address information for the processes that were
started.
"""
return start_ray_processes(address_info=address_info,
node_ip_address=node_ip_address,
num_workers=num_workers,
num_local_schedulers=num_local_schedulers,
worker_path=worker_path,
cleanup=cleanup,
redirect_output=redirect_output,
include_global_scheduler=True,
include_log_monitor=True,
include_redis=True,
include_webui=True,
start_workers_from_local_scheduler=start_workers_from_local_scheduler,
num_cpus=num_cpus,
num_gpus=num_gpus)
return start_ray_processes(
address_info=address_info,
node_ip_address=node_ip_address,
num_workers=num_workers,
num_local_schedulers=num_local_schedulers,
worker_path=worker_path,
cleanup=cleanup,
redirect_output=redirect_output,
include_global_scheduler=True,
include_log_monitor=True,
include_redis=True,
include_webui=True,
start_workers_from_local_scheduler=start_workers_from_local_scheduler,
num_cpus=num_cpus,
num_gpus=num_gpus)
def new_log_files(name, redirect_output):
"""Generate partially randomized filenames for log files.
Args:
name (str): descriptive string for this log file.
redirect_output (bool): True if files should be generated for logging stdout
and stderr and false if stdout and stderr should not be redirected.
redirect_output (bool): True if files should be generated for logging
stdout and stderr and false if stdout and stderr should not be
redirected.
Returns:
If redirect_output is true, this will return a tuple of two filehandles. The
first is for redirecting stdout and the second is for redirecting stderr.
If redirect_output is false, this will return a tuple of two None objects.
If redirect_output is true, this will return a tuple of two filehandles.
The first is for redirecting stdout and the second is for redirecting
stderr. If redirect_output is false, this will return a tuple of two None
objects.
"""
if not redirect_output:
return None, None
+17
View File
@@ -8,44 +8,53 @@ import numpy as np
# Test simple functionality
@ray.remote(num_return_vals=2)
def handle_int(a, b):
return a + 1, b + 1
# Test timing
@ray.remote
def empty_function():
pass
@ray.remote
def trivial_function():
return 1
# Test keyword arguments
@ray.remote
def keyword_fct1(a, b="hello"):
return "{} {}".format(a, b)
@ray.remote
def keyword_fct2(a="hello", b="world"):
return "{} {}".format(a, b)
@ray.remote
def keyword_fct3(a, b, c="hello", d="world"):
return "{} {} {} {}".format(a, b, c, d)
# Test variable numbers of arguments
@ray.remote
def varargs_fct1(*a):
return " ".join(map(str, a))
@ray.remote
def varargs_fct2(a, *b):
return " ".join(map(str, b))
try:
@ray.remote
def kwargs_throw_exception(**c):
@@ -64,24 +73,29 @@ except:
# test throwing an exception
@ray.remote
def throw_exception_fct1():
raise Exception("Test function 1 intentionally failed.")
@ray.remote
def throw_exception_fct2():
raise Exception("Test function 2 intentionally failed.")
@ray.remote(num_return_vals=3)
def throw_exception_fct3(x):
raise Exception("Test function 3 intentionally failed.")
# test Python mode
@ray.remote
def python_mode_f():
return np.array([0, 0])
@ray.remote
def python_mode_g(x):
x[0] = 1
@@ -89,14 +103,17 @@ def python_mode_g(x):
# test no return values
@ray.remote
def no_op():
pass
class TestClass(object):
def __init__(self):
self.a = 5
@ray.remote
def test_unknown_type():
return TestClass()
+439 -272
View File
File diff suppressed because it is too large Load Diff
+31 -17
View File
@@ -3,25 +3,33 @@ from __future__ import division
from __future__ import print_function
import argparse
import binascii
import numpy as np
import redis
import traceback
import sys
import binascii
import ray
parser = argparse.ArgumentParser(description="Parse addresses for the worker to connect to.")
parser.add_argument("--node-ip-address", required=True, type=str, help="the ip address of the worker's node")
parser.add_argument("--redis-address", required=True, type=str, help="the address to use for Redis")
parser.add_argument("--object-store-name", required=True, type=str, help="the object store's name")
parser.add_argument("--object-store-manager-name", required=True, type=str, help="the object store manager's name")
parser.add_argument("--local-scheduler-name", required=True, type=str, help="the local scheduler's name")
parser.add_argument("--actor-id", required=False, type=str, help="the actor ID of this worker")
parser = argparse.ArgumentParser(description=("Parse addresses for the worker "
"to connect to."))
parser.add_argument("--node-ip-address", required=True, type=str,
help="the ip address of the worker's node")
parser.add_argument("--redis-address", required=True, type=str,
help="the address to use for Redis")
parser.add_argument("--object-store-name", required=True, type=str,
help="the object store's name")
parser.add_argument("--object-store-manager-name", required=True, type=str,
help="the object store manager's name")
parser.add_argument("--local-scheduler-name", required=True, type=str,
help="the local scheduler's name")
parser.add_argument("--actor-id", required=False, type=str,
help="the actor ID of this worker")
def random_string():
return np.random.bytes(20)
if __name__ == "__main__":
args = parser.parse_args()
info = {"node_ip_address": args.node_ip_address,
@@ -30,7 +38,10 @@ if __name__ == "__main__":
"manager_socket_name": args.object_store_manager_name,
"local_scheduler_socket_name": args.local_scheduler_name}
actor_id = binascii.unhexlify(args.actor_id) if not args.actor_id is None else ray.worker.NIL_ACTOR_ID
if args.actor_id is not None:
actor_id = binascii.unhexlify(args.actor_id)
else:
actor_id = ray.worker.NIL_ACTOR_ID
ray.worker.connect(info, mode=ray.WORKER_MODE, actor_id=actor_id)
@@ -43,24 +54,27 @@ being caught in "python/ray/workers/default_worker.py".
while True:
try:
# This call to main_loop should never return if things are working. Most
# exceptions that are thrown (e.g., inside the execution of a task) should
# be caught and handled inside of the call to main_loop. If an exception
# is thrown here, then that means that there is some error that we didn't
# anticipate.
# exceptions that are thrown (e.g., inside the execution of a task)
# should be caught and handled inside of the call to main_loop. If an
# exception is thrown here, then that means that there is some error that
# we didn't anticipate.
ray.worker.main_loop()
except Exception as e:
traceback_str = traceback.format_exc() + error_explanation
DRIVER_ID_LENGTH = 20
# We use a driver ID of all zeros to push an error message to all drivers.
# We use a driver ID of all zeros to push an error message to all
# drivers.
driver_id = DRIVER_ID_LENGTH * b"\x00"
error_key = b"Error:" + driver_id + b":" + random_string()
redis_ip_address, redis_port = args.redis_address.split(":")
# For this command to work, some other client (on the same machine as
# Redis) must have run "CONFIG SET protected-mode no".
redis_client = redis.StrictRedis(host=redis_ip_address, port=int(redis_port))
redis_client = redis.StrictRedis(host=redis_ip_address,
port=int(redis_port))
redis_client.hmset(error_key, {"type": "worker_crash",
"message": traceback_str,
"note": "This error is unexpected and should not have happened."})
"note": ("This error is unexpected and "
"should not have happened.")})
redis_client.rpush("ErrorKeys", error_key)
# TODO(rkn): Note that if the worker was in the middle of executing a
# task, the any worker or driver that is blocking in a get call and
+16 -11
View File
@@ -2,12 +2,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import subprocess
from setuptools import setup, find_packages
import setuptools.command.install as _install
class install(_install.install):
def run(self):
subprocess.check_call(["../build.sh"])
@@ -16,19 +16,24 @@ class install(_install.install):
# setuptools. So, calling do_egg_install() manually here.
self.do_egg_install()
package_data = {
"ray": ["core/src/common/thirdparty/redis/src/redis-server",
"core/src/common/redis_module/libray_redis_module.so",
"core/src/plasma/plasma_store",
"core/src/plasma/plasma_manager",
"core/src/plasma/libplasma.so",
"core/src/local_scheduler/local_scheduler",
"core/src/local_scheduler/liblocal_scheduler_library.so",
"core/src/numbuf/libarrow.so",
"core/src/numbuf/libnumbuf.so",
"core/src/global_scheduler/global_scheduler"]
}
setup(name="ray",
version="0.0.1",
packages=find_packages(),
package_data={"ray": ["core/src/common/thirdparty/redis/src/redis-server",
"core/src/common/redis_module/libray_redis_module.so",
"core/src/plasma/plasma_store",
"core/src/plasma/plasma_manager",
"core/src/plasma/libplasma.so",
"core/src/local_scheduler/local_scheduler",
"core/src/local_scheduler/liblocal_scheduler_library.so",
"core/src/numbuf/libarrow.so",
"core/src/numbuf/libnumbuf.so",
"core/src/global_scheduler/global_scheduler"]},
package_data=package_data,
cmdclass={"install": install},
install_requires=["numpy",
"funcsigs",