Run flake8 in Travis and make code PEP8 compliant. (#387)

This commit is contained in:
Robert Nishihara
2017-03-21 12:57:54 -07:00
committed by Philipp Moritz
parent 083e7a28ad
commit ba02fc0eb0
54 changed files with 2391 additions and 1313 deletions
+4
View File
@@ -35,6 +35,10 @@ matrix:
- cd doc
- pip install -r requirements-doc.txt
- sphinx-build -W -b html -d _build/doctrees source _build/html
- cd ..
# Run Python linting.
- flake8 --ignore=E111,E114
--exclude=python/ray/core/src/common/flatbuffers_ep-prefix/,python/ray/core/generated/,src/numbuf/thirdparty/,src/common/format/,examples/,doc/source/conf.py
- os: linux
dist: trusty
env: VALGRIND=1 PYTHON=2.7
+3
View File
@@ -64,6 +64,9 @@ elif [[ "$LINT" == "1" ]]; then
# Install miniconda.
wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh
bash miniconda.sh -b -p $HOME/miniconda
export PATH="$HOME/miniconda/bin:$PATH"
# Install Python linting tools.
pip install flake8
else
echo "Unrecognized environment."
exit 1
+22 -12
View File
@@ -2,19 +2,29 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Ray version string
__version__ = "0.01"
import ctypes
# Windows only
if hasattr(ctypes, "windll"):
# Makes sure that all child processes die when we die
# Also makes sure that fatal crashes result in process termination rather than an error dialog (the latter is annoying since we have a lot of processes)
# This is done by associating all child processes with a "job" object that imposes this behavior.
(lambda kernel32: (lambda job: (lambda n: kernel32.SetInformationJobObject(job, 9, "\0" * 17 + chr(0x8 | 0x4 | 0x20) + "\0" * (n - 18), n))(0x90 if ctypes.sizeof(ctypes.c_void_p) > ctypes.sizeof(ctypes.c_int) else 0x70) and kernel32.AssignProcessToJobObject(job, ctypes.c_void_p(kernel32.GetCurrentProcess())))(ctypes.c_void_p(kernel32.CreateJobObjectW(None, None))) if kernel32 is not None else None)(ctypes.windll.kernel32)
from ray.worker import register_class, error_info, init, connect, disconnect, get, put, wait, remote, log_event, log_span, flush_log
from ray.worker import (register_class, error_info, init, connect, disconnect,
get, put, wait, remote, log_event, log_span,
flush_log)
from ray.actor import actor
from ray.actor import get_gpu_ids
from ray.worker import EnvironmentVariable, env
from ray.worker import SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, SILENT_MODE
# Ray version string
__version__ = "0.01"
__all__ = ["register_class", "error_info", "init", "connect", "disconnect",
"get", "put", "wait", "remote", "log_event", "log_span",
"flush_log", "actor", "get_gpu_ids", "EnvironmentVariable", "env",
"SCRIPT_MODE", "WORKER_MODE", "PYTHON_MODE", "SILENT_MODE",
"__version__"]
import ctypes
# Windows only
if hasattr(ctypes, "windll"):
# Makes sure that all child processes die when we die. Also makes sure that
# fatal crashes result in process termination rather than an error dialog
# (the latter is annoying since we have a lot of processes). This is done by
# associating all child processes with a "job" object that imposes this
# behavior.
(lambda kernel32: (lambda job: (lambda n: kernel32.SetInformationJobObject(job, 9, "\0" * 17 + chr(0x8 | 0x4 | 0x20) + "\0" * (n - 18), n))(0x90 if ctypes.sizeof(ctypes.c_void_p) > ctypes.sizeof(ctypes.c_int) else 0x70) and kernel32.AssignProcessToJobObject(job, ctypes.c_void_p(kernel32.GetCurrentProcess())))(ctypes.c_void_p(kernel32.CreateJobObjectW(None, None))) if kernel32 is not None else None)(ctypes.windll.kernel32) # noqa: E501
+60 -21
View File
@@ -18,6 +18,7 @@ import ray.experimental.state as state
# the worker is currently allowed to use.
gpu_ids = []
def get_gpu_ids():
"""Get the IDs of the GPU that are available to the worker.
@@ -26,12 +27,15 @@ def get_gpu_ids():
"""
return gpu_ids
def random_string():
return np.random.bytes(20)
def random_actor_id():
return ray.local_scheduler.ObjectID(random_string())
def get_actor_method_function_id(attr):
"""Get the function ID corresponding to an actor method.
@@ -47,10 +51,14 @@ def get_actor_method_function_id(attr):
assert len(function_id) == 20
return ray.local_scheduler.ObjectID(function_id)
def fetch_and_register_actor(key, worker):
"""Import an actor."""
driver_id, actor_id_str, actor_name, module, pickled_class, assigned_gpu_ids, actor_method_names = \
worker.redis_client.hmget(key, ["driver_id", "actor_id", "name", "module", "class", "gpu_ids", "actor_method_names"])
(driver_id, actor_id_str, actor_name,
module, pickled_class, assigned_gpu_ids,
actor_method_names) = worker.redis_client.hmget(
key, ["driver_id", "actor_id", "name", "module", "class", "gpu_ids",
"actor_method_names"])
actor_id = ray.local_scheduler.ObjectID(actor_id_str)
actor_name = actor_name.decode("ascii")
module = module.decode("ascii")
@@ -64,12 +72,14 @@ def fetch_and_register_actor(key, worker):
class TemporaryActor(object):
pass
worker.actors[actor_id_str] = TemporaryActor()
def temporary_actor_method(*xs):
raise Exception("The actor with name {} failed to be imported, and so "
"cannot execute this method".format(actor_name))
for actor_method_name in actor_method_names:
function_id = get_actor_method_function_id(actor_method_name).id()
worker.functions[driver_id][function_id] = (actor_method_name, temporary_actor_method)
worker.functions[driver_id][function_id] = (actor_method_name,
temporary_actor_method)
try:
unpickled_class = pickling.loads(pickled_class)
@@ -84,11 +94,15 @@ def fetch_and_register_actor(key, worker):
# TODO(pcm): Why is the below line necessary?
unpickled_class.__module__ = module
worker.actors[actor_id_str] = unpickled_class.__new__(unpickled_class)
for (k, v) in inspect.getmembers(unpickled_class, predicate=(lambda x: inspect.isfunction(x) or inspect.ismethod(x))):
for (k, v) in inspect.getmembers(
unpickled_class, predicate=(lambda x: (inspect.isfunction(x) or
inspect.ismethod(x)))):
function_id = get_actor_method_function_id(k).id()
worker.functions[driver_id][function_id] = (k, v)
# We do not set worker.function_properties[driver_id][function_id] because
# we currently do need the actor worker to submit new tasks for the actor.
# We do not set worker.function_properties[driver_id][function_id]
# because we currently do need the actor worker to submit new tasks for
# the actor.
def select_local_scheduler(local_schedulers, num_gpus, worker):
"""Select a local scheduler to assign this actor to.
@@ -119,15 +133,19 @@ def select_local_scheduler(local_schedulers, num_gpus, worker):
# Loop through all of the local schedulers.
for local_scheduler in local_schedulers:
# See if there are enough available GPUs on this local scheduler.
local_scheduler_total_gpus = int(float(local_scheduler[b"num_gpus"].decode("ascii")))
gpus_in_use = worker.redis_client.hget(local_scheduler[b"ray_client_id"], b"gpus_in_use")
local_scheduler_total_gpus = int(float(
local_scheduler[b"num_gpus"].decode("ascii")))
gpus_in_use = worker.redis_client.hget(local_scheduler[b"ray_client_id"],
b"gpus_in_use")
gpus_in_use = 0 if gpus_in_use is None else int(gpus_in_use)
if gpus_in_use + num_gpus <= local_scheduler_total_gpus:
# Attempt to reserve some GPUs for this actor.
new_gpus_in_use = worker.redis_client.hincrby(local_scheduler[b"ray_client_id"], b"gpus_in_use", num_gpus)
new_gpus_in_use = worker.redis_client.hincrby(
local_scheduler[b"ray_client_id"], b"gpus_in_use", num_gpus)
if new_gpus_in_use > local_scheduler_total_gpus:
# If we failed to reserve the GPUs, undo the increment.
worker.redis_client.hincrby(local_scheduler[b"ray_client_id"], b"gpus_in_use", num_gpus)
worker.redis_client.hincrby(local_scheduler[b"ray_client_id"],
b"gpus_in_use", num_gpus)
else:
# We succeeded at reserving the GPUs, so we are done.
local_scheduler_id = local_scheduler[b"ray_client_id"]
@@ -135,10 +153,13 @@ def select_local_scheduler(local_schedulers, num_gpus, worker):
break
if local_scheduler_id is None:
raise Exception("Could not find a node with enough GPUs to create this "
"actor. The local scheduler information is {}.".format(local_schedulers))
"actor. The local scheduler information is {}."
.format(local_schedulers))
return local_scheduler_id, gpu_ids
def export_actor(actor_id, Class, actor_method_names, num_cpus, num_gpus, worker):
def export_actor(actor_id, Class, actor_method_names, num_cpus, num_gpus,
worker):
"""Export an actor to redis.
Args:
@@ -158,13 +179,16 @@ def export_actor(actor_id, Class, actor_method_names, num_cpus, num_gpus, worker
driver_id = worker.task_driver_id.id()
for actor_method_name in actor_method_names:
function_id = get_actor_method_function_id(actor_method_name).id()
worker.function_properties[driver_id][function_id] = (1, num_cpus, num_gpus)
worker.function_properties[driver_id][function_id] = (1, num_cpus,
num_gpus)
# Select a local scheduler for the actor.
local_schedulers = state.get_local_schedulers(worker)
local_scheduler_id, gpu_ids = select_local_scheduler(local_schedulers, num_gpus, worker)
local_scheduler_id, gpu_ids = select_local_scheduler(local_schedulers,
num_gpus, worker)
worker.redis_client.publish("actor_notifications", actor_id.id() + local_scheduler_id)
worker.redis_client.publish("actor_notifications",
actor_id.id() + local_scheduler_id)
d = {"driver_id": driver_id,
"actor_id": actor_id.id(),
@@ -176,6 +200,7 @@ def export_actor(actor_id, Class, actor_method_names, num_cpus, num_gpus, worker
worker.redis_client.hmset(key, d)
worker.redis_client.rpush("Exports", key)
def actor(*args, **kwargs):
def make_actor_decorator(num_cpus=1, num_gpus=0):
def make_actor(Class):
@@ -189,7 +214,8 @@ def actor(*args, **kwargs):
raise Exception("Actors currently do not support **kwargs.")
function_id = get_actor_method_function_id(attr)
# TODO(pcm): Extend args with keyword args.
object_ids = ray.worker.global_worker.submit_task(function_id, "", args,
object_ids = ray.worker.global_worker.submit_task(function_id, "",
args,
actor_id=actor_id)
if len(object_ids) == 1:
return object_ids[0]
@@ -199,24 +225,34 @@ def actor(*args, **kwargs):
class NewClass(object):
def __init__(self, *args, **kwargs):
self._ray_actor_id = random_actor_id()
self._ray_actor_methods = {k: v for (k, v) in inspect.getmembers(Class, predicate=(lambda x: inspect.isfunction(x) or inspect.ismethod(x)))}
export_actor(self._ray_actor_id, Class, self._ray_actor_methods.keys(), num_cpus, num_gpus, ray.worker.global_worker)
self._ray_actor_methods = {
k: v for (k, v) in inspect.getmembers(
Class, predicate=(lambda x: (inspect.isfunction(x) or
inspect.ismethod(x))))}
export_actor(self._ray_actor_id, Class,
self._ray_actor_methods.keys(), num_cpus, num_gpus,
ray.worker.global_worker)
# Call __init__ as a remote function.
if "__init__" in self._ray_actor_methods.keys():
actor_method_call(self._ray_actor_id, "__init__", *args, **kwargs)
else:
print("WARNING: this object has no __init__ method.")
# Make tab completion work.
def __dir__(self):
return self._ray_actor_methods
def __getattribute__(self, attr):
# The following is needed so we can still access self.actor_methods.
if attr in ["_ray_actor_id", "_ray_actor_methods"]:
return super(NewClass, self).__getattribute__(attr)
if attr in self._ray_actor_methods.keys():
return lambda *args, **kwargs: actor_method_call(self._ray_actor_id, attr, *args, **kwargs)
return lambda *args, **kwargs: actor_method_call(
self._ray_actor_id, attr, *args, **kwargs)
# There is no method with this name, so raise an exception.
raise AttributeError("'{}' Actor object has no attribute '{}'".format(Class, attr))
raise AttributeError("'{}' Actor object has no attribute '{}'"
.format(Class, attr))
def __repr__(self):
return "Actor(" + self._ray_actor_id.hex() + ")"
@@ -230,7 +266,9 @@ def actor(*args, **kwargs):
return make_actor_decorator(num_cpus=1, num_gpus=0)(Class)
# In this case, the actor decorator is something like @ray.actor(num_gpus=1).
if len(args) == 0 and len(kwargs) > 0 and all([key in ["num_cpus", "num_gpus"] for key in kwargs.keys()]):
if len(args) == 0 and len(kwargs) > 0 and all([key
in ["num_cpus", "num_gpus"]
for key in kwargs.keys()]):
num_cpus = kwargs["num_cpus"] if "num_cpus" in kwargs.keys() else 1
num_gpus = kwargs["num_gpus"] if "num_gpus" in kwargs.keys() else 0
return make_actor_decorator(num_cpus=num_cpus, num_gpus=num_gpus)
@@ -240,4 +278,5 @@ def actor(*args, **kwargs):
"some of the arguments 'num_cpus' or 'num_gpus' as in "
"'ray.actor(num_gpus=1)'.")
ray.worker.global_worker.fetch_and_register["Actor"] = fetch_and_register_actor
+146 -78
View File
@@ -2,17 +2,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import random
import subprocess
import redis
import sys
import time
import unittest
import redis
import ray.services
# Import flatbuffer bindings.
from ray.core.generated.SubscribeToNotificationsReply import SubscribeToNotificationsReply
from ray.core.generated.SubscribeToNotificationsReply \
import SubscribeToNotificationsReply
from ray.core.generated.TaskReply import TaskReply
from ray.core.generated.ResultTableReply import ResultTableReply
@@ -22,6 +21,7 @@ OBJECT_SUBSCRIBE_PREFIX = "OS:"
TASK_PREFIX = "TT:"
OBJECT_CHANNEL_PREFIX = "OC:"
def integerToAsciiHex(num, numbytes):
retstr = b""
# Support 32 and 64 bit architecture.
@@ -36,6 +36,7 @@ def integerToAsciiHex(num, numbytes):
return retstr
def get_next_message(pubsub_client, timeout_seconds=10):
"""Block until the next message is available on the pubsub channel."""
start_time = time.time()
@@ -47,6 +48,7 @@ def get_next_message(pubsub_client, timeout_seconds=10):
if time.time() - start_time > timeout_seconds:
raise Exception("Timed out while waiting for next message.")
class TestGlobalStateStore(unittest.TestCase):
def setUp(self):
@@ -57,102 +59,144 @@ class TestGlobalStateStore(unittest.TestCase):
ray.services.cleanup()
def testInvalidObjectTableAdd(self):
# Check that Redis returns an error when RAY.OBJECT_TABLE_ADD is called with
# the wrong arguments.
# Check that Redis returns an error when RAY.OBJECT_TABLE_ADD is called
# with the wrong arguments.
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "hello")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", "one", "hash2", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", "one",
"hash2", "manager_id1")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", 1, "hash2", "manager_id1", "extra argument")
# Check that Redis returns an error when RAY.OBJECT_TABLE_ADD adds an object
# ID that is already present with a different hash.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", 1,
"hash2", "manager_id1", "extra argument")
# Check that Redis returns an error when RAY.OBJECT_TABLE_ADD adds an
# object ID that is already present with a different hash.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id1")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1"})
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash2", "manager_id2")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash2", "manager_id2")
# Check that the second manager was added, even though the hash was
# mismatched.
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
# Check that it is fine if we add the same object ID multiple times with the
# most recent hash.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash2", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash2", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash2", "manager_id2")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 2, "hash2", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
# Check that it is fine if we add the same object ID multiple times with
# the most recent hash.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash2", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash2", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash2", "manager_id2")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 2,
"hash2", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
def testObjectTableAddAndLookup(self):
# Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not been
# added yet.
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(response, None)
# Add some managers and try again.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
# Add a manager that already exists again and try again.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
# Check that we properly handle NULL characters. In the past, NULL
# characters were handled improperly causing a "hash mismatch" error if two
# object IDs that agreed up to the NULL character were inserted with
# different hashes.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id3", 1, "hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id4", 1, "hash2", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id3", 1,
"hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id4", 1,
"hash2", "manager_id1")
# Check that NULL characters in the hash are handled properly.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1, "\x00hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1,
"\x00hash1", "manager_id1")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1, "\x00hash2", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1,
"\x00hash2", "manager_id1")
def testObjectTableAddAndRemove(self):
# Try removing a manager from an object ID that has not been added yet.
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
"manager_id1")
# Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not been
# added yet.
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(response, None)
# Add some managers and try again.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
# Remove a manager that doesn't exist, and make sure we still have the same set.
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id3")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
# Remove a manager that doesn't exist, and make sure we still have the same
# set.
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
"manager_id3")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
# Remove a manager that does exist. Make sure it gets removed the first
# time and does nothing the second time.
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id1")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
"manager_id1")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id2"})
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id1")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
"manager_id1")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id2"})
# Remove the last manager, and make sure we have an empty set.
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
"manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), set())
# Remove a manager from an empty set, and make sure we now have an empty set.
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id3")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
# Remove a manager from an empty set, and make sure we now have an empty
# set.
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
"manager_id3")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), set())
def testObjectTableSubscribeToNotifications(self):
# Define a helper method for checking the contents of object notifications.
def check_object_notification(notification_message, object_id, object_size, manager_ids):
notification_object = SubscribeToNotificationsReply.GetRootAsSubscribeToNotificationsReply(notification_message, 0)
def check_object_notification(notification_message, object_id, object_size,
manager_ids):
notification_object = (SubscribeToNotificationsReply
.GetRootAsSubscribeToNotificationsReply(
notification_message, 0))
self.assertEqual(notification_object.ObjectId(), object_id)
self.assertEqual(notification_object.ObjectSize(), object_size)
self.assertEqual(notification_object.ManagerIdsLength(), len(manager_ids))
self.assertEqual(notification_object.ManagerIdsLength(),
len(manager_ids))
for i in range(len(manager_ids)):
self.assertEqual(notification_object.ManagerIds(i), manager_ids[i])
@@ -160,37 +204,45 @@ class TestGlobalStateStore(unittest.TestCase):
p = self.redis.pubsub()
# Subscribe to an object ID.
p.psubscribe("{}manager_id1".format(OBJECT_CHANNEL_PREFIX))
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", data_size, "hash1", "manager_id2")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", data_size,
"hash1", "manager_id2")
# Receive the acknowledgement message.
self.assertEqual(get_next_message(p)["data"], 1)
# Request a notification and receive the data.
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", "manager_id1", "object_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
"manager_id1", "object_id1")
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id1",
data_size,
[b"manager_id2"])
# Request a notification for an object that isn't there. Then add the object
# and receive the data. Only the first call to RAY.OBJECT_TABLE_ADD should
# trigger notifications.
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", "manager_id1", "object_id2", "object_id3")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, "hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, "hash1", "manager_id2")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, "hash1", "manager_id3")
# Request a notification for an object that isn't there. Then add the
# object and receive the data. Only the first call to RAY.OBJECT_TABLE_ADD
# should trigger notifications.
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
"manager_id1", "object_id2", "object_id3")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size,
"hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size,
"hash1", "manager_id2")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size,
"hash1", "manager_id3")
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id3",
data_size,
[b"manager_id1"])
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", data_size, "hash1", "manager_id3")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", data_size,
"hash1", "manager_id3")
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id2",
data_size,
[b"manager_id3"])
# Request notifications for object_id3 again.
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", "manager_id1", "object_id3")
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
"manager_id1", "object_id3")
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id3",
@@ -199,29 +251,38 @@ class TestGlobalStateStore(unittest.TestCase):
def testResultTableAddAndLookup(self):
def check_result_table_entry(message, task_id, is_put):
result_table_reply = ResultTableReply.GetRootAsResultTableReply(message, 0)
result_table_reply = ResultTableReply.GetRootAsResultTableReply(message,
0)
self.assertEqual(result_table_reply.TaskId(), task_id)
self.assertEqual(result_table_reply.IsPut(), is_put)
# Try looking up something in the result table before anything is added.
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1")
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
"object_id1")
self.assertIsNone(response)
# Adding the object to the object table should have no effect.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1")
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id1")
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
"object_id1")
self.assertIsNone(response)
# Add the result to the result table. The lookup now returns the task ID.
task_id = b"task_id1"
self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id1", task_id, 0)
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1")
self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id1", task_id,
0)
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
"object_id1")
check_result_table_entry(response, task_id, False)
# Doing it again should still work.
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1")
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
"object_id1")
check_result_table_entry(response, task_id, False)
# Try another result table lookup. This should succeed.
task_id = b"task_id2"
self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id2", task_id, 1)
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id2")
self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id2", task_id,
1)
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
"object_id2")
check_result_table_entry(response, task_id, True)
def testInvalidTaskTableAdd(self):
@@ -251,7 +312,8 @@ class TestGlobalStateStore(unittest.TestCase):
task_status, local_scheduler_id, task_spec = task_args
task_reply_object = TaskReply.GetRootAsTaskReply(message, 0)
self.assertEqual(task_reply_object.State(), task_status)
self.assertEqual(task_reply_object.LocalSchedulerId(), local_scheduler_id)
self.assertEqual(task_reply_object.LocalSchedulerId(),
local_scheduler_id)
self.assertEqual(task_reply_object.TaskSpec(), task_spec)
self.assertEqual(task_reply_object.Updated(), updated)
@@ -263,7 +325,8 @@ class TestGlobalStateStore(unittest.TestCase):
check_task_reply(response, task_args)
task_args[0] = TASK_STATUS_SCHEDULED
self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", *task_args[:2])
self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id",
*task_args[:2])
response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
check_task_reply(response, task_args)
@@ -331,9 +394,12 @@ class TestGlobalStateStore(unittest.TestCase):
# Subscribe to the task table.
p = self.redis.pubsub()
p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
p.psubscribe("{prefix}*:{state}".format(prefix=TASK_PREFIX, state=scheduling_state))
p.psubscribe("{prefix}{local_scheduler_id}:*".format(prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
task_args = [b"task_id", scheduling_state, local_scheduler_id.encode("ascii"), b"task_spec"]
p.psubscribe("{prefix}*:{state}".format(
prefix=TASK_PREFIX, state=scheduling_state))
p.psubscribe("{prefix}{local_scheduler_id}:*".format(
prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
task_args = [b"task_id", scheduling_state,
local_scheduler_id.encode("ascii"), b"task_spec"]
self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args)
# Receive the acknowledgement message.
self.assertEqual(get_next_message(p)["data"], 1)
@@ -346,8 +412,10 @@ class TestGlobalStateStore(unittest.TestCase):
notification_object = TaskReply.GetRootAsTaskReply(message, 0)
self.assertEqual(notification_object.TaskId(), b"task_id")
self.assertEqual(notification_object.State(), scheduling_state)
self.assertEqual(notification_object.LocalSchedulerId(), local_scheduler_id.encode("ascii"))
self.assertEqual(notification_object.LocalSchedulerId(),
local_scheduler_id.encode("ascii"))
self.assertEqual(notification_object.TaskSpec(), b"task_spec")
if __name__ == "__main__":
unittest.main(verbosity=2)
+58 -41
View File
@@ -11,25 +11,29 @@ import ray.local_scheduler as local_scheduler
ID_SIZE = 20
def random_object_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
def random_function_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
def random_driver_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
def random_task_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
BASE_SIMPLE_OBJECTS = [
0, 1, 100000, 0.0, 0.5, 0.9, 100000.1, (), [], {},
"", 990 * "h", u"", 990 * u"h"
]
0, 1, 100000, 0.0, 0.5, 0.9, 100000.1, (), [], {}, "", 990 * "h", u"",
990 * u"h"]
if sys.version_info < (3, 0):
BASE_SIMPLE_OBJECTS += [long(0), long(1), long(100000), long(1 << 100)]
BASE_SIMPLE_OBJECTS += [long(0), long(1), long(100000), long(1 << 100)] # noqa: E501,F821
LIST_SIMPLE_OBJECTS = [[obj] for obj in BASE_SIMPLE_OBJECTS]
TUPLE_SIMPLE_OBJECTS = [(obj,) for obj in BASE_SIMPLE_OBJECTS]
@@ -45,11 +49,14 @@ SIMPLE_OBJECTS = (BASE_SIMPLE_OBJECTS +
l = []
l.append(l)
class Foo(object):
def __init__(self):
pass
BASE_COMPLEX_OBJECTS = [999 * "h", 999 * u"h", l, Foo(), 10 * [10 * [10 * [1]]]]
BASE_COMPLEX_OBJECTS = [999 * "h", 999 * u"h", l, Foo(),
10 * [10 * [10 * [1]]]]
LIST_COMPLEX_OBJECTS = [[obj] for obj in BASE_COMPLEX_OBJECTS]
TUPLE_COMPLEX_OBJECTS = [(obj,) for obj in BASE_COMPLEX_OBJECTS]
@@ -60,6 +67,7 @@ COMPLEX_OBJECTS = (BASE_COMPLEX_OBJECTS +
TUPLE_COMPLEX_OBJECTS +
DICT_COMPLEX_OBJECTS)
class TestSerialization(unittest.TestCase):
def test_serialize_by_value(self):
@@ -69,27 +77,31 @@ class TestSerialization(unittest.TestCase):
for val in COMPLEX_OBJECTS:
self.assertFalse(local_scheduler.check_simple_value(val))
class TestObjectID(unittest.TestCase):
def test_create_object_id(self):
object_id = random_object_id()
random_object_id()
def test_cannot_pickle_object_ids(self):
object_ids = [random_object_id() for _ in range(256)]
def f():
return object_ids
def g(val=object_ids):
return 1
def h():
x = object_ids[0]
object_ids[0]
return 1
# Make sure that object IDs cannot be pickled (including functions that
# close over object IDs).
self.assertRaises(Exception, lambda : pickling.dumps(object_ids[0]))
self.assertRaises(Exception, lambda : pickling.dumps(object_ids))
self.assertRaises(Exception, lambda : pickling.dumps(f))
self.assertRaises(Exception, lambda : pickling.dumps(g))
self.assertRaises(Exception, lambda : pickling.dumps(h))
self.assertRaises(Exception, lambda: pickle.dumps(object_ids[0]))
self.assertRaises(Exception, lambda: pickle.dumps(object_ids))
self.assertRaises(Exception, lambda: pickle.dumps(f))
self.assertRaises(Exception, lambda: pickle.dumps(g))
self.assertRaises(Exception, lambda: pickle.dumps(h))
def test_equality_comparisons(self):
x1 = local_scheduler.ObjectID(ID_SIZE * b"a")
@@ -101,8 +113,10 @@ class TestObjectID(unittest.TestCase):
self.assertNotEqual(x1, y1)
random_strings = [np.random.bytes(ID_SIZE) for _ in range(256)]
object_ids1 = [local_scheduler.ObjectID(random_strings[i]) for i in range(256)]
object_ids2 = [local_scheduler.ObjectID(random_strings[i]) for i in range(256)]
object_ids1 = [local_scheduler.ObjectID(random_strings[i])
for i in range(256)]
object_ids2 = [local_scheduler.ObjectID(random_strings[i])
for i in range(256)]
self.assertEqual(len(set(object_ids1)), 256)
self.assertEqual(len(set(object_ids1 + object_ids2)), 256)
self.assertEqual(set(object_ids1), set(object_ids2))
@@ -113,6 +127,7 @@ class TestObjectID(unittest.TestCase):
{x: y}
set([x, y])
class TestTask(unittest.TestCase):
def check_task(self, task, function_id, num_return_vals, args):
@@ -127,44 +142,46 @@ class TestTask(unittest.TestCase):
self.assertEqual(retrieved_args[i], args[i])
def test_create_and_serialize_task(self):
# TODO(rkn): The function ID should be a FunctionID object, not an ObjectID.
# TODO(rkn): The function ID should be a FunctionID object, not an
# ObjectID.
driver_id = random_driver_id()
parent_id = random_task_id()
function_id = random_function_id()
object_ids = [random_object_id() for _ in range(256)]
args_list = [
[],
1 * [1],
10 * [1],
100 * [1],
1000 * [1],
1 * ["a"],
10 * ["a"],
100 * ["a"],
1000 * ["a"],
[1, 1.3, 2, 1 << 100, "hi", u"hi", [1, 2]],
object_ids[:1],
object_ids[:2],
object_ids[:3],
object_ids[:4],
object_ids[:5],
object_ids[:10],
object_ids[:100],
object_ids[:256],
[1, object_ids[0]],
[object_ids[0], "a"],
[1, object_ids[0], "a"],
[object_ids[0], 1, object_ids[1], "a"],
object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
object_ids + 100 * ["a"] + object_ids
]
[],
1 * [1],
10 * [1],
100 * [1],
1000 * [1],
1 * ["a"],
10 * ["a"],
100 * ["a"],
1000 * ["a"],
[1, 1.3, 2, 1 << 100, "hi", u"hi", [1, 2]],
object_ids[:1],
object_ids[:2],
object_ids[:3],
object_ids[:4],
object_ids[:5],
object_ids[:10],
object_ids[:100],
object_ids[:256],
[1, object_ids[0]],
[object_ids[0], "a"],
[1, object_ids[0], "a"],
[object_ids[0], 1, object_ids[1], "a"],
object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
object_ids + 100 * ["a"] + object_ids]
for args in args_list:
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
task = local_scheduler.Task(driver_id, function_id, args, num_return_vals, parent_id, 0)
task = local_scheduler.Task(driver_id, function_id, args,
num_return_vals, parent_id, 0)
self.check_task(task, function_id, num_return_vals, args)
data = local_scheduler.task_to_string(task)
task2 = local_scheduler.task_from_string(data)
self.check_task(task2, function_id, num_return_vals, args)
if __name__ == "__main__":
unittest.main(verbosity=2)
+2
View File
@@ -4,3 +4,5 @@ from __future__ import print_function
from .utils import copy_directory
from .tfutils import TensorFlowVariables
__all__ = ["copy_directory", "TensorFlowVariables"]
@@ -4,4 +4,10 @@ from __future__ import print_function
from . import random
from . import linalg
from .core import *
from .core import (BLOCK_SIZE, DistArray, assemble, zeros, ones, copy, eye,
triu, tril, blockwise_dot, dot, transpose, add, subtract,
numpy_to_dist, subblocks)
__all__ = ["random", "linalg", "BLOCK_SIZE", "DistArray", "assemble", "zeros",
"ones", "copy", "eye", "triu", "tril", "blockwise_dot", "dot",
"transpose", "add", "subtract", "numpy_to_dist", "subblocks"]
@@ -6,30 +6,38 @@ import numpy as np
import ray.experimental.array.remote as ra
import ray
__all__ = ["BLOCK_SIZE", "DistArray", "assemble", "zeros", "ones", "copy",
"eye", "triu", "tril", "blockwise_dot", "dot", "transpose", "add", "subtract", "numpy_to_dist", "subblocks"]
BLOCK_SIZE = 10
class DistArray(object):
def __init__(self, shape, objectids=None):
self.shape = shape
self.ndim = len(shape)
self.num_blocks = [int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in self.shape]
self.objectids = objectids if objectids is not None else np.empty(self.num_blocks, dtype=object)
if objectids is not None:
self.objectids = objectids
else:
self.objectids = np.empty(self.num_blocks, dtype=object)
if self.num_blocks != list(self.objectids.shape):
raise Exception("The fields `num_blocks` and `objectids` are inconsistent, `num_blocks` is {} and `objectids` has shape {}".format(self.num_blocks, list(self.objectids.shape)))
raise Exception("The fields `num_blocks` and `objectids` are "
"inconsistent, `num_blocks` is {} and `objectids` has "
"shape {}".format(self.num_blocks,
list(self.objectids.shape)))
@staticmethod
def compute_block_lower(index, shape):
if len(index) != len(shape):
raise Exception("The fields `index` and `shape` must have the same length, but `index` is {} and `shape` is {}.".format(index, shape))
raise Exception("The fields `index` and `shape` must have the same "
"length, but `index` is {} and `shape` is "
"{}.".format(index, shape))
return [elem * BLOCK_SIZE for elem in index]
@staticmethod
def compute_block_upper(index, shape):
if len(index) != len(shape):
raise Exception("The fields `index` and `shape` must have the same length, but `index` is {} and `shape` is {}.".format(index, shape))
raise Exception("The fields `index` and `shape` must have the same "
"length, but `index` is {} and `shape` is "
"{}.".format(index, shape))
upper = []
for i in range(len(shape)):
upper.append(min((index[i] + 1) * BLOCK_SIZE, shape[i]))
@@ -46,59 +54,73 @@ class DistArray(object):
return [int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in shape]
def assemble(self):
"""Assemble an array on this node from a distributed array of object IDs."""
"""Assemble an array from a distributed array of object IDs."""
first_block = ray.get(self.objectids[(0,) * self.ndim])
dtype = first_block.dtype
result = np.zeros(self.shape, dtype=dtype)
for index in np.ndindex(*self.num_blocks):
lower = DistArray.compute_block_lower(index, self.shape)
upper = DistArray.compute_block_upper(index, self.shape)
result[[slice(l, u) for (l, u) in zip(lower, upper)]] = ray.get(self.objectids[index])
result[[slice(l, u) for (l, u) in zip(lower, upper)]] = ray.get(
self.objectids[index])
return result
def __getitem__(self, sliced):
# TODO(rkn): fix this, this is just a placeholder that should work but is inefficient
# TODO(rkn): Fix this, this is just a placeholder that should work but is
# inefficient.
a = self.assemble()
return a[sliced]
# Register the DistArray class with Ray so that it knows how to serialize it.
ray.register_class(DistArray)
@ray.remote
def assemble(a):
return a.assemble()
# TODO(rkn): what should we call this method
# TODO(rkn): What should we call this method?
@ray.remote
def numpy_to_dist(a):
result = DistArray(a.shape)
for index in np.ndindex(*result.num_blocks):
lower = DistArray.compute_block_lower(index, a.shape)
upper = DistArray.compute_block_upper(index, a.shape)
result.objectids[index] = ray.put(a[[slice(l, u) for (l, u) in zip(lower, upper)]])
result.objectids[index] = ray.put(a[[slice(l, u) for (l, u)
in zip(lower, upper)]])
return result
@ray.remote
def zeros(shape, dtype_name="float"):
result = DistArray(shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = ra.zeros.remote(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
result.objectids[index] = ra.zeros.remote(
DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
return result
@ray.remote
def ones(shape, dtype_name="float"):
result = DistArray(shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = ra.ones.remote(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
result.objectids[index] = ra.ones.remote(
DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
return result
@ray.remote
def copy(a):
result = DistArray(a.shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = a.objectids[index] # We don't need to actually copy the objects because cluster-level objects are assumed to be immutable.
# We don't need to actually copy the objects because remote objects are
# immutable.
result.objectids[index] = a.objectids[index]
return result
@ray.remote
def eye(dim1, dim2=-1, dtype_name="float"):
dim2 = dim1 if dim2 == -1 else dim2
@@ -107,15 +129,19 @@ def eye(dim1, dim2=-1, dtype_name="float"):
for (i, j) in np.ndindex(*result.num_blocks):
block_shape = DistArray.compute_block_shape([i, j], shape)
if i == j:
result.objectids[i, j] = ra.eye.remote(block_shape[0], block_shape[1], dtype_name=dtype_name)
result.objectids[i, j] = ra.eye.remote(block_shape[0], block_shape[1],
dtype_name=dtype_name)
else:
result.objectids[i, j] = ra.zeros.remote(block_shape, dtype_name=dtype_name)
result.objectids[i, j] = ra.zeros.remote(block_shape,
dtype_name=dtype_name)
return result
@ray.remote
def triu(a):
if a.ndim != 2:
raise Exception("Input must have 2 dimensions, but a.ndim is " + str(a.ndim))
raise Exception("Input must have 2 dimensions, but a.ndim is "
"{}.".format(a.ndim))
result = DistArray(a.shape)
for (i, j) in np.ndindex(*result.num_blocks):
if i < j:
@@ -126,10 +152,12 @@ def triu(a):
result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j])
return result
@ray.remote
def tril(a):
if a.ndim != 2:
raise Exception("Input must have 2 dimensions, but a.ndim is " + str(a.ndim))
raise Exception("Input must have 2 dimensions, but a.ndim is "
"{}.".format(a.ndim))
result = DistArray(a.shape)
for (i, j) in np.ndindex(*result.num_blocks):
if i > j:
@@ -140,25 +168,31 @@ def tril(a):
result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j])
return result
@ray.remote
def blockwise_dot(*matrices):
n = len(matrices)
if n % 2 != 0:
raise Exception("blockwise_dot expects an even number of arguments, but len(matrices) is {}.".format(n))
raise Exception("blockwise_dot expects an even number of arguments, but "
"len(matrices) is {}.".format(n))
shape = (matrices[0].shape[0], matrices[n // 2].shape[1])
result = np.zeros(shape)
for i in range(n // 2):
result += np.dot(matrices[i], matrices[n // 2 + i])
return result
@ray.remote
def dot(a, b):
if a.ndim != 2:
raise Exception("dot expects its arguments to be 2-dimensional, but a.ndim = {}.".format(a.ndim))
raise Exception("dot expects its arguments to be 2-dimensional, but "
"a.ndim = {}.".format(a.ndim))
if b.ndim != 2:
raise Exception("dot expects its arguments to be 2-dimensional, but b.ndim = {}.".format(b.ndim))
raise Exception("dot expects its arguments to be 2-dimensional, but "
"b.ndim = {}.".format(b.ndim))
if a.shape[1] != b.shape[0]:
raise Exception("dot expects a.shape[1] to equal b.shape[0], but a.shape = {} and b.shape = {}.".format(a.shape, b.shape))
raise Exception("dot expects a.shape[1] to equal b.shape[0], but a.shape "
"= {} and b.shape = {}.".format(a.shape, b.shape))
shape = [a.shape[0], b.shape[1]]
result = DistArray(shape)
for (i, j) in np.ndindex(*result.num_blocks):
@@ -166,10 +200,12 @@ def dot(a, b):
result.objectids[i, j] = blockwise_dot.remote(*args)
return result
@ray.remote
def subblocks(a, *ranges):
"""
This function produces a distributed array from a subset of the blocks in the `a`. The result and `a` will have the same number of dimensions.For example,
This function produces a distributed array from a subset of the blocks in the
`a`. The result and `a` will have the same number of dimensions.For example,
subblocks(a, [0, 1], [2, 4])
will produce a DistArray whose objectids are
[[a.objectids[0, 2], a.objectids[0, 4]],
@@ -178,50 +214,71 @@ def subblocks(a, *ranges):
"""
ranges = list(ranges)
if len(ranges) != a.ndim:
raise Exception("sub_blocks expects to receive a number of ranges equal to a.ndim, but it received {} ranges and a.ndim = {}.".format(len(ranges), a.ndim))
raise Exception("sub_blocks expects to receive a number of ranges equal "
"to a.ndim, but it received {} ranges and a.ndim = "
"{}.".format(len(ranges), a.ndim))
for i in range(len(ranges)):
if ranges[i] == []: # We allow the user to pass in an empty list to indicate the full range
# We allow the user to pass in an empty list to indicate the full range.
if ranges[i] == []:
ranges[i] = range(a.num_blocks[i])
if not np.alltrue(ranges[i] == np.sort(ranges[i])):
raise Exception("Ranges passed to sub_blocks must be sorted, but the {}th range is {}.".format(i, ranges[i]))
raise Exception("Ranges passed to sub_blocks must be sorted, but the "
"{}th range is {}.".format(i, ranges[i]))
if ranges[i][0] < 0:
raise Exception("Values in the ranges passed to sub_blocks must be at least 0, but the {}th range is {}.".format(i, ranges[i]))
raise Exception("Values in the ranges passed to sub_blocks must be at "
"least 0, but the {}th range is {}.".format(i,
ranges[i]))
if ranges[i][-1] >= a.num_blocks[i]:
raise Exception("Values in the ranges passed to sub_blocks must be less than the relevant number of blocks, but the {}th range is {}, and a.num_blocks = {}.".format(i, ranges[i], a.num_blocks))
raise Exception("Values in the ranges passed to sub_blocks must be less "
"than the relevant number of blocks, but the {}th range "
"is {}, and a.num_blocks = {}.".format(i, ranges[i],
a.num_blocks))
last_index = [r[-1] for r in ranges]
last_block_shape = DistArray.compute_block_shape(last_index, a.shape)
shape = [(len(ranges[i]) - 1) * BLOCK_SIZE + last_block_shape[i] for i in range(a.ndim)]
shape = [(len(ranges[i]) - 1) * BLOCK_SIZE + last_block_shape[i]
for i in range(a.ndim)]
result = DistArray(shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = a.objectids[tuple([ranges[i][index[i]] for i in range(a.ndim)])]
result.objectids[index] = a.objectids[tuple([ranges[i][index[i]]
for i in range(a.ndim)])]
return result
@ray.remote
def transpose(a):
if a.ndim != 2:
raise Exception("transpose expects its argument to be 2-dimensional, but a.ndim = {}, a.shape = {}.".format(a.ndim, a.shape))
raise Exception("transpose expects its argument to be 2-dimensional, but "
"a.ndim = {}, a.shape = {}.".format(a.ndim, a.shape))
result = DistArray([a.shape[1], a.shape[0]])
for i in range(result.num_blocks[0]):
for j in range(result.num_blocks[1]):
result.objectids[i, j] = ra.transpose.remote(a.objectids[j, i])
return result
# TODO(rkn): support broadcasting?
@ray.remote
def add(x1, x2):
if x1.shape != x2.shape:
raise Exception("add expects arguments `x1` and `x2` to have the same shape, but x1.shape = {}, and x2.shape = {}.".format(x1.shape, x2.shape))
raise Exception("add expects arguments `x1` and `x2` to have the same "
"shape, but x1.shape = {}, and x2.shape = {}."
.format(x1.shape, x2.shape))
result = DistArray(x1.shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = ra.add.remote(x1.objectids[index], x2.objectids[index])
result.objectids[index] = ra.add.remote(x1.objectids[index],
x2.objectids[index])
return result
# TODO(rkn): support broadcasting?
@ray.remote
def subtract(x1, x2):
if x1.shape != x2.shape:
raise Exception("subtract expects arguments `x1` and `x2` to have the same shape, but x1.shape = {}, and x2.shape = {}.".format(x1.shape, x2.shape))
raise Exception("subtract expects arguments `x1` and `x2` to have the "
"same shape, but x1.shape = {}, and x2.shape = {}."
.format(x1.shape, x2.shape))
result = DistArray(x1.shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = ra.subtract.remote(x1.objectids[index], x2.objectids[index])
result.objectids[index] = ra.subtract.remote(x1.objectids[index],
x2.objectids[index])
return result
@@ -6,30 +6,32 @@ import numpy as np
import ray.experimental.array.remote as ra
import ray
from .core import *
from . import core
__all__ = ["tsqr", "modified_lu", "tsqr_hr", "qr"]
@ray.remote(num_return_vals=2)
def tsqr(a):
"""
arguments:
a: a distributed matrix
Suppose that
a.shape == (M, N)
K == min(M, N)
return values:
q: DistArray, if q_full = ray.get(DistArray, q).assemble(), then
q_full.shape == (M, K)
np.allclose(np.dot(q_full.T, q_full), np.eye(K)) == True
r: np.ndarray, if r_val = ray.get(np.ndarray, r), then
r_val.shape == (K, N)
np.allclose(r, np.triu(r)) == True
"""Perform a QR decomposition of a tall-skinny matrix.
Args:
a: A distributed matrix with shape MxN (suppose K = min(M, N)).
Returns:
A tuple of q (a DistArray) and r (a numpy array) satisfying the following.
- If q_full = ray.get(DistArray, q).assemble(), then
q_full.shape == (M, K).
- np.allclose(np.dot(q_full.T, q_full), np.eye(K)) == True.
- If r_val = ray.get(np.ndarray, r), then r_val.shape == (K, N).
- np.allclose(r, np.triu(r)) == True.
"""
if len(a.shape) != 2:
raise Exception("tsqr requires len(a.shape) == 2, but a.shape is {}".format(a.shape))
raise Exception("tsqr requires len(a.shape) == 2, but a.shape is "
"{}".format(a.shape))
if a.num_blocks[1] != 1:
raise Exception("tsqr requires a.num_blocks[1] == 1, but a.num_blocks is {}".format(a.num_blocks))
raise Exception("tsqr requires a.num_blocks[1] == 1, but a.num_blocks is "
"{}".format(a.num_blocks))
num_blocks = a.num_blocks[0]
K = int(np.ceil(np.log2(num_blocks))) + 1
@@ -57,9 +59,9 @@ def tsqr(a):
q_shape = a.shape
else:
q_shape = [a.shape[0], a.shape[0]]
q_num_blocks = DistArray.compute_num_blocks(q_shape)
q_num_blocks = core.DistArray.compute_num_blocks(q_shape)
q_objectids = np.empty(q_num_blocks, dtype=object)
q_result = DistArray(q_shape, q_objectids)
q_result = core.DistArray(q_shape, q_objectids)
# reconstruct output
for i in range(num_blocks):
@@ -68,28 +70,35 @@ def tsqr(a):
for j in range(1, K):
if np.mod(ith_index, 2) == 0:
lower = [0, 0]
upper = [a.shape[1], BLOCK_SIZE]
upper = [a.shape[1], core.BLOCK_SIZE]
else:
lower = [a.shape[1], 0]
upper = [2 * a.shape[1], BLOCK_SIZE]
upper = [2 * a.shape[1], core.BLOCK_SIZE]
ith_index //= 2
q_block_current = ra.dot.remote(q_block_current, ra.subarray.remote(q_tree[ith_index, j], lower, upper))
q_block_current = ra.dot.remote(q_block_current,
ra.subarray.remote(q_tree[ith_index, j],
lower, upper))
q_result.objectids[i] = q_block_current
r = current_rs[0]
return q_result, ray.get(r)
# TODO(rkn): This is unoptimized, we really want a block version of this.
# This is Algorithm 5 from
# http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf.
@ray.remote(num_return_vals=3)
def modified_lu(q):
"""
Algorithm 5 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf
takes a matrix q with orthonormal columns, returns l, u, s such that q - s = l * u
arguments:
q: a two dimensional orthonormal q
return values:
l: lower triangular
u: upper triangular
s: a diagonal matrix represented by its diagonal
"""Perform a modified LU decomposition of a matrix.
This takes a matrix q with orthonormal columns, returns l, u, s such that
q - s = l * u.
Args:
q: A two dimensional orthonormal matrix q.
Returns:
A tuple of a lower triangular matrix l, an upper triangular matrix u, and a
a vector representing a diagonal matrix s such that q - s = l * u.
"""
q = q.assemble()
m, b = q.shape[0], q.shape[1]
@@ -100,14 +109,19 @@ def modified_lu(q):
for i in range(b):
S[i] = -1 * np.sign(q_work[i, i])
q_work[i, i] -= S[i]
q_work[(i + 1):m, i] /= q_work[i, i] # scale ith column of L by diagonal element
q_work[(i + 1):m, (i + 1):b] -= np.outer(q_work[(i + 1):m, i], q_work[i, (i + 1):b]) # perform Schur complement update
# Scale ith column of L by diagonal element.
q_work[(i + 1):m, i] /= q_work[i, i]
# Perform Schur complement update.
q_work[(i + 1):m, (i + 1):b] -= np.outer(q_work[(i + 1):m, i],
q_work[i, (i + 1):b])
L = np.tril(q_work)
for i in range(b):
L[i, i] = 1
U = np.triu(q_work)[:b, :]
return ray.get(numpy_to_dist.remote(ray.put(L))), U, S # TODO(rkn): get rid of put
# TODO(rkn): Get rid of the put below.
return ray.get(core.numpy_to_dist.remote(ray.put(L))), U, S
@ray.remote(num_return_vals=2)
def tsqr_hr_helper1(u, s, y_top_block, b):
@@ -116,45 +130,60 @@ def tsqr_hr_helper1(u, s, y_top_block, b):
t = -1 * np.dot(u, np.dot(s_full, np.linalg.inv(y_top).T))
return t, y_top
@ray.remote
def tsqr_hr_helper2(s, r_temp):
s_full = np.diag(s)
return np.dot(s_full, r_temp)
# This is Algorithm 6 from
# http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf.
@ray.remote(num_return_vals=4)
def tsqr_hr(a):
"""Algorithm 6 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf"""
q, r_temp = tsqr.remote(a)
y, u, s = modified_lu.remote(q)
y_blocked = ray.get(y)
t, y_top = tsqr_hr_helper1.remote(u, s, y_blocked.objectids[0, 0], a.shape[1])
t, y_top = tsqr_hr_helper1.remote(u, s, y_blocked.objectids[0, 0],
a.shape[1])
r = tsqr_hr_helper2.remote(s, r_temp)
return ray.get(y), ray.get(t), ray.get(y_top), ray.get(r)
@ray.remote
def qr_helper1(a_rc, y_ri, t, W_c):
return a_rc - np.dot(y_ri, np.dot(t.T, W_c))
@ray.remote
def qr_helper2(y_ri, a_rc):
return np.dot(y_ri.T, a_rc)
# This is Algorithm 7 from
# http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf.
@ray.remote(num_return_vals=2)
def qr(a):
"""Algorithm 7 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf"""
m, n = a.shape[0], a.shape[1]
k = min(m, n)
# we will store our scratch work in a_work
a_work = DistArray(a.shape, np.copy(a.objectids))
a_work = core.DistArray(a.shape, np.copy(a.objectids))
result_dtype = np.linalg.qr(ray.get(a.objectids[0, 0]))[0].dtype.name
r_res = ray.get(zeros.remote([k, n], result_dtype)) # TODO(rkn): It would be preferable not to get this right after creating it.
y_res = ray.get(zeros.remote([m, k], result_dtype)) # TODO(rkn): It would be preferable not to get this right after creating it.
# TODO(rkn): It would be preferable not to get this right after creating it.
r_res = ray.get(core.zeros.remote([k, n], result_dtype))
# TODO(rkn): It would be preferable not to get this right after creating it.
y_res = ray.get(core.zeros.remote([m, k], result_dtype))
Ts = []
for i in range(min(a.num_blocks[0], a.num_blocks[1])): # this differs from the paper, which says "for i in range(a.num_blocks[1])", but that doesn't seem to make any sense when a.num_blocks[1] > a.num_blocks[0]
sub_dist_array = subblocks.remote(a_work, list(range(i, a_work.num_blocks[0])), [i])
# The for loop differs from the paper, which says
# "for i in range(a.num_blocks[1])", but that doesn't seem to make any sense
# when a.num_blocks[1] > a.num_blocks[0].
for i in range(min(a.num_blocks[0], a.num_blocks[1])):
sub_dist_array = core.subblocks.remote(
a_work, list(range(i, a_work.num_blocks[0])), [i])
y, t, _, R = tsqr_hr.remote(sub_dist_array)
y_val = ray.get(y)
@@ -167,7 +196,7 @@ def qr(a):
r_res.objectids[i, i] = ra.dot.remote(eye_temp, R)
else:
r_res.objectids[i, i] = R
Ts.append(numpy_to_dist.remote(t))
Ts.append(core.numpy_to_dist.remote(t))
for c in range(i + 1, a.num_blocks[1]):
W_rcs = []
@@ -182,9 +211,14 @@ def qr(a):
r_res.objectids[i, c] = a_work.objectids[i, c]
# construct q_res from Ys and Ts
q = eye.remote(m, k, dtype_name=result_dtype)
q = core.eye.remote(m, k, dtype_name=result_dtype)
for i in range(len(Ts))[::-1]:
y_col_block = subblocks.remote(y_res, [], [i])
q = subtract.remote(q, dot.remote(y_col_block, dot.remote(Ts[i], dot.remote(transpose.remote(y_col_block), q))))
y_col_block = core.subblocks.remote(y_res, [], [i])
q = core.subtract.remote(
q, core.dot.remote(
y_col_block,
core.dot.remote(Ts[i],
core.dot.remote(core.transpose.remote(y_col_block),
q))))
return ray.get(q), r_res
@@ -6,13 +6,15 @@ import numpy as np
import ray.experimental.array.remote as ra
import ray
from .core import *
from .core import DistArray
@ray.remote
def normal(shape):
num_blocks = DistArray.compute_num_blocks(shape)
objectids = np.empty(num_blocks, dtype=object)
for index in np.ndindex(*num_blocks):
objectids[index] = ra.random.normal.remote(DistArray.compute_block_shape(index, shape))
objectids[index] = ra.random.normal.remote(
DistArray.compute_block_shape(index, shape))
result = DistArray(shape, objectids)
return result
@@ -4,4 +4,10 @@ from __future__ import print_function
from . import random
from . import linalg
from .core import *
from .core import (zeros, zeros_like, ones, eye, dot, vstack, hstack, subarray,
copy, tril, triu, diag, transpose, add, subtract, sum,
shape, sum_list)
__all__ = ["random", "linalg", "zeros", "zeros_like", "ones", "eye", "dot",
"vstack", "hstack", "subarray", "copy", "tril", "triu", "diag",
"transpose", "add", "subtract", "sum", "shape", "sum_list"]
+20 -5
View File
@@ -5,82 +5,97 @@ from __future__ import print_function
import numpy as np
import ray
__all__ = ["zeros", "zeros_like", "ones", "eye", "dot", "vstack", "hstack", "subarray", "copy", "tril", "triu", "diag", "transpose", "add", "subtract", "sum", "shape", "sum_list"]
@ray.remote
def zeros(shape, dtype_name="float", order="C"):
return np.zeros(shape, dtype=np.dtype(dtype_name), order=order)
@ray.remote
def zeros_like(a, dtype_name="None", order="K", subok=True):
dtype_val = None if dtype_name == "None" else np.dtype(dtype_name)
return np.zeros_like(a, dtype=dtype_val, order=order, subok=subok)
@ray.remote
def ones(shape, dtype_name="float", order="C"):
return np.ones(shape, dtype=np.dtype(dtype_name), order=order)
@ray.remote
def eye(N, M=-1, k=0, dtype_name="float"):
M = N if M == -1 else M
return np.eye(N, M=M, k=k, dtype=np.dtype(dtype_name))
@ray.remote
def dot(a, b):
return np.dot(a, b)
@ray.remote
def vstack(*xs):
return np.vstack(xs)
@ray.remote
def hstack(*xs):
return np.hstack(xs)
# TODO(rkn): instead of this, consider implementing slicing
# TODO(rkn): Instead of this, consider implementing slicing.
# TODO(rkn): Be consistent about using "index" versus "indices".
@ray.remote
def subarray(a, lower_indices, upper_indices): # TODO(rkn): be consistent about using "index" versus "indices"
def subarray(a, lower_indices, upper_indices):
return a[[slice(l, u) for (l, u) in zip(lower_indices, upper_indices)]]
@ray.remote
def copy(a, order="K"):
return np.copy(a, order=order)
@ray.remote
def tril(m, k=0):
return np.tril(m, k=k)
@ray.remote
def triu(m, k=0):
return np.triu(m, k=k)
@ray.remote
def diag(v, k=0):
return np.diag(v, k=k)
@ray.remote
def transpose(a, axes=[]):
axes = None if axes == [] else axes
return np.transpose(a, axes=axes)
@ray.remote
def add(x1, x2):
return np.add(x1, x2)
@ray.remote
def subtract(x1, x2):
return np.subtract(x1, x2)
@ray.remote
def sum(x, axis=-1):
return np.sum(x, axis=axis if axis != -1 else None)
@ray.remote
def shape(a):
return np.shape(a)
# We use Any to allow different numerical types as well as numpy arrays.
# TODO(rkn):this isn't in the numpy API, so be careful about exposing this.
@ray.remote
def sum_list(*xs):
return np.sum(xs, axis=0)
+21 -1
View File
@@ -8,84 +8,104 @@ import ray
__all__ = ["matrix_power", "solve", "tensorsolve", "tensorinv", "inv",
"cholesky", "eigvals", "eigvalsh", "pinv", "slogdet", "det",
"svd", "eig", "eigh", "lstsq", "norm", "qr", "cond", "matrix_rank",
"LinAlgError", "multi_dot"]
"multi_dot"]
@ray.remote
def matrix_power(M, n):
return np.linalg.matrix_power(M, n)
@ray.remote
def solve(a, b):
return np.linalg.solve(a, b)
@ray.remote(num_return_vals=2)
def tensorsolve(a):
raise NotImplementedError
@ray.remote(num_return_vals=2)
def tensorinv(a):
raise NotImplementedError
@ray.remote
def inv(a):
return np.linalg.inv(a)
@ray.remote
def cholesky(a):
return np.linalg.cholesky(a)
@ray.remote
def eigvals(a):
return np.linalg.eigvals(a)
@ray.remote
def eigvalsh(a):
raise NotImplementedError
@ray.remote
def pinv(a):
return np.linalg.pinv(a)
@ray.remote
def slogdet(a):
raise NotImplementedError
@ray.remote
def det(a):
return np.linalg.det(a)
@ray.remote(num_return_vals=3)
def svd(a):
return np.linalg.svd(a)
@ray.remote(num_return_vals=2)
def eig(a):
return np.linalg.eig(a)
@ray.remote(num_return_vals=2)
def eigh(a):
return np.linalg.eigh(a)
@ray.remote(num_return_vals=4)
def lstsq(a, b):
return np.linalg.lstsq(a)
@ray.remote
def norm(x):
return np.linalg.norm(x)
@ray.remote(num_return_vals=2)
def qr(a):
return np.linalg.qr(a)
@ray.remote
def cond(x):
return np.linalg.cond(x)
@ray.remote
def matrix_rank(M):
return np.linalg.matrix_rank(M)
@ray.remote
def multi_dot(*a):
raise NotImplementedError
@@ -5,6 +5,7 @@ from __future__ import print_function
import numpy as np
import ray
@ray.remote
def normal(shape):
return np.random.normal(size=shape)
+1
View File
@@ -2,6 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def get_local_schedulers(worker):
local_schedulers = []
for client in worker.redis_client.keys("CL:*"):
+24 -11
View File
@@ -4,6 +4,7 @@ from __future__ import print_function
import numpy as np
from collections import deque, OrderedDict
def unflatten(vector, shapes):
i = 0
arrays = []
@@ -15,6 +16,7 @@ def unflatten(vector, shapes):
assert len(vector) == i, "Passed weight does not have the correct shape."
return arrays
class TensorFlowVariables(object):
"""An object used to extract variables from a loss function.
@@ -38,11 +40,11 @@ class TensorFlowVariables(object):
variable_names = []
explored_inputs = set([loss])
# We do a BFS on the dependency graph of the input function to find
# We do a BFS on the dependency graph of the input function to find
# the variables.
while len(queue) != 0:
tf_obj = queue.popleft()
# The object put into the queue is not necessarily an operation, so we
# want the op attribute to get the operation underlying the object.
# Only operations contain the inputs that we can explore.
@@ -61,14 +63,16 @@ class TensorFlowVariables(object):
if "Variable" in tf_obj.node_def.op:
variable_names.append(tf_obj.node_def.name)
self.variables = OrderedDict()
for v in [v for v in tf.global_variables() if v.op.node_def.name in variable_names]:
for v in [v for v in tf.global_variables()
if v.op.node_def.name in variable_names]:
self.variables[v.op.node_def.name] = v
self.placeholders = dict()
self.assignment_nodes = []
# Create new placeholders to put in custom weights.
for k, var in self.variables.items():
self.placeholders[k] = tf.placeholder(var.value().dtype, var.get_shape().as_list())
self.placeholders[k] = tf.placeholder(var.value().dtype,
var.get_shape().as_list())
self.assignment_nodes.append(var.assign(self.placeholders[k]))
def set_session(self, sess):
@@ -76,24 +80,30 @@ class TensorFlowVariables(object):
self.sess = sess
def get_flat_size(self):
return sum([np.prod(v.get_shape().as_list()) for v in self.variables.values()])
return sum([np.prod(v.get_shape().as_list())
for v in self.variables.values()])
def _check_sess(self):
"""Checks if the session is set, and if not throw an error message."""
assert self.sess is not None, "The session is not set. Set the session either by passing it into the TensorFlowVariables constructor or by calling set_session(sess)."
assert self.sess is not None, ("The session is not set. Set the session "
"either by passing it into the "
"TensorFlowVariables constructor or by "
"calling set_session(sess).")
def get_flat(self):
"""Gets the weights and returns them as a flat array."""
self._check_sess()
return np.concatenate([v.eval(session=self.sess).flatten() for v in self.variables.values()])
return np.concatenate([v.eval(session=self.sess).flatten()
for v in self.variables.values()])
def set_flat(self, new_weights):
"""Sets the weights to new_weights, converting from a flat array."""
self._check_sess()
shapes = [v.get_shape().as_list() for v in self.variables.values()]
arrays = unflatten(new_weights, shapes)
placeholders = [self.placeholders[k] for k, v in self.variables.items()]
self.sess.run(self.assignment_nodes, feed_dict=dict(zip(placeholders,arrays)))
self.sess.run(self.assignment_nodes,
feed_dict=dict(zip(placeholders, arrays)))
def get_weights(self):
"""Returns the weights of the variables of the loss function in a list."""
@@ -103,4 +113,7 @@ class TensorFlowVariables(object):
def set_weights(self, new_weights):
"""Sets the weights to new_weights."""
self._check_sess()
self.sess.run(self.assignment_nodes, feed_dict={self.placeholders[name]: value for (name, value) in new_weights.items() if name in self.placeholders})
self.sess.run(self.assignment_nodes,
feed_dict={self.placeholders[name]: value
for (name, value) in new_weights.items()
if name in self.placeholders})
+13 -7
View File
@@ -9,6 +9,7 @@ import sys
import ray
def tarred_directory_as_bytes(source_dir):
"""Tar a directory and return it as a byte string.
@@ -26,6 +27,7 @@ def tarred_directory_as_bytes(source_dir):
string_file.seek(0)
return string_file.read()
def tarred_bytes_to_directory(tarred_bytes, target_dir):
"""Take a byte string and untar it.
@@ -38,6 +40,7 @@ def tarred_bytes_to_directory(tarred_bytes, target_dir):
with tarfile.open(fileobj=string_file) as tar:
tar.extractall(path=target_dir)
def copy_directory(source_dir, target_dir=None):
"""Copy a local directory to each machine in the Ray cluster.
@@ -45,17 +48,18 @@ def copy_directory(source_dir, target_dir=None):
example, source_dir can be /a/b/c and target_dir can be /d/e/c. In this case,
the directory /d/e will be added to the Python path of each worker.
Note that this method is not completely safe to use. For example, workers that
do not do the copying and only set their paths (only one worker per node does
the copying) may try to execute functions that use the files in the directory
being copied before the directory being copied has finished untarring.
Note that this method is not completely safe to use. For example, workers
that do not do the copying and only set their paths (only one worker per node
does the copying) may try to execute functions that use the files in the
directory being copied before the directory being copied has finished
untarring.
Args:
source_dir (str): The directory to copy.
target_dir (str): The location to copy it to on the other machines. If this
is not provided, the source_dir will be used. If it is provided and is
different from source_dir, the source_dir also be copied to the target_dir
location on this machine.
different from source_dir, the source_dir also be copied to the
target_dir location on this machine.
"""
target_dir = source_dir if target_dir is None else target_dir
source_dir = os.path.abspath(source_dir)
@@ -63,8 +67,10 @@ def copy_directory(source_dir, target_dir=None):
source_basename = os.path.basename(source_dir)
target_basename = os.path.basename(target_dir)
if source_basename != target_basename:
raise Exception("The source_dir and target_dir must have the same base name, {} != {}".format(source_basename, target_basename))
raise Exception("The source_dir and target_dir must have the same base "
"name, {} != {}".format(source_basename, target_basename))
tarred_bytes = tarred_directory_as_bytes(source_dir)
def f(worker_info):
if worker_info["counter"] == 0:
tarred_bytes_to_directory(tarred_bytes, os.path.dirname(target_dir))
+3 -1
View File
@@ -2,4 +2,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from .global_scheduler_services import *
from .global_scheduler_services import start_global_scheduler
__all__ = ["start_global_scheduler"]
@@ -6,6 +6,7 @@ import os
import subprocess
import time
def start_global_scheduler(redis_address, use_valgrind=False,
use_profiler=False, stdout_file=None,
stderr_file=None):
@@ -15,8 +16,8 @@ def start_global_scheduler(redis_address, use_valgrind=False,
redis_address (str): The address of the Redis instance.
use_valgrind (bool): True if the global scheduler should be started inside
of valgrind. If this is True, use_profiler must be False.
use_profiler (bool): True if the global scheduler should be started inside a
profiler. If this is True, use_valgrind must be False.
use_profiler (bool): True if the global scheduler should be started inside
a profiler. If this is True, use_valgrind must be False.
stdout_file: A file handle opened for writing to redirect stdout to. If no
redirection should happen, then this should be None.
stderr_file: A file handle opened for writing to redirect stderr to. If no
@@ -27,7 +28,9 @@ def start_global_scheduler(redis_address, use_valgrind=False,
"""
if use_valgrind and use_profiler:
raise Exception("Cannot use valgrind and profiler at the same time.")
global_scheduler_executable = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../core/src/global_scheduler/global_scheduler")
global_scheduler_executable = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
"../core/src/global_scheduler/global_scheduler")
command = [global_scheduler_executable, "-r", redis_address]
if use_valgrind:
pid = subprocess.Popen(["valgrind",
@@ -35,7 +38,7 @@ def start_global_scheduler(redis_address, use_valgrind=False,
"--leak-check=full",
"--show-leak-kinds=all",
"--error-exitcode=1"] + command,
stdout=stdout_file, stderr=stderr_file)
stdout=stdout_file, stderr=stderr_file)
time.sleep(1.0)
elif use_profiler:
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
+57 -31
View File
@@ -7,16 +7,14 @@ import os
import random
import redis
import signal
import subprocess
import sys
import threading
import time
import unittest
import ray.global_scheduler as global_scheduler
import ray.local_scheduler as local_scheduler
import ray.plasma as plasma
from ray.plasma.utils import random_object_id, generate_metadata, write_to_data_buffer, create_object_with_id, create_object
from ray.plasma.utils import create_object
from ray import services
@@ -39,21 +37,27 @@ TASK_STATUS_DONE = 16
DB_CLIENT_PREFIX = "CL:"
TASK_PREFIX = "TT:"
def random_driver_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
def random_task_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
def random_function_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
def random_object_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
def new_port():
return random.randint(10000, 65535)
class TestGlobalScheduler(unittest.TestCase):
def setUp(self):
@@ -62,9 +66,11 @@ class TestGlobalScheduler(unittest.TestCase):
redis_port, self.redis_process = services.start_redis(cleanup=False)
redis_address = services.address(node_ip_address, redis_port)
# Create a Redis client.
self.redis_client = redis.StrictRedis(host=node_ip_address, port=redis_port)
self.redis_client = redis.StrictRedis(host=node_ip_address,
port=redis_port)
# Start one global scheduler.
self.p1 = global_scheduler.start_global_scheduler(redis_address, use_valgrind=USE_VALGRIND)
self.p1 = global_scheduler.start_global_scheduler(
redis_address, use_valgrind=USE_VALGRIND)
self.plasma_store_pids = []
self.plasma_manager_pids = []
self.local_scheduler_pids = []
@@ -76,11 +82,15 @@ class TestGlobalScheduler(unittest.TestCase):
plasma_store_name, p2 = plasma.start_plasma_store()
self.plasma_store_pids.append(p2)
# Start the Plasma manager.
# Assumption: Plasma manager name and port are randomly generated by the plasma module.
plasma_manager_name, p3, plasma_manager_port = plasma.start_plasma_manager(plasma_store_name, redis_address)
# Assumption: Plasma manager name and port are randomly generated by the
# plasma module.
manager_info = plasma.start_plasma_manager(plasma_store_name,
redis_address)
plasma_manager_name, p3, plasma_manager_port = manager_info
self.plasma_manager_pids.append(p3)
plasma_address = "{}:{}".format(node_ip_address, plasma_manager_port)
plasma_client = plasma.PlasmaClient(plasma_store_name, plasma_manager_name)
plasma_client = plasma.PlasmaClient(plasma_store_name,
plasma_manager_name)
self.plasma_clients.append(plasma_client)
# Start the local scheduler.
local_scheduler_name, p4 = local_scheduler.start_local_scheduler(
@@ -91,7 +101,7 @@ class TestGlobalScheduler(unittest.TestCase):
static_resource_list=[10, 0])
# Connect to the scheduler.
local_scheduler_client = local_scheduler.LocalSchedulerClient(
local_scheduler_name, NIL_ACTOR_ID, False)
local_scheduler_name, NIL_ACTOR_ID, False)
self.local_scheduler_clients.append(local_scheduler_client)
self.local_scheduler_pids.append(p4)
@@ -149,11 +159,13 @@ class TestGlobalScheduler(unittest.TestCase):
return db_client_id
def test_task_default_resources(self):
task1 = local_scheduler.Task(random_driver_id(), random_function_id(), [random_object_id()], 0, random_task_id(), 0)
task1 = local_scheduler.Task(random_driver_id(), random_function_id(),
[random_object_id()], 0, random_task_id(), 0)
self.assertEqual(task1.required_resources(), [1.0, 0.0])
task2 = local_scheduler.Task(random_driver_id(), random_function_id(),
[random_object_id()], 0, random_task_id(), 0,
local_scheduler.ObjectID(NIL_ACTOR_ID), 0, [1.0, 2.0])
local_scheduler.ObjectID(NIL_ACTOR_ID), 0,
[1.0, 2.0])
self.assertEqual(task2.required_resources(), [1.0, 2.0])
def test_redis_only_single_task(self):
@@ -162,31 +174,37 @@ class TestGlobalScheduler(unittest.TestCase):
task state transitions in Redis only. TODO(atumanov): implement.
"""
# Check precondition for this test:
# There should be 2n+1 db clients: the global scheduler + one local scheduler and one plasma per node.
self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
2 * NUM_CLUSTER_NODES + 1)
# There should be 2n+1 db clients: the global scheduler + one local
# scheduler and one plasma per node.
self.assertEqual(
len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
2 * NUM_CLUSTER_NODES + 1)
db_client_id = self.get_plasma_manager_id()
assert(db_client_id != None)
assert(db_client_id is not None)
assert(db_client_id.startswith(b"CL:"))
db_client_id = db_client_id[len(b"CL:"):] # Remove the CL: prefix.
db_client_id = db_client_id[len(b"CL:"):] # Remove the CL: prefix.
def test_integration_single_task(self):
# There should be three db clients, the global scheduler, the local
# scheduler, and the plasma manager.
self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
2 * NUM_CLUSTER_NODES + 1)
self.assertEqual(
len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
2 * NUM_CLUSTER_NODES + 1)
num_return_vals = [0, 1, 2, 3, 5, 10]
# Insert the object into Redis.
data_size = 0xf1f0
metadata_size = 0x40
plasma_client = self.plasma_clients[0]
object_dep, memory_buffer, metadata = create_object(plasma_client, data_size, metadata_size, seal=True)
object_dep, memory_buffer, metadata = create_object(
plasma_client, data_size, metadata_size, seal=True)
# Sleep before submitting task to local scheduler.
time.sleep(0.1)
# Submit a task to Redis.
task = local_scheduler.Task(random_driver_id(), random_function_id(), [local_scheduler.ObjectID(object_dep)], num_return_vals[0], random_task_id(), 0)
task = local_scheduler.Task(random_driver_id(), random_function_id(),
[local_scheduler.ObjectID(object_dep)],
num_return_vals[0], random_task_id(), 0)
self.local_scheduler_clients[0].submit(task)
time.sleep(0.1)
# There should now be a task in Redis, and it should get assigned to the
@@ -217,8 +235,9 @@ class TestGlobalScheduler(unittest.TestCase):
def integration_many_tasks_helper(self, timesync=True):
# There should be three db clients, the global scheduler, the local
# scheduler, and the plasma manager.
self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
2 * NUM_CLUSTER_NODES + 1)
self.assertEqual(
len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
2 * NUM_CLUSTER_NODES + 1)
num_return_vals = [0, 1, 2, 3, 5, 10]
# Submit a bunch of tasks to Redis.
@@ -228,11 +247,16 @@ class TestGlobalScheduler(unittest.TestCase):
data_size = np.random.randint(1 << 20)
metadata_size = np.random.randint(1 << 10)
plasma_client = self.plasma_clients[0]
object_dep, memory_buffer, metadata = create_object(plasma_client, data_size, metadata_size, seal=True)
object_dep, memory_buffer, metadata = create_object(plasma_client,
data_size,
metadata_size,
seal=True)
if timesync:
# Give 10ms for object info handler to fire (long enough to yield CPU).
time.sleep(0.010)
task = local_scheduler.Task(random_driver_id(), random_function_id(), [local_scheduler.ObjectID(object_dep)], num_return_vals[0], random_task_id(), 0)
task = local_scheduler.Task(random_driver_id(), random_function_id(),
[local_scheduler.ObjectID(object_dep)],
num_return_vals[0], random_task_id(), 0)
self.local_scheduler_clients[0].submit(task)
# Check that there are the correct number of tasks in Redis and that they
# all get assigned to the local scheduler.
@@ -243,17 +267,18 @@ class TestGlobalScheduler(unittest.TestCase):
self.assertLessEqual(len(task_entries), num_tasks)
# First, check if all tasks made it to Redis.
if len(task_entries) == num_tasks:
task_contents = [self.redis_client.hgetall(task_entries[i]) for i in range(len(task_entries))]
task_contents = [self.redis_client.hgetall(task_entries[i])
for i in range(len(task_entries))]
task_statuses = [int(contents[b"state"]) for contents in task_contents]
self.assertTrue(all([
status in [TASK_STATUS_WAITING,
TASK_STATUS_SCHEDULED,
TASK_STATUS_QUEUED] for status in task_statuses
]))
self.assertTrue(all([status in [TASK_STATUS_WAITING,
TASK_STATUS_SCHEDULED,
TASK_STATUS_QUEUED]
for status in task_statuses]))
num_tasks_done = task_statuses.count(TASK_STATUS_QUEUED)
num_tasks_scheduled = task_statuses.count(TASK_STATUS_SCHEDULED)
num_tasks_waiting = task_statuses.count(TASK_STATUS_WAITING)
print("tasks in Redis = {}, tasks waiting = {}, tasks scheduled = {}, tasks queued = {}, retries left = {}"
print("tasks in Redis = {}, tasks waiting = {}, tasks scheduled = {}, "
"tasks queued = {}, retries left = {}"
.format(len(task_entries), num_tasks_waiting,
num_tasks_scheduled, num_tasks_done, num_retries))
if all([status == TASK_STATUS_QUEUED for status in task_statuses]):
@@ -275,6 +300,7 @@ class TestGlobalScheduler(unittest.TestCase):
# notifications.
self.integration_many_tasks_helper(timesync=False)
if __name__ == "__main__":
if len(sys.argv) > 1:
# Pop the argument so we don't mess with unittest's own argument parser.
+7 -2
View File
@@ -2,5 +2,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.core.src.local_scheduler.liblocal_scheduler_library import *
from .local_scheduler_services import *
from ray.core.src.local_scheduler.liblocal_scheduler_library import (
Task, LocalSchedulerClient, ObjectID, check_simple_value, task_from_string,
task_to_string)
from .local_scheduler_services import start_local_scheduler
__all__ = ["Task", "LocalSchedulerClient", "ObjectID", "check_simple_value",
"task_from_string", "task_to_string", "start_local_scheduler"]
@@ -7,9 +7,11 @@ import random
import subprocess
import time
def random_name():
return str(random.randint(0, 99999999))
def start_local_scheduler(plasma_store_name,
plasma_manager_name=None,
worker_path=None,
@@ -38,8 +40,8 @@ def start_local_scheduler(plasma_store_name,
running on.
redis_address (str): The address of the Redis instance to connect to. If
this is not provided, then the local scheduler will not connect to Redis.
use_valgrind (bool): True if the local scheduler should be started inside of
valgrind. If this is True, use_profiler must be False.
use_valgrind (bool): True if the local scheduler should be started inside
of valgrind. If this is True, use_profiler must be False.
use_profiler (bool): True if the local scheduler should be started inside a
profiler. If this is True, use_valgrind must be False.
stdout_file: A file handle opened for writing to redirect stdout to. If no
@@ -56,11 +58,14 @@ def start_local_scheduler(plasma_store_name,
A tuple of the name of the local scheduler socket and the process ID of the
local scheduler process.
"""
if (plasma_manager_name == None) != (redis_address == None):
raise Exception("If one of the plasma_manager_name and the redis_address is provided, then both must be provided.")
if (plasma_manager_name is None) != (redis_address is None):
raise Exception("If one of the plasma_manager_name and the redis_address "
"is provided, then both must be provided.")
if use_valgrind and use_profiler:
raise Exception("Cannot use valgrind and profiler at the same time.")
local_scheduler_executable = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../core/src/local_scheduler/local_scheduler")
local_scheduler_executable = os.path.join(os.path.dirname(
os.path.abspath(__file__)),
"../core/src/local_scheduler/local_scheduler")
local_scheduler_name = "/tmp/scheduler{}".format(random_name())
command = [local_scheduler_executable,
"-s", local_scheduler_name,
@@ -90,8 +95,10 @@ def start_local_scheduler(plasma_store_name,
if plasma_address is not None:
command += ["-a", plasma_address]
if static_resource_list is not None:
assert all([isinstance(resource, int) or isinstance(resource, float) for resource in static_resource_list])
command += ["-c", ",".join([str(resource) for resource in static_resource_list])]
assert all([isinstance(resource, int) or isinstance(resource, float)
for resource in static_resource_list])
command += ["-c", ",".join([str(resource) for resource
in static_resource_list])]
if use_valgrind:
pid = subprocess.Popen(["valgrind",
@@ -99,7 +106,7 @@ def start_local_scheduler(plasma_store_name,
"--leak-check=full",
"--show-leak-kinds=all",
"--error-exitcode=1"] + command,
stdout=stdout_file, stderr=stderr_file)
stdout=stdout_file, stderr=stderr_file)
time.sleep(1.0)
elif use_profiler:
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
+46 -35
View File
@@ -4,9 +4,7 @@ from __future__ import print_function
import numpy as np
import os
import random
import signal
import subprocess
import sys
import threading
import time
@@ -20,18 +18,23 @@ ID_SIZE = 20
NIL_ACTOR_ID = 20 * b"\xff"
def random_object_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
def random_driver_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
def random_task_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
def random_function_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
class TestLocalSchedulerClient(unittest.TestCase):
def setUp(self):
@@ -39,10 +42,11 @@ class TestLocalSchedulerClient(unittest.TestCase):
plasma_store_name, self.p1 = plasma.start_plasma_store()
self.plasma_client = plasma.PlasmaClient(plasma_store_name)
# Start a local scheduler.
scheduler_name, self.p2 = local_scheduler.start_local_scheduler(plasma_store_name, use_valgrind=USE_VALGRIND)
scheduler_name, self.p2 = local_scheduler.start_local_scheduler(
plasma_store_name, use_valgrind=USE_VALGRIND)
# Connect to the scheduler.
self.local_scheduler_client = local_scheduler.LocalSchedulerClient(
scheduler_name, NIL_ACTOR_ID, False)
scheduler_name, NIL_ACTOR_ID, False)
def tearDown(self):
# Check that the processes are still alive.
@@ -70,37 +74,38 @@ class TestLocalSchedulerClient(unittest.TestCase):
self.plasma_client.seal(object_id.id())
# Define some arguments to use for the tasks.
args_list = [
[],
#{},
#(),
1 * [1],
10 * [1],
100 * [1],
1000 * [1],
1 * ["a"],
10 * ["a"],
100 * ["a"],
1000 * ["a"],
[1, 1.3, 1 << 100, "hi", u"hi", [1, 2]],
object_ids[:1],
object_ids[:2],
object_ids[:3],
object_ids[:4],
object_ids[:5],
object_ids[:10],
object_ids[:100],
object_ids[:256],
[1, object_ids[0]],
[object_ids[0], "a"],
[1, object_ids[0], "a"],
[object_ids[0], 1, object_ids[1], "a"],
object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
object_ids + 100 * ["a"] + object_ids
[],
[{}],
[()],
1 * [1],
10 * [1],
100 * [1],
1000 * [1],
1 * ["a"],
10 * ["a"],
100 * ["a"],
1000 * ["a"],
[1, 1.3, 1 << 100, "hi", u"hi", [1, 2]],
object_ids[:1],
object_ids[:2],
object_ids[:3],
object_ids[:4],
object_ids[:5],
object_ids[:10],
object_ids[:100],
object_ids[:256],
[1, object_ids[0]],
[object_ids[0], "a"],
[1, object_ids[0], "a"],
[object_ids[0], 1, object_ids[1], "a"],
object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
object_ids + 100 * ["a"] + object_ids
]
for args in args_list:
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
task = local_scheduler.Task(random_driver_id(), function_id, args, num_return_vals, random_task_id(), 0)
task = local_scheduler.Task(random_driver_id(), function_id, args,
num_return_vals, random_task_id(), 0)
# Submit a task.
self.local_scheduler_client.submit(task)
# Get the task.
@@ -119,7 +124,8 @@ class TestLocalSchedulerClient(unittest.TestCase):
# Submit all of the tasks.
for args in args_list:
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
task = local_scheduler.Task(random_driver_id(), function_id, args, num_return_vals, random_task_id(), 0)
task = local_scheduler.Task(random_driver_id(), function_id, args,
num_return_vals, random_task_id(), 0)
self.local_scheduler_client.submit(task)
# Get all of the tasks.
for args in args_list:
@@ -129,8 +135,10 @@ class TestLocalSchedulerClient(unittest.TestCase):
def test_scheduling_when_objects_ready(self):
# Create a task and submit it.
object_id = random_object_id()
task = local_scheduler.Task(random_driver_id(), random_function_id(), [object_id], 0, random_task_id(), 0)
task = local_scheduler.Task(random_driver_id(), random_function_id(),
[object_id], 0, random_task_id(), 0)
self.local_scheduler_client.submit(task)
# Launch a thread to get the task.
def get_task():
self.local_scheduler_client.get_task()
@@ -149,7 +157,9 @@ class TestLocalSchedulerClient(unittest.TestCase):
# Create a task with two dependencies and submit it.
object_id1 = random_object_id()
object_id2 = random_object_id()
task = local_scheduler.Task(random_driver_id(), random_function_id(), [object_id1, object_id2], 0, random_task_id(), 0)
task = local_scheduler.Task(random_driver_id(), random_function_id(),
[object_id1, object_id2], 0, random_task_id(),
0)
self.local_scheduler_client.submit(task)
# Launch a thread to get the task.
@@ -196,9 +206,10 @@ class TestLocalSchedulerClient(unittest.TestCase):
# Wait until the thread finishes so that we know the task was scheduled.
t.join()
if __name__ == "__main__":
if len(sys.argv) > 1:
# pop the argument so we don't mess with unittest's own argument parser
# Pop the argument so we don't mess with unittest's own argument parser.
if sys.argv[-1] == "valgrind":
arg = sys.argv.pop()
USE_VALGRIND = True
+9 -5
View File
@@ -3,13 +3,13 @@ from __future__ import division
from __future__ import print_function
import argparse
import os
import redis
import time
from ray.services import get_ip_address
from ray.services import get_port
class LogMonitor(object):
"""A monitor process for monitoring Ray log files.
@@ -27,15 +27,17 @@ class LogMonitor(object):
def __init__(self, redis_ip_address, redis_port, node_ip_address):
"""Initialize the log monitor object."""
self.node_ip_address = node_ip_address
self.redis_client = redis.StrictRedis(host=redis_ip_address, port=redis_port)
self.redis_client = redis.StrictRedis(host=redis_ip_address,
port=redis_port)
self.log_files = {}
self.log_file_handles = {}
def update_log_filenames(self):
"""Get the most up-to-date list of log files to monitor from Redis."""
num_current_log_files = len(self.log_files)
new_log_filenames = self.redis_client.lrange("LOG_FILENAMES:{}".format(self.node_ip_address),
num_current_log_files, -1)
new_log_filenames = self.redis_client.lrange(
"LOG_FILENAMES:{}".format(self.node_ip_address),
num_current_log_files, -1)
for log_filename in new_log_filenames:
print("Beginning to track file {}".format(log_filename))
assert log_filename not in self.log_files
@@ -50,7 +52,8 @@ class LogMonitor(object):
# If there are any new lines, cache them and also push them to Redis.
if len(new_lines) > 0:
self.log_files[log_filename] += new_lines
redis_key = "LOGFILE:{}:{}".format(self.node_ip_address, log_filename.decode("ascii"))
redis_key = "LOGFILE:{}:{}".format(self.node_ip_address,
log_filename.decode("ascii"))
self.redis_client.rpush(redis_key, *new_lines)
else:
try:
@@ -69,6 +72,7 @@ class LogMonitor(object):
self.check_log_files_and_push_updates()
time.sleep(1)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=("Parse Redis server for the "
"log monitor to connect to."))
+22 -17
View File
@@ -3,7 +3,6 @@ from __future__ import division
from __future__ import print_function
import argparse
import binascii
from collections import Counter
import logging
import redis
@@ -13,7 +12,8 @@ from ray.services import get_ip_address
from ray.services import get_port
# Import flatbuffer bindings.
from ray.core.generated.SubscribeToDBClientTableReply import SubscribeToDBClientTableReply
from ray.core.generated.SubscribeToDBClientTableReply \
import SubscribeToDBClientTableReply
from ray.core.generated.TaskReply import TaskReply
# These variables must be kept in sync with the C codebase.
@@ -41,6 +41,7 @@ logging.basicConfig()
log = logging.getLogger()
log.setLevel(logging.WARN)
class Monitor(object):
"""A monitor for Ray processes.
@@ -92,7 +93,8 @@ class Monitor(object):
TASK_STATUS_LOST. A local scheduler is deemed dead if it is in
self.dead_local_schedulers.
"""
task_ids = self.redis.scan_iter(match="{prefix}*".format(prefix=TASK_PREFIX))
task_ids = self.redis.scan_iter(
match="{prefix}*".format(prefix=TASK_PREFIX))
num_tasks_updated = 0
for task_id in task_ids:
task_id = task_id[len(TASK_PREFIX):]
@@ -123,11 +125,13 @@ class Monitor(object):
"""
# TODO(swang): Also kill the associated plasma store, since it's no longer
# reachable without a plasma manager.
object_ids = self.redis.scan_iter(match="{prefix}*".format(prefix=OBJECT_PREFIX))
object_ids = self.redis.scan_iter(
match="{prefix}*".format(prefix=OBJECT_PREFIX))
num_objects_removed = 0
for object_id in object_ids:
object_id = object_id[len(OBJECT_PREFIX):]
managers = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", object_id)
managers = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
object_id)
for manager in managers:
if manager in self.dead_plasma_managers:
# If the object was on a dead plasma manager, remove that location
@@ -149,7 +153,8 @@ class Monitor(object):
not miss any notifications for deleted clients that occurred before we
subscribed.
"""
db_client_keys = self.redis.keys("{prefix}*".format(prefix=DB_CLIENT_PREFIX))
db_client_keys = self.redis.keys(
"{prefix}*".format(prefix=DB_CLIENT_PREFIX))
for db_client_key in db_client_keys:
db_client_id = db_client_key[len(DB_CLIENT_PREFIX):]
client_type, deleted = self.redis.hmget(db_client_key,
@@ -175,10 +180,10 @@ class Monitor(object):
flatbuffer. Deletions are processed, insertions are ignored. Cleanup of the
associated state in the state tables should be handled by the caller.
"""
notification_object = SubscribeToDBClientTableReply.GetRootAsSubscribeToDBClientTableReply(data, 0)
notification_object = (SubscribeToDBClientTableReply
.GetRootAsSubscribeToDBClientTableReply(data, 0))
db_client_id = notification_object.DbClientId()
client_type = notification_object.ClientType()
auxiliary_address = notification_object.AuxAddress()
is_insertion = notification_object.IsInsertion()
# If the update was an insertion, we ignore it.
@@ -227,7 +232,6 @@ class Monitor(object):
if not self.subscribed[channel]:
# If the data was an integer, then the message was a response to an
# initial subscription request.
is_subscribe = int(data)
message_handler = self.subscribe_handler
elif channel == PLASMA_MANAGER_HEARTBEAT_CHANNEL:
assert(self.subscribed[channel])
@@ -264,12 +268,11 @@ class Monitor(object):
self.cleanup_task_table()
if len(self.dead_plasma_managers) > 0:
self.cleanup_object_table()
log.debug("{} dead local schedulers, {} plasma "
"managers total, {} dead plasma managers".format(
len(self.dead_local_schedulers),
len(self.live_plasma_managers) + len(self.dead_plasma_managers),
len(self.dead_plasma_managers)
))
log.debug("{} dead local schedulers, {} plasma managers total, {} dead "
"plasma managers".format(len(self.dead_local_schedulers),
(len(self.live_plasma_managers) +
len(self.dead_plasma_managers)),
len(self.dead_plasma_managers)))
# Handle messages from the subscription channels.
while True:
@@ -289,7 +292,8 @@ class Monitor(object):
# Handle plasma managers that timed out during this round.
plasma_manager_ids = list(self.live_plasma_managers.keys())
for plasma_manager_id in plasma_manager_ids:
if self.live_plasma_managers[plasma_manager_id] >= NUM_HEARTBEATS_TIMEOUT:
if ((self.live_plasma_managers
[plasma_manager_id]) >= NUM_HEARTBEATS_TIMEOUT):
log.warn("Timed out {}".format(PLASMA_MANAGER_CLIENT_TYPE))
# Remove the plasma manager from the managers whose heartbeats we're
# tracking.
@@ -299,7 +303,8 @@ class Monitor(object):
# receive the notification for this db_client deletion.
self.redis.execute_command("RAY.DISCONNECT", plasma_manager_id)
# Increment the number of heartbeats that we've missed from each plasma manager.
# Increment the number of heartbeats that we've missed from each plasma
# manager.
for plasma_manager_id in self.live_plasma_managers:
self.live_plasma_managers[plasma_manager_id] += 1
+18 -4
View File
@@ -10,15 +10,29 @@ If you are using Anaconda, try fixing this problem by running:
conda install libgcc
"""
__all__ = ["deserialize_list", "numbuf_error",
"numbuf_plasma_object_exists_error", "read_from_buffer",
"register_callbacks", "retrieve_list", "serialize_list",
"store_list", "write_to_buffer"]
try:
from ray.core.src.numbuf.libnumbuf import *
from ray.core.src.numbuf.libnumbuf import (deserialize_list, numbuf_error,
numbuf_plasma_object_exists_error,
read_from_buffer,
register_callbacks, retrieve_list,
serialize_list, store_list,
write_to_buffer)
except ImportError as e:
if hasattr(e, "msg") and isinstance(e.msg, str) and ("libstdc++" in e.msg or "CXX" in e.msg):
if (hasattr(e, "msg") and isinstance(e.msg, str) and ("libstdc++" in e.msg or
"CXX" in e.msg)):
# This code path should be taken with Python 3.
e.msg += helpful_message
elif hasattr(e, "message") and isinstance(e.message, str) and ("libstdc++" in e.message or "CXX" in e.message):
elif (hasattr(e, "message") and isinstance(e.message, str) and
("libstdc++" in e.message or "CXX" in e.message)):
# This code path should be taken with Python 2.
if hasattr(e, "args") and isinstance(e.args, tuple) and len(e.args) == 1 and isinstance(e.args[0], str):
condition = (hasattr(e, "args") and isinstance(e.args, tuple) and
len(e.args) == 1 and isinstance(e.args[0], str))
if condition:
e.args = (e.args[0] + helpful_message,)
else:
if not hasattr(e, "args"):
+34 -14
View File
@@ -1,55 +1,74 @@
# Note that a little bit of code here is taken and slightly modified from the pickler because it was not possible to change its behavior otherwise.
# Note that a little bit of code here is taken and slightly modified from the
# pickler because it was not possible to change its behavior otherwise.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
from ctypes import c_void_p
from cloudpickle import pickle, cloudpickle, CloudPickler, load, loads
__all__ = ["load", "loads", "dump", "dumps"]
try:
from ctypes import pythonapi
pythonapi.PyCell_Set # Make sure this exists
except:
pythonapi = None
def dump(obj, file, protocol=2):
return BetterPickler(file, protocol).dump(obj)
def dumps(obj):
stringio = cloudpickle.StringIO()
dump(obj, stringio)
return stringio.getvalue()
def _make_skel_func(code, closure, base_globals = None):
""" Creates a skeleton function object that contains just the provided
code and the correct number of cells in func_closure. All other
func attributes (e.g. func_globals) are empty.
def _make_skel_func(code, closure, base_globals=None):
"""Create a skeleton function object.
Creates a skeleton function object that contains just the provided code and
the correct number of cells in func_closure. All other func attributes
(e.g. func_globals) are empty.
"""
if base_globals is None: base_globals = {}
base_globals['__builtins__'] = __builtins__
return _make_skel_func.__class__(code, base_globals, None, None, tuple(closure))
if base_globals is None:
base_globals = {}
base_globals["__builtins__"] = __builtins__
return _make_skel_func.__class__(code, base_globals, None, None,
tuple(closure))
def _fill_function(func, globals, defaults, closure, dict):
""" Fills in the rest of function data into the skeleton function object
that were created via _make_skel_func(), including closures.
"""Fill in the resst of the function data.
This fills in the rest of function data into the skeleton function object
that were created via _make_skel_func(), including closures.
"""
result = cloudpickle._fill_function(func, globals, defaults, dict)
if pythonapi is not None:
for i, v in enumerate(closure):
pythonapi.PyCell_Set(c_void_p(id(result.__closure__[i])), c_void_p(id(v)))
pythonapi.PyCell_Set(c_void_p(id(result.__closure__[i])),
c_void_p(id(v)))
return result
class BetterPickler(CloudPickler):
def save_function_tuple(self, func):
code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func)
(code, f_globals, defaults,
closure, dct, base_globals) = self.extract_func_data(func)
self.save(_fill_function)
self.write(pickle.MARK)
self.save(_make_skel_func if pythonapi else cloudpickle._make_skel_func)
self.save((code, map(lambda _: cloudpickle._make_cell(None), closure) if closure and pythonapi is not None else closure, base_globals))
self.save((code,
(map(lambda _: cloudpickle._make_cell(None), closure)
if closure and pythonapi is not None
else closure),
base_globals))
self.write(pickle.REDUCE)
self.memoize(func)
@@ -59,6 +78,7 @@ class BetterPickler(CloudPickler):
self.save(dct)
self.write(pickle.TUPLE)
self.write(pickle.REDUCE)
def save_cell(self, obj):
self.save(cloudpickle._make_cell)
self.save((obj.cell_contents,))
+10 -1
View File
@@ -2,4 +2,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.plasma.plasma import *
from ray.plasma.plasma import (PlasmaBuffer, buffers_equal, PlasmaClient,
start_plasma_store, start_plasma_manager,
plasma_object_exists_error,
plasma_out_of_memory_error,
DEFAULT_PLASMA_STORE_MEMORY)
__all__ = ["PlasmaBuffer", "buffers_equal", "PlasmaClient",
"start_plasma_store", "start_plasma_manager",
"plasma_object_exists_error", "plasma_out_of_memory_error",
"DEFAULT_PLASMA_STORE_MEMORY"]
+55 -29
View File
@@ -12,9 +12,14 @@ import ray.core.src.plasma.libplasma as libplasma
from ray.core.src.plasma.libplasma import plasma_object_exists_error
from ray.core.src.plasma.libplasma import plasma_out_of_memory_error
PLASMA_ID_SIZE = 20
__all__ = ["PlasmaBuffer", "buffers_equal", "PlasmaClient",
"start_plasma_store", "start_plasma_manager",
"plasma_object_exists_error", "plasma_out_of_memory_error",
"DEFAULT_PLASMA_STORE_MEMORY"]
PLASMA_WAIT_TIMEOUT = 2 ** 30
class PlasmaBuffer(object):
"""This is the type of objects returned by calls to get with a PlasmaClient.
@@ -47,8 +52,8 @@ class PlasmaBuffer(object):
"""Read from the PlasmaBuffer as if it were just a regular buffer."""
# We currently don't allow slicing plasma buffers. We should handle this
# better, but it requires some care because the slice may be backed by the
# same memory in the object store, but the original plasma buffer may go out
# of scope causing the memory to no longer be accessible.
# same memory in the object store, but the original plasma buffer may go
# out of scope causing the memory to no longer be accessible.
assert not isinstance(index, slice)
value = self.buffer[index]
if sys.version_info >= (3, 0) and not isinstance(index, slice):
@@ -62,8 +67,8 @@ class PlasmaBuffer(object):
"""
# We currently don't allow slicing plasma buffers. We should handle this
# better, but it requires some care because the slice may be backed by the
# same memory in the object store, but the original plasma buffer may go out
# of scope causing the memory to no longer be accessible.
# same memory in the object store, but the original plasma buffer may go
# out of scope causing the memory to no longer be accessible.
assert not isinstance(index, slice)
if sys.version_info >= (3, 0) and not isinstance(index, slice):
value = ord(value)
@@ -73,48 +78,56 @@ class PlasmaBuffer(object):
"""Return the length of the buffer."""
return len(self.buffer)
def buffers_equal(buff1, buff2):
"""Compare two buffers. These buffers may be PlasmaBuffer objects.
This method should only be used in the tests. We implement a special helper
method for doing this because doing comparisons by slicing is much faster, but
we don't want to expose slicing of PlasmaBuffer objects because it currently
is not safe.
method for doing this because doing comparisons by slicing is much faster,
but we don't want to expose slicing of PlasmaBuffer objects because it
currently is not safe.
"""
buff1_to_compare = buff1.buffer if isinstance(buff1, PlasmaBuffer) else buff1
buff2_to_compare = buff2.buffer if isinstance(buff2, PlasmaBuffer) else buff2
return buff1_to_compare[:] == buff2_to_compare[:]
class PlasmaClient(object):
"""The PlasmaClient is used to interface with a plasma store and a plasma manager.
"""The PlasmaClient is used to interface with a plasma store and manager.
The PlasmaClient can ask the PlasmaStore to allocate a new buffer, seal a
buffer, and get a buffer. Buffers are referred to by object IDs, which are
strings.
"""
def __init__(self, store_socket_name, manager_socket_name=None, release_delay=64):
def __init__(self, store_socket_name, manager_socket_name=None,
release_delay=64):
"""Initialize the PlasmaClient.
Args:
store_socket_name (str): Name of the socket the plasma store is listening at.
manager_socket_name (str): Name of the socket the plasma manager is listening at.
store_socket_name (str): Name of the socket the plasma store is listening
at.
manager_socket_name (str): Name of the socket the plasma manager is
listening at.
release_delay (int): The maximum number of objects that the client will
keep and delay releasing (for caching reasons).
"""
self.store_socket_name = store_socket_name
self.manager_socket_name = manager_socket_name
self.alive = True
if manager_socket_name is not None:
self.conn = libplasma.connect(store_socket_name, manager_socket_name, release_delay)
self.conn = libplasma.connect(store_socket_name, manager_socket_name,
release_delay)
else:
self.conn = libplasma.connect(store_socket_name, "", release_delay)
def shutdown(self):
"""Shutdown the client so that it does not send messages.
If we kill the Plasma store and Plasma manager that this client is connected
to, then we can use this method to prevent the client from trying to send
messages to the killed processes.
If we kill the Plasma store and Plasma manager that this client is
connected to, then we can use this method to prevent the client from trying
to send messages to the killed processes.
"""
if self.alive:
libplasma.disconnect(self.conn)
@@ -169,8 +182,8 @@ class PlasmaClient(object):
def get_metadata(self, object_ids, timeout_ms=-1):
"""Create a buffer from the PlasmaStore based on object ID.
If the object has not been sealed yet, this call will block until the object
has been sealed. The retrieved buffer is immutable.
If the object has not been sealed yet, this call will block until the
object has been sealed. The retrieved buffer is immutable.
Args:
object_ids (List[str]): A list of strings used to identify some objects.
@@ -275,7 +288,8 @@ class PlasmaClient(object):
# currently crashes if given duplicate object IDs.
if len(object_ids) != len(set(object_ids)):
raise Exception("Wait requires a list of unique object IDs.")
ready_ids, waiting_ids = libplasma.wait(self.conn, object_ids, timeout, num_returns)
ready_ids, waiting_ids = libplasma.wait(self.conn, object_ids, timeout,
num_returns)
return ready_ids, list(waiting_ids)
def subscribe(self):
@@ -286,11 +300,14 @@ class PlasmaClient(object):
"""Get the next notification from the notification socket."""
return libplasma.receive_notification(self.notification_fd)
DEFAULT_PLASMA_STORE_MEMORY = 10 ** 9
def random_name():
return str(random.randint(0, 99999999))
def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
use_valgrind=False, use_profiler=False,
stdout_file=None, stderr_file=None):
@@ -312,16 +329,20 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
"""
if use_valgrind and use_profiler:
raise Exception("Cannot use valgrind and profiler at the same time.")
plasma_store_executable = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../core/src/plasma/plasma_store")
plasma_store_executable = os.path.join(os.path.abspath(
os.path.dirname(__file__)),
"../core/src/plasma/plasma_store")
plasma_store_name = "/tmp/plasma_store{}".format(random_name())
command = [plasma_store_executable, "-s", plasma_store_name, "-m", str(plasma_store_memory)]
command = [plasma_store_executable,
"-s", plasma_store_name,
"-m", str(plasma_store_memory)]
if use_valgrind:
pid = subprocess.Popen(["valgrind",
"--track-origins=yes",
"--leak-check=full",
"--show-leak-kinds=all",
"--error-exitcode=1"] + command,
stdout=stdout_file, stderr=stderr_file)
stdout=stdout_file, stderr=stderr_file)
time.sleep(1.0)
elif use_profiler:
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
@@ -332,13 +353,16 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
time.sleep(0.1)
return plasma_store_name, pid
def new_port():
return random.randint(10000, 65535)
def start_plasma_manager(store_name, redis_address, node_ip_address="127.0.0.1",
plasma_manager_port=None, num_retries=20,
use_valgrind=False, run_profiler=False,
stdout_file=None, stderr_file=None):
def start_plasma_manager(store_name, redis_address,
node_ip_address="127.0.0.1", plasma_manager_port=None,
num_retries=20, use_valgrind=False,
run_profiler=False, stdout_file=None,
stderr_file=None):
"""Start a plasma manager and return the ports it listens on.
Args:
@@ -361,7 +385,9 @@ def start_plasma_manager(store_name, redis_address, node_ip_address="127.0.0.1",
Raises:
Exception: An exception is raised if the manager could not be started.
"""
plasma_manager_executable = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../core/src/plasma/plasma_manager")
plasma_manager_executable = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
"../core/src/plasma/plasma_manager")
plasma_manager_name = "/tmp/plasma_manager{}".format(random_name())
if plasma_manager_port is not None:
if num_retries != 1:
@@ -385,7 +411,7 @@ def start_plasma_manager(store_name, redis_address, node_ip_address="127.0.0.1",
"--leak-check=full",
"--show-leak-kinds=all",
"--error-exitcode=1"] + command,
stdout=stdout_file, stderr=stderr_file)
stdout=stdout_file, stderr=stderr_file)
elif run_profiler:
process = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
stdout=stdout_file, stderr=stderr_file)
@@ -396,7 +422,7 @@ def start_plasma_manager(store_name, redis_address, node_ip_address="127.0.0.1",
# port is already in use, then we need it to fail within 0.1 seconds.
time.sleep(0.1)
# See if the process has terminated
if process.poll() == None:
if process.poll() is None:
return plasma_manager_name, process, plasma_manager_port
# Generate a new port and try again.
plasma_manager_port = new_port()
+111 -65
View File
@@ -6,23 +6,22 @@ import numpy as np
import os
import random
import signal
import socket
import struct
import subprocess
import sys
import tempfile
import threading
import time
import unittest
import ray.plasma as plasma
from ray.plasma.utils import random_object_id, generate_metadata, write_to_data_buffer, create_object_with_id, create_object
from ray.plasma.utils import (random_object_id, generate_metadata,
create_object_with_id, create_object)
from ray import services
USE_VALGRIND = False
PLASMA_STORE_MEMORY = 1000000000
def assert_get_object_equal(unit_test, client1, client2, object_id, memory_buffer=None, metadata=None):
def assert_get_object_equal(unit_test, client1, client2, object_id,
memory_buffer=None, metadata=None):
client1_buff = client1.get([object_id])[0]
client2_buff = client2.get([object_id])[0]
client1_metadata = client1.get_metadata([object_id])[0]
@@ -32,7 +31,8 @@ def assert_get_object_equal(unit_test, client1, client2, object_id, memory_buffe
# Check that the buffers from the two clients are the same.
unit_test.assertTrue(plasma.buffers_equal(client1_buff, client2_buff))
# Check that the metadata buffers from the two clients are the same.
unit_test.assertTrue(plasma.buffers_equal(client1_metadata, client2_metadata))
unit_test.assertTrue(plasma.buffers_equal(client1_metadata,
client2_metadata))
# If a reference buffer was provided, check that it is the same as well.
if memory_buffer is not None:
unit_test.assertTrue(plasma.buffers_equal(memory_buffer, client1_buff))
@@ -40,11 +40,13 @@ def assert_get_object_equal(unit_test, client1, client2, object_id, memory_buffe
if metadata is not None:
unit_test.assertTrue(plasma.buffers_equal(metadata, client1_metadata))
class TestPlasmaClient(unittest.TestCase):
def setUp(self):
# Start Plasma store.
plasma_store_name, self.p = plasma.start_plasma_store(use_valgrind=USE_VALGRIND)
plasma_store_name, self.p = plasma.start_plasma_store(
use_valgrind=USE_VALGRIND)
# Connect to Plasma.
self.plasma_client = plasma.PlasmaClient(plasma_store_name, None, 64)
# For the eviction test
@@ -107,7 +109,7 @@ class TestPlasmaClient(unittest.TestCase):
object_id = random_object_id()
self.plasma_client.create(object_id, length, generate_metadata(length))
try:
val = self.plasma_client.create(object_id, length, generate_metadata(length))
self.plasma_client.create(object_id, length, generate_metadata(length))
except plasma.plasma_object_exists_error as e:
pass
else:
@@ -125,20 +127,24 @@ class TestPlasmaClient(unittest.TestCase):
metadata_buffers = []
for i in range(num_object_ids):
if i % 2 == 0:
data_buffer, metadata_buffer = create_object_with_id(self.plasma_client, object_ids[i], 2000, 2000)
data_buffer, metadata_buffer = create_object_with_id(
self.plasma_client, object_ids[i], 2000, 2000)
data_buffers.append(data_buffer)
metadata_buffers.append(metadata_buffer)
# Test timing out from some but not all get calls with various timeouts.
for timeout in [0, 10, 100, 1000]:
data_results = self.plasma_client.get(object_ids, timeout_ms=timeout)
metadata_results = self.plasma_client.get(object_ids, timeout_ms=timeout)
# metadata_results = self.plasma_client.get_metadata(object_ids,
# timeout_ms=timeout)
for i in range(num_object_ids):
if i % 2 == 0:
self.assertTrue(plasma.buffers_equal(data_buffers[i // 2], data_results[i]))
# TODO(rkn): We should compare the metadata as well. But currently the
# types are different (e.g., memoryview versus bytearray).
# self.assertTrue(plasma.buffers_equal(metadata_buffers[i // 2], metadata_results[i]))
self.assertTrue(plasma.buffers_equal(data_buffers[i // 2],
data_results[i]))
# TODO(rkn): We should compare the metadata as well. But currently
# the types are different (e.g., memoryview versus bytearray).
# self.assertTrue(plasma.buffers_equal(metadata_buffers[i // 2],
# metadata_results[i]))
else:
self.assertIsNone(results[i])
@@ -148,7 +154,9 @@ class TestPlasmaClient(unittest.TestCase):
def assert_create_raises_plasma_full(unit_test, size):
partial_size = np.random.randint(size)
try:
_, memory_buffer, _ = create_object(unit_test.plasma_client, partial_size, size - partial_size)
_, memory_buffer, _ = create_object(unit_test.plasma_client,
partial_size,
size - partial_size)
except plasma.plasma_out_of_memory_error as e:
pass
else:
@@ -215,7 +223,7 @@ class TestPlasmaClient(unittest.TestCase):
real_object_ids = [random_object_id() for _ in range(100)]
for object_id in real_object_ids:
self.assertFalse(self.plasma_client.contains(object_id))
memory_buffer = self.plasma_client.create(object_id, 100)
self.plasma_client.create(object_id, 100)
self.plasma_client.seal(object_id)
self.assertTrue(self.plasma_client.contains(object_id))
for object_id in fake_object_ids:
@@ -226,7 +234,7 @@ class TestPlasmaClient(unittest.TestCase):
def test_hash(self):
# Check the hash of an object that doesn't exist.
object_id1 = random_object_id()
h = self.plasma_client.hash(object_id1)
self.plasma_client.hash(object_id1)
length = 1000
# Create a random object, and check that the hash function always returns
@@ -357,7 +365,7 @@ class TestPlasmaClient(unittest.TestCase):
length = 1000
memory_buffer = self.plasma_client.create(object_id, length)
# Make sure we cannot access memory out of bounds.
self.assertRaises(Exception, lambda : memory_buffer[length])
self.assertRaises(Exception, lambda: memory_buffer[length])
# Seal the object.
self.plasma_client.seal(object_id)
# This test is commented out because it currently fails.
@@ -367,6 +375,7 @@ class TestPlasmaClient(unittest.TestCase):
# self.assertRaises(Exception, illegal_assignment)
# Get the object.
memory_buffer = self.plasma_client.get([object_id])[0]
# Make sure the object is read only.
def illegal_assignment():
memory_buffer[0] = chr(0)
@@ -413,18 +422,20 @@ class TestPlasmaClient(unittest.TestCase):
def test_subscribe(self):
# Subscribe to notifications from the Plasma Store.
sock = self.plasma_client.subscribe()
self.plasma_client.subscribe()
for i in [1, 10, 100, 1000, 10000, 100000]:
object_ids = [random_object_id() for _ in range(i)]
metadata_sizes = [np.random.randint(1000) for _ in range(i)]
data_sizes = [np.random.randint(1000) for _ in range(i)]
for j in range(i):
self.plasma_client.create(object_ids[j], size=data_sizes[j],
metadata=bytearray(np.random.bytes(metadata_sizes[j])))
self.plasma_client.create(
object_ids[j], size=data_sizes[j],
metadata=bytearray(np.random.bytes(metadata_sizes[j])))
self.plasma_client.seal(object_ids[j])
# Check that we received notifications for all of the objects.
for j in range(i):
recv_objid, recv_dsize, recv_msize = self.plasma_client.get_next_notification()
notification_info = self.plasma_client.get_next_notification()
recv_objid, recv_dsize, recv_msize = notification_info
self.assertEqual(object_ids[j], recv_objid)
self.assertEqual(data_sizes[j], recv_dsize)
self.assertEqual(metadata_sizes[j], recv_msize)
@@ -432,20 +443,22 @@ class TestPlasmaClient(unittest.TestCase):
def test_subscribe_deletions(self):
# Subscribe to notifications from the Plasma Store. We use plasma_client2
# to make sure that all used objects will get evicted properly.
sock = self.plasma_client2.subscribe()
self.plasma_client2.subscribe()
for i in [1, 10, 100, 1000, 10000, 100000]:
object_ids = [random_object_id() for _ in range(i)]
# Add 1 to the sizes to make sure we have nonzero object sizes.
metadata_sizes = [np.random.randint(1000) + 1 for _ in range(i)]
data_sizes = [np.random.randint(1000) + 1 for _ in range(i)]
for j in range(i):
x = self.plasma_client2.create(object_ids[j], size=data_sizes[j],
metadata=bytearray(np.random.bytes(metadata_sizes[j])))
x = self.plasma_client2.create(
object_ids[j], size=data_sizes[j],
metadata=bytearray(np.random.bytes(metadata_sizes[j])))
self.plasma_client2.seal(object_ids[j])
del x
# Check that we received notifications for creating all of the objects.
for j in range(i):
recv_objid, recv_dsize, recv_msize = self.plasma_client2.get_next_notification()
notification_info = self.plasma_client2.get_next_notification()
recv_objid, recv_dsize, recv_msize = notification_info
self.assertEqual(object_ids[j], recv_objid)
self.assertEqual(data_sizes[j], recv_dsize)
self.assertEqual(metadata_sizes[j], recv_msize)
@@ -453,8 +466,10 @@ class TestPlasmaClient(unittest.TestCase):
# Check that we receive notifications for deleting all objects, as we
# evict them.
for j in range(i):
self.assertEqual(self.plasma_client2.evict(1), data_sizes[j] + metadata_sizes[j])
recv_objid, recv_dsize, recv_msize = self.plasma_client2.get_next_notification()
self.assertEqual(self.plasma_client2.evict(1),
data_sizes[j] + metadata_sizes[j])
notification_info = self.plasma_client2.get_next_notification()
recv_objid, recv_dsize, recv_msize = notification_info
self.assertEqual(object_ids[j], recv_objid)
self.assertEqual(-1, recv_dsize)
self.assertEqual(-1, recv_msize)
@@ -469,18 +484,22 @@ class TestPlasmaClient(unittest.TestCase):
metadata_sizes.append(np.random.randint(1000))
data_sizes.append(np.random.randint(1000))
for i in range(num_object_ids):
x = self.plasma_client2.create(object_ids[i], size=data_sizes[i],
metadata=bytearray(np.random.bytes(metadata_sizes[i])))
x = self.plasma_client2.create(
object_ids[i], size=data_sizes[i],
metadata=bytearray(np.random.bytes(metadata_sizes[i])))
self.plasma_client2.seal(object_ids[i])
del x
for i in range(num_object_ids):
recv_objid, recv_dsize, recv_msize = self.plasma_client2.get_next_notification()
notification_info = self.plasma_client2.get_next_notification()
recv_objid, recv_dsize, recv_msize = notification_info
self.assertEqual(object_ids[i], recv_objid)
self.assertEqual(data_sizes[i], recv_dsize)
self.assertEqual(metadata_sizes[i], recv_msize)
self.assertEqual(self.plasma_client2.evict(1), data_sizes[-1] + metadata_sizes[-1])
self.assertEqual(self.plasma_client2.evict(1),
data_sizes[-1] + metadata_sizes[-1])
for i in range(num_object_ids):
recv_objid, recv_dsize, recv_msize = self.plasma_client2.get_next_notification()
notification_info = self.plasma_client2.get_next_notification()
recv_objid, recv_dsize, recv_msize = notification_info
self.assertEqual(object_ids[i], recv_objid)
self.assertEqual(-1, recv_dsize)
self.assertEqual(-1, recv_msize)
@@ -495,8 +514,10 @@ class TestPlasmaManager(unittest.TestCase):
# Start a Redis server.
redis_address = services.address("127.0.0.1", services.start_redis()[0])
# Start two PlasmaManagers.
manager_name1, self.p4, self.port1 = plasma.start_plasma_manager(store_name1, redis_address, use_valgrind=USE_VALGRIND)
manager_name2, self.p5, self.port2 = plasma.start_plasma_manager(store_name2, redis_address, use_valgrind=USE_VALGRIND)
manager_name1, self.p4, self.port1 = plasma.start_plasma_manager(
store_name1, redis_address, use_valgrind=USE_VALGRIND)
manager_name2, self.p5, self.port2 = plasma.start_plasma_manager(
store_name2, redis_address, use_valgrind=USE_VALGRIND)
# Connect two PlasmaClients.
self.client1 = plasma.PlasmaClient(store_name1, manager_name1)
self.client2 = plasma.PlasmaClient(store_name2, manager_name2)
@@ -513,7 +534,8 @@ class TestPlasmaManager(unittest.TestCase):
# Kill the Plasma store and Plasma manager processes.
if USE_VALGRIND:
time.sleep(1) # give processes opportunity to finish work
# Give processes opportunity to finish work.
time.sleep(1)
for process in self.processes_to_kill:
process.send_signal(signal.SIGTERM)
process.wait()
@@ -530,7 +552,8 @@ class TestPlasmaManager(unittest.TestCase):
def test_fetch(self):
for _ in range(10):
# Create an object.
object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000, 2000)
object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000,
2000)
self.client1.fetch([object_id1])
self.assertEqual(self.client1.contains(object_id1), True)
self.assertEqual(self.client2.contains(object_id1), False)
@@ -546,7 +569,8 @@ class TestPlasmaManager(unittest.TestCase):
object_id2 = random_object_id()
self.client1.fetch([object_id2])
self.assertEqual(self.client1.contains(object_id2), False)
memory_buffer2, metadata2 = create_object_with_id(self.client2, object_id2, 2000, 2000)
memory_buffer2, metadata2 = create_object_with_id(self.client2, object_id2,
2000, 2000)
# # Check that the object has been fetched.
# self.assertEqual(self.client1.contains(object_id2), True)
# Compare the two buffers.
@@ -560,11 +584,12 @@ class TestPlasmaManager(unittest.TestCase):
for _ in range(10):
self.client1.fetch([object_id3])
self.client2.fetch([object_id3])
memory_buffer3, metadata3 = create_object_with_id(self.client1, object_id3, 2000, 2000)
memory_buffer3, metadata3 = create_object_with_id(self.client1, object_id3,
2000, 2000)
for _ in range(10):
self.client1.fetch([object_id3])
self.client2.fetch([object_id3])
#TODO(rkn): Right now we must wait for the object table to be updated.
# TODO(rkn): Right now we must wait for the object table to be updated.
while not self.client2.contains(object_id3):
self.client2.fetch([object_id3])
assert_get_object_equal(self, self.client1, self.client2, object_id3,
@@ -573,14 +598,17 @@ class TestPlasmaManager(unittest.TestCase):
def test_fetch_multiple(self):
for _ in range(20):
# Create two objects and a third fake one that doesn't exist.
object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000, 2000)
object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000,
2000)
missing_object_id = random_object_id()
object_id2, memory_buffer2, metadata2 = create_object(self.client1, 2000, 2000)
object_id2, memory_buffer2, metadata2 = create_object(self.client1, 2000,
2000)
object_ids = [object_id1, missing_object_id, object_id2]
# Fetch the objects from the other plasma store. The second object ID
# should timeout since it does not exist.
# TODO(rkn): Right now we must wait for the object table to be updated.
while (not self.client2.contains(object_id1)) or (not self.client2.contains(object_id2)):
while ((not self.client2.contains(object_id1)) or
(not self.client2.contains(object_id2))):
self.client2.fetch(object_ids)
# Compare the buffers of the objects that do exist.
assert_get_object_equal(self, self.client1, self.client2, object_id1,
@@ -597,7 +625,8 @@ class TestPlasmaManager(unittest.TestCase):
# Check that we can call fetch with duplicated object IDs.
object_id3 = random_object_id()
self.client1.fetch([object_id3, object_id3])
object_id4, memory_buffer4, metadata4 = create_object(self.client1, 2000, 2000)
object_id4, memory_buffer4, metadata4 = create_object(self.client1, 2000,
2000)
time.sleep(0.1)
# TODO(rkn): Right now we must wait for the object table to be updated.
while not self.client2.contains(object_id4):
@@ -623,7 +652,8 @@ class TestPlasmaManager(unittest.TestCase):
obj_id2 = random_object_id()
self.client1.create(obj_id2, 1000)
# Don't seal.
ready, waiting = self.client1.wait([obj_id2, obj_id1], timeout=100, num_returns=1)
ready, waiting = self.client1.wait([obj_id2, obj_id1], timeout=100,
num_returns=1)
self.assertEqual(set(ready), set([obj_id1]))
self.assertEqual(set(waiting), set([obj_id2]))
@@ -636,11 +666,13 @@ class TestPlasmaManager(unittest.TestCase):
t = threading.Timer(0.1, finish)
t.start()
ready, waiting = self.client1.wait([obj_id3, obj_id2, obj_id1], timeout=1000, num_returns=2)
ready, waiting = self.client1.wait([obj_id3, obj_id2, obj_id1],
timeout=1000, num_returns=2)
self.assertEqual(set(ready), set([obj_id1, obj_id3]))
self.assertEqual(set(waiting), set([obj_id2]))
# Test if the appropriate number of objects is shown if some objects are not ready
# Test if the appropriate number of objects is shown if some objects are
# not ready.
ready, waiting = self.client1.wait([obj_id3, obj_id2, obj_id1], 100, 3)
self.assertEqual(set(ready), set([obj_id1, obj_id3]))
self.assertEqual(set(waiting), set([obj_id2]))
@@ -651,9 +683,9 @@ class TestPlasmaManager(unittest.TestCase):
# Test calling wait a bunch of times.
object_ids = []
# TODO(rkn): Increasing n to 100 (or larger) will cause failures. The
# problem appears to be that the number of timers added to the manager event
# loop slow down the manager so much that some of the asynchronous Redis
# commands timeout triggering fatal failure callbacks.
# problem appears to be that the number of timers added to the manager
# event loop slow down the manager so much that some of the asynchronous
# Redis commands timeout triggering fatal failure callbacks.
n = 40
for i in range(n * (n + 1) // 2):
if i % 2 == 0:
@@ -669,7 +701,8 @@ class TestPlasmaManager(unittest.TestCase):
self.assertEqual(len(ready), i)
retrieved += ready
self.assertEqual(set(retrieved), set(object_ids))
ready, waiting = self.client1.wait(object_ids, timeout=1000, num_returns=len(object_ids))
ready, waiting = self.client1.wait(object_ids, timeout=1000,
num_returns=len(object_ids))
self.assertEqual(set(ready), set(object_ids))
self.assertEqual(waiting, [])
# Try waiting for all of the object IDs on the second client.
@@ -680,7 +713,8 @@ class TestPlasmaManager(unittest.TestCase):
self.assertEqual(len(ready), i)
retrieved += ready
self.assertEqual(set(retrieved), set(object_ids))
ready, waiting = self.client2.wait(object_ids, timeout=1000, num_returns=len(object_ids))
ready, waiting = self.client2.wait(object_ids, timeout=1000,
num_returns=len(object_ids))
self.assertEqual(set(ready), set(object_ids))
self.assertEqual(waiting, [])
@@ -702,7 +736,8 @@ class TestPlasmaManager(unittest.TestCase):
num_attempts = 100
for _ in range(100):
# Create an object.
object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000, 2000)
object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000,
2000)
# Transfer the buffer to the the other Plasma store. There is a race
# condition on the create and transfer of the object, so keep trying
# until the object appears on the second Plasma store.
@@ -721,10 +756,12 @@ class TestPlasmaManager(unittest.TestCase):
# self.client1.transfer("127.0.0.1", self.port2, object_id1)
# # Compare the two buffers.
# assert_get_object_equal(self, self.client1, self.client2, object_id1,
# memory_buffer=memory_buffer1, metadata=metadata1)
# memory_buffer=memory_buffer1,
# metadata=metadata1)
# Create an object.
object_id2, memory_buffer2, metadata2 = create_object(self.client2, 20000, 20000)
object_id2, memory_buffer2, metadata2 = create_object(self.client2,
20000, 20000)
# Transfer the buffer to the the other Plasma store. There is a race
# condition on the create and transfer of the object, so keep trying
# until the object appears on the second Plasma store.
@@ -742,17 +779,19 @@ class TestPlasmaManager(unittest.TestCase):
def test_illegal_functionality(self):
# Create an object id string.
object_id = random_object_id()
# object_id = random_object_id()
# Create a new buffer.
# memory_buffer = self.client1.create(object_id, 20000)
# This test is commented out because it currently fails.
# # Transferring the buffer before sealing it should fail.
# self.assertRaises(Exception, lambda : self.manager1.transfer(1, object_id))
# self.assertRaises(Exception,
# lambda : self.manager1.transfer(1, object_id))
pass
def test_stresstest(self):
a = time.time()
object_ids = []
for i in range(10000): # TODO(pcm): increase this to 100000
for i in range(10000): # TODO(pcm): increase this to 100000.
object_id = random_object_id()
object_ids.append(object_id)
self.client1.create(object_id, 1)
@@ -763,13 +802,16 @@ class TestPlasmaManager(unittest.TestCase):
print("it took", b, "seconds to put and transfer the objects")
class TestPlasmaManagerRecovery(unittest.TestCase):
def setUp(self):
# Start a Plasma store.
self.store_name, self.p2 = plasma.start_plasma_store(use_valgrind=USE_VALGRIND)
self.store_name, self.p2 = plasma.start_plasma_store(
use_valgrind=USE_VALGRIND)
# Start a Redis server.
self.redis_address = services.address("127.0.0.1", services.start_redis()[0])
self.redis_address = services.address("127.0.0.1",
services.start_redis()[0])
# Start a PlasmaManagers.
manager_name, self.p3, self.port1 = plasma.start_plasma_manager(
self.store_name,
@@ -791,7 +833,8 @@ class TestPlasmaManagerRecovery(unittest.TestCase):
# Kill the Plasma store and Plasma manager processes.
if USE_VALGRIND:
time.sleep(1) # give processes opportunity to finish work
# Give processes opportunity to finish work.
time.sleep(1)
for process in self.processes_to_kill:
process.send_signal(signal.SIGTERM)
process.wait()
@@ -818,23 +861,26 @@ class TestPlasmaManagerRecovery(unittest.TestCase):
self.assertEqual(waiting, [])
# Start a second plasma manager attached to the same store.
manager_name, self.p5, self.port2 = plasma.start_plasma_manager(self.store_name, self.redis_address, use_valgrind=USE_VALGRIND)
manager_name, self.p5, self.port2 = plasma.start_plasma_manager(
self.store_name, self.redis_address, use_valgrind=USE_VALGRIND)
self.processes_to_kill = [self.p5] + self.processes_to_kill
# Check that the second manager knows about existing objects.
client2 = plasma.PlasmaClient(self.store_name, manager_name)
ready, waiting = [], object_ids
while True:
ready, waiting = client2.wait(object_ids, num_returns=num_objects, timeout=0)
ready, waiting = client2.wait(object_ids, num_returns=num_objects,
timeout=0)
if len(ready) == len(object_ids):
break
self.assertEqual(set(ready), set(object_ids))
self.assertEqual(waiting, [])
if __name__ == "__main__":
if len(sys.argv) > 1:
# pop the argument so we don't mess with unittest's own argument parser
# Pop the argument so we don't mess with unittest's own argument parser.
if sys.argv[-1] == "valgrind":
arg = sys.argv.pop()
USE_VALGRIND = True
+9 -2
View File
@@ -5,9 +5,11 @@ from __future__ import print_function
import numpy as np
import random
def random_object_id():
return np.random.bytes(20)
def generate_metadata(length):
metadata_buffer = bytearray(length)
if length > 0:
@@ -17,6 +19,7 @@ def generate_metadata(length):
metadata_buffer[random.randint(0, length - 1)] = random.randint(0, 255)
return metadata_buffer
def write_to_data_buffer(buff, length):
if length > 0:
buff[0] = chr(random.randint(0, 255))
@@ -24,7 +27,9 @@ def write_to_data_buffer(buff, length):
for _ in range(100):
buff[random.randint(0, length - 1)] = chr(random.randint(0, 255))
def create_object_with_id(client, object_id, data_size, metadata_size, seal=True):
def create_object_with_id(client, object_id, data_size, metadata_size,
seal=True):
metadata = generate_metadata(metadata_size)
memory_buffer = client.create(object_id, data_size, metadata)
write_to_data_buffer(memory_buffer, data_size)
@@ -32,7 +37,9 @@ def create_object_with_id(client, object_id, data_size, metadata_size, seal=True
client.seal(object_id)
return memory_buffer, metadata
def create_object(client, data_size, metadata_size, seal=True):
object_id = random_object_id()
memory_buffer, metadata = create_object_with_id(client, object_id, data_size, metadata_size, seal=seal)
memory_buffer, metadata = create_object_with_id(client, object_id, data_size,
metadata_size, seal=seal)
return object_id, memory_buffer, metadata
+44 -10
View File
@@ -7,6 +7,7 @@ import numpy as np
import ray.numbuf
import ray.pickling as pickling
def check_serializable(cls):
"""Throws an exception if Ray cannot serialize this class efficiently.
@@ -21,15 +22,30 @@ def check_serializable(cls):
# This case works.
return
if not hasattr(cls, "__new__"):
raise Exception("The class {} does not have a '__new__' attribute, and is probably an old-style class. We do not support this. Please either make it a new-style class by inheriting from 'object', or use 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls))
raise Exception("The class {} does not have a '__new__' attribute, and is "
"probably an old-style class. We do not support this. "
"Please either make it a new-style class by inheriting "
"from 'object', or use "
"'ray.register_class(cls, pickle=True)'. However, note "
"that pickle is inefficient.".format(cls))
try:
obj = cls.__new__(cls)
except:
raise Exception("The class {} has overridden '__new__', so Ray may not be able to serialize it efficiently. Try using 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls))
raise Exception("The class {} has overridden '__new__', so Ray may not be "
"able to serialize it efficiently. Try using "
"'ray.register_class(cls, pickle=True)'. However, note "
"that pickle is inefficient.".format(cls))
if not hasattr(obj, "__dict__"):
raise Exception("Objects of the class {} do not have a `__dict__` attribute, so Ray cannot serialize it efficiently. Try using 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls))
raise Exception("Objects of the class {} do not have a `__dict__` "
"attribute, so Ray cannot serialize it efficiently. Try "
"using 'ray.register_class(cls, pickle=True)'. However, "
"note that pickle is inefficient.".format(cls))
if hasattr(obj, "__slots__"):
raise Exception("The class {} uses '__slots__', so Ray may not be able to serialize it efficiently. Try using 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls))
raise Exception("The class {} uses '__slots__', so Ray may not be able to "
"serialize it efficiently. Try using "
"'ray.register_class(cls, pickle=True)'. However, note "
"that pickle is inefficient.".format(cls))
# This field keeps track of a whitelisted set of classes that Ray will
# serialize.
@@ -38,10 +54,12 @@ classes_to_pickle = set()
custom_serializers = {}
custom_deserializers = {}
def class_identifier(typ):
"""Return a string that identifies this type."""
return "{}.{}".format(typ.__module__, typ.__name__)
def is_named_tuple(cls):
"""Return True if cls is a namedtuple and False otherwise."""
b = cls.__bases__
@@ -52,7 +70,9 @@ def is_named_tuple(cls):
return False
return all(type(n) == str for n in f)
def add_class_to_whitelist(cls, pickle=False, custom_serializer=None, custom_deserializer=None):
def add_class_to_whitelist(cls, pickle=False, custom_serializer=None,
custom_deserializer=None):
"""Add cls to the list of classes that we can serialize.
Args:
@@ -72,13 +92,21 @@ def add_class_to_whitelist(cls, pickle=False, custom_serializer=None, custom_des
custom_serializers[class_id] = custom_serializer
custom_deserializers[class_id] = custom_deserializer
# Here we define a custom serializer and deserializer for handling numpy
# arrays that contain objects.
def array_custom_serializer(obj):
return obj.tolist(), obj.dtype.str
def array_custom_deserializer(serialized_obj):
return np.array(serialized_obj[0], dtype=np.dtype(serialized_obj[1]))
add_class_to_whitelist(np.ndarray, pickle=False, custom_serializer=array_custom_serializer, custom_deserializer=array_custom_deserializer)
add_class_to_whitelist(np.ndarray, pickle=False,
custom_serializer=array_custom_serializer,
custom_deserializer=array_custom_deserializer)
def serialize(obj):
"""This is the callback that will be used by numbuf.
@@ -94,7 +122,9 @@ def serialize(obj):
"""
class_id = class_identifier(type(obj))
if class_id not in whitelisted_classes:
raise Exception("Ray does not know how to serialize objects of type {}. To fix this, call 'ray.register_class' with this class.".format(type(obj)))
raise Exception("Ray does not know how to serialize objects of type {}. "
"To fix this, call 'ray.register_class' with this class."
.format(type(obj)))
if class_id in classes_to_pickle:
serialized_obj = {"data": pickling.dumps(obj)}
elif class_id in custom_serializers.keys():
@@ -107,10 +137,12 @@ def serialize(obj):
elif hasattr(obj, "__dict__"):
serialized_obj = obj.__dict__
else:
raise Exception("We do not know how to serialize the object '{}'".format(obj))
raise Exception("We do not know how to serialize the object '{}'"
.format(obj))
result = dict(serialized_obj, **{"_pytype_": class_id})
return result
def deserialize(serialized_obj):
"""This is the callback that will be used by numbuf.
@@ -139,11 +171,13 @@ def deserialize(serialized_obj):
obj.__dict__.update(serialized_obj)
return obj
def set_callbacks():
"""Register the custom callbacks with numbuf.
The serialize callback is used to serialize objects that numbuf does not know
how to serialize (for example custom Python classes). The deserialize callback
is used to serialize objects that were serialized by the serialize callback.
how to serialize (for example custom Python classes). The deserialize
callback is used to serialize objects that were serialized by the serialize
callback.
"""
ray.numbuf.register_callbacks(serialize, deserialize)
+162 -103
View File
@@ -10,7 +10,6 @@ import random
import redis
import signal
import socket
import string
import subprocess
import sys
import time
@@ -60,16 +59,20 @@ ObjectStoreAddress = namedtuple("ObjectStoreAddress", ["name",
"manager_name",
"manager_port"])
def address(ip_address, port):
return ip_address + ":" + str(port)
def get_ip_address(address):
try:
ip_address = address.split(":")[0]
except:
raise Exception("Unable to parse IP address from address {}".format(address))
raise Exception("Unable to parse IP address from address "
"{}".format(address))
return ip_address
def get_port(address):
try:
port = int(address.split(":")[1])
@@ -77,12 +80,15 @@ def get_port(address):
raise Exception("Unable to parse port from address {}".format(address))
return port
def new_port():
return random.randint(10000, 65535)
def random_name():
return str(random.randint(0, 99999999))
def kill_process(p):
"""Kill a process.
@@ -92,11 +98,15 @@ def kill_process(p):
Returns:
True if the process was killed successfully and false otherwise.
"""
if p.poll() is not None: # process has already terminated
if p.poll() is not None:
# The process has already terminated.
return True
if RUN_LOCAL_SCHEDULER_PROFILER or RUN_PLASMA_MANAGER_PROFILER or RUN_PLASMA_STORE_PROFILER:
os.kill(p.pid, signal.SIGINT) # Give process signal to write profiler data.
time.sleep(0.1) # Wait for profiling data to be written.
if any([RUN_LOCAL_SCHEDULER_PROFILER, RUN_PLASMA_MANAGER_PROFILER,
RUN_PLASMA_STORE_PROFILER]):
# Give process signal to write profiler data.
os.kill(p.pid, signal.SIGINT)
# Wait for profiling data to be written.
time.sleep(0.1)
# Allow the process one second to exit gracefully.
p.terminate()
@@ -118,6 +128,7 @@ def kill_process(p):
# The process was not killed for some reason.
return False
def cleanup():
"""When running in local mode, shutdown the Ray processes.
@@ -138,6 +149,7 @@ def cleanup():
if not successfully_shut_down:
print("Ray did not shut down properly.")
def all_processes_alive(exclude=[]):
"""Check if all of the processes are still alive.
@@ -147,10 +159,13 @@ def all_processes_alive(exclude=[]):
for process_type, processes in all_processes.items():
# Note that p.poll() returns the exit code that the process exited with, so
# an exit code of None indicates that the process is still alive.
if not all([p.poll() is None for p in processes]) and process_type not in exclude:
processes_alive = [p.poll() is None for p in processes]
if (not all(processes_alive) and process_type not in exclude):
print("A process of type {} has dead.".format(process_type))
return False
return True
def get_node_ip_address(address="8.8.8.8:53"):
"""Determine the IP address of the local node.
@@ -166,6 +181,7 @@ def get_node_ip_address(address="8.8.8.8:53"):
s.connect((ip_address, int(port)))
return s.getsockname()[0]
def record_log_files_in_redis(redis_address, node_ip_address, log_files):
"""Record in Redis that a new log file has been created.
@@ -187,6 +203,7 @@ def record_log_files_in_redis(redis_address, node_ip_address, log_files):
log_file_list_key = "LOG_FILENAMES:{}".format(node_ip_address)
redis_client.rpush(log_file_list_key, log_file.name)
def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5):
"""Wait for a Redis server to be available.
@@ -208,7 +225,8 @@ def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5):
while counter < num_retries:
try:
# Run some random command and see if it worked.
print("Waiting for redis server at {}:{} to respond...".format(redis_ip_address, redis_port))
print("Waiting for redis server at {}:{} to respond..."
.format(redis_ip_address, redis_port))
redis_client.client_list()
except redis.ConnectionError as e:
# Wait a little bit.
@@ -218,7 +236,10 @@ def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5):
else:
break
if counter == num_retries:
raise Exception("Unable to connect to Redis. If the Redis instance is on a different machine, check that your firewall is configured properly.")
raise Exception("Unable to connect to Redis. If the Redis instance is on "
"a different machine, check that your firewall is "
"configured properly.")
def start_redis(node_ip_address="127.0.0.1", port=None, num_retries=20,
stdout_file=None, stderr_file=None, cleanup=True):
@@ -239,13 +260,18 @@ def start_redis(node_ip_address="127.0.0.1", port=None, num_retries=20,
Returns:
A tuple of the port used by Redis and a handle to the process that was
started. If a port is passed in, then the returned port value is the same.
started. If a port is passed in, then the returned port value is the
same.
Raises:
Exception: An exception is raised if Redis could not be started.
"""
redis_filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), "./core/src/common/thirdparty/redis/src/redis-server")
redis_module = os.path.join(os.path.dirname(os.path.abspath(__file__)), "./core/src/common/redis_module/libray_redis_module.so")
redis_filepath = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"./core/src/common/thirdparty/redis/src/redis-server")
redis_module = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"./core/src/common/redis_module/libray_redis_module.so")
assert os.path.isfile(redis_filepath)
assert os.path.isfile(redis_module)
counter = 0
@@ -261,7 +287,7 @@ def start_redis(node_ip_address="127.0.0.1", port=None, num_retries=20,
"--port", str(port),
"--loglevel", "warning",
"--loadmodule", redis_module],
stdout=stdout_file, stderr=stderr_file)
stdout=stdout_file, stderr=stderr_file)
time.sleep(0.1)
# Check if Redis successfully started (or at least if it the executable did
# not exit within 0.1 seconds).
@@ -291,6 +317,7 @@ def start_redis(node_ip_address="127.0.0.1", port=None, num_retries=20,
[stdout_file, stderr_file])
return port, p
def start_log_monitor(redis_address, node_ip_address, stdout_file=None,
stderr_file=None, cleanup=cleanup):
"""Start a log monitor process.
@@ -307,7 +334,9 @@ def start_log_monitor(redis_address, node_ip_address, stdout_file=None,
this process will be killed by services.cleanup() when the Python process
that imported services exits.
"""
log_monitor_filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log_monitor.py")
log_monitor_filepath = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"log_monitor.py")
p = subprocess.Popen(["python", log_monitor_filepath,
"--redis-address", redis_address,
"--node-ip-address", node_ip_address],
@@ -317,13 +346,15 @@ def start_log_monitor(redis_address, node_ip_address, stdout_file=None,
record_log_files_in_redis(redis_address, node_ip_address,
[stdout_file, stderr_file])
def start_global_scheduler(redis_address, node_ip_address, stdout_file=None,
stderr_file=None, cleanup=True):
"""Start a global scheduler process.
Args:
redis_address (str): The address of the Redis instance.
node_ip_address: The IP address of the node that this scheduler will run on.
node_ip_address: The IP address of the node that this scheduler will run
on.
stdout_file: A file handle opened for writing to redirect stdout to. If no
redirection should happen, then this should be None.
stderr_file: A file handle opened for writing to redirect stderr to. If no
@@ -340,6 +371,7 @@ def start_global_scheduler(redis_address, node_ip_address, stdout_file=None,
record_log_files_in_redis(redis_address, node_ip_address,
[stdout_file, stderr_file])
def start_webui(redis_address, node_ip_address, backend_stdout_file=None,
backend_stderr_file=None, polymer_stdout_file=None,
polymer_stderr_file=None, cleanup=True):
@@ -367,8 +399,11 @@ def start_webui(redis_address, node_ip_address, backend_stdout_file=None,
Return:
True if the web UI was successfully started, otherwise false.
"""
webui_backend_filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../webui/backend/ray_ui.py")
webui_directory = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../webui/")
webui_backend_filepath = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"../../webui/backend/ray_ui.py")
webui_directory = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"../../webui/")
if sys.version_info >= (3, 0):
python_executable = "python"
@@ -376,7 +411,8 @@ def start_webui(redis_address, node_ip_address, backend_stdout_file=None,
# If the user is using Python 2, it is still possible to run the webserver
# separately with Python 3, so try to find a Python 3 executable.
try:
python_executable = subprocess.check_output(["which", "python3"]).decode("ascii").strip()
python_executable = subprocess.check_output(
["which", "python3"]).decode("ascii").strip()
except Exception as e:
print("Not starting the web UI because the web UI requires Python 3.")
return False
@@ -394,8 +430,8 @@ def start_webui(redis_address, node_ip_address, backend_stdout_file=None,
return False
# Try to start polymer. If this fails, it may that port 8080 is already in
# use. It'd be nice to test for this, but doing so by calling "bind" may start
# using the port and prevent polymer from using it.
# use. It'd be nice to test for this, but doing so by calling "bind" may
# start using the port and prevent polymer from using it.
try:
polymer_process = subprocess.Popen(["polymer", "serve", "--port", "8080"],
cwd=webui_directory,
@@ -433,6 +469,7 @@ def start_webui(redis_address, node_ip_address, backend_stdout_file=None,
return True
def start_local_scheduler(redis_address,
node_ip_address,
plasma_store_name,
@@ -472,8 +509,8 @@ def start_local_scheduler(redis_address,
The name of the local scheduler socket.
"""
if num_cpus is None:
# By default, use the number of hardware execution threads for the number of
# cores.
# By default, use the number of hardware execution threads for the number
# of cores.
num_cpus = multiprocessing.cpu_count()
if num_gpus is None:
# By default, assume this node has no GPUs.
@@ -496,6 +533,7 @@ def start_local_scheduler(redis_address,
[stdout_file, stderr_file])
return local_scheduler_name
def start_objstore(node_ip_address, redis_address, object_manager_port=None,
store_stdout_file=None, store_stderr_file=None,
manager_stdout_file=None, manager_stderr_file=None,
@@ -511,10 +549,10 @@ def start_objstore(node_ip_address, redis_address, object_manager_port=None,
If no redirection should happen, then this should be None.
store_stderr_file: A file handle opened for writing to redirect stderr to.
If no redirection should happen, then this should be None.
manager_stdout_file: A file handle opened for writing to redirect stdout to.
If no redirection should happen, then this should be None.
manager_stderr_file: A file handle opened for writing to redirect stderr to.
If no redirection should happen, then this should be None.
manager_stdout_file: A file handle opened for writing to redirect stdout
to. If no redirection should happen, then this should be None.
manager_stderr_file: A file handle opened for writing to redirect stderr
to. If no redirection should happen, then this should be None.
cleanup (bool): True if using Ray in local mode. If cleanup is true, then
this process will be killed by serices.cleanup() when the Python process
that imported services exits.
@@ -522,8 +560,8 @@ def start_objstore(node_ip_address, redis_address, object_manager_port=None,
with.
Return:
A tuple of the Plasma store socket name, the Plasma manager socket name, and
the plasma manager port.
A tuple of the Plasma store socket name, the Plasma manager socket name,
and the plasma manager port.
"""
if objstore_memory is None:
# Compute a fraction of the system memory for the Plasma store to use.
@@ -559,7 +597,8 @@ def start_objstore(node_ip_address, redis_address, object_manager_port=None,
stderr_file=store_stderr_file)
# Start the plasma manager.
if object_manager_port is not None:
plasma_manager_name, p2, plasma_manager_port = ray.plasma.start_plasma_manager(
(plasma_manager_name, p2,
plasma_manager_port) = ray.plasma.start_plasma_manager(
plasma_store_name,
redis_address,
plasma_manager_port=object_manager_port,
@@ -570,7 +609,8 @@ def start_objstore(node_ip_address, redis_address, object_manager_port=None,
stderr_file=manager_stderr_file)
assert plasma_manager_port == object_manager_port
else:
plasma_manager_name, p2, plasma_manager_port = ray.plasma.start_plasma_manager(
(plasma_manager_name, p2,
plasma_manager_port) = ray.plasma.start_plasma_manager(
plasma_store_name,
redis_address,
node_ip_address=node_ip_address,
@@ -587,6 +627,7 @@ def start_objstore(node_ip_address, redis_address, object_manager_port=None,
return ObjectStoreAddress(plasma_store_name, plasma_manager_name,
plasma_manager_port)
def start_worker(node_ip_address, object_store_name, object_store_manager_name,
local_scheduler_name, redis_address, worker_path,
stdout_file=None, stderr_file=None, cleanup=True):
@@ -599,8 +640,8 @@ def start_worker(node_ip_address, object_store_name, object_store_manager_name,
object_store_manager_name (str): The name of the object store manager.
local_scheduler_name (str): The name of the local scheduler.
redis_address (str): The address that the Redis server is listening on.
worker_path (str): The path of the source code which the worker process will
run.
worker_path (str): The path of the source code which the worker process
will run.
stdout_file: A file handle opened for writing to redirect stdout to. If no
redirection should happen, then this should be None.
stderr_file: A file handle opened for writing to redirect stderr to. If no
@@ -622,6 +663,7 @@ def start_worker(node_ip_address, object_store_name, object_store_manager_name,
record_log_files_in_redis(redis_address, node_ip_address,
[stdout_file, stderr_file])
def start_monitor(redis_address, node_ip_address, stdout_file=None,
stderr_file=None, cleanup=True):
"""Run a process to monitor the other processes.
@@ -637,7 +679,8 @@ def start_monitor(redis_address, node_ip_address, stdout_file=None,
this process will be killed by services.cleanup() when the Python process
that imported services exits. This is True by default.
"""
monitor_path= os.path.join(os.path.dirname(os.path.abspath(__file__)), "monitor.py")
monitor_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"monitor.py")
command = ["python",
monitor_path,
"--redis-address=" + str(redis_address)]
@@ -647,6 +690,7 @@ def start_monitor(redis_address, node_ip_address, stdout_file=None,
record_log_files_in_redis(redis_address, node_ip_address,
[stdout_file, stderr_file])
def start_ray_processes(address_info=None,
node_ip_address="127.0.0.1",
num_workers=0,
@@ -685,8 +729,9 @@ def start_ray_processes(address_info=None,
start a global scheduler process.
include_redis (bool): If include_redis is True, then start a Redis server
process.
include_log_monitor (bool): If True, then start a log monitor to monitor the
log files for all processes on this node and push their contents to Redis.
include_log_monitor (bool): If True, then start a log monitor to monitor
the log files for all processes on this node and push their contents to
Redis.
include_webui (bool): If True, then attempt to start the web UI. Note that
this is only possible with Python 3.
start_workers_from_local_scheduler (bool): If this flag is True, then start
@@ -713,7 +758,8 @@ def start_ray_processes(address_info=None,
address_info["node_ip_address"] = node_ip_address
if worker_path is None:
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "workers/default_worker.py")
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"workers/default_worker.py")
# Start Redis if there isn't already an instance running. TODO(rkn): We are
# suppressing the output of Redis because on Linux it prints a bunch of
@@ -721,7 +767,8 @@ def start_ray_processes(address_info=None,
# should address the warnings.
redis_address = address_info.get("redis_address")
if include_redis:
redis_stdout_file, redis_stderr_file = new_log_files("redis", redirect_output)
redis_stdout_file, redis_stderr_file = new_log_files("redis",
redirect_output)
if redis_address is None:
# Start a Redis server. The start_redis method will choose a random port.
redis_port, _ = start_redis(node_ip_address,
@@ -735,7 +782,6 @@ def start_ray_processes(address_info=None,
# A Redis address was provided, so start a Redis server with the given
# port. TODO(rkn): We should check that the IP address corresponds to the
# machine that this method is running on.
redis_ip_address = get_ip_address(redis_address)
redis_port = get_port(redis_address)
new_redis_port, _ = start_redis(port=int(redis_port),
num_retries=1,
@@ -744,7 +790,8 @@ def start_ray_processes(address_info=None,
cleanup=cleanup)
assert redis_port == new_redis_port
# Start monitoring the processes.
monitor_stdout_file, monitor_stderr_file = new_log_files("monitor", redirect_output)
monitor_stdout_file, monitor_stderr_file = new_log_files("monitor",
redirect_output)
start_monitor(redis_address,
node_ip_address,
stdout_file=monitor_stdout_file,
@@ -755,8 +802,8 @@ def start_ray_processes(address_info=None,
# Start the log monitor, if necessary.
if include_log_monitor:
log_monitor_stdout_file, log_monitor_stderr_file = new_log_files("log_monitor",
redirect_output=True)
log_monitor_stdout_file, log_monitor_stderr_file = new_log_files(
"log_monitor", redirect_output=True)
start_log_monitor(redis_address,
node_ip_address,
stdout_file=log_monitor_stdout_file,
@@ -765,7 +812,8 @@ def start_ray_processes(address_info=None,
# Start the global scheduler, if necessary.
if include_global_scheduler:
global_scheduler_stdout_file, global_scheduler_stderr_file = new_log_files("global_scheduler", redirect_output)
global_scheduler_stdout_file, global_scheduler_stderr_file = new_log_files(
"global_scheduler", redirect_output)
start_global_scheduler(redis_address,
node_ip_address,
stdout_file=global_scheduler_stdout_file,
@@ -781,7 +829,8 @@ def start_ray_processes(address_info=None,
local_scheduler_socket_names = address_info["local_scheduler_socket_names"]
# Get the ports to use for the object managers if any are provided.
object_manager_ports = address_info["object_manager_ports"] if "object_manager_ports" in address_info else None
object_manager_ports = (address_info["object_manager_ports"]
if "object_manager_ports" in address_info else None)
if not isinstance(object_manager_ports, list):
object_manager_ports = num_local_schedulers * [object_manager_ports]
assert len(object_manager_ports) == num_local_schedulers
@@ -789,23 +838,26 @@ def start_ray_processes(address_info=None,
# Start any object stores that do not yet exist.
for i in range(num_local_schedulers - len(object_store_addresses)):
# Start Plasma.
plasma_store_stdout_file, plasma_store_stderr_file = new_log_files("plasma_store_{}".format(i), redirect_output)
plasma_manager_stdout_file, plasma_manager_stderr_file = new_log_files("plasma_manager_{}".format(i), redirect_output)
object_store_address = start_objstore(node_ip_address,
redis_address,
object_manager_port=object_manager_ports[i],
store_stdout_file=plasma_store_stdout_file,
store_stderr_file=plasma_store_stderr_file,
manager_stdout_file=plasma_manager_stdout_file,
manager_stderr_file=plasma_manager_stderr_file,
cleanup=cleanup)
plasma_store_stdout_file, plasma_store_stderr_file = new_log_files(
"plasma_store_{}".format(i), redirect_output)
plasma_manager_stdout_file, plasma_manager_stderr_file = new_log_files(
"plasma_manager_{}".format(i), redirect_output)
object_store_address = start_objstore(
node_ip_address,
redis_address,
object_manager_port=object_manager_ports[i],
store_stdout_file=plasma_store_stdout_file,
store_stderr_file=plasma_store_stderr_file,
manager_stdout_file=plasma_manager_stdout_file,
manager_stderr_file=plasma_manager_stderr_file,
cleanup=cleanup)
object_store_addresses.append(object_store_address)
time.sleep(0.1)
# Determine how many workers to start for each local scheduler.
num_workers_per_local_scheduler = [0] * num_local_schedulers
workers_per_local_scheduler = [0] * num_local_schedulers
for i in range(num_workers):
num_workers_per_local_scheduler[i % num_local_schedulers] += 1
workers_per_local_scheduler[i % num_local_schedulers] += 1
# Start any local schedulers that do not yet exist.
for i in range(len(local_scheduler_socket_names), num_local_schedulers):
@@ -815,26 +867,28 @@ def start_ray_processes(address_info=None,
object_store_address.manager_port)
# Determine how many workers this local scheduler should start.
if start_workers_from_local_scheduler:
num_local_scheduler_workers = num_workers_per_local_scheduler[i]
num_workers_per_local_scheduler[i] = 0
num_local_scheduler_workers = workers_per_local_scheduler[i]
workers_per_local_scheduler[i] = 0
else:
# If we're starting the workers from Python, the local scheduler should
# not start any workers.
num_local_scheduler_workers = 0
# Start the local scheduler.
local_scheduler_stdout_file, local_scheduler_stderr_file = new_log_files("local_scheduler_{}".format(i), redirect_output)
local_scheduler_name = start_local_scheduler(redis_address,
node_ip_address,
object_store_address.name,
object_store_address.manager_name,
worker_path,
plasma_address=plasma_address,
stdout_file=local_scheduler_stdout_file,
stderr_file=local_scheduler_stderr_file,
cleanup=cleanup,
num_cpus=num_cpus[i],
num_gpus=num_gpus[i],
num_workers=num_local_scheduler_workers)
local_scheduler_stdout_file, local_scheduler_stderr_file = new_log_files(
"local_scheduler_{}".format(i), redirect_output)
local_scheduler_name = start_local_scheduler(
redis_address,
node_ip_address,
object_store_address.name,
object_store_address.manager_name,
worker_path,
plasma_address=plasma_address,
stdout_file=local_scheduler_stdout_file,
stderr_file=local_scheduler_stderr_file,
cleanup=cleanup,
num_cpus=num_cpus[i],
num_gpus=num_gpus[i],
num_workers=num_local_scheduler_workers)
local_scheduler_socket_names.append(local_scheduler_name)
time.sleep(0.1)
@@ -844,11 +898,12 @@ def start_ray_processes(address_info=None,
assert len(local_scheduler_socket_names) == num_local_schedulers
# Start any workers that the local scheduler has not already started.
for i, num_local_scheduler_workers in enumerate(num_workers_per_local_scheduler):
for i, num_local_scheduler_workers in enumerate(workers_per_local_scheduler):
object_store_address = object_store_addresses[i]
local_scheduler_name = local_scheduler_socket_names[i]
for j in range(num_local_scheduler_workers):
worker_stdout_file, worker_stderr_file = new_log_files("worker_{}_{}".format(i, j), redirect_output)
worker_stdout_file, worker_stderr_file = new_log_files(
"worker_{}_{}".format(i, j), redirect_output)
start_worker(node_ip_address,
object_store_address.name,
object_store_address.manager_name,
@@ -858,17 +913,17 @@ def start_ray_processes(address_info=None,
stdout_file=worker_stdout_file,
stderr_file=worker_stderr_file,
cleanup=cleanup)
num_workers_per_local_scheduler[i] -= 1
workers_per_local_scheduler[i] -= 1
# Make sure that we've started all the workers.
assert(sum(num_workers_per_local_scheduler) == 0)
assert(sum(workers_per_local_scheduler) == 0)
# Try to start the web UI.
if include_webui:
backend_stdout_file, backend_stderr_file = new_log_files("webui_backend",
redirect_output=True)
polymer_stdout_file, polymer_stderr_file = new_log_files("webui_polymer",
redirect_output=True)
backend_stdout_file, backend_stderr_file = new_log_files(
"webui_backend", redirect_output=True)
polymer_stdout_file, polymer_stderr_file = new_log_files(
"webui_polymer", redirect_output=True)
successfully_started = start_webui(redis_address,
node_ip_address,
backend_stdout_file=backend_stdout_file,
@@ -883,6 +938,7 @@ def start_ray_processes(address_info=None,
# Return the addresses of the relevant processes.
return address_info
def start_ray_node(node_ip_address,
redis_address,
object_manager_ports=None,
@@ -905,8 +961,8 @@ def start_ray_node(node_ip_address,
managers. There should be one per object manager being started on this
node (typically just one).
num_workers (int): The number of workers to start.
num_local_schedulers (int): The number of local schedulers to start. This is
also the number of plasma stores and plasma managers to start.
num_local_schedulers (int): The number of local schedulers to start. This
is also the number of plasma stores and plasma managers to start.
worker_path (str): The path of the source code that will be run by the
worker.
cleanup (bool): If cleanup is true, then the processes started here will be
@@ -919,10 +975,8 @@ def start_ray_node(node_ip_address,
A dictionary of the address information for the processes that were
started.
"""
address_info = {
"redis_address": redis_address,
"object_manager_ports": object_manager_ports,
}
address_info = {"redis_address": redis_address,
"object_manager_ports": object_manager_ports}
return start_ray_processes(address_info=address_info,
node_ip_address=node_ip_address,
num_workers=num_workers,
@@ -934,6 +988,7 @@ def start_ray_node(node_ip_address,
num_cpus=num_cpus,
num_gpus=num_gpus)
def start_ray_head(address_info=None,
node_ip_address="127.0.0.1",
num_workers=0,
@@ -974,33 +1029,37 @@ def start_ray_head(address_info=None,
A dictionary of the address information for the processes that were
started.
"""
return start_ray_processes(address_info=address_info,
node_ip_address=node_ip_address,
num_workers=num_workers,
num_local_schedulers=num_local_schedulers,
worker_path=worker_path,
cleanup=cleanup,
redirect_output=redirect_output,
include_global_scheduler=True,
include_log_monitor=True,
include_redis=True,
include_webui=True,
start_workers_from_local_scheduler=start_workers_from_local_scheduler,
num_cpus=num_cpus,
num_gpus=num_gpus)
return start_ray_processes(
address_info=address_info,
node_ip_address=node_ip_address,
num_workers=num_workers,
num_local_schedulers=num_local_schedulers,
worker_path=worker_path,
cleanup=cleanup,
redirect_output=redirect_output,
include_global_scheduler=True,
include_log_monitor=True,
include_redis=True,
include_webui=True,
start_workers_from_local_scheduler=start_workers_from_local_scheduler,
num_cpus=num_cpus,
num_gpus=num_gpus)
def new_log_files(name, redirect_output):
"""Generate partially randomized filenames for log files.
Args:
name (str): descriptive string for this log file.
redirect_output (bool): True if files should be generated for logging stdout
and stderr and false if stdout and stderr should not be redirected.
redirect_output (bool): True if files should be generated for logging
stdout and stderr and false if stdout and stderr should not be
redirected.
Returns:
If redirect_output is true, this will return a tuple of two filehandles. The
first is for redirecting stdout and the second is for redirecting stderr.
If redirect_output is false, this will return a tuple of two None objects.
If redirect_output is true, this will return a tuple of two filehandles.
The first is for redirecting stdout and the second is for redirecting
stderr. If redirect_output is false, this will return a tuple of two None
objects.
"""
if not redirect_output:
return None, None
+17
View File
@@ -8,44 +8,53 @@ import numpy as np
# Test simple functionality
@ray.remote(num_return_vals=2)
def handle_int(a, b):
return a + 1, b + 1
# Test timing
@ray.remote
def empty_function():
pass
@ray.remote
def trivial_function():
return 1
# Test keyword arguments
@ray.remote
def keyword_fct1(a, b="hello"):
return "{} {}".format(a, b)
@ray.remote
def keyword_fct2(a="hello", b="world"):
return "{} {}".format(a, b)
@ray.remote
def keyword_fct3(a, b, c="hello", d="world"):
return "{} {} {} {}".format(a, b, c, d)
# Test variable numbers of arguments
@ray.remote
def varargs_fct1(*a):
return " ".join(map(str, a))
@ray.remote
def varargs_fct2(a, *b):
return " ".join(map(str, b))
try:
@ray.remote
def kwargs_throw_exception(**c):
@@ -64,24 +73,29 @@ except:
# test throwing an exception
@ray.remote
def throw_exception_fct1():
raise Exception("Test function 1 intentionally failed.")
@ray.remote
def throw_exception_fct2():
raise Exception("Test function 2 intentionally failed.")
@ray.remote(num_return_vals=3)
def throw_exception_fct3(x):
raise Exception("Test function 3 intentionally failed.")
# test Python mode
@ray.remote
def python_mode_f():
return np.array([0, 0])
@ray.remote
def python_mode_g(x):
x[0] = 1
@@ -89,14 +103,17 @@ def python_mode_g(x):
# test no return values
@ray.remote
def no_op():
pass
class TestClass(object):
def __init__(self):
self.a = 5
@ray.remote
def test_unknown_type():
return TestClass()
+439 -272
View File
File diff suppressed because it is too large Load Diff
+31 -17
View File
@@ -3,25 +3,33 @@ from __future__ import division
from __future__ import print_function
import argparse
import binascii
import numpy as np
import redis
import traceback
import sys
import binascii
import ray
parser = argparse.ArgumentParser(description="Parse addresses for the worker to connect to.")
parser.add_argument("--node-ip-address", required=True, type=str, help="the ip address of the worker's node")
parser.add_argument("--redis-address", required=True, type=str, help="the address to use for Redis")
parser.add_argument("--object-store-name", required=True, type=str, help="the object store's name")
parser.add_argument("--object-store-manager-name", required=True, type=str, help="the object store manager's name")
parser.add_argument("--local-scheduler-name", required=True, type=str, help="the local scheduler's name")
parser.add_argument("--actor-id", required=False, type=str, help="the actor ID of this worker")
parser = argparse.ArgumentParser(description=("Parse addresses for the worker "
"to connect to."))
parser.add_argument("--node-ip-address", required=True, type=str,
help="the ip address of the worker's node")
parser.add_argument("--redis-address", required=True, type=str,
help="the address to use for Redis")
parser.add_argument("--object-store-name", required=True, type=str,
help="the object store's name")
parser.add_argument("--object-store-manager-name", required=True, type=str,
help="the object store manager's name")
parser.add_argument("--local-scheduler-name", required=True, type=str,
help="the local scheduler's name")
parser.add_argument("--actor-id", required=False, type=str,
help="the actor ID of this worker")
def random_string():
return np.random.bytes(20)
if __name__ == "__main__":
args = parser.parse_args()
info = {"node_ip_address": args.node_ip_address,
@@ -30,7 +38,10 @@ if __name__ == "__main__":
"manager_socket_name": args.object_store_manager_name,
"local_scheduler_socket_name": args.local_scheduler_name}
actor_id = binascii.unhexlify(args.actor_id) if not args.actor_id is None else ray.worker.NIL_ACTOR_ID
if args.actor_id is not None:
actor_id = binascii.unhexlify(args.actor_id)
else:
actor_id = ray.worker.NIL_ACTOR_ID
ray.worker.connect(info, mode=ray.WORKER_MODE, actor_id=actor_id)
@@ -43,24 +54,27 @@ being caught in "python/ray/workers/default_worker.py".
while True:
try:
# This call to main_loop should never return if things are working. Most
# exceptions that are thrown (e.g., inside the execution of a task) should
# be caught and handled inside of the call to main_loop. If an exception
# is thrown here, then that means that there is some error that we didn't
# anticipate.
# exceptions that are thrown (e.g., inside the execution of a task)
# should be caught and handled inside of the call to main_loop. If an
# exception is thrown here, then that means that there is some error that
# we didn't anticipate.
ray.worker.main_loop()
except Exception as e:
traceback_str = traceback.format_exc() + error_explanation
DRIVER_ID_LENGTH = 20
# We use a driver ID of all zeros to push an error message to all drivers.
# We use a driver ID of all zeros to push an error message to all
# drivers.
driver_id = DRIVER_ID_LENGTH * b"\x00"
error_key = b"Error:" + driver_id + b":" + random_string()
redis_ip_address, redis_port = args.redis_address.split(":")
# For this command to work, some other client (on the same machine as
# Redis) must have run "CONFIG SET protected-mode no".
redis_client = redis.StrictRedis(host=redis_ip_address, port=int(redis_port))
redis_client = redis.StrictRedis(host=redis_ip_address,
port=int(redis_port))
redis_client.hmset(error_key, {"type": "worker_crash",
"message": traceback_str,
"note": "This error is unexpected and should not have happened."})
"note": ("This error is unexpected and "
"should not have happened.")})
redis_client.rpush("ErrorKeys", error_key)
# TODO(rkn): Note that if the worker was in the middle of executing a
# task, the any worker or driver that is blocking in a get call and
+16 -11
View File
@@ -2,12 +2,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import subprocess
from setuptools import setup, find_packages
import setuptools.command.install as _install
class install(_install.install):
def run(self):
subprocess.check_call(["../build.sh"])
@@ -16,19 +16,24 @@ class install(_install.install):
# setuptools. So, calling do_egg_install() manually here.
self.do_egg_install()
package_data = {
"ray": ["core/src/common/thirdparty/redis/src/redis-server",
"core/src/common/redis_module/libray_redis_module.so",
"core/src/plasma/plasma_store",
"core/src/plasma/plasma_manager",
"core/src/plasma/libplasma.so",
"core/src/local_scheduler/local_scheduler",
"core/src/local_scheduler/liblocal_scheduler_library.so",
"core/src/numbuf/libarrow.so",
"core/src/numbuf/libnumbuf.so",
"core/src/global_scheduler/global_scheduler"]
}
setup(name="ray",
version="0.0.1",
packages=find_packages(),
package_data={"ray": ["core/src/common/thirdparty/redis/src/redis-server",
"core/src/common/redis_module/libray_redis_module.so",
"core/src/plasma/plasma_store",
"core/src/plasma/plasma_manager",
"core/src/plasma/libplasma.so",
"core/src/local_scheduler/local_scheduler",
"core/src/local_scheduler/liblocal_scheduler_library.so",
"core/src/numbuf/libarrow.so",
"core/src/numbuf/libnumbuf.so",
"core/src/global_scheduler/global_scheduler"]},
package_data=package_data,
cmdclass={"install": install},
install_requires=["numpy",
"funcsigs",
+42 -26
View File
@@ -7,15 +7,25 @@ import redis
import ray.services as services
parser = argparse.ArgumentParser(description="Parse addresses for the worker to connect to.")
parser.add_argument("--node-ip-address", required=False, type=str, help="the IP address of the worker's node")
parser.add_argument("--redis-address", required=False, type=str, help="the address to use for connecting to Redis")
parser.add_argument("--redis-port", required=False, type=str, help="the port to use for starting Redis")
parser.add_argument("--object-manager-port", required=False, type=int, help="the port to use for starting the object manager")
parser.add_argument("--num-workers", default=10, required=False, type=int, help="the number of workers to start on this node")
parser.add_argument("--num-cpus", required=False, type=int, help="the number of CPUs on this node")
parser.add_argument("--num-gpus", required=False, type=int, help="the number of GPUs on this node")
parser.add_argument("--head", action="store_true", help="provide this argument for the head node")
parser = argparse.ArgumentParser(
description="Parse addresses for the worker to connect to.")
parser.add_argument("--node-ip-address", required=False, type=str,
help="the IP address of the worker's node")
parser.add_argument("--redis-address", required=False, type=str,
help="the address to use for connecting to Redis")
parser.add_argument("--redis-port", required=False, type=str,
help="the port to use for starting Redis")
parser.add_argument("--object-manager-port", required=False, type=int,
help="the port to use for starting the object manager")
parser.add_argument("--num-workers", default=10, required=False, type=int,
help="the number of workers to start on this node")
parser.add_argument("--num-cpus", required=False, type=int,
help="the number of CPUs on this node")
parser.add_argument("--num-gpus", required=False, type=int,
help="the number of GPUs on this node")
parser.add_argument("--head", action="store_true",
help="provide this argument for the head node")
def check_no_existing_redis_clients(node_ip_address, redis_address):
redis_ip_address, redis_port = redis_address.split(":")
@@ -39,7 +49,9 @@ def check_no_existing_redis_clients(node_ip_address, redis_address):
continue
if info[b"node_ip_address"].decode("ascii") == node_ip_address:
raise Exception("This Redis instance is already connected to clients with this IP address.")
raise Exception("This Redis instance is already connected to clients "
"with this IP address.")
if __name__ == "__main__":
args = parser.parse_args()
@@ -52,7 +64,8 @@ if __name__ == "__main__":
if args.head:
# Start Ray on the head node.
if args.redis_address is not None:
raise Exception("If --head is passed in, a Redis server will be started, so a Redis address should not be provided.")
raise Exception("If --head is passed in, a Redis server will be "
"started, so a Redis address should not be provided.")
# Get the node IP address if one is not provided.
if args.node_ip_address is None:
@@ -82,25 +95,27 @@ if __name__ == "__main__":
print(address_info)
print("\nStarted Ray with {} workers on this node. A different number of "
"workers can be set with the --num-workers flag (but you have to "
"first terminate the existing cluster). You can add additional nodes "
"to the cluster by calling\n\n"
"first terminate the existing cluster). You can add additional "
"nodes to the cluster by calling\n\n"
" ./scripts/start_ray.sh --redis-address {}\n\n"
"from the node you wish to add. You can connect a driver to the "
"cluster from Python by running\n\n"
" import ray\n"
" ray.init(redis_address=\"{}\")\n\n"
"If you have trouble connecting from a different machine, check that "
"your firewall is configured properly. If you wish to terminate the "
"processes that have been started, run\n\n"
"If you have trouble connecting from a different machine, check "
"that your firewall is configured properly. If you wish to "
"terminate the processes that have been started, run\n\n"
" ./scripts/stop_ray.sh".format(args.num_workers,
address_info["redis_address"],
address_info["redis_address"]))
else:
# Start Ray on a non-head node.
if args.redis_port is not None:
raise Exception("If --head is not passed in, --redis-port is not allowed")
raise Exception("If --head is not passed in, --redis-port is not "
"allowed")
if args.redis_address is None:
raise Exception("If --head is not passed in, --redis-address must be provided.")
raise Exception("If --head is not passed in, --redis-address must be "
"provided.")
redis_ip_address, redis_port = args.redis_address.split(":")
# Wait for the Redis server to be started. And throw an exception if we
# can't connect to it.
@@ -115,14 +130,15 @@ if __name__ == "__main__":
# connected with this Redis instance. This raises an exception if the Redis
# server already has clients on this node.
check_no_existing_redis_clients(node_ip_address, args.redis_address)
address_info = services.start_ray_node(node_ip_address=node_ip_address,
redis_address=args.redis_address,
object_manager_ports=[args.object_manager_port],
num_workers=args.num_workers,
cleanup=False,
redirect_output=True,
num_cpus=args.num_cpus,
num_gpus=args.num_gpus)
address_info = services.start_ray_node(
node_ip_address=node_ip_address,
redis_address=args.redis_address,
object_manager_ports=[args.object_manager_port],
num_workers=args.num_workers,
cleanup=False,
redirect_output=True,
num_cpus=args.num_cpus,
num_gpus=args.num_gpus)
print(address_info)
print("\nStarted {} workers on this node. A different number of workers "
"can be set with the --num-workers flag (but you have to first "
-29
View File
@@ -1,29 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# See https://github.com/ray-project/ray/issues/131.
helpful_message = """
If you are using Anaconda, try fixing this problem by running:
conda install libgcc
"""
try:
from .libnumbuf import *
except ImportError as e:
if hasattr(e, "msg") and isinstance(e.msg, str) and "GLIBCXX" in e.msg:
# This code path should be taken with Python 3.
e.msg += helpful_message
elif hasattr(e, "message") and isinstance(e.message, str) and "GLIBCXX" in e.message:
# This code path should be taken with Python 2.
if hasattr(e, "args") and isinstance(e.args, tuple) and len(e.args) == 1 and isinstance(e.args[0], str):
e.args = (e.args[0] + helpful_message,)
else:
if not hasattr(e, "args"):
e.args = ()
elif not isinstance(e.args, tuple):
e.args = (e.args,)
e.args += (helpful_message,)
raise
+10 -5
View File
@@ -9,19 +9,21 @@ from numpy.testing import assert_equal
import os
import sys
TEST_OBJECTS = [{(1,2) : 1}, {() : 2}, [1, "hello", 3.0], 42, 43, "hello world",
TEST_OBJECTS = [{(1, 2): 1}, {(): 2}, [1, "hello", 3.0], 42, 43,
"hello world",
u"x", u"\u262F", 42.0,
1 << 62, (1.0, "hi"),
None, (None, None), ("hello", None),
True, False, (True, False), "hello",
{True: "hello", False: "world"},
{"hello" : "world", 1: 42, 1.0: 45}, {},
{"hello": "world", 1: 42, 2.5: 45}, {},
np.int8(3), np.int32(4), np.int64(5),
np.uint8(3), np.uint32(4), np.uint64(5),
np.float32(1.0), np.float64(1.0)]
if sys.version_info < (3, 0):
TEST_OBJECTS += [long(42), long(1 << 62)]
TEST_OBJECTS += [long(42), long(1 << 62)] # noqa: F821
class SerializationTests(unittest.TestCase):
@@ -47,14 +49,16 @@ class SerializationTests(unittest.TestCase):
self.roundTripTest([{"hello": [1, 2, 3]}])
self.roundTripTest([{"hello": [1, [2, 3]]}])
self.roundTripTest([{"hello": (None, 2, [3, 4])}])
self.roundTripTest([{"hello": (None, 2, [3, 4], np.array([1.0, 2.0, 3.0]))}])
self.roundTripTest(
[{"hello": (None, 2, [3, 4], np.array([1.0, 2.0, 3.0]))}])
def numpyTest(self, t):
a = np.random.randint(0, 10, size=(100, 100)).astype(t)
self.roundTripTest([a])
def testArrays(self):
for t in ["int8", "uint8", "int16", "uint16", "int32", "uint32", "float32", "float64"]:
for t in ["int8", "uint8", "int16", "uint16", "int32", "uint32", "float32",
"float64"]:
self.numpyTest(t)
def testRay(self):
@@ -165,5 +169,6 @@ class SerializationTests(unittest.TestCase):
print("Not running testArrowLimits on Travis because of the test's "
"memory requirements.")
if __name__ == "__main__":
unittest.main(verbosity=2)
+5 -3
View File
@@ -7,11 +7,13 @@ import setuptools.command.install as _install
import subprocess
class install(_install.install):
def run(self):
subprocess.check_call(["make"])
subprocess.check_call(["cp", "build/plasma_store", "plasma/plasma_store"])
subprocess.check_call(["cp", "build/plasma_manager", "plasma/plasma_manager"])
subprocess.check_call(["cp", "build/plasma_manager",
"plasma/plasma_manager"])
subprocess.check_call(["cmake", ".."], cwd="./build")
subprocess.check_call(["make", "install"], cwd="./build")
# Calling _install.install.run(self) does not fetch required packages and
@@ -19,14 +21,14 @@ class install(_install.install):
# setuptools. So, calling do_egg_install() manually here.
self.do_egg_install()
setup(name="Plasma",
version="0.0.1",
description="Plasma client for Python",
packages=find_packages(),
package_data={"plasma": ["plasma_store",
"plasma_manager",
"libplasma.so"],
},
"libplasma.so"]},
cmdclass={"install": install},
include_package_data=True,
zip_safe=False)
+105 -46
View File
@@ -2,11 +2,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import numpy as np
import time
import unittest
import ray
class ActorAPI(unittest.TestCase):
def testKeywordArgs(self):
@@ -18,6 +19,7 @@ class ActorAPI(unittest.TestCase):
self.arg0 = arg0
self.arg1 = arg1
self.arg2 = arg2
def get_values(self, arg0, arg1=2, arg2="b"):
return self.arg0 + arg0, self.arg1 + arg1, self.arg2 + arg2
@@ -53,6 +55,7 @@ class ActorAPI(unittest.TestCase):
self.arg0 = arg0
self.arg1 = arg1
self.args = args
def get_values(self, arg0, arg1=2, *args):
return self.arg0 + arg0, self.arg1 + arg1, self.args, args
@@ -63,10 +66,12 @@ class ActorAPI(unittest.TestCase):
self.assertEqual(ray.get(actor.get_values(2, 3)), (3, 5, (), ()))
actor = Actor(1, 2, "c")
self.assertEqual(ray.get(actor.get_values(2, 3, "d")), (3, 5, ("c",), ("d",)))
self.assertEqual(ray.get(actor.get_values(2, 3, "d")),
(3, 5, ("c",), ("d",)))
actor = Actor(1, 2, "a", "b", "c", "d")
self.assertEqual(ray.get(actor.get_values(2, 3, 1, 2, 3, 4)), (3, 5, ("a", "b", "c", "d"), (1, 2, 3, 4)))
self.assertEqual(ray.get(actor.get_values(2, 3, 1, 2, 3, 4)),
(3, 5, ("a", "b", "c", "d"), (1, 2, 3, 4)))
ray.worker.cleanup()
@@ -77,6 +82,7 @@ class ActorAPI(unittest.TestCase):
class Actor(object):
def __init__(self):
pass
def get_values(self):
pass
@@ -112,8 +118,10 @@ class ActorAPI(unittest.TestCase):
def __init__(self, f2):
self.f1 = Foo(1)
self.f2 = f2
def get_values1(self):
return self.f1, self.f2
def get_values2(self, f3):
return self.f1, self.f2, f3
@@ -144,38 +152,39 @@ class ActorAPI(unittest.TestCase):
# This is an invalid way of using the actor decorator.
with self.assertRaises(Exception):
@ray.actor(invalid_kwarg=0)
@ray.actor(invalid_kwarg=0) # noqa: F811
class Actor(object):
def __init__(self):
pass
# This is an invalid way of using the actor decorator.
with self.assertRaises(Exception):
@ray.actor(num_cpus=0, invalid_kwarg=0)
@ray.actor(num_cpus=0, invalid_kwarg=0) # noqa: F811
class Actor(object):
def __init__(self):
pass
# This is a valid way of using the decorator.
@ray.actor(num_cpus=1)
@ray.actor(num_cpus=1) # noqa: F811
class Actor(object):
def __init__(self):
pass
# This is a valid way of using the decorator.
@ray.actor(num_gpus=1)
@ray.actor(num_gpus=1) # noqa: F811
class Actor(object):
def __init__(self):
pass
# This is a valid way of using the decorator.
@ray.actor(num_cpus=1, num_gpus=1)
@ray.actor(num_cpus=1, num_gpus=1) # noqa: F811
class Actor(object):
def __init__(self):
pass
ray.worker.cleanup()
class ActorMethods(unittest.TestCase):
def testDefineActor(self):
@@ -185,6 +194,7 @@ class ActorMethods(unittest.TestCase):
class Test(object):
def __init__(self, x):
self.x = x
def f(self, y):
return self.x + y
@@ -200,8 +210,10 @@ class ActorMethods(unittest.TestCase):
class Counter(object):
def __init__(self):
self.value = 0
def increase(self):
self.value += 1
def value(self):
return self.value
@@ -224,9 +236,11 @@ class ActorMethods(unittest.TestCase):
class Counter(object):
def __init__(self, value):
self.value = value
def increase(self):
self.value += 1
return self.value
def reset(self):
self.value = 0
@@ -240,7 +254,9 @@ class ActorMethods(unittest.TestCase):
results += [actors[i].increase() for _ in range(num_increases)]
result_values = ray.get(results)
for i in range(num_actors):
self.assertEqual(result_values[(num_increases * i):(num_increases * (i + 1))], list(range(i + 1, num_increases + i + 1)))
self.assertEqual(
result_values[(num_increases * i):(num_increases * (i + 1))],
list(range(i + 1, num_increases + i + 1)))
# Reset the actor values.
[actor.reset() for actor in actors]
@@ -251,10 +267,12 @@ class ActorMethods(unittest.TestCase):
results += [actor.increase() for actor in actors]
result_values = ray.get(results)
for j in range(num_increases):
self.assertEqual(result_values[(num_actors * j):(num_actors * (j + 1))], num_actors * [j + 1])
self.assertEqual(result_values[(num_actors * j):(num_actors * (j + 1))],
num_actors * [j + 1])
ray.worker.cleanup()
class ActorNesting(unittest.TestCase):
def testRemoteFunctionWithinActor(self):
@@ -302,7 +320,8 @@ class ActorNesting(unittest.TestCase):
self.assertEqual(ray.get(ray.get(actor.f())), list(range(1, 6)))
self.assertEqual(ray.get(actor.g()), list(range(1, 6)))
self.assertEqual(ray.get(actor.h([f.remote(i) for i in range(5)])), list(range(1, 6)))
self.assertEqual(ray.get(actor.h([f.remote(i) for i in range(5)])),
list(range(1, 6)))
ray.worker.cleanup()
@@ -320,6 +339,7 @@ class ActorNesting(unittest.TestCase):
class Actor2(object):
def __init__(self, x):
self.x = x
def get_value(self):
return self.x
self.actor2 = Actor2(z)
@@ -370,13 +390,15 @@ class ActorNesting(unittest.TestCase):
class Actor1(object):
def __init__(self, x):
self.x = x
def get_value(self):
return self.x
actor = Actor1(x)
return ray.get([actor.get_value() for _ in range(n)])
self.assertEqual(ray.get(f.remote(3, 1)), [3])
self.assertEqual(ray.get([f.remote(i, 20) for i in range(10)]), [20 * [i] for i in range(10)])
self.assertEqual(ray.get([f.remote(i, 20) for i in range(10)]),
[20 * [i] for i in range(10)])
ray.worker.cleanup()
@@ -421,8 +443,10 @@ class ActorNesting(unittest.TestCase):
def __init__(self):
# This should use the last version of f.
self.x = ray.get(f.remote())
def get_val(self):
return self.x
actor = Actor()
return ray.get(actor.get_val())
@@ -430,6 +454,7 @@ class ActorNesting(unittest.TestCase):
ray.worker.cleanup()
class ActorInheritance(unittest.TestCase):
def testInheritActorFromClass(self):
@@ -440,8 +465,10 @@ class ActorInheritance(unittest.TestCase):
class Foo(object):
def __init__(self, x):
self.x = x
def f(self):
return self.x
def g(self, y):
return self.x + y
@@ -449,6 +476,7 @@ class ActorInheritance(unittest.TestCase):
class Actor(Foo):
def __init__(self, x):
Foo.__init__(self, x)
def get_value(self):
return self.f()
@@ -458,6 +486,7 @@ class ActorInheritance(unittest.TestCase):
ray.worker.cleanup()
class ActorSchedulingProperties(unittest.TestCase):
def testRemoteFunctionsNotScheduledOnActors(self):
@@ -469,7 +498,7 @@ class ActorSchedulingProperties(unittest.TestCase):
def __init__(self):
pass
actor = Actor()
Actor()
@ray.remote
def f():
@@ -477,22 +506,26 @@ class ActorSchedulingProperties(unittest.TestCase):
# Make sure that f cannot be scheduled on the worker created for the actor.
# The wait call should time out.
ready_ids, remaining_ids = ray.wait([f.remote() for _ in range(10)], timeout=3000)
ready_ids, remaining_ids = ray.wait([f.remote() for _ in range(10)],
timeout=3000)
self.assertEqual(ready_ids, [])
self.assertEqual(len(remaining_ids), 10)
ray.worker.cleanup()
class ActorsOnMultipleNodes(unittest.TestCase):
def testActorLoadBalancing(self):
num_local_schedulers = 3
ray.worker._init(start_ray_local=True, num_workers=0, num_local_schedulers=num_local_schedulers)
ray.worker._init(start_ray_local=True, num_workers=0,
num_local_schedulers=num_local_schedulers)
@ray.actor
class Actor1(object):
def __init__(self):
pass
def get_location(self):
return ray.worker.global_worker.plasma_client.store_socket_name
@@ -509,7 +542,8 @@ class ActorsOnMultipleNodes(unittest.TestCase):
names = set(locations)
counts = [locations.count(name) for name in names]
print("Counts are {}.".format(counts))
if len(names) == num_local_schedulers and all([count >= minimum_count for count in counts]):
if len(names) == num_local_schedulers and all([count >= minimum_count
for count in counts]):
break
attempts += 1
self.assertLess(attempts, num_attempts)
@@ -523,26 +557,32 @@ class ActorsOnMultipleNodes(unittest.TestCase):
ray.worker.cleanup()
class ActorsWithGPUs(unittest.TestCase):
def testActorGPUs(self):
num_local_schedulers = 3
num_gpus_per_scheduler = 4
ray.worker._init(start_ray_local=True, num_workers=0,
num_local_schedulers=num_local_schedulers,
num_gpus=(num_local_schedulers * [num_gpus_per_scheduler]))
ray.worker._init(
start_ray_local=True, num_workers=0,
num_local_schedulers=num_local_schedulers,
num_gpus=(num_local_schedulers * [num_gpus_per_scheduler]))
@ray.actor(num_gpus=1)
class Actor1(object):
def __init__(self):
self.gpu_ids = ray.get_gpu_ids()
def get_location_and_ids(self):
return ray.worker.global_worker.plasma_client.store_socket_name, tuple(self.gpu_ids)
return (ray.worker.global_worker.plasma_client.store_socket_name,
tuple(self.gpu_ids))
# Create one actor per GPU.
actors = [Actor1() for _ in range(num_local_schedulers * num_gpus_per_scheduler)]
actors = [Actor1() for _
in range(num_local_schedulers * num_gpus_per_scheduler)]
# Make sure that no two actors are assigned to the same GPU.
locations_and_ids = ray.get([actor.get_location_and_ids() for actor in actors])
locations_and_ids = ray.get([actor.get_location_and_ids()
for actor in actors])
node_names = set([location for location, gpu_id in locations_and_ids])
self.assertEqual(len(node_names), num_local_schedulers)
location_actor_combinations = []
@@ -553,28 +593,32 @@ class ActorsWithGPUs(unittest.TestCase):
# Creating a new actor should fail because all of the GPUs are being used.
with self.assertRaises(Exception):
a = Actor1()
Actor1()
ray.worker.cleanup()
def testActorMultipleGPUs(self):
num_local_schedulers = 3
num_gpus_per_scheduler = 5
ray.worker._init(start_ray_local=True, num_workers=0,
num_local_schedulers=num_local_schedulers,
num_gpus=(num_local_schedulers * [num_gpus_per_scheduler]))
ray.worker._init(
start_ray_local=True, num_workers=0,
num_local_schedulers=num_local_schedulers,
num_gpus=(num_local_schedulers * [num_gpus_per_scheduler]))
@ray.actor(num_gpus=2)
class Actor1(object):
def __init__(self):
self.gpu_ids = ray.get_gpu_ids()
def get_location_and_ids(self):
return ray.worker.global_worker.plasma_client.store_socket_name, tuple(self.gpu_ids)
return (ray.worker.global_worker.plasma_client.store_socket_name,
tuple(self.gpu_ids))
# Create some actors.
actors = [Actor1() for _ in range(num_local_schedulers * 2)]
# Make sure that no two actors are assigned to the same GPU.
locations_and_ids = ray.get([actor.get_location_and_ids() for actor in actors])
locations_and_ids = ray.get([actor.get_location_and_ids()
for actor in actors])
node_names = set([location for location, gpu_id in locations_and_ids])
self.assertEqual(len(node_names), num_local_schedulers)
location_actor_combinations = []
@@ -585,20 +629,23 @@ class ActorsWithGPUs(unittest.TestCase):
# Creating a new actor should fail because all of the GPUs are being used.
with self.assertRaises(Exception):
a = Actor1()
Actor1()
# We should be able to create more actors that use only a single GPU.
@ray.actor(num_gpus=1)
class Actor2(object):
def __init__(self):
self.gpu_ids = ray.get_gpu_ids()
def get_location_and_ids(self):
return ray.worker.global_worker.plasma_client.store_socket_name, tuple(self.gpu_ids)
return (ray.worker.global_worker.plasma_client.store_socket_name,
tuple(self.gpu_ids))
# Create some actors.
actors = [Actor2() for _ in range(num_local_schedulers)]
# Make sure that no two actors are assigned to the same GPU.
locations_and_ids = ray.get([actor.get_location_and_ids() for actor in actors])
locations_and_ids = ray.get([actor.get_location_and_ids()
for actor in actors])
node_names = set([location for location, gpu_id in locations_and_ids])
self.assertEqual(len(node_names), num_local_schedulers)
location_actor_combinations = []
@@ -608,13 +655,13 @@ class ActorsWithGPUs(unittest.TestCase):
# Creating a new actor should fail because all of the GPUs are being used.
with self.assertRaises(Exception):
a = Actor2()
Actor2()
ray.worker.cleanup()
def testActorDifferentNumbersOfGPUs(self):
# Test that we can create actors on two nodes that have different numbers of
# GPUs.
# Test that we can create actors on two nodes that have different numbers
# of GPUs.
ray.worker._init(start_ray_local=True, num_workers=0,
num_local_schedulers=3, num_gpus=[0, 5, 10])
@@ -622,32 +669,38 @@ class ActorsWithGPUs(unittest.TestCase):
class Actor1(object):
def __init__(self):
self.gpu_ids = ray.get_gpu_ids()
def get_location_and_ids(self):
return ray.worker.global_worker.plasma_client.store_socket_name, tuple(self.gpu_ids)
return (ray.worker.global_worker.plasma_client.store_socket_name,
tuple(self.gpu_ids))
# Create some actors.
actors = [Actor1() for _ in range(0 + 5 + 10)]
# Make sure that no two actors are assigned to the same GPU.
locations_and_ids = ray.get([actor.get_location_and_ids() for actor in actors])
locations_and_ids = ray.get([actor.get_location_and_ids()
for actor in actors])
node_names = set([location for location, gpu_id in locations_and_ids])
self.assertEqual(len(node_names), 2)
for node_name in node_names:
node_gpu_ids = [gpu_id for location, gpu_id in locations_and_ids if location == node_name]
node_gpu_ids = [gpu_id for location, gpu_id in locations_and_ids
if location == node_name]
self.assertIn(len(node_gpu_ids), [5, 10])
self.assertEqual(set(node_gpu_ids), set([(i,) for i in range(len(node_gpu_ids))]))
self.assertEqual(set(node_gpu_ids),
set([(i,) for i in range(len(node_gpu_ids))]))
# Creating a new actor should fail because all of the GPUs are being used.
with self.assertRaises(Exception):
a = Actor1()
Actor1()
ray.worker.cleanup()
def testActorMultipleGPUsFromMultipleTasks(self):
num_local_schedulers = 10
num_gpus_per_scheduler = 10
ray.worker._init(start_ray_local=True, num_workers=0,
num_local_schedulers=num_local_schedulers,
num_gpus=(num_local_schedulers * [num_gpus_per_scheduler]))
ray.worker._init(
start_ray_local=True, num_workers=0,
num_local_schedulers=num_local_schedulers,
num_gpus=(num_local_schedulers * [num_gpus_per_scheduler]))
@ray.remote
def create_actors(n):
@@ -655,20 +708,25 @@ class ActorsWithGPUs(unittest.TestCase):
class Actor(object):
def __init__(self):
self.gpu_ids = ray.get_gpu_ids()
def get_location_and_ids(self):
return ray.worker.global_worker.plasma_client.store_socket_name, tuple(self.gpu_ids)
return (ray.worker.global_worker.plasma_client.store_socket_name,
tuple(self.gpu_ids))
# Create n actors.
for _ in range(n):
Actor()
ray.get([create_actors.remote(num_gpus_per_scheduler) for _ in range(num_local_schedulers)])
ray.get([create_actors.remote(num_gpus_per_scheduler)
for _ in range(num_local_schedulers)])
@ray.actor(num_gpus=1)
class Actor(object):
def __init__(self):
self.gpu_ids = ray.get_gpu_ids()
def get_location_and_ids(self):
return ray.worker.global_worker.plasma_client.store_socket_name, tuple(self.gpu_ids)
return (ray.worker.global_worker.plasma_client.store_socket_name,
tuple(self.gpu_ids))
# All the GPUs should be used up now.
with self.assertRaises(Exception):
@@ -676,5 +734,6 @@ class ActorsWithGPUs(unittest.TestCase):
ray.worker.cleanup()
if __name__ == "__main__":
unittest.main(verbosity=2)
+55 -27
View File
@@ -5,20 +5,21 @@ from __future__ import print_function
import unittest
import ray
import numpy as np
import time
from numpy.testing import assert_equal, assert_almost_equal
import sys
if sys.version_info >= (3, 0):
from importlib import reload
import ray.experimental.array.remote as ra
import ray.experimental.array.distributed as da
if sys.version_info >= (3, 0):
from importlib import reload
class RemoteArrayTest(unittest.TestCase):
def testMethods(self):
for module in [ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg]:
for module in [ra.core, ra.random, ra.linalg, da.core, da.random,
da.linalg]:
reload(module)
ray.init(num_workers=1)
@@ -49,24 +50,30 @@ class RemoteArrayTest(unittest.TestCase):
ray.worker.cleanup()
class DistributedArrayTest(unittest.TestCase):
def testAssemble(self):
for module in [ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg]:
for module in [ra.core, ra.random, ra.linalg, da.core, da.random,
da.linalg]:
reload(module)
ray.init(num_workers=1)
a = ra.ones.remote([da.BLOCK_SIZE, da.BLOCK_SIZE])
b = ra.zeros.remote([da.BLOCK_SIZE, da.BLOCK_SIZE])
x = da.DistArray([2 * da.BLOCK_SIZE, da.BLOCK_SIZE], np.array([[a], [b]]))
assert_equal(x.assemble(), np.vstack([np.ones([da.BLOCK_SIZE, da.BLOCK_SIZE]), np.zeros([da.BLOCK_SIZE, da.BLOCK_SIZE])]))
assert_equal(x.assemble(),
np.vstack([np.ones([da.BLOCK_SIZE, da.BLOCK_SIZE]),
np.zeros([da.BLOCK_SIZE, da.BLOCK_SIZE])]))
ray.worker.cleanup()
def testMethods(self):
for module in [ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg]:
for module in [ra.core, ra.random, ra.linalg, da.core, da.random,
da.linalg]:
reload(module)
ray.worker._init(start_ray_local=True, num_workers=10, num_local_schedulers=2, num_cpus=[10, 10])
ray.worker._init(start_ray_local=True, num_workers=10,
num_local_schedulers=2, num_cpus=[10, 10])
x = da.zeros.remote([9, 25, 51], "float")
assert_equal(ray.get(da.assemble.remote(x)), np.zeros([9, 25, 51]))
@@ -76,18 +83,21 @@ class DistributedArrayTest(unittest.TestCase):
x = da.random.normal.remote([11, 25, 49])
y = da.copy.remote(x)
assert_equal(ray.get(da.assemble.remote(x)), ray.get(da.assemble.remote(y)))
assert_equal(ray.get(da.assemble.remote(x)),
ray.get(da.assemble.remote(y)))
x = da.eye.remote(25, dtype_name="float")
assert_equal(ray.get(da.assemble.remote(x)), np.eye(25))
x = da.random.normal.remote([25, 49])
y = da.triu.remote(x)
assert_equal(ray.get(da.assemble.remote(y)), np.triu(ray.get(da.assemble.remote(x))))
assert_equal(ray.get(da.assemble.remote(y)),
np.triu(ray.get(da.assemble.remote(x))))
x = da.random.normal.remote([25, 49])
y = da.tril.remote(x)
assert_equal(ray.get(da.assemble.remote(y)), np.tril(ray.get(da.assemble.remote(x))))
assert_equal(ray.get(da.assemble.remote(y)),
np.tril(ray.get(da.assemble.remote(x))))
x = da.random.normal.remote([25, 49])
y = da.random.normal.remote([49, 18])
@@ -102,29 +112,37 @@ class DistributedArrayTest(unittest.TestCase):
x = da.random.normal.remote([23, 42])
y = da.random.normal.remote([23, 42])
z = da.add.remote(x, y)
assert_almost_equal(ray.get(da.assemble.remote(z)), ray.get(da.assemble.remote(x)) + ray.get(da.assemble.remote(y)))
assert_almost_equal(ray.get(da.assemble.remote(z)),
ray.get(da.assemble.remote(x)) +
ray.get(da.assemble.remote(y)))
# test subtract
x = da.random.normal.remote([33, 40])
y = da.random.normal.remote([33, 40])
z = da.subtract.remote(x, y)
assert_almost_equal(ray.get(da.assemble.remote(z)), ray.get(da.assemble.remote(x)) - ray.get(da.assemble.remote(y)))
assert_almost_equal(ray.get(da.assemble.remote(z)),
ray.get(da.assemble.remote(x)) -
ray.get(da.assemble.remote(y)))
# test transpose
x = da.random.normal.remote([234, 432])
y = da.transpose.remote(x)
assert_equal(ray.get(da.assemble.remote(x)).T, ray.get(da.assemble.remote(y)))
assert_equal(ray.get(da.assemble.remote(x)).T,
ray.get(da.assemble.remote(y)))
# test numpy_to_dist
x = da.random.normal.remote([23, 45])
y = da.assemble.remote(x)
z = da.numpy_to_dist.remote(y)
w = da.assemble.remote(z)
assert_equal(ray.get(da.assemble.remote(x)), ray.get(da.assemble.remote(z)))
assert_equal(ray.get(da.assemble.remote(x)),
ray.get(da.assemble.remote(z)))
assert_equal(ray.get(y), ray.get(w))
# test da.tsqr
for shape in [[123, da.BLOCK_SIZE], [7, da.BLOCK_SIZE], [da.BLOCK_SIZE, da.BLOCK_SIZE], [da.BLOCK_SIZE, 7], [10 * da.BLOCK_SIZE, da.BLOCK_SIZE]]:
for shape in [[123, da.BLOCK_SIZE], [7, da.BLOCK_SIZE],
[da.BLOCK_SIZE, da.BLOCK_SIZE], [da.BLOCK_SIZE, 7],
[10 * da.BLOCK_SIZE, da.BLOCK_SIZE]]:
x = da.random.normal.remote(shape)
K = min(shape)
q, r = da.linalg.tsqr.remote(x)
@@ -138,23 +156,26 @@ class DistributedArrayTest(unittest.TestCase):
# test da.linalg.modified_lu
def test_modified_lu(d1, d2):
print("testing dist_modified_lu with d1 = " + str(d1) + ", d2 = " + str(d2))
print("testing dist_modified_lu with d1 = " + str(d1) +
", d2 = " + str(d2))
assert d1 >= d2
k = min(d1, d2)
m = ra.random.normal.remote([d1, d2])
q, r = ra.linalg.qr.remote(m)
l, u, s = da.linalg.modified_lu.remote(da.numpy_to_dist.remote(q))
q_val = ray.get(q)
r_val = ray.get(r)
ray.get(r)
l_val = ray.get(da.assemble.remote(l))
u_val = ray.get(u)
s_val = ray.get(s)
s_mat = np.zeros((d1, d2))
for i in range(len(s_val)):
s_mat[i, i] = s_val[i]
assert_almost_equal(q_val - s_mat, np.dot(l_val, u_val)) # check that q - s = l * u
assert_equal(np.triu(u_val), u_val) # check that u is upper triangular
assert_equal(np.tril(l_val), l_val) # check that l is lower triangular
# Check that q - s = l * u.
assert_almost_equal(q_val - s_mat, np.dot(l_val, u_val))
# Check that u is upper triangular.
assert_equal(np.triu(u_val), u_val)
# Check that l is lower triangular.
assert_equal(np.tril(l_val), l_val)
for d1, d2 in [(100, 100), (99, 98), (7, 5), (7, 7), (20, 7), (20, 10)]:
test_modified_lu(d1, d2)
@@ -172,10 +193,14 @@ class DistributedArrayTest(unittest.TestCase):
tall_eye = np.zeros((d1, min(d1, d2)))
np.fill_diagonal(tall_eye, 1)
q = tall_eye - np.dot(y_val, np.dot(t_val, y_top_val.T))
assert_almost_equal(np.dot(q.T, q), np.eye(min(d1, d2))) # check that q.T * q = I
assert_almost_equal(np.dot(q, r_val), a_val) # check that a = (I - y * t * y_top.T) * r
# Check that q.T * q = I.
assert_almost_equal(np.dot(q.T, q), np.eye(min(d1, d2)))
# Check that a = (I - y * t * y_top.T) * r.
assert_almost_equal(np.dot(q, r_val), a_val)
for d1, d2 in [(123, da.BLOCK_SIZE), (7, da.BLOCK_SIZE), (da.BLOCK_SIZE, da.BLOCK_SIZE), (da.BLOCK_SIZE, 7), (10 * da.BLOCK_SIZE, da.BLOCK_SIZE)]:
for d1, d2 in [(123, da.BLOCK_SIZE), (7, da.BLOCK_SIZE),
(da.BLOCK_SIZE, da.BLOCK_SIZE), (da.BLOCK_SIZE, 7),
(10 * da.BLOCK_SIZE, da.BLOCK_SIZE)]:
test_dist_tsqr_hr(d1, d2)
def test_dist_qr(d1, d2):
@@ -192,7 +217,9 @@ class DistributedArrayTest(unittest.TestCase):
assert_equal(r_val, np.triu(r_val))
assert_almost_equal(a_val, np.dot(q_val, r_val))
for d1, d2 in [(123, da.BLOCK_SIZE), (7, da.BLOCK_SIZE), (da.BLOCK_SIZE, da.BLOCK_SIZE), (da.BLOCK_SIZE, 7), (13, 21), (34, 35), (8, 7)]:
for d1, d2 in [(123, da.BLOCK_SIZE), (7, da.BLOCK_SIZE),
(da.BLOCK_SIZE, da.BLOCK_SIZE), (da.BLOCK_SIZE, 7),
(13, 21), (34, 35), (8, 7)]:
test_dist_qr(d1, d2)
test_dist_qr(d2, d1)
for _ in range(20):
@@ -202,5 +229,6 @@ class DistributedArrayTest(unittest.TestCase):
ray.worker.cleanup()
if __name__ == "__main__":
unittest.main(verbosity=2)
+46 -32
View File
@@ -3,10 +3,10 @@ from __future__ import division
from __future__ import print_function
import ray
import sys
import time
import unittest
class ComponentFailureTest(unittest.TestCase):
def tearDown(self):
@@ -16,6 +16,7 @@ class ComponentFailureTest(unittest.TestCase):
# store and manager will not die.
def testDyingWorkerGet(self):
obj_id = 20 * b"a"
@ray.remote
def f():
ray.worker.global_worker.plasma_client.get(obj_id)
@@ -40,12 +41,14 @@ class ComponentFailureTest(unittest.TestCase):
time.sleep(0.1)
# Make sure that nothing has died.
self.assertTrue(ray.services.all_processes_alive(exclude=[ray.services.PROCESS_TYPE_WORKER]))
self.assertTrue(ray.services.all_processes_alive(
exclude=[ray.services.PROCESS_TYPE_WORKER]))
# This test checks that when a worker dies in the middle of a wait, the plasma
# store and manager will not die.
# This test checks that when a worker dies in the middle of a wait, the
# plasma store and manager will not die.
def testDyingWorkerWait(self):
obj_id = 20 * b"a"
@ray.remote
def f():
ray.worker.global_worker.plasma_client.wait([obj_id])
@@ -70,7 +73,8 @@ class ComponentFailureTest(unittest.TestCase):
time.sleep(0.1)
# Make sure that nothing has died.
self.assertTrue(ray.services.all_processes_alive(exclude=[ray.services.PROCESS_TYPE_WORKER]))
self.assertTrue(ray.services.all_processes_alive(
exclude=[ray.services.PROCESS_TYPE_WORKER]))
def _testWorkerFailed(self, num_local_schedulers):
@ray.remote
@@ -86,7 +90,8 @@ class ComponentFailureTest(unittest.TestCase):
num_cpus=[num_initial_workers] * num_local_schedulers)
# Submit more tasks than there are workers so that all workers and cores
# are utilized.
object_ids = [f.remote(i) for i in range(num_initial_workers * num_local_schedulers)]
object_ids = [f.remote(i) for i
in range(num_initial_workers * num_local_schedulers)]
object_ids += [f.remote(object_id) for object_id in object_ids]
# Allow the tasks some time to begin executing.
time.sleep(0.1)
@@ -94,7 +99,8 @@ class ComponentFailureTest(unittest.TestCase):
for worker in ray.services.all_processes[ray.services.PROCESS_TYPE_WORKER]:
worker.terminate()
time.sleep(0.1)
# Make sure that we can still get the objects after the executing tasks died.
# Make sure that we can still get the objects after the executing tasks
# died.
ray.get(object_ids)
def testWorkerFailed(self):
@@ -104,8 +110,7 @@ class ComponentFailureTest(unittest.TestCase):
self._testWorkerFailed(4)
def _testComponentFailed(self, component_type):
"""Kill a component on all worker nodes and check that workload succeeds.
"""
"""Kill a component on all worker nodes and check workload succeeds."""
@ray.remote
def f(x, j):
time.sleep(0.2)
@@ -114,14 +119,16 @@ class ComponentFailureTest(unittest.TestCase):
# Start with 4 workers and 4 cores.
num_local_schedulers = 4
num_workers_per_scheduler = 8
address_info = ray.worker._init(num_workers=num_local_schedulers * num_workers_per_scheduler,
num_local_schedulers=num_local_schedulers,
start_ray_local=True,
num_cpus=[num_workers_per_scheduler] * num_local_schedulers)
ray.worker._init(
num_workers=num_local_schedulers * num_workers_per_scheduler,
num_local_schedulers=num_local_schedulers,
start_ray_local=True,
num_cpus=[num_workers_per_scheduler] * num_local_schedulers)
# Submit more tasks than there are workers so that all workers and cores are
# utilized.
object_ids = [f.remote(i, 0) for i in range(num_workers_per_scheduler * num_local_schedulers)]
# Submit more tasks than there are workers so that all workers and cores
# are utilized.
object_ids = [f.remote(i, 0) for i
in range(num_workers_per_scheduler * num_local_schedulers)]
object_ids += [f.remote(object_id, 1) for object_id in object_ids]
object_ids += [f.remote(object_id, 2) for object_id in object_ids]
@@ -140,7 +147,8 @@ class ComponentFailureTest(unittest.TestCase):
# Make sure that we can still get the objects after the executing tasks
# died.
results = ray.get(object_ids)
expected_results = 4 * list(range(num_workers_per_scheduler * num_local_schedulers))
expected_results = 4 * list(range(
num_workers_per_scheduler * num_local_schedulers))
self.assertEqual(results, expected_results)
def check_components_alive(self, component_type, check_component_alive):
@@ -161,7 +169,8 @@ class ComponentFailureTest(unittest.TestCase):
# nodes.
self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_STORE, True)
self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_MANAGER, True)
self.check_components_alive(ray.services.PROCESS_TYPE_LOCAL_SCHEDULER, False)
self.check_components_alive(ray.services.PROCESS_TYPE_LOCAL_SCHEDULER,
False)
def testPlasmaManagerFailed(self):
# Kill all plasma managers on worker nodes.
@@ -170,8 +179,10 @@ class ComponentFailureTest(unittest.TestCase):
# The plasma stores should still be alive (but unreachable) on the worker
# nodes.
self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_STORE, True)
self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_MANAGER, False)
self.check_components_alive(ray.services.PROCESS_TYPE_LOCAL_SCHEDULER, False)
self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_MANAGER,
False)
self.check_components_alive(ray.services.PROCESS_TYPE_LOCAL_SCHEDULER,
False)
def testPlasmaStoreFailed(self):
# Kill all plasma stores on worker nodes.
@@ -179,17 +190,19 @@ class ComponentFailureTest(unittest.TestCase):
# No processes should be left alive on the worker nodes.
self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_STORE, False)
self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_MANAGER, False)
self.check_components_alive(ray.services.PROCESS_TYPE_LOCAL_SCHEDULER, False)
self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_MANAGER,
False)
self.check_components_alive(ray.services.PROCESS_TYPE_LOCAL_SCHEDULER,
False)
def testDriverLivesSequential(self):
ray.worker.init()
all_processes = ray.services.all_processes
processes = [
ray.services.all_processes[ray.services.PROCESS_TYPE_PLASMA_STORE][0],
ray.services.all_processes[ray.services.PROCESS_TYPE_PLASMA_MANAGER][0],
ray.services.all_processes[ray.services.PROCESS_TYPE_LOCAL_SCHEDULER][0],
ray.services.all_processes[ray.services.PROCESS_TYPE_GLOBAL_SCHEDULER][0],
]
all_processes[ray.services.PROCESS_TYPE_PLASMA_STORE][0],
all_processes[ray.services.PROCESS_TYPE_PLASMA_MANAGER][0],
all_processes[ray.services.PROCESS_TYPE_LOCAL_SCHEDULER][0],
all_processes[ray.services.PROCESS_TYPE_GLOBAL_SCHEDULER][0]]
# Kill all the components sequentially.
for process in processes:
@@ -202,12 +215,12 @@ class ComponentFailureTest(unittest.TestCase):
def testDriverLivesParallel(self):
ray.worker.init()
all_processes = ray.services.all_processes
processes = [
ray.services.all_processes[ray.services.PROCESS_TYPE_PLASMA_STORE][0],
ray.services.all_processes[ray.services.PROCESS_TYPE_PLASMA_MANAGER][0],
ray.services.all_processes[ray.services.PROCESS_TYPE_LOCAL_SCHEDULER][0],
ray.services.all_processes[ray.services.PROCESS_TYPE_GLOBAL_SCHEDULER][0],
]
all_processes[ray.services.PROCESS_TYPE_PLASMA_STORE][0],
all_processes[ray.services.PROCESS_TYPE_PLASMA_MANAGER][0],
all_processes[ray.services.PROCESS_TYPE_LOCAL_SCHEDULER][0],
all_processes[ray.services.PROCESS_TYPE_GLOBAL_SCHEDULER][0]]
# Kill all the components in parallel.
for process in processes:
@@ -222,5 +235,6 @@ class ComponentFailureTest(unittest.TestCase):
# If the driver can reach the tearDown method, then it is still alive.
if __name__ == "__main__":
unittest.main(verbosity=2)
+54 -24
View File
@@ -9,14 +9,16 @@ import tempfile
import time
import unittest
import ray.test.test_functions as test_functions
if sys.version_info >= (3, 0):
from importlib import reload
import ray.test.test_functions as test_functions
def relevant_errors(error_type):
return [info for info in ray.error_info() if info[b"type"] == error_type]
def wait_for_errors(error_type, num_errors, timeout=10):
start_time = time.time()
while time.time() - start_time < timeout:
@@ -25,6 +27,7 @@ def wait_for_errors(error_type, num_errors, timeout=10):
time.sleep(0.1)
print("Timing out of wait.")
class FailureTest(unittest.TestCase):
def testUnknownSerialization(self):
reload(test_functions)
@@ -32,32 +35,35 @@ class FailureTest(unittest.TestCase):
test_functions.test_unknown_type.remote()
wait_for_errors(b"task", 1)
error_info = ray.error_info()
self.assertEqual(len(relevant_errors(b"task")), 1)
ray.worker.cleanup()
class TaskSerializationTest(unittest.TestCase):
def testReturnAndPassUnknownType(self):
ray.init(num_workers=1, driver_mode=ray.SILENT_MODE)
class Foo(object):
pass
# Check that returning an unknown type from a remote function raises an
# exception.
@ray.remote
def f():
return Foo()
self.assertRaises(Exception, lambda : ray.get(f.remote()))
self.assertRaises(Exception, lambda: ray.get(f.remote()))
# Check that passing an unknown type into a remote function raises an
# exception.
@ray.remote
def g(x):
return 1
self.assertRaises(Exception, lambda : g.remote(Foo()))
self.assertRaises(Exception, lambda: g.remote(Foo()))
ray.worker.cleanup()
class TaskStatusTest(unittest.TestCase):
def testFailedTask(self):
reload(test_functions)
@@ -66,10 +72,10 @@ class TaskStatusTest(unittest.TestCase):
test_functions.throw_exception_fct1.remote()
test_functions.throw_exception_fct1.remote()
wait_for_errors(b"task", 2)
result = ray.error_info()
self.assertEqual(len(relevant_errors(b"task")), 2)
for task in relevant_errors(b"task"):
self.assertIn(b"Test function 1 intentionally failed.", task.get(b"message"))
self.assertIn(b"Test function 1 intentionally failed.",
task.get(b"message"))
x = test_functions.throw_exception_fct2.remote()
try:
@@ -77,7 +83,8 @@ class TaskStatusTest(unittest.TestCase):
except Exception as e:
self.assertIn("Test function 2 intentionally failed.", str(e))
else:
self.assertTrue(False) # ray.get should throw an exception
# ray.get should throw an exception.
self.assertTrue(False)
x, y, z = test_functions.throw_exception_fct3.remote(1.0)
for ref in [x, y, z]:
@@ -86,7 +93,8 @@ class TaskStatusTest(unittest.TestCase):
except Exception as e:
self.assertIn("Test function 3 intentionally failed.", str(e))
else:
self.assertTrue(False) # ray.get should throw an exception
# ray.get should throw an exception.
self.assertTrue(False)
ray.worker.cleanup()
@@ -108,8 +116,8 @@ def temporary_helper_function():
sys.path.append(directory)
module = __import__(module_name)
# Define a function that closes over this temporary module. This should fail
# when it is unpickled.
# Define a function that closes over this temporary module. This should
# fail when it is unpickled.
@ray.remote
def g():
return module.temporary_python_file()
@@ -121,7 +129,7 @@ def temporary_helper_function():
# Check that if we try to call the function it throws an exception and does
# not hang.
for _ in range(10):
self.assertRaises(Exception, lambda : ray.get(g.remote()))
self.assertRaises(Exception, lambda: ray.get(g.remote()))
f.close()
@@ -150,16 +158,19 @@ def temporary_helper_function():
def initializer():
return 0
def reinitializer(foo):
raise Exception("The reinitializer failed.")
ray.env.foo = ray.EnvironmentVariable(initializer, reinitializer)
@ray.remote
def use_foo():
ray.env.foo
use_foo.remote()
wait_for_errors(b"reinitialize_environment_variable", 1)
# Check that the error message is in the task info.
self.assertIn(b"The reinitializer failed.", ray.error_info()[0][b"message"])
self.assertIn(b"The reinitializer failed.",
ray.error_info()[0][b"message"])
ray.worker.cleanup()
@@ -202,6 +213,7 @@ def temporary_helper_function():
class Foo(object):
def __init__(self):
self.x = module.temporary_python_file()
def get_val(self):
return 1
@@ -217,7 +229,8 @@ def temporary_helper_function():
# Wait for the error from when the __init__ tries to run.
wait_for_errors(b"task", 1)
self.assertIn(b"failed to be imported, and so cannot execute this method", ray.error_info()[1][b"message"])
self.assertIn(b"failed to be imported, and so cannot execute this method",
ray.error_info()[1][b"message"])
# Check that if we try to get the function it throws an exception and does
# not hang.
@@ -226,7 +239,8 @@ def temporary_helper_function():
# Wait for the error from when the call to get_val.
wait_for_errors(b"task", 2)
self.assertIn(b"failed to be imported, and so cannot execute this method", ray.error_info()[2][b"message"])
self.assertIn(b"failed to be imported, and so cannot execute this method",
ray.error_info()[2][b"message"])
f.close()
@@ -234,6 +248,7 @@ def temporary_helper_function():
sys.path.pop(-1)
ray.worker.cleanup()
class ActorTest(unittest.TestCase):
def testFailedActorInit(self):
@@ -241,12 +256,15 @@ class ActorTest(unittest.TestCase):
error_message1 = "actor constructor failed"
error_message2 = "actor method failed"
@ray.actor
class FailedActor(object):
def __init__(self):
raise Exception(error_message1)
def get_val(self):
return 1
def fail_method(self):
raise Exception(error_message2)
@@ -255,13 +273,15 @@ class ActorTest(unittest.TestCase):
# Make sure that we get errors from a failed constructor.
wait_for_errors(b"task", 1)
self.assertEqual(len(ray.error_info()), 1)
self.assertIn(error_message1, ray.error_info()[0][b"message"].decode("ascii"))
self.assertIn(error_message1,
ray.error_info()[0][b"message"].decode("ascii"))
# Make sure that we get errors from a failed method.
a.fail_method()
wait_for_errors(b"task", 2)
self.assertEqual(len(ray.error_info()), 2)
self.assertIn(error_message2, ray.error_info()[1][b"message"].decode("ascii"))
self.assertIn(error_message2,
ray.error_info()[1][b"message"].decode("ascii"))
ray.worker.cleanup()
@@ -272,6 +292,7 @@ class ActorTest(unittest.TestCase):
class Actor(object):
def __init__(self, missing_variable_name):
pass
def get_val(self, x):
pass
@@ -284,18 +305,22 @@ class ActorTest(unittest.TestCase):
wait_for_errors(b"task", 1)
self.assertEqual(len(ray.error_info()), 1)
if sys.version_info >= (3, 0):
self.assertIn("missing 1 required", ray.error_info()[0][b"message"].decode("ascii"))
self.assertIn("missing 1 required",
ray.error_info()[0][b"message"].decode("ascii"))
else:
self.assertIn("takes exactly 2 arguments", ray.error_info()[0][b"message"].decode("ascii"))
self.assertIn("takes exactly 2 arguments",
ray.error_info()[0][b"message"].decode("ascii"))
# Create an actor with too many arguments.
a = Actor(1, 2)
wait_for_errors(b"task", 2)
self.assertEqual(len(ray.error_info()), 2)
if sys.version_info >= (3, 0):
self.assertIn("but 3 were given", ray.error_info()[1][b"message"].decode("ascii"))
self.assertIn("but 3 were given",
ray.error_info()[1][b"message"].decode("ascii"))
else:
self.assertIn("takes exactly 2 arguments", ray.error_info()[1][b"message"].decode("ascii"))
self.assertIn("takes exactly 2 arguments",
ray.error_info()[1][b"message"].decode("ascii"))
# Create an actor the correct number of arguments.
a = Actor(1)
@@ -305,23 +330,28 @@ class ActorTest(unittest.TestCase):
wait_for_errors(b"task", 3)
self.assertEqual(len(ray.error_info()), 3)
if sys.version_info >= (3, 0):
self.assertIn("missing 1 required", ray.error_info()[2][b"message"].decode("ascii"))
self.assertIn("missing 1 required",
ray.error_info()[2][b"message"].decode("ascii"))
else:
self.assertIn("takes exactly 2 arguments", ray.error_info()[2][b"message"].decode("ascii"))
self.assertIn("takes exactly 2 arguments",
ray.error_info()[2][b"message"].decode("ascii"))
# Call a method with too many arguments.
a.get_val(1, 2)
wait_for_errors(b"task", 4)
self.assertEqual(len(ray.error_info()), 4)
if sys.version_info >= (3, 0):
self.assertIn("but 3 were given", ray.error_info()[3][b"message"].decode("ascii"))
self.assertIn("but 3 were given",
ray.error_info()[3][b"message"].decode("ascii"))
else:
self.assertIn("takes exactly 2 arguments", ray.error_info()[3][b"message"].decode("ascii"))
self.assertIn("takes exactly 2 arguments",
ray.error_info()[3][b"message"].decode("ascii"))
# Call a method that doesn't exist.
with self.assertRaises(AttributeError):
a.nonexistent_method()
ray.worker.cleanup()
if __name__ == "__main__":
unittest.main(verbosity=2)
+26 -18
View File
@@ -7,7 +7,7 @@ import os
import re
import subprocess
import sys
import time
def wait_for_output(proc):
"""This is a convenience method to parse a process's stdout and stderr.
@@ -19,10 +19,13 @@ def wait_for_output(proc):
A tuple of the stdout and stderr of the process as strings.
"""
stdout_data, stderr_data = proc.communicate()
stdout_data = stdout_data.decode("ascii") if stdout_data is not None else None
stderr_data = stderr_data.decode("ascii") if stderr_data is not None else None
stdout_data = (stdout_data.decode("ascii") if stdout_data is not None
else None)
stderr_data = (stderr_data.decode("ascii") if stderr_data is not None
else None)
return stdout_data, stderr_data
class DockerRunner(object):
"""This class manages the logistics of running multiple nodes in Docker.
@@ -34,8 +37,8 @@ class DockerRunner(object):
head_container_id: The ID of the docker container that runs the head node.
worker_container_ids: A list of the docker container IDs of the Ray worker
nodes.
head_container_ip: The IP address of the docker container that runs the head
node.
head_container_ip: The IP address of the docker container that runs the
head node.
"""
def __init__(self):
"""Initialize the DockerRunner."""
@@ -47,8 +50,8 @@ class DockerRunner(object):
"""Parse the docker container ID from stdout_data.
Args:
stdout_data: This should be a string with the standard output of a call to
a docker command.
stdout_data: This should be a string with the standard output of a call
to a docker command.
Returns:
The container ID of the docker container.
@@ -70,7 +73,8 @@ class DockerRunner(object):
The IP address of the container.
"""
proc = subprocess.Popen(["docker", "inspect",
"--format={{.NetworkSettings.Networks.bridge.IPAddress}}",
"--format={{.NetworkSettings.Networks.bridge"
".IPAddress}}",
container_id],
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout_data, _ = wait_for_output(proc)
@@ -86,9 +90,10 @@ class DockerRunner(object):
"""Start the Ray head node inside a docker container."""
mem_arg = ["--memory=" + mem_size] if mem_size else []
shm_arg = ["--shm-size=" + shm_size] if shm_size else []
volume_arg = ["-v",
"{}:{}".format(os.path.dirname(os.path.realpath(__file__)),
"/ray/test/jenkins_tests")] if development_mode else []
volume_arg = (["-v",
"{}:{}".format(os.path.dirname(os.path.realpath(__file__)),
"/ray/test/jenkins_tests")]
if development_mode else [])
proc = subprocess.Popen(["docker", "run", "-d"] + mem_arg + shm_arg +
volume_arg +
[docker_image, "/ray/scripts/start_ray.sh",
@@ -113,7 +118,8 @@ class DockerRunner(object):
proc = subprocess.Popen(["docker", "run", "-d"] + mem_arg + shm_arg +
["--shm-size=" + shm_size, docker_image,
"/ray/scripts/start_ray.sh",
"--redis-address={:s}:6379".format(self.head_container_ip)],
"--redis-address={:s}:6379".format(
self.head_container_ip)],
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout_data, _ = wait_for_output(proc)
container_id = self._get_container_id(stdout_data)
@@ -136,10 +142,10 @@ class DockerRunner(object):
mem_size: The amount of memory to start each docker container with. This
will be passed into `docker run` as the --memory flag. If this is None,
then no --memory flag will be used.
shm_size: The amount of shared memory to start each docker container with.
This will be passed into `docker run` as the `--shm-size` flag.
num_nodes: The number of nodes to use in the cluster (this counts the head
node as well).
shm_size: The amount of shared memory to start each docker container
with. This will be passed into `docker run` as the `--shm-size` flag.
num_nodes: The number of nodes to use in the cluster (this counts the
head node as well).
development_mode: True if you want to mount the local copy of
test/jenkins_test on the head node so we can avoid rebuilding docker
images during development.
@@ -163,7 +169,7 @@ class DockerRunner(object):
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout_data, _ = wait_for_output(proc)
removed_container_id = self._get_container_id(stdout_data)
if not container_id == stopped_container_id:
if not container_id == removed_container_id:
raise Exception("Failed to remove container {}.".format(container_id))
print("stop_node", {"container_id": container_id,
@@ -202,8 +208,10 @@ class DockerRunner(object):
print(stderr_data)
return {"success": proc.returncode == 0, "return_code": proc.returncode}
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run multinode tests in Docker.")
parser = argparse.ArgumentParser(
description="Run multinode tests in Docker.")
parser.add_argument("--docker-image", default="ray-project/deploy",
help="docker image")
parser.add_argument("--mem-size", help="memory size")
@@ -3,11 +3,13 @@ import time
import ray
@ray.remote
def f():
time.sleep(0.1)
return ray.services.get_node_ip_address()
if __name__ == "__main__":
ray.init(redis_address=os.environ["RAY_REDIS_ADDRESS"])
# Check that tasks are scheduled on all nodes.
+19 -12
View File
@@ -9,10 +9,11 @@ import sys
import time
import numpy as np
import ray.test.test_functions as test_functions
if sys.version_info >= (3, 0):
from importlib import reload
import ray.test.test_functions as test_functions
class MicroBenchmarkTest(unittest.TestCase):
@@ -20,7 +21,7 @@ class MicroBenchmarkTest(unittest.TestCase):
reload(test_functions)
ray.init(num_workers=3)
# measure the time required to submit a remote task to the scheduler
# Measure the time required to submit a remote task to the scheduler.
elapsed_times = []
for _ in range(1000):
start_time = time.time()
@@ -34,9 +35,10 @@ class MicroBenchmarkTest(unittest.TestCase):
print(" 90th percentile: {}".format(elapsed_times[900]))
print(" 99th percentile: {}".format(elapsed_times[990]))
print(" worst: {}".format(elapsed_times[999]))
# average_elapsed_time should be about 0.00038
# average_elapsed_time should be about 0.00038.
# measure the time required to submit a remote task to the scheduler (where the remote task returns one value)
# Measure the time required to submit a remote task to the scheduler
# (where the remote task returns one value).
elapsed_times = []
for _ in range(1000):
start_time = time.time()
@@ -50,9 +52,10 @@ class MicroBenchmarkTest(unittest.TestCase):
print(" 90th percentile: {}".format(elapsed_times[900]))
print(" 99th percentile: {}".format(elapsed_times[990]))
print(" worst: {}".format(elapsed_times[999]))
# average_elapsed_time should be about 0.001
# average_elapsed_time should be about 0.001.
# measure the time required to submit a remote task to the scheduler and get the result
# Measure the time required to submit a remote task to the scheduler and
# get the result.
elapsed_times = []
for _ in range(1000):
start_time = time.time()
@@ -62,14 +65,15 @@ class MicroBenchmarkTest(unittest.TestCase):
elapsed_times.append(end_time - start_time)
elapsed_times = np.sort(elapsed_times)
average_elapsed_time = sum(elapsed_times) / 1000
print("Time required to submit a trivial function call and get the result:")
print("Time required to submit a trivial function call and get the "
"result:")
print(" Average: {}".format(average_elapsed_time))
print(" 90th percentile: {}".format(elapsed_times[900]))
print(" 99th percentile: {}".format(elapsed_times[990]))
print(" worst: {}".format(elapsed_times[999]))
# average_elapsed_time should be about 0.0013
# average_elapsed_time should be about 0.0013.
# measure the time required to do do a put
# Measure the time required to do do a put.
elapsed_times = []
for _ in range(1000):
start_time = time.time()
@@ -83,7 +87,7 @@ class MicroBenchmarkTest(unittest.TestCase):
print(" 90th percentile: {}".format(elapsed_times[900]))
print(" 99th percentile: {}".format(elapsed_times[990]))
print(" worst: {}".format(elapsed_times[999]))
# average_elapsed_time should be about 0.00087
# average_elapsed_time should be about 0.00087.
ray.worker.cleanup()
@@ -105,11 +109,14 @@ class MicroBenchmarkTest(unittest.TestCase):
if d > 1.5 * b:
if os.getenv("TRAVIS") is None:
raise Exception("The caching test was too slow. d = {}, b = {}".format(d, b))
raise Exception("The caching test was too slow. "
"d = {}, b = {}".format(d, b))
else:
print("WARNING: The caching test was too slow. d = {}, b = {}".format(d, b))
print("WARNING: The caching test was too slow. "
"d = {}, b = {}".format(d, b))
ray.worker.cleanup()
if __name__ == "__main__":
unittest.main(verbosity=2)
+16 -9
View File
@@ -2,17 +2,18 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import unittest
import ray
import subprocess
import sys
import tempfile
import time
start_ray_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../scripts/start_ray.sh")
stop_ray_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../scripts/stop_ray.sh")
start_ray_script = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"../scripts/start_ray.sh")
stop_ray_script = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"../scripts/stop_ray.sh")
class MultiNodeTest(unittest.TestCase):
@@ -21,7 +22,8 @@ class MultiNodeTest(unittest.TestCase):
out = subprocess.check_output([start_ray_script, "--head"]).decode("ascii")
# Get the redis address from the output.
redis_substring_prefix = "redis_address=\""
redis_address_location = out.find(redis_substring_prefix) + len(redis_substring_prefix)
redis_address_location = (out.find(redis_substring_prefix) +
len(redis_substring_prefix))
redis_address = out[redis_address_location:]
self.redis_address = redis_address.split("\"")[0]
@@ -54,7 +56,8 @@ class MultiNodeTest(unittest.TestCase):
# Make sure we got the error.
self.assertEqual(len(ray.error_info()), 1)
self.assertIn(error_string1, ray.error_info()[0][b"message"].decode("ascii"))
self.assertIn(error_string1,
ray.error_info()[0][b"message"].decode("ascii"))
# Start another driver and make sure that it does not receive this error.
# Make the other driver throw an error, and make sure it receives that
@@ -98,7 +101,8 @@ print("success")
# Make sure that the other error message doesn't show up for this driver.
self.assertEqual(len(ray.error_info()), 1)
self.assertIn(error_string1, ray.error_info()[0][b"message"].decode("ascii"))
self.assertIn(error_string1,
ray.error_info()[0][b"message"].decode("ascii"))
ray.worker.cleanup()
@@ -149,6 +153,7 @@ print("success")
ray.worker.cleanup()
class StartRayScriptTest(unittest.TestCase):
def testCallingStartRayHead(self):
@@ -157,11 +162,12 @@ class StartRayScriptTest(unittest.TestCase):
# the non-head node code path.
# Test starting Ray with no arguments.
out = subprocess.check_output([start_ray_script, "--head"]).decode("ascii")
subprocess.check_output([start_ray_script, "--head"]).decode("ascii")
subprocess.Popen([stop_ray_script]).wait()
# Test starting Ray with a number of workers specified.
subprocess.check_output([start_ray_script, "--head", "--num-workers", "20"])
subprocess.check_output([start_ray_script, "--head", "--num-workers",
"20"])
subprocess.Popen([stop_ray_script]).wait()
# Test starting Ray with a redis port specified.
@@ -204,5 +210,6 @@ class StartRayScriptTest(unittest.TestCase):
"--redis-address", "127.0.0.1:6379"])
subprocess.Popen([stop_ray_script]).wait()
if __name__ == "__main__":
unittest.main(verbosity=2)
+147 -60
View File
@@ -12,24 +12,33 @@ import string
import sys
from collections import namedtuple
import ray.test.test_functions as test_functions
if sys.version_info >= (3, 0):
from importlib import reload
import ray.test.test_functions as test_functions
import ray.experimental.array.remote as ra
import ray.experimental.array.distributed as da
def assert_equal(obj1, obj2):
if type(obj1).__module__ == np.__name__ or type(obj2).__module__ == np.__name__:
if (hasattr(obj1, "shape") and obj1.shape == ()) or (hasattr(obj2, "shape") and obj2.shape == ()):
module_numpy = (type(obj1).__module__ == np.__name__ or
type(obj2).__module__ == np.__name__)
if module_numpy:
empty_shape = ((hasattr(obj1, "shape") and obj1.shape == ()) or
(hasattr(obj2, "shape") and obj2.shape == ()))
if empty_shape:
# This is a special case because currently np.testing.assert_equal fails
# because we do not properly handle different numerical types.
assert obj1 == obj2, "Objects {} and {} are different.".format(obj1, obj2)
assert obj1 == obj2, ("Objects {} and {} are "
"different.".format(obj1, obj2))
else:
np.testing.assert_equal(obj1, obj2)
elif hasattr(obj1, "__dict__") and hasattr(obj2, "__dict__"):
special_keys = ["_pytype_"]
assert set(list(obj1.__dict__.keys()) + special_keys) == set(list(obj2.__dict__.keys()) + special_keys), "Objects {} and {} are different.".format(obj1, obj2)
assert (set(list(obj1.__dict__.keys()) + special_keys) ==
set(list(obj2.__dict__.keys()) + special_keys)), ("Objects {} and "
"{} are "
"different."
.format(obj1,
obj2))
for key in obj1.__dict__.keys():
if key not in special_keys:
assert_equal(obj1.__dict__[key], obj2.__dict__[key])
@@ -38,24 +47,29 @@ def assert_equal(obj1, obj2):
for key in obj1.keys():
assert_equal(obj1[key], obj2[key])
elif type(obj1) is list or type(obj2) is list:
assert len(obj1) == len(obj2), "Objects {} and {} are lists with different lengths.".format(obj1, obj2)
assert len(obj1) == len(obj2), ("Objects {} and {} are lists with "
"different lengths.".format(obj1, obj2))
for i in range(len(obj1)):
assert_equal(obj1[i], obj2[i])
elif type(obj1) is tuple or type(obj2) is tuple:
assert len(obj1) == len(obj2), "Objects {} and {} are tuples with different lengths.".format(obj1, obj2)
assert len(obj1) == len(obj2), ("Objects {} and {} are tuples with "
"different lengths.".format(obj1, obj2))
for i in range(len(obj1)):
assert_equal(obj1[i], obj2[i])
elif ray.serialization.is_named_tuple(type(obj1)) or ray.serialization.is_named_tuple(type(obj2)):
assert len(obj1) == len(obj2), "Objects {} and {} are named tuples with different lengths.".format(obj1, obj2)
elif (ray.serialization.is_named_tuple(type(obj1)) or
ray.serialization.is_named_tuple(type(obj2))):
assert len(obj1) == len(obj2), ("Objects {} and {} are named tuples with "
"different lengths.".format(obj1, obj2))
for i in range(len(obj1)):
assert_equal(obj1[i], obj2[i])
else:
assert obj1 == obj2, "Objects {} and {} are different.".format(obj1, obj2)
if sys.version_info >= (3, 0):
long_extras = [0, np.array([["hi", u"hi"], [1.3, 1]])]
else:
long_extras = [long(0), np.array([["hi", u"hi"], [1.3, long(1)]])]
long_extras = [long(0), np.array([["hi", u"hi"], [1.3, long(1)]])] # noqa: E501,F821
PRIMITIVE_OBJECTS = [0, 0.0, 0.9, 1 << 62, "a", string.printable, "\u262F",
u"hello world", u"\xff\xfe\x9c\x001\x000\x00", None, True,
@@ -65,45 +79,55 @@ PRIMITIVE_OBJECTS = [0, 0.0, 0.9, 1 << 62, "a", string.printable, "\u262F",
np.random.normal(size=[100, 100]), np.array(["hi", 3]),
np.array(["hi", 3], dtype=object)] + long_extras
COMPLEX_OBJECTS = [[[[[[[[[[[[[]]]]]]]]]]]],
{"obj{}".format(i): np.random.normal(size=[100, 100]) for i in range(10)},
#{(): {(): {(): {(): {(): {(): {(): {(): {(): {(): {(): {(): {}}}}}}}}}}}}},
((((((((((),),),),),),),),),),
{"a": {"b": {"c": {"d": {}}}}}
]
COMPLEX_OBJECTS = [
[[[[[[[[[[[[]]]]]]]]]]]],
{"obj{}".format(i): np.random.normal(size=[100, 100]) for i in range(10)},
# {(): {(): {(): {(): {(): {(): {(): {(): {(): {(): {
# (): {(): {}}}}}}}}}}}}},
((((((((((),),),),),),),),),),
{"a": {"b": {"c": {"d": {}}}}}]
class Foo(object):
def __init__(self):
pass
class Bar(object):
def __init__(self):
for i, val in enumerate(PRIMITIVE_OBJECTS + COMPLEX_OBJECTS):
setattr(self, "field{}".format(i), val)
class Baz(object):
def __init__(self):
self.foo = Foo()
self.bar = Bar()
def method(self, arg):
pass
class Qux(object):
def __init__(self):
self.objs = [Foo(), Bar(), Baz()]
class SubQux(Qux):
def __init__(self):
Qux.__init__(self)
class CustomError(Exception):
pass
Point = namedtuple("Point", ["x", "y"])
NamedTupleExample = namedtuple("Example", "field1, field2, field3, field4, field5")
NamedTupleExample = namedtuple("Example",
"field1, field2, field3, field4, field5")
CUSTOM_OBJECTS = [Exception("Test object."), CustomError(), Point(11, y=22),
Foo(), Bar(), Baz(), # Qux(), SubQux(),
Foo(), Bar(), Baz(), # Qux(), SubQux(),
NamedTupleExample(1, 1.0, "hi", np.zeros([3, 5]), [1, 2, 3])]
BASE_OBJECTS = PRIMITIVE_OBJECTS + COMPLEX_OBJECTS + CUSTOM_OBJECTS
@@ -112,8 +136,9 @@ LIST_OBJECTS = [[obj] for obj in BASE_OBJECTS]
TUPLE_OBJECTS = [(obj,) for obj in BASE_OBJECTS]
# The check that type(obj).__module__ != "numpy" should be unnecessary, but
# otherwise this seems to fail on Mac OS X on Travis.
DICT_OBJECTS = ([{obj: obj} for obj in PRIMITIVE_OBJECTS if obj.__hash__ is not None and type(obj).__module__ != "numpy"] +
# DICT_OBJECTS = ([{obj: obj} for obj in BASE_OBJECTS if obj.__hash__ is not None] +
DICT_OBJECTS = ([{obj: obj} for obj in PRIMITIVE_OBJECTS
if (obj.__hash__ is not None and
type(obj).__module__ != "numpy")] +
[{0: obj} for obj in BASE_OBJECTS])
RAY_TEST_OBJECTS = BASE_OBJECTS + LIST_OBJECTS + TUPLE_OBJECTS + DICT_OBJECTS
@@ -124,7 +149,10 @@ try:
cloudpickle.dumps(Point)
except AttributeError:
cloudpickle_command = "pip install --upgrade cloudpickle"
raise Exception("You have an older version of cloudpickle that is not able to serialize namedtuples. Try running \n\n{}\n\n".format(cloudpickle_command))
raise Exception("You have an older version of cloudpickle that is not able "
"to serialize namedtuples. Try running "
"\n\n{}\n\n".format(cloudpickle_command))
class SerializationTest(unittest.TestCase):
@@ -155,7 +183,7 @@ class SerializationTest(unittest.TestCase):
# Check that exceptions are thrown when we serialize the recursive objects.
for obj in recursive_objects:
self.assertRaises(Exception, lambda : ray.put(obj))
self.assertRaises(Exception, lambda: ray.put(obj))
ray.worker.cleanup()
@@ -181,6 +209,7 @@ class SerializationTest(unittest.TestCase):
ray.worker.cleanup()
class WorkerTest(unittest.TestCase):
def testPythonWorkers(self):
@@ -228,6 +257,7 @@ class WorkerTest(unittest.TestCase):
ray.worker.cleanup()
class APITest(unittest.TestCase):
def testRegisterClass(self):
@@ -237,10 +267,10 @@ class APITest(unittest.TestCase):
# throws an exception.
class TempClass(object):
pass
self.assertRaises(Exception, lambda : ray.put(Foo))
self.assertRaises(Exception, lambda: ray.put(Foo))
# Check that registering a class that Ray cannot serialize efficiently
# raises an exception.
self.assertRaises(Exception, lambda : ray.register_class(type(True)))
self.assertRaises(Exception, lambda: ray.register_class(type(True)))
# Check that registering the same class with pickle works.
ray.register_class(type(float), pickle=True)
self.assertEqual(ray.get(ray.put(float)), float)
@@ -328,7 +358,9 @@ class APITest(unittest.TestCase):
print("Still using old definition of f, trying again.")
# Test that we can close over plain old data.
data = [np.zeros([3, 5]), (1, 2, "a"), [0.0, 1.0, 1 << 62], 1 << 60, {"a": np.zeros(3)}]
data = [np.zeros([3, 5]), (1, 2, "a"), [0.0, 1.0, 1 << 62], 1 << 60,
{"a": np.zeros(3)}]
@ray.remote
def g():
return data
@@ -339,18 +371,22 @@ class APITest(unittest.TestCase):
def h():
return np.zeros([3, 5])
assert_equal(ray.get(h.remote()), np.zeros([3, 5]))
@ray.remote
def j():
return time.time()
ray.get(j.remote())
# Test that we can define remote functions that call other remote functions.
# Test that we can define remote functions that call other remote
# functions.
@ray.remote
def k(x):
return x + 1
@ray.remote
def l(x):
return ray.get(k.remote(x))
@ray.remote
def m(x):
return ray.get(l.remote(x))
@@ -398,7 +434,7 @@ class APITest(unittest.TestCase):
# Verify that calling wait with duplicate object IDs throws an exception.
x = ray.put(1)
self.assertRaises(Exception, lambda : ray.wait([x, x]))
self.assertRaises(Exception, lambda: ray.wait([x, x]))
ray.worker.cleanup()
@@ -435,11 +471,14 @@ class APITest(unittest.TestCase):
ray.worker.cleanup()
def testCachingEnvironmentVariables(self):
# Test that we can define environment variables before the driver is connected.
# Test that we can define environment variables before the driver is
# connected.
def foo_initializer():
return 1
def bar_initializer():
return []
def bar_reinitializer(bar):
return []
ray.env.foo = ray.EnvironmentVariable(foo_initializer)
@@ -448,6 +487,7 @@ class APITest(unittest.TestCase):
@ray.remote
def use_foo():
return ray.env.foo
@ray.remote
def use_bar():
ray.env.bar.append(1)
@@ -463,16 +503,20 @@ class APITest(unittest.TestCase):
ray.worker.cleanup()
def testCachingFunctionsToRun(self):
# Test that we export functions to run on all workers before the driver is connected.
# Test that we export functions to run on all workers before the driver is
# connected.
def f(worker_info):
sys.path.append(1)
ray.worker.global_worker.run_function_on_all_workers(f)
def f(worker_info):
sys.path.append(2)
ray.worker.global_worker.run_function_on_all_workers(f)
def g(worker_info):
sys.path.append(3)
ray.worker.global_worker.run_function_on_all_workers(g)
def f(worker_info):
sys.path.append(4)
ray.worker.global_worker.run_function_on_all_workers(f)
@@ -505,13 +549,16 @@ class APITest(unittest.TestCase):
def f(worker_info):
sys.path.append("fake_directory")
ray.worker.global_worker.run_function_on_all_workers(f)
@ray.remote
def get_path1():
return sys.path
self.assertEqual("fake_directory", ray.get(get_path1.remote())[-1])
def f(worker_info):
sys.path.pop(-1)
ray.worker.global_worker.run_function_on_all_workers(f)
# Create a second remote function to guarantee that when we call
# get_path2.remote(), the second function to run will have been run on the
# worker.
@@ -528,6 +575,7 @@ class APITest(unittest.TestCase):
def f(worker_info):
sys.path.append(worker_info)
ray.worker.global_worker.run_function_on_all_workers(f)
@ray.remote
def get_path():
time.sleep(1)
@@ -542,6 +590,7 @@ class APITest(unittest.TestCase):
counters = [worker_info["counter"] for worker_info in worker_infos]
# We use range(11) because the driver also runs the function.
self.assertEqual(set(counters), set(range(11)))
# Clean up the worker paths.
def f(worker_info):
sys.path.pop(-1)
@@ -555,7 +604,8 @@ class APITest(unittest.TestCase):
def events():
# This is a hack for getting the event log. It is not part of the API.
keys = ray.worker.global_worker.redis_client.keys("event_log:*")
return [ray.worker.global_worker.redis_client.lrange(key, 0, -1) for key in keys]
return [ray.worker.global_worker.redis_client.lrange(key, 0, -1)
for key in keys]
def wait_for_num_events(num_events, timeout=10):
start_time = time.time()
@@ -604,25 +654,28 @@ class APITest(unittest.TestCase):
# accidentally call an older version.
ray.init(num_workers=2)
num_remote_functions = 100
num_calls = 200
@ray.remote
def f():
return 1
results1 = [f.remote() for _ in range(num_calls)]
@ray.remote
def f():
return 2
results2 = [f.remote() for _ in range(num_calls)]
@ray.remote
def f():
return 3
results3 = [f.remote() for _ in range(num_calls)]
@ray.remote
def f():
return 4
results4 = [f.remote() for _ in range(num_calls)]
@ray.remote
def f():
return 5
@@ -637,16 +690,20 @@ class APITest(unittest.TestCase):
@ray.remote
def g():
return 1
@ray.remote
@ray.remote # noqa: F811
def g():
return 2
@ray.remote
@ray.remote # noqa: F811
def g():
return 3
@ray.remote
@ray.remote # noqa: F811
def g():
return 4
@ray.remote
@ray.remote # noqa: F811
def g():
return 5
@@ -668,6 +725,7 @@ class APITest(unittest.TestCase):
ray.worker.cleanup()
class PythonModeTest(unittest.TestCase):
def testPythonMode(self):
@@ -678,17 +736,21 @@ class PythonModeTest(unittest.TestCase):
def f():
return np.ones([3, 4, 5])
xref = f.remote()
assert_equal(xref, np.ones([3, 4, 5])) # remote functions should return by value
assert_equal(xref, ray.get(xref)) # ray.get should be the identity
# Remote functions should return by value.
assert_equal(xref, np.ones([3, 4, 5]))
# Check that ray.get is the identity.
assert_equal(xref, ray.get(xref))
y = np.random.normal(size=[11, 12])
assert_equal(y, ray.put(y)) # ray.put should be the identity
# Check that ray.put is the identity.
assert_equal(y, ray.put(y))
# make sure objects are immutable, this example is why we need to copy
# Make sure objects are immutable, this example is why we need to copy
# arguments before passing them into remote functions in python mode
aref = test_functions.python_mode_f.remote()
assert_equal(aref, np.array([0, 0]))
bref = test_functions.python_mode_g.remote(aref)
assert_equal(aref, np.array([0, 0])) # python_mode_g should not mutate aref
# Make sure python_mode_g does not mutate aref.
assert_equal(aref, np.array([0, 0]))
assert_equal(bref, np.array([1, 0]))
ray.worker.cleanup()
@@ -699,6 +761,7 @@ class PythonModeTest(unittest.TestCase):
def l_init():
return []
def l_reinit(l):
return []
ray.env.l = ray.EnvironmentVariable(l_init, l_reinit)
@@ -717,7 +780,8 @@ class PythonModeTest(unittest.TestCase):
assert_equal(ray.get(use_l.remote()), [1])
assert_equal(ray.get(use_l.remote()), [1])
# Make sure the local copy of the environment variable has not been mutated.
# Make sure the local copy of the environment variable has not been
# mutated.
assert_equal(l, [])
l = ray.env.l
assert_equal(l, [])
@@ -730,6 +794,7 @@ class PythonModeTest(unittest.TestCase):
ray.worker.cleanup()
class EnvironmentVariablesTest(unittest.TestCase):
def testEnvironmentVariables(self):
@@ -739,6 +804,7 @@ class EnvironmentVariablesTest(unittest.TestCase):
def foo_initializer():
return 1
def foo_reinitializer(foo):
return foo
@@ -752,7 +818,8 @@ class EnvironmentVariablesTest(unittest.TestCase):
self.assertEqual(ray.get(use_foo.remote()), 1)
self.assertEqual(ray.get(use_foo.remote()), 1)
# Test that we can add a variable to the key-value store, mutate it, and reset it.
# Test that we can add a variable to the key-value store, mutate it, and
# reset it.
def bar_initializer():
return [1, 2, 3]
@@ -771,6 +838,7 @@ class EnvironmentVariablesTest(unittest.TestCase):
def baz_initializer():
return np.zeros([4])
def baz_reinitializer(baz):
for i in range(len(baz)):
baz[i] = 0
@@ -794,6 +862,7 @@ class EnvironmentVariablesTest(unittest.TestCase):
def qux_initializer():
return 0
def qux_reinitializer(x):
return x + 1
@@ -815,6 +884,7 @@ class EnvironmentVariablesTest(unittest.TestCase):
def foo_initializer():
return []
def foo_reinitializer(foo):
return []
@@ -846,6 +916,7 @@ class EnvironmentVariablesTest(unittest.TestCase):
ray.worker.cleanup()
class UtilsTest(unittest.TestCase):
def testCopyingDirectory(self):
@@ -894,6 +965,7 @@ class UtilsTest(unittest.TestCase):
ray.worker.cleanup()
class ResourcesTest(unittest.TestCase):
def testResourceConstraints(self):
@@ -901,13 +973,16 @@ class ResourcesTest(unittest.TestCase):
ray.init(num_workers=num_workers, num_cpus=10, num_gpus=2)
# Attempt to wait for all of the workers to start up.
ray.worker.global_worker.run_function_on_all_workers(lambda worker_info: sys.path.append(worker_info["counter"]))
ray.worker.global_worker.run_function_on_all_workers(
lambda worker_info: sys.path.append(worker_info["counter"]))
@ray.remote(num_cpus=0)
def get_worker_id():
time.sleep(1)
return sys.path[-1]
while True:
if len(set(ray.get([get_worker_id.remote() for _ in range(num_workers)]))) == num_workers:
if len(set(ray.get([get_worker_id.remote()
for _ in range(num_workers)]))) == num_workers:
break
time_buffer = 0.3
@@ -974,13 +1049,16 @@ class ResourcesTest(unittest.TestCase):
ray.init(num_workers=num_workers, num_cpus=10, num_gpus=10)
# Attempt to wait for all of the workers to start up.
ray.worker.global_worker.run_function_on_all_workers(lambda worker_info: sys.path.append(worker_info["counter"]))
ray.worker.global_worker.run_function_on_all_workers(
lambda worker_info: sys.path.append(worker_info["counter"]))
@ray.remote(num_cpus=0)
def get_worker_id():
time.sleep(1)
return sys.path[-1]
while True:
if len(set(ray.get([get_worker_id.remote() for _ in range(num_workers)]))) == num_workers:
if len(set(ray.get([get_worker_id.remote()
for _ in range(num_workers)]))) == num_workers:
break
@ray.remote(num_cpus=1, num_gpus=9)
@@ -1021,8 +1099,8 @@ class ResourcesTest(unittest.TestCase):
def testMultipleLocalSchedulers(self):
# This test will define a bunch of tasks that can only be assigned to
# specific local schedulers, and we will check that they are assigned to the
# correct local schedulers.
# specific local schedulers, and we will check that they are assigned to
# the correct local schedulers.
address_info = ray.worker._init(start_ray_local=True,
num_local_schedulers=3,
num_cpus=[100, 5, 10],
@@ -1088,7 +1166,8 @@ class ResourcesTest(unittest.TestCase):
results.append(run_on_0_2.remote())
return names, results
store_names = [object_store_address.name for object_store_address in address_info["object_store_addresses"]]
store_names = [object_store_address.name for object_store_address
in address_info["object_store_addresses"]]
def validate_names_and_results(names, results):
for name, result in zip(names, ray.get(results)):
@@ -1099,7 +1178,8 @@ class ResourcesTest(unittest.TestCase):
elif name == "run_on_2":
self.assertIn(result, [store_names[2]])
elif name == "run_on_0_1_2":
self.assertIn(result, [store_names[0], store_names[1], store_names[2]])
self.assertIn(result, [store_names[0], store_names[1],
store_names[2]])
elif name == "run_on_1_2":
self.assertIn(result, [store_names[1], store_names[2]])
elif name == "run_on_0_2":
@@ -1128,6 +1208,7 @@ class ResourcesTest(unittest.TestCase):
ray.worker.cleanup()
class WorkerPoolTests(unittest.TestCase):
def tearDown(self):
@@ -1177,6 +1258,7 @@ class WorkerPoolTests(unittest.TestCase):
ray.worker.cleanup()
class SchedulingAlgorithm(unittest.TestCase):
def attempt_to_load_balance(self, remote_function, args, total_tasks,
@@ -1184,21 +1266,24 @@ class SchedulingAlgorithm(unittest.TestCase):
num_attempts=20):
attempts = 0
while attempts < num_attempts:
locations = ray.get([remote_function.remote(*args) for _ in range(total_tasks)])
locations = ray.get([remote_function.remote(*args)
for _ in range(total_tasks)])
names = set(locations)
counts = [locations.count(name) for name in names]
print("Counts are {}.".format(counts))
if len(names) == num_local_schedulers and all([count >= minimum_count for count in counts]):
if len(names) == num_local_schedulers and all([count >= minimum_count
for count in counts]):
break
attempts += 1
self.assertLess(attempts, num_attempts)
def testLoadBalancing(self):
# This test ensures that tasks are being assigned to all local schedulers in
# a roughly equal manner.
# This test ensures that tasks are being assigned to all local schedulers
# in a roughly equal manner.
num_workers = 21
num_local_schedulers = 3
ray.worker._init(start_ray_local=True, num_workers=num_workers, num_local_schedulers=num_local_schedulers)
ray.worker._init(start_ray_local=True, num_workers=num_workers,
num_local_schedulers=num_local_schedulers)
@ray.remote
def f():
@@ -1211,11 +1296,12 @@ class SchedulingAlgorithm(unittest.TestCase):
ray.worker.cleanup()
def testLoadBalancingWithDependencies(self):
# This test ensures that tasks are being assigned to all local schedulers in
# a roughly equal manner even when the tasks have dependencies.
# This test ensures that tasks are being assigned to all local schedulers
# in a roughly equal manner even when the tasks have dependencies.
num_workers = 3
num_local_schedulers = 3
ray.worker._init(start_ray_local=True, num_workers=num_workers, num_local_schedulers=num_local_schedulers)
ray.worker._init(start_ray_local=True, num_workers=num_workers,
num_local_schedulers=num_local_schedulers)
@ray.remote
def f(x):
@@ -1229,5 +1315,6 @@ class SchedulingAlgorithm(unittest.TestCase):
ray.worker.cleanup()
if __name__ == "__main__":
unittest.main(verbosity=2)
+49 -38
View File
@@ -11,6 +11,7 @@ import redis
# Import flatbuffer bindings.
from ray.core.generated.TaskReply import TaskReply
class TaskTests(unittest.TestCase):
def testSubmittingTasks(self):
@@ -93,7 +94,7 @@ class TaskTests(unittest.TestCase):
def f():
return 1
n = 10 ** 4 # TODO(pcm): replace by 10 ** 5 once this is faster
n = 10 ** 4 # TODO(pcm): replace by 10 ** 5 once this is faster.
l = ray.get([f.remote() for _ in range(n)])
self.assertEqual(l, n * [1])
@@ -123,12 +124,14 @@ class TaskTests(unittest.TestCase):
time.sleep(x)
for i in range(1, 5):
x_ids = [g.remote(np.random.uniform(0, i)) for _ in range(2 * num_workers)]
x_ids = [g.remote(np.random.uniform(0, i))
for _ in range(2 * num_workers)]
ray.wait(x_ids, num_returns=len(x_ids))
self.assertTrue(ray.services.all_processes_alive())
ray.worker.cleanup()
class ReconstructionTests(unittest.TestCase):
num_local_schedulers = 1
@@ -144,14 +147,10 @@ class ReconstructionTests(unittest.TestCase):
plasma_addresses = []
objstore_memory = (self.plasma_store_memory // self.num_local_schedulers)
for i in range(self.num_local_schedulers):
plasma_addresses.append(
ray.services.start_objstore(node_ip_address, redis_address,
objstore_memory=objstore_memory)
)
address_info = {
"redis_address": redis_address,
"object_store_addresses": plasma_addresses,
}
plasma_addresses.append(ray.services.start_objstore(
node_ip_address, redis_address, objstore_memory=objstore_memory))
address_info = {"redis_address": redis_address,
"object_store_addresses": plasma_addresses}
# Start the rest of the services in the Ray cluster.
ray.worker._init(address_info=address_info, start_ray_local=True,
@@ -180,7 +179,8 @@ class ReconstructionTests(unittest.TestCase):
# total number of local schedulers to account for NIL_LOCAL_SCHEDULER_ID.
# This is the local scheduler ID associated with the driver task, since it
# is not scheduled by a particular local scheduler.
self.assertEqual(len(set(local_scheduler_ids)), self.num_local_schedulers + 1)
self.assertEqual(len(set(local_scheduler_ids)),
self.num_local_schedulers + 1)
# Clean up the Ray cluster.
ray.worker.cleanup()
@@ -218,7 +218,7 @@ class ReconstructionTests(unittest.TestCase):
num_chunks = 4 * self.num_local_schedulers
chunk = num_objects // num_chunks
for i in range(num_chunks):
values = ray.get(args[i * chunk : (i + 1) * chunk])
values = ray.get(args[i * chunk:(i + 1) * chunk])
del values
def testRecursive(self):
@@ -261,14 +261,14 @@ class ReconstructionTests(unittest.TestCase):
self.assertEqual(value[0], i)
# Get 10 values randomly.
for _ in range(10):
i = np.random.randint(num_objects)
i = np.random.randint(num_objects)
value = ray.get(args[i])
self.assertEqual(value[0], i)
# Get values sequentially, in chunks.
num_chunks = 4 * self.num_local_schedulers
chunk = num_objects // num_chunks
for i in range(num_chunks):
values = ray.get(args[i * chunk : (i + 1) * chunk])
values = ray.get(args[i * chunk:(i + 1) * chunk])
del values
def testMultipleRecursive(self):
@@ -316,7 +316,7 @@ class ReconstructionTests(unittest.TestCase):
self.assertEqual(value[0], i)
# Get 10 values randomly.
for _ in range(10):
i = np.random.randint(num_objects)
i = np.random.randint(num_objects)
value = ray.get(args[i])
self.assertEqual(value[0], i)
@@ -391,7 +391,8 @@ class ReconstructionTests(unittest.TestCase):
return len(errors) >= min_errors
errors = self.wait_for_errors(error_check)
# Make sure all the errors have the correct type.
self.assertTrue(all(error[b"type"] == b"object_hash_mismatch" for error in errors))
self.assertTrue(all(error[b"type"] == b"object_hash_mismatch"
for error in errors))
# Make sure all the errors have the correct function name.
self.assertTrue(all(error[b"data"] == b"__main__.foo" for error in errors))
@@ -462,20 +463,26 @@ class ReconstructionTests(unittest.TestCase):
self.assertEqual(value[0], i)
put_arg_task.remote(size)
def error_check(errors):
return len(errors) > 1
errors = self.wait_for_errors(error_check)
# Make sure all the errors have the correct type.
self.assertTrue(all(error[b"type"] == b"put_reconstruction" for error in errors))
self.assertTrue(all(error[b"data"] == b"__main__.put_arg_task" for error in errors))
self.assertTrue(all(error[b"type"] == b"put_reconstruction"
for error in errors))
self.assertTrue(all(error[b"data"] == b"__main__.put_arg_task"
for error in errors))
put_task.remote(size)
def error_check(errors):
return any(error[b"data"] == b"__main__.put_task" for error in errors)
errors = self.wait_for_errors(error_check)
# Make sure all the errors have the correct type.
self.assertTrue(all(error[b"type"] == b"put_reconstruction" for error in errors))
self.assertTrue(any(error[b"data"] == b"__main__.put_task" for error in errors))
self.assertTrue(all(error[b"type"] == b"put_reconstruction"
for error in errors))
self.assertTrue(any(error[b"data"] == b"__main__.put_task"
for error in errors))
def testDriverPutErrors(self):
# Define the size of one task's return argument so that the combined sum of
@@ -511,11 +518,14 @@ class ReconstructionTests(unittest.TestCase):
# were evicted and whose originating tasks are still running, this
# for-loop should hang on its first iteration and push an error to the
# driver.
ray.worker.global_worker.local_scheduler_client.reconstruct_object(args[0].id())
ray.worker.global_worker.local_scheduler_client.reconstruct_object(
args[0].id())
def error_check(errors):
return len(errors) > 1
errors = self.wait_for_errors(error_check)
self.assertTrue(all(error[b"type"] == b"put_reconstruction" for error in errors))
self.assertTrue(all(error[b"type"] == b"put_reconstruction"
for error in errors))
self.assertTrue(all(error[b"data"] == b"Driver" for error in errors))
@@ -526,26 +536,27 @@ class ReconstructionTestsMultinode(ReconstructionTests):
num_local_schedulers = 4
# NOTE(swang): This test tries to launch 1000 workers and breaks.
#class WorkerPoolTests(unittest.TestCase):
# class WorkerPoolTests(unittest.TestCase):
#
# def tearDown(self):
# ray.worker.cleanup()
# def tearDown(self):
# ray.worker.cleanup()
#
# def testBlockingTasks(self):
# @ray.remote
# def f(i, j):
# return (i, j)
# def testBlockingTasks(self):
# @ray.remote
# def f(i, j):
# return (i, j)
#
# @ray.remote
# def g(i):
# # Each instance of g submits and blocks on the result of another remote
# # task.
# object_ids = [f.remote(i, j) for j in range(10)]
# return ray.get(object_ids)
# @ray.remote
# def g(i):
# # Each instance of g submits and blocks on the result of another remote
# # task.
# object_ids = [f.remote(i, j) for j in range(10)]
# return ray.get(object_ids)
#
# ray.init(num_workers=1)
# ray.get([g.remote(i) for i in range(1000)])
# ray.worker.cleanup()
# ray.init(num_workers=1)
# ray.get([g.remote(i) for i in range(1000)])
# ray.worker.cleanup()
if __name__ == "__main__":
unittest.main(verbosity=2)
+27 -34
View File
@@ -2,11 +2,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import uuid
import tensorflow as tf
import ray
from numpy.testing import assert_almost_equal
import tensorflow as tf
import unittest
import ray
def make_linear_network(w_name=None, b_name=None):
# Define the inputs.
@@ -17,7 +18,9 @@ def make_linear_network(w_name=None, b_name=None):
b = tf.Variable(tf.zeros([1]), name=b_name)
y = w * x_data + b
# Return the loss and weight initializer.
return tf.reduce_mean(tf.square(y - y_data)), tf.global_variables_initializer(), x_data, y_data
return (tf.reduce_mean(tf.square(y - y_data)),
tf.global_variables_initializer(), x_data, y_data)
class NetActor(object):
@@ -40,6 +43,7 @@ class NetActor(object):
def get_weights(self):
return self.values[0].get_weights()
class TrainActor(object):
def __init__(self):
@@ -57,11 +61,13 @@ class TrainActor(object):
def training_step(self, weights):
_, variables, _, sess, grads, _, placeholders = self.values
variables.set_weights(weights)
return sess.run([grad[0] for grad in grads], feed_dict=dict(zip(placeholders, [[1]*100, [2]*100])))
return sess.run([grad[0] for grad in grads],
feed_dict=dict(zip(placeholders, [[1] * 100, [2] * 100])))
def get_weights(self):
return self.values[1].get_weights()
class TensorFlowTest(unittest.TestCase):
def testTensorFlowVariables(self):
@@ -113,9 +119,6 @@ class TensorFlowTest(unittest.TestCase):
net1 = NetActor()
net2 = NetActor()
net_vars1, init1, sess1 = net1.values
net_vars2, init2, sess2 = net2.values
# This is checking that the variable names of the two nets are the same,
# i.e. that the names in the weight dictionaries are the same
net1.values[0].set_weights(net2.values[0].get_weights())
@@ -125,7 +128,8 @@ class TensorFlowTest(unittest.TestCase):
# Test that different networks on the same worker are independent and
# we can get/set their weights without any interaction.
def testNetworksIndependent(self):
# Note we use only one worker to ensure that all of the remote functions run on the same worker.
# Note we use only one worker to ensure that all of the remote functions
# run on the same worker.
ray.init(num_workers=1)
net1 = NetActor()
net2 = NetActor()
@@ -151,15 +155,15 @@ class TensorFlowTest(unittest.TestCase):
ray.worker.cleanup()
# This test creates an additional network on the driver so that the tensorflow
# variables on the driver and the worker differ.
# This test creates an additional network on the driver so that the
# tensorflow variables on the driver and the worker differ.
def testNetworkDriverWorkerIndependent(self):
ray.init(num_workers=1)
# Create a network on the driver locally.
sess1 = tf.Session()
loss1, init1, _, _ = make_linear_network()
net_vars1 = ray.experimental.TensorFlowVariables(loss1, sess1)
ray.experimental.TensorFlowVariables(loss1, sess1)
sess1.run(init1)
net2 = ray.actor(NetActor)()
@@ -194,39 +198,28 @@ class TensorFlowTest(unittest.TestCase):
ray.worker.cleanup()
def testRemoteTrainingLoss(self):
ray.init(num_workers=2)
net = ray.actor(TrainActor)()
loss, variables, _, sess, grads, train, placeholders = TrainActor().values
before_acc = sess.run(loss, feed_dict=dict(zip(placeholders, [[2]*100, [4]*100])))
before_acc = sess.run(loss, feed_dict=dict(zip(placeholders,
[[2] * 100, [4] * 100])))
for _ in range(3):
gradients_list = ray.get([net.training_step(variables.get_weights()) for _ in range(2)])
mean_grads = [sum([gradients[i] for gradients in gradients_list]) / len(gradients_list) for i in range(len(gradients_list[0]))]
feed_dict = {grad[0]: mean_grad for (grad, mean_grad) in zip(grads, mean_grads)}
gradients_list = ray.get([net.training_step(variables.get_weights())
for _ in range(2)])
mean_grads = [sum([gradients[i] for gradients in gradients_list]) /
len(gradients_list) for i in range(len(gradients_list[0]))]
feed_dict = {grad[0]: mean_grad for (grad, mean_grad)
in zip(grads, mean_grads)}
sess.run(train, feed_dict=feed_dict)
after_acc = sess.run(loss, feed_dict=dict(zip(placeholders, [[2]*100, [4]*100])))
after_acc = sess.run(loss, feed_dict=dict(zip(placeholders,
[[2] * 100, [4] * 100])))
self.assertTrue(before_acc < after_acc)
ray.worker.cleanup()
def testVariablesControlDependencies(self):
ray.init(num_workers=1)
# Creates a network and appends a momentum optimizer.
sess = tf.Session()
loss, init, _, _ = make_linear_network()
minimizer = tf.train.MomentumOptimizer(0.9, 0.9).minimize(loss)
net_vars = ray.experimental.TensorFlowVariables(minimizer, sess)
sess.run(init)
# Tests if all variables are properly retrieved, 2 variables and 2 momentum
# variables.
self.assertEqual(len(net_vars.variables.items()), 4)
ray.worker.cleanup()
if __name__ == "__main__":
unittest.main(verbosity=2)
+111 -58
View File
@@ -6,17 +6,17 @@ import collections
import datetime
import json
import numpy as np
import os
import redis
import sys
import time
import websockets
# Import flatbuffer bindings.
from ray.core.generated.LocalSchedulerInfoMessage import LocalSchedulerInfoMessage
from ray.core.generated.LocalSchedulerInfoMessage import \
LocalSchedulerInfoMessage
parser = argparse.ArgumentParser(description="parse information for the web ui")
parser.add_argument("--redis-address", required=True, type=str, help="the address to use for redis")
parser = argparse.ArgumentParser(
description="parse information for the web ui")
parser.add_argument("--redis-address", required=True, type=str,
help="the address to use for redis")
loop = asyncio.get_event_loop()
@@ -25,27 +25,36 @@ IDENTIFIER_LENGTH = 20
# This prefix must match the value defined in ray_redis_module.cc.
DB_CLIENT_PREFIX = b"CL:"
def hex_identifier(identifier):
return binascii.hexlify(identifier).decode()
def identifier(hex_identifier):
return binascii.unhexlify(hex_identifier)
def key_to_hex_identifier(key):
return hex_identifier(key[(key.index(b":") + 1):(key.index(b":") + IDENTIFIER_LENGTH + 1)])
return hex_identifier(
key[(key.index(b":") + 1):(key.index(b":") + IDENTIFIER_LENGTH + 1)])
def timestamp_to_date_string(timestamp):
"""Convert a time stamp returned by time.time() to a formatted string."""
return datetime.datetime.fromtimestamp(timestamp).strftime("%Y/%m/%d %H:%M:%S")
return (datetime.datetime.fromtimestamp(timestamp)
.strftime("%Y/%m/%d %H:%M:%S"))
def key_to_hex_identifiers(key):
# Extract worker_id and task_id from key of the form prefix:worker_id:task_id.
# Extract worker_id and task_id from key of the form
# prefix:worker_id:task_id.
offset = key.index(b":") + 1
worker_id = hex_identifier(key[offset:(offset + IDENTIFIER_LENGTH)])
offset += IDENTIFIER_LENGTH + 1
task_id = hex_identifier(key[offset:(offset + IDENTIFIER_LENGTH)])
return worker_id, task_id
async def hgetall_as_dict(redis_conn, key):
fields = await redis_conn.execute("hgetall", key)
return {fields[2 * i]: fields[2 * i + 1] for i in range(len(fields) // 2)}
@@ -55,6 +64,7 @@ async def hgetall_as_dict(redis_conn, key):
local_schedulers = {}
errors = []
def duration_to_string(duration):
"""Format a duration in seconds as a string.
@@ -79,8 +89,10 @@ def duration_to_string(duration):
duration_str = "{} microseconds".format(int(duration * 1000000))
return duration_str
async def handle_get_statistics(websocket, redis_conn):
cluster_start_time = float(await redis_conn.execute("get", "redis_start_time"))
cluster_start_time = float(await redis_conn.execute("get",
"redis_start_time"))
start_date = timestamp_to_date_string(cluster_start_time)
uptime = duration_to_string(time.time() - cluster_start_time)
@@ -90,7 +102,9 @@ async def handle_get_statistics(websocket, redis_conn):
for client_key in client_keys:
client_fields = await hgetall_as_dict(redis_conn, client_key)
clients.append(client_fields)
ip_addresses = list(set([client[b"node_ip_address"].decode("ascii") for client in clients if client[b"client_type"] == b"local_scheduler"]))
ip_addresses = list(set([client[b"node_ip_address"].decode("ascii")
for client in clients
if client[b"client_type"] == b"local_scheduler"]))
num_nodes = len(ip_addresses)
reply = {"uptime": uptime,
"start_date": start_date,
@@ -98,18 +112,22 @@ async def handle_get_statistics(websocket, redis_conn):
"addresses": ip_addresses}
await websocket.send(json.dumps(reply))
async def handle_get_drivers(websocket, redis_conn):
keys = await redis_conn.execute("keys", "Drivers:*")
drivers = []
for key in keys:
driver_fields = await hgetall_as_dict(redis_conn, key)
driver_info = {"node ip address": driver_fields[b"node_ip_address"].decode("ascii"),
"name": driver_fields[b"name"].decode("ascii")}
driver_info = {
"node ip address": driver_fields[b"node_ip_address"].decode("ascii"),
"name": driver_fields[b"name"].decode("ascii")}
driver_info["start time"] = timestamp_to_date_string(float(driver_fields[b"start_time"]))
driver_info["start time"] = timestamp_to_date_string(
float(driver_fields[b"start_time"]))
if b"end_time" in driver_fields:
duration = float(driver_fields[b"end_time"]) - float(driver_fields[b"start_time"])
duration = (float(driver_fields[b"end_time"]) -
float(driver_fields[b"start_time"]))
else:
duration = time.time() - float(driver_fields[b"start_time"])
driver_info["duration"] = duration_to_string(duration)
@@ -129,17 +147,20 @@ async def handle_get_drivers(websocket, redis_conn):
reply = sorted(drivers, key=(lambda driver: driver["start time"]))[::-1]
await websocket.send(json.dumps(reply))
async def listen_for_errors(redis_ip_address, redis_port):
pubsub_conn = await aioredis.create_connection((redis_ip_address, redis_port), loop=loop)
data_conn = await aioredis.create_connection((redis_ip_address, redis_port), loop=loop)
pubsub_conn = await aioredis.create_connection(
(redis_ip_address, redis_port), loop=loop)
data_conn = await aioredis.create_connection((redis_ip_address, redis_port),
loop=loop)
error_pattern = "__keyspace@0__:ErrorKeys"
psub = await pubsub_conn.execute_pubsub("psubscribe", error_pattern)
await pubsub_conn.execute_pubsub("psubscribe", error_pattern)
channel = pubsub_conn.pubsub_patterns[error_pattern]
print("Listening for error messages...")
index = 0
while (await channel.wait_message()):
msg = await channel.get()
await channel.get()
info = await data_conn.execute("lrange", "ErrorKeys", index, -1)
for error_key in info:
@@ -154,6 +175,7 @@ async def listen_for_errors(redis_ip_address, redis_port):
"error": result})
index += 1
async def handle_get_errors(websocket):
"""Send error messages to the frontend."""
await websocket.send(json.dumps(errors))
@@ -161,6 +183,7 @@ async def handle_get_errors(websocket):
node_info = collections.OrderedDict()
worker_info = collections.OrderedDict()
async def handle_get_recent_tasks(websocket, redis_conn, num_tasks):
# First update the cache of worker information.
worker_keys = await redis_conn.execute("keys", "Workers:*")
@@ -168,7 +191,8 @@ async def handle_get_recent_tasks(websocket, redis_conn, num_tasks):
worker_id = hex_identifier(key[len("Workers:"):])
if worker_id not in worker_info:
worker_info[worker_id] = await hgetall_as_dict(redis_conn, key)
node_ip_address = worker_info[worker_id][b"node_ip_address"].decode("ascii")
node_ip_address = (worker_info[worker_id][b"node_ip_address"]
.decode("ascii"))
if node_ip_address not in node_info:
node_info[node_ip_address] = {"workers": []}
node_info[node_ip_address]["workers"].append(worker_id)
@@ -183,7 +207,8 @@ async def handle_get_recent_tasks(websocket, redis_conn, num_tasks):
for key in keys:
content = await redis_conn.execute("lrange", key, "0", "-1")
contents.append(json.loads(content[0].decode()))
timestamps += [timestamp for (timestamp, task, kind, info) in contents[-1] if task == "ray:task"]
timestamps += [timestamp for (timestamp, task, kind, info)
in contents[-1] if task == "ray:task"]
timestamps.sort()
time_cutoff = timestamps[(-2 * num_tasks):][0]
@@ -197,36 +222,49 @@ async def handle_get_recent_tasks(websocket, redis_conn, num_tasks):
num_tasks = 0
task_data = [{"task_data": [],
"num_workers": len(node_info[node_ip_address]["workers"])} for node_ip_address in node_ip_addresses]
"num_workers": len(node_info[node_ip_address]["workers"])}
for node_ip_address in node_ip_addresses]
for i in range(len(keys)):
worker_id, task_id = key_to_hex_identifiers(keys[i])
data = contents[i]
if worker_id not in worker_ids:
# This case should be extremely rare.
raise Exception("A worker ID was not present in the list of worker IDs.")
node_ip_address = worker_info[worker_id][b"node_ip_address"].decode("ascii")
raise Exception("A worker ID was not present in the list of worker "
"IDs.")
node_ip_address = (worker_info[worker_id][b"node_ip_address"]
.decode("ascii"))
worker_index = node_info[node_ip_address]["workers"].index(worker_id)
node_index = node_ip_addresses.index(node_ip_address)
task_times = [timestamp for (timestamp, task, kind, info) in data if task == "ray:task"]
task_times = [timestamp for (timestamp, task, kind, info) in data
if task == "ray:task"]
if task_times[1] <= time_cutoff:
continue
task_get_arguments_times = [timestamp for (timestamp, task, kind, info) in data if task == "ray:task:get_arguments"]
task_execute_times = [timestamp for (timestamp, task, kind, info) in data if task == "ray:task:execute"]
task_store_outputs_times = [timestamp for (timestamp, task, kind, info) in data if task == "ray:task:store_outputs"]
task_info = {"task": task_times,
"get_arguments": task_get_arguments_times,
"execute": task_execute_times,
"store_outputs": task_store_outputs_times,
"worker_index": worker_index,
"node_ip_address": node_ip_address,
"task_formatted_time": duration_to_string(task_times[1] - task_times[0]),
"get_arguments_formatted_time": duration_to_string(task_get_arguments_times[1] - task_get_arguments_times[0])}
task_get_arguments_times = [timestamp for (timestamp, task, kind, info)
in data if task == "ray:task:get_arguments"]
task_execute_times = [timestamp for (timestamp, task, kind, info)
in data if task == "ray:task:execute"]
task_store_outputs_times = [timestamp for (timestamp, task, kind, info)
in data if task == "ray:task:store_outputs"]
task_info = {
"task": task_times,
"get_arguments": task_get_arguments_times,
"execute": task_execute_times,
"store_outputs": task_store_outputs_times,
"worker_index": worker_index,
"node_ip_address": node_ip_address,
"task_formatted_time": duration_to_string(task_times[1] -
task_times[0]),
"get_arguments_formatted_time":
duration_to_string(task_get_arguments_times[1] -
task_get_arguments_times[0])}
if len(task_execute_times) == 2:
task_info["execute_formatted_time"] = duration_to_string(task_execute_times[1] - task_execute_times[0])
task_info["execute_formatted_time"] = duration_to_string(
task_execute_times[1] - task_execute_times[0])
if len(task_store_outputs_times) == 2:
task_info["store_outputs_formatted_time"] = duration_to_string(task_store_outputs_times[1] - task_store_outputs_times[0])
task_info["store_outputs_formatted_time"] = duration_to_string(
task_store_outputs_times[1] - task_store_outputs_times[0])
task_data[node_index]["task_data"].append(task_info)
num_tasks += 1
reply = {"min_time": min_time,
@@ -235,34 +273,41 @@ async def handle_get_recent_tasks(websocket, redis_conn, num_tasks):
"task_data": task_data}
await websocket.send(json.dumps(reply))
async def send_heartbeat_payload(websocket):
"""Send heartbeat updates to the frontend every half second."""
while True:
reply = []
for local_scheduler_id, local_scheduler in local_schedulers.items():
current_time = time.time()
local_scheduler_info = {"local scheduler ID": local_scheduler_id,
"time since heartbeat": duration_to_string(current_time - local_scheduler["last_heartbeat"]),
"time since heartbeat numeric": str(current_time - local_scheduler["last_heartbeat"]),
"node ip address": local_scheduler["node_ip_address"]}
local_scheduler_info = {
"local scheduler ID": local_scheduler_id,
"time since heartbeat":
(duration_to_string(current_time -
local_scheduler["last_heartbeat"])),
"time since heartbeat numeric":
str(current_time - local_scheduler["last_heartbeat"]),
"node ip address": local_scheduler["node_ip_address"]}
reply.append(local_scheduler_info)
# Send the payload to the frontend.
await websocket.send(json.dumps(reply))
# Wait for a little while so as not to overwhelm the frontend.
await asyncio.sleep(0.5)
async def send_heartbeats(websocket, redis_conn):
# First update the local scheduler info locally.
client_keys = await redis_conn.execute("keys", "CL:*")
clients = []
for client_key in client_keys:
client_fields = await hgetall_as_dict(redis_conn, client_key)
if client_fields[b"client_type"] == b"local_scheduler":
local_scheduler_id = hex_identifier(client_fields[b"ray_client_id"])
local_schedulers[local_scheduler_id] = {"node_ip_address": client_fields[b"node_ip_address"].decode("ascii"),
"local_scheduler_socket_name": client_fields[b"local_scheduler_socket_name"].decode("ascii"),
"aux_address": client_fields[b"aux_address"].decode("ascii"),
"last_heartbeat": -1 * np.inf}
local_schedulers[local_scheduler_id] = {
"node_ip_address": client_fields[b"node_ip_address"].decode("ascii"),
"local_scheduler_socket_name":
client_fields[b"local_scheduler_socket_name"].decode("ascii"),
"aux_address": client_fields[b"aux_address"].decode("ascii"),
"last_heartbeat": -1 * np.inf}
# Subscribe to local scheduler heartbeats.
await redis_conn.execute_pubsub("subscribe", "local_schedulers")
@@ -272,7 +317,8 @@ async def send_heartbeats(websocket, redis_conn):
while True:
msg = await redis_conn.pubsub_channels["local_schedulers"].get()
heartbeat = LocalSchedulerInfoMessage.GetRootAsLocalSchedulerInfoMessage(msg, 0)
heartbeat = LocalSchedulerInfoMessage.GetRootAsLocalSchedulerInfoMessage(
msg, 0)
local_scheduler_id_bytes = heartbeat.DbClientId()
local_scheduler_id = hex_identifier(local_scheduler_id_bytes)
if local_scheduler_id not in local_schedulers:
@@ -281,6 +327,7 @@ async def send_heartbeats(websocket, redis_conn):
continue
local_schedulers[local_scheduler_id]["last_heartbeat"] = time.time()
async def cache_data_from_redis(redis_ip_address, redis_port):
"""Open up ports to listen for new updates from Redis."""
# TODO(richard): A lot of code needs to be ported in order to open new
@@ -288,6 +335,7 @@ async def cache_data_from_redis(redis_ip_address, redis_port):
asyncio.ensure_future(listen_for_errors(redis_ip_address, redis_port))
async def handle_get_log_files(websocket, redis_conn):
reply = {}
# First get all keys for the log file lists.
@@ -296,9 +344,11 @@ async def handle_get_log_files(websocket, redis_conn):
node_ip_address = log_file_list_key.decode("ascii").split(":")[1]
reply[node_ip_address] = {}
# Get all of the log filenames for this node IP address.
log_filenames = await redis_conn.execute("lrange", log_file_list_key, 0, -1)
log_filenames = await redis_conn.execute("lrange", log_file_list_key, 0,
-1)
for log_filename in log_filenames:
log_filename_key = "LOGFILE:{}:{}".format(node_ip_address, log_filename.decode("ascii"))
log_filename_key = "LOGFILE:{}:{}".format(node_ip_address,
log_filename.decode("ascii"))
logfile = await redis_conn.execute("lrange", log_filename_key, 0, -1)
logfile = [line.decode("ascii") for line in logfile]
reply[node_ip_address][log_filename.decode("ascii")] = logfile
@@ -306,8 +356,10 @@ async def handle_get_log_files(websocket, redis_conn):
# Send the reply back to the front end.
await websocket.send(json.dumps(reply))
async def serve_requests(websocket, path):
redis_conn = await aioredis.create_connection((redis_ip_address, redis_port), loop=loop)
redis_conn = await aioredis.create_connection((redis_ip_address, redis_port),
loop=loop)
while True:
command = json.loads(await websocket.recv())
print("received command {}".format(command))
@@ -352,10 +404,10 @@ async def serve_requests(websocket, path):
"data_size": content[5].decode()})
await websocket.send(json.dumps(result))
elif command["command"] == "get-object-info":
# TODO(pcm): Get the object here (have to connect to ray) and ship content
# and type back to webclient. One challenge here is that the naive
# implementation will block the web ui backend, which is not ok if it is
# serving multiple users.
# TODO(pcm): Get the object here (have to connect to ray) and ship
# content and type back to webclient. One challenge here is that the
# naive implementation will block the web ui backend, which is not ok if
# it is serving multiple users.
await websocket.send(json.dumps({"object_id": "none"}))
elif command["command"] == "get-tasks":
result = []
@@ -372,7 +424,8 @@ async def serve_requests(websocket, path):
worker_id, task_id = key_to_hex_identifiers(key)
content = await redis_conn.execute("lrange", key, "0", "-1")
data = json.loads(content[0].decode())
begin_and_end_time = [timestamp for (timestamp, task, kind, info) in data if task == "ray:task"]
begin_and_end_time = [timestamp for (timestamp, task, kind, info)
in data if task == "ray:task"]
tasks[worker_id].append({"task_id": task_id,
"start_task": min(begin_and_end_time),
"end_task": max(begin_and_end_time)})
@@ -396,8 +449,8 @@ if __name__ == "__main__":
redis_ip_address, redis_port = redis_address[0], int(redis_address[1])
# The port here must match the value used by the frontend to connect over
# websockets. TODO(richard): Automatically increment the port if it is already
# taken.
# websockets. TODO(richard): Automatically increment the port if it is
# already taken.
port = 8888
loop.run_until_complete(cache_data_from_redis(redis_ip_address, redis_port))