mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:23:10 +08:00
Run flake8 in Travis and make code PEP8 compliant. (#387)
This commit is contained in:
committed by
Philipp Moritz
parent
083e7a28ad
commit
ba02fc0eb0
+22
-12
@@ -2,19 +2,29 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# Ray version string
|
||||
__version__ = "0.01"
|
||||
|
||||
import ctypes
|
||||
# Windows only
|
||||
if hasattr(ctypes, "windll"):
|
||||
# Makes sure that all child processes die when we die
|
||||
# Also makes sure that fatal crashes result in process termination rather than an error dialog (the latter is annoying since we have a lot of processes)
|
||||
# This is done by associating all child processes with a "job" object that imposes this behavior.
|
||||
(lambda kernel32: (lambda job: (lambda n: kernel32.SetInformationJobObject(job, 9, "\0" * 17 + chr(0x8 | 0x4 | 0x20) + "\0" * (n - 18), n))(0x90 if ctypes.sizeof(ctypes.c_void_p) > ctypes.sizeof(ctypes.c_int) else 0x70) and kernel32.AssignProcessToJobObject(job, ctypes.c_void_p(kernel32.GetCurrentProcess())))(ctypes.c_void_p(kernel32.CreateJobObjectW(None, None))) if kernel32 is not None else None)(ctypes.windll.kernel32)
|
||||
|
||||
from ray.worker import register_class, error_info, init, connect, disconnect, get, put, wait, remote, log_event, log_span, flush_log
|
||||
from ray.worker import (register_class, error_info, init, connect, disconnect,
|
||||
get, put, wait, remote, log_event, log_span,
|
||||
flush_log)
|
||||
from ray.actor import actor
|
||||
from ray.actor import get_gpu_ids
|
||||
from ray.worker import EnvironmentVariable, env
|
||||
from ray.worker import SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, SILENT_MODE
|
||||
|
||||
# Ray version string
|
||||
__version__ = "0.01"
|
||||
|
||||
__all__ = ["register_class", "error_info", "init", "connect", "disconnect",
|
||||
"get", "put", "wait", "remote", "log_event", "log_span",
|
||||
"flush_log", "actor", "get_gpu_ids", "EnvironmentVariable", "env",
|
||||
"SCRIPT_MODE", "WORKER_MODE", "PYTHON_MODE", "SILENT_MODE",
|
||||
"__version__"]
|
||||
|
||||
import ctypes
|
||||
# Windows only
|
||||
if hasattr(ctypes, "windll"):
|
||||
# Makes sure that all child processes die when we die. Also makes sure that
|
||||
# fatal crashes result in process termination rather than an error dialog
|
||||
# (the latter is annoying since we have a lot of processes). This is done by
|
||||
# associating all child processes with a "job" object that imposes this
|
||||
# behavior.
|
||||
(lambda kernel32: (lambda job: (lambda n: kernel32.SetInformationJobObject(job, 9, "\0" * 17 + chr(0x8 | 0x4 | 0x20) + "\0" * (n - 18), n))(0x90 if ctypes.sizeof(ctypes.c_void_p) > ctypes.sizeof(ctypes.c_int) else 0x70) and kernel32.AssignProcessToJobObject(job, ctypes.c_void_p(kernel32.GetCurrentProcess())))(ctypes.c_void_p(kernel32.CreateJobObjectW(None, None))) if kernel32 is not None else None)(ctypes.windll.kernel32) # noqa: E501
|
||||
|
||||
+60
-21
@@ -18,6 +18,7 @@ import ray.experimental.state as state
|
||||
# the worker is currently allowed to use.
|
||||
gpu_ids = []
|
||||
|
||||
|
||||
def get_gpu_ids():
|
||||
"""Get the IDs of the GPU that are available to the worker.
|
||||
|
||||
@@ -26,12 +27,15 @@ def get_gpu_ids():
|
||||
"""
|
||||
return gpu_ids
|
||||
|
||||
|
||||
def random_string():
|
||||
return np.random.bytes(20)
|
||||
|
||||
|
||||
def random_actor_id():
|
||||
return ray.local_scheduler.ObjectID(random_string())
|
||||
|
||||
|
||||
def get_actor_method_function_id(attr):
|
||||
"""Get the function ID corresponding to an actor method.
|
||||
|
||||
@@ -47,10 +51,14 @@ def get_actor_method_function_id(attr):
|
||||
assert len(function_id) == 20
|
||||
return ray.local_scheduler.ObjectID(function_id)
|
||||
|
||||
|
||||
def fetch_and_register_actor(key, worker):
|
||||
"""Import an actor."""
|
||||
driver_id, actor_id_str, actor_name, module, pickled_class, assigned_gpu_ids, actor_method_names = \
|
||||
worker.redis_client.hmget(key, ["driver_id", "actor_id", "name", "module", "class", "gpu_ids", "actor_method_names"])
|
||||
(driver_id, actor_id_str, actor_name,
|
||||
module, pickled_class, assigned_gpu_ids,
|
||||
actor_method_names) = worker.redis_client.hmget(
|
||||
key, ["driver_id", "actor_id", "name", "module", "class", "gpu_ids",
|
||||
"actor_method_names"])
|
||||
actor_id = ray.local_scheduler.ObjectID(actor_id_str)
|
||||
actor_name = actor_name.decode("ascii")
|
||||
module = module.decode("ascii")
|
||||
@@ -64,12 +72,14 @@ def fetch_and_register_actor(key, worker):
|
||||
class TemporaryActor(object):
|
||||
pass
|
||||
worker.actors[actor_id_str] = TemporaryActor()
|
||||
|
||||
def temporary_actor_method(*xs):
|
||||
raise Exception("The actor with name {} failed to be imported, and so "
|
||||
"cannot execute this method".format(actor_name))
|
||||
for actor_method_name in actor_method_names:
|
||||
function_id = get_actor_method_function_id(actor_method_name).id()
|
||||
worker.functions[driver_id][function_id] = (actor_method_name, temporary_actor_method)
|
||||
worker.functions[driver_id][function_id] = (actor_method_name,
|
||||
temporary_actor_method)
|
||||
|
||||
try:
|
||||
unpickled_class = pickling.loads(pickled_class)
|
||||
@@ -84,11 +94,15 @@ def fetch_and_register_actor(key, worker):
|
||||
# TODO(pcm): Why is the below line necessary?
|
||||
unpickled_class.__module__ = module
|
||||
worker.actors[actor_id_str] = unpickled_class.__new__(unpickled_class)
|
||||
for (k, v) in inspect.getmembers(unpickled_class, predicate=(lambda x: inspect.isfunction(x) or inspect.ismethod(x))):
|
||||
for (k, v) in inspect.getmembers(
|
||||
unpickled_class, predicate=(lambda x: (inspect.isfunction(x) or
|
||||
inspect.ismethod(x)))):
|
||||
function_id = get_actor_method_function_id(k).id()
|
||||
worker.functions[driver_id][function_id] = (k, v)
|
||||
# We do not set worker.function_properties[driver_id][function_id] because
|
||||
# we currently do need the actor worker to submit new tasks for the actor.
|
||||
# We do not set worker.function_properties[driver_id][function_id]
|
||||
# because we currently do need the actor worker to submit new tasks for
|
||||
# the actor.
|
||||
|
||||
|
||||
def select_local_scheduler(local_schedulers, num_gpus, worker):
|
||||
"""Select a local scheduler to assign this actor to.
|
||||
@@ -119,15 +133,19 @@ def select_local_scheduler(local_schedulers, num_gpus, worker):
|
||||
# Loop through all of the local schedulers.
|
||||
for local_scheduler in local_schedulers:
|
||||
# See if there are enough available GPUs on this local scheduler.
|
||||
local_scheduler_total_gpus = int(float(local_scheduler[b"num_gpus"].decode("ascii")))
|
||||
gpus_in_use = worker.redis_client.hget(local_scheduler[b"ray_client_id"], b"gpus_in_use")
|
||||
local_scheduler_total_gpus = int(float(
|
||||
local_scheduler[b"num_gpus"].decode("ascii")))
|
||||
gpus_in_use = worker.redis_client.hget(local_scheduler[b"ray_client_id"],
|
||||
b"gpus_in_use")
|
||||
gpus_in_use = 0 if gpus_in_use is None else int(gpus_in_use)
|
||||
if gpus_in_use + num_gpus <= local_scheduler_total_gpus:
|
||||
# Attempt to reserve some GPUs for this actor.
|
||||
new_gpus_in_use = worker.redis_client.hincrby(local_scheduler[b"ray_client_id"], b"gpus_in_use", num_gpus)
|
||||
new_gpus_in_use = worker.redis_client.hincrby(
|
||||
local_scheduler[b"ray_client_id"], b"gpus_in_use", num_gpus)
|
||||
if new_gpus_in_use > local_scheduler_total_gpus:
|
||||
# If we failed to reserve the GPUs, undo the increment.
|
||||
worker.redis_client.hincrby(local_scheduler[b"ray_client_id"], b"gpus_in_use", num_gpus)
|
||||
worker.redis_client.hincrby(local_scheduler[b"ray_client_id"],
|
||||
b"gpus_in_use", num_gpus)
|
||||
else:
|
||||
# We succeeded at reserving the GPUs, so we are done.
|
||||
local_scheduler_id = local_scheduler[b"ray_client_id"]
|
||||
@@ -135,10 +153,13 @@ def select_local_scheduler(local_schedulers, num_gpus, worker):
|
||||
break
|
||||
if local_scheduler_id is None:
|
||||
raise Exception("Could not find a node with enough GPUs to create this "
|
||||
"actor. The local scheduler information is {}.".format(local_schedulers))
|
||||
"actor. The local scheduler information is {}."
|
||||
.format(local_schedulers))
|
||||
return local_scheduler_id, gpu_ids
|
||||
|
||||
def export_actor(actor_id, Class, actor_method_names, num_cpus, num_gpus, worker):
|
||||
|
||||
def export_actor(actor_id, Class, actor_method_names, num_cpus, num_gpus,
|
||||
worker):
|
||||
"""Export an actor to redis.
|
||||
|
||||
Args:
|
||||
@@ -158,13 +179,16 @@ def export_actor(actor_id, Class, actor_method_names, num_cpus, num_gpus, worker
|
||||
driver_id = worker.task_driver_id.id()
|
||||
for actor_method_name in actor_method_names:
|
||||
function_id = get_actor_method_function_id(actor_method_name).id()
|
||||
worker.function_properties[driver_id][function_id] = (1, num_cpus, num_gpus)
|
||||
worker.function_properties[driver_id][function_id] = (1, num_cpus,
|
||||
num_gpus)
|
||||
|
||||
# Select a local scheduler for the actor.
|
||||
local_schedulers = state.get_local_schedulers(worker)
|
||||
local_scheduler_id, gpu_ids = select_local_scheduler(local_schedulers, num_gpus, worker)
|
||||
local_scheduler_id, gpu_ids = select_local_scheduler(local_schedulers,
|
||||
num_gpus, worker)
|
||||
|
||||
worker.redis_client.publish("actor_notifications", actor_id.id() + local_scheduler_id)
|
||||
worker.redis_client.publish("actor_notifications",
|
||||
actor_id.id() + local_scheduler_id)
|
||||
|
||||
d = {"driver_id": driver_id,
|
||||
"actor_id": actor_id.id(),
|
||||
@@ -176,6 +200,7 @@ def export_actor(actor_id, Class, actor_method_names, num_cpus, num_gpus, worker
|
||||
worker.redis_client.hmset(key, d)
|
||||
worker.redis_client.rpush("Exports", key)
|
||||
|
||||
|
||||
def actor(*args, **kwargs):
|
||||
def make_actor_decorator(num_cpus=1, num_gpus=0):
|
||||
def make_actor(Class):
|
||||
@@ -189,7 +214,8 @@ def actor(*args, **kwargs):
|
||||
raise Exception("Actors currently do not support **kwargs.")
|
||||
function_id = get_actor_method_function_id(attr)
|
||||
# TODO(pcm): Extend args with keyword args.
|
||||
object_ids = ray.worker.global_worker.submit_task(function_id, "", args,
|
||||
object_ids = ray.worker.global_worker.submit_task(function_id, "",
|
||||
args,
|
||||
actor_id=actor_id)
|
||||
if len(object_ids) == 1:
|
||||
return object_ids[0]
|
||||
@@ -199,24 +225,34 @@ def actor(*args, **kwargs):
|
||||
class NewClass(object):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._ray_actor_id = random_actor_id()
|
||||
self._ray_actor_methods = {k: v for (k, v) in inspect.getmembers(Class, predicate=(lambda x: inspect.isfunction(x) or inspect.ismethod(x)))}
|
||||
export_actor(self._ray_actor_id, Class, self._ray_actor_methods.keys(), num_cpus, num_gpus, ray.worker.global_worker)
|
||||
self._ray_actor_methods = {
|
||||
k: v for (k, v) in inspect.getmembers(
|
||||
Class, predicate=(lambda x: (inspect.isfunction(x) or
|
||||
inspect.ismethod(x))))}
|
||||
export_actor(self._ray_actor_id, Class,
|
||||
self._ray_actor_methods.keys(), num_cpus, num_gpus,
|
||||
ray.worker.global_worker)
|
||||
# Call __init__ as a remote function.
|
||||
if "__init__" in self._ray_actor_methods.keys():
|
||||
actor_method_call(self._ray_actor_id, "__init__", *args, **kwargs)
|
||||
else:
|
||||
print("WARNING: this object has no __init__ method.")
|
||||
|
||||
# Make tab completion work.
|
||||
def __dir__(self):
|
||||
return self._ray_actor_methods
|
||||
|
||||
def __getattribute__(self, attr):
|
||||
# The following is needed so we can still access self.actor_methods.
|
||||
if attr in ["_ray_actor_id", "_ray_actor_methods"]:
|
||||
return super(NewClass, self).__getattribute__(attr)
|
||||
if attr in self._ray_actor_methods.keys():
|
||||
return lambda *args, **kwargs: actor_method_call(self._ray_actor_id, attr, *args, **kwargs)
|
||||
return lambda *args, **kwargs: actor_method_call(
|
||||
self._ray_actor_id, attr, *args, **kwargs)
|
||||
# There is no method with this name, so raise an exception.
|
||||
raise AttributeError("'{}' Actor object has no attribute '{}'".format(Class, attr))
|
||||
raise AttributeError("'{}' Actor object has no attribute '{}'"
|
||||
.format(Class, attr))
|
||||
|
||||
def __repr__(self):
|
||||
return "Actor(" + self._ray_actor_id.hex() + ")"
|
||||
|
||||
@@ -230,7 +266,9 @@ def actor(*args, **kwargs):
|
||||
return make_actor_decorator(num_cpus=1, num_gpus=0)(Class)
|
||||
|
||||
# In this case, the actor decorator is something like @ray.actor(num_gpus=1).
|
||||
if len(args) == 0 and len(kwargs) > 0 and all([key in ["num_cpus", "num_gpus"] for key in kwargs.keys()]):
|
||||
if len(args) == 0 and len(kwargs) > 0 and all([key
|
||||
in ["num_cpus", "num_gpus"]
|
||||
for key in kwargs.keys()]):
|
||||
num_cpus = kwargs["num_cpus"] if "num_cpus" in kwargs.keys() else 1
|
||||
num_gpus = kwargs["num_gpus"] if "num_gpus" in kwargs.keys() else 0
|
||||
return make_actor_decorator(num_cpus=num_cpus, num_gpus=num_gpus)
|
||||
@@ -240,4 +278,5 @@ def actor(*args, **kwargs):
|
||||
"some of the arguments 'num_cpus' or 'num_gpus' as in "
|
||||
"'ray.actor(num_gpus=1)'.")
|
||||
|
||||
|
||||
ray.worker.global_worker.fetch_and_register["Actor"] = fetch_and_register_actor
|
||||
|
||||
@@ -2,17 +2,16 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import random
|
||||
import subprocess
|
||||
import redis
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
import redis
|
||||
|
||||
import ray.services
|
||||
|
||||
# Import flatbuffer bindings.
|
||||
from ray.core.generated.SubscribeToNotificationsReply import SubscribeToNotificationsReply
|
||||
from ray.core.generated.SubscribeToNotificationsReply \
|
||||
import SubscribeToNotificationsReply
|
||||
from ray.core.generated.TaskReply import TaskReply
|
||||
from ray.core.generated.ResultTableReply import ResultTableReply
|
||||
|
||||
@@ -22,6 +21,7 @@ OBJECT_SUBSCRIBE_PREFIX = "OS:"
|
||||
TASK_PREFIX = "TT:"
|
||||
OBJECT_CHANNEL_PREFIX = "OC:"
|
||||
|
||||
|
||||
def integerToAsciiHex(num, numbytes):
|
||||
retstr = b""
|
||||
# Support 32 and 64 bit architecture.
|
||||
@@ -36,6 +36,7 @@ def integerToAsciiHex(num, numbytes):
|
||||
|
||||
return retstr
|
||||
|
||||
|
||||
def get_next_message(pubsub_client, timeout_seconds=10):
|
||||
"""Block until the next message is available on the pubsub channel."""
|
||||
start_time = time.time()
|
||||
@@ -47,6 +48,7 @@ def get_next_message(pubsub_client, timeout_seconds=10):
|
||||
if time.time() - start_time > timeout_seconds:
|
||||
raise Exception("Timed out while waiting for next message.")
|
||||
|
||||
|
||||
class TestGlobalStateStore(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@@ -57,102 +59,144 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
ray.services.cleanup()
|
||||
|
||||
def testInvalidObjectTableAdd(self):
|
||||
# Check that Redis returns an error when RAY.OBJECT_TABLE_ADD is called with
|
||||
# the wrong arguments.
|
||||
# Check that Redis returns an error when RAY.OBJECT_TABLE_ADD is called
|
||||
# with the wrong arguments.
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD")
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "hello")
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", "one", "hash2", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", "one",
|
||||
"hash2", "manager_id1")
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", 1, "hash2", "manager_id1", "extra argument")
|
||||
# Check that Redis returns an error when RAY.OBJECT_TABLE_ADD adds an object
|
||||
# ID that is already present with a different hash.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", 1,
|
||||
"hash2", "manager_id1", "extra argument")
|
||||
# Check that Redis returns an error when RAY.OBJECT_TABLE_ADD adds an
|
||||
# object ID that is already present with a different hash.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash1", "manager_id1")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), {b"manager_id1"})
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash2", "manager_id2")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash2", "manager_id2")
|
||||
# Check that the second manager was added, even though the hash was
|
||||
# mismatched.
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
|
||||
# Check that it is fine if we add the same object ID multiple times with the
|
||||
# most recent hash.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash2", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash2", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash2", "manager_id2")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 2, "hash2", "manager_id2")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
|
||||
# Check that it is fine if we add the same object ID multiple times with
|
||||
# the most recent hash.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash2", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash2", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash2", "manager_id2")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 2,
|
||||
"hash2", "manager_id2")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
|
||||
|
||||
def testObjectTableAddAndLookup(self):
|
||||
# Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not been
|
||||
# added yet.
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(response, None)
|
||||
# Add some managers and try again.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id2")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash1", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash1", "manager_id2")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
|
||||
# Add a manager that already exists again and try again.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id2")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash1", "manager_id2")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
|
||||
# Check that we properly handle NULL characters. In the past, NULL
|
||||
# characters were handled improperly causing a "hash mismatch" error if two
|
||||
# object IDs that agreed up to the NULL character were inserted with
|
||||
# different hashes.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id3", 1, "hash1", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id4", 1, "hash2", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id3", 1,
|
||||
"hash1", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id4", 1,
|
||||
"hash2", "manager_id1")
|
||||
# Check that NULL characters in the hash are handled properly.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1, "\x00hash1", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1,
|
||||
"\x00hash1", "manager_id1")
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1, "\x00hash2", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1,
|
||||
"\x00hash2", "manager_id1")
|
||||
|
||||
def testObjectTableAddAndRemove(self):
|
||||
# Try removing a manager from an object ID that has not been added yet.
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
|
||||
"manager_id1")
|
||||
# Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not been
|
||||
# added yet.
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(response, None)
|
||||
# Add some managers and try again.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id2")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash1", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash1", "manager_id2")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
|
||||
# Remove a manager that doesn't exist, and make sure we still have the same set.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id3")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
|
||||
# Remove a manager that doesn't exist, and make sure we still have the same
|
||||
# set.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
|
||||
"manager_id3")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
|
||||
# Remove a manager that does exist. Make sure it gets removed the first
|
||||
# time and does nothing the second time.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id1")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
|
||||
"manager_id1")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), {b"manager_id2"})
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id1")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
|
||||
"manager_id1")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), {b"manager_id2"})
|
||||
# Remove the last manager, and make sure we have an empty set.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id2")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
|
||||
"manager_id2")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), set())
|
||||
# Remove a manager from an empty set, and make sure we now have an empty set.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id3")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
|
||||
# Remove a manager from an empty set, and make sure we now have an empty
|
||||
# set.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
|
||||
"manager_id3")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), set())
|
||||
|
||||
def testObjectTableSubscribeToNotifications(self):
|
||||
# Define a helper method for checking the contents of object notifications.
|
||||
def check_object_notification(notification_message, object_id, object_size, manager_ids):
|
||||
notification_object = SubscribeToNotificationsReply.GetRootAsSubscribeToNotificationsReply(notification_message, 0)
|
||||
def check_object_notification(notification_message, object_id, object_size,
|
||||
manager_ids):
|
||||
notification_object = (SubscribeToNotificationsReply
|
||||
.GetRootAsSubscribeToNotificationsReply(
|
||||
notification_message, 0))
|
||||
self.assertEqual(notification_object.ObjectId(), object_id)
|
||||
self.assertEqual(notification_object.ObjectSize(), object_size)
|
||||
self.assertEqual(notification_object.ManagerIdsLength(), len(manager_ids))
|
||||
self.assertEqual(notification_object.ManagerIdsLength(),
|
||||
len(manager_ids))
|
||||
for i in range(len(manager_ids)):
|
||||
self.assertEqual(notification_object.ManagerIds(i), manager_ids[i])
|
||||
|
||||
@@ -160,37 +204,45 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
p = self.redis.pubsub()
|
||||
# Subscribe to an object ID.
|
||||
p.psubscribe("{}manager_id1".format(OBJECT_CHANNEL_PREFIX))
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", data_size, "hash1", "manager_id2")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", data_size,
|
||||
"hash1", "manager_id2")
|
||||
# Receive the acknowledgement message.
|
||||
self.assertEqual(get_next_message(p)["data"], 1)
|
||||
# Request a notification and receive the data.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", "manager_id1", "object_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
|
||||
"manager_id1", "object_id1")
|
||||
# Verify that the notification is correct.
|
||||
check_object_notification(get_next_message(p)["data"],
|
||||
b"object_id1",
|
||||
data_size,
|
||||
[b"manager_id2"])
|
||||
|
||||
# Request a notification for an object that isn't there. Then add the object
|
||||
# and receive the data. Only the first call to RAY.OBJECT_TABLE_ADD should
|
||||
# trigger notifications.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", "manager_id1", "object_id2", "object_id3")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, "hash1", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, "hash1", "manager_id2")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, "hash1", "manager_id3")
|
||||
# Request a notification for an object that isn't there. Then add the
|
||||
# object and receive the data. Only the first call to RAY.OBJECT_TABLE_ADD
|
||||
# should trigger notifications.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
|
||||
"manager_id1", "object_id2", "object_id3")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size,
|
||||
"hash1", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size,
|
||||
"hash1", "manager_id2")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size,
|
||||
"hash1", "manager_id3")
|
||||
# Verify that the notification is correct.
|
||||
check_object_notification(get_next_message(p)["data"],
|
||||
b"object_id3",
|
||||
data_size,
|
||||
[b"manager_id1"])
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", data_size, "hash1", "manager_id3")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", data_size,
|
||||
"hash1", "manager_id3")
|
||||
# Verify that the notification is correct.
|
||||
check_object_notification(get_next_message(p)["data"],
|
||||
b"object_id2",
|
||||
data_size,
|
||||
[b"manager_id3"])
|
||||
# Request notifications for object_id3 again.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", "manager_id1", "object_id3")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
|
||||
"manager_id1", "object_id3")
|
||||
# Verify that the notification is correct.
|
||||
check_object_notification(get_next_message(p)["data"],
|
||||
b"object_id3",
|
||||
@@ -199,29 +251,38 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
|
||||
def testResultTableAddAndLookup(self):
|
||||
def check_result_table_entry(message, task_id, is_put):
|
||||
result_table_reply = ResultTableReply.GetRootAsResultTableReply(message, 0)
|
||||
result_table_reply = ResultTableReply.GetRootAsResultTableReply(message,
|
||||
0)
|
||||
self.assertEqual(result_table_reply.TaskId(), task_id)
|
||||
self.assertEqual(result_table_reply.IsPut(), is_put)
|
||||
|
||||
# Try looking up something in the result table before anything is added.
|
||||
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1")
|
||||
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertIsNone(response)
|
||||
# Adding the object to the object table should have no effect.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1")
|
||||
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash1", "manager_id1")
|
||||
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertIsNone(response)
|
||||
# Add the result to the result table. The lookup now returns the task ID.
|
||||
task_id = b"task_id1"
|
||||
self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id1", task_id, 0)
|
||||
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1")
|
||||
self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id1", task_id,
|
||||
0)
|
||||
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
check_result_table_entry(response, task_id, False)
|
||||
# Doing it again should still work.
|
||||
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1")
|
||||
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
check_result_table_entry(response, task_id, False)
|
||||
# Try another result table lookup. This should succeed.
|
||||
task_id = b"task_id2"
|
||||
self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id2", task_id, 1)
|
||||
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id2")
|
||||
self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id2", task_id,
|
||||
1)
|
||||
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
|
||||
"object_id2")
|
||||
check_result_table_entry(response, task_id, True)
|
||||
|
||||
def testInvalidTaskTableAdd(self):
|
||||
@@ -251,7 +312,8 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
task_status, local_scheduler_id, task_spec = task_args
|
||||
task_reply_object = TaskReply.GetRootAsTaskReply(message, 0)
|
||||
self.assertEqual(task_reply_object.State(), task_status)
|
||||
self.assertEqual(task_reply_object.LocalSchedulerId(), local_scheduler_id)
|
||||
self.assertEqual(task_reply_object.LocalSchedulerId(),
|
||||
local_scheduler_id)
|
||||
self.assertEqual(task_reply_object.TaskSpec(), task_spec)
|
||||
self.assertEqual(task_reply_object.Updated(), updated)
|
||||
|
||||
@@ -263,7 +325,8 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
check_task_reply(response, task_args)
|
||||
|
||||
task_args[0] = TASK_STATUS_SCHEDULED
|
||||
self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", *task_args[:2])
|
||||
self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id",
|
||||
*task_args[:2])
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
|
||||
check_task_reply(response, task_args)
|
||||
|
||||
@@ -331,9 +394,12 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
# Subscribe to the task table.
|
||||
p = self.redis.pubsub()
|
||||
p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
|
||||
p.psubscribe("{prefix}*:{state}".format(prefix=TASK_PREFIX, state=scheduling_state))
|
||||
p.psubscribe("{prefix}{local_scheduler_id}:*".format(prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
|
||||
task_args = [b"task_id", scheduling_state, local_scheduler_id.encode("ascii"), b"task_spec"]
|
||||
p.psubscribe("{prefix}*:{state}".format(
|
||||
prefix=TASK_PREFIX, state=scheduling_state))
|
||||
p.psubscribe("{prefix}{local_scheduler_id}:*".format(
|
||||
prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
|
||||
task_args = [b"task_id", scheduling_state,
|
||||
local_scheduler_id.encode("ascii"), b"task_spec"]
|
||||
self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args)
|
||||
# Receive the acknowledgement message.
|
||||
self.assertEqual(get_next_message(p)["data"], 1)
|
||||
@@ -346,8 +412,10 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
notification_object = TaskReply.GetRootAsTaskReply(message, 0)
|
||||
self.assertEqual(notification_object.TaskId(), b"task_id")
|
||||
self.assertEqual(notification_object.State(), scheduling_state)
|
||||
self.assertEqual(notification_object.LocalSchedulerId(), local_scheduler_id.encode("ascii"))
|
||||
self.assertEqual(notification_object.LocalSchedulerId(),
|
||||
local_scheduler_id.encode("ascii"))
|
||||
self.assertEqual(notification_object.TaskSpec(), b"task_spec")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -11,25 +11,29 @@ import ray.local_scheduler as local_scheduler
|
||||
|
||||
ID_SIZE = 20
|
||||
|
||||
|
||||
def random_object_id():
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
|
||||
|
||||
def random_function_id():
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
|
||||
|
||||
def random_driver_id():
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
|
||||
|
||||
def random_task_id():
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
|
||||
|
||||
BASE_SIMPLE_OBJECTS = [
|
||||
0, 1, 100000, 0.0, 0.5, 0.9, 100000.1, (), [], {},
|
||||
"", 990 * "h", u"", 990 * u"h"
|
||||
]
|
||||
0, 1, 100000, 0.0, 0.5, 0.9, 100000.1, (), [], {}, "", 990 * "h", u"",
|
||||
990 * u"h"]
|
||||
|
||||
if sys.version_info < (3, 0):
|
||||
BASE_SIMPLE_OBJECTS += [long(0), long(1), long(100000), long(1 << 100)]
|
||||
BASE_SIMPLE_OBJECTS += [long(0), long(1), long(100000), long(1 << 100)] # noqa: E501,F821
|
||||
|
||||
LIST_SIMPLE_OBJECTS = [[obj] for obj in BASE_SIMPLE_OBJECTS]
|
||||
TUPLE_SIMPLE_OBJECTS = [(obj,) for obj in BASE_SIMPLE_OBJECTS]
|
||||
@@ -45,11 +49,14 @@ SIMPLE_OBJECTS = (BASE_SIMPLE_OBJECTS +
|
||||
l = []
|
||||
l.append(l)
|
||||
|
||||
|
||||
class Foo(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
BASE_COMPLEX_OBJECTS = [999 * "h", 999 * u"h", l, Foo(), 10 * [10 * [10 * [1]]]]
|
||||
|
||||
BASE_COMPLEX_OBJECTS = [999 * "h", 999 * u"h", l, Foo(),
|
||||
10 * [10 * [10 * [1]]]]
|
||||
|
||||
LIST_COMPLEX_OBJECTS = [[obj] for obj in BASE_COMPLEX_OBJECTS]
|
||||
TUPLE_COMPLEX_OBJECTS = [(obj,) for obj in BASE_COMPLEX_OBJECTS]
|
||||
@@ -60,6 +67,7 @@ COMPLEX_OBJECTS = (BASE_COMPLEX_OBJECTS +
|
||||
TUPLE_COMPLEX_OBJECTS +
|
||||
DICT_COMPLEX_OBJECTS)
|
||||
|
||||
|
||||
class TestSerialization(unittest.TestCase):
|
||||
|
||||
def test_serialize_by_value(self):
|
||||
@@ -69,27 +77,31 @@ class TestSerialization(unittest.TestCase):
|
||||
for val in COMPLEX_OBJECTS:
|
||||
self.assertFalse(local_scheduler.check_simple_value(val))
|
||||
|
||||
|
||||
class TestObjectID(unittest.TestCase):
|
||||
|
||||
def test_create_object_id(self):
|
||||
object_id = random_object_id()
|
||||
random_object_id()
|
||||
|
||||
def test_cannot_pickle_object_ids(self):
|
||||
object_ids = [random_object_id() for _ in range(256)]
|
||||
|
||||
def f():
|
||||
return object_ids
|
||||
|
||||
def g(val=object_ids):
|
||||
return 1
|
||||
|
||||
def h():
|
||||
x = object_ids[0]
|
||||
object_ids[0]
|
||||
return 1
|
||||
# Make sure that object IDs cannot be pickled (including functions that
|
||||
# close over object IDs).
|
||||
self.assertRaises(Exception, lambda : pickling.dumps(object_ids[0]))
|
||||
self.assertRaises(Exception, lambda : pickling.dumps(object_ids))
|
||||
self.assertRaises(Exception, lambda : pickling.dumps(f))
|
||||
self.assertRaises(Exception, lambda : pickling.dumps(g))
|
||||
self.assertRaises(Exception, lambda : pickling.dumps(h))
|
||||
self.assertRaises(Exception, lambda: pickle.dumps(object_ids[0]))
|
||||
self.assertRaises(Exception, lambda: pickle.dumps(object_ids))
|
||||
self.assertRaises(Exception, lambda: pickle.dumps(f))
|
||||
self.assertRaises(Exception, lambda: pickle.dumps(g))
|
||||
self.assertRaises(Exception, lambda: pickle.dumps(h))
|
||||
|
||||
def test_equality_comparisons(self):
|
||||
x1 = local_scheduler.ObjectID(ID_SIZE * b"a")
|
||||
@@ -101,8 +113,10 @@ class TestObjectID(unittest.TestCase):
|
||||
self.assertNotEqual(x1, y1)
|
||||
|
||||
random_strings = [np.random.bytes(ID_SIZE) for _ in range(256)]
|
||||
object_ids1 = [local_scheduler.ObjectID(random_strings[i]) for i in range(256)]
|
||||
object_ids2 = [local_scheduler.ObjectID(random_strings[i]) for i in range(256)]
|
||||
object_ids1 = [local_scheduler.ObjectID(random_strings[i])
|
||||
for i in range(256)]
|
||||
object_ids2 = [local_scheduler.ObjectID(random_strings[i])
|
||||
for i in range(256)]
|
||||
self.assertEqual(len(set(object_ids1)), 256)
|
||||
self.assertEqual(len(set(object_ids1 + object_ids2)), 256)
|
||||
self.assertEqual(set(object_ids1), set(object_ids2))
|
||||
@@ -113,6 +127,7 @@ class TestObjectID(unittest.TestCase):
|
||||
{x: y}
|
||||
set([x, y])
|
||||
|
||||
|
||||
class TestTask(unittest.TestCase):
|
||||
|
||||
def check_task(self, task, function_id, num_return_vals, args):
|
||||
@@ -127,44 +142,46 @@ class TestTask(unittest.TestCase):
|
||||
self.assertEqual(retrieved_args[i], args[i])
|
||||
|
||||
def test_create_and_serialize_task(self):
|
||||
# TODO(rkn): The function ID should be a FunctionID object, not an ObjectID.
|
||||
# TODO(rkn): The function ID should be a FunctionID object, not an
|
||||
# ObjectID.
|
||||
driver_id = random_driver_id()
|
||||
parent_id = random_task_id()
|
||||
function_id = random_function_id()
|
||||
object_ids = [random_object_id() for _ in range(256)]
|
||||
args_list = [
|
||||
[],
|
||||
1 * [1],
|
||||
10 * [1],
|
||||
100 * [1],
|
||||
1000 * [1],
|
||||
1 * ["a"],
|
||||
10 * ["a"],
|
||||
100 * ["a"],
|
||||
1000 * ["a"],
|
||||
[1, 1.3, 2, 1 << 100, "hi", u"hi", [1, 2]],
|
||||
object_ids[:1],
|
||||
object_ids[:2],
|
||||
object_ids[:3],
|
||||
object_ids[:4],
|
||||
object_ids[:5],
|
||||
object_ids[:10],
|
||||
object_ids[:100],
|
||||
object_ids[:256],
|
||||
[1, object_ids[0]],
|
||||
[object_ids[0], "a"],
|
||||
[1, object_ids[0], "a"],
|
||||
[object_ids[0], 1, object_ids[1], "a"],
|
||||
object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
|
||||
object_ids + 100 * ["a"] + object_ids
|
||||
]
|
||||
[],
|
||||
1 * [1],
|
||||
10 * [1],
|
||||
100 * [1],
|
||||
1000 * [1],
|
||||
1 * ["a"],
|
||||
10 * ["a"],
|
||||
100 * ["a"],
|
||||
1000 * ["a"],
|
||||
[1, 1.3, 2, 1 << 100, "hi", u"hi", [1, 2]],
|
||||
object_ids[:1],
|
||||
object_ids[:2],
|
||||
object_ids[:3],
|
||||
object_ids[:4],
|
||||
object_ids[:5],
|
||||
object_ids[:10],
|
||||
object_ids[:100],
|
||||
object_ids[:256],
|
||||
[1, object_ids[0]],
|
||||
[object_ids[0], "a"],
|
||||
[1, object_ids[0], "a"],
|
||||
[object_ids[0], 1, object_ids[1], "a"],
|
||||
object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
|
||||
object_ids + 100 * ["a"] + object_ids]
|
||||
for args in args_list:
|
||||
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
|
||||
task = local_scheduler.Task(driver_id, function_id, args, num_return_vals, parent_id, 0)
|
||||
task = local_scheduler.Task(driver_id, function_id, args,
|
||||
num_return_vals, parent_id, 0)
|
||||
self.check_task(task, function_id, num_return_vals, args)
|
||||
data = local_scheduler.task_to_string(task)
|
||||
task2 = local_scheduler.task_from_string(data)
|
||||
self.check_task(task2, function_id, num_return_vals, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -4,3 +4,5 @@ from __future__ import print_function
|
||||
|
||||
from .utils import copy_directory
|
||||
from .tfutils import TensorFlowVariables
|
||||
|
||||
__all__ = ["copy_directory", "TensorFlowVariables"]
|
||||
|
||||
@@ -4,4 +4,10 @@ from __future__ import print_function
|
||||
|
||||
from . import random
|
||||
from . import linalg
|
||||
from .core import *
|
||||
from .core import (BLOCK_SIZE, DistArray, assemble, zeros, ones, copy, eye,
|
||||
triu, tril, blockwise_dot, dot, transpose, add, subtract,
|
||||
numpy_to_dist, subblocks)
|
||||
|
||||
__all__ = ["random", "linalg", "BLOCK_SIZE", "DistArray", "assemble", "zeros",
|
||||
"ones", "copy", "eye", "triu", "tril", "blockwise_dot", "dot",
|
||||
"transpose", "add", "subtract", "numpy_to_dist", "subblocks"]
|
||||
|
||||
@@ -6,30 +6,38 @@ import numpy as np
|
||||
import ray.experimental.array.remote as ra
|
||||
import ray
|
||||
|
||||
__all__ = ["BLOCK_SIZE", "DistArray", "assemble", "zeros", "ones", "copy",
|
||||
"eye", "triu", "tril", "blockwise_dot", "dot", "transpose", "add", "subtract", "numpy_to_dist", "subblocks"]
|
||||
|
||||
BLOCK_SIZE = 10
|
||||
|
||||
|
||||
class DistArray(object):
|
||||
def __init__(self, shape, objectids=None):
|
||||
self.shape = shape
|
||||
self.ndim = len(shape)
|
||||
self.num_blocks = [int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in self.shape]
|
||||
self.objectids = objectids if objectids is not None else np.empty(self.num_blocks, dtype=object)
|
||||
if objectids is not None:
|
||||
self.objectids = objectids
|
||||
else:
|
||||
self.objectids = np.empty(self.num_blocks, dtype=object)
|
||||
if self.num_blocks != list(self.objectids.shape):
|
||||
raise Exception("The fields `num_blocks` and `objectids` are inconsistent, `num_blocks` is {} and `objectids` has shape {}".format(self.num_blocks, list(self.objectids.shape)))
|
||||
raise Exception("The fields `num_blocks` and `objectids` are "
|
||||
"inconsistent, `num_blocks` is {} and `objectids` has "
|
||||
"shape {}".format(self.num_blocks,
|
||||
list(self.objectids.shape)))
|
||||
|
||||
@staticmethod
|
||||
def compute_block_lower(index, shape):
|
||||
if len(index) != len(shape):
|
||||
raise Exception("The fields `index` and `shape` must have the same length, but `index` is {} and `shape` is {}.".format(index, shape))
|
||||
raise Exception("The fields `index` and `shape` must have the same "
|
||||
"length, but `index` is {} and `shape` is "
|
||||
"{}.".format(index, shape))
|
||||
return [elem * BLOCK_SIZE for elem in index]
|
||||
|
||||
@staticmethod
|
||||
def compute_block_upper(index, shape):
|
||||
if len(index) != len(shape):
|
||||
raise Exception("The fields `index` and `shape` must have the same length, but `index` is {} and `shape` is {}.".format(index, shape))
|
||||
raise Exception("The fields `index` and `shape` must have the same "
|
||||
"length, but `index` is {} and `shape` is "
|
||||
"{}.".format(index, shape))
|
||||
upper = []
|
||||
for i in range(len(shape)):
|
||||
upper.append(min((index[i] + 1) * BLOCK_SIZE, shape[i]))
|
||||
@@ -46,59 +54,73 @@ class DistArray(object):
|
||||
return [int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in shape]
|
||||
|
||||
def assemble(self):
|
||||
"""Assemble an array on this node from a distributed array of object IDs."""
|
||||
"""Assemble an array from a distributed array of object IDs."""
|
||||
first_block = ray.get(self.objectids[(0,) * self.ndim])
|
||||
dtype = first_block.dtype
|
||||
result = np.zeros(self.shape, dtype=dtype)
|
||||
for index in np.ndindex(*self.num_blocks):
|
||||
lower = DistArray.compute_block_lower(index, self.shape)
|
||||
upper = DistArray.compute_block_upper(index, self.shape)
|
||||
result[[slice(l, u) for (l, u) in zip(lower, upper)]] = ray.get(self.objectids[index])
|
||||
result[[slice(l, u) for (l, u) in zip(lower, upper)]] = ray.get(
|
||||
self.objectids[index])
|
||||
return result
|
||||
|
||||
def __getitem__(self, sliced):
|
||||
# TODO(rkn): fix this, this is just a placeholder that should work but is inefficient
|
||||
# TODO(rkn): Fix this, this is just a placeholder that should work but is
|
||||
# inefficient.
|
||||
a = self.assemble()
|
||||
return a[sliced]
|
||||
|
||||
|
||||
# Register the DistArray class with Ray so that it knows how to serialize it.
|
||||
ray.register_class(DistArray)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def assemble(a):
|
||||
return a.assemble()
|
||||
|
||||
# TODO(rkn): what should we call this method
|
||||
|
||||
# TODO(rkn): What should we call this method?
|
||||
@ray.remote
|
||||
def numpy_to_dist(a):
|
||||
result = DistArray(a.shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
lower = DistArray.compute_block_lower(index, a.shape)
|
||||
upper = DistArray.compute_block_upper(index, a.shape)
|
||||
result.objectids[index] = ray.put(a[[slice(l, u) for (l, u) in zip(lower, upper)]])
|
||||
result.objectids[index] = ray.put(a[[slice(l, u) for (l, u)
|
||||
in zip(lower, upper)]])
|
||||
return result
|
||||
|
||||
|
||||
@ray.remote
|
||||
def zeros(shape, dtype_name="float"):
|
||||
result = DistArray(shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objectids[index] = ra.zeros.remote(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
|
||||
result.objectids[index] = ra.zeros.remote(
|
||||
DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
|
||||
return result
|
||||
|
||||
|
||||
@ray.remote
|
||||
def ones(shape, dtype_name="float"):
|
||||
result = DistArray(shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objectids[index] = ra.ones.remote(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
|
||||
result.objectids[index] = ra.ones.remote(
|
||||
DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
|
||||
return result
|
||||
|
||||
|
||||
@ray.remote
|
||||
def copy(a):
|
||||
result = DistArray(a.shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objectids[index] = a.objectids[index] # We don't need to actually copy the objects because cluster-level objects are assumed to be immutable.
|
||||
# We don't need to actually copy the objects because remote objects are
|
||||
# immutable.
|
||||
result.objectids[index] = a.objectids[index]
|
||||
return result
|
||||
|
||||
|
||||
@ray.remote
|
||||
def eye(dim1, dim2=-1, dtype_name="float"):
|
||||
dim2 = dim1 if dim2 == -1 else dim2
|
||||
@@ -107,15 +129,19 @@ def eye(dim1, dim2=-1, dtype_name="float"):
|
||||
for (i, j) in np.ndindex(*result.num_blocks):
|
||||
block_shape = DistArray.compute_block_shape([i, j], shape)
|
||||
if i == j:
|
||||
result.objectids[i, j] = ra.eye.remote(block_shape[0], block_shape[1], dtype_name=dtype_name)
|
||||
result.objectids[i, j] = ra.eye.remote(block_shape[0], block_shape[1],
|
||||
dtype_name=dtype_name)
|
||||
else:
|
||||
result.objectids[i, j] = ra.zeros.remote(block_shape, dtype_name=dtype_name)
|
||||
result.objectids[i, j] = ra.zeros.remote(block_shape,
|
||||
dtype_name=dtype_name)
|
||||
return result
|
||||
|
||||
|
||||
@ray.remote
|
||||
def triu(a):
|
||||
if a.ndim != 2:
|
||||
raise Exception("Input must have 2 dimensions, but a.ndim is " + str(a.ndim))
|
||||
raise Exception("Input must have 2 dimensions, but a.ndim is "
|
||||
"{}.".format(a.ndim))
|
||||
result = DistArray(a.shape)
|
||||
for (i, j) in np.ndindex(*result.num_blocks):
|
||||
if i < j:
|
||||
@@ -126,10 +152,12 @@ def triu(a):
|
||||
result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j])
|
||||
return result
|
||||
|
||||
|
||||
@ray.remote
|
||||
def tril(a):
|
||||
if a.ndim != 2:
|
||||
raise Exception("Input must have 2 dimensions, but a.ndim is " + str(a.ndim))
|
||||
raise Exception("Input must have 2 dimensions, but a.ndim is "
|
||||
"{}.".format(a.ndim))
|
||||
result = DistArray(a.shape)
|
||||
for (i, j) in np.ndindex(*result.num_blocks):
|
||||
if i > j:
|
||||
@@ -140,25 +168,31 @@ def tril(a):
|
||||
result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j])
|
||||
return result
|
||||
|
||||
|
||||
@ray.remote
|
||||
def blockwise_dot(*matrices):
|
||||
n = len(matrices)
|
||||
if n % 2 != 0:
|
||||
raise Exception("blockwise_dot expects an even number of arguments, but len(matrices) is {}.".format(n))
|
||||
raise Exception("blockwise_dot expects an even number of arguments, but "
|
||||
"len(matrices) is {}.".format(n))
|
||||
shape = (matrices[0].shape[0], matrices[n // 2].shape[1])
|
||||
result = np.zeros(shape)
|
||||
for i in range(n // 2):
|
||||
result += np.dot(matrices[i], matrices[n // 2 + i])
|
||||
return result
|
||||
|
||||
|
||||
@ray.remote
|
||||
def dot(a, b):
|
||||
if a.ndim != 2:
|
||||
raise Exception("dot expects its arguments to be 2-dimensional, but a.ndim = {}.".format(a.ndim))
|
||||
raise Exception("dot expects its arguments to be 2-dimensional, but "
|
||||
"a.ndim = {}.".format(a.ndim))
|
||||
if b.ndim != 2:
|
||||
raise Exception("dot expects its arguments to be 2-dimensional, but b.ndim = {}.".format(b.ndim))
|
||||
raise Exception("dot expects its arguments to be 2-dimensional, but "
|
||||
"b.ndim = {}.".format(b.ndim))
|
||||
if a.shape[1] != b.shape[0]:
|
||||
raise Exception("dot expects a.shape[1] to equal b.shape[0], but a.shape = {} and b.shape = {}.".format(a.shape, b.shape))
|
||||
raise Exception("dot expects a.shape[1] to equal b.shape[0], but a.shape "
|
||||
"= {} and b.shape = {}.".format(a.shape, b.shape))
|
||||
shape = [a.shape[0], b.shape[1]]
|
||||
result = DistArray(shape)
|
||||
for (i, j) in np.ndindex(*result.num_blocks):
|
||||
@@ -166,10 +200,12 @@ def dot(a, b):
|
||||
result.objectids[i, j] = blockwise_dot.remote(*args)
|
||||
return result
|
||||
|
||||
|
||||
@ray.remote
|
||||
def subblocks(a, *ranges):
|
||||
"""
|
||||
This function produces a distributed array from a subset of the blocks in the `a`. The result and `a` will have the same number of dimensions.For example,
|
||||
This function produces a distributed array from a subset of the blocks in the
|
||||
`a`. The result and `a` will have the same number of dimensions.For example,
|
||||
subblocks(a, [0, 1], [2, 4])
|
||||
will produce a DistArray whose objectids are
|
||||
[[a.objectids[0, 2], a.objectids[0, 4]],
|
||||
@@ -178,50 +214,71 @@ def subblocks(a, *ranges):
|
||||
"""
|
||||
ranges = list(ranges)
|
||||
if len(ranges) != a.ndim:
|
||||
raise Exception("sub_blocks expects to receive a number of ranges equal to a.ndim, but it received {} ranges and a.ndim = {}.".format(len(ranges), a.ndim))
|
||||
raise Exception("sub_blocks expects to receive a number of ranges equal "
|
||||
"to a.ndim, but it received {} ranges and a.ndim = "
|
||||
"{}.".format(len(ranges), a.ndim))
|
||||
for i in range(len(ranges)):
|
||||
if ranges[i] == []: # We allow the user to pass in an empty list to indicate the full range
|
||||
# We allow the user to pass in an empty list to indicate the full range.
|
||||
if ranges[i] == []:
|
||||
ranges[i] = range(a.num_blocks[i])
|
||||
if not np.alltrue(ranges[i] == np.sort(ranges[i])):
|
||||
raise Exception("Ranges passed to sub_blocks must be sorted, but the {}th range is {}.".format(i, ranges[i]))
|
||||
raise Exception("Ranges passed to sub_blocks must be sorted, but the "
|
||||
"{}th range is {}.".format(i, ranges[i]))
|
||||
if ranges[i][0] < 0:
|
||||
raise Exception("Values in the ranges passed to sub_blocks must be at least 0, but the {}th range is {}.".format(i, ranges[i]))
|
||||
raise Exception("Values in the ranges passed to sub_blocks must be at "
|
||||
"least 0, but the {}th range is {}.".format(i,
|
||||
ranges[i]))
|
||||
if ranges[i][-1] >= a.num_blocks[i]:
|
||||
raise Exception("Values in the ranges passed to sub_blocks must be less than the relevant number of blocks, but the {}th range is {}, and a.num_blocks = {}.".format(i, ranges[i], a.num_blocks))
|
||||
raise Exception("Values in the ranges passed to sub_blocks must be less "
|
||||
"than the relevant number of blocks, but the {}th range "
|
||||
"is {}, and a.num_blocks = {}.".format(i, ranges[i],
|
||||
a.num_blocks))
|
||||
last_index = [r[-1] for r in ranges]
|
||||
last_block_shape = DistArray.compute_block_shape(last_index, a.shape)
|
||||
shape = [(len(ranges[i]) - 1) * BLOCK_SIZE + last_block_shape[i] for i in range(a.ndim)]
|
||||
shape = [(len(ranges[i]) - 1) * BLOCK_SIZE + last_block_shape[i]
|
||||
for i in range(a.ndim)]
|
||||
result = DistArray(shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objectids[index] = a.objectids[tuple([ranges[i][index[i]] for i in range(a.ndim)])]
|
||||
result.objectids[index] = a.objectids[tuple([ranges[i][index[i]]
|
||||
for i in range(a.ndim)])]
|
||||
return result
|
||||
|
||||
|
||||
@ray.remote
|
||||
def transpose(a):
|
||||
if a.ndim != 2:
|
||||
raise Exception("transpose expects its argument to be 2-dimensional, but a.ndim = {}, a.shape = {}.".format(a.ndim, a.shape))
|
||||
raise Exception("transpose expects its argument to be 2-dimensional, but "
|
||||
"a.ndim = {}, a.shape = {}.".format(a.ndim, a.shape))
|
||||
result = DistArray([a.shape[1], a.shape[0]])
|
||||
for i in range(result.num_blocks[0]):
|
||||
for j in range(result.num_blocks[1]):
|
||||
result.objectids[i, j] = ra.transpose.remote(a.objectids[j, i])
|
||||
return result
|
||||
|
||||
|
||||
# TODO(rkn): support broadcasting?
|
||||
@ray.remote
|
||||
def add(x1, x2):
|
||||
if x1.shape != x2.shape:
|
||||
raise Exception("add expects arguments `x1` and `x2` to have the same shape, but x1.shape = {}, and x2.shape = {}.".format(x1.shape, x2.shape))
|
||||
raise Exception("add expects arguments `x1` and `x2` to have the same "
|
||||
"shape, but x1.shape = {}, and x2.shape = {}."
|
||||
.format(x1.shape, x2.shape))
|
||||
result = DistArray(x1.shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objectids[index] = ra.add.remote(x1.objectids[index], x2.objectids[index])
|
||||
result.objectids[index] = ra.add.remote(x1.objectids[index],
|
||||
x2.objectids[index])
|
||||
return result
|
||||
|
||||
|
||||
# TODO(rkn): support broadcasting?
|
||||
@ray.remote
|
||||
def subtract(x1, x2):
|
||||
if x1.shape != x2.shape:
|
||||
raise Exception("subtract expects arguments `x1` and `x2` to have the same shape, but x1.shape = {}, and x2.shape = {}.".format(x1.shape, x2.shape))
|
||||
raise Exception("subtract expects arguments `x1` and `x2` to have the "
|
||||
"same shape, but x1.shape = {}, and x2.shape = {}."
|
||||
.format(x1.shape, x2.shape))
|
||||
result = DistArray(x1.shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objectids[index] = ra.subtract.remote(x1.objectids[index], x2.objectids[index])
|
||||
result.objectids[index] = ra.subtract.remote(x1.objectids[index],
|
||||
x2.objectids[index])
|
||||
return result
|
||||
|
||||
@@ -6,30 +6,32 @@ import numpy as np
|
||||
import ray.experimental.array.remote as ra
|
||||
import ray
|
||||
|
||||
from .core import *
|
||||
from . import core
|
||||
|
||||
__all__ = ["tsqr", "modified_lu", "tsqr_hr", "qr"]
|
||||
|
||||
|
||||
@ray.remote(num_return_vals=2)
|
||||
def tsqr(a):
|
||||
"""
|
||||
arguments:
|
||||
a: a distributed matrix
|
||||
Suppose that
|
||||
a.shape == (M, N)
|
||||
K == min(M, N)
|
||||
return values:
|
||||
q: DistArray, if q_full = ray.get(DistArray, q).assemble(), then
|
||||
q_full.shape == (M, K)
|
||||
np.allclose(np.dot(q_full.T, q_full), np.eye(K)) == True
|
||||
r: np.ndarray, if r_val = ray.get(np.ndarray, r), then
|
||||
r_val.shape == (K, N)
|
||||
np.allclose(r, np.triu(r)) == True
|
||||
"""Perform a QR decomposition of a tall-skinny matrix.
|
||||
|
||||
Args:
|
||||
a: A distributed matrix with shape MxN (suppose K = min(M, N)).
|
||||
|
||||
Returns:
|
||||
A tuple of q (a DistArray) and r (a numpy array) satisfying the following.
|
||||
- If q_full = ray.get(DistArray, q).assemble(), then
|
||||
q_full.shape == (M, K).
|
||||
- np.allclose(np.dot(q_full.T, q_full), np.eye(K)) == True.
|
||||
- If r_val = ray.get(np.ndarray, r), then r_val.shape == (K, N).
|
||||
- np.allclose(r, np.triu(r)) == True.
|
||||
"""
|
||||
if len(a.shape) != 2:
|
||||
raise Exception("tsqr requires len(a.shape) == 2, but a.shape is {}".format(a.shape))
|
||||
raise Exception("tsqr requires len(a.shape) == 2, but a.shape is "
|
||||
"{}".format(a.shape))
|
||||
if a.num_blocks[1] != 1:
|
||||
raise Exception("tsqr requires a.num_blocks[1] == 1, but a.num_blocks is {}".format(a.num_blocks))
|
||||
raise Exception("tsqr requires a.num_blocks[1] == 1, but a.num_blocks is "
|
||||
"{}".format(a.num_blocks))
|
||||
|
||||
num_blocks = a.num_blocks[0]
|
||||
K = int(np.ceil(np.log2(num_blocks))) + 1
|
||||
@@ -57,9 +59,9 @@ def tsqr(a):
|
||||
q_shape = a.shape
|
||||
else:
|
||||
q_shape = [a.shape[0], a.shape[0]]
|
||||
q_num_blocks = DistArray.compute_num_blocks(q_shape)
|
||||
q_num_blocks = core.DistArray.compute_num_blocks(q_shape)
|
||||
q_objectids = np.empty(q_num_blocks, dtype=object)
|
||||
q_result = DistArray(q_shape, q_objectids)
|
||||
q_result = core.DistArray(q_shape, q_objectids)
|
||||
|
||||
# reconstruct output
|
||||
for i in range(num_blocks):
|
||||
@@ -68,28 +70,35 @@ def tsqr(a):
|
||||
for j in range(1, K):
|
||||
if np.mod(ith_index, 2) == 0:
|
||||
lower = [0, 0]
|
||||
upper = [a.shape[1], BLOCK_SIZE]
|
||||
upper = [a.shape[1], core.BLOCK_SIZE]
|
||||
else:
|
||||
lower = [a.shape[1], 0]
|
||||
upper = [2 * a.shape[1], BLOCK_SIZE]
|
||||
upper = [2 * a.shape[1], core.BLOCK_SIZE]
|
||||
ith_index //= 2
|
||||
q_block_current = ra.dot.remote(q_block_current, ra.subarray.remote(q_tree[ith_index, j], lower, upper))
|
||||
q_block_current = ra.dot.remote(q_block_current,
|
||||
ra.subarray.remote(q_tree[ith_index, j],
|
||||
lower, upper))
|
||||
q_result.objectids[i] = q_block_current
|
||||
r = current_rs[0]
|
||||
return q_result, ray.get(r)
|
||||
|
||||
|
||||
# TODO(rkn): This is unoptimized, we really want a block version of this.
|
||||
# This is Algorithm 5 from
|
||||
# http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf.
|
||||
@ray.remote(num_return_vals=3)
|
||||
def modified_lu(q):
|
||||
"""
|
||||
Algorithm 5 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf
|
||||
takes a matrix q with orthonormal columns, returns l, u, s such that q - s = l * u
|
||||
arguments:
|
||||
q: a two dimensional orthonormal q
|
||||
return values:
|
||||
l: lower triangular
|
||||
u: upper triangular
|
||||
s: a diagonal matrix represented by its diagonal
|
||||
"""Perform a modified LU decomposition of a matrix.
|
||||
|
||||
This takes a matrix q with orthonormal columns, returns l, u, s such that
|
||||
q - s = l * u.
|
||||
|
||||
Args:
|
||||
q: A two dimensional orthonormal matrix q.
|
||||
|
||||
Returns:
|
||||
A tuple of a lower triangular matrix l, an upper triangular matrix u, and a
|
||||
a vector representing a diagonal matrix s such that q - s = l * u.
|
||||
"""
|
||||
q = q.assemble()
|
||||
m, b = q.shape[0], q.shape[1]
|
||||
@@ -100,14 +109,19 @@ def modified_lu(q):
|
||||
for i in range(b):
|
||||
S[i] = -1 * np.sign(q_work[i, i])
|
||||
q_work[i, i] -= S[i]
|
||||
q_work[(i + 1):m, i] /= q_work[i, i] # scale ith column of L by diagonal element
|
||||
q_work[(i + 1):m, (i + 1):b] -= np.outer(q_work[(i + 1):m, i], q_work[i, (i + 1):b]) # perform Schur complement update
|
||||
# Scale ith column of L by diagonal element.
|
||||
q_work[(i + 1):m, i] /= q_work[i, i]
|
||||
# Perform Schur complement update.
|
||||
q_work[(i + 1):m, (i + 1):b] -= np.outer(q_work[(i + 1):m, i],
|
||||
q_work[i, (i + 1):b])
|
||||
|
||||
L = np.tril(q_work)
|
||||
for i in range(b):
|
||||
L[i, i] = 1
|
||||
U = np.triu(q_work)[:b, :]
|
||||
return ray.get(numpy_to_dist.remote(ray.put(L))), U, S # TODO(rkn): get rid of put
|
||||
# TODO(rkn): Get rid of the put below.
|
||||
return ray.get(core.numpy_to_dist.remote(ray.put(L))), U, S
|
||||
|
||||
|
||||
@ray.remote(num_return_vals=2)
|
||||
def tsqr_hr_helper1(u, s, y_top_block, b):
|
||||
@@ -116,45 +130,60 @@ def tsqr_hr_helper1(u, s, y_top_block, b):
|
||||
t = -1 * np.dot(u, np.dot(s_full, np.linalg.inv(y_top).T))
|
||||
return t, y_top
|
||||
|
||||
|
||||
@ray.remote
|
||||
def tsqr_hr_helper2(s, r_temp):
|
||||
s_full = np.diag(s)
|
||||
return np.dot(s_full, r_temp)
|
||||
|
||||
|
||||
# This is Algorithm 6 from
|
||||
# http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf.
|
||||
@ray.remote(num_return_vals=4)
|
||||
def tsqr_hr(a):
|
||||
"""Algorithm 6 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf"""
|
||||
q, r_temp = tsqr.remote(a)
|
||||
y, u, s = modified_lu.remote(q)
|
||||
y_blocked = ray.get(y)
|
||||
t, y_top = tsqr_hr_helper1.remote(u, s, y_blocked.objectids[0, 0], a.shape[1])
|
||||
t, y_top = tsqr_hr_helper1.remote(u, s, y_blocked.objectids[0, 0],
|
||||
a.shape[1])
|
||||
r = tsqr_hr_helper2.remote(s, r_temp)
|
||||
return ray.get(y), ray.get(t), ray.get(y_top), ray.get(r)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def qr_helper1(a_rc, y_ri, t, W_c):
|
||||
return a_rc - np.dot(y_ri, np.dot(t.T, W_c))
|
||||
|
||||
|
||||
@ray.remote
|
||||
def qr_helper2(y_ri, a_rc):
|
||||
return np.dot(y_ri.T, a_rc)
|
||||
|
||||
|
||||
# This is Algorithm 7 from
|
||||
# http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf.
|
||||
@ray.remote(num_return_vals=2)
|
||||
def qr(a):
|
||||
"""Algorithm 7 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf"""
|
||||
|
||||
m, n = a.shape[0], a.shape[1]
|
||||
k = min(m, n)
|
||||
|
||||
# we will store our scratch work in a_work
|
||||
a_work = DistArray(a.shape, np.copy(a.objectids))
|
||||
a_work = core.DistArray(a.shape, np.copy(a.objectids))
|
||||
|
||||
result_dtype = np.linalg.qr(ray.get(a.objectids[0, 0]))[0].dtype.name
|
||||
r_res = ray.get(zeros.remote([k, n], result_dtype)) # TODO(rkn): It would be preferable not to get this right after creating it.
|
||||
y_res = ray.get(zeros.remote([m, k], result_dtype)) # TODO(rkn): It would be preferable not to get this right after creating it.
|
||||
# TODO(rkn): It would be preferable not to get this right after creating it.
|
||||
r_res = ray.get(core.zeros.remote([k, n], result_dtype))
|
||||
# TODO(rkn): It would be preferable not to get this right after creating it.
|
||||
y_res = ray.get(core.zeros.remote([m, k], result_dtype))
|
||||
Ts = []
|
||||
|
||||
for i in range(min(a.num_blocks[0], a.num_blocks[1])): # this differs from the paper, which says "for i in range(a.num_blocks[1])", but that doesn't seem to make any sense when a.num_blocks[1] > a.num_blocks[0]
|
||||
sub_dist_array = subblocks.remote(a_work, list(range(i, a_work.num_blocks[0])), [i])
|
||||
# The for loop differs from the paper, which says
|
||||
# "for i in range(a.num_blocks[1])", but that doesn't seem to make any sense
|
||||
# when a.num_blocks[1] > a.num_blocks[0].
|
||||
for i in range(min(a.num_blocks[0], a.num_blocks[1])):
|
||||
sub_dist_array = core.subblocks.remote(
|
||||
a_work, list(range(i, a_work.num_blocks[0])), [i])
|
||||
y, t, _, R = tsqr_hr.remote(sub_dist_array)
|
||||
y_val = ray.get(y)
|
||||
|
||||
@@ -167,7 +196,7 @@ def qr(a):
|
||||
r_res.objectids[i, i] = ra.dot.remote(eye_temp, R)
|
||||
else:
|
||||
r_res.objectids[i, i] = R
|
||||
Ts.append(numpy_to_dist.remote(t))
|
||||
Ts.append(core.numpy_to_dist.remote(t))
|
||||
|
||||
for c in range(i + 1, a.num_blocks[1]):
|
||||
W_rcs = []
|
||||
@@ -182,9 +211,14 @@ def qr(a):
|
||||
r_res.objectids[i, c] = a_work.objectids[i, c]
|
||||
|
||||
# construct q_res from Ys and Ts
|
||||
q = eye.remote(m, k, dtype_name=result_dtype)
|
||||
q = core.eye.remote(m, k, dtype_name=result_dtype)
|
||||
for i in range(len(Ts))[::-1]:
|
||||
y_col_block = subblocks.remote(y_res, [], [i])
|
||||
q = subtract.remote(q, dot.remote(y_col_block, dot.remote(Ts[i], dot.remote(transpose.remote(y_col_block), q))))
|
||||
y_col_block = core.subblocks.remote(y_res, [], [i])
|
||||
q = core.subtract.remote(
|
||||
q, core.dot.remote(
|
||||
y_col_block,
|
||||
core.dot.remote(Ts[i],
|
||||
core.dot.remote(core.transpose.remote(y_col_block),
|
||||
q))))
|
||||
|
||||
return ray.get(q), r_res
|
||||
|
||||
@@ -6,13 +6,15 @@ import numpy as np
|
||||
import ray.experimental.array.remote as ra
|
||||
import ray
|
||||
|
||||
from .core import *
|
||||
from .core import DistArray
|
||||
|
||||
|
||||
@ray.remote
|
||||
def normal(shape):
|
||||
num_blocks = DistArray.compute_num_blocks(shape)
|
||||
objectids = np.empty(num_blocks, dtype=object)
|
||||
for index in np.ndindex(*num_blocks):
|
||||
objectids[index] = ra.random.normal.remote(DistArray.compute_block_shape(index, shape))
|
||||
objectids[index] = ra.random.normal.remote(
|
||||
DistArray.compute_block_shape(index, shape))
|
||||
result = DistArray(shape, objectids)
|
||||
return result
|
||||
|
||||
@@ -4,4 +4,10 @@ from __future__ import print_function
|
||||
|
||||
from . import random
|
||||
from . import linalg
|
||||
from .core import *
|
||||
from .core import (zeros, zeros_like, ones, eye, dot, vstack, hstack, subarray,
|
||||
copy, tril, triu, diag, transpose, add, subtract, sum,
|
||||
shape, sum_list)
|
||||
|
||||
__all__ = ["random", "linalg", "zeros", "zeros_like", "ones", "eye", "dot",
|
||||
"vstack", "hstack", "subarray", "copy", "tril", "triu", "diag",
|
||||
"transpose", "add", "subtract", "sum", "shape", "sum_list"]
|
||||
|
||||
@@ -5,82 +5,97 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
import ray
|
||||
|
||||
__all__ = ["zeros", "zeros_like", "ones", "eye", "dot", "vstack", "hstack", "subarray", "copy", "tril", "triu", "diag", "transpose", "add", "subtract", "sum", "shape", "sum_list"]
|
||||
|
||||
@ray.remote
|
||||
def zeros(shape, dtype_name="float", order="C"):
|
||||
return np.zeros(shape, dtype=np.dtype(dtype_name), order=order)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def zeros_like(a, dtype_name="None", order="K", subok=True):
|
||||
dtype_val = None if dtype_name == "None" else np.dtype(dtype_name)
|
||||
return np.zeros_like(a, dtype=dtype_val, order=order, subok=subok)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def ones(shape, dtype_name="float", order="C"):
|
||||
return np.ones(shape, dtype=np.dtype(dtype_name), order=order)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def eye(N, M=-1, k=0, dtype_name="float"):
|
||||
M = N if M == -1 else M
|
||||
return np.eye(N, M=M, k=k, dtype=np.dtype(dtype_name))
|
||||
|
||||
|
||||
@ray.remote
|
||||
def dot(a, b):
|
||||
return np.dot(a, b)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def vstack(*xs):
|
||||
return np.vstack(xs)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def hstack(*xs):
|
||||
return np.hstack(xs)
|
||||
|
||||
# TODO(rkn): instead of this, consider implementing slicing
|
||||
|
||||
# TODO(rkn): Instead of this, consider implementing slicing.
|
||||
# TODO(rkn): Be consistent about using "index" versus "indices".
|
||||
@ray.remote
|
||||
def subarray(a, lower_indices, upper_indices): # TODO(rkn): be consistent about using "index" versus "indices"
|
||||
def subarray(a, lower_indices, upper_indices):
|
||||
return a[[slice(l, u) for (l, u) in zip(lower_indices, upper_indices)]]
|
||||
|
||||
|
||||
@ray.remote
|
||||
def copy(a, order="K"):
|
||||
return np.copy(a, order=order)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def tril(m, k=0):
|
||||
return np.tril(m, k=k)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def triu(m, k=0):
|
||||
return np.triu(m, k=k)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def diag(v, k=0):
|
||||
return np.diag(v, k=k)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def transpose(a, axes=[]):
|
||||
axes = None if axes == [] else axes
|
||||
return np.transpose(a, axes=axes)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def add(x1, x2):
|
||||
return np.add(x1, x2)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def subtract(x1, x2):
|
||||
return np.subtract(x1, x2)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def sum(x, axis=-1):
|
||||
return np.sum(x, axis=axis if axis != -1 else None)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def shape(a):
|
||||
return np.shape(a)
|
||||
|
||||
# We use Any to allow different numerical types as well as numpy arrays.
|
||||
# TODO(rkn):this isn't in the numpy API, so be careful about exposing this.
|
||||
|
||||
@ray.remote
|
||||
def sum_list(*xs):
|
||||
return np.sum(xs, axis=0)
|
||||
|
||||
@@ -8,84 +8,104 @@ import ray
|
||||
__all__ = ["matrix_power", "solve", "tensorsolve", "tensorinv", "inv",
|
||||
"cholesky", "eigvals", "eigvalsh", "pinv", "slogdet", "det",
|
||||
"svd", "eig", "eigh", "lstsq", "norm", "qr", "cond", "matrix_rank",
|
||||
"LinAlgError", "multi_dot"]
|
||||
"multi_dot"]
|
||||
|
||||
|
||||
@ray.remote
|
||||
def matrix_power(M, n):
|
||||
return np.linalg.matrix_power(M, n)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def solve(a, b):
|
||||
return np.linalg.solve(a, b)
|
||||
|
||||
|
||||
@ray.remote(num_return_vals=2)
|
||||
def tensorsolve(a):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ray.remote(num_return_vals=2)
|
||||
def tensorinv(a):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ray.remote
|
||||
def inv(a):
|
||||
return np.linalg.inv(a)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def cholesky(a):
|
||||
return np.linalg.cholesky(a)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def eigvals(a):
|
||||
return np.linalg.eigvals(a)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def eigvalsh(a):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ray.remote
|
||||
def pinv(a):
|
||||
return np.linalg.pinv(a)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def slogdet(a):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ray.remote
|
||||
def det(a):
|
||||
return np.linalg.det(a)
|
||||
|
||||
|
||||
@ray.remote(num_return_vals=3)
|
||||
def svd(a):
|
||||
return np.linalg.svd(a)
|
||||
|
||||
|
||||
@ray.remote(num_return_vals=2)
|
||||
def eig(a):
|
||||
return np.linalg.eig(a)
|
||||
|
||||
|
||||
@ray.remote(num_return_vals=2)
|
||||
def eigh(a):
|
||||
return np.linalg.eigh(a)
|
||||
|
||||
|
||||
@ray.remote(num_return_vals=4)
|
||||
def lstsq(a, b):
|
||||
return np.linalg.lstsq(a)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def norm(x):
|
||||
return np.linalg.norm(x)
|
||||
|
||||
|
||||
@ray.remote(num_return_vals=2)
|
||||
def qr(a):
|
||||
return np.linalg.qr(a)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def cond(x):
|
||||
return np.linalg.cond(x)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def matrix_rank(M):
|
||||
return np.linalg.matrix_rank(M)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def multi_dot(*a):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
import ray
|
||||
|
||||
|
||||
@ray.remote
|
||||
def normal(shape):
|
||||
return np.random.normal(size=shape)
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
def get_local_schedulers(worker):
|
||||
local_schedulers = []
|
||||
for client in worker.redis_client.keys("CL:*"):
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
from collections import deque, OrderedDict
|
||||
|
||||
|
||||
def unflatten(vector, shapes):
|
||||
i = 0
|
||||
arrays = []
|
||||
@@ -15,6 +16,7 @@ def unflatten(vector, shapes):
|
||||
assert len(vector) == i, "Passed weight does not have the correct shape."
|
||||
return arrays
|
||||
|
||||
|
||||
class TensorFlowVariables(object):
|
||||
"""An object used to extract variables from a loss function.
|
||||
|
||||
@@ -38,11 +40,11 @@ class TensorFlowVariables(object):
|
||||
variable_names = []
|
||||
explored_inputs = set([loss])
|
||||
|
||||
# We do a BFS on the dependency graph of the input function to find
|
||||
# We do a BFS on the dependency graph of the input function to find
|
||||
# the variables.
|
||||
while len(queue) != 0:
|
||||
tf_obj = queue.popleft()
|
||||
|
||||
|
||||
# The object put into the queue is not necessarily an operation, so we
|
||||
# want the op attribute to get the operation underlying the object.
|
||||
# Only operations contain the inputs that we can explore.
|
||||
@@ -61,14 +63,16 @@ class TensorFlowVariables(object):
|
||||
if "Variable" in tf_obj.node_def.op:
|
||||
variable_names.append(tf_obj.node_def.name)
|
||||
self.variables = OrderedDict()
|
||||
for v in [v for v in tf.global_variables() if v.op.node_def.name in variable_names]:
|
||||
for v in [v for v in tf.global_variables()
|
||||
if v.op.node_def.name in variable_names]:
|
||||
self.variables[v.op.node_def.name] = v
|
||||
self.placeholders = dict()
|
||||
self.assignment_nodes = []
|
||||
|
||||
# Create new placeholders to put in custom weights.
|
||||
for k, var in self.variables.items():
|
||||
self.placeholders[k] = tf.placeholder(var.value().dtype, var.get_shape().as_list())
|
||||
self.placeholders[k] = tf.placeholder(var.value().dtype,
|
||||
var.get_shape().as_list())
|
||||
self.assignment_nodes.append(var.assign(self.placeholders[k]))
|
||||
|
||||
def set_session(self, sess):
|
||||
@@ -76,24 +80,30 @@ class TensorFlowVariables(object):
|
||||
self.sess = sess
|
||||
|
||||
def get_flat_size(self):
|
||||
return sum([np.prod(v.get_shape().as_list()) for v in self.variables.values()])
|
||||
return sum([np.prod(v.get_shape().as_list())
|
||||
for v in self.variables.values()])
|
||||
|
||||
def _check_sess(self):
|
||||
"""Checks if the session is set, and if not throw an error message."""
|
||||
assert self.sess is not None, "The session is not set. Set the session either by passing it into the TensorFlowVariables constructor or by calling set_session(sess)."
|
||||
|
||||
assert self.sess is not None, ("The session is not set. Set the session "
|
||||
"either by passing it into the "
|
||||
"TensorFlowVariables constructor or by "
|
||||
"calling set_session(sess).")
|
||||
|
||||
def get_flat(self):
|
||||
"""Gets the weights and returns them as a flat array."""
|
||||
self._check_sess()
|
||||
return np.concatenate([v.eval(session=self.sess).flatten() for v in self.variables.values()])
|
||||
|
||||
return np.concatenate([v.eval(session=self.sess).flatten()
|
||||
for v in self.variables.values()])
|
||||
|
||||
def set_flat(self, new_weights):
|
||||
"""Sets the weights to new_weights, converting from a flat array."""
|
||||
self._check_sess()
|
||||
shapes = [v.get_shape().as_list() for v in self.variables.values()]
|
||||
arrays = unflatten(new_weights, shapes)
|
||||
placeholders = [self.placeholders[k] for k, v in self.variables.items()]
|
||||
self.sess.run(self.assignment_nodes, feed_dict=dict(zip(placeholders,arrays)))
|
||||
self.sess.run(self.assignment_nodes,
|
||||
feed_dict=dict(zip(placeholders, arrays)))
|
||||
|
||||
def get_weights(self):
|
||||
"""Returns the weights of the variables of the loss function in a list."""
|
||||
@@ -103,4 +113,7 @@ class TensorFlowVariables(object):
|
||||
def set_weights(self, new_weights):
|
||||
"""Sets the weights to new_weights."""
|
||||
self._check_sess()
|
||||
self.sess.run(self.assignment_nodes, feed_dict={self.placeholders[name]: value for (name, value) in new_weights.items() if name in self.placeholders})
|
||||
self.sess.run(self.assignment_nodes,
|
||||
feed_dict={self.placeholders[name]: value
|
||||
for (name, value) in new_weights.items()
|
||||
if name in self.placeholders})
|
||||
|
||||
@@ -9,6 +9,7 @@ import sys
|
||||
|
||||
import ray
|
||||
|
||||
|
||||
def tarred_directory_as_bytes(source_dir):
|
||||
"""Tar a directory and return it as a byte string.
|
||||
|
||||
@@ -26,6 +27,7 @@ def tarred_directory_as_bytes(source_dir):
|
||||
string_file.seek(0)
|
||||
return string_file.read()
|
||||
|
||||
|
||||
def tarred_bytes_to_directory(tarred_bytes, target_dir):
|
||||
"""Take a byte string and untar it.
|
||||
|
||||
@@ -38,6 +40,7 @@ def tarred_bytes_to_directory(tarred_bytes, target_dir):
|
||||
with tarfile.open(fileobj=string_file) as tar:
|
||||
tar.extractall(path=target_dir)
|
||||
|
||||
|
||||
def copy_directory(source_dir, target_dir=None):
|
||||
"""Copy a local directory to each machine in the Ray cluster.
|
||||
|
||||
@@ -45,17 +48,18 @@ def copy_directory(source_dir, target_dir=None):
|
||||
example, source_dir can be /a/b/c and target_dir can be /d/e/c. In this case,
|
||||
the directory /d/e will be added to the Python path of each worker.
|
||||
|
||||
Note that this method is not completely safe to use. For example, workers that
|
||||
do not do the copying and only set their paths (only one worker per node does
|
||||
the copying) may try to execute functions that use the files in the directory
|
||||
being copied before the directory being copied has finished untarring.
|
||||
Note that this method is not completely safe to use. For example, workers
|
||||
that do not do the copying and only set their paths (only one worker per node
|
||||
does the copying) may try to execute functions that use the files in the
|
||||
directory being copied before the directory being copied has finished
|
||||
untarring.
|
||||
|
||||
Args:
|
||||
source_dir (str): The directory to copy.
|
||||
target_dir (str): The location to copy it to on the other machines. If this
|
||||
is not provided, the source_dir will be used. If it is provided and is
|
||||
different from source_dir, the source_dir also be copied to the target_dir
|
||||
location on this machine.
|
||||
different from source_dir, the source_dir also be copied to the
|
||||
target_dir location on this machine.
|
||||
"""
|
||||
target_dir = source_dir if target_dir is None else target_dir
|
||||
source_dir = os.path.abspath(source_dir)
|
||||
@@ -63,8 +67,10 @@ def copy_directory(source_dir, target_dir=None):
|
||||
source_basename = os.path.basename(source_dir)
|
||||
target_basename = os.path.basename(target_dir)
|
||||
if source_basename != target_basename:
|
||||
raise Exception("The source_dir and target_dir must have the same base name, {} != {}".format(source_basename, target_basename))
|
||||
raise Exception("The source_dir and target_dir must have the same base "
|
||||
"name, {} != {}".format(source_basename, target_basename))
|
||||
tarred_bytes = tarred_directory_as_bytes(source_dir)
|
||||
|
||||
def f(worker_info):
|
||||
if worker_info["counter"] == 0:
|
||||
tarred_bytes_to_directory(tarred_bytes, os.path.dirname(target_dir))
|
||||
|
||||
@@ -2,4 +2,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from .global_scheduler_services import *
|
||||
from .global_scheduler_services import start_global_scheduler
|
||||
|
||||
__all__ = ["start_global_scheduler"]
|
||||
|
||||
@@ -6,6 +6,7 @@ import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
|
||||
def start_global_scheduler(redis_address, use_valgrind=False,
|
||||
use_profiler=False, stdout_file=None,
|
||||
stderr_file=None):
|
||||
@@ -15,8 +16,8 @@ def start_global_scheduler(redis_address, use_valgrind=False,
|
||||
redis_address (str): The address of the Redis instance.
|
||||
use_valgrind (bool): True if the global scheduler should be started inside
|
||||
of valgrind. If this is True, use_profiler must be False.
|
||||
use_profiler (bool): True if the global scheduler should be started inside a
|
||||
profiler. If this is True, use_valgrind must be False.
|
||||
use_profiler (bool): True if the global scheduler should be started inside
|
||||
a profiler. If this is True, use_valgrind must be False.
|
||||
stdout_file: A file handle opened for writing to redirect stdout to. If no
|
||||
redirection should happen, then this should be None.
|
||||
stderr_file: A file handle opened for writing to redirect stderr to. If no
|
||||
@@ -27,7 +28,9 @@ def start_global_scheduler(redis_address, use_valgrind=False,
|
||||
"""
|
||||
if use_valgrind and use_profiler:
|
||||
raise Exception("Cannot use valgrind and profiler at the same time.")
|
||||
global_scheduler_executable = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../core/src/global_scheduler/global_scheduler")
|
||||
global_scheduler_executable = os.path.join(
|
||||
os.path.abspath(os.path.dirname(__file__)),
|
||||
"../core/src/global_scheduler/global_scheduler")
|
||||
command = [global_scheduler_executable, "-r", redis_address]
|
||||
if use_valgrind:
|
||||
pid = subprocess.Popen(["valgrind",
|
||||
@@ -35,7 +38,7 @@ def start_global_scheduler(redis_address, use_valgrind=False,
|
||||
"--leak-check=full",
|
||||
"--show-leak-kinds=all",
|
||||
"--error-exitcode=1"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
elif use_profiler:
|
||||
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
|
||||
|
||||
@@ -7,16 +7,14 @@ import os
|
||||
import random
|
||||
import redis
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import ray.global_scheduler as global_scheduler
|
||||
import ray.local_scheduler as local_scheduler
|
||||
import ray.plasma as plasma
|
||||
from ray.plasma.utils import random_object_id, generate_metadata, write_to_data_buffer, create_object_with_id, create_object
|
||||
from ray.plasma.utils import create_object
|
||||
|
||||
from ray import services
|
||||
|
||||
@@ -39,21 +37,27 @@ TASK_STATUS_DONE = 16
|
||||
DB_CLIENT_PREFIX = "CL:"
|
||||
TASK_PREFIX = "TT:"
|
||||
|
||||
|
||||
def random_driver_id():
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
|
||||
|
||||
def random_task_id():
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
|
||||
|
||||
def random_function_id():
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
|
||||
|
||||
def random_object_id():
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
|
||||
|
||||
def new_port():
|
||||
return random.randint(10000, 65535)
|
||||
|
||||
|
||||
class TestGlobalScheduler(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@@ -62,9 +66,11 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
redis_port, self.redis_process = services.start_redis(cleanup=False)
|
||||
redis_address = services.address(node_ip_address, redis_port)
|
||||
# Create a Redis client.
|
||||
self.redis_client = redis.StrictRedis(host=node_ip_address, port=redis_port)
|
||||
self.redis_client = redis.StrictRedis(host=node_ip_address,
|
||||
port=redis_port)
|
||||
# Start one global scheduler.
|
||||
self.p1 = global_scheduler.start_global_scheduler(redis_address, use_valgrind=USE_VALGRIND)
|
||||
self.p1 = global_scheduler.start_global_scheduler(
|
||||
redis_address, use_valgrind=USE_VALGRIND)
|
||||
self.plasma_store_pids = []
|
||||
self.plasma_manager_pids = []
|
||||
self.local_scheduler_pids = []
|
||||
@@ -76,11 +82,15 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
plasma_store_name, p2 = plasma.start_plasma_store()
|
||||
self.plasma_store_pids.append(p2)
|
||||
# Start the Plasma manager.
|
||||
# Assumption: Plasma manager name and port are randomly generated by the plasma module.
|
||||
plasma_manager_name, p3, plasma_manager_port = plasma.start_plasma_manager(plasma_store_name, redis_address)
|
||||
# Assumption: Plasma manager name and port are randomly generated by the
|
||||
# plasma module.
|
||||
manager_info = plasma.start_plasma_manager(plasma_store_name,
|
||||
redis_address)
|
||||
plasma_manager_name, p3, plasma_manager_port = manager_info
|
||||
self.plasma_manager_pids.append(p3)
|
||||
plasma_address = "{}:{}".format(node_ip_address, plasma_manager_port)
|
||||
plasma_client = plasma.PlasmaClient(plasma_store_name, plasma_manager_name)
|
||||
plasma_client = plasma.PlasmaClient(plasma_store_name,
|
||||
plasma_manager_name)
|
||||
self.plasma_clients.append(plasma_client)
|
||||
# Start the local scheduler.
|
||||
local_scheduler_name, p4 = local_scheduler.start_local_scheduler(
|
||||
@@ -91,7 +101,7 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
static_resource_list=[10, 0])
|
||||
# Connect to the scheduler.
|
||||
local_scheduler_client = local_scheduler.LocalSchedulerClient(
|
||||
local_scheduler_name, NIL_ACTOR_ID, False)
|
||||
local_scheduler_name, NIL_ACTOR_ID, False)
|
||||
self.local_scheduler_clients.append(local_scheduler_client)
|
||||
self.local_scheduler_pids.append(p4)
|
||||
|
||||
@@ -149,11 +159,13 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
return db_client_id
|
||||
|
||||
def test_task_default_resources(self):
|
||||
task1 = local_scheduler.Task(random_driver_id(), random_function_id(), [random_object_id()], 0, random_task_id(), 0)
|
||||
task1 = local_scheduler.Task(random_driver_id(), random_function_id(),
|
||||
[random_object_id()], 0, random_task_id(), 0)
|
||||
self.assertEqual(task1.required_resources(), [1.0, 0.0])
|
||||
task2 = local_scheduler.Task(random_driver_id(), random_function_id(),
|
||||
[random_object_id()], 0, random_task_id(), 0,
|
||||
local_scheduler.ObjectID(NIL_ACTOR_ID), 0, [1.0, 2.0])
|
||||
local_scheduler.ObjectID(NIL_ACTOR_ID), 0,
|
||||
[1.0, 2.0])
|
||||
self.assertEqual(task2.required_resources(), [1.0, 2.0])
|
||||
|
||||
def test_redis_only_single_task(self):
|
||||
@@ -162,31 +174,37 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
task state transitions in Redis only. TODO(atumanov): implement.
|
||||
"""
|
||||
# Check precondition for this test:
|
||||
# There should be 2n+1 db clients: the global scheduler + one local scheduler and one plasma per node.
|
||||
self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
|
||||
2 * NUM_CLUSTER_NODES + 1)
|
||||
# There should be 2n+1 db clients: the global scheduler + one local
|
||||
# scheduler and one plasma per node.
|
||||
self.assertEqual(
|
||||
len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
|
||||
2 * NUM_CLUSTER_NODES + 1)
|
||||
db_client_id = self.get_plasma_manager_id()
|
||||
assert(db_client_id != None)
|
||||
assert(db_client_id is not None)
|
||||
assert(db_client_id.startswith(b"CL:"))
|
||||
db_client_id = db_client_id[len(b"CL:"):] # Remove the CL: prefix.
|
||||
db_client_id = db_client_id[len(b"CL:"):] # Remove the CL: prefix.
|
||||
|
||||
def test_integration_single_task(self):
|
||||
# There should be three db clients, the global scheduler, the local
|
||||
# scheduler, and the plasma manager.
|
||||
self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
|
||||
2 * NUM_CLUSTER_NODES + 1)
|
||||
self.assertEqual(
|
||||
len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
|
||||
2 * NUM_CLUSTER_NODES + 1)
|
||||
|
||||
num_return_vals = [0, 1, 2, 3, 5, 10]
|
||||
# Insert the object into Redis.
|
||||
data_size = 0xf1f0
|
||||
metadata_size = 0x40
|
||||
plasma_client = self.plasma_clients[0]
|
||||
object_dep, memory_buffer, metadata = create_object(plasma_client, data_size, metadata_size, seal=True)
|
||||
object_dep, memory_buffer, metadata = create_object(
|
||||
plasma_client, data_size, metadata_size, seal=True)
|
||||
|
||||
# Sleep before submitting task to local scheduler.
|
||||
time.sleep(0.1)
|
||||
# Submit a task to Redis.
|
||||
task = local_scheduler.Task(random_driver_id(), random_function_id(), [local_scheduler.ObjectID(object_dep)], num_return_vals[0], random_task_id(), 0)
|
||||
task = local_scheduler.Task(random_driver_id(), random_function_id(),
|
||||
[local_scheduler.ObjectID(object_dep)],
|
||||
num_return_vals[0], random_task_id(), 0)
|
||||
self.local_scheduler_clients[0].submit(task)
|
||||
time.sleep(0.1)
|
||||
# There should now be a task in Redis, and it should get assigned to the
|
||||
@@ -217,8 +235,9 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
def integration_many_tasks_helper(self, timesync=True):
|
||||
# There should be three db clients, the global scheduler, the local
|
||||
# scheduler, and the plasma manager.
|
||||
self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
|
||||
2 * NUM_CLUSTER_NODES + 1)
|
||||
self.assertEqual(
|
||||
len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
|
||||
2 * NUM_CLUSTER_NODES + 1)
|
||||
num_return_vals = [0, 1, 2, 3, 5, 10]
|
||||
|
||||
# Submit a bunch of tasks to Redis.
|
||||
@@ -228,11 +247,16 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
data_size = np.random.randint(1 << 20)
|
||||
metadata_size = np.random.randint(1 << 10)
|
||||
plasma_client = self.plasma_clients[0]
|
||||
object_dep, memory_buffer, metadata = create_object(plasma_client, data_size, metadata_size, seal=True)
|
||||
object_dep, memory_buffer, metadata = create_object(plasma_client,
|
||||
data_size,
|
||||
metadata_size,
|
||||
seal=True)
|
||||
if timesync:
|
||||
# Give 10ms for object info handler to fire (long enough to yield CPU).
|
||||
time.sleep(0.010)
|
||||
task = local_scheduler.Task(random_driver_id(), random_function_id(), [local_scheduler.ObjectID(object_dep)], num_return_vals[0], random_task_id(), 0)
|
||||
task = local_scheduler.Task(random_driver_id(), random_function_id(),
|
||||
[local_scheduler.ObjectID(object_dep)],
|
||||
num_return_vals[0], random_task_id(), 0)
|
||||
self.local_scheduler_clients[0].submit(task)
|
||||
# Check that there are the correct number of tasks in Redis and that they
|
||||
# all get assigned to the local scheduler.
|
||||
@@ -243,17 +267,18 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
self.assertLessEqual(len(task_entries), num_tasks)
|
||||
# First, check if all tasks made it to Redis.
|
||||
if len(task_entries) == num_tasks:
|
||||
task_contents = [self.redis_client.hgetall(task_entries[i]) for i in range(len(task_entries))]
|
||||
task_contents = [self.redis_client.hgetall(task_entries[i])
|
||||
for i in range(len(task_entries))]
|
||||
task_statuses = [int(contents[b"state"]) for contents in task_contents]
|
||||
self.assertTrue(all([
|
||||
status in [TASK_STATUS_WAITING,
|
||||
TASK_STATUS_SCHEDULED,
|
||||
TASK_STATUS_QUEUED] for status in task_statuses
|
||||
]))
|
||||
self.assertTrue(all([status in [TASK_STATUS_WAITING,
|
||||
TASK_STATUS_SCHEDULED,
|
||||
TASK_STATUS_QUEUED]
|
||||
for status in task_statuses]))
|
||||
num_tasks_done = task_statuses.count(TASK_STATUS_QUEUED)
|
||||
num_tasks_scheduled = task_statuses.count(TASK_STATUS_SCHEDULED)
|
||||
num_tasks_waiting = task_statuses.count(TASK_STATUS_WAITING)
|
||||
print("tasks in Redis = {}, tasks waiting = {}, tasks scheduled = {}, tasks queued = {}, retries left = {}"
|
||||
print("tasks in Redis = {}, tasks waiting = {}, tasks scheduled = {}, "
|
||||
"tasks queued = {}, retries left = {}"
|
||||
.format(len(task_entries), num_tasks_waiting,
|
||||
num_tasks_scheduled, num_tasks_done, num_retries))
|
||||
if all([status == TASK_STATUS_QUEUED for status in task_statuses]):
|
||||
@@ -275,6 +300,7 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
# notifications.
|
||||
self.integration_many_tasks_helper(timesync=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) > 1:
|
||||
# Pop the argument so we don't mess with unittest's own argument parser.
|
||||
|
||||
@@ -2,5 +2,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.core.src.local_scheduler.liblocal_scheduler_library import *
|
||||
from .local_scheduler_services import *
|
||||
from ray.core.src.local_scheduler.liblocal_scheduler_library import (
|
||||
Task, LocalSchedulerClient, ObjectID, check_simple_value, task_from_string,
|
||||
task_to_string)
|
||||
from .local_scheduler_services import start_local_scheduler
|
||||
|
||||
__all__ = ["Task", "LocalSchedulerClient", "ObjectID", "check_simple_value",
|
||||
"task_from_string", "task_to_string", "start_local_scheduler"]
|
||||
|
||||
@@ -7,9 +7,11 @@ import random
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
|
||||
def random_name():
|
||||
return str(random.randint(0, 99999999))
|
||||
|
||||
|
||||
def start_local_scheduler(plasma_store_name,
|
||||
plasma_manager_name=None,
|
||||
worker_path=None,
|
||||
@@ -38,8 +40,8 @@ def start_local_scheduler(plasma_store_name,
|
||||
running on.
|
||||
redis_address (str): The address of the Redis instance to connect to. If
|
||||
this is not provided, then the local scheduler will not connect to Redis.
|
||||
use_valgrind (bool): True if the local scheduler should be started inside of
|
||||
valgrind. If this is True, use_profiler must be False.
|
||||
use_valgrind (bool): True if the local scheduler should be started inside
|
||||
of valgrind. If this is True, use_profiler must be False.
|
||||
use_profiler (bool): True if the local scheduler should be started inside a
|
||||
profiler. If this is True, use_valgrind must be False.
|
||||
stdout_file: A file handle opened for writing to redirect stdout to. If no
|
||||
@@ -56,11 +58,14 @@ def start_local_scheduler(plasma_store_name,
|
||||
A tuple of the name of the local scheduler socket and the process ID of the
|
||||
local scheduler process.
|
||||
"""
|
||||
if (plasma_manager_name == None) != (redis_address == None):
|
||||
raise Exception("If one of the plasma_manager_name and the redis_address is provided, then both must be provided.")
|
||||
if (plasma_manager_name is None) != (redis_address is None):
|
||||
raise Exception("If one of the plasma_manager_name and the redis_address "
|
||||
"is provided, then both must be provided.")
|
||||
if use_valgrind and use_profiler:
|
||||
raise Exception("Cannot use valgrind and profiler at the same time.")
|
||||
local_scheduler_executable = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../core/src/local_scheduler/local_scheduler")
|
||||
local_scheduler_executable = os.path.join(os.path.dirname(
|
||||
os.path.abspath(__file__)),
|
||||
"../core/src/local_scheduler/local_scheduler")
|
||||
local_scheduler_name = "/tmp/scheduler{}".format(random_name())
|
||||
command = [local_scheduler_executable,
|
||||
"-s", local_scheduler_name,
|
||||
@@ -90,8 +95,10 @@ def start_local_scheduler(plasma_store_name,
|
||||
if plasma_address is not None:
|
||||
command += ["-a", plasma_address]
|
||||
if static_resource_list is not None:
|
||||
assert all([isinstance(resource, int) or isinstance(resource, float) for resource in static_resource_list])
|
||||
command += ["-c", ",".join([str(resource) for resource in static_resource_list])]
|
||||
assert all([isinstance(resource, int) or isinstance(resource, float)
|
||||
for resource in static_resource_list])
|
||||
command += ["-c", ",".join([str(resource) for resource
|
||||
in static_resource_list])]
|
||||
|
||||
if use_valgrind:
|
||||
pid = subprocess.Popen(["valgrind",
|
||||
@@ -99,7 +106,7 @@ def start_local_scheduler(plasma_store_name,
|
||||
"--leak-check=full",
|
||||
"--show-leak-kinds=all",
|
||||
"--error-exitcode=1"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
elif use_profiler:
|
||||
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
|
||||
|
||||
@@ -4,9 +4,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
@@ -20,18 +18,23 @@ ID_SIZE = 20
|
||||
|
||||
NIL_ACTOR_ID = 20 * b"\xff"
|
||||
|
||||
|
||||
def random_object_id():
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
|
||||
|
||||
def random_driver_id():
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
|
||||
|
||||
def random_task_id():
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
|
||||
|
||||
def random_function_id():
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
|
||||
|
||||
class TestLocalSchedulerClient(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@@ -39,10 +42,11 @@ class TestLocalSchedulerClient(unittest.TestCase):
|
||||
plasma_store_name, self.p1 = plasma.start_plasma_store()
|
||||
self.plasma_client = plasma.PlasmaClient(plasma_store_name)
|
||||
# Start a local scheduler.
|
||||
scheduler_name, self.p2 = local_scheduler.start_local_scheduler(plasma_store_name, use_valgrind=USE_VALGRIND)
|
||||
scheduler_name, self.p2 = local_scheduler.start_local_scheduler(
|
||||
plasma_store_name, use_valgrind=USE_VALGRIND)
|
||||
# Connect to the scheduler.
|
||||
self.local_scheduler_client = local_scheduler.LocalSchedulerClient(
|
||||
scheduler_name, NIL_ACTOR_ID, False)
|
||||
scheduler_name, NIL_ACTOR_ID, False)
|
||||
|
||||
def tearDown(self):
|
||||
# Check that the processes are still alive.
|
||||
@@ -70,37 +74,38 @@ class TestLocalSchedulerClient(unittest.TestCase):
|
||||
self.plasma_client.seal(object_id.id())
|
||||
# Define some arguments to use for the tasks.
|
||||
args_list = [
|
||||
[],
|
||||
#{},
|
||||
#(),
|
||||
1 * [1],
|
||||
10 * [1],
|
||||
100 * [1],
|
||||
1000 * [1],
|
||||
1 * ["a"],
|
||||
10 * ["a"],
|
||||
100 * ["a"],
|
||||
1000 * ["a"],
|
||||
[1, 1.3, 1 << 100, "hi", u"hi", [1, 2]],
|
||||
object_ids[:1],
|
||||
object_ids[:2],
|
||||
object_ids[:3],
|
||||
object_ids[:4],
|
||||
object_ids[:5],
|
||||
object_ids[:10],
|
||||
object_ids[:100],
|
||||
object_ids[:256],
|
||||
[1, object_ids[0]],
|
||||
[object_ids[0], "a"],
|
||||
[1, object_ids[0], "a"],
|
||||
[object_ids[0], 1, object_ids[1], "a"],
|
||||
object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
|
||||
object_ids + 100 * ["a"] + object_ids
|
||||
[],
|
||||
[{}],
|
||||
[()],
|
||||
1 * [1],
|
||||
10 * [1],
|
||||
100 * [1],
|
||||
1000 * [1],
|
||||
1 * ["a"],
|
||||
10 * ["a"],
|
||||
100 * ["a"],
|
||||
1000 * ["a"],
|
||||
[1, 1.3, 1 << 100, "hi", u"hi", [1, 2]],
|
||||
object_ids[:1],
|
||||
object_ids[:2],
|
||||
object_ids[:3],
|
||||
object_ids[:4],
|
||||
object_ids[:5],
|
||||
object_ids[:10],
|
||||
object_ids[:100],
|
||||
object_ids[:256],
|
||||
[1, object_ids[0]],
|
||||
[object_ids[0], "a"],
|
||||
[1, object_ids[0], "a"],
|
||||
[object_ids[0], 1, object_ids[1], "a"],
|
||||
object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
|
||||
object_ids + 100 * ["a"] + object_ids
|
||||
]
|
||||
|
||||
for args in args_list:
|
||||
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
|
||||
task = local_scheduler.Task(random_driver_id(), function_id, args, num_return_vals, random_task_id(), 0)
|
||||
task = local_scheduler.Task(random_driver_id(), function_id, args,
|
||||
num_return_vals, random_task_id(), 0)
|
||||
# Submit a task.
|
||||
self.local_scheduler_client.submit(task)
|
||||
# Get the task.
|
||||
@@ -119,7 +124,8 @@ class TestLocalSchedulerClient(unittest.TestCase):
|
||||
# Submit all of the tasks.
|
||||
for args in args_list:
|
||||
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
|
||||
task = local_scheduler.Task(random_driver_id(), function_id, args, num_return_vals, random_task_id(), 0)
|
||||
task = local_scheduler.Task(random_driver_id(), function_id, args,
|
||||
num_return_vals, random_task_id(), 0)
|
||||
self.local_scheduler_client.submit(task)
|
||||
# Get all of the tasks.
|
||||
for args in args_list:
|
||||
@@ -129,8 +135,10 @@ class TestLocalSchedulerClient(unittest.TestCase):
|
||||
def test_scheduling_when_objects_ready(self):
|
||||
# Create a task and submit it.
|
||||
object_id = random_object_id()
|
||||
task = local_scheduler.Task(random_driver_id(), random_function_id(), [object_id], 0, random_task_id(), 0)
|
||||
task = local_scheduler.Task(random_driver_id(), random_function_id(),
|
||||
[object_id], 0, random_task_id(), 0)
|
||||
self.local_scheduler_client.submit(task)
|
||||
|
||||
# Launch a thread to get the task.
|
||||
def get_task():
|
||||
self.local_scheduler_client.get_task()
|
||||
@@ -149,7 +157,9 @@ class TestLocalSchedulerClient(unittest.TestCase):
|
||||
# Create a task with two dependencies and submit it.
|
||||
object_id1 = random_object_id()
|
||||
object_id2 = random_object_id()
|
||||
task = local_scheduler.Task(random_driver_id(), random_function_id(), [object_id1, object_id2], 0, random_task_id(), 0)
|
||||
task = local_scheduler.Task(random_driver_id(), random_function_id(),
|
||||
[object_id1, object_id2], 0, random_task_id(),
|
||||
0)
|
||||
self.local_scheduler_client.submit(task)
|
||||
|
||||
# Launch a thread to get the task.
|
||||
@@ -196,9 +206,10 @@ class TestLocalSchedulerClient(unittest.TestCase):
|
||||
# Wait until the thread finishes so that we know the task was scheduled.
|
||||
t.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) > 1:
|
||||
# pop the argument so we don't mess with unittest's own argument parser
|
||||
# Pop the argument so we don't mess with unittest's own argument parser.
|
||||
if sys.argv[-1] == "valgrind":
|
||||
arg = sys.argv.pop()
|
||||
USE_VALGRIND = True
|
||||
|
||||
@@ -3,13 +3,13 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import redis
|
||||
import time
|
||||
|
||||
from ray.services import get_ip_address
|
||||
from ray.services import get_port
|
||||
|
||||
|
||||
class LogMonitor(object):
|
||||
"""A monitor process for monitoring Ray log files.
|
||||
|
||||
@@ -27,15 +27,17 @@ class LogMonitor(object):
|
||||
def __init__(self, redis_ip_address, redis_port, node_ip_address):
|
||||
"""Initialize the log monitor object."""
|
||||
self.node_ip_address = node_ip_address
|
||||
self.redis_client = redis.StrictRedis(host=redis_ip_address, port=redis_port)
|
||||
self.redis_client = redis.StrictRedis(host=redis_ip_address,
|
||||
port=redis_port)
|
||||
self.log_files = {}
|
||||
self.log_file_handles = {}
|
||||
|
||||
def update_log_filenames(self):
|
||||
"""Get the most up-to-date list of log files to monitor from Redis."""
|
||||
num_current_log_files = len(self.log_files)
|
||||
new_log_filenames = self.redis_client.lrange("LOG_FILENAMES:{}".format(self.node_ip_address),
|
||||
num_current_log_files, -1)
|
||||
new_log_filenames = self.redis_client.lrange(
|
||||
"LOG_FILENAMES:{}".format(self.node_ip_address),
|
||||
num_current_log_files, -1)
|
||||
for log_filename in new_log_filenames:
|
||||
print("Beginning to track file {}".format(log_filename))
|
||||
assert log_filename not in self.log_files
|
||||
@@ -50,7 +52,8 @@ class LogMonitor(object):
|
||||
# If there are any new lines, cache them and also push them to Redis.
|
||||
if len(new_lines) > 0:
|
||||
self.log_files[log_filename] += new_lines
|
||||
redis_key = "LOGFILE:{}:{}".format(self.node_ip_address, log_filename.decode("ascii"))
|
||||
redis_key = "LOGFILE:{}:{}".format(self.node_ip_address,
|
||||
log_filename.decode("ascii"))
|
||||
self.redis_client.rpush(redis_key, *new_lines)
|
||||
else:
|
||||
try:
|
||||
@@ -69,6 +72,7 @@ class LogMonitor(object):
|
||||
self.check_log_files_and_push_updates()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=("Parse Redis server for the "
|
||||
"log monitor to connect to."))
|
||||
|
||||
+22
-17
@@ -3,7 +3,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import binascii
|
||||
from collections import Counter
|
||||
import logging
|
||||
import redis
|
||||
@@ -13,7 +12,8 @@ from ray.services import get_ip_address
|
||||
from ray.services import get_port
|
||||
|
||||
# Import flatbuffer bindings.
|
||||
from ray.core.generated.SubscribeToDBClientTableReply import SubscribeToDBClientTableReply
|
||||
from ray.core.generated.SubscribeToDBClientTableReply \
|
||||
import SubscribeToDBClientTableReply
|
||||
from ray.core.generated.TaskReply import TaskReply
|
||||
|
||||
# These variables must be kept in sync with the C codebase.
|
||||
@@ -41,6 +41,7 @@ logging.basicConfig()
|
||||
log = logging.getLogger()
|
||||
log.setLevel(logging.WARN)
|
||||
|
||||
|
||||
class Monitor(object):
|
||||
"""A monitor for Ray processes.
|
||||
|
||||
@@ -92,7 +93,8 @@ class Monitor(object):
|
||||
TASK_STATUS_LOST. A local scheduler is deemed dead if it is in
|
||||
self.dead_local_schedulers.
|
||||
"""
|
||||
task_ids = self.redis.scan_iter(match="{prefix}*".format(prefix=TASK_PREFIX))
|
||||
task_ids = self.redis.scan_iter(
|
||||
match="{prefix}*".format(prefix=TASK_PREFIX))
|
||||
num_tasks_updated = 0
|
||||
for task_id in task_ids:
|
||||
task_id = task_id[len(TASK_PREFIX):]
|
||||
@@ -123,11 +125,13 @@ class Monitor(object):
|
||||
"""
|
||||
# TODO(swang): Also kill the associated plasma store, since it's no longer
|
||||
# reachable without a plasma manager.
|
||||
object_ids = self.redis.scan_iter(match="{prefix}*".format(prefix=OBJECT_PREFIX))
|
||||
object_ids = self.redis.scan_iter(
|
||||
match="{prefix}*".format(prefix=OBJECT_PREFIX))
|
||||
num_objects_removed = 0
|
||||
for object_id in object_ids:
|
||||
object_id = object_id[len(OBJECT_PREFIX):]
|
||||
managers = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", object_id)
|
||||
managers = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
object_id)
|
||||
for manager in managers:
|
||||
if manager in self.dead_plasma_managers:
|
||||
# If the object was on a dead plasma manager, remove that location
|
||||
@@ -149,7 +153,8 @@ class Monitor(object):
|
||||
not miss any notifications for deleted clients that occurred before we
|
||||
subscribed.
|
||||
"""
|
||||
db_client_keys = self.redis.keys("{prefix}*".format(prefix=DB_CLIENT_PREFIX))
|
||||
db_client_keys = self.redis.keys(
|
||||
"{prefix}*".format(prefix=DB_CLIENT_PREFIX))
|
||||
for db_client_key in db_client_keys:
|
||||
db_client_id = db_client_key[len(DB_CLIENT_PREFIX):]
|
||||
client_type, deleted = self.redis.hmget(db_client_key,
|
||||
@@ -175,10 +180,10 @@ class Monitor(object):
|
||||
flatbuffer. Deletions are processed, insertions are ignored. Cleanup of the
|
||||
associated state in the state tables should be handled by the caller.
|
||||
"""
|
||||
notification_object = SubscribeToDBClientTableReply.GetRootAsSubscribeToDBClientTableReply(data, 0)
|
||||
notification_object = (SubscribeToDBClientTableReply
|
||||
.GetRootAsSubscribeToDBClientTableReply(data, 0))
|
||||
db_client_id = notification_object.DbClientId()
|
||||
client_type = notification_object.ClientType()
|
||||
auxiliary_address = notification_object.AuxAddress()
|
||||
is_insertion = notification_object.IsInsertion()
|
||||
|
||||
# If the update was an insertion, we ignore it.
|
||||
@@ -227,7 +232,6 @@ class Monitor(object):
|
||||
if not self.subscribed[channel]:
|
||||
# If the data was an integer, then the message was a response to an
|
||||
# initial subscription request.
|
||||
is_subscribe = int(data)
|
||||
message_handler = self.subscribe_handler
|
||||
elif channel == PLASMA_MANAGER_HEARTBEAT_CHANNEL:
|
||||
assert(self.subscribed[channel])
|
||||
@@ -264,12 +268,11 @@ class Monitor(object):
|
||||
self.cleanup_task_table()
|
||||
if len(self.dead_plasma_managers) > 0:
|
||||
self.cleanup_object_table()
|
||||
log.debug("{} dead local schedulers, {} plasma "
|
||||
"managers total, {} dead plasma managers".format(
|
||||
len(self.dead_local_schedulers),
|
||||
len(self.live_plasma_managers) + len(self.dead_plasma_managers),
|
||||
len(self.dead_plasma_managers)
|
||||
))
|
||||
log.debug("{} dead local schedulers, {} plasma managers total, {} dead "
|
||||
"plasma managers".format(len(self.dead_local_schedulers),
|
||||
(len(self.live_plasma_managers) +
|
||||
len(self.dead_plasma_managers)),
|
||||
len(self.dead_plasma_managers)))
|
||||
|
||||
# Handle messages from the subscription channels.
|
||||
while True:
|
||||
@@ -289,7 +292,8 @@ class Monitor(object):
|
||||
# Handle plasma managers that timed out during this round.
|
||||
plasma_manager_ids = list(self.live_plasma_managers.keys())
|
||||
for plasma_manager_id in plasma_manager_ids:
|
||||
if self.live_plasma_managers[plasma_manager_id] >= NUM_HEARTBEATS_TIMEOUT:
|
||||
if ((self.live_plasma_managers
|
||||
[plasma_manager_id]) >= NUM_HEARTBEATS_TIMEOUT):
|
||||
log.warn("Timed out {}".format(PLASMA_MANAGER_CLIENT_TYPE))
|
||||
# Remove the plasma manager from the managers whose heartbeats we're
|
||||
# tracking.
|
||||
@@ -299,7 +303,8 @@ class Monitor(object):
|
||||
# receive the notification for this db_client deletion.
|
||||
self.redis.execute_command("RAY.DISCONNECT", plasma_manager_id)
|
||||
|
||||
# Increment the number of heartbeats that we've missed from each plasma manager.
|
||||
# Increment the number of heartbeats that we've missed from each plasma
|
||||
# manager.
|
||||
for plasma_manager_id in self.live_plasma_managers:
|
||||
self.live_plasma_managers[plasma_manager_id] += 1
|
||||
|
||||
|
||||
@@ -10,15 +10,29 @@ If you are using Anaconda, try fixing this problem by running:
|
||||
conda install libgcc
|
||||
"""
|
||||
|
||||
__all__ = ["deserialize_list", "numbuf_error",
|
||||
"numbuf_plasma_object_exists_error", "read_from_buffer",
|
||||
"register_callbacks", "retrieve_list", "serialize_list",
|
||||
"store_list", "write_to_buffer"]
|
||||
|
||||
try:
|
||||
from ray.core.src.numbuf.libnumbuf import *
|
||||
from ray.core.src.numbuf.libnumbuf import (deserialize_list, numbuf_error,
|
||||
numbuf_plasma_object_exists_error,
|
||||
read_from_buffer,
|
||||
register_callbacks, retrieve_list,
|
||||
serialize_list, store_list,
|
||||
write_to_buffer)
|
||||
except ImportError as e:
|
||||
if hasattr(e, "msg") and isinstance(e.msg, str) and ("libstdc++" in e.msg or "CXX" in e.msg):
|
||||
if (hasattr(e, "msg") and isinstance(e.msg, str) and ("libstdc++" in e.msg or
|
||||
"CXX" in e.msg)):
|
||||
# This code path should be taken with Python 3.
|
||||
e.msg += helpful_message
|
||||
elif hasattr(e, "message") and isinstance(e.message, str) and ("libstdc++" in e.message or "CXX" in e.message):
|
||||
elif (hasattr(e, "message") and isinstance(e.message, str) and
|
||||
("libstdc++" in e.message or "CXX" in e.message)):
|
||||
# This code path should be taken with Python 2.
|
||||
if hasattr(e, "args") and isinstance(e.args, tuple) and len(e.args) == 1 and isinstance(e.args[0], str):
|
||||
condition = (hasattr(e, "args") and isinstance(e.args, tuple) and
|
||||
len(e.args) == 1 and isinstance(e.args[0], str))
|
||||
if condition:
|
||||
e.args = (e.args[0] + helpful_message,)
|
||||
else:
|
||||
if not hasattr(e, "args"):
|
||||
|
||||
+34
-14
@@ -1,55 +1,74 @@
|
||||
# Note that a little bit of code here is taken and slightly modified from the pickler because it was not possible to change its behavior otherwise.
|
||||
# Note that a little bit of code here is taken and slightly modified from the
|
||||
# pickler because it was not possible to change its behavior otherwise.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
from ctypes import c_void_p
|
||||
from cloudpickle import pickle, cloudpickle, CloudPickler, load, loads
|
||||
|
||||
__all__ = ["load", "loads", "dump", "dumps"]
|
||||
|
||||
try:
|
||||
from ctypes import pythonapi
|
||||
pythonapi.PyCell_Set # Make sure this exists
|
||||
except:
|
||||
pythonapi = None
|
||||
|
||||
|
||||
def dump(obj, file, protocol=2):
|
||||
return BetterPickler(file, protocol).dump(obj)
|
||||
|
||||
|
||||
def dumps(obj):
|
||||
stringio = cloudpickle.StringIO()
|
||||
dump(obj, stringio)
|
||||
return stringio.getvalue()
|
||||
|
||||
def _make_skel_func(code, closure, base_globals = None):
|
||||
""" Creates a skeleton function object that contains just the provided
|
||||
code and the correct number of cells in func_closure. All other
|
||||
func attributes (e.g. func_globals) are empty.
|
||||
|
||||
def _make_skel_func(code, closure, base_globals=None):
|
||||
"""Create a skeleton function object.
|
||||
|
||||
Creates a skeleton function object that contains just the provided code and
|
||||
the correct number of cells in func_closure. All other func attributes
|
||||
(e.g. func_globals) are empty.
|
||||
"""
|
||||
if base_globals is None: base_globals = {}
|
||||
base_globals['__builtins__'] = __builtins__
|
||||
return _make_skel_func.__class__(code, base_globals, None, None, tuple(closure))
|
||||
if base_globals is None:
|
||||
base_globals = {}
|
||||
base_globals["__builtins__"] = __builtins__
|
||||
return _make_skel_func.__class__(code, base_globals, None, None,
|
||||
tuple(closure))
|
||||
|
||||
|
||||
def _fill_function(func, globals, defaults, closure, dict):
|
||||
""" Fills in the rest of function data into the skeleton function object
|
||||
that were created via _make_skel_func(), including closures.
|
||||
"""Fill in the resst of the function data.
|
||||
|
||||
This fills in the rest of function data into the skeleton function object
|
||||
that were created via _make_skel_func(), including closures.
|
||||
"""
|
||||
result = cloudpickle._fill_function(func, globals, defaults, dict)
|
||||
if pythonapi is not None:
|
||||
for i, v in enumerate(closure):
|
||||
pythonapi.PyCell_Set(c_void_p(id(result.__closure__[i])), c_void_p(id(v)))
|
||||
pythonapi.PyCell_Set(c_void_p(id(result.__closure__[i])),
|
||||
c_void_p(id(v)))
|
||||
return result
|
||||
|
||||
|
||||
class BetterPickler(CloudPickler):
|
||||
def save_function_tuple(self, func):
|
||||
code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func)
|
||||
(code, f_globals, defaults,
|
||||
closure, dct, base_globals) = self.extract_func_data(func)
|
||||
|
||||
self.save(_fill_function)
|
||||
self.write(pickle.MARK)
|
||||
|
||||
self.save(_make_skel_func if pythonapi else cloudpickle._make_skel_func)
|
||||
self.save((code, map(lambda _: cloudpickle._make_cell(None), closure) if closure and pythonapi is not None else closure, base_globals))
|
||||
self.save((code,
|
||||
(map(lambda _: cloudpickle._make_cell(None), closure)
|
||||
if closure and pythonapi is not None
|
||||
else closure),
|
||||
base_globals))
|
||||
self.write(pickle.REDUCE)
|
||||
self.memoize(func)
|
||||
|
||||
@@ -59,6 +78,7 @@ class BetterPickler(CloudPickler):
|
||||
self.save(dct)
|
||||
self.write(pickle.TUPLE)
|
||||
self.write(pickle.REDUCE)
|
||||
|
||||
def save_cell(self, obj):
|
||||
self.save(cloudpickle._make_cell)
|
||||
self.save((obj.cell_contents,))
|
||||
|
||||
@@ -2,4 +2,13 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.plasma.plasma import *
|
||||
from ray.plasma.plasma import (PlasmaBuffer, buffers_equal, PlasmaClient,
|
||||
start_plasma_store, start_plasma_manager,
|
||||
plasma_object_exists_error,
|
||||
plasma_out_of_memory_error,
|
||||
DEFAULT_PLASMA_STORE_MEMORY)
|
||||
|
||||
__all__ = ["PlasmaBuffer", "buffers_equal", "PlasmaClient",
|
||||
"start_plasma_store", "start_plasma_manager",
|
||||
"plasma_object_exists_error", "plasma_out_of_memory_error",
|
||||
"DEFAULT_PLASMA_STORE_MEMORY"]
|
||||
|
||||
+55
-29
@@ -12,9 +12,14 @@ import ray.core.src.plasma.libplasma as libplasma
|
||||
from ray.core.src.plasma.libplasma import plasma_object_exists_error
|
||||
from ray.core.src.plasma.libplasma import plasma_out_of_memory_error
|
||||
|
||||
PLASMA_ID_SIZE = 20
|
||||
__all__ = ["PlasmaBuffer", "buffers_equal", "PlasmaClient",
|
||||
"start_plasma_store", "start_plasma_manager",
|
||||
"plasma_object_exists_error", "plasma_out_of_memory_error",
|
||||
"DEFAULT_PLASMA_STORE_MEMORY"]
|
||||
|
||||
PLASMA_WAIT_TIMEOUT = 2 ** 30
|
||||
|
||||
|
||||
class PlasmaBuffer(object):
|
||||
"""This is the type of objects returned by calls to get with a PlasmaClient.
|
||||
|
||||
@@ -47,8 +52,8 @@ class PlasmaBuffer(object):
|
||||
"""Read from the PlasmaBuffer as if it were just a regular buffer."""
|
||||
# We currently don't allow slicing plasma buffers. We should handle this
|
||||
# better, but it requires some care because the slice may be backed by the
|
||||
# same memory in the object store, but the original plasma buffer may go out
|
||||
# of scope causing the memory to no longer be accessible.
|
||||
# same memory in the object store, but the original plasma buffer may go
|
||||
# out of scope causing the memory to no longer be accessible.
|
||||
assert not isinstance(index, slice)
|
||||
value = self.buffer[index]
|
||||
if sys.version_info >= (3, 0) and not isinstance(index, slice):
|
||||
@@ -62,8 +67,8 @@ class PlasmaBuffer(object):
|
||||
"""
|
||||
# We currently don't allow slicing plasma buffers. We should handle this
|
||||
# better, but it requires some care because the slice may be backed by the
|
||||
# same memory in the object store, but the original plasma buffer may go out
|
||||
# of scope causing the memory to no longer be accessible.
|
||||
# same memory in the object store, but the original plasma buffer may go
|
||||
# out of scope causing the memory to no longer be accessible.
|
||||
assert not isinstance(index, slice)
|
||||
if sys.version_info >= (3, 0) and not isinstance(index, slice):
|
||||
value = ord(value)
|
||||
@@ -73,48 +78,56 @@ class PlasmaBuffer(object):
|
||||
"""Return the length of the buffer."""
|
||||
return len(self.buffer)
|
||||
|
||||
|
||||
def buffers_equal(buff1, buff2):
|
||||
"""Compare two buffers. These buffers may be PlasmaBuffer objects.
|
||||
|
||||
This method should only be used in the tests. We implement a special helper
|
||||
method for doing this because doing comparisons by slicing is much faster, but
|
||||
we don't want to expose slicing of PlasmaBuffer objects because it currently
|
||||
is not safe.
|
||||
method for doing this because doing comparisons by slicing is much faster,
|
||||
but we don't want to expose slicing of PlasmaBuffer objects because it
|
||||
currently is not safe.
|
||||
"""
|
||||
buff1_to_compare = buff1.buffer if isinstance(buff1, PlasmaBuffer) else buff1
|
||||
buff2_to_compare = buff2.buffer if isinstance(buff2, PlasmaBuffer) else buff2
|
||||
return buff1_to_compare[:] == buff2_to_compare[:]
|
||||
|
||||
|
||||
class PlasmaClient(object):
|
||||
"""The PlasmaClient is used to interface with a plasma store and a plasma manager.
|
||||
"""The PlasmaClient is used to interface with a plasma store and manager.
|
||||
|
||||
The PlasmaClient can ask the PlasmaStore to allocate a new buffer, seal a
|
||||
buffer, and get a buffer. Buffers are referred to by object IDs, which are
|
||||
strings.
|
||||
"""
|
||||
|
||||
def __init__(self, store_socket_name, manager_socket_name=None, release_delay=64):
|
||||
def __init__(self, store_socket_name, manager_socket_name=None,
|
||||
release_delay=64):
|
||||
"""Initialize the PlasmaClient.
|
||||
|
||||
Args:
|
||||
store_socket_name (str): Name of the socket the plasma store is listening at.
|
||||
manager_socket_name (str): Name of the socket the plasma manager is listening at.
|
||||
store_socket_name (str): Name of the socket the plasma store is listening
|
||||
at.
|
||||
manager_socket_name (str): Name of the socket the plasma manager is
|
||||
listening at.
|
||||
release_delay (int): The maximum number of objects that the client will
|
||||
keep and delay releasing (for caching reasons).
|
||||
"""
|
||||
self.store_socket_name = store_socket_name
|
||||
self.manager_socket_name = manager_socket_name
|
||||
self.alive = True
|
||||
|
||||
if manager_socket_name is not None:
|
||||
self.conn = libplasma.connect(store_socket_name, manager_socket_name, release_delay)
|
||||
self.conn = libplasma.connect(store_socket_name, manager_socket_name,
|
||||
release_delay)
|
||||
else:
|
||||
self.conn = libplasma.connect(store_socket_name, "", release_delay)
|
||||
|
||||
def shutdown(self):
|
||||
"""Shutdown the client so that it does not send messages.
|
||||
|
||||
If we kill the Plasma store and Plasma manager that this client is connected
|
||||
to, then we can use this method to prevent the client from trying to send
|
||||
messages to the killed processes.
|
||||
If we kill the Plasma store and Plasma manager that this client is
|
||||
connected to, then we can use this method to prevent the client from trying
|
||||
to send messages to the killed processes.
|
||||
"""
|
||||
if self.alive:
|
||||
libplasma.disconnect(self.conn)
|
||||
@@ -169,8 +182,8 @@ class PlasmaClient(object):
|
||||
def get_metadata(self, object_ids, timeout_ms=-1):
|
||||
"""Create a buffer from the PlasmaStore based on object ID.
|
||||
|
||||
If the object has not been sealed yet, this call will block until the object
|
||||
has been sealed. The retrieved buffer is immutable.
|
||||
If the object has not been sealed yet, this call will block until the
|
||||
object has been sealed. The retrieved buffer is immutable.
|
||||
|
||||
Args:
|
||||
object_ids (List[str]): A list of strings used to identify some objects.
|
||||
@@ -275,7 +288,8 @@ class PlasmaClient(object):
|
||||
# currently crashes if given duplicate object IDs.
|
||||
if len(object_ids) != len(set(object_ids)):
|
||||
raise Exception("Wait requires a list of unique object IDs.")
|
||||
ready_ids, waiting_ids = libplasma.wait(self.conn, object_ids, timeout, num_returns)
|
||||
ready_ids, waiting_ids = libplasma.wait(self.conn, object_ids, timeout,
|
||||
num_returns)
|
||||
return ready_ids, list(waiting_ids)
|
||||
|
||||
def subscribe(self):
|
||||
@@ -286,11 +300,14 @@ class PlasmaClient(object):
|
||||
"""Get the next notification from the notification socket."""
|
||||
return libplasma.receive_notification(self.notification_fd)
|
||||
|
||||
|
||||
DEFAULT_PLASMA_STORE_MEMORY = 10 ** 9
|
||||
|
||||
|
||||
def random_name():
|
||||
return str(random.randint(0, 99999999))
|
||||
|
||||
|
||||
def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
|
||||
use_valgrind=False, use_profiler=False,
|
||||
stdout_file=None, stderr_file=None):
|
||||
@@ -312,16 +329,20 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
|
||||
"""
|
||||
if use_valgrind and use_profiler:
|
||||
raise Exception("Cannot use valgrind and profiler at the same time.")
|
||||
plasma_store_executable = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../core/src/plasma/plasma_store")
|
||||
plasma_store_executable = os.path.join(os.path.abspath(
|
||||
os.path.dirname(__file__)),
|
||||
"../core/src/plasma/plasma_store")
|
||||
plasma_store_name = "/tmp/plasma_store{}".format(random_name())
|
||||
command = [plasma_store_executable, "-s", plasma_store_name, "-m", str(plasma_store_memory)]
|
||||
command = [plasma_store_executable,
|
||||
"-s", plasma_store_name,
|
||||
"-m", str(plasma_store_memory)]
|
||||
if use_valgrind:
|
||||
pid = subprocess.Popen(["valgrind",
|
||||
"--track-origins=yes",
|
||||
"--leak-check=full",
|
||||
"--show-leak-kinds=all",
|
||||
"--error-exitcode=1"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
elif use_profiler:
|
||||
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
|
||||
@@ -332,13 +353,16 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
|
||||
time.sleep(0.1)
|
||||
return plasma_store_name, pid
|
||||
|
||||
|
||||
def new_port():
|
||||
return random.randint(10000, 65535)
|
||||
|
||||
def start_plasma_manager(store_name, redis_address, node_ip_address="127.0.0.1",
|
||||
plasma_manager_port=None, num_retries=20,
|
||||
use_valgrind=False, run_profiler=False,
|
||||
stdout_file=None, stderr_file=None):
|
||||
|
||||
def start_plasma_manager(store_name, redis_address,
|
||||
node_ip_address="127.0.0.1", plasma_manager_port=None,
|
||||
num_retries=20, use_valgrind=False,
|
||||
run_profiler=False, stdout_file=None,
|
||||
stderr_file=None):
|
||||
"""Start a plasma manager and return the ports it listens on.
|
||||
|
||||
Args:
|
||||
@@ -361,7 +385,9 @@ def start_plasma_manager(store_name, redis_address, node_ip_address="127.0.0.1",
|
||||
Raises:
|
||||
Exception: An exception is raised if the manager could not be started.
|
||||
"""
|
||||
plasma_manager_executable = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../core/src/plasma/plasma_manager")
|
||||
plasma_manager_executable = os.path.join(
|
||||
os.path.abspath(os.path.dirname(__file__)),
|
||||
"../core/src/plasma/plasma_manager")
|
||||
plasma_manager_name = "/tmp/plasma_manager{}".format(random_name())
|
||||
if plasma_manager_port is not None:
|
||||
if num_retries != 1:
|
||||
@@ -385,7 +411,7 @@ def start_plasma_manager(store_name, redis_address, node_ip_address="127.0.0.1",
|
||||
"--leak-check=full",
|
||||
"--show-leak-kinds=all",
|
||||
"--error-exitcode=1"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
elif run_profiler:
|
||||
process = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
@@ -396,7 +422,7 @@ def start_plasma_manager(store_name, redis_address, node_ip_address="127.0.0.1",
|
||||
# port is already in use, then we need it to fail within 0.1 seconds.
|
||||
time.sleep(0.1)
|
||||
# See if the process has terminated
|
||||
if process.poll() == None:
|
||||
if process.poll() is None:
|
||||
return plasma_manager_name, process, plasma_manager_port
|
||||
# Generate a new port and try again.
|
||||
plasma_manager_port = new_port()
|
||||
|
||||
+111
-65
@@ -6,23 +6,22 @@ import numpy as np
|
||||
import os
|
||||
import random
|
||||
import signal
|
||||
import socket
|
||||
import struct
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import ray.plasma as plasma
|
||||
from ray.plasma.utils import random_object_id, generate_metadata, write_to_data_buffer, create_object_with_id, create_object
|
||||
from ray.plasma.utils import (random_object_id, generate_metadata,
|
||||
create_object_with_id, create_object)
|
||||
from ray import services
|
||||
|
||||
USE_VALGRIND = False
|
||||
PLASMA_STORE_MEMORY = 1000000000
|
||||
|
||||
def assert_get_object_equal(unit_test, client1, client2, object_id, memory_buffer=None, metadata=None):
|
||||
|
||||
def assert_get_object_equal(unit_test, client1, client2, object_id,
|
||||
memory_buffer=None, metadata=None):
|
||||
client1_buff = client1.get([object_id])[0]
|
||||
client2_buff = client2.get([object_id])[0]
|
||||
client1_metadata = client1.get_metadata([object_id])[0]
|
||||
@@ -32,7 +31,8 @@ def assert_get_object_equal(unit_test, client1, client2, object_id, memory_buffe
|
||||
# Check that the buffers from the two clients are the same.
|
||||
unit_test.assertTrue(plasma.buffers_equal(client1_buff, client2_buff))
|
||||
# Check that the metadata buffers from the two clients are the same.
|
||||
unit_test.assertTrue(plasma.buffers_equal(client1_metadata, client2_metadata))
|
||||
unit_test.assertTrue(plasma.buffers_equal(client1_metadata,
|
||||
client2_metadata))
|
||||
# If a reference buffer was provided, check that it is the same as well.
|
||||
if memory_buffer is not None:
|
||||
unit_test.assertTrue(plasma.buffers_equal(memory_buffer, client1_buff))
|
||||
@@ -40,11 +40,13 @@ def assert_get_object_equal(unit_test, client1, client2, object_id, memory_buffe
|
||||
if metadata is not None:
|
||||
unit_test.assertTrue(plasma.buffers_equal(metadata, client1_metadata))
|
||||
|
||||
|
||||
class TestPlasmaClient(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Start Plasma store.
|
||||
plasma_store_name, self.p = plasma.start_plasma_store(use_valgrind=USE_VALGRIND)
|
||||
plasma_store_name, self.p = plasma.start_plasma_store(
|
||||
use_valgrind=USE_VALGRIND)
|
||||
# Connect to Plasma.
|
||||
self.plasma_client = plasma.PlasmaClient(plasma_store_name, None, 64)
|
||||
# For the eviction test
|
||||
@@ -107,7 +109,7 @@ class TestPlasmaClient(unittest.TestCase):
|
||||
object_id = random_object_id()
|
||||
self.plasma_client.create(object_id, length, generate_metadata(length))
|
||||
try:
|
||||
val = self.plasma_client.create(object_id, length, generate_metadata(length))
|
||||
self.plasma_client.create(object_id, length, generate_metadata(length))
|
||||
except plasma.plasma_object_exists_error as e:
|
||||
pass
|
||||
else:
|
||||
@@ -125,20 +127,24 @@ class TestPlasmaClient(unittest.TestCase):
|
||||
metadata_buffers = []
|
||||
for i in range(num_object_ids):
|
||||
if i % 2 == 0:
|
||||
data_buffer, metadata_buffer = create_object_with_id(self.plasma_client, object_ids[i], 2000, 2000)
|
||||
data_buffer, metadata_buffer = create_object_with_id(
|
||||
self.plasma_client, object_ids[i], 2000, 2000)
|
||||
data_buffers.append(data_buffer)
|
||||
metadata_buffers.append(metadata_buffer)
|
||||
|
||||
# Test timing out from some but not all get calls with various timeouts.
|
||||
for timeout in [0, 10, 100, 1000]:
|
||||
data_results = self.plasma_client.get(object_ids, timeout_ms=timeout)
|
||||
metadata_results = self.plasma_client.get(object_ids, timeout_ms=timeout)
|
||||
# metadata_results = self.plasma_client.get_metadata(object_ids,
|
||||
# timeout_ms=timeout)
|
||||
for i in range(num_object_ids):
|
||||
if i % 2 == 0:
|
||||
self.assertTrue(plasma.buffers_equal(data_buffers[i // 2], data_results[i]))
|
||||
# TODO(rkn): We should compare the metadata as well. But currently the
|
||||
# types are different (e.g., memoryview versus bytearray).
|
||||
# self.assertTrue(plasma.buffers_equal(metadata_buffers[i // 2], metadata_results[i]))
|
||||
self.assertTrue(plasma.buffers_equal(data_buffers[i // 2],
|
||||
data_results[i]))
|
||||
# TODO(rkn): We should compare the metadata as well. But currently
|
||||
# the types are different (e.g., memoryview versus bytearray).
|
||||
# self.assertTrue(plasma.buffers_equal(metadata_buffers[i // 2],
|
||||
# metadata_results[i]))
|
||||
else:
|
||||
self.assertIsNone(results[i])
|
||||
|
||||
@@ -148,7 +154,9 @@ class TestPlasmaClient(unittest.TestCase):
|
||||
def assert_create_raises_plasma_full(unit_test, size):
|
||||
partial_size = np.random.randint(size)
|
||||
try:
|
||||
_, memory_buffer, _ = create_object(unit_test.plasma_client, partial_size, size - partial_size)
|
||||
_, memory_buffer, _ = create_object(unit_test.plasma_client,
|
||||
partial_size,
|
||||
size - partial_size)
|
||||
except plasma.plasma_out_of_memory_error as e:
|
||||
pass
|
||||
else:
|
||||
@@ -215,7 +223,7 @@ class TestPlasmaClient(unittest.TestCase):
|
||||
real_object_ids = [random_object_id() for _ in range(100)]
|
||||
for object_id in real_object_ids:
|
||||
self.assertFalse(self.plasma_client.contains(object_id))
|
||||
memory_buffer = self.plasma_client.create(object_id, 100)
|
||||
self.plasma_client.create(object_id, 100)
|
||||
self.plasma_client.seal(object_id)
|
||||
self.assertTrue(self.plasma_client.contains(object_id))
|
||||
for object_id in fake_object_ids:
|
||||
@@ -226,7 +234,7 @@ class TestPlasmaClient(unittest.TestCase):
|
||||
def test_hash(self):
|
||||
# Check the hash of an object that doesn't exist.
|
||||
object_id1 = random_object_id()
|
||||
h = self.plasma_client.hash(object_id1)
|
||||
self.plasma_client.hash(object_id1)
|
||||
|
||||
length = 1000
|
||||
# Create a random object, and check that the hash function always returns
|
||||
@@ -357,7 +365,7 @@ class TestPlasmaClient(unittest.TestCase):
|
||||
length = 1000
|
||||
memory_buffer = self.plasma_client.create(object_id, length)
|
||||
# Make sure we cannot access memory out of bounds.
|
||||
self.assertRaises(Exception, lambda : memory_buffer[length])
|
||||
self.assertRaises(Exception, lambda: memory_buffer[length])
|
||||
# Seal the object.
|
||||
self.plasma_client.seal(object_id)
|
||||
# This test is commented out because it currently fails.
|
||||
@@ -367,6 +375,7 @@ class TestPlasmaClient(unittest.TestCase):
|
||||
# self.assertRaises(Exception, illegal_assignment)
|
||||
# Get the object.
|
||||
memory_buffer = self.plasma_client.get([object_id])[0]
|
||||
|
||||
# Make sure the object is read only.
|
||||
def illegal_assignment():
|
||||
memory_buffer[0] = chr(0)
|
||||
@@ -413,18 +422,20 @@ class TestPlasmaClient(unittest.TestCase):
|
||||
|
||||
def test_subscribe(self):
|
||||
# Subscribe to notifications from the Plasma Store.
|
||||
sock = self.plasma_client.subscribe()
|
||||
self.plasma_client.subscribe()
|
||||
for i in [1, 10, 100, 1000, 10000, 100000]:
|
||||
object_ids = [random_object_id() for _ in range(i)]
|
||||
metadata_sizes = [np.random.randint(1000) for _ in range(i)]
|
||||
data_sizes = [np.random.randint(1000) for _ in range(i)]
|
||||
for j in range(i):
|
||||
self.plasma_client.create(object_ids[j], size=data_sizes[j],
|
||||
metadata=bytearray(np.random.bytes(metadata_sizes[j])))
|
||||
self.plasma_client.create(
|
||||
object_ids[j], size=data_sizes[j],
|
||||
metadata=bytearray(np.random.bytes(metadata_sizes[j])))
|
||||
self.plasma_client.seal(object_ids[j])
|
||||
# Check that we received notifications for all of the objects.
|
||||
for j in range(i):
|
||||
recv_objid, recv_dsize, recv_msize = self.plasma_client.get_next_notification()
|
||||
notification_info = self.plasma_client.get_next_notification()
|
||||
recv_objid, recv_dsize, recv_msize = notification_info
|
||||
self.assertEqual(object_ids[j], recv_objid)
|
||||
self.assertEqual(data_sizes[j], recv_dsize)
|
||||
self.assertEqual(metadata_sizes[j], recv_msize)
|
||||
@@ -432,20 +443,22 @@ class TestPlasmaClient(unittest.TestCase):
|
||||
def test_subscribe_deletions(self):
|
||||
# Subscribe to notifications from the Plasma Store. We use plasma_client2
|
||||
# to make sure that all used objects will get evicted properly.
|
||||
sock = self.plasma_client2.subscribe()
|
||||
self.plasma_client2.subscribe()
|
||||
for i in [1, 10, 100, 1000, 10000, 100000]:
|
||||
object_ids = [random_object_id() for _ in range(i)]
|
||||
# Add 1 to the sizes to make sure we have nonzero object sizes.
|
||||
metadata_sizes = [np.random.randint(1000) + 1 for _ in range(i)]
|
||||
data_sizes = [np.random.randint(1000) + 1 for _ in range(i)]
|
||||
for j in range(i):
|
||||
x = self.plasma_client2.create(object_ids[j], size=data_sizes[j],
|
||||
metadata=bytearray(np.random.bytes(metadata_sizes[j])))
|
||||
x = self.plasma_client2.create(
|
||||
object_ids[j], size=data_sizes[j],
|
||||
metadata=bytearray(np.random.bytes(metadata_sizes[j])))
|
||||
self.plasma_client2.seal(object_ids[j])
|
||||
del x
|
||||
# Check that we received notifications for creating all of the objects.
|
||||
for j in range(i):
|
||||
recv_objid, recv_dsize, recv_msize = self.plasma_client2.get_next_notification()
|
||||
notification_info = self.plasma_client2.get_next_notification()
|
||||
recv_objid, recv_dsize, recv_msize = notification_info
|
||||
self.assertEqual(object_ids[j], recv_objid)
|
||||
self.assertEqual(data_sizes[j], recv_dsize)
|
||||
self.assertEqual(metadata_sizes[j], recv_msize)
|
||||
@@ -453,8 +466,10 @@ class TestPlasmaClient(unittest.TestCase):
|
||||
# Check that we receive notifications for deleting all objects, as we
|
||||
# evict them.
|
||||
for j in range(i):
|
||||
self.assertEqual(self.plasma_client2.evict(1), data_sizes[j] + metadata_sizes[j])
|
||||
recv_objid, recv_dsize, recv_msize = self.plasma_client2.get_next_notification()
|
||||
self.assertEqual(self.plasma_client2.evict(1),
|
||||
data_sizes[j] + metadata_sizes[j])
|
||||
notification_info = self.plasma_client2.get_next_notification()
|
||||
recv_objid, recv_dsize, recv_msize = notification_info
|
||||
self.assertEqual(object_ids[j], recv_objid)
|
||||
self.assertEqual(-1, recv_dsize)
|
||||
self.assertEqual(-1, recv_msize)
|
||||
@@ -469,18 +484,22 @@ class TestPlasmaClient(unittest.TestCase):
|
||||
metadata_sizes.append(np.random.randint(1000))
|
||||
data_sizes.append(np.random.randint(1000))
|
||||
for i in range(num_object_ids):
|
||||
x = self.plasma_client2.create(object_ids[i], size=data_sizes[i],
|
||||
metadata=bytearray(np.random.bytes(metadata_sizes[i])))
|
||||
x = self.plasma_client2.create(
|
||||
object_ids[i], size=data_sizes[i],
|
||||
metadata=bytearray(np.random.bytes(metadata_sizes[i])))
|
||||
self.plasma_client2.seal(object_ids[i])
|
||||
del x
|
||||
for i in range(num_object_ids):
|
||||
recv_objid, recv_dsize, recv_msize = self.plasma_client2.get_next_notification()
|
||||
notification_info = self.plasma_client2.get_next_notification()
|
||||
recv_objid, recv_dsize, recv_msize = notification_info
|
||||
self.assertEqual(object_ids[i], recv_objid)
|
||||
self.assertEqual(data_sizes[i], recv_dsize)
|
||||
self.assertEqual(metadata_sizes[i], recv_msize)
|
||||
self.assertEqual(self.plasma_client2.evict(1), data_sizes[-1] + metadata_sizes[-1])
|
||||
self.assertEqual(self.plasma_client2.evict(1),
|
||||
data_sizes[-1] + metadata_sizes[-1])
|
||||
for i in range(num_object_ids):
|
||||
recv_objid, recv_dsize, recv_msize = self.plasma_client2.get_next_notification()
|
||||
notification_info = self.plasma_client2.get_next_notification()
|
||||
recv_objid, recv_dsize, recv_msize = notification_info
|
||||
self.assertEqual(object_ids[i], recv_objid)
|
||||
self.assertEqual(-1, recv_dsize)
|
||||
self.assertEqual(-1, recv_msize)
|
||||
@@ -495,8 +514,10 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
# Start a Redis server.
|
||||
redis_address = services.address("127.0.0.1", services.start_redis()[0])
|
||||
# Start two PlasmaManagers.
|
||||
manager_name1, self.p4, self.port1 = plasma.start_plasma_manager(store_name1, redis_address, use_valgrind=USE_VALGRIND)
|
||||
manager_name2, self.p5, self.port2 = plasma.start_plasma_manager(store_name2, redis_address, use_valgrind=USE_VALGRIND)
|
||||
manager_name1, self.p4, self.port1 = plasma.start_plasma_manager(
|
||||
store_name1, redis_address, use_valgrind=USE_VALGRIND)
|
||||
manager_name2, self.p5, self.port2 = plasma.start_plasma_manager(
|
||||
store_name2, redis_address, use_valgrind=USE_VALGRIND)
|
||||
# Connect two PlasmaClients.
|
||||
self.client1 = plasma.PlasmaClient(store_name1, manager_name1)
|
||||
self.client2 = plasma.PlasmaClient(store_name2, manager_name2)
|
||||
@@ -513,7 +534,8 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
|
||||
# Kill the Plasma store and Plasma manager processes.
|
||||
if USE_VALGRIND:
|
||||
time.sleep(1) # give processes opportunity to finish work
|
||||
# Give processes opportunity to finish work.
|
||||
time.sleep(1)
|
||||
for process in self.processes_to_kill:
|
||||
process.send_signal(signal.SIGTERM)
|
||||
process.wait()
|
||||
@@ -530,7 +552,8 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
def test_fetch(self):
|
||||
for _ in range(10):
|
||||
# Create an object.
|
||||
object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000, 2000)
|
||||
object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000,
|
||||
2000)
|
||||
self.client1.fetch([object_id1])
|
||||
self.assertEqual(self.client1.contains(object_id1), True)
|
||||
self.assertEqual(self.client2.contains(object_id1), False)
|
||||
@@ -546,7 +569,8 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
object_id2 = random_object_id()
|
||||
self.client1.fetch([object_id2])
|
||||
self.assertEqual(self.client1.contains(object_id2), False)
|
||||
memory_buffer2, metadata2 = create_object_with_id(self.client2, object_id2, 2000, 2000)
|
||||
memory_buffer2, metadata2 = create_object_with_id(self.client2, object_id2,
|
||||
2000, 2000)
|
||||
# # Check that the object has been fetched.
|
||||
# self.assertEqual(self.client1.contains(object_id2), True)
|
||||
# Compare the two buffers.
|
||||
@@ -560,11 +584,12 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
for _ in range(10):
|
||||
self.client1.fetch([object_id3])
|
||||
self.client2.fetch([object_id3])
|
||||
memory_buffer3, metadata3 = create_object_with_id(self.client1, object_id3, 2000, 2000)
|
||||
memory_buffer3, metadata3 = create_object_with_id(self.client1, object_id3,
|
||||
2000, 2000)
|
||||
for _ in range(10):
|
||||
self.client1.fetch([object_id3])
|
||||
self.client2.fetch([object_id3])
|
||||
#TODO(rkn): Right now we must wait for the object table to be updated.
|
||||
# TODO(rkn): Right now we must wait for the object table to be updated.
|
||||
while not self.client2.contains(object_id3):
|
||||
self.client2.fetch([object_id3])
|
||||
assert_get_object_equal(self, self.client1, self.client2, object_id3,
|
||||
@@ -573,14 +598,17 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
def test_fetch_multiple(self):
|
||||
for _ in range(20):
|
||||
# Create two objects and a third fake one that doesn't exist.
|
||||
object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000, 2000)
|
||||
object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000,
|
||||
2000)
|
||||
missing_object_id = random_object_id()
|
||||
object_id2, memory_buffer2, metadata2 = create_object(self.client1, 2000, 2000)
|
||||
object_id2, memory_buffer2, metadata2 = create_object(self.client1, 2000,
|
||||
2000)
|
||||
object_ids = [object_id1, missing_object_id, object_id2]
|
||||
# Fetch the objects from the other plasma store. The second object ID
|
||||
# should timeout since it does not exist.
|
||||
# TODO(rkn): Right now we must wait for the object table to be updated.
|
||||
while (not self.client2.contains(object_id1)) or (not self.client2.contains(object_id2)):
|
||||
while ((not self.client2.contains(object_id1)) or
|
||||
(not self.client2.contains(object_id2))):
|
||||
self.client2.fetch(object_ids)
|
||||
# Compare the buffers of the objects that do exist.
|
||||
assert_get_object_equal(self, self.client1, self.client2, object_id1,
|
||||
@@ -597,7 +625,8 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
# Check that we can call fetch with duplicated object IDs.
|
||||
object_id3 = random_object_id()
|
||||
self.client1.fetch([object_id3, object_id3])
|
||||
object_id4, memory_buffer4, metadata4 = create_object(self.client1, 2000, 2000)
|
||||
object_id4, memory_buffer4, metadata4 = create_object(self.client1, 2000,
|
||||
2000)
|
||||
time.sleep(0.1)
|
||||
# TODO(rkn): Right now we must wait for the object table to be updated.
|
||||
while not self.client2.contains(object_id4):
|
||||
@@ -623,7 +652,8 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
obj_id2 = random_object_id()
|
||||
self.client1.create(obj_id2, 1000)
|
||||
# Don't seal.
|
||||
ready, waiting = self.client1.wait([obj_id2, obj_id1], timeout=100, num_returns=1)
|
||||
ready, waiting = self.client1.wait([obj_id2, obj_id1], timeout=100,
|
||||
num_returns=1)
|
||||
self.assertEqual(set(ready), set([obj_id1]))
|
||||
self.assertEqual(set(waiting), set([obj_id2]))
|
||||
|
||||
@@ -636,11 +666,13 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
|
||||
t = threading.Timer(0.1, finish)
|
||||
t.start()
|
||||
ready, waiting = self.client1.wait([obj_id3, obj_id2, obj_id1], timeout=1000, num_returns=2)
|
||||
ready, waiting = self.client1.wait([obj_id3, obj_id2, obj_id1],
|
||||
timeout=1000, num_returns=2)
|
||||
self.assertEqual(set(ready), set([obj_id1, obj_id3]))
|
||||
self.assertEqual(set(waiting), set([obj_id2]))
|
||||
|
||||
# Test if the appropriate number of objects is shown if some objects are not ready
|
||||
# Test if the appropriate number of objects is shown if some objects are
|
||||
# not ready.
|
||||
ready, waiting = self.client1.wait([obj_id3, obj_id2, obj_id1], 100, 3)
|
||||
self.assertEqual(set(ready), set([obj_id1, obj_id3]))
|
||||
self.assertEqual(set(waiting), set([obj_id2]))
|
||||
@@ -651,9 +683,9 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
# Test calling wait a bunch of times.
|
||||
object_ids = []
|
||||
# TODO(rkn): Increasing n to 100 (or larger) will cause failures. The
|
||||
# problem appears to be that the number of timers added to the manager event
|
||||
# loop slow down the manager so much that some of the asynchronous Redis
|
||||
# commands timeout triggering fatal failure callbacks.
|
||||
# problem appears to be that the number of timers added to the manager
|
||||
# event loop slow down the manager so much that some of the asynchronous
|
||||
# Redis commands timeout triggering fatal failure callbacks.
|
||||
n = 40
|
||||
for i in range(n * (n + 1) // 2):
|
||||
if i % 2 == 0:
|
||||
@@ -669,7 +701,8 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
self.assertEqual(len(ready), i)
|
||||
retrieved += ready
|
||||
self.assertEqual(set(retrieved), set(object_ids))
|
||||
ready, waiting = self.client1.wait(object_ids, timeout=1000, num_returns=len(object_ids))
|
||||
ready, waiting = self.client1.wait(object_ids, timeout=1000,
|
||||
num_returns=len(object_ids))
|
||||
self.assertEqual(set(ready), set(object_ids))
|
||||
self.assertEqual(waiting, [])
|
||||
# Try waiting for all of the object IDs on the second client.
|
||||
@@ -680,7 +713,8 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
self.assertEqual(len(ready), i)
|
||||
retrieved += ready
|
||||
self.assertEqual(set(retrieved), set(object_ids))
|
||||
ready, waiting = self.client2.wait(object_ids, timeout=1000, num_returns=len(object_ids))
|
||||
ready, waiting = self.client2.wait(object_ids, timeout=1000,
|
||||
num_returns=len(object_ids))
|
||||
self.assertEqual(set(ready), set(object_ids))
|
||||
self.assertEqual(waiting, [])
|
||||
|
||||
@@ -702,7 +736,8 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
num_attempts = 100
|
||||
for _ in range(100):
|
||||
# Create an object.
|
||||
object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000, 2000)
|
||||
object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000,
|
||||
2000)
|
||||
# Transfer the buffer to the the other Plasma store. There is a race
|
||||
# condition on the create and transfer of the object, so keep trying
|
||||
# until the object appears on the second Plasma store.
|
||||
@@ -721,10 +756,12 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
# self.client1.transfer("127.0.0.1", self.port2, object_id1)
|
||||
# # Compare the two buffers.
|
||||
# assert_get_object_equal(self, self.client1, self.client2, object_id1,
|
||||
# memory_buffer=memory_buffer1, metadata=metadata1)
|
||||
# memory_buffer=memory_buffer1,
|
||||
# metadata=metadata1)
|
||||
|
||||
# Create an object.
|
||||
object_id2, memory_buffer2, metadata2 = create_object(self.client2, 20000, 20000)
|
||||
object_id2, memory_buffer2, metadata2 = create_object(self.client2,
|
||||
20000, 20000)
|
||||
# Transfer the buffer to the the other Plasma store. There is a race
|
||||
# condition on the create and transfer of the object, so keep trying
|
||||
# until the object appears on the second Plasma store.
|
||||
@@ -742,17 +779,19 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
|
||||
def test_illegal_functionality(self):
|
||||
# Create an object id string.
|
||||
object_id = random_object_id()
|
||||
# object_id = random_object_id()
|
||||
# Create a new buffer.
|
||||
# memory_buffer = self.client1.create(object_id, 20000)
|
||||
# This test is commented out because it currently fails.
|
||||
# # Transferring the buffer before sealing it should fail.
|
||||
# self.assertRaises(Exception, lambda : self.manager1.transfer(1, object_id))
|
||||
# self.assertRaises(Exception,
|
||||
# lambda : self.manager1.transfer(1, object_id))
|
||||
pass
|
||||
|
||||
def test_stresstest(self):
|
||||
a = time.time()
|
||||
object_ids = []
|
||||
for i in range(10000): # TODO(pcm): increase this to 100000
|
||||
for i in range(10000): # TODO(pcm): increase this to 100000.
|
||||
object_id = random_object_id()
|
||||
object_ids.append(object_id)
|
||||
self.client1.create(object_id, 1)
|
||||
@@ -763,13 +802,16 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
|
||||
print("it took", b, "seconds to put and transfer the objects")
|
||||
|
||||
|
||||
class TestPlasmaManagerRecovery(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Start a Plasma store.
|
||||
self.store_name, self.p2 = plasma.start_plasma_store(use_valgrind=USE_VALGRIND)
|
||||
self.store_name, self.p2 = plasma.start_plasma_store(
|
||||
use_valgrind=USE_VALGRIND)
|
||||
# Start a Redis server.
|
||||
self.redis_address = services.address("127.0.0.1", services.start_redis()[0])
|
||||
self.redis_address = services.address("127.0.0.1",
|
||||
services.start_redis()[0])
|
||||
# Start a PlasmaManagers.
|
||||
manager_name, self.p3, self.port1 = plasma.start_plasma_manager(
|
||||
self.store_name,
|
||||
@@ -791,7 +833,8 @@ class TestPlasmaManagerRecovery(unittest.TestCase):
|
||||
|
||||
# Kill the Plasma store and Plasma manager processes.
|
||||
if USE_VALGRIND:
|
||||
time.sleep(1) # give processes opportunity to finish work
|
||||
# Give processes opportunity to finish work.
|
||||
time.sleep(1)
|
||||
for process in self.processes_to_kill:
|
||||
process.send_signal(signal.SIGTERM)
|
||||
process.wait()
|
||||
@@ -818,23 +861,26 @@ class TestPlasmaManagerRecovery(unittest.TestCase):
|
||||
self.assertEqual(waiting, [])
|
||||
|
||||
# Start a second plasma manager attached to the same store.
|
||||
manager_name, self.p5, self.port2 = plasma.start_plasma_manager(self.store_name, self.redis_address, use_valgrind=USE_VALGRIND)
|
||||
manager_name, self.p5, self.port2 = plasma.start_plasma_manager(
|
||||
self.store_name, self.redis_address, use_valgrind=USE_VALGRIND)
|
||||
self.processes_to_kill = [self.p5] + self.processes_to_kill
|
||||
|
||||
# Check that the second manager knows about existing objects.
|
||||
client2 = plasma.PlasmaClient(self.store_name, manager_name)
|
||||
ready, waiting = [], object_ids
|
||||
while True:
|
||||
ready, waiting = client2.wait(object_ids, num_returns=num_objects, timeout=0)
|
||||
ready, waiting = client2.wait(object_ids, num_returns=num_objects,
|
||||
timeout=0)
|
||||
if len(ready) == len(object_ids):
|
||||
break
|
||||
|
||||
self.assertEqual(set(ready), set(object_ids))
|
||||
self.assertEqual(waiting, [])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) > 1:
|
||||
# pop the argument so we don't mess with unittest's own argument parser
|
||||
# Pop the argument so we don't mess with unittest's own argument parser.
|
||||
if sys.argv[-1] == "valgrind":
|
||||
arg = sys.argv.pop()
|
||||
USE_VALGRIND = True
|
||||
|
||||
@@ -5,9 +5,11 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
|
||||
def random_object_id():
|
||||
return np.random.bytes(20)
|
||||
|
||||
|
||||
def generate_metadata(length):
|
||||
metadata_buffer = bytearray(length)
|
||||
if length > 0:
|
||||
@@ -17,6 +19,7 @@ def generate_metadata(length):
|
||||
metadata_buffer[random.randint(0, length - 1)] = random.randint(0, 255)
|
||||
return metadata_buffer
|
||||
|
||||
|
||||
def write_to_data_buffer(buff, length):
|
||||
if length > 0:
|
||||
buff[0] = chr(random.randint(0, 255))
|
||||
@@ -24,7 +27,9 @@ def write_to_data_buffer(buff, length):
|
||||
for _ in range(100):
|
||||
buff[random.randint(0, length - 1)] = chr(random.randint(0, 255))
|
||||
|
||||
def create_object_with_id(client, object_id, data_size, metadata_size, seal=True):
|
||||
|
||||
def create_object_with_id(client, object_id, data_size, metadata_size,
|
||||
seal=True):
|
||||
metadata = generate_metadata(metadata_size)
|
||||
memory_buffer = client.create(object_id, data_size, metadata)
|
||||
write_to_data_buffer(memory_buffer, data_size)
|
||||
@@ -32,7 +37,9 @@ def create_object_with_id(client, object_id, data_size, metadata_size, seal=True
|
||||
client.seal(object_id)
|
||||
return memory_buffer, metadata
|
||||
|
||||
|
||||
def create_object(client, data_size, metadata_size, seal=True):
|
||||
object_id = random_object_id()
|
||||
memory_buffer, metadata = create_object_with_id(client, object_id, data_size, metadata_size, seal=seal)
|
||||
memory_buffer, metadata = create_object_with_id(client, object_id, data_size,
|
||||
metadata_size, seal=seal)
|
||||
return object_id, memory_buffer, metadata
|
||||
|
||||
+44
-10
@@ -7,6 +7,7 @@ import numpy as np
|
||||
import ray.numbuf
|
||||
import ray.pickling as pickling
|
||||
|
||||
|
||||
def check_serializable(cls):
|
||||
"""Throws an exception if Ray cannot serialize this class efficiently.
|
||||
|
||||
@@ -21,15 +22,30 @@ def check_serializable(cls):
|
||||
# This case works.
|
||||
return
|
||||
if not hasattr(cls, "__new__"):
|
||||
raise Exception("The class {} does not have a '__new__' attribute, and is probably an old-style class. We do not support this. Please either make it a new-style class by inheriting from 'object', or use 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls))
|
||||
raise Exception("The class {} does not have a '__new__' attribute, and is "
|
||||
"probably an old-style class. We do not support this. "
|
||||
"Please either make it a new-style class by inheriting "
|
||||
"from 'object', or use "
|
||||
"'ray.register_class(cls, pickle=True)'. However, note "
|
||||
"that pickle is inefficient.".format(cls))
|
||||
try:
|
||||
obj = cls.__new__(cls)
|
||||
except:
|
||||
raise Exception("The class {} has overridden '__new__', so Ray may not be able to serialize it efficiently. Try using 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls))
|
||||
raise Exception("The class {} has overridden '__new__', so Ray may not be "
|
||||
"able to serialize it efficiently. Try using "
|
||||
"'ray.register_class(cls, pickle=True)'. However, note "
|
||||
"that pickle is inefficient.".format(cls))
|
||||
if not hasattr(obj, "__dict__"):
|
||||
raise Exception("Objects of the class {} do not have a `__dict__` attribute, so Ray cannot serialize it efficiently. Try using 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls))
|
||||
raise Exception("Objects of the class {} do not have a `__dict__` "
|
||||
"attribute, so Ray cannot serialize it efficiently. Try "
|
||||
"using 'ray.register_class(cls, pickle=True)'. However, "
|
||||
"note that pickle is inefficient.".format(cls))
|
||||
if hasattr(obj, "__slots__"):
|
||||
raise Exception("The class {} uses '__slots__', so Ray may not be able to serialize it efficiently. Try using 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls))
|
||||
raise Exception("The class {} uses '__slots__', so Ray may not be able to "
|
||||
"serialize it efficiently. Try using "
|
||||
"'ray.register_class(cls, pickle=True)'. However, note "
|
||||
"that pickle is inefficient.".format(cls))
|
||||
|
||||
|
||||
# This field keeps track of a whitelisted set of classes that Ray will
|
||||
# serialize.
|
||||
@@ -38,10 +54,12 @@ classes_to_pickle = set()
|
||||
custom_serializers = {}
|
||||
custom_deserializers = {}
|
||||
|
||||
|
||||
def class_identifier(typ):
|
||||
"""Return a string that identifies this type."""
|
||||
return "{}.{}".format(typ.__module__, typ.__name__)
|
||||
|
||||
|
||||
def is_named_tuple(cls):
|
||||
"""Return True if cls is a namedtuple and False otherwise."""
|
||||
b = cls.__bases__
|
||||
@@ -52,7 +70,9 @@ def is_named_tuple(cls):
|
||||
return False
|
||||
return all(type(n) == str for n in f)
|
||||
|
||||
def add_class_to_whitelist(cls, pickle=False, custom_serializer=None, custom_deserializer=None):
|
||||
|
||||
def add_class_to_whitelist(cls, pickle=False, custom_serializer=None,
|
||||
custom_deserializer=None):
|
||||
"""Add cls to the list of classes that we can serialize.
|
||||
|
||||
Args:
|
||||
@@ -72,13 +92,21 @@ def add_class_to_whitelist(cls, pickle=False, custom_serializer=None, custom_des
|
||||
custom_serializers[class_id] = custom_serializer
|
||||
custom_deserializers[class_id] = custom_deserializer
|
||||
|
||||
|
||||
# Here we define a custom serializer and deserializer for handling numpy
|
||||
# arrays that contain objects.
|
||||
def array_custom_serializer(obj):
|
||||
return obj.tolist(), obj.dtype.str
|
||||
|
||||
|
||||
def array_custom_deserializer(serialized_obj):
|
||||
return np.array(serialized_obj[0], dtype=np.dtype(serialized_obj[1]))
|
||||
add_class_to_whitelist(np.ndarray, pickle=False, custom_serializer=array_custom_serializer, custom_deserializer=array_custom_deserializer)
|
||||
|
||||
|
||||
add_class_to_whitelist(np.ndarray, pickle=False,
|
||||
custom_serializer=array_custom_serializer,
|
||||
custom_deserializer=array_custom_deserializer)
|
||||
|
||||
|
||||
def serialize(obj):
|
||||
"""This is the callback that will be used by numbuf.
|
||||
@@ -94,7 +122,9 @@ def serialize(obj):
|
||||
"""
|
||||
class_id = class_identifier(type(obj))
|
||||
if class_id not in whitelisted_classes:
|
||||
raise Exception("Ray does not know how to serialize objects of type {}. To fix this, call 'ray.register_class' with this class.".format(type(obj)))
|
||||
raise Exception("Ray does not know how to serialize objects of type {}. "
|
||||
"To fix this, call 'ray.register_class' with this class."
|
||||
.format(type(obj)))
|
||||
if class_id in classes_to_pickle:
|
||||
serialized_obj = {"data": pickling.dumps(obj)}
|
||||
elif class_id in custom_serializers.keys():
|
||||
@@ -107,10 +137,12 @@ def serialize(obj):
|
||||
elif hasattr(obj, "__dict__"):
|
||||
serialized_obj = obj.__dict__
|
||||
else:
|
||||
raise Exception("We do not know how to serialize the object '{}'".format(obj))
|
||||
raise Exception("We do not know how to serialize the object '{}'"
|
||||
.format(obj))
|
||||
result = dict(serialized_obj, **{"_pytype_": class_id})
|
||||
return result
|
||||
|
||||
|
||||
def deserialize(serialized_obj):
|
||||
"""This is the callback that will be used by numbuf.
|
||||
|
||||
@@ -139,11 +171,13 @@ def deserialize(serialized_obj):
|
||||
obj.__dict__.update(serialized_obj)
|
||||
return obj
|
||||
|
||||
|
||||
def set_callbacks():
|
||||
"""Register the custom callbacks with numbuf.
|
||||
|
||||
The serialize callback is used to serialize objects that numbuf does not know
|
||||
how to serialize (for example custom Python classes). The deserialize callback
|
||||
is used to serialize objects that were serialized by the serialize callback.
|
||||
how to serialize (for example custom Python classes). The deserialize
|
||||
callback is used to serialize objects that were serialized by the serialize
|
||||
callback.
|
||||
"""
|
||||
ray.numbuf.register_callbacks(serialize, deserialize)
|
||||
|
||||
+162
-103
@@ -10,7 +10,6 @@ import random
|
||||
import redis
|
||||
import signal
|
||||
import socket
|
||||
import string
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
@@ -60,16 +59,20 @@ ObjectStoreAddress = namedtuple("ObjectStoreAddress", ["name",
|
||||
"manager_name",
|
||||
"manager_port"])
|
||||
|
||||
|
||||
def address(ip_address, port):
|
||||
return ip_address + ":" + str(port)
|
||||
|
||||
|
||||
def get_ip_address(address):
|
||||
try:
|
||||
ip_address = address.split(":")[0]
|
||||
except:
|
||||
raise Exception("Unable to parse IP address from address {}".format(address))
|
||||
raise Exception("Unable to parse IP address from address "
|
||||
"{}".format(address))
|
||||
return ip_address
|
||||
|
||||
|
||||
def get_port(address):
|
||||
try:
|
||||
port = int(address.split(":")[1])
|
||||
@@ -77,12 +80,15 @@ def get_port(address):
|
||||
raise Exception("Unable to parse port from address {}".format(address))
|
||||
return port
|
||||
|
||||
|
||||
def new_port():
|
||||
return random.randint(10000, 65535)
|
||||
|
||||
|
||||
def random_name():
|
||||
return str(random.randint(0, 99999999))
|
||||
|
||||
|
||||
def kill_process(p):
|
||||
"""Kill a process.
|
||||
|
||||
@@ -92,11 +98,15 @@ def kill_process(p):
|
||||
Returns:
|
||||
True if the process was killed successfully and false otherwise.
|
||||
"""
|
||||
if p.poll() is not None: # process has already terminated
|
||||
if p.poll() is not None:
|
||||
# The process has already terminated.
|
||||
return True
|
||||
if RUN_LOCAL_SCHEDULER_PROFILER or RUN_PLASMA_MANAGER_PROFILER or RUN_PLASMA_STORE_PROFILER:
|
||||
os.kill(p.pid, signal.SIGINT) # Give process signal to write profiler data.
|
||||
time.sleep(0.1) # Wait for profiling data to be written.
|
||||
if any([RUN_LOCAL_SCHEDULER_PROFILER, RUN_PLASMA_MANAGER_PROFILER,
|
||||
RUN_PLASMA_STORE_PROFILER]):
|
||||
# Give process signal to write profiler data.
|
||||
os.kill(p.pid, signal.SIGINT)
|
||||
# Wait for profiling data to be written.
|
||||
time.sleep(0.1)
|
||||
|
||||
# Allow the process one second to exit gracefully.
|
||||
p.terminate()
|
||||
@@ -118,6 +128,7 @@ def kill_process(p):
|
||||
# The process was not killed for some reason.
|
||||
return False
|
||||
|
||||
|
||||
def cleanup():
|
||||
"""When running in local mode, shutdown the Ray processes.
|
||||
|
||||
@@ -138,6 +149,7 @@ def cleanup():
|
||||
if not successfully_shut_down:
|
||||
print("Ray did not shut down properly.")
|
||||
|
||||
|
||||
def all_processes_alive(exclude=[]):
|
||||
"""Check if all of the processes are still alive.
|
||||
|
||||
@@ -147,10 +159,13 @@ def all_processes_alive(exclude=[]):
|
||||
for process_type, processes in all_processes.items():
|
||||
# Note that p.poll() returns the exit code that the process exited with, so
|
||||
# an exit code of None indicates that the process is still alive.
|
||||
if not all([p.poll() is None for p in processes]) and process_type not in exclude:
|
||||
processes_alive = [p.poll() is None for p in processes]
|
||||
if (not all(processes_alive) and process_type not in exclude):
|
||||
print("A process of type {} has dead.".format(process_type))
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_node_ip_address(address="8.8.8.8:53"):
|
||||
"""Determine the IP address of the local node.
|
||||
|
||||
@@ -166,6 +181,7 @@ def get_node_ip_address(address="8.8.8.8:53"):
|
||||
s.connect((ip_address, int(port)))
|
||||
return s.getsockname()[0]
|
||||
|
||||
|
||||
def record_log_files_in_redis(redis_address, node_ip_address, log_files):
|
||||
"""Record in Redis that a new log file has been created.
|
||||
|
||||
@@ -187,6 +203,7 @@ def record_log_files_in_redis(redis_address, node_ip_address, log_files):
|
||||
log_file_list_key = "LOG_FILENAMES:{}".format(node_ip_address)
|
||||
redis_client.rpush(log_file_list_key, log_file.name)
|
||||
|
||||
|
||||
def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5):
|
||||
"""Wait for a Redis server to be available.
|
||||
|
||||
@@ -208,7 +225,8 @@ def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5):
|
||||
while counter < num_retries:
|
||||
try:
|
||||
# Run some random command and see if it worked.
|
||||
print("Waiting for redis server at {}:{} to respond...".format(redis_ip_address, redis_port))
|
||||
print("Waiting for redis server at {}:{} to respond..."
|
||||
.format(redis_ip_address, redis_port))
|
||||
redis_client.client_list()
|
||||
except redis.ConnectionError as e:
|
||||
# Wait a little bit.
|
||||
@@ -218,7 +236,10 @@ def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5):
|
||||
else:
|
||||
break
|
||||
if counter == num_retries:
|
||||
raise Exception("Unable to connect to Redis. If the Redis instance is on a different machine, check that your firewall is configured properly.")
|
||||
raise Exception("Unable to connect to Redis. If the Redis instance is on "
|
||||
"a different machine, check that your firewall is "
|
||||
"configured properly.")
|
||||
|
||||
|
||||
def start_redis(node_ip_address="127.0.0.1", port=None, num_retries=20,
|
||||
stdout_file=None, stderr_file=None, cleanup=True):
|
||||
@@ -239,13 +260,18 @@ def start_redis(node_ip_address="127.0.0.1", port=None, num_retries=20,
|
||||
|
||||
Returns:
|
||||
A tuple of the port used by Redis and a handle to the process that was
|
||||
started. If a port is passed in, then the returned port value is the same.
|
||||
started. If a port is passed in, then the returned port value is the
|
||||
same.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if Redis could not be started.
|
||||
"""
|
||||
redis_filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), "./core/src/common/thirdparty/redis/src/redis-server")
|
||||
redis_module = os.path.join(os.path.dirname(os.path.abspath(__file__)), "./core/src/common/redis_module/libray_redis_module.so")
|
||||
redis_filepath = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"./core/src/common/thirdparty/redis/src/redis-server")
|
||||
redis_module = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"./core/src/common/redis_module/libray_redis_module.so")
|
||||
assert os.path.isfile(redis_filepath)
|
||||
assert os.path.isfile(redis_module)
|
||||
counter = 0
|
||||
@@ -261,7 +287,7 @@ def start_redis(node_ip_address="127.0.0.1", port=None, num_retries=20,
|
||||
"--port", str(port),
|
||||
"--loglevel", "warning",
|
||||
"--loadmodule", redis_module],
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
time.sleep(0.1)
|
||||
# Check if Redis successfully started (or at least if it the executable did
|
||||
# not exit within 0.1 seconds).
|
||||
@@ -291,6 +317,7 @@ def start_redis(node_ip_address="127.0.0.1", port=None, num_retries=20,
|
||||
[stdout_file, stderr_file])
|
||||
return port, p
|
||||
|
||||
|
||||
def start_log_monitor(redis_address, node_ip_address, stdout_file=None,
|
||||
stderr_file=None, cleanup=cleanup):
|
||||
"""Start a log monitor process.
|
||||
@@ -307,7 +334,9 @@ def start_log_monitor(redis_address, node_ip_address, stdout_file=None,
|
||||
this process will be killed by services.cleanup() when the Python process
|
||||
that imported services exits.
|
||||
"""
|
||||
log_monitor_filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log_monitor.py")
|
||||
log_monitor_filepath = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"log_monitor.py")
|
||||
p = subprocess.Popen(["python", log_monitor_filepath,
|
||||
"--redis-address", redis_address,
|
||||
"--node-ip-address", node_ip_address],
|
||||
@@ -317,13 +346,15 @@ def start_log_monitor(redis_address, node_ip_address, stdout_file=None,
|
||||
record_log_files_in_redis(redis_address, node_ip_address,
|
||||
[stdout_file, stderr_file])
|
||||
|
||||
|
||||
def start_global_scheduler(redis_address, node_ip_address, stdout_file=None,
|
||||
stderr_file=None, cleanup=True):
|
||||
"""Start a global scheduler process.
|
||||
|
||||
Args:
|
||||
redis_address (str): The address of the Redis instance.
|
||||
node_ip_address: The IP address of the node that this scheduler will run on.
|
||||
node_ip_address: The IP address of the node that this scheduler will run
|
||||
on.
|
||||
stdout_file: A file handle opened for writing to redirect stdout to. If no
|
||||
redirection should happen, then this should be None.
|
||||
stderr_file: A file handle opened for writing to redirect stderr to. If no
|
||||
@@ -340,6 +371,7 @@ def start_global_scheduler(redis_address, node_ip_address, stdout_file=None,
|
||||
record_log_files_in_redis(redis_address, node_ip_address,
|
||||
[stdout_file, stderr_file])
|
||||
|
||||
|
||||
def start_webui(redis_address, node_ip_address, backend_stdout_file=None,
|
||||
backend_stderr_file=None, polymer_stdout_file=None,
|
||||
polymer_stderr_file=None, cleanup=True):
|
||||
@@ -367,8 +399,11 @@ def start_webui(redis_address, node_ip_address, backend_stdout_file=None,
|
||||
Return:
|
||||
True if the web UI was successfully started, otherwise false.
|
||||
"""
|
||||
webui_backend_filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../webui/backend/ray_ui.py")
|
||||
webui_directory = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../webui/")
|
||||
webui_backend_filepath = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"../../webui/backend/ray_ui.py")
|
||||
webui_directory = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
"../../webui/")
|
||||
|
||||
if sys.version_info >= (3, 0):
|
||||
python_executable = "python"
|
||||
@@ -376,7 +411,8 @@ def start_webui(redis_address, node_ip_address, backend_stdout_file=None,
|
||||
# If the user is using Python 2, it is still possible to run the webserver
|
||||
# separately with Python 3, so try to find a Python 3 executable.
|
||||
try:
|
||||
python_executable = subprocess.check_output(["which", "python3"]).decode("ascii").strip()
|
||||
python_executable = subprocess.check_output(
|
||||
["which", "python3"]).decode("ascii").strip()
|
||||
except Exception as e:
|
||||
print("Not starting the web UI because the web UI requires Python 3.")
|
||||
return False
|
||||
@@ -394,8 +430,8 @@ def start_webui(redis_address, node_ip_address, backend_stdout_file=None,
|
||||
return False
|
||||
|
||||
# Try to start polymer. If this fails, it may that port 8080 is already in
|
||||
# use. It'd be nice to test for this, but doing so by calling "bind" may start
|
||||
# using the port and prevent polymer from using it.
|
||||
# use. It'd be nice to test for this, but doing so by calling "bind" may
|
||||
# start using the port and prevent polymer from using it.
|
||||
try:
|
||||
polymer_process = subprocess.Popen(["polymer", "serve", "--port", "8080"],
|
||||
cwd=webui_directory,
|
||||
@@ -433,6 +469,7 @@ def start_webui(redis_address, node_ip_address, backend_stdout_file=None,
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def start_local_scheduler(redis_address,
|
||||
node_ip_address,
|
||||
plasma_store_name,
|
||||
@@ -472,8 +509,8 @@ def start_local_scheduler(redis_address,
|
||||
The name of the local scheduler socket.
|
||||
"""
|
||||
if num_cpus is None:
|
||||
# By default, use the number of hardware execution threads for the number of
|
||||
# cores.
|
||||
# By default, use the number of hardware execution threads for the number
|
||||
# of cores.
|
||||
num_cpus = multiprocessing.cpu_count()
|
||||
if num_gpus is None:
|
||||
# By default, assume this node has no GPUs.
|
||||
@@ -496,6 +533,7 @@ def start_local_scheduler(redis_address,
|
||||
[stdout_file, stderr_file])
|
||||
return local_scheduler_name
|
||||
|
||||
|
||||
def start_objstore(node_ip_address, redis_address, object_manager_port=None,
|
||||
store_stdout_file=None, store_stderr_file=None,
|
||||
manager_stdout_file=None, manager_stderr_file=None,
|
||||
@@ -511,10 +549,10 @@ def start_objstore(node_ip_address, redis_address, object_manager_port=None,
|
||||
If no redirection should happen, then this should be None.
|
||||
store_stderr_file: A file handle opened for writing to redirect stderr to.
|
||||
If no redirection should happen, then this should be None.
|
||||
manager_stdout_file: A file handle opened for writing to redirect stdout to.
|
||||
If no redirection should happen, then this should be None.
|
||||
manager_stderr_file: A file handle opened for writing to redirect stderr to.
|
||||
If no redirection should happen, then this should be None.
|
||||
manager_stdout_file: A file handle opened for writing to redirect stdout
|
||||
to. If no redirection should happen, then this should be None.
|
||||
manager_stderr_file: A file handle opened for writing to redirect stderr
|
||||
to. If no redirection should happen, then this should be None.
|
||||
cleanup (bool): True if using Ray in local mode. If cleanup is true, then
|
||||
this process will be killed by serices.cleanup() when the Python process
|
||||
that imported services exits.
|
||||
@@ -522,8 +560,8 @@ def start_objstore(node_ip_address, redis_address, object_manager_port=None,
|
||||
with.
|
||||
|
||||
Return:
|
||||
A tuple of the Plasma store socket name, the Plasma manager socket name, and
|
||||
the plasma manager port.
|
||||
A tuple of the Plasma store socket name, the Plasma manager socket name,
|
||||
and the plasma manager port.
|
||||
"""
|
||||
if objstore_memory is None:
|
||||
# Compute a fraction of the system memory for the Plasma store to use.
|
||||
@@ -559,7 +597,8 @@ def start_objstore(node_ip_address, redis_address, object_manager_port=None,
|
||||
stderr_file=store_stderr_file)
|
||||
# Start the plasma manager.
|
||||
if object_manager_port is not None:
|
||||
plasma_manager_name, p2, plasma_manager_port = ray.plasma.start_plasma_manager(
|
||||
(plasma_manager_name, p2,
|
||||
plasma_manager_port) = ray.plasma.start_plasma_manager(
|
||||
plasma_store_name,
|
||||
redis_address,
|
||||
plasma_manager_port=object_manager_port,
|
||||
@@ -570,7 +609,8 @@ def start_objstore(node_ip_address, redis_address, object_manager_port=None,
|
||||
stderr_file=manager_stderr_file)
|
||||
assert plasma_manager_port == object_manager_port
|
||||
else:
|
||||
plasma_manager_name, p2, plasma_manager_port = ray.plasma.start_plasma_manager(
|
||||
(plasma_manager_name, p2,
|
||||
plasma_manager_port) = ray.plasma.start_plasma_manager(
|
||||
plasma_store_name,
|
||||
redis_address,
|
||||
node_ip_address=node_ip_address,
|
||||
@@ -587,6 +627,7 @@ def start_objstore(node_ip_address, redis_address, object_manager_port=None,
|
||||
return ObjectStoreAddress(plasma_store_name, plasma_manager_name,
|
||||
plasma_manager_port)
|
||||
|
||||
|
||||
def start_worker(node_ip_address, object_store_name, object_store_manager_name,
|
||||
local_scheduler_name, redis_address, worker_path,
|
||||
stdout_file=None, stderr_file=None, cleanup=True):
|
||||
@@ -599,8 +640,8 @@ def start_worker(node_ip_address, object_store_name, object_store_manager_name,
|
||||
object_store_manager_name (str): The name of the object store manager.
|
||||
local_scheduler_name (str): The name of the local scheduler.
|
||||
redis_address (str): The address that the Redis server is listening on.
|
||||
worker_path (str): The path of the source code which the worker process will
|
||||
run.
|
||||
worker_path (str): The path of the source code which the worker process
|
||||
will run.
|
||||
stdout_file: A file handle opened for writing to redirect stdout to. If no
|
||||
redirection should happen, then this should be None.
|
||||
stderr_file: A file handle opened for writing to redirect stderr to. If no
|
||||
@@ -622,6 +663,7 @@ def start_worker(node_ip_address, object_store_name, object_store_manager_name,
|
||||
record_log_files_in_redis(redis_address, node_ip_address,
|
||||
[stdout_file, stderr_file])
|
||||
|
||||
|
||||
def start_monitor(redis_address, node_ip_address, stdout_file=None,
|
||||
stderr_file=None, cleanup=True):
|
||||
"""Run a process to monitor the other processes.
|
||||
@@ -637,7 +679,8 @@ def start_monitor(redis_address, node_ip_address, stdout_file=None,
|
||||
this process will be killed by services.cleanup() when the Python process
|
||||
that imported services exits. This is True by default.
|
||||
"""
|
||||
monitor_path= os.path.join(os.path.dirname(os.path.abspath(__file__)), "monitor.py")
|
||||
monitor_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
"monitor.py")
|
||||
command = ["python",
|
||||
monitor_path,
|
||||
"--redis-address=" + str(redis_address)]
|
||||
@@ -647,6 +690,7 @@ def start_monitor(redis_address, node_ip_address, stdout_file=None,
|
||||
record_log_files_in_redis(redis_address, node_ip_address,
|
||||
[stdout_file, stderr_file])
|
||||
|
||||
|
||||
def start_ray_processes(address_info=None,
|
||||
node_ip_address="127.0.0.1",
|
||||
num_workers=0,
|
||||
@@ -685,8 +729,9 @@ def start_ray_processes(address_info=None,
|
||||
start a global scheduler process.
|
||||
include_redis (bool): If include_redis is True, then start a Redis server
|
||||
process.
|
||||
include_log_monitor (bool): If True, then start a log monitor to monitor the
|
||||
log files for all processes on this node and push their contents to Redis.
|
||||
include_log_monitor (bool): If True, then start a log monitor to monitor
|
||||
the log files for all processes on this node and push their contents to
|
||||
Redis.
|
||||
include_webui (bool): If True, then attempt to start the web UI. Note that
|
||||
this is only possible with Python 3.
|
||||
start_workers_from_local_scheduler (bool): If this flag is True, then start
|
||||
@@ -713,7 +758,8 @@ def start_ray_processes(address_info=None,
|
||||
address_info["node_ip_address"] = node_ip_address
|
||||
|
||||
if worker_path is None:
|
||||
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "workers/default_worker.py")
|
||||
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
"workers/default_worker.py")
|
||||
|
||||
# Start Redis if there isn't already an instance running. TODO(rkn): We are
|
||||
# suppressing the output of Redis because on Linux it prints a bunch of
|
||||
@@ -721,7 +767,8 @@ def start_ray_processes(address_info=None,
|
||||
# should address the warnings.
|
||||
redis_address = address_info.get("redis_address")
|
||||
if include_redis:
|
||||
redis_stdout_file, redis_stderr_file = new_log_files("redis", redirect_output)
|
||||
redis_stdout_file, redis_stderr_file = new_log_files("redis",
|
||||
redirect_output)
|
||||
if redis_address is None:
|
||||
# Start a Redis server. The start_redis method will choose a random port.
|
||||
redis_port, _ = start_redis(node_ip_address,
|
||||
@@ -735,7 +782,6 @@ def start_ray_processes(address_info=None,
|
||||
# A Redis address was provided, so start a Redis server with the given
|
||||
# port. TODO(rkn): We should check that the IP address corresponds to the
|
||||
# machine that this method is running on.
|
||||
redis_ip_address = get_ip_address(redis_address)
|
||||
redis_port = get_port(redis_address)
|
||||
new_redis_port, _ = start_redis(port=int(redis_port),
|
||||
num_retries=1,
|
||||
@@ -744,7 +790,8 @@ def start_ray_processes(address_info=None,
|
||||
cleanup=cleanup)
|
||||
assert redis_port == new_redis_port
|
||||
# Start monitoring the processes.
|
||||
monitor_stdout_file, monitor_stderr_file = new_log_files("monitor", redirect_output)
|
||||
monitor_stdout_file, monitor_stderr_file = new_log_files("monitor",
|
||||
redirect_output)
|
||||
start_monitor(redis_address,
|
||||
node_ip_address,
|
||||
stdout_file=monitor_stdout_file,
|
||||
@@ -755,8 +802,8 @@ def start_ray_processes(address_info=None,
|
||||
|
||||
# Start the log monitor, if necessary.
|
||||
if include_log_monitor:
|
||||
log_monitor_stdout_file, log_monitor_stderr_file = new_log_files("log_monitor",
|
||||
redirect_output=True)
|
||||
log_monitor_stdout_file, log_monitor_stderr_file = new_log_files(
|
||||
"log_monitor", redirect_output=True)
|
||||
start_log_monitor(redis_address,
|
||||
node_ip_address,
|
||||
stdout_file=log_monitor_stdout_file,
|
||||
@@ -765,7 +812,8 @@ def start_ray_processes(address_info=None,
|
||||
|
||||
# Start the global scheduler, if necessary.
|
||||
if include_global_scheduler:
|
||||
global_scheduler_stdout_file, global_scheduler_stderr_file = new_log_files("global_scheduler", redirect_output)
|
||||
global_scheduler_stdout_file, global_scheduler_stderr_file = new_log_files(
|
||||
"global_scheduler", redirect_output)
|
||||
start_global_scheduler(redis_address,
|
||||
node_ip_address,
|
||||
stdout_file=global_scheduler_stdout_file,
|
||||
@@ -781,7 +829,8 @@ def start_ray_processes(address_info=None,
|
||||
local_scheduler_socket_names = address_info["local_scheduler_socket_names"]
|
||||
|
||||
# Get the ports to use for the object managers if any are provided.
|
||||
object_manager_ports = address_info["object_manager_ports"] if "object_manager_ports" in address_info else None
|
||||
object_manager_ports = (address_info["object_manager_ports"]
|
||||
if "object_manager_ports" in address_info else None)
|
||||
if not isinstance(object_manager_ports, list):
|
||||
object_manager_ports = num_local_schedulers * [object_manager_ports]
|
||||
assert len(object_manager_ports) == num_local_schedulers
|
||||
@@ -789,23 +838,26 @@ def start_ray_processes(address_info=None,
|
||||
# Start any object stores that do not yet exist.
|
||||
for i in range(num_local_schedulers - len(object_store_addresses)):
|
||||
# Start Plasma.
|
||||
plasma_store_stdout_file, plasma_store_stderr_file = new_log_files("plasma_store_{}".format(i), redirect_output)
|
||||
plasma_manager_stdout_file, plasma_manager_stderr_file = new_log_files("plasma_manager_{}".format(i), redirect_output)
|
||||
object_store_address = start_objstore(node_ip_address,
|
||||
redis_address,
|
||||
object_manager_port=object_manager_ports[i],
|
||||
store_stdout_file=plasma_store_stdout_file,
|
||||
store_stderr_file=plasma_store_stderr_file,
|
||||
manager_stdout_file=plasma_manager_stdout_file,
|
||||
manager_stderr_file=plasma_manager_stderr_file,
|
||||
cleanup=cleanup)
|
||||
plasma_store_stdout_file, plasma_store_stderr_file = new_log_files(
|
||||
"plasma_store_{}".format(i), redirect_output)
|
||||
plasma_manager_stdout_file, plasma_manager_stderr_file = new_log_files(
|
||||
"plasma_manager_{}".format(i), redirect_output)
|
||||
object_store_address = start_objstore(
|
||||
node_ip_address,
|
||||
redis_address,
|
||||
object_manager_port=object_manager_ports[i],
|
||||
store_stdout_file=plasma_store_stdout_file,
|
||||
store_stderr_file=plasma_store_stderr_file,
|
||||
manager_stdout_file=plasma_manager_stdout_file,
|
||||
manager_stderr_file=plasma_manager_stderr_file,
|
||||
cleanup=cleanup)
|
||||
object_store_addresses.append(object_store_address)
|
||||
time.sleep(0.1)
|
||||
|
||||
# Determine how many workers to start for each local scheduler.
|
||||
num_workers_per_local_scheduler = [0] * num_local_schedulers
|
||||
workers_per_local_scheduler = [0] * num_local_schedulers
|
||||
for i in range(num_workers):
|
||||
num_workers_per_local_scheduler[i % num_local_schedulers] += 1
|
||||
workers_per_local_scheduler[i % num_local_schedulers] += 1
|
||||
|
||||
# Start any local schedulers that do not yet exist.
|
||||
for i in range(len(local_scheduler_socket_names), num_local_schedulers):
|
||||
@@ -815,26 +867,28 @@ def start_ray_processes(address_info=None,
|
||||
object_store_address.manager_port)
|
||||
# Determine how many workers this local scheduler should start.
|
||||
if start_workers_from_local_scheduler:
|
||||
num_local_scheduler_workers = num_workers_per_local_scheduler[i]
|
||||
num_workers_per_local_scheduler[i] = 0
|
||||
num_local_scheduler_workers = workers_per_local_scheduler[i]
|
||||
workers_per_local_scheduler[i] = 0
|
||||
else:
|
||||
# If we're starting the workers from Python, the local scheduler should
|
||||
# not start any workers.
|
||||
num_local_scheduler_workers = 0
|
||||
# Start the local scheduler.
|
||||
local_scheduler_stdout_file, local_scheduler_stderr_file = new_log_files("local_scheduler_{}".format(i), redirect_output)
|
||||
local_scheduler_name = start_local_scheduler(redis_address,
|
||||
node_ip_address,
|
||||
object_store_address.name,
|
||||
object_store_address.manager_name,
|
||||
worker_path,
|
||||
plasma_address=plasma_address,
|
||||
stdout_file=local_scheduler_stdout_file,
|
||||
stderr_file=local_scheduler_stderr_file,
|
||||
cleanup=cleanup,
|
||||
num_cpus=num_cpus[i],
|
||||
num_gpus=num_gpus[i],
|
||||
num_workers=num_local_scheduler_workers)
|
||||
local_scheduler_stdout_file, local_scheduler_stderr_file = new_log_files(
|
||||
"local_scheduler_{}".format(i), redirect_output)
|
||||
local_scheduler_name = start_local_scheduler(
|
||||
redis_address,
|
||||
node_ip_address,
|
||||
object_store_address.name,
|
||||
object_store_address.manager_name,
|
||||
worker_path,
|
||||
plasma_address=plasma_address,
|
||||
stdout_file=local_scheduler_stdout_file,
|
||||
stderr_file=local_scheduler_stderr_file,
|
||||
cleanup=cleanup,
|
||||
num_cpus=num_cpus[i],
|
||||
num_gpus=num_gpus[i],
|
||||
num_workers=num_local_scheduler_workers)
|
||||
local_scheduler_socket_names.append(local_scheduler_name)
|
||||
time.sleep(0.1)
|
||||
|
||||
@@ -844,11 +898,12 @@ def start_ray_processes(address_info=None,
|
||||
assert len(local_scheduler_socket_names) == num_local_schedulers
|
||||
|
||||
# Start any workers that the local scheduler has not already started.
|
||||
for i, num_local_scheduler_workers in enumerate(num_workers_per_local_scheduler):
|
||||
for i, num_local_scheduler_workers in enumerate(workers_per_local_scheduler):
|
||||
object_store_address = object_store_addresses[i]
|
||||
local_scheduler_name = local_scheduler_socket_names[i]
|
||||
for j in range(num_local_scheduler_workers):
|
||||
worker_stdout_file, worker_stderr_file = new_log_files("worker_{}_{}".format(i, j), redirect_output)
|
||||
worker_stdout_file, worker_stderr_file = new_log_files(
|
||||
"worker_{}_{}".format(i, j), redirect_output)
|
||||
start_worker(node_ip_address,
|
||||
object_store_address.name,
|
||||
object_store_address.manager_name,
|
||||
@@ -858,17 +913,17 @@ def start_ray_processes(address_info=None,
|
||||
stdout_file=worker_stdout_file,
|
||||
stderr_file=worker_stderr_file,
|
||||
cleanup=cleanup)
|
||||
num_workers_per_local_scheduler[i] -= 1
|
||||
workers_per_local_scheduler[i] -= 1
|
||||
|
||||
# Make sure that we've started all the workers.
|
||||
assert(sum(num_workers_per_local_scheduler) == 0)
|
||||
assert(sum(workers_per_local_scheduler) == 0)
|
||||
|
||||
# Try to start the web UI.
|
||||
if include_webui:
|
||||
backend_stdout_file, backend_stderr_file = new_log_files("webui_backend",
|
||||
redirect_output=True)
|
||||
polymer_stdout_file, polymer_stderr_file = new_log_files("webui_polymer",
|
||||
redirect_output=True)
|
||||
backend_stdout_file, backend_stderr_file = new_log_files(
|
||||
"webui_backend", redirect_output=True)
|
||||
polymer_stdout_file, polymer_stderr_file = new_log_files(
|
||||
"webui_polymer", redirect_output=True)
|
||||
successfully_started = start_webui(redis_address,
|
||||
node_ip_address,
|
||||
backend_stdout_file=backend_stdout_file,
|
||||
@@ -883,6 +938,7 @@ def start_ray_processes(address_info=None,
|
||||
# Return the addresses of the relevant processes.
|
||||
return address_info
|
||||
|
||||
|
||||
def start_ray_node(node_ip_address,
|
||||
redis_address,
|
||||
object_manager_ports=None,
|
||||
@@ -905,8 +961,8 @@ def start_ray_node(node_ip_address,
|
||||
managers. There should be one per object manager being started on this
|
||||
node (typically just one).
|
||||
num_workers (int): The number of workers to start.
|
||||
num_local_schedulers (int): The number of local schedulers to start. This is
|
||||
also the number of plasma stores and plasma managers to start.
|
||||
num_local_schedulers (int): The number of local schedulers to start. This
|
||||
is also the number of plasma stores and plasma managers to start.
|
||||
worker_path (str): The path of the source code that will be run by the
|
||||
worker.
|
||||
cleanup (bool): If cleanup is true, then the processes started here will be
|
||||
@@ -919,10 +975,8 @@ def start_ray_node(node_ip_address,
|
||||
A dictionary of the address information for the processes that were
|
||||
started.
|
||||
"""
|
||||
address_info = {
|
||||
"redis_address": redis_address,
|
||||
"object_manager_ports": object_manager_ports,
|
||||
}
|
||||
address_info = {"redis_address": redis_address,
|
||||
"object_manager_ports": object_manager_ports}
|
||||
return start_ray_processes(address_info=address_info,
|
||||
node_ip_address=node_ip_address,
|
||||
num_workers=num_workers,
|
||||
@@ -934,6 +988,7 @@ def start_ray_node(node_ip_address,
|
||||
num_cpus=num_cpus,
|
||||
num_gpus=num_gpus)
|
||||
|
||||
|
||||
def start_ray_head(address_info=None,
|
||||
node_ip_address="127.0.0.1",
|
||||
num_workers=0,
|
||||
@@ -974,33 +1029,37 @@ def start_ray_head(address_info=None,
|
||||
A dictionary of the address information for the processes that were
|
||||
started.
|
||||
"""
|
||||
return start_ray_processes(address_info=address_info,
|
||||
node_ip_address=node_ip_address,
|
||||
num_workers=num_workers,
|
||||
num_local_schedulers=num_local_schedulers,
|
||||
worker_path=worker_path,
|
||||
cleanup=cleanup,
|
||||
redirect_output=redirect_output,
|
||||
include_global_scheduler=True,
|
||||
include_log_monitor=True,
|
||||
include_redis=True,
|
||||
include_webui=True,
|
||||
start_workers_from_local_scheduler=start_workers_from_local_scheduler,
|
||||
num_cpus=num_cpus,
|
||||
num_gpus=num_gpus)
|
||||
return start_ray_processes(
|
||||
address_info=address_info,
|
||||
node_ip_address=node_ip_address,
|
||||
num_workers=num_workers,
|
||||
num_local_schedulers=num_local_schedulers,
|
||||
worker_path=worker_path,
|
||||
cleanup=cleanup,
|
||||
redirect_output=redirect_output,
|
||||
include_global_scheduler=True,
|
||||
include_log_monitor=True,
|
||||
include_redis=True,
|
||||
include_webui=True,
|
||||
start_workers_from_local_scheduler=start_workers_from_local_scheduler,
|
||||
num_cpus=num_cpus,
|
||||
num_gpus=num_gpus)
|
||||
|
||||
|
||||
def new_log_files(name, redirect_output):
|
||||
"""Generate partially randomized filenames for log files.
|
||||
|
||||
Args:
|
||||
name (str): descriptive string for this log file.
|
||||
redirect_output (bool): True if files should be generated for logging stdout
|
||||
and stderr and false if stdout and stderr should not be redirected.
|
||||
redirect_output (bool): True if files should be generated for logging
|
||||
stdout and stderr and false if stdout and stderr should not be
|
||||
redirected.
|
||||
|
||||
Returns:
|
||||
If redirect_output is true, this will return a tuple of two filehandles. The
|
||||
first is for redirecting stdout and the second is for redirecting stderr.
|
||||
If redirect_output is false, this will return a tuple of two None objects.
|
||||
If redirect_output is true, this will return a tuple of two filehandles.
|
||||
The first is for redirecting stdout and the second is for redirecting
|
||||
stderr. If redirect_output is false, this will return a tuple of two None
|
||||
objects.
|
||||
"""
|
||||
if not redirect_output:
|
||||
return None, None
|
||||
|
||||
@@ -8,44 +8,53 @@ import numpy as np
|
||||
|
||||
# Test simple functionality
|
||||
|
||||
|
||||
@ray.remote(num_return_vals=2)
|
||||
def handle_int(a, b):
|
||||
return a + 1, b + 1
|
||||
|
||||
# Test timing
|
||||
|
||||
|
||||
@ray.remote
|
||||
def empty_function():
|
||||
pass
|
||||
|
||||
|
||||
@ray.remote
|
||||
def trivial_function():
|
||||
return 1
|
||||
|
||||
# Test keyword arguments
|
||||
|
||||
|
||||
@ray.remote
|
||||
def keyword_fct1(a, b="hello"):
|
||||
return "{} {}".format(a, b)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def keyword_fct2(a="hello", b="world"):
|
||||
return "{} {}".format(a, b)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def keyword_fct3(a, b, c="hello", d="world"):
|
||||
return "{} {} {} {}".format(a, b, c, d)
|
||||
|
||||
# Test variable numbers of arguments
|
||||
|
||||
|
||||
@ray.remote
|
||||
def varargs_fct1(*a):
|
||||
return " ".join(map(str, a))
|
||||
|
||||
|
||||
@ray.remote
|
||||
def varargs_fct2(a, *b):
|
||||
return " ".join(map(str, b))
|
||||
|
||||
|
||||
try:
|
||||
@ray.remote
|
||||
def kwargs_throw_exception(**c):
|
||||
@@ -64,24 +73,29 @@ except:
|
||||
|
||||
# test throwing an exception
|
||||
|
||||
|
||||
@ray.remote
|
||||
def throw_exception_fct1():
|
||||
raise Exception("Test function 1 intentionally failed.")
|
||||
|
||||
|
||||
@ray.remote
|
||||
def throw_exception_fct2():
|
||||
raise Exception("Test function 2 intentionally failed.")
|
||||
|
||||
|
||||
@ray.remote(num_return_vals=3)
|
||||
def throw_exception_fct3(x):
|
||||
raise Exception("Test function 3 intentionally failed.")
|
||||
|
||||
# test Python mode
|
||||
|
||||
|
||||
@ray.remote
|
||||
def python_mode_f():
|
||||
return np.array([0, 0])
|
||||
|
||||
|
||||
@ray.remote
|
||||
def python_mode_g(x):
|
||||
x[0] = 1
|
||||
@@ -89,14 +103,17 @@ def python_mode_g(x):
|
||||
|
||||
# test no return values
|
||||
|
||||
|
||||
@ray.remote
|
||||
def no_op():
|
||||
pass
|
||||
|
||||
|
||||
class TestClass(object):
|
||||
def __init__(self):
|
||||
self.a = 5
|
||||
|
||||
|
||||
@ray.remote
|
||||
def test_unknown_type():
|
||||
return TestClass()
|
||||
|
||||
+439
-272
File diff suppressed because it is too large
Load Diff
@@ -3,25 +3,33 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import binascii
|
||||
import numpy as np
|
||||
import redis
|
||||
import traceback
|
||||
import sys
|
||||
import binascii
|
||||
|
||||
import ray
|
||||
|
||||
parser = argparse.ArgumentParser(description="Parse addresses for the worker to connect to.")
|
||||
parser.add_argument("--node-ip-address", required=True, type=str, help="the ip address of the worker's node")
|
||||
parser.add_argument("--redis-address", required=True, type=str, help="the address to use for Redis")
|
||||
parser.add_argument("--object-store-name", required=True, type=str, help="the object store's name")
|
||||
parser.add_argument("--object-store-manager-name", required=True, type=str, help="the object store manager's name")
|
||||
parser.add_argument("--local-scheduler-name", required=True, type=str, help="the local scheduler's name")
|
||||
parser.add_argument("--actor-id", required=False, type=str, help="the actor ID of this worker")
|
||||
parser = argparse.ArgumentParser(description=("Parse addresses for the worker "
|
||||
"to connect to."))
|
||||
parser.add_argument("--node-ip-address", required=True, type=str,
|
||||
help="the ip address of the worker's node")
|
||||
parser.add_argument("--redis-address", required=True, type=str,
|
||||
help="the address to use for Redis")
|
||||
parser.add_argument("--object-store-name", required=True, type=str,
|
||||
help="the object store's name")
|
||||
parser.add_argument("--object-store-manager-name", required=True, type=str,
|
||||
help="the object store manager's name")
|
||||
parser.add_argument("--local-scheduler-name", required=True, type=str,
|
||||
help="the local scheduler's name")
|
||||
parser.add_argument("--actor-id", required=False, type=str,
|
||||
help="the actor ID of this worker")
|
||||
|
||||
|
||||
def random_string():
|
||||
return np.random.bytes(20)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
info = {"node_ip_address": args.node_ip_address,
|
||||
@@ -30,7 +38,10 @@ if __name__ == "__main__":
|
||||
"manager_socket_name": args.object_store_manager_name,
|
||||
"local_scheduler_socket_name": args.local_scheduler_name}
|
||||
|
||||
actor_id = binascii.unhexlify(args.actor_id) if not args.actor_id is None else ray.worker.NIL_ACTOR_ID
|
||||
if args.actor_id is not None:
|
||||
actor_id = binascii.unhexlify(args.actor_id)
|
||||
else:
|
||||
actor_id = ray.worker.NIL_ACTOR_ID
|
||||
|
||||
ray.worker.connect(info, mode=ray.WORKER_MODE, actor_id=actor_id)
|
||||
|
||||
@@ -43,24 +54,27 @@ being caught in "python/ray/workers/default_worker.py".
|
||||
while True:
|
||||
try:
|
||||
# This call to main_loop should never return if things are working. Most
|
||||
# exceptions that are thrown (e.g., inside the execution of a task) should
|
||||
# be caught and handled inside of the call to main_loop. If an exception
|
||||
# is thrown here, then that means that there is some error that we didn't
|
||||
# anticipate.
|
||||
# exceptions that are thrown (e.g., inside the execution of a task)
|
||||
# should be caught and handled inside of the call to main_loop. If an
|
||||
# exception is thrown here, then that means that there is some error that
|
||||
# we didn't anticipate.
|
||||
ray.worker.main_loop()
|
||||
except Exception as e:
|
||||
traceback_str = traceback.format_exc() + error_explanation
|
||||
DRIVER_ID_LENGTH = 20
|
||||
# We use a driver ID of all zeros to push an error message to all drivers.
|
||||
# We use a driver ID of all zeros to push an error message to all
|
||||
# drivers.
|
||||
driver_id = DRIVER_ID_LENGTH * b"\x00"
|
||||
error_key = b"Error:" + driver_id + b":" + random_string()
|
||||
redis_ip_address, redis_port = args.redis_address.split(":")
|
||||
# For this command to work, some other client (on the same machine as
|
||||
# Redis) must have run "CONFIG SET protected-mode no".
|
||||
redis_client = redis.StrictRedis(host=redis_ip_address, port=int(redis_port))
|
||||
redis_client = redis.StrictRedis(host=redis_ip_address,
|
||||
port=int(redis_port))
|
||||
redis_client.hmset(error_key, {"type": "worker_crash",
|
||||
"message": traceback_str,
|
||||
"note": "This error is unexpected and should not have happened."})
|
||||
"note": ("This error is unexpected and "
|
||||
"should not have happened.")})
|
||||
redis_client.rpush("ErrorKeys", error_key)
|
||||
# TODO(rkn): Note that if the worker was in the middle of executing a
|
||||
# task, the any worker or driver that is blocking in a get call and
|
||||
|
||||
+16
-11
@@ -2,12 +2,12 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
from setuptools import setup, find_packages
|
||||
import setuptools.command.install as _install
|
||||
|
||||
|
||||
class install(_install.install):
|
||||
def run(self):
|
||||
subprocess.check_call(["../build.sh"])
|
||||
@@ -16,19 +16,24 @@ class install(_install.install):
|
||||
# setuptools. So, calling do_egg_install() manually here.
|
||||
self.do_egg_install()
|
||||
|
||||
|
||||
package_data = {
|
||||
"ray": ["core/src/common/thirdparty/redis/src/redis-server",
|
||||
"core/src/common/redis_module/libray_redis_module.so",
|
||||
"core/src/plasma/plasma_store",
|
||||
"core/src/plasma/plasma_manager",
|
||||
"core/src/plasma/libplasma.so",
|
||||
"core/src/local_scheduler/local_scheduler",
|
||||
"core/src/local_scheduler/liblocal_scheduler_library.so",
|
||||
"core/src/numbuf/libarrow.so",
|
||||
"core/src/numbuf/libnumbuf.so",
|
||||
"core/src/global_scheduler/global_scheduler"]
|
||||
}
|
||||
|
||||
setup(name="ray",
|
||||
version="0.0.1",
|
||||
packages=find_packages(),
|
||||
package_data={"ray": ["core/src/common/thirdparty/redis/src/redis-server",
|
||||
"core/src/common/redis_module/libray_redis_module.so",
|
||||
"core/src/plasma/plasma_store",
|
||||
"core/src/plasma/plasma_manager",
|
||||
"core/src/plasma/libplasma.so",
|
||||
"core/src/local_scheduler/local_scheduler",
|
||||
"core/src/local_scheduler/liblocal_scheduler_library.so",
|
||||
"core/src/numbuf/libarrow.so",
|
||||
"core/src/numbuf/libnumbuf.so",
|
||||
"core/src/global_scheduler/global_scheduler"]},
|
||||
package_data=package_data,
|
||||
cmdclass={"install": install},
|
||||
install_requires=["numpy",
|
||||
"funcsigs",
|
||||
|
||||
Reference in New Issue
Block a user