diff --git a/.travis.yml b/.travis.yml index 414221f5a..de3ad1090 100644 --- a/.travis.yml +++ b/.travis.yml @@ -35,6 +35,10 @@ matrix: - cd doc - pip install -r requirements-doc.txt - sphinx-build -W -b html -d _build/doctrees source _build/html + - cd .. + # Run Python linting. + - flake8 --ignore=E111,E114 + --exclude=python/ray/core/src/common/flatbuffers_ep-prefix/,python/ray/core/generated/,src/numbuf/thirdparty/,src/common/format/,examples/,doc/source/conf.py - os: linux dist: trusty env: VALGRIND=1 PYTHON=2.7 diff --git a/.travis/install-dependencies.sh b/.travis/install-dependencies.sh index 83077febe..94c0506f5 100755 --- a/.travis/install-dependencies.sh +++ b/.travis/install-dependencies.sh @@ -64,6 +64,9 @@ elif [[ "$LINT" == "1" ]]; then # Install miniconda. wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh bash miniconda.sh -b -p $HOME/miniconda + export PATH="$HOME/miniconda/bin:$PATH" + # Install Python linting tools. + pip install flake8 else echo "Unrecognized environment." exit 1 diff --git a/python/ray/__init__.py b/python/ray/__init__.py index e25b8b4fa..a3ef64a60 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -2,19 +2,29 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# Ray version string -__version__ = "0.01" - -import ctypes -# Windows only -if hasattr(ctypes, "windll"): - # Makes sure that all child processes die when we die - # Also makes sure that fatal crashes result in process termination rather than an error dialog (the latter is annoying since we have a lot of processes) - # This is done by associating all child processes with a "job" object that imposes this behavior. - (lambda kernel32: (lambda job: (lambda n: kernel32.SetInformationJobObject(job, 9, "\0" * 17 + chr(0x8 | 0x4 | 0x20) + "\0" * (n - 18), n))(0x90 if ctypes.sizeof(ctypes.c_void_p) > ctypes.sizeof(ctypes.c_int) else 0x70) and kernel32.AssignProcessToJobObject(job, ctypes.c_void_p(kernel32.GetCurrentProcess())))(ctypes.c_void_p(kernel32.CreateJobObjectW(None, None))) if kernel32 is not None else None)(ctypes.windll.kernel32) - -from ray.worker import register_class, error_info, init, connect, disconnect, get, put, wait, remote, log_event, log_span, flush_log +from ray.worker import (register_class, error_info, init, connect, disconnect, + get, put, wait, remote, log_event, log_span, + flush_log) from ray.actor import actor from ray.actor import get_gpu_ids from ray.worker import EnvironmentVariable, env from ray.worker import SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, SILENT_MODE + +# Ray version string +__version__ = "0.01" + +__all__ = ["register_class", "error_info", "init", "connect", "disconnect", + "get", "put", "wait", "remote", "log_event", "log_span", + "flush_log", "actor", "get_gpu_ids", "EnvironmentVariable", "env", + "SCRIPT_MODE", "WORKER_MODE", "PYTHON_MODE", "SILENT_MODE", + "__version__"] + +import ctypes +# Windows only +if hasattr(ctypes, "windll"): + # Makes sure that all child processes die when we die. Also makes sure that + # fatal crashes result in process termination rather than an error dialog + # (the latter is annoying since we have a lot of processes). This is done by + # associating all child processes with a "job" object that imposes this + # behavior. + (lambda kernel32: (lambda job: (lambda n: kernel32.SetInformationJobObject(job, 9, "\0" * 17 + chr(0x8 | 0x4 | 0x20) + "\0" * (n - 18), n))(0x90 if ctypes.sizeof(ctypes.c_void_p) > ctypes.sizeof(ctypes.c_int) else 0x70) and kernel32.AssignProcessToJobObject(job, ctypes.c_void_p(kernel32.GetCurrentProcess())))(ctypes.c_void_p(kernel32.CreateJobObjectW(None, None))) if kernel32 is not None else None)(ctypes.windll.kernel32) # noqa: E501 diff --git a/python/ray/actor.py b/python/ray/actor.py index daae1fa0a..04f476e0a 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -18,6 +18,7 @@ import ray.experimental.state as state # the worker is currently allowed to use. gpu_ids = [] + def get_gpu_ids(): """Get the IDs of the GPU that are available to the worker. @@ -26,12 +27,15 @@ def get_gpu_ids(): """ return gpu_ids + def random_string(): return np.random.bytes(20) + def random_actor_id(): return ray.local_scheduler.ObjectID(random_string()) + def get_actor_method_function_id(attr): """Get the function ID corresponding to an actor method. @@ -47,10 +51,14 @@ def get_actor_method_function_id(attr): assert len(function_id) == 20 return ray.local_scheduler.ObjectID(function_id) + def fetch_and_register_actor(key, worker): """Import an actor.""" - driver_id, actor_id_str, actor_name, module, pickled_class, assigned_gpu_ids, actor_method_names = \ - worker.redis_client.hmget(key, ["driver_id", "actor_id", "name", "module", "class", "gpu_ids", "actor_method_names"]) + (driver_id, actor_id_str, actor_name, + module, pickled_class, assigned_gpu_ids, + actor_method_names) = worker.redis_client.hmget( + key, ["driver_id", "actor_id", "name", "module", "class", "gpu_ids", + "actor_method_names"]) actor_id = ray.local_scheduler.ObjectID(actor_id_str) actor_name = actor_name.decode("ascii") module = module.decode("ascii") @@ -64,12 +72,14 @@ def fetch_and_register_actor(key, worker): class TemporaryActor(object): pass worker.actors[actor_id_str] = TemporaryActor() + def temporary_actor_method(*xs): raise Exception("The actor with name {} failed to be imported, and so " "cannot execute this method".format(actor_name)) for actor_method_name in actor_method_names: function_id = get_actor_method_function_id(actor_method_name).id() - worker.functions[driver_id][function_id] = (actor_method_name, temporary_actor_method) + worker.functions[driver_id][function_id] = (actor_method_name, + temporary_actor_method) try: unpickled_class = pickling.loads(pickled_class) @@ -84,11 +94,15 @@ def fetch_and_register_actor(key, worker): # TODO(pcm): Why is the below line necessary? unpickled_class.__module__ = module worker.actors[actor_id_str] = unpickled_class.__new__(unpickled_class) - for (k, v) in inspect.getmembers(unpickled_class, predicate=(lambda x: inspect.isfunction(x) or inspect.ismethod(x))): + for (k, v) in inspect.getmembers( + unpickled_class, predicate=(lambda x: (inspect.isfunction(x) or + inspect.ismethod(x)))): function_id = get_actor_method_function_id(k).id() worker.functions[driver_id][function_id] = (k, v) - # We do not set worker.function_properties[driver_id][function_id] because - # we currently do need the actor worker to submit new tasks for the actor. + # We do not set worker.function_properties[driver_id][function_id] + # because we currently do need the actor worker to submit new tasks for + # the actor. + def select_local_scheduler(local_schedulers, num_gpus, worker): """Select a local scheduler to assign this actor to. @@ -119,15 +133,19 @@ def select_local_scheduler(local_schedulers, num_gpus, worker): # Loop through all of the local schedulers. for local_scheduler in local_schedulers: # See if there are enough available GPUs on this local scheduler. - local_scheduler_total_gpus = int(float(local_scheduler[b"num_gpus"].decode("ascii"))) - gpus_in_use = worker.redis_client.hget(local_scheduler[b"ray_client_id"], b"gpus_in_use") + local_scheduler_total_gpus = int(float( + local_scheduler[b"num_gpus"].decode("ascii"))) + gpus_in_use = worker.redis_client.hget(local_scheduler[b"ray_client_id"], + b"gpus_in_use") gpus_in_use = 0 if gpus_in_use is None else int(gpus_in_use) if gpus_in_use + num_gpus <= local_scheduler_total_gpus: # Attempt to reserve some GPUs for this actor. - new_gpus_in_use = worker.redis_client.hincrby(local_scheduler[b"ray_client_id"], b"gpus_in_use", num_gpus) + new_gpus_in_use = worker.redis_client.hincrby( + local_scheduler[b"ray_client_id"], b"gpus_in_use", num_gpus) if new_gpus_in_use > local_scheduler_total_gpus: # If we failed to reserve the GPUs, undo the increment. - worker.redis_client.hincrby(local_scheduler[b"ray_client_id"], b"gpus_in_use", num_gpus) + worker.redis_client.hincrby(local_scheduler[b"ray_client_id"], + b"gpus_in_use", num_gpus) else: # We succeeded at reserving the GPUs, so we are done. local_scheduler_id = local_scheduler[b"ray_client_id"] @@ -135,10 +153,13 @@ def select_local_scheduler(local_schedulers, num_gpus, worker): break if local_scheduler_id is None: raise Exception("Could not find a node with enough GPUs to create this " - "actor. The local scheduler information is {}.".format(local_schedulers)) + "actor. The local scheduler information is {}." + .format(local_schedulers)) return local_scheduler_id, gpu_ids -def export_actor(actor_id, Class, actor_method_names, num_cpus, num_gpus, worker): + +def export_actor(actor_id, Class, actor_method_names, num_cpus, num_gpus, + worker): """Export an actor to redis. Args: @@ -158,13 +179,16 @@ def export_actor(actor_id, Class, actor_method_names, num_cpus, num_gpus, worker driver_id = worker.task_driver_id.id() for actor_method_name in actor_method_names: function_id = get_actor_method_function_id(actor_method_name).id() - worker.function_properties[driver_id][function_id] = (1, num_cpus, num_gpus) + worker.function_properties[driver_id][function_id] = (1, num_cpus, + num_gpus) # Select a local scheduler for the actor. local_schedulers = state.get_local_schedulers(worker) - local_scheduler_id, gpu_ids = select_local_scheduler(local_schedulers, num_gpus, worker) + local_scheduler_id, gpu_ids = select_local_scheduler(local_schedulers, + num_gpus, worker) - worker.redis_client.publish("actor_notifications", actor_id.id() + local_scheduler_id) + worker.redis_client.publish("actor_notifications", + actor_id.id() + local_scheduler_id) d = {"driver_id": driver_id, "actor_id": actor_id.id(), @@ -176,6 +200,7 @@ def export_actor(actor_id, Class, actor_method_names, num_cpus, num_gpus, worker worker.redis_client.hmset(key, d) worker.redis_client.rpush("Exports", key) + def actor(*args, **kwargs): def make_actor_decorator(num_cpus=1, num_gpus=0): def make_actor(Class): @@ -189,7 +214,8 @@ def actor(*args, **kwargs): raise Exception("Actors currently do not support **kwargs.") function_id = get_actor_method_function_id(attr) # TODO(pcm): Extend args with keyword args. - object_ids = ray.worker.global_worker.submit_task(function_id, "", args, + object_ids = ray.worker.global_worker.submit_task(function_id, "", + args, actor_id=actor_id) if len(object_ids) == 1: return object_ids[0] @@ -199,24 +225,34 @@ def actor(*args, **kwargs): class NewClass(object): def __init__(self, *args, **kwargs): self._ray_actor_id = random_actor_id() - self._ray_actor_methods = {k: v for (k, v) in inspect.getmembers(Class, predicate=(lambda x: inspect.isfunction(x) or inspect.ismethod(x)))} - export_actor(self._ray_actor_id, Class, self._ray_actor_methods.keys(), num_cpus, num_gpus, ray.worker.global_worker) + self._ray_actor_methods = { + k: v for (k, v) in inspect.getmembers( + Class, predicate=(lambda x: (inspect.isfunction(x) or + inspect.ismethod(x))))} + export_actor(self._ray_actor_id, Class, + self._ray_actor_methods.keys(), num_cpus, num_gpus, + ray.worker.global_worker) # Call __init__ as a remote function. if "__init__" in self._ray_actor_methods.keys(): actor_method_call(self._ray_actor_id, "__init__", *args, **kwargs) else: print("WARNING: this object has no __init__ method.") + # Make tab completion work. def __dir__(self): return self._ray_actor_methods + def __getattribute__(self, attr): # The following is needed so we can still access self.actor_methods. if attr in ["_ray_actor_id", "_ray_actor_methods"]: return super(NewClass, self).__getattribute__(attr) if attr in self._ray_actor_methods.keys(): - return lambda *args, **kwargs: actor_method_call(self._ray_actor_id, attr, *args, **kwargs) + return lambda *args, **kwargs: actor_method_call( + self._ray_actor_id, attr, *args, **kwargs) # There is no method with this name, so raise an exception. - raise AttributeError("'{}' Actor object has no attribute '{}'".format(Class, attr)) + raise AttributeError("'{}' Actor object has no attribute '{}'" + .format(Class, attr)) + def __repr__(self): return "Actor(" + self._ray_actor_id.hex() + ")" @@ -230,7 +266,9 @@ def actor(*args, **kwargs): return make_actor_decorator(num_cpus=1, num_gpus=0)(Class) # In this case, the actor decorator is something like @ray.actor(num_gpus=1). - if len(args) == 0 and len(kwargs) > 0 and all([key in ["num_cpus", "num_gpus"] for key in kwargs.keys()]): + if len(args) == 0 and len(kwargs) > 0 and all([key + in ["num_cpus", "num_gpus"] + for key in kwargs.keys()]): num_cpus = kwargs["num_cpus"] if "num_cpus" in kwargs.keys() else 1 num_gpus = kwargs["num_gpus"] if "num_gpus" in kwargs.keys() else 0 return make_actor_decorator(num_cpus=num_cpus, num_gpus=num_gpus) @@ -240,4 +278,5 @@ def actor(*args, **kwargs): "some of the arguments 'num_cpus' or 'num_gpus' as in " "'ray.actor(num_gpus=1)'.") + ray.worker.global_worker.fetch_and_register["Actor"] = fetch_and_register_actor diff --git a/python/ray/common/redis_module/runtest.py b/python/ray/common/redis_module/runtest.py index c82838a1b..493af9c50 100644 --- a/python/ray/common/redis_module/runtest.py +++ b/python/ray/common/redis_module/runtest.py @@ -2,17 +2,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os -import random -import subprocess +import redis import sys import time import unittest -import redis + import ray.services # Import flatbuffer bindings. -from ray.core.generated.SubscribeToNotificationsReply import SubscribeToNotificationsReply +from ray.core.generated.SubscribeToNotificationsReply \ + import SubscribeToNotificationsReply from ray.core.generated.TaskReply import TaskReply from ray.core.generated.ResultTableReply import ResultTableReply @@ -22,6 +21,7 @@ OBJECT_SUBSCRIBE_PREFIX = "OS:" TASK_PREFIX = "TT:" OBJECT_CHANNEL_PREFIX = "OC:" + def integerToAsciiHex(num, numbytes): retstr = b"" # Support 32 and 64 bit architecture. @@ -36,6 +36,7 @@ def integerToAsciiHex(num, numbytes): return retstr + def get_next_message(pubsub_client, timeout_seconds=10): """Block until the next message is available on the pubsub channel.""" start_time = time.time() @@ -47,6 +48,7 @@ def get_next_message(pubsub_client, timeout_seconds=10): if time.time() - start_time > timeout_seconds: raise Exception("Timed out while waiting for next message.") + class TestGlobalStateStore(unittest.TestCase): def setUp(self): @@ -57,102 +59,144 @@ class TestGlobalStateStore(unittest.TestCase): ray.services.cleanup() def testInvalidObjectTableAdd(self): - # Check that Redis returns an error when RAY.OBJECT_TABLE_ADD is called with - # the wrong arguments. + # Check that Redis returns an error when RAY.OBJECT_TABLE_ADD is called + # with the wrong arguments. with self.assertRaises(redis.ResponseError): self.redis.execute_command("RAY.OBJECT_TABLE_ADD") with self.assertRaises(redis.ResponseError): self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "hello") with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", "one", "hash2", "manager_id1") + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", "one", + "hash2", "manager_id1") with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", 1, "hash2", "manager_id1", "extra argument") - # Check that Redis returns an error when RAY.OBJECT_TABLE_ADD adds an object - # ID that is already present with a different hash. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1") + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", 1, + "hash2", "manager_id1", "extra argument") + # Check that Redis returns an error when RAY.OBJECT_TABLE_ADD adds an + # object ID that is already present with a different hash. + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, + "hash1", "manager_id1") + response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", + "object_id1") self.assertEqual(set(response), {b"manager_id1"}) with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash2", "manager_id2") + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, + "hash2", "manager_id2") # Check that the second manager was added, even though the hash was # mismatched. - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1") + response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", + "object_id1") self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) - # Check that it is fine if we add the same object ID multiple times with the - # most recent hash. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash2", "manager_id1") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash2", "manager_id1") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash2", "manager_id2") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 2, "hash2", "manager_id2") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1") + # Check that it is fine if we add the same object ID multiple times with + # the most recent hash. + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, + "hash2", "manager_id1") + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, + "hash2", "manager_id1") + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, + "hash2", "manager_id2") + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 2, + "hash2", "manager_id2") + response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", + "object_id1") self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) def testObjectTableAddAndLookup(self): # Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not been # added yet. - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1") + response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", + "object_id1") self.assertEqual(response, None) # Add some managers and try again. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id2") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1") + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, + "hash1", "manager_id1") + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, + "hash1", "manager_id2") + response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", + "object_id1") self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) # Add a manager that already exists again and try again. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id2") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1") + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, + "hash1", "manager_id2") + response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", + "object_id1") self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) # Check that we properly handle NULL characters. In the past, NULL # characters were handled improperly causing a "hash mismatch" error if two # object IDs that agreed up to the NULL character were inserted with # different hashes. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id3", 1, "hash1", "manager_id1") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id4", 1, "hash2", "manager_id1") + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id3", 1, + "hash1", "manager_id1") + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id4", 1, + "hash2", "manager_id1") # Check that NULL characters in the hash are handled properly. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1, "\x00hash1", "manager_id1") + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1, + "\x00hash1", "manager_id1") with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1, "\x00hash2", "manager_id1") + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1, + "\x00hash2", "manager_id1") def testObjectTableAddAndRemove(self): # Try removing a manager from an object ID that has not been added yet. with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id1") + self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", + "manager_id1") # Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not been # added yet. - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1") + response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", + "object_id1") self.assertEqual(response, None) # Add some managers and try again. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id2") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1") + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, + "hash1", "manager_id1") + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, + "hash1", "manager_id2") + response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", + "object_id1") self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) - # Remove a manager that doesn't exist, and make sure we still have the same set. - self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id3") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1") + # Remove a manager that doesn't exist, and make sure we still have the same + # set. + self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", + "manager_id3") + response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", + "object_id1") self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) # Remove a manager that does exist. Make sure it gets removed the first # time and does nothing the second time. - self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id1") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1") + self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", + "manager_id1") + response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", + "object_id1") self.assertEqual(set(response), {b"manager_id2"}) - self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id1") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1") + self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", + "manager_id1") + response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", + "object_id1") self.assertEqual(set(response), {b"manager_id2"}) # Remove the last manager, and make sure we have an empty set. - self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id2") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1") + self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", + "manager_id2") + response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", + "object_id1") self.assertEqual(set(response), set()) - # Remove a manager from an empty set, and make sure we now have an empty set. - self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id3") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1") + # Remove a manager from an empty set, and make sure we now have an empty + # set. + self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", + "manager_id3") + response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", + "object_id1") self.assertEqual(set(response), set()) def testObjectTableSubscribeToNotifications(self): # Define a helper method for checking the contents of object notifications. - def check_object_notification(notification_message, object_id, object_size, manager_ids): - notification_object = SubscribeToNotificationsReply.GetRootAsSubscribeToNotificationsReply(notification_message, 0) + def check_object_notification(notification_message, object_id, object_size, + manager_ids): + notification_object = (SubscribeToNotificationsReply + .GetRootAsSubscribeToNotificationsReply( + notification_message, 0)) self.assertEqual(notification_object.ObjectId(), object_id) self.assertEqual(notification_object.ObjectSize(), object_size) - self.assertEqual(notification_object.ManagerIdsLength(), len(manager_ids)) + self.assertEqual(notification_object.ManagerIdsLength(), + len(manager_ids)) for i in range(len(manager_ids)): self.assertEqual(notification_object.ManagerIds(i), manager_ids[i]) @@ -160,37 +204,45 @@ class TestGlobalStateStore(unittest.TestCase): p = self.redis.pubsub() # Subscribe to an object ID. p.psubscribe("{}manager_id1".format(OBJECT_CHANNEL_PREFIX)) - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", data_size, "hash1", "manager_id2") + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", data_size, + "hash1", "manager_id2") # Receive the acknowledgement message. self.assertEqual(get_next_message(p)["data"], 1) # Request a notification and receive the data. - self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", "manager_id1", "object_id1") + self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", + "manager_id1", "object_id1") # Verify that the notification is correct. check_object_notification(get_next_message(p)["data"], b"object_id1", data_size, [b"manager_id2"]) - # Request a notification for an object that isn't there. Then add the object - # and receive the data. Only the first call to RAY.OBJECT_TABLE_ADD should - # trigger notifications. - self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", "manager_id1", "object_id2", "object_id3") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, "hash1", "manager_id1") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, "hash1", "manager_id2") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, "hash1", "manager_id3") + # Request a notification for an object that isn't there. Then add the + # object and receive the data. Only the first call to RAY.OBJECT_TABLE_ADD + # should trigger notifications. + self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", + "manager_id1", "object_id2", "object_id3") + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, + "hash1", "manager_id1") + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, + "hash1", "manager_id2") + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, + "hash1", "manager_id3") # Verify that the notification is correct. check_object_notification(get_next_message(p)["data"], b"object_id3", data_size, [b"manager_id1"]) - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", data_size, "hash1", "manager_id3") + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", data_size, + "hash1", "manager_id3") # Verify that the notification is correct. check_object_notification(get_next_message(p)["data"], b"object_id2", data_size, [b"manager_id3"]) # Request notifications for object_id3 again. - self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", "manager_id1", "object_id3") + self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", + "manager_id1", "object_id3") # Verify that the notification is correct. check_object_notification(get_next_message(p)["data"], b"object_id3", @@ -199,29 +251,38 @@ class TestGlobalStateStore(unittest.TestCase): def testResultTableAddAndLookup(self): def check_result_table_entry(message, task_id, is_put): - result_table_reply = ResultTableReply.GetRootAsResultTableReply(message, 0) + result_table_reply = ResultTableReply.GetRootAsResultTableReply(message, + 0) self.assertEqual(result_table_reply.TaskId(), task_id) self.assertEqual(result_table_reply.IsPut(), is_put) # Try looking up something in the result table before anything is added. - response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1") + response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", + "object_id1") self.assertIsNone(response) # Adding the object to the object table should have no effect. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1") - response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1") + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, + "hash1", "manager_id1") + response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", + "object_id1") self.assertIsNone(response) # Add the result to the result table. The lookup now returns the task ID. task_id = b"task_id1" - self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id1", task_id, 0) - response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1") + self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id1", task_id, + 0) + response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", + "object_id1") check_result_table_entry(response, task_id, False) # Doing it again should still work. - response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1") + response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", + "object_id1") check_result_table_entry(response, task_id, False) # Try another result table lookup. This should succeed. task_id = b"task_id2" - self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id2", task_id, 1) - response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id2") + self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id2", task_id, + 1) + response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", + "object_id2") check_result_table_entry(response, task_id, True) def testInvalidTaskTableAdd(self): @@ -251,7 +312,8 @@ class TestGlobalStateStore(unittest.TestCase): task_status, local_scheduler_id, task_spec = task_args task_reply_object = TaskReply.GetRootAsTaskReply(message, 0) self.assertEqual(task_reply_object.State(), task_status) - self.assertEqual(task_reply_object.LocalSchedulerId(), local_scheduler_id) + self.assertEqual(task_reply_object.LocalSchedulerId(), + local_scheduler_id) self.assertEqual(task_reply_object.TaskSpec(), task_spec) self.assertEqual(task_reply_object.Updated(), updated) @@ -263,7 +325,8 @@ class TestGlobalStateStore(unittest.TestCase): check_task_reply(response, task_args) task_args[0] = TASK_STATUS_SCHEDULED - self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", *task_args[:2]) + self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", + *task_args[:2]) response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id") check_task_reply(response, task_args) @@ -331,9 +394,12 @@ class TestGlobalStateStore(unittest.TestCase): # Subscribe to the task table. p = self.redis.pubsub() p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX)) - p.psubscribe("{prefix}*:{state}".format(prefix=TASK_PREFIX, state=scheduling_state)) - p.psubscribe("{prefix}{local_scheduler_id}:*".format(prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id)) - task_args = [b"task_id", scheduling_state, local_scheduler_id.encode("ascii"), b"task_spec"] + p.psubscribe("{prefix}*:{state}".format( + prefix=TASK_PREFIX, state=scheduling_state)) + p.psubscribe("{prefix}{local_scheduler_id}:*".format( + prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id)) + task_args = [b"task_id", scheduling_state, + local_scheduler_id.encode("ascii"), b"task_spec"] self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args) # Receive the acknowledgement message. self.assertEqual(get_next_message(p)["data"], 1) @@ -346,8 +412,10 @@ class TestGlobalStateStore(unittest.TestCase): notification_object = TaskReply.GetRootAsTaskReply(message, 0) self.assertEqual(notification_object.TaskId(), b"task_id") self.assertEqual(notification_object.State(), scheduling_state) - self.assertEqual(notification_object.LocalSchedulerId(), local_scheduler_id.encode("ascii")) + self.assertEqual(notification_object.LocalSchedulerId(), + local_scheduler_id.encode("ascii")) self.assertEqual(notification_object.TaskSpec(), b"task_spec") + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/python/ray/common/test/test.py b/python/ray/common/test/test.py index 05d59eec2..6e0a62433 100644 --- a/python/ray/common/test/test.py +++ b/python/ray/common/test/test.py @@ -11,25 +11,29 @@ import ray.local_scheduler as local_scheduler ID_SIZE = 20 + def random_object_id(): return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) + def random_function_id(): return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) + def random_driver_id(): return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) + def random_task_id(): return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) + BASE_SIMPLE_OBJECTS = [ - 0, 1, 100000, 0.0, 0.5, 0.9, 100000.1, (), [], {}, - "", 990 * "h", u"", 990 * u"h" -] + 0, 1, 100000, 0.0, 0.5, 0.9, 100000.1, (), [], {}, "", 990 * "h", u"", + 990 * u"h"] if sys.version_info < (3, 0): - BASE_SIMPLE_OBJECTS += [long(0), long(1), long(100000), long(1 << 100)] + BASE_SIMPLE_OBJECTS += [long(0), long(1), long(100000), long(1 << 100)] # noqa: E501,F821 LIST_SIMPLE_OBJECTS = [[obj] for obj in BASE_SIMPLE_OBJECTS] TUPLE_SIMPLE_OBJECTS = [(obj,) for obj in BASE_SIMPLE_OBJECTS] @@ -45,11 +49,14 @@ SIMPLE_OBJECTS = (BASE_SIMPLE_OBJECTS + l = [] l.append(l) + class Foo(object): def __init__(self): pass -BASE_COMPLEX_OBJECTS = [999 * "h", 999 * u"h", l, Foo(), 10 * [10 * [10 * [1]]]] + +BASE_COMPLEX_OBJECTS = [999 * "h", 999 * u"h", l, Foo(), + 10 * [10 * [10 * [1]]]] LIST_COMPLEX_OBJECTS = [[obj] for obj in BASE_COMPLEX_OBJECTS] TUPLE_COMPLEX_OBJECTS = [(obj,) for obj in BASE_COMPLEX_OBJECTS] @@ -60,6 +67,7 @@ COMPLEX_OBJECTS = (BASE_COMPLEX_OBJECTS + TUPLE_COMPLEX_OBJECTS + DICT_COMPLEX_OBJECTS) + class TestSerialization(unittest.TestCase): def test_serialize_by_value(self): @@ -69,27 +77,31 @@ class TestSerialization(unittest.TestCase): for val in COMPLEX_OBJECTS: self.assertFalse(local_scheduler.check_simple_value(val)) + class TestObjectID(unittest.TestCase): def test_create_object_id(self): - object_id = random_object_id() + random_object_id() def test_cannot_pickle_object_ids(self): object_ids = [random_object_id() for _ in range(256)] + def f(): return object_ids + def g(val=object_ids): return 1 + def h(): - x = object_ids[0] + object_ids[0] return 1 # Make sure that object IDs cannot be pickled (including functions that # close over object IDs). - self.assertRaises(Exception, lambda : pickling.dumps(object_ids[0])) - self.assertRaises(Exception, lambda : pickling.dumps(object_ids)) - self.assertRaises(Exception, lambda : pickling.dumps(f)) - self.assertRaises(Exception, lambda : pickling.dumps(g)) - self.assertRaises(Exception, lambda : pickling.dumps(h)) + self.assertRaises(Exception, lambda: pickle.dumps(object_ids[0])) + self.assertRaises(Exception, lambda: pickle.dumps(object_ids)) + self.assertRaises(Exception, lambda: pickle.dumps(f)) + self.assertRaises(Exception, lambda: pickle.dumps(g)) + self.assertRaises(Exception, lambda: pickle.dumps(h)) def test_equality_comparisons(self): x1 = local_scheduler.ObjectID(ID_SIZE * b"a") @@ -101,8 +113,10 @@ class TestObjectID(unittest.TestCase): self.assertNotEqual(x1, y1) random_strings = [np.random.bytes(ID_SIZE) for _ in range(256)] - object_ids1 = [local_scheduler.ObjectID(random_strings[i]) for i in range(256)] - object_ids2 = [local_scheduler.ObjectID(random_strings[i]) for i in range(256)] + object_ids1 = [local_scheduler.ObjectID(random_strings[i]) + for i in range(256)] + object_ids2 = [local_scheduler.ObjectID(random_strings[i]) + for i in range(256)] self.assertEqual(len(set(object_ids1)), 256) self.assertEqual(len(set(object_ids1 + object_ids2)), 256) self.assertEqual(set(object_ids1), set(object_ids2)) @@ -113,6 +127,7 @@ class TestObjectID(unittest.TestCase): {x: y} set([x, y]) + class TestTask(unittest.TestCase): def check_task(self, task, function_id, num_return_vals, args): @@ -127,44 +142,46 @@ class TestTask(unittest.TestCase): self.assertEqual(retrieved_args[i], args[i]) def test_create_and_serialize_task(self): - # TODO(rkn): The function ID should be a FunctionID object, not an ObjectID. + # TODO(rkn): The function ID should be a FunctionID object, not an + # ObjectID. driver_id = random_driver_id() parent_id = random_task_id() function_id = random_function_id() object_ids = [random_object_id() for _ in range(256)] args_list = [ - [], - 1 * [1], - 10 * [1], - 100 * [1], - 1000 * [1], - 1 * ["a"], - 10 * ["a"], - 100 * ["a"], - 1000 * ["a"], - [1, 1.3, 2, 1 << 100, "hi", u"hi", [1, 2]], - object_ids[:1], - object_ids[:2], - object_ids[:3], - object_ids[:4], - object_ids[:5], - object_ids[:10], - object_ids[:100], - object_ids[:256], - [1, object_ids[0]], - [object_ids[0], "a"], - [1, object_ids[0], "a"], - [object_ids[0], 1, object_ids[1], "a"], - object_ids[:3] + [1, "hi", 2.3] + object_ids[:5], - object_ids + 100 * ["a"] + object_ids - ] + [], + 1 * [1], + 10 * [1], + 100 * [1], + 1000 * [1], + 1 * ["a"], + 10 * ["a"], + 100 * ["a"], + 1000 * ["a"], + [1, 1.3, 2, 1 << 100, "hi", u"hi", [1, 2]], + object_ids[:1], + object_ids[:2], + object_ids[:3], + object_ids[:4], + object_ids[:5], + object_ids[:10], + object_ids[:100], + object_ids[:256], + [1, object_ids[0]], + [object_ids[0], "a"], + [1, object_ids[0], "a"], + [object_ids[0], 1, object_ids[1], "a"], + object_ids[:3] + [1, "hi", 2.3] + object_ids[:5], + object_ids + 100 * ["a"] + object_ids] for args in args_list: for num_return_vals in [0, 1, 2, 3, 5, 10, 100]: - task = local_scheduler.Task(driver_id, function_id, args, num_return_vals, parent_id, 0) + task = local_scheduler.Task(driver_id, function_id, args, + num_return_vals, parent_id, 0) self.check_task(task, function_id, num_return_vals, args) data = local_scheduler.task_to_string(task) task2 = local_scheduler.task_from_string(data) self.check_task(task2, function_id, num_return_vals, args) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/python/ray/experimental/__init__.py b/python/ray/experimental/__init__.py index 66fbb4886..94ab6498d 100644 --- a/python/ray/experimental/__init__.py +++ b/python/ray/experimental/__init__.py @@ -4,3 +4,5 @@ from __future__ import print_function from .utils import copy_directory from .tfutils import TensorFlowVariables + +__all__ = ["copy_directory", "TensorFlowVariables"] diff --git a/python/ray/experimental/array/distributed/__init__.py b/python/ray/experimental/array/distributed/__init__.py index 2c0ddfffc..61e20f7b7 100644 --- a/python/ray/experimental/array/distributed/__init__.py +++ b/python/ray/experimental/array/distributed/__init__.py @@ -4,4 +4,10 @@ from __future__ import print_function from . import random from . import linalg -from .core import * +from .core import (BLOCK_SIZE, DistArray, assemble, zeros, ones, copy, eye, + triu, tril, blockwise_dot, dot, transpose, add, subtract, + numpy_to_dist, subblocks) + +__all__ = ["random", "linalg", "BLOCK_SIZE", "DistArray", "assemble", "zeros", + "ones", "copy", "eye", "triu", "tril", "blockwise_dot", "dot", + "transpose", "add", "subtract", "numpy_to_dist", "subblocks"] diff --git a/python/ray/experimental/array/distributed/core.py b/python/ray/experimental/array/distributed/core.py index d885218c4..7d10345f5 100644 --- a/python/ray/experimental/array/distributed/core.py +++ b/python/ray/experimental/array/distributed/core.py @@ -6,30 +6,38 @@ import numpy as np import ray.experimental.array.remote as ra import ray -__all__ = ["BLOCK_SIZE", "DistArray", "assemble", "zeros", "ones", "copy", - "eye", "triu", "tril", "blockwise_dot", "dot", "transpose", "add", "subtract", "numpy_to_dist", "subblocks"] - BLOCK_SIZE = 10 + class DistArray(object): def __init__(self, shape, objectids=None): self.shape = shape self.ndim = len(shape) self.num_blocks = [int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in self.shape] - self.objectids = objectids if objectids is not None else np.empty(self.num_blocks, dtype=object) + if objectids is not None: + self.objectids = objectids + else: + self.objectids = np.empty(self.num_blocks, dtype=object) if self.num_blocks != list(self.objectids.shape): - raise Exception("The fields `num_blocks` and `objectids` are inconsistent, `num_blocks` is {} and `objectids` has shape {}".format(self.num_blocks, list(self.objectids.shape))) + raise Exception("The fields `num_blocks` and `objectids` are " + "inconsistent, `num_blocks` is {} and `objectids` has " + "shape {}".format(self.num_blocks, + list(self.objectids.shape))) @staticmethod def compute_block_lower(index, shape): if len(index) != len(shape): - raise Exception("The fields `index` and `shape` must have the same length, but `index` is {} and `shape` is {}.".format(index, shape)) + raise Exception("The fields `index` and `shape` must have the same " + "length, but `index` is {} and `shape` is " + "{}.".format(index, shape)) return [elem * BLOCK_SIZE for elem in index] @staticmethod def compute_block_upper(index, shape): if len(index) != len(shape): - raise Exception("The fields `index` and `shape` must have the same length, but `index` is {} and `shape` is {}.".format(index, shape)) + raise Exception("The fields `index` and `shape` must have the same " + "length, but `index` is {} and `shape` is " + "{}.".format(index, shape)) upper = [] for i in range(len(shape)): upper.append(min((index[i] + 1) * BLOCK_SIZE, shape[i])) @@ -46,59 +54,73 @@ class DistArray(object): return [int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in shape] def assemble(self): - """Assemble an array on this node from a distributed array of object IDs.""" + """Assemble an array from a distributed array of object IDs.""" first_block = ray.get(self.objectids[(0,) * self.ndim]) dtype = first_block.dtype result = np.zeros(self.shape, dtype=dtype) for index in np.ndindex(*self.num_blocks): lower = DistArray.compute_block_lower(index, self.shape) upper = DistArray.compute_block_upper(index, self.shape) - result[[slice(l, u) for (l, u) in zip(lower, upper)]] = ray.get(self.objectids[index]) + result[[slice(l, u) for (l, u) in zip(lower, upper)]] = ray.get( + self.objectids[index]) return result def __getitem__(self, sliced): - # TODO(rkn): fix this, this is just a placeholder that should work but is inefficient + # TODO(rkn): Fix this, this is just a placeholder that should work but is + # inefficient. a = self.assemble() return a[sliced] + # Register the DistArray class with Ray so that it knows how to serialize it. ray.register_class(DistArray) + @ray.remote def assemble(a): return a.assemble() -# TODO(rkn): what should we call this method + +# TODO(rkn): What should we call this method? @ray.remote def numpy_to_dist(a): result = DistArray(a.shape) for index in np.ndindex(*result.num_blocks): lower = DistArray.compute_block_lower(index, a.shape) upper = DistArray.compute_block_upper(index, a.shape) - result.objectids[index] = ray.put(a[[slice(l, u) for (l, u) in zip(lower, upper)]]) + result.objectids[index] = ray.put(a[[slice(l, u) for (l, u) + in zip(lower, upper)]]) return result + @ray.remote def zeros(shape, dtype_name="float"): result = DistArray(shape) for index in np.ndindex(*result.num_blocks): - result.objectids[index] = ra.zeros.remote(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name) + result.objectids[index] = ra.zeros.remote( + DistArray.compute_block_shape(index, shape), dtype_name=dtype_name) return result + @ray.remote def ones(shape, dtype_name="float"): result = DistArray(shape) for index in np.ndindex(*result.num_blocks): - result.objectids[index] = ra.ones.remote(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name) + result.objectids[index] = ra.ones.remote( + DistArray.compute_block_shape(index, shape), dtype_name=dtype_name) return result + @ray.remote def copy(a): result = DistArray(a.shape) for index in np.ndindex(*result.num_blocks): - result.objectids[index] = a.objectids[index] # We don't need to actually copy the objects because cluster-level objects are assumed to be immutable. + # We don't need to actually copy the objects because remote objects are + # immutable. + result.objectids[index] = a.objectids[index] return result + @ray.remote def eye(dim1, dim2=-1, dtype_name="float"): dim2 = dim1 if dim2 == -1 else dim2 @@ -107,15 +129,19 @@ def eye(dim1, dim2=-1, dtype_name="float"): for (i, j) in np.ndindex(*result.num_blocks): block_shape = DistArray.compute_block_shape([i, j], shape) if i == j: - result.objectids[i, j] = ra.eye.remote(block_shape[0], block_shape[1], dtype_name=dtype_name) + result.objectids[i, j] = ra.eye.remote(block_shape[0], block_shape[1], + dtype_name=dtype_name) else: - result.objectids[i, j] = ra.zeros.remote(block_shape, dtype_name=dtype_name) + result.objectids[i, j] = ra.zeros.remote(block_shape, + dtype_name=dtype_name) return result + @ray.remote def triu(a): if a.ndim != 2: - raise Exception("Input must have 2 dimensions, but a.ndim is " + str(a.ndim)) + raise Exception("Input must have 2 dimensions, but a.ndim is " + "{}.".format(a.ndim)) result = DistArray(a.shape) for (i, j) in np.ndindex(*result.num_blocks): if i < j: @@ -126,10 +152,12 @@ def triu(a): result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j]) return result + @ray.remote def tril(a): if a.ndim != 2: - raise Exception("Input must have 2 dimensions, but a.ndim is " + str(a.ndim)) + raise Exception("Input must have 2 dimensions, but a.ndim is " + "{}.".format(a.ndim)) result = DistArray(a.shape) for (i, j) in np.ndindex(*result.num_blocks): if i > j: @@ -140,25 +168,31 @@ def tril(a): result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j]) return result + @ray.remote def blockwise_dot(*matrices): n = len(matrices) if n % 2 != 0: - raise Exception("blockwise_dot expects an even number of arguments, but len(matrices) is {}.".format(n)) + raise Exception("blockwise_dot expects an even number of arguments, but " + "len(matrices) is {}.".format(n)) shape = (matrices[0].shape[0], matrices[n // 2].shape[1]) result = np.zeros(shape) for i in range(n // 2): result += np.dot(matrices[i], matrices[n // 2 + i]) return result + @ray.remote def dot(a, b): if a.ndim != 2: - raise Exception("dot expects its arguments to be 2-dimensional, but a.ndim = {}.".format(a.ndim)) + raise Exception("dot expects its arguments to be 2-dimensional, but " + "a.ndim = {}.".format(a.ndim)) if b.ndim != 2: - raise Exception("dot expects its arguments to be 2-dimensional, but b.ndim = {}.".format(b.ndim)) + raise Exception("dot expects its arguments to be 2-dimensional, but " + "b.ndim = {}.".format(b.ndim)) if a.shape[1] != b.shape[0]: - raise Exception("dot expects a.shape[1] to equal b.shape[0], but a.shape = {} and b.shape = {}.".format(a.shape, b.shape)) + raise Exception("dot expects a.shape[1] to equal b.shape[0], but a.shape " + "= {} and b.shape = {}.".format(a.shape, b.shape)) shape = [a.shape[0], b.shape[1]] result = DistArray(shape) for (i, j) in np.ndindex(*result.num_blocks): @@ -166,10 +200,12 @@ def dot(a, b): result.objectids[i, j] = blockwise_dot.remote(*args) return result + @ray.remote def subblocks(a, *ranges): """ - This function produces a distributed array from a subset of the blocks in the `a`. The result and `a` will have the same number of dimensions.For example, + This function produces a distributed array from a subset of the blocks in the + `a`. The result and `a` will have the same number of dimensions.For example, subblocks(a, [0, 1], [2, 4]) will produce a DistArray whose objectids are [[a.objectids[0, 2], a.objectids[0, 4]], @@ -178,50 +214,71 @@ def subblocks(a, *ranges): """ ranges = list(ranges) if len(ranges) != a.ndim: - raise Exception("sub_blocks expects to receive a number of ranges equal to a.ndim, but it received {} ranges and a.ndim = {}.".format(len(ranges), a.ndim)) + raise Exception("sub_blocks expects to receive a number of ranges equal " + "to a.ndim, but it received {} ranges and a.ndim = " + "{}.".format(len(ranges), a.ndim)) for i in range(len(ranges)): - if ranges[i] == []: # We allow the user to pass in an empty list to indicate the full range + # We allow the user to pass in an empty list to indicate the full range. + if ranges[i] == []: ranges[i] = range(a.num_blocks[i]) if not np.alltrue(ranges[i] == np.sort(ranges[i])): - raise Exception("Ranges passed to sub_blocks must be sorted, but the {}th range is {}.".format(i, ranges[i])) + raise Exception("Ranges passed to sub_blocks must be sorted, but the " + "{}th range is {}.".format(i, ranges[i])) if ranges[i][0] < 0: - raise Exception("Values in the ranges passed to sub_blocks must be at least 0, but the {}th range is {}.".format(i, ranges[i])) + raise Exception("Values in the ranges passed to sub_blocks must be at " + "least 0, but the {}th range is {}.".format(i, + ranges[i])) if ranges[i][-1] >= a.num_blocks[i]: - raise Exception("Values in the ranges passed to sub_blocks must be less than the relevant number of blocks, but the {}th range is {}, and a.num_blocks = {}.".format(i, ranges[i], a.num_blocks)) + raise Exception("Values in the ranges passed to sub_blocks must be less " + "than the relevant number of blocks, but the {}th range " + "is {}, and a.num_blocks = {}.".format(i, ranges[i], + a.num_blocks)) last_index = [r[-1] for r in ranges] last_block_shape = DistArray.compute_block_shape(last_index, a.shape) - shape = [(len(ranges[i]) - 1) * BLOCK_SIZE + last_block_shape[i] for i in range(a.ndim)] + shape = [(len(ranges[i]) - 1) * BLOCK_SIZE + last_block_shape[i] + for i in range(a.ndim)] result = DistArray(shape) for index in np.ndindex(*result.num_blocks): - result.objectids[index] = a.objectids[tuple([ranges[i][index[i]] for i in range(a.ndim)])] + result.objectids[index] = a.objectids[tuple([ranges[i][index[i]] + for i in range(a.ndim)])] return result + @ray.remote def transpose(a): if a.ndim != 2: - raise Exception("transpose expects its argument to be 2-dimensional, but a.ndim = {}, a.shape = {}.".format(a.ndim, a.shape)) + raise Exception("transpose expects its argument to be 2-dimensional, but " + "a.ndim = {}, a.shape = {}.".format(a.ndim, a.shape)) result = DistArray([a.shape[1], a.shape[0]]) for i in range(result.num_blocks[0]): for j in range(result.num_blocks[1]): result.objectids[i, j] = ra.transpose.remote(a.objectids[j, i]) return result + # TODO(rkn): support broadcasting? @ray.remote def add(x1, x2): if x1.shape != x2.shape: - raise Exception("add expects arguments `x1` and `x2` to have the same shape, but x1.shape = {}, and x2.shape = {}.".format(x1.shape, x2.shape)) + raise Exception("add expects arguments `x1` and `x2` to have the same " + "shape, but x1.shape = {}, and x2.shape = {}." + .format(x1.shape, x2.shape)) result = DistArray(x1.shape) for index in np.ndindex(*result.num_blocks): - result.objectids[index] = ra.add.remote(x1.objectids[index], x2.objectids[index]) + result.objectids[index] = ra.add.remote(x1.objectids[index], + x2.objectids[index]) return result + # TODO(rkn): support broadcasting? @ray.remote def subtract(x1, x2): if x1.shape != x2.shape: - raise Exception("subtract expects arguments `x1` and `x2` to have the same shape, but x1.shape = {}, and x2.shape = {}.".format(x1.shape, x2.shape)) + raise Exception("subtract expects arguments `x1` and `x2` to have the " + "same shape, but x1.shape = {}, and x2.shape = {}." + .format(x1.shape, x2.shape)) result = DistArray(x1.shape) for index in np.ndindex(*result.num_blocks): - result.objectids[index] = ra.subtract.remote(x1.objectids[index], x2.objectids[index]) + result.objectids[index] = ra.subtract.remote(x1.objectids[index], + x2.objectids[index]) return result diff --git a/python/ray/experimental/array/distributed/linalg.py b/python/ray/experimental/array/distributed/linalg.py index 45341dddd..86a66fd95 100644 --- a/python/ray/experimental/array/distributed/linalg.py +++ b/python/ray/experimental/array/distributed/linalg.py @@ -6,30 +6,32 @@ import numpy as np import ray.experimental.array.remote as ra import ray -from .core import * +from . import core __all__ = ["tsqr", "modified_lu", "tsqr_hr", "qr"] + @ray.remote(num_return_vals=2) def tsqr(a): - """ - arguments: - a: a distributed matrix - Suppose that - a.shape == (M, N) - K == min(M, N) - return values: - q: DistArray, if q_full = ray.get(DistArray, q).assemble(), then - q_full.shape == (M, K) - np.allclose(np.dot(q_full.T, q_full), np.eye(K)) == True - r: np.ndarray, if r_val = ray.get(np.ndarray, r), then - r_val.shape == (K, N) - np.allclose(r, np.triu(r)) == True + """Perform a QR decomposition of a tall-skinny matrix. + + Args: + a: A distributed matrix with shape MxN (suppose K = min(M, N)). + + Returns: + A tuple of q (a DistArray) and r (a numpy array) satisfying the following. + - If q_full = ray.get(DistArray, q).assemble(), then + q_full.shape == (M, K). + - np.allclose(np.dot(q_full.T, q_full), np.eye(K)) == True. + - If r_val = ray.get(np.ndarray, r), then r_val.shape == (K, N). + - np.allclose(r, np.triu(r)) == True. """ if len(a.shape) != 2: - raise Exception("tsqr requires len(a.shape) == 2, but a.shape is {}".format(a.shape)) + raise Exception("tsqr requires len(a.shape) == 2, but a.shape is " + "{}".format(a.shape)) if a.num_blocks[1] != 1: - raise Exception("tsqr requires a.num_blocks[1] == 1, but a.num_blocks is {}".format(a.num_blocks)) + raise Exception("tsqr requires a.num_blocks[1] == 1, but a.num_blocks is " + "{}".format(a.num_blocks)) num_blocks = a.num_blocks[0] K = int(np.ceil(np.log2(num_blocks))) + 1 @@ -57,9 +59,9 @@ def tsqr(a): q_shape = a.shape else: q_shape = [a.shape[0], a.shape[0]] - q_num_blocks = DistArray.compute_num_blocks(q_shape) + q_num_blocks = core.DistArray.compute_num_blocks(q_shape) q_objectids = np.empty(q_num_blocks, dtype=object) - q_result = DistArray(q_shape, q_objectids) + q_result = core.DistArray(q_shape, q_objectids) # reconstruct output for i in range(num_blocks): @@ -68,28 +70,35 @@ def tsqr(a): for j in range(1, K): if np.mod(ith_index, 2) == 0: lower = [0, 0] - upper = [a.shape[1], BLOCK_SIZE] + upper = [a.shape[1], core.BLOCK_SIZE] else: lower = [a.shape[1], 0] - upper = [2 * a.shape[1], BLOCK_SIZE] + upper = [2 * a.shape[1], core.BLOCK_SIZE] ith_index //= 2 - q_block_current = ra.dot.remote(q_block_current, ra.subarray.remote(q_tree[ith_index, j], lower, upper)) + q_block_current = ra.dot.remote(q_block_current, + ra.subarray.remote(q_tree[ith_index, j], + lower, upper)) q_result.objectids[i] = q_block_current r = current_rs[0] return q_result, ray.get(r) + # TODO(rkn): This is unoptimized, we really want a block version of this. +# This is Algorithm 5 from +# http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf. @ray.remote(num_return_vals=3) def modified_lu(q): - """ - Algorithm 5 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf - takes a matrix q with orthonormal columns, returns l, u, s such that q - s = l * u - arguments: - q: a two dimensional orthonormal q - return values: - l: lower triangular - u: upper triangular - s: a diagonal matrix represented by its diagonal + """Perform a modified LU decomposition of a matrix. + + This takes a matrix q with orthonormal columns, returns l, u, s such that + q - s = l * u. + + Args: + q: A two dimensional orthonormal matrix q. + + Returns: + A tuple of a lower triangular matrix l, an upper triangular matrix u, and a + a vector representing a diagonal matrix s such that q - s = l * u. """ q = q.assemble() m, b = q.shape[0], q.shape[1] @@ -100,14 +109,19 @@ def modified_lu(q): for i in range(b): S[i] = -1 * np.sign(q_work[i, i]) q_work[i, i] -= S[i] - q_work[(i + 1):m, i] /= q_work[i, i] # scale ith column of L by diagonal element - q_work[(i + 1):m, (i + 1):b] -= np.outer(q_work[(i + 1):m, i], q_work[i, (i + 1):b]) # perform Schur complement update + # Scale ith column of L by diagonal element. + q_work[(i + 1):m, i] /= q_work[i, i] + # Perform Schur complement update. + q_work[(i + 1):m, (i + 1):b] -= np.outer(q_work[(i + 1):m, i], + q_work[i, (i + 1):b]) L = np.tril(q_work) for i in range(b): L[i, i] = 1 U = np.triu(q_work)[:b, :] - return ray.get(numpy_to_dist.remote(ray.put(L))), U, S # TODO(rkn): get rid of put + # TODO(rkn): Get rid of the put below. + return ray.get(core.numpy_to_dist.remote(ray.put(L))), U, S + @ray.remote(num_return_vals=2) def tsqr_hr_helper1(u, s, y_top_block, b): @@ -116,45 +130,60 @@ def tsqr_hr_helper1(u, s, y_top_block, b): t = -1 * np.dot(u, np.dot(s_full, np.linalg.inv(y_top).T)) return t, y_top + @ray.remote def tsqr_hr_helper2(s, r_temp): s_full = np.diag(s) return np.dot(s_full, r_temp) + +# This is Algorithm 6 from +# http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf. @ray.remote(num_return_vals=4) def tsqr_hr(a): - """Algorithm 6 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf""" q, r_temp = tsqr.remote(a) y, u, s = modified_lu.remote(q) y_blocked = ray.get(y) - t, y_top = tsqr_hr_helper1.remote(u, s, y_blocked.objectids[0, 0], a.shape[1]) + t, y_top = tsqr_hr_helper1.remote(u, s, y_blocked.objectids[0, 0], + a.shape[1]) r = tsqr_hr_helper2.remote(s, r_temp) return ray.get(y), ray.get(t), ray.get(y_top), ray.get(r) + @ray.remote def qr_helper1(a_rc, y_ri, t, W_c): return a_rc - np.dot(y_ri, np.dot(t.T, W_c)) + @ray.remote def qr_helper2(y_ri, a_rc): return np.dot(y_ri.T, a_rc) + +# This is Algorithm 7 from +# http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf. @ray.remote(num_return_vals=2) def qr(a): - """Algorithm 7 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf""" + m, n = a.shape[0], a.shape[1] k = min(m, n) # we will store our scratch work in a_work - a_work = DistArray(a.shape, np.copy(a.objectids)) + a_work = core.DistArray(a.shape, np.copy(a.objectids)) result_dtype = np.linalg.qr(ray.get(a.objectids[0, 0]))[0].dtype.name - r_res = ray.get(zeros.remote([k, n], result_dtype)) # TODO(rkn): It would be preferable not to get this right after creating it. - y_res = ray.get(zeros.remote([m, k], result_dtype)) # TODO(rkn): It would be preferable not to get this right after creating it. + # TODO(rkn): It would be preferable not to get this right after creating it. + r_res = ray.get(core.zeros.remote([k, n], result_dtype)) + # TODO(rkn): It would be preferable not to get this right after creating it. + y_res = ray.get(core.zeros.remote([m, k], result_dtype)) Ts = [] - for i in range(min(a.num_blocks[0], a.num_blocks[1])): # this differs from the paper, which says "for i in range(a.num_blocks[1])", but that doesn't seem to make any sense when a.num_blocks[1] > a.num_blocks[0] - sub_dist_array = subblocks.remote(a_work, list(range(i, a_work.num_blocks[0])), [i]) + # The for loop differs from the paper, which says + # "for i in range(a.num_blocks[1])", but that doesn't seem to make any sense + # when a.num_blocks[1] > a.num_blocks[0]. + for i in range(min(a.num_blocks[0], a.num_blocks[1])): + sub_dist_array = core.subblocks.remote( + a_work, list(range(i, a_work.num_blocks[0])), [i]) y, t, _, R = tsqr_hr.remote(sub_dist_array) y_val = ray.get(y) @@ -167,7 +196,7 @@ def qr(a): r_res.objectids[i, i] = ra.dot.remote(eye_temp, R) else: r_res.objectids[i, i] = R - Ts.append(numpy_to_dist.remote(t)) + Ts.append(core.numpy_to_dist.remote(t)) for c in range(i + 1, a.num_blocks[1]): W_rcs = [] @@ -182,9 +211,14 @@ def qr(a): r_res.objectids[i, c] = a_work.objectids[i, c] # construct q_res from Ys and Ts - q = eye.remote(m, k, dtype_name=result_dtype) + q = core.eye.remote(m, k, dtype_name=result_dtype) for i in range(len(Ts))[::-1]: - y_col_block = subblocks.remote(y_res, [], [i]) - q = subtract.remote(q, dot.remote(y_col_block, dot.remote(Ts[i], dot.remote(transpose.remote(y_col_block), q)))) + y_col_block = core.subblocks.remote(y_res, [], [i]) + q = core.subtract.remote( + q, core.dot.remote( + y_col_block, + core.dot.remote(Ts[i], + core.dot.remote(core.transpose.remote(y_col_block), + q)))) return ray.get(q), r_res diff --git a/python/ray/experimental/array/distributed/random.py b/python/ray/experimental/array/distributed/random.py index 61aabe21d..e3e25ff8c 100644 --- a/python/ray/experimental/array/distributed/random.py +++ b/python/ray/experimental/array/distributed/random.py @@ -6,13 +6,15 @@ import numpy as np import ray.experimental.array.remote as ra import ray -from .core import * +from .core import DistArray + @ray.remote def normal(shape): num_blocks = DistArray.compute_num_blocks(shape) objectids = np.empty(num_blocks, dtype=object) for index in np.ndindex(*num_blocks): - objectids[index] = ra.random.normal.remote(DistArray.compute_block_shape(index, shape)) + objectids[index] = ra.random.normal.remote( + DistArray.compute_block_shape(index, shape)) result = DistArray(shape, objectids) return result diff --git a/python/ray/experimental/array/remote/__init__.py b/python/ray/experimental/array/remote/__init__.py index 2c0ddfffc..c2fd32a5a 100644 --- a/python/ray/experimental/array/remote/__init__.py +++ b/python/ray/experimental/array/remote/__init__.py @@ -4,4 +4,10 @@ from __future__ import print_function from . import random from . import linalg -from .core import * +from .core import (zeros, zeros_like, ones, eye, dot, vstack, hstack, subarray, + copy, tril, triu, diag, transpose, add, subtract, sum, + shape, sum_list) + +__all__ = ["random", "linalg", "zeros", "zeros_like", "ones", "eye", "dot", + "vstack", "hstack", "subarray", "copy", "tril", "triu", "diag", + "transpose", "add", "subtract", "sum", "shape", "sum_list"] diff --git a/python/ray/experimental/array/remote/core.py b/python/ray/experimental/array/remote/core.py index 7d8e69e52..e7f142f1b 100644 --- a/python/ray/experimental/array/remote/core.py +++ b/python/ray/experimental/array/remote/core.py @@ -5,82 +5,97 @@ from __future__ import print_function import numpy as np import ray -__all__ = ["zeros", "zeros_like", "ones", "eye", "dot", "vstack", "hstack", "subarray", "copy", "tril", "triu", "diag", "transpose", "add", "subtract", "sum", "shape", "sum_list"] @ray.remote def zeros(shape, dtype_name="float", order="C"): return np.zeros(shape, dtype=np.dtype(dtype_name), order=order) + @ray.remote def zeros_like(a, dtype_name="None", order="K", subok=True): dtype_val = None if dtype_name == "None" else np.dtype(dtype_name) return np.zeros_like(a, dtype=dtype_val, order=order, subok=subok) + @ray.remote def ones(shape, dtype_name="float", order="C"): return np.ones(shape, dtype=np.dtype(dtype_name), order=order) + @ray.remote def eye(N, M=-1, k=0, dtype_name="float"): M = N if M == -1 else M return np.eye(N, M=M, k=k, dtype=np.dtype(dtype_name)) + @ray.remote def dot(a, b): return np.dot(a, b) + @ray.remote def vstack(*xs): return np.vstack(xs) + @ray.remote def hstack(*xs): return np.hstack(xs) -# TODO(rkn): instead of this, consider implementing slicing + +# TODO(rkn): Instead of this, consider implementing slicing. +# TODO(rkn): Be consistent about using "index" versus "indices". @ray.remote -def subarray(a, lower_indices, upper_indices): # TODO(rkn): be consistent about using "index" versus "indices" +def subarray(a, lower_indices, upper_indices): return a[[slice(l, u) for (l, u) in zip(lower_indices, upper_indices)]] + @ray.remote def copy(a, order="K"): return np.copy(a, order=order) + @ray.remote def tril(m, k=0): return np.tril(m, k=k) + @ray.remote def triu(m, k=0): return np.triu(m, k=k) + @ray.remote def diag(v, k=0): return np.diag(v, k=k) + @ray.remote def transpose(a, axes=[]): axes = None if axes == [] else axes return np.transpose(a, axes=axes) + @ray.remote def add(x1, x2): return np.add(x1, x2) + @ray.remote def subtract(x1, x2): return np.subtract(x1, x2) + @ray.remote def sum(x, axis=-1): return np.sum(x, axis=axis if axis != -1 else None) + @ray.remote def shape(a): return np.shape(a) -# We use Any to allow different numerical types as well as numpy arrays. -# TODO(rkn):this isn't in the numpy API, so be careful about exposing this. + @ray.remote def sum_list(*xs): return np.sum(xs, axis=0) diff --git a/python/ray/experimental/array/remote/linalg.py b/python/ray/experimental/array/remote/linalg.py index 6ddb5fa89..b1436648c 100644 --- a/python/ray/experimental/array/remote/linalg.py +++ b/python/ray/experimental/array/remote/linalg.py @@ -8,84 +8,104 @@ import ray __all__ = ["matrix_power", "solve", "tensorsolve", "tensorinv", "inv", "cholesky", "eigvals", "eigvalsh", "pinv", "slogdet", "det", "svd", "eig", "eigh", "lstsq", "norm", "qr", "cond", "matrix_rank", - "LinAlgError", "multi_dot"] + "multi_dot"] + @ray.remote def matrix_power(M, n): return np.linalg.matrix_power(M, n) + @ray.remote def solve(a, b): return np.linalg.solve(a, b) + @ray.remote(num_return_vals=2) def tensorsolve(a): raise NotImplementedError + @ray.remote(num_return_vals=2) def tensorinv(a): raise NotImplementedError + @ray.remote def inv(a): return np.linalg.inv(a) + @ray.remote def cholesky(a): return np.linalg.cholesky(a) + @ray.remote def eigvals(a): return np.linalg.eigvals(a) + @ray.remote def eigvalsh(a): raise NotImplementedError + @ray.remote def pinv(a): return np.linalg.pinv(a) + @ray.remote def slogdet(a): raise NotImplementedError + @ray.remote def det(a): return np.linalg.det(a) + @ray.remote(num_return_vals=3) def svd(a): return np.linalg.svd(a) + @ray.remote(num_return_vals=2) def eig(a): return np.linalg.eig(a) + @ray.remote(num_return_vals=2) def eigh(a): return np.linalg.eigh(a) + @ray.remote(num_return_vals=4) def lstsq(a, b): return np.linalg.lstsq(a) + @ray.remote def norm(x): return np.linalg.norm(x) + @ray.remote(num_return_vals=2) def qr(a): return np.linalg.qr(a) + @ray.remote def cond(x): return np.linalg.cond(x) + @ray.remote def matrix_rank(M): return np.linalg.matrix_rank(M) + @ray.remote def multi_dot(*a): raise NotImplementedError diff --git a/python/ray/experimental/array/remote/random.py b/python/ray/experimental/array/remote/random.py index 635b9c136..bea781f2b 100644 --- a/python/ray/experimental/array/remote/random.py +++ b/python/ray/experimental/array/remote/random.py @@ -5,6 +5,7 @@ from __future__ import print_function import numpy as np import ray + @ray.remote def normal(shape): return np.random.normal(size=shape) diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index 147fe39d8..6e29e838a 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -2,6 +2,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function + def get_local_schedulers(worker): local_schedulers = [] for client in worker.redis_client.keys("CL:*"): diff --git a/python/ray/experimental/tfutils.py b/python/ray/experimental/tfutils.py index 7ce146c7c..8ebcd60ed 100644 --- a/python/ray/experimental/tfutils.py +++ b/python/ray/experimental/tfutils.py @@ -4,6 +4,7 @@ from __future__ import print_function import numpy as np from collections import deque, OrderedDict + def unflatten(vector, shapes): i = 0 arrays = [] @@ -15,6 +16,7 @@ def unflatten(vector, shapes): assert len(vector) == i, "Passed weight does not have the correct shape." return arrays + class TensorFlowVariables(object): """An object used to extract variables from a loss function. @@ -38,11 +40,11 @@ class TensorFlowVariables(object): variable_names = [] explored_inputs = set([loss]) - # We do a BFS on the dependency graph of the input function to find + # We do a BFS on the dependency graph of the input function to find # the variables. while len(queue) != 0: tf_obj = queue.popleft() - + # The object put into the queue is not necessarily an operation, so we # want the op attribute to get the operation underlying the object. # Only operations contain the inputs that we can explore. @@ -61,14 +63,16 @@ class TensorFlowVariables(object): if "Variable" in tf_obj.node_def.op: variable_names.append(tf_obj.node_def.name) self.variables = OrderedDict() - for v in [v for v in tf.global_variables() if v.op.node_def.name in variable_names]: + for v in [v for v in tf.global_variables() + if v.op.node_def.name in variable_names]: self.variables[v.op.node_def.name] = v self.placeholders = dict() self.assignment_nodes = [] # Create new placeholders to put in custom weights. for k, var in self.variables.items(): - self.placeholders[k] = tf.placeholder(var.value().dtype, var.get_shape().as_list()) + self.placeholders[k] = tf.placeholder(var.value().dtype, + var.get_shape().as_list()) self.assignment_nodes.append(var.assign(self.placeholders[k])) def set_session(self, sess): @@ -76,24 +80,30 @@ class TensorFlowVariables(object): self.sess = sess def get_flat_size(self): - return sum([np.prod(v.get_shape().as_list()) for v in self.variables.values()]) + return sum([np.prod(v.get_shape().as_list()) + for v in self.variables.values()]) def _check_sess(self): """Checks if the session is set, and if not throw an error message.""" - assert self.sess is not None, "The session is not set. Set the session either by passing it into the TensorFlowVariables constructor or by calling set_session(sess)." - + assert self.sess is not None, ("The session is not set. Set the session " + "either by passing it into the " + "TensorFlowVariables constructor or by " + "calling set_session(sess).") + def get_flat(self): """Gets the weights and returns them as a flat array.""" self._check_sess() - return np.concatenate([v.eval(session=self.sess).flatten() for v in self.variables.values()]) - + return np.concatenate([v.eval(session=self.sess).flatten() + for v in self.variables.values()]) + def set_flat(self, new_weights): """Sets the weights to new_weights, converting from a flat array.""" self._check_sess() shapes = [v.get_shape().as_list() for v in self.variables.values()] arrays = unflatten(new_weights, shapes) placeholders = [self.placeholders[k] for k, v in self.variables.items()] - self.sess.run(self.assignment_nodes, feed_dict=dict(zip(placeholders,arrays))) + self.sess.run(self.assignment_nodes, + feed_dict=dict(zip(placeholders, arrays))) def get_weights(self): """Returns the weights of the variables of the loss function in a list.""" @@ -103,4 +113,7 @@ class TensorFlowVariables(object): def set_weights(self, new_weights): """Sets the weights to new_weights.""" self._check_sess() - self.sess.run(self.assignment_nodes, feed_dict={self.placeholders[name]: value for (name, value) in new_weights.items() if name in self.placeholders}) + self.sess.run(self.assignment_nodes, + feed_dict={self.placeholders[name]: value + for (name, value) in new_weights.items() + if name in self.placeholders}) diff --git a/python/ray/experimental/utils.py b/python/ray/experimental/utils.py index 7c8083214..166ae0904 100644 --- a/python/ray/experimental/utils.py +++ b/python/ray/experimental/utils.py @@ -9,6 +9,7 @@ import sys import ray + def tarred_directory_as_bytes(source_dir): """Tar a directory and return it as a byte string. @@ -26,6 +27,7 @@ def tarred_directory_as_bytes(source_dir): string_file.seek(0) return string_file.read() + def tarred_bytes_to_directory(tarred_bytes, target_dir): """Take a byte string and untar it. @@ -38,6 +40,7 @@ def tarred_bytes_to_directory(tarred_bytes, target_dir): with tarfile.open(fileobj=string_file) as tar: tar.extractall(path=target_dir) + def copy_directory(source_dir, target_dir=None): """Copy a local directory to each machine in the Ray cluster. @@ -45,17 +48,18 @@ def copy_directory(source_dir, target_dir=None): example, source_dir can be /a/b/c and target_dir can be /d/e/c. In this case, the directory /d/e will be added to the Python path of each worker. - Note that this method is not completely safe to use. For example, workers that - do not do the copying and only set their paths (only one worker per node does - the copying) may try to execute functions that use the files in the directory - being copied before the directory being copied has finished untarring. + Note that this method is not completely safe to use. For example, workers + that do not do the copying and only set their paths (only one worker per node + does the copying) may try to execute functions that use the files in the + directory being copied before the directory being copied has finished + untarring. Args: source_dir (str): The directory to copy. target_dir (str): The location to copy it to on the other machines. If this is not provided, the source_dir will be used. If it is provided and is - different from source_dir, the source_dir also be copied to the target_dir - location on this machine. + different from source_dir, the source_dir also be copied to the + target_dir location on this machine. """ target_dir = source_dir if target_dir is None else target_dir source_dir = os.path.abspath(source_dir) @@ -63,8 +67,10 @@ def copy_directory(source_dir, target_dir=None): source_basename = os.path.basename(source_dir) target_basename = os.path.basename(target_dir) if source_basename != target_basename: - raise Exception("The source_dir and target_dir must have the same base name, {} != {}".format(source_basename, target_basename)) + raise Exception("The source_dir and target_dir must have the same base " + "name, {} != {}".format(source_basename, target_basename)) tarred_bytes = tarred_directory_as_bytes(source_dir) + def f(worker_info): if worker_info["counter"] == 0: tarred_bytes_to_directory(tarred_bytes, os.path.dirname(target_dir)) diff --git a/python/ray/global_scheduler/__init__.py b/python/ray/global_scheduler/__init__.py index cde7ce231..25e4d2cf6 100644 --- a/python/ray/global_scheduler/__init__.py +++ b/python/ray/global_scheduler/__init__.py @@ -2,4 +2,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from .global_scheduler_services import * +from .global_scheduler_services import start_global_scheduler + +__all__ = ["start_global_scheduler"] diff --git a/python/ray/global_scheduler/global_scheduler_services.py b/python/ray/global_scheduler/global_scheduler_services.py index ae9b8b1e7..36913801e 100644 --- a/python/ray/global_scheduler/global_scheduler_services.py +++ b/python/ray/global_scheduler/global_scheduler_services.py @@ -6,6 +6,7 @@ import os import subprocess import time + def start_global_scheduler(redis_address, use_valgrind=False, use_profiler=False, stdout_file=None, stderr_file=None): @@ -15,8 +16,8 @@ def start_global_scheduler(redis_address, use_valgrind=False, redis_address (str): The address of the Redis instance. use_valgrind (bool): True if the global scheduler should be started inside of valgrind. If this is True, use_profiler must be False. - use_profiler (bool): True if the global scheduler should be started inside a - profiler. If this is True, use_valgrind must be False. + use_profiler (bool): True if the global scheduler should be started inside + a profiler. If this is True, use_valgrind must be False. stdout_file: A file handle opened for writing to redirect stdout to. If no redirection should happen, then this should be None. stderr_file: A file handle opened for writing to redirect stderr to. If no @@ -27,7 +28,9 @@ def start_global_scheduler(redis_address, use_valgrind=False, """ if use_valgrind and use_profiler: raise Exception("Cannot use valgrind and profiler at the same time.") - global_scheduler_executable = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../core/src/global_scheduler/global_scheduler") + global_scheduler_executable = os.path.join( + os.path.abspath(os.path.dirname(__file__)), + "../core/src/global_scheduler/global_scheduler") command = [global_scheduler_executable, "-r", redis_address] if use_valgrind: pid = subprocess.Popen(["valgrind", @@ -35,7 +38,7 @@ def start_global_scheduler(redis_address, use_valgrind=False, "--leak-check=full", "--show-leak-kinds=all", "--error-exitcode=1"] + command, - stdout=stdout_file, stderr=stderr_file) + stdout=stdout_file, stderr=stderr_file) time.sleep(1.0) elif use_profiler: pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command, diff --git a/python/ray/global_scheduler/test/test.py b/python/ray/global_scheduler/test/test.py index 3f9bbc731..aeaff66fe 100644 --- a/python/ray/global_scheduler/test/test.py +++ b/python/ray/global_scheduler/test/test.py @@ -7,16 +7,14 @@ import os import random import redis import signal -import subprocess import sys -import threading import time import unittest import ray.global_scheduler as global_scheduler import ray.local_scheduler as local_scheduler import ray.plasma as plasma -from ray.plasma.utils import random_object_id, generate_metadata, write_to_data_buffer, create_object_with_id, create_object +from ray.plasma.utils import create_object from ray import services @@ -39,21 +37,27 @@ TASK_STATUS_DONE = 16 DB_CLIENT_PREFIX = "CL:" TASK_PREFIX = "TT:" + def random_driver_id(): return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) + def random_task_id(): return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) + def random_function_id(): return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) + def random_object_id(): return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) + def new_port(): return random.randint(10000, 65535) + class TestGlobalScheduler(unittest.TestCase): def setUp(self): @@ -62,9 +66,11 @@ class TestGlobalScheduler(unittest.TestCase): redis_port, self.redis_process = services.start_redis(cleanup=False) redis_address = services.address(node_ip_address, redis_port) # Create a Redis client. - self.redis_client = redis.StrictRedis(host=node_ip_address, port=redis_port) + self.redis_client = redis.StrictRedis(host=node_ip_address, + port=redis_port) # Start one global scheduler. - self.p1 = global_scheduler.start_global_scheduler(redis_address, use_valgrind=USE_VALGRIND) + self.p1 = global_scheduler.start_global_scheduler( + redis_address, use_valgrind=USE_VALGRIND) self.plasma_store_pids = [] self.plasma_manager_pids = [] self.local_scheduler_pids = [] @@ -76,11 +82,15 @@ class TestGlobalScheduler(unittest.TestCase): plasma_store_name, p2 = plasma.start_plasma_store() self.plasma_store_pids.append(p2) # Start the Plasma manager. - # Assumption: Plasma manager name and port are randomly generated by the plasma module. - plasma_manager_name, p3, plasma_manager_port = plasma.start_plasma_manager(plasma_store_name, redis_address) + # Assumption: Plasma manager name and port are randomly generated by the + # plasma module. + manager_info = plasma.start_plasma_manager(plasma_store_name, + redis_address) + plasma_manager_name, p3, plasma_manager_port = manager_info self.plasma_manager_pids.append(p3) plasma_address = "{}:{}".format(node_ip_address, plasma_manager_port) - plasma_client = plasma.PlasmaClient(plasma_store_name, plasma_manager_name) + plasma_client = plasma.PlasmaClient(plasma_store_name, + plasma_manager_name) self.plasma_clients.append(plasma_client) # Start the local scheduler. local_scheduler_name, p4 = local_scheduler.start_local_scheduler( @@ -91,7 +101,7 @@ class TestGlobalScheduler(unittest.TestCase): static_resource_list=[10, 0]) # Connect to the scheduler. local_scheduler_client = local_scheduler.LocalSchedulerClient( - local_scheduler_name, NIL_ACTOR_ID, False) + local_scheduler_name, NIL_ACTOR_ID, False) self.local_scheduler_clients.append(local_scheduler_client) self.local_scheduler_pids.append(p4) @@ -149,11 +159,13 @@ class TestGlobalScheduler(unittest.TestCase): return db_client_id def test_task_default_resources(self): - task1 = local_scheduler.Task(random_driver_id(), random_function_id(), [random_object_id()], 0, random_task_id(), 0) + task1 = local_scheduler.Task(random_driver_id(), random_function_id(), + [random_object_id()], 0, random_task_id(), 0) self.assertEqual(task1.required_resources(), [1.0, 0.0]) task2 = local_scheduler.Task(random_driver_id(), random_function_id(), [random_object_id()], 0, random_task_id(), 0, - local_scheduler.ObjectID(NIL_ACTOR_ID), 0, [1.0, 2.0]) + local_scheduler.ObjectID(NIL_ACTOR_ID), 0, + [1.0, 2.0]) self.assertEqual(task2.required_resources(), [1.0, 2.0]) def test_redis_only_single_task(self): @@ -162,31 +174,37 @@ class TestGlobalScheduler(unittest.TestCase): task state transitions in Redis only. TODO(atumanov): implement. """ # Check precondition for this test: - # There should be 2n+1 db clients: the global scheduler + one local scheduler and one plasma per node. - self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))), - 2 * NUM_CLUSTER_NODES + 1) + # There should be 2n+1 db clients: the global scheduler + one local + # scheduler and one plasma per node. + self.assertEqual( + len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))), + 2 * NUM_CLUSTER_NODES + 1) db_client_id = self.get_plasma_manager_id() - assert(db_client_id != None) + assert(db_client_id is not None) assert(db_client_id.startswith(b"CL:")) - db_client_id = db_client_id[len(b"CL:"):] # Remove the CL: prefix. + db_client_id = db_client_id[len(b"CL:"):] # Remove the CL: prefix. def test_integration_single_task(self): # There should be three db clients, the global scheduler, the local # scheduler, and the plasma manager. - self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))), - 2 * NUM_CLUSTER_NODES + 1) + self.assertEqual( + len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))), + 2 * NUM_CLUSTER_NODES + 1) num_return_vals = [0, 1, 2, 3, 5, 10] # Insert the object into Redis. data_size = 0xf1f0 metadata_size = 0x40 plasma_client = self.plasma_clients[0] - object_dep, memory_buffer, metadata = create_object(plasma_client, data_size, metadata_size, seal=True) + object_dep, memory_buffer, metadata = create_object( + plasma_client, data_size, metadata_size, seal=True) # Sleep before submitting task to local scheduler. time.sleep(0.1) # Submit a task to Redis. - task = local_scheduler.Task(random_driver_id(), random_function_id(), [local_scheduler.ObjectID(object_dep)], num_return_vals[0], random_task_id(), 0) + task = local_scheduler.Task(random_driver_id(), random_function_id(), + [local_scheduler.ObjectID(object_dep)], + num_return_vals[0], random_task_id(), 0) self.local_scheduler_clients[0].submit(task) time.sleep(0.1) # There should now be a task in Redis, and it should get assigned to the @@ -217,8 +235,9 @@ class TestGlobalScheduler(unittest.TestCase): def integration_many_tasks_helper(self, timesync=True): # There should be three db clients, the global scheduler, the local # scheduler, and the plasma manager. - self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))), - 2 * NUM_CLUSTER_NODES + 1) + self.assertEqual( + len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))), + 2 * NUM_CLUSTER_NODES + 1) num_return_vals = [0, 1, 2, 3, 5, 10] # Submit a bunch of tasks to Redis. @@ -228,11 +247,16 @@ class TestGlobalScheduler(unittest.TestCase): data_size = np.random.randint(1 << 20) metadata_size = np.random.randint(1 << 10) plasma_client = self.plasma_clients[0] - object_dep, memory_buffer, metadata = create_object(plasma_client, data_size, metadata_size, seal=True) + object_dep, memory_buffer, metadata = create_object(plasma_client, + data_size, + metadata_size, + seal=True) if timesync: # Give 10ms for object info handler to fire (long enough to yield CPU). time.sleep(0.010) - task = local_scheduler.Task(random_driver_id(), random_function_id(), [local_scheduler.ObjectID(object_dep)], num_return_vals[0], random_task_id(), 0) + task = local_scheduler.Task(random_driver_id(), random_function_id(), + [local_scheduler.ObjectID(object_dep)], + num_return_vals[0], random_task_id(), 0) self.local_scheduler_clients[0].submit(task) # Check that there are the correct number of tasks in Redis and that they # all get assigned to the local scheduler. @@ -243,17 +267,18 @@ class TestGlobalScheduler(unittest.TestCase): self.assertLessEqual(len(task_entries), num_tasks) # First, check if all tasks made it to Redis. if len(task_entries) == num_tasks: - task_contents = [self.redis_client.hgetall(task_entries[i]) for i in range(len(task_entries))] + task_contents = [self.redis_client.hgetall(task_entries[i]) + for i in range(len(task_entries))] task_statuses = [int(contents[b"state"]) for contents in task_contents] - self.assertTrue(all([ - status in [TASK_STATUS_WAITING, - TASK_STATUS_SCHEDULED, - TASK_STATUS_QUEUED] for status in task_statuses - ])) + self.assertTrue(all([status in [TASK_STATUS_WAITING, + TASK_STATUS_SCHEDULED, + TASK_STATUS_QUEUED] + for status in task_statuses])) num_tasks_done = task_statuses.count(TASK_STATUS_QUEUED) num_tasks_scheduled = task_statuses.count(TASK_STATUS_SCHEDULED) num_tasks_waiting = task_statuses.count(TASK_STATUS_WAITING) - print("tasks in Redis = {}, tasks waiting = {}, tasks scheduled = {}, tasks queued = {}, retries left = {}" + print("tasks in Redis = {}, tasks waiting = {}, tasks scheduled = {}, " + "tasks queued = {}, retries left = {}" .format(len(task_entries), num_tasks_waiting, num_tasks_scheduled, num_tasks_done, num_retries)) if all([status == TASK_STATUS_QUEUED for status in task_statuses]): @@ -275,6 +300,7 @@ class TestGlobalScheduler(unittest.TestCase): # notifications. self.integration_many_tasks_helper(timesync=False) + if __name__ == "__main__": if len(sys.argv) > 1: # Pop the argument so we don't mess with unittest's own argument parser. diff --git a/python/ray/local_scheduler/__init__.py b/python/ray/local_scheduler/__init__.py index d3895317f..2264fb1b1 100644 --- a/python/ray/local_scheduler/__init__.py +++ b/python/ray/local_scheduler/__init__.py @@ -2,5 +2,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from ray.core.src.local_scheduler.liblocal_scheduler_library import * -from .local_scheduler_services import * +from ray.core.src.local_scheduler.liblocal_scheduler_library import ( + Task, LocalSchedulerClient, ObjectID, check_simple_value, task_from_string, + task_to_string) +from .local_scheduler_services import start_local_scheduler + +__all__ = ["Task", "LocalSchedulerClient", "ObjectID", "check_simple_value", + "task_from_string", "task_to_string", "start_local_scheduler"] diff --git a/python/ray/local_scheduler/local_scheduler_services.py b/python/ray/local_scheduler/local_scheduler_services.py index f338b22b5..03c213095 100644 --- a/python/ray/local_scheduler/local_scheduler_services.py +++ b/python/ray/local_scheduler/local_scheduler_services.py @@ -7,9 +7,11 @@ import random import subprocess import time + def random_name(): return str(random.randint(0, 99999999)) + def start_local_scheduler(plasma_store_name, plasma_manager_name=None, worker_path=None, @@ -38,8 +40,8 @@ def start_local_scheduler(plasma_store_name, running on. redis_address (str): The address of the Redis instance to connect to. If this is not provided, then the local scheduler will not connect to Redis. - use_valgrind (bool): True if the local scheduler should be started inside of - valgrind. If this is True, use_profiler must be False. + use_valgrind (bool): True if the local scheduler should be started inside + of valgrind. If this is True, use_profiler must be False. use_profiler (bool): True if the local scheduler should be started inside a profiler. If this is True, use_valgrind must be False. stdout_file: A file handle opened for writing to redirect stdout to. If no @@ -56,11 +58,14 @@ def start_local_scheduler(plasma_store_name, A tuple of the name of the local scheduler socket and the process ID of the local scheduler process. """ - if (plasma_manager_name == None) != (redis_address == None): - raise Exception("If one of the plasma_manager_name and the redis_address is provided, then both must be provided.") + if (plasma_manager_name is None) != (redis_address is None): + raise Exception("If one of the plasma_manager_name and the redis_address " + "is provided, then both must be provided.") if use_valgrind and use_profiler: raise Exception("Cannot use valgrind and profiler at the same time.") - local_scheduler_executable = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../core/src/local_scheduler/local_scheduler") + local_scheduler_executable = os.path.join(os.path.dirname( + os.path.abspath(__file__)), + "../core/src/local_scheduler/local_scheduler") local_scheduler_name = "/tmp/scheduler{}".format(random_name()) command = [local_scheduler_executable, "-s", local_scheduler_name, @@ -90,8 +95,10 @@ def start_local_scheduler(plasma_store_name, if plasma_address is not None: command += ["-a", plasma_address] if static_resource_list is not None: - assert all([isinstance(resource, int) or isinstance(resource, float) for resource in static_resource_list]) - command += ["-c", ",".join([str(resource) for resource in static_resource_list])] + assert all([isinstance(resource, int) or isinstance(resource, float) + for resource in static_resource_list]) + command += ["-c", ",".join([str(resource) for resource + in static_resource_list])] if use_valgrind: pid = subprocess.Popen(["valgrind", @@ -99,7 +106,7 @@ def start_local_scheduler(plasma_store_name, "--leak-check=full", "--show-leak-kinds=all", "--error-exitcode=1"] + command, - stdout=stdout_file, stderr=stderr_file) + stdout=stdout_file, stderr=stderr_file) time.sleep(1.0) elif use_profiler: pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command, diff --git a/python/ray/local_scheduler/test/test.py b/python/ray/local_scheduler/test/test.py index a89e84181..810051e1c 100644 --- a/python/ray/local_scheduler/test/test.py +++ b/python/ray/local_scheduler/test/test.py @@ -4,9 +4,7 @@ from __future__ import print_function import numpy as np import os -import random import signal -import subprocess import sys import threading import time @@ -20,18 +18,23 @@ ID_SIZE = 20 NIL_ACTOR_ID = 20 * b"\xff" + def random_object_id(): return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) + def random_driver_id(): return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) + def random_task_id(): return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) + def random_function_id(): return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) + class TestLocalSchedulerClient(unittest.TestCase): def setUp(self): @@ -39,10 +42,11 @@ class TestLocalSchedulerClient(unittest.TestCase): plasma_store_name, self.p1 = plasma.start_plasma_store() self.plasma_client = plasma.PlasmaClient(plasma_store_name) # Start a local scheduler. - scheduler_name, self.p2 = local_scheduler.start_local_scheduler(plasma_store_name, use_valgrind=USE_VALGRIND) + scheduler_name, self.p2 = local_scheduler.start_local_scheduler( + plasma_store_name, use_valgrind=USE_VALGRIND) # Connect to the scheduler. self.local_scheduler_client = local_scheduler.LocalSchedulerClient( - scheduler_name, NIL_ACTOR_ID, False) + scheduler_name, NIL_ACTOR_ID, False) def tearDown(self): # Check that the processes are still alive. @@ -70,37 +74,38 @@ class TestLocalSchedulerClient(unittest.TestCase): self.plasma_client.seal(object_id.id()) # Define some arguments to use for the tasks. args_list = [ - [], - #{}, - #(), - 1 * [1], - 10 * [1], - 100 * [1], - 1000 * [1], - 1 * ["a"], - 10 * ["a"], - 100 * ["a"], - 1000 * ["a"], - [1, 1.3, 1 << 100, "hi", u"hi", [1, 2]], - object_ids[:1], - object_ids[:2], - object_ids[:3], - object_ids[:4], - object_ids[:5], - object_ids[:10], - object_ids[:100], - object_ids[:256], - [1, object_ids[0]], - [object_ids[0], "a"], - [1, object_ids[0], "a"], - [object_ids[0], 1, object_ids[1], "a"], - object_ids[:3] + [1, "hi", 2.3] + object_ids[:5], - object_ids + 100 * ["a"] + object_ids + [], + [{}], + [()], + 1 * [1], + 10 * [1], + 100 * [1], + 1000 * [1], + 1 * ["a"], + 10 * ["a"], + 100 * ["a"], + 1000 * ["a"], + [1, 1.3, 1 << 100, "hi", u"hi", [1, 2]], + object_ids[:1], + object_ids[:2], + object_ids[:3], + object_ids[:4], + object_ids[:5], + object_ids[:10], + object_ids[:100], + object_ids[:256], + [1, object_ids[0]], + [object_ids[0], "a"], + [1, object_ids[0], "a"], + [object_ids[0], 1, object_ids[1], "a"], + object_ids[:3] + [1, "hi", 2.3] + object_ids[:5], + object_ids + 100 * ["a"] + object_ids ] for args in args_list: for num_return_vals in [0, 1, 2, 3, 5, 10, 100]: - task = local_scheduler.Task(random_driver_id(), function_id, args, num_return_vals, random_task_id(), 0) + task = local_scheduler.Task(random_driver_id(), function_id, args, + num_return_vals, random_task_id(), 0) # Submit a task. self.local_scheduler_client.submit(task) # Get the task. @@ -119,7 +124,8 @@ class TestLocalSchedulerClient(unittest.TestCase): # Submit all of the tasks. for args in args_list: for num_return_vals in [0, 1, 2, 3, 5, 10, 100]: - task = local_scheduler.Task(random_driver_id(), function_id, args, num_return_vals, random_task_id(), 0) + task = local_scheduler.Task(random_driver_id(), function_id, args, + num_return_vals, random_task_id(), 0) self.local_scheduler_client.submit(task) # Get all of the tasks. for args in args_list: @@ -129,8 +135,10 @@ class TestLocalSchedulerClient(unittest.TestCase): def test_scheduling_when_objects_ready(self): # Create a task and submit it. object_id = random_object_id() - task = local_scheduler.Task(random_driver_id(), random_function_id(), [object_id], 0, random_task_id(), 0) + task = local_scheduler.Task(random_driver_id(), random_function_id(), + [object_id], 0, random_task_id(), 0) self.local_scheduler_client.submit(task) + # Launch a thread to get the task. def get_task(): self.local_scheduler_client.get_task() @@ -149,7 +157,9 @@ class TestLocalSchedulerClient(unittest.TestCase): # Create a task with two dependencies and submit it. object_id1 = random_object_id() object_id2 = random_object_id() - task = local_scheduler.Task(random_driver_id(), random_function_id(), [object_id1, object_id2], 0, random_task_id(), 0) + task = local_scheduler.Task(random_driver_id(), random_function_id(), + [object_id1, object_id2], 0, random_task_id(), + 0) self.local_scheduler_client.submit(task) # Launch a thread to get the task. @@ -196,9 +206,10 @@ class TestLocalSchedulerClient(unittest.TestCase): # Wait until the thread finishes so that we know the task was scheduled. t.join() + if __name__ == "__main__": if len(sys.argv) > 1: - # pop the argument so we don't mess with unittest's own argument parser + # Pop the argument so we don't mess with unittest's own argument parser. if sys.argv[-1] == "valgrind": arg = sys.argv.pop() USE_VALGRIND = True diff --git a/python/ray/log_monitor.py b/python/ray/log_monitor.py index febcd9e10..db9321ea1 100644 --- a/python/ray/log_monitor.py +++ b/python/ray/log_monitor.py @@ -3,13 +3,13 @@ from __future__ import division from __future__ import print_function import argparse -import os import redis import time from ray.services import get_ip_address from ray.services import get_port + class LogMonitor(object): """A monitor process for monitoring Ray log files. @@ -27,15 +27,17 @@ class LogMonitor(object): def __init__(self, redis_ip_address, redis_port, node_ip_address): """Initialize the log monitor object.""" self.node_ip_address = node_ip_address - self.redis_client = redis.StrictRedis(host=redis_ip_address, port=redis_port) + self.redis_client = redis.StrictRedis(host=redis_ip_address, + port=redis_port) self.log_files = {} self.log_file_handles = {} def update_log_filenames(self): """Get the most up-to-date list of log files to monitor from Redis.""" num_current_log_files = len(self.log_files) - new_log_filenames = self.redis_client.lrange("LOG_FILENAMES:{}".format(self.node_ip_address), - num_current_log_files, -1) + new_log_filenames = self.redis_client.lrange( + "LOG_FILENAMES:{}".format(self.node_ip_address), + num_current_log_files, -1) for log_filename in new_log_filenames: print("Beginning to track file {}".format(log_filename)) assert log_filename not in self.log_files @@ -50,7 +52,8 @@ class LogMonitor(object): # If there are any new lines, cache them and also push them to Redis. if len(new_lines) > 0: self.log_files[log_filename] += new_lines - redis_key = "LOGFILE:{}:{}".format(self.node_ip_address, log_filename.decode("ascii")) + redis_key = "LOGFILE:{}:{}".format(self.node_ip_address, + log_filename.decode("ascii")) self.redis_client.rpush(redis_key, *new_lines) else: try: @@ -69,6 +72,7 @@ class LogMonitor(object): self.check_log_files_and_push_updates() time.sleep(1) + if __name__ == "__main__": parser = argparse.ArgumentParser(description=("Parse Redis server for the " "log monitor to connect to.")) diff --git a/python/ray/monitor.py b/python/ray/monitor.py index 3d524ecd9..38cb16768 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -3,7 +3,6 @@ from __future__ import division from __future__ import print_function import argparse -import binascii from collections import Counter import logging import redis @@ -13,7 +12,8 @@ from ray.services import get_ip_address from ray.services import get_port # Import flatbuffer bindings. -from ray.core.generated.SubscribeToDBClientTableReply import SubscribeToDBClientTableReply +from ray.core.generated.SubscribeToDBClientTableReply \ + import SubscribeToDBClientTableReply from ray.core.generated.TaskReply import TaskReply # These variables must be kept in sync with the C codebase. @@ -41,6 +41,7 @@ logging.basicConfig() log = logging.getLogger() log.setLevel(logging.WARN) + class Monitor(object): """A monitor for Ray processes. @@ -92,7 +93,8 @@ class Monitor(object): TASK_STATUS_LOST. A local scheduler is deemed dead if it is in self.dead_local_schedulers. """ - task_ids = self.redis.scan_iter(match="{prefix}*".format(prefix=TASK_PREFIX)) + task_ids = self.redis.scan_iter( + match="{prefix}*".format(prefix=TASK_PREFIX)) num_tasks_updated = 0 for task_id in task_ids: task_id = task_id[len(TASK_PREFIX):] @@ -123,11 +125,13 @@ class Monitor(object): """ # TODO(swang): Also kill the associated plasma store, since it's no longer # reachable without a plasma manager. - object_ids = self.redis.scan_iter(match="{prefix}*".format(prefix=OBJECT_PREFIX)) + object_ids = self.redis.scan_iter( + match="{prefix}*".format(prefix=OBJECT_PREFIX)) num_objects_removed = 0 for object_id in object_ids: object_id = object_id[len(OBJECT_PREFIX):] - managers = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", object_id) + managers = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", + object_id) for manager in managers: if manager in self.dead_plasma_managers: # If the object was on a dead plasma manager, remove that location @@ -149,7 +153,8 @@ class Monitor(object): not miss any notifications for deleted clients that occurred before we subscribed. """ - db_client_keys = self.redis.keys("{prefix}*".format(prefix=DB_CLIENT_PREFIX)) + db_client_keys = self.redis.keys( + "{prefix}*".format(prefix=DB_CLIENT_PREFIX)) for db_client_key in db_client_keys: db_client_id = db_client_key[len(DB_CLIENT_PREFIX):] client_type, deleted = self.redis.hmget(db_client_key, @@ -175,10 +180,10 @@ class Monitor(object): flatbuffer. Deletions are processed, insertions are ignored. Cleanup of the associated state in the state tables should be handled by the caller. """ - notification_object = SubscribeToDBClientTableReply.GetRootAsSubscribeToDBClientTableReply(data, 0) + notification_object = (SubscribeToDBClientTableReply + .GetRootAsSubscribeToDBClientTableReply(data, 0)) db_client_id = notification_object.DbClientId() client_type = notification_object.ClientType() - auxiliary_address = notification_object.AuxAddress() is_insertion = notification_object.IsInsertion() # If the update was an insertion, we ignore it. @@ -227,7 +232,6 @@ class Monitor(object): if not self.subscribed[channel]: # If the data was an integer, then the message was a response to an # initial subscription request. - is_subscribe = int(data) message_handler = self.subscribe_handler elif channel == PLASMA_MANAGER_HEARTBEAT_CHANNEL: assert(self.subscribed[channel]) @@ -264,12 +268,11 @@ class Monitor(object): self.cleanup_task_table() if len(self.dead_plasma_managers) > 0: self.cleanup_object_table() - log.debug("{} dead local schedulers, {} plasma " - "managers total, {} dead plasma managers".format( - len(self.dead_local_schedulers), - len(self.live_plasma_managers) + len(self.dead_plasma_managers), - len(self.dead_plasma_managers) - )) + log.debug("{} dead local schedulers, {} plasma managers total, {} dead " + "plasma managers".format(len(self.dead_local_schedulers), + (len(self.live_plasma_managers) + + len(self.dead_plasma_managers)), + len(self.dead_plasma_managers))) # Handle messages from the subscription channels. while True: @@ -289,7 +292,8 @@ class Monitor(object): # Handle plasma managers that timed out during this round. plasma_manager_ids = list(self.live_plasma_managers.keys()) for plasma_manager_id in plasma_manager_ids: - if self.live_plasma_managers[plasma_manager_id] >= NUM_HEARTBEATS_TIMEOUT: + if ((self.live_plasma_managers + [plasma_manager_id]) >= NUM_HEARTBEATS_TIMEOUT): log.warn("Timed out {}".format(PLASMA_MANAGER_CLIENT_TYPE)) # Remove the plasma manager from the managers whose heartbeats we're # tracking. @@ -299,7 +303,8 @@ class Monitor(object): # receive the notification for this db_client deletion. self.redis.execute_command("RAY.DISCONNECT", plasma_manager_id) - # Increment the number of heartbeats that we've missed from each plasma manager. + # Increment the number of heartbeats that we've missed from each plasma + # manager. for plasma_manager_id in self.live_plasma_managers: self.live_plasma_managers[plasma_manager_id] += 1 diff --git a/python/ray/numbuf/__init__.py b/python/ray/numbuf/__init__.py index b4b58bc1f..dc6a82e22 100644 --- a/python/ray/numbuf/__init__.py +++ b/python/ray/numbuf/__init__.py @@ -10,15 +10,29 @@ If you are using Anaconda, try fixing this problem by running: conda install libgcc """ +__all__ = ["deserialize_list", "numbuf_error", + "numbuf_plasma_object_exists_error", "read_from_buffer", + "register_callbacks", "retrieve_list", "serialize_list", + "store_list", "write_to_buffer"] + try: - from ray.core.src.numbuf.libnumbuf import * + from ray.core.src.numbuf.libnumbuf import (deserialize_list, numbuf_error, + numbuf_plasma_object_exists_error, + read_from_buffer, + register_callbacks, retrieve_list, + serialize_list, store_list, + write_to_buffer) except ImportError as e: - if hasattr(e, "msg") and isinstance(e.msg, str) and ("libstdc++" in e.msg or "CXX" in e.msg): + if (hasattr(e, "msg") and isinstance(e.msg, str) and ("libstdc++" in e.msg or + "CXX" in e.msg)): # This code path should be taken with Python 3. e.msg += helpful_message - elif hasattr(e, "message") and isinstance(e.message, str) and ("libstdc++" in e.message or "CXX" in e.message): + elif (hasattr(e, "message") and isinstance(e.message, str) and + ("libstdc++" in e.message or "CXX" in e.message)): # This code path should be taken with Python 2. - if hasattr(e, "args") and isinstance(e.args, tuple) and len(e.args) == 1 and isinstance(e.args[0], str): + condition = (hasattr(e, "args") and isinstance(e.args, tuple) and + len(e.args) == 1 and isinstance(e.args[0], str)) + if condition: e.args = (e.args[0] + helpful_message,) else: if not hasattr(e, "args"): diff --git a/python/ray/pickling.py b/python/ray/pickling.py index 40e63089f..9e6d68b68 100644 --- a/python/ray/pickling.py +++ b/python/ray/pickling.py @@ -1,55 +1,74 @@ -# Note that a little bit of code here is taken and slightly modified from the pickler because it was not possible to change its behavior otherwise. +# Note that a little bit of code here is taken and slightly modified from the +# pickler because it was not possible to change its behavior otherwise. from __future__ import absolute_import from __future__ import division from __future__ import print_function -import sys from ctypes import c_void_p from cloudpickle import pickle, cloudpickle, CloudPickler, load, loads +__all__ = ["load", "loads", "dump", "dumps"] + try: from ctypes import pythonapi pythonapi.PyCell_Set # Make sure this exists except: pythonapi = None + def dump(obj, file, protocol=2): return BetterPickler(file, protocol).dump(obj) + def dumps(obj): stringio = cloudpickle.StringIO() dump(obj, stringio) return stringio.getvalue() -def _make_skel_func(code, closure, base_globals = None): - """ Creates a skeleton function object that contains just the provided - code and the correct number of cells in func_closure. All other - func attributes (e.g. func_globals) are empty. + +def _make_skel_func(code, closure, base_globals=None): + """Create a skeleton function object. + + Creates a skeleton function object that contains just the provided code and + the correct number of cells in func_closure. All other func attributes + (e.g. func_globals) are empty. """ - if base_globals is None: base_globals = {} - base_globals['__builtins__'] = __builtins__ - return _make_skel_func.__class__(code, base_globals, None, None, tuple(closure)) + if base_globals is None: + base_globals = {} + base_globals["__builtins__"] = __builtins__ + return _make_skel_func.__class__(code, base_globals, None, None, + tuple(closure)) + def _fill_function(func, globals, defaults, closure, dict): - """ Fills in the rest of function data into the skeleton function object - that were created via _make_skel_func(), including closures. + """Fill in the resst of the function data. + + This fills in the rest of function data into the skeleton function object + that were created via _make_skel_func(), including closures. """ result = cloudpickle._fill_function(func, globals, defaults, dict) if pythonapi is not None: for i, v in enumerate(closure): - pythonapi.PyCell_Set(c_void_p(id(result.__closure__[i])), c_void_p(id(v))) + pythonapi.PyCell_Set(c_void_p(id(result.__closure__[i])), + c_void_p(id(v))) return result + class BetterPickler(CloudPickler): def save_function_tuple(self, func): - code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func) + (code, f_globals, defaults, + closure, dct, base_globals) = self.extract_func_data(func) self.save(_fill_function) self.write(pickle.MARK) self.save(_make_skel_func if pythonapi else cloudpickle._make_skel_func) - self.save((code, map(lambda _: cloudpickle._make_cell(None), closure) if closure and pythonapi is not None else closure, base_globals)) + self.save((code, + (map(lambda _: cloudpickle._make_cell(None), closure) + if closure and pythonapi is not None + else closure), + base_globals)) self.write(pickle.REDUCE) self.memoize(func) @@ -59,6 +78,7 @@ class BetterPickler(CloudPickler): self.save(dct) self.write(pickle.TUPLE) self.write(pickle.REDUCE) + def save_cell(self, obj): self.save(cloudpickle._make_cell) self.save((obj.cell_contents,)) diff --git a/python/ray/plasma/__init__.py b/python/ray/plasma/__init__.py index b49ea4b9d..c5238bd50 100644 --- a/python/ray/plasma/__init__.py +++ b/python/ray/plasma/__init__.py @@ -2,4 +2,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from ray.plasma.plasma import * +from ray.plasma.plasma import (PlasmaBuffer, buffers_equal, PlasmaClient, + start_plasma_store, start_plasma_manager, + plasma_object_exists_error, + plasma_out_of_memory_error, + DEFAULT_PLASMA_STORE_MEMORY) + +__all__ = ["PlasmaBuffer", "buffers_equal", "PlasmaClient", + "start_plasma_store", "start_plasma_manager", + "plasma_object_exists_error", "plasma_out_of_memory_error", + "DEFAULT_PLASMA_STORE_MEMORY"] diff --git a/python/ray/plasma/plasma.py b/python/ray/plasma/plasma.py index c67eb4ef6..847f6ce2c 100644 --- a/python/ray/plasma/plasma.py +++ b/python/ray/plasma/plasma.py @@ -12,9 +12,14 @@ import ray.core.src.plasma.libplasma as libplasma from ray.core.src.plasma.libplasma import plasma_object_exists_error from ray.core.src.plasma.libplasma import plasma_out_of_memory_error -PLASMA_ID_SIZE = 20 +__all__ = ["PlasmaBuffer", "buffers_equal", "PlasmaClient", + "start_plasma_store", "start_plasma_manager", + "plasma_object_exists_error", "plasma_out_of_memory_error", + "DEFAULT_PLASMA_STORE_MEMORY"] + PLASMA_WAIT_TIMEOUT = 2 ** 30 + class PlasmaBuffer(object): """This is the type of objects returned by calls to get with a PlasmaClient. @@ -47,8 +52,8 @@ class PlasmaBuffer(object): """Read from the PlasmaBuffer as if it were just a regular buffer.""" # We currently don't allow slicing plasma buffers. We should handle this # better, but it requires some care because the slice may be backed by the - # same memory in the object store, but the original plasma buffer may go out - # of scope causing the memory to no longer be accessible. + # same memory in the object store, but the original plasma buffer may go + # out of scope causing the memory to no longer be accessible. assert not isinstance(index, slice) value = self.buffer[index] if sys.version_info >= (3, 0) and not isinstance(index, slice): @@ -62,8 +67,8 @@ class PlasmaBuffer(object): """ # We currently don't allow slicing plasma buffers. We should handle this # better, but it requires some care because the slice may be backed by the - # same memory in the object store, but the original plasma buffer may go out - # of scope causing the memory to no longer be accessible. + # same memory in the object store, but the original plasma buffer may go + # out of scope causing the memory to no longer be accessible. assert not isinstance(index, slice) if sys.version_info >= (3, 0) and not isinstance(index, slice): value = ord(value) @@ -73,48 +78,56 @@ class PlasmaBuffer(object): """Return the length of the buffer.""" return len(self.buffer) + def buffers_equal(buff1, buff2): """Compare two buffers. These buffers may be PlasmaBuffer objects. This method should only be used in the tests. We implement a special helper - method for doing this because doing comparisons by slicing is much faster, but - we don't want to expose slicing of PlasmaBuffer objects because it currently - is not safe. + method for doing this because doing comparisons by slicing is much faster, + but we don't want to expose slicing of PlasmaBuffer objects because it + currently is not safe. """ buff1_to_compare = buff1.buffer if isinstance(buff1, PlasmaBuffer) else buff1 buff2_to_compare = buff2.buffer if isinstance(buff2, PlasmaBuffer) else buff2 return buff1_to_compare[:] == buff2_to_compare[:] + class PlasmaClient(object): - """The PlasmaClient is used to interface with a plasma store and a plasma manager. + """The PlasmaClient is used to interface with a plasma store and manager. The PlasmaClient can ask the PlasmaStore to allocate a new buffer, seal a buffer, and get a buffer. Buffers are referred to by object IDs, which are strings. """ - def __init__(self, store_socket_name, manager_socket_name=None, release_delay=64): + def __init__(self, store_socket_name, manager_socket_name=None, + release_delay=64): """Initialize the PlasmaClient. Args: - store_socket_name (str): Name of the socket the plasma store is listening at. - manager_socket_name (str): Name of the socket the plasma manager is listening at. + store_socket_name (str): Name of the socket the plasma store is listening + at. + manager_socket_name (str): Name of the socket the plasma manager is + listening at. + release_delay (int): The maximum number of objects that the client will + keep and delay releasing (for caching reasons). """ self.store_socket_name = store_socket_name self.manager_socket_name = manager_socket_name self.alive = True if manager_socket_name is not None: - self.conn = libplasma.connect(store_socket_name, manager_socket_name, release_delay) + self.conn = libplasma.connect(store_socket_name, manager_socket_name, + release_delay) else: self.conn = libplasma.connect(store_socket_name, "", release_delay) def shutdown(self): """Shutdown the client so that it does not send messages. - If we kill the Plasma store and Plasma manager that this client is connected - to, then we can use this method to prevent the client from trying to send - messages to the killed processes. + If we kill the Plasma store and Plasma manager that this client is + connected to, then we can use this method to prevent the client from trying + to send messages to the killed processes. """ if self.alive: libplasma.disconnect(self.conn) @@ -169,8 +182,8 @@ class PlasmaClient(object): def get_metadata(self, object_ids, timeout_ms=-1): """Create a buffer from the PlasmaStore based on object ID. - If the object has not been sealed yet, this call will block until the object - has been sealed. The retrieved buffer is immutable. + If the object has not been sealed yet, this call will block until the + object has been sealed. The retrieved buffer is immutable. Args: object_ids (List[str]): A list of strings used to identify some objects. @@ -275,7 +288,8 @@ class PlasmaClient(object): # currently crashes if given duplicate object IDs. if len(object_ids) != len(set(object_ids)): raise Exception("Wait requires a list of unique object IDs.") - ready_ids, waiting_ids = libplasma.wait(self.conn, object_ids, timeout, num_returns) + ready_ids, waiting_ids = libplasma.wait(self.conn, object_ids, timeout, + num_returns) return ready_ids, list(waiting_ids) def subscribe(self): @@ -286,11 +300,14 @@ class PlasmaClient(object): """Get the next notification from the notification socket.""" return libplasma.receive_notification(self.notification_fd) + DEFAULT_PLASMA_STORE_MEMORY = 10 ** 9 + def random_name(): return str(random.randint(0, 99999999)) + def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY, use_valgrind=False, use_profiler=False, stdout_file=None, stderr_file=None): @@ -312,16 +329,20 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY, """ if use_valgrind and use_profiler: raise Exception("Cannot use valgrind and profiler at the same time.") - plasma_store_executable = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../core/src/plasma/plasma_store") + plasma_store_executable = os.path.join(os.path.abspath( + os.path.dirname(__file__)), + "../core/src/plasma/plasma_store") plasma_store_name = "/tmp/plasma_store{}".format(random_name()) - command = [plasma_store_executable, "-s", plasma_store_name, "-m", str(plasma_store_memory)] + command = [plasma_store_executable, + "-s", plasma_store_name, + "-m", str(plasma_store_memory)] if use_valgrind: pid = subprocess.Popen(["valgrind", "--track-origins=yes", "--leak-check=full", "--show-leak-kinds=all", "--error-exitcode=1"] + command, - stdout=stdout_file, stderr=stderr_file) + stdout=stdout_file, stderr=stderr_file) time.sleep(1.0) elif use_profiler: pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command, @@ -332,13 +353,16 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY, time.sleep(0.1) return plasma_store_name, pid + def new_port(): return random.randint(10000, 65535) -def start_plasma_manager(store_name, redis_address, node_ip_address="127.0.0.1", - plasma_manager_port=None, num_retries=20, - use_valgrind=False, run_profiler=False, - stdout_file=None, stderr_file=None): + +def start_plasma_manager(store_name, redis_address, + node_ip_address="127.0.0.1", plasma_manager_port=None, + num_retries=20, use_valgrind=False, + run_profiler=False, stdout_file=None, + stderr_file=None): """Start a plasma manager and return the ports it listens on. Args: @@ -361,7 +385,9 @@ def start_plasma_manager(store_name, redis_address, node_ip_address="127.0.0.1", Raises: Exception: An exception is raised if the manager could not be started. """ - plasma_manager_executable = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../core/src/plasma/plasma_manager") + plasma_manager_executable = os.path.join( + os.path.abspath(os.path.dirname(__file__)), + "../core/src/plasma/plasma_manager") plasma_manager_name = "/tmp/plasma_manager{}".format(random_name()) if plasma_manager_port is not None: if num_retries != 1: @@ -385,7 +411,7 @@ def start_plasma_manager(store_name, redis_address, node_ip_address="127.0.0.1", "--leak-check=full", "--show-leak-kinds=all", "--error-exitcode=1"] + command, - stdout=stdout_file, stderr=stderr_file) + stdout=stdout_file, stderr=stderr_file) elif run_profiler: process = subprocess.Popen(["valgrind", "--tool=callgrind"] + command, stdout=stdout_file, stderr=stderr_file) @@ -396,7 +422,7 @@ def start_plasma_manager(store_name, redis_address, node_ip_address="127.0.0.1", # port is already in use, then we need it to fail within 0.1 seconds. time.sleep(0.1) # See if the process has terminated - if process.poll() == None: + if process.poll() is None: return plasma_manager_name, process, plasma_manager_port # Generate a new port and try again. plasma_manager_port = new_port() diff --git a/python/ray/plasma/test/test.py b/python/ray/plasma/test/test.py index a7c833057..bc24e4e93 100644 --- a/python/ray/plasma/test/test.py +++ b/python/ray/plasma/test/test.py @@ -6,23 +6,22 @@ import numpy as np import os import random import signal -import socket -import struct -import subprocess import sys -import tempfile import threading import time import unittest import ray.plasma as plasma -from ray.plasma.utils import random_object_id, generate_metadata, write_to_data_buffer, create_object_with_id, create_object +from ray.plasma.utils import (random_object_id, generate_metadata, + create_object_with_id, create_object) from ray import services USE_VALGRIND = False PLASMA_STORE_MEMORY = 1000000000 -def assert_get_object_equal(unit_test, client1, client2, object_id, memory_buffer=None, metadata=None): + +def assert_get_object_equal(unit_test, client1, client2, object_id, + memory_buffer=None, metadata=None): client1_buff = client1.get([object_id])[0] client2_buff = client2.get([object_id])[0] client1_metadata = client1.get_metadata([object_id])[0] @@ -32,7 +31,8 @@ def assert_get_object_equal(unit_test, client1, client2, object_id, memory_buffe # Check that the buffers from the two clients are the same. unit_test.assertTrue(plasma.buffers_equal(client1_buff, client2_buff)) # Check that the metadata buffers from the two clients are the same. - unit_test.assertTrue(plasma.buffers_equal(client1_metadata, client2_metadata)) + unit_test.assertTrue(plasma.buffers_equal(client1_metadata, + client2_metadata)) # If a reference buffer was provided, check that it is the same as well. if memory_buffer is not None: unit_test.assertTrue(plasma.buffers_equal(memory_buffer, client1_buff)) @@ -40,11 +40,13 @@ def assert_get_object_equal(unit_test, client1, client2, object_id, memory_buffe if metadata is not None: unit_test.assertTrue(plasma.buffers_equal(metadata, client1_metadata)) + class TestPlasmaClient(unittest.TestCase): def setUp(self): # Start Plasma store. - plasma_store_name, self.p = plasma.start_plasma_store(use_valgrind=USE_VALGRIND) + plasma_store_name, self.p = plasma.start_plasma_store( + use_valgrind=USE_VALGRIND) # Connect to Plasma. self.plasma_client = plasma.PlasmaClient(plasma_store_name, None, 64) # For the eviction test @@ -107,7 +109,7 @@ class TestPlasmaClient(unittest.TestCase): object_id = random_object_id() self.plasma_client.create(object_id, length, generate_metadata(length)) try: - val = self.plasma_client.create(object_id, length, generate_metadata(length)) + self.plasma_client.create(object_id, length, generate_metadata(length)) except plasma.plasma_object_exists_error as e: pass else: @@ -125,20 +127,24 @@ class TestPlasmaClient(unittest.TestCase): metadata_buffers = [] for i in range(num_object_ids): if i % 2 == 0: - data_buffer, metadata_buffer = create_object_with_id(self.plasma_client, object_ids[i], 2000, 2000) + data_buffer, metadata_buffer = create_object_with_id( + self.plasma_client, object_ids[i], 2000, 2000) data_buffers.append(data_buffer) metadata_buffers.append(metadata_buffer) # Test timing out from some but not all get calls with various timeouts. for timeout in [0, 10, 100, 1000]: data_results = self.plasma_client.get(object_ids, timeout_ms=timeout) - metadata_results = self.plasma_client.get(object_ids, timeout_ms=timeout) + # metadata_results = self.plasma_client.get_metadata(object_ids, + # timeout_ms=timeout) for i in range(num_object_ids): if i % 2 == 0: - self.assertTrue(plasma.buffers_equal(data_buffers[i // 2], data_results[i])) - # TODO(rkn): We should compare the metadata as well. But currently the - # types are different (e.g., memoryview versus bytearray). - # self.assertTrue(plasma.buffers_equal(metadata_buffers[i // 2], metadata_results[i])) + self.assertTrue(plasma.buffers_equal(data_buffers[i // 2], + data_results[i])) + # TODO(rkn): We should compare the metadata as well. But currently + # the types are different (e.g., memoryview versus bytearray). + # self.assertTrue(plasma.buffers_equal(metadata_buffers[i // 2], + # metadata_results[i])) else: self.assertIsNone(results[i]) @@ -148,7 +154,9 @@ class TestPlasmaClient(unittest.TestCase): def assert_create_raises_plasma_full(unit_test, size): partial_size = np.random.randint(size) try: - _, memory_buffer, _ = create_object(unit_test.plasma_client, partial_size, size - partial_size) + _, memory_buffer, _ = create_object(unit_test.plasma_client, + partial_size, + size - partial_size) except plasma.plasma_out_of_memory_error as e: pass else: @@ -215,7 +223,7 @@ class TestPlasmaClient(unittest.TestCase): real_object_ids = [random_object_id() for _ in range(100)] for object_id in real_object_ids: self.assertFalse(self.plasma_client.contains(object_id)) - memory_buffer = self.plasma_client.create(object_id, 100) + self.plasma_client.create(object_id, 100) self.plasma_client.seal(object_id) self.assertTrue(self.plasma_client.contains(object_id)) for object_id in fake_object_ids: @@ -226,7 +234,7 @@ class TestPlasmaClient(unittest.TestCase): def test_hash(self): # Check the hash of an object that doesn't exist. object_id1 = random_object_id() - h = self.plasma_client.hash(object_id1) + self.plasma_client.hash(object_id1) length = 1000 # Create a random object, and check that the hash function always returns @@ -357,7 +365,7 @@ class TestPlasmaClient(unittest.TestCase): length = 1000 memory_buffer = self.plasma_client.create(object_id, length) # Make sure we cannot access memory out of bounds. - self.assertRaises(Exception, lambda : memory_buffer[length]) + self.assertRaises(Exception, lambda: memory_buffer[length]) # Seal the object. self.plasma_client.seal(object_id) # This test is commented out because it currently fails. @@ -367,6 +375,7 @@ class TestPlasmaClient(unittest.TestCase): # self.assertRaises(Exception, illegal_assignment) # Get the object. memory_buffer = self.plasma_client.get([object_id])[0] + # Make sure the object is read only. def illegal_assignment(): memory_buffer[0] = chr(0) @@ -413,18 +422,20 @@ class TestPlasmaClient(unittest.TestCase): def test_subscribe(self): # Subscribe to notifications from the Plasma Store. - sock = self.plasma_client.subscribe() + self.plasma_client.subscribe() for i in [1, 10, 100, 1000, 10000, 100000]: object_ids = [random_object_id() for _ in range(i)] metadata_sizes = [np.random.randint(1000) for _ in range(i)] data_sizes = [np.random.randint(1000) for _ in range(i)] for j in range(i): - self.plasma_client.create(object_ids[j], size=data_sizes[j], - metadata=bytearray(np.random.bytes(metadata_sizes[j]))) + self.plasma_client.create( + object_ids[j], size=data_sizes[j], + metadata=bytearray(np.random.bytes(metadata_sizes[j]))) self.plasma_client.seal(object_ids[j]) # Check that we received notifications for all of the objects. for j in range(i): - recv_objid, recv_dsize, recv_msize = self.plasma_client.get_next_notification() + notification_info = self.plasma_client.get_next_notification() + recv_objid, recv_dsize, recv_msize = notification_info self.assertEqual(object_ids[j], recv_objid) self.assertEqual(data_sizes[j], recv_dsize) self.assertEqual(metadata_sizes[j], recv_msize) @@ -432,20 +443,22 @@ class TestPlasmaClient(unittest.TestCase): def test_subscribe_deletions(self): # Subscribe to notifications from the Plasma Store. We use plasma_client2 # to make sure that all used objects will get evicted properly. - sock = self.plasma_client2.subscribe() + self.plasma_client2.subscribe() for i in [1, 10, 100, 1000, 10000, 100000]: object_ids = [random_object_id() for _ in range(i)] # Add 1 to the sizes to make sure we have nonzero object sizes. metadata_sizes = [np.random.randint(1000) + 1 for _ in range(i)] data_sizes = [np.random.randint(1000) + 1 for _ in range(i)] for j in range(i): - x = self.plasma_client2.create(object_ids[j], size=data_sizes[j], - metadata=bytearray(np.random.bytes(metadata_sizes[j]))) + x = self.plasma_client2.create( + object_ids[j], size=data_sizes[j], + metadata=bytearray(np.random.bytes(metadata_sizes[j]))) self.plasma_client2.seal(object_ids[j]) del x # Check that we received notifications for creating all of the objects. for j in range(i): - recv_objid, recv_dsize, recv_msize = self.plasma_client2.get_next_notification() + notification_info = self.plasma_client2.get_next_notification() + recv_objid, recv_dsize, recv_msize = notification_info self.assertEqual(object_ids[j], recv_objid) self.assertEqual(data_sizes[j], recv_dsize) self.assertEqual(metadata_sizes[j], recv_msize) @@ -453,8 +466,10 @@ class TestPlasmaClient(unittest.TestCase): # Check that we receive notifications for deleting all objects, as we # evict them. for j in range(i): - self.assertEqual(self.plasma_client2.evict(1), data_sizes[j] + metadata_sizes[j]) - recv_objid, recv_dsize, recv_msize = self.plasma_client2.get_next_notification() + self.assertEqual(self.plasma_client2.evict(1), + data_sizes[j] + metadata_sizes[j]) + notification_info = self.plasma_client2.get_next_notification() + recv_objid, recv_dsize, recv_msize = notification_info self.assertEqual(object_ids[j], recv_objid) self.assertEqual(-1, recv_dsize) self.assertEqual(-1, recv_msize) @@ -469,18 +484,22 @@ class TestPlasmaClient(unittest.TestCase): metadata_sizes.append(np.random.randint(1000)) data_sizes.append(np.random.randint(1000)) for i in range(num_object_ids): - x = self.plasma_client2.create(object_ids[i], size=data_sizes[i], - metadata=bytearray(np.random.bytes(metadata_sizes[i]))) + x = self.plasma_client2.create( + object_ids[i], size=data_sizes[i], + metadata=bytearray(np.random.bytes(metadata_sizes[i]))) self.plasma_client2.seal(object_ids[i]) del x for i in range(num_object_ids): - recv_objid, recv_dsize, recv_msize = self.plasma_client2.get_next_notification() + notification_info = self.plasma_client2.get_next_notification() + recv_objid, recv_dsize, recv_msize = notification_info self.assertEqual(object_ids[i], recv_objid) self.assertEqual(data_sizes[i], recv_dsize) self.assertEqual(metadata_sizes[i], recv_msize) - self.assertEqual(self.plasma_client2.evict(1), data_sizes[-1] + metadata_sizes[-1]) + self.assertEqual(self.plasma_client2.evict(1), + data_sizes[-1] + metadata_sizes[-1]) for i in range(num_object_ids): - recv_objid, recv_dsize, recv_msize = self.plasma_client2.get_next_notification() + notification_info = self.plasma_client2.get_next_notification() + recv_objid, recv_dsize, recv_msize = notification_info self.assertEqual(object_ids[i], recv_objid) self.assertEqual(-1, recv_dsize) self.assertEqual(-1, recv_msize) @@ -495,8 +514,10 @@ class TestPlasmaManager(unittest.TestCase): # Start a Redis server. redis_address = services.address("127.0.0.1", services.start_redis()[0]) # Start two PlasmaManagers. - manager_name1, self.p4, self.port1 = plasma.start_plasma_manager(store_name1, redis_address, use_valgrind=USE_VALGRIND) - manager_name2, self.p5, self.port2 = plasma.start_plasma_manager(store_name2, redis_address, use_valgrind=USE_VALGRIND) + manager_name1, self.p4, self.port1 = plasma.start_plasma_manager( + store_name1, redis_address, use_valgrind=USE_VALGRIND) + manager_name2, self.p5, self.port2 = plasma.start_plasma_manager( + store_name2, redis_address, use_valgrind=USE_VALGRIND) # Connect two PlasmaClients. self.client1 = plasma.PlasmaClient(store_name1, manager_name1) self.client2 = plasma.PlasmaClient(store_name2, manager_name2) @@ -513,7 +534,8 @@ class TestPlasmaManager(unittest.TestCase): # Kill the Plasma store and Plasma manager processes. if USE_VALGRIND: - time.sleep(1) # give processes opportunity to finish work + # Give processes opportunity to finish work. + time.sleep(1) for process in self.processes_to_kill: process.send_signal(signal.SIGTERM) process.wait() @@ -530,7 +552,8 @@ class TestPlasmaManager(unittest.TestCase): def test_fetch(self): for _ in range(10): # Create an object. - object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000, 2000) + object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000, + 2000) self.client1.fetch([object_id1]) self.assertEqual(self.client1.contains(object_id1), True) self.assertEqual(self.client2.contains(object_id1), False) @@ -546,7 +569,8 @@ class TestPlasmaManager(unittest.TestCase): object_id2 = random_object_id() self.client1.fetch([object_id2]) self.assertEqual(self.client1.contains(object_id2), False) - memory_buffer2, metadata2 = create_object_with_id(self.client2, object_id2, 2000, 2000) + memory_buffer2, metadata2 = create_object_with_id(self.client2, object_id2, + 2000, 2000) # # Check that the object has been fetched. # self.assertEqual(self.client1.contains(object_id2), True) # Compare the two buffers. @@ -560,11 +584,12 @@ class TestPlasmaManager(unittest.TestCase): for _ in range(10): self.client1.fetch([object_id3]) self.client2.fetch([object_id3]) - memory_buffer3, metadata3 = create_object_with_id(self.client1, object_id3, 2000, 2000) + memory_buffer3, metadata3 = create_object_with_id(self.client1, object_id3, + 2000, 2000) for _ in range(10): self.client1.fetch([object_id3]) self.client2.fetch([object_id3]) - #TODO(rkn): Right now we must wait for the object table to be updated. + # TODO(rkn): Right now we must wait for the object table to be updated. while not self.client2.contains(object_id3): self.client2.fetch([object_id3]) assert_get_object_equal(self, self.client1, self.client2, object_id3, @@ -573,14 +598,17 @@ class TestPlasmaManager(unittest.TestCase): def test_fetch_multiple(self): for _ in range(20): # Create two objects and a third fake one that doesn't exist. - object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000, 2000) + object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000, + 2000) missing_object_id = random_object_id() - object_id2, memory_buffer2, metadata2 = create_object(self.client1, 2000, 2000) + object_id2, memory_buffer2, metadata2 = create_object(self.client1, 2000, + 2000) object_ids = [object_id1, missing_object_id, object_id2] # Fetch the objects from the other plasma store. The second object ID # should timeout since it does not exist. # TODO(rkn): Right now we must wait for the object table to be updated. - while (not self.client2.contains(object_id1)) or (not self.client2.contains(object_id2)): + while ((not self.client2.contains(object_id1)) or + (not self.client2.contains(object_id2))): self.client2.fetch(object_ids) # Compare the buffers of the objects that do exist. assert_get_object_equal(self, self.client1, self.client2, object_id1, @@ -597,7 +625,8 @@ class TestPlasmaManager(unittest.TestCase): # Check that we can call fetch with duplicated object IDs. object_id3 = random_object_id() self.client1.fetch([object_id3, object_id3]) - object_id4, memory_buffer4, metadata4 = create_object(self.client1, 2000, 2000) + object_id4, memory_buffer4, metadata4 = create_object(self.client1, 2000, + 2000) time.sleep(0.1) # TODO(rkn): Right now we must wait for the object table to be updated. while not self.client2.contains(object_id4): @@ -623,7 +652,8 @@ class TestPlasmaManager(unittest.TestCase): obj_id2 = random_object_id() self.client1.create(obj_id2, 1000) # Don't seal. - ready, waiting = self.client1.wait([obj_id2, obj_id1], timeout=100, num_returns=1) + ready, waiting = self.client1.wait([obj_id2, obj_id1], timeout=100, + num_returns=1) self.assertEqual(set(ready), set([obj_id1])) self.assertEqual(set(waiting), set([obj_id2])) @@ -636,11 +666,13 @@ class TestPlasmaManager(unittest.TestCase): t = threading.Timer(0.1, finish) t.start() - ready, waiting = self.client1.wait([obj_id3, obj_id2, obj_id1], timeout=1000, num_returns=2) + ready, waiting = self.client1.wait([obj_id3, obj_id2, obj_id1], + timeout=1000, num_returns=2) self.assertEqual(set(ready), set([obj_id1, obj_id3])) self.assertEqual(set(waiting), set([obj_id2])) - # Test if the appropriate number of objects is shown if some objects are not ready + # Test if the appropriate number of objects is shown if some objects are + # not ready. ready, waiting = self.client1.wait([obj_id3, obj_id2, obj_id1], 100, 3) self.assertEqual(set(ready), set([obj_id1, obj_id3])) self.assertEqual(set(waiting), set([obj_id2])) @@ -651,9 +683,9 @@ class TestPlasmaManager(unittest.TestCase): # Test calling wait a bunch of times. object_ids = [] # TODO(rkn): Increasing n to 100 (or larger) will cause failures. The - # problem appears to be that the number of timers added to the manager event - # loop slow down the manager so much that some of the asynchronous Redis - # commands timeout triggering fatal failure callbacks. + # problem appears to be that the number of timers added to the manager + # event loop slow down the manager so much that some of the asynchronous + # Redis commands timeout triggering fatal failure callbacks. n = 40 for i in range(n * (n + 1) // 2): if i % 2 == 0: @@ -669,7 +701,8 @@ class TestPlasmaManager(unittest.TestCase): self.assertEqual(len(ready), i) retrieved += ready self.assertEqual(set(retrieved), set(object_ids)) - ready, waiting = self.client1.wait(object_ids, timeout=1000, num_returns=len(object_ids)) + ready, waiting = self.client1.wait(object_ids, timeout=1000, + num_returns=len(object_ids)) self.assertEqual(set(ready), set(object_ids)) self.assertEqual(waiting, []) # Try waiting for all of the object IDs on the second client. @@ -680,7 +713,8 @@ class TestPlasmaManager(unittest.TestCase): self.assertEqual(len(ready), i) retrieved += ready self.assertEqual(set(retrieved), set(object_ids)) - ready, waiting = self.client2.wait(object_ids, timeout=1000, num_returns=len(object_ids)) + ready, waiting = self.client2.wait(object_ids, timeout=1000, + num_returns=len(object_ids)) self.assertEqual(set(ready), set(object_ids)) self.assertEqual(waiting, []) @@ -702,7 +736,8 @@ class TestPlasmaManager(unittest.TestCase): num_attempts = 100 for _ in range(100): # Create an object. - object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000, 2000) + object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000, + 2000) # Transfer the buffer to the the other Plasma store. There is a race # condition on the create and transfer of the object, so keep trying # until the object appears on the second Plasma store. @@ -721,10 +756,12 @@ class TestPlasmaManager(unittest.TestCase): # self.client1.transfer("127.0.0.1", self.port2, object_id1) # # Compare the two buffers. # assert_get_object_equal(self, self.client1, self.client2, object_id1, - # memory_buffer=memory_buffer1, metadata=metadata1) + # memory_buffer=memory_buffer1, + # metadata=metadata1) # Create an object. - object_id2, memory_buffer2, metadata2 = create_object(self.client2, 20000, 20000) + object_id2, memory_buffer2, metadata2 = create_object(self.client2, + 20000, 20000) # Transfer the buffer to the the other Plasma store. There is a race # condition on the create and transfer of the object, so keep trying # until the object appears on the second Plasma store. @@ -742,17 +779,19 @@ class TestPlasmaManager(unittest.TestCase): def test_illegal_functionality(self): # Create an object id string. - object_id = random_object_id() + # object_id = random_object_id() # Create a new buffer. # memory_buffer = self.client1.create(object_id, 20000) # This test is commented out because it currently fails. # # Transferring the buffer before sealing it should fail. - # self.assertRaises(Exception, lambda : self.manager1.transfer(1, object_id)) + # self.assertRaises(Exception, + # lambda : self.manager1.transfer(1, object_id)) + pass def test_stresstest(self): a = time.time() object_ids = [] - for i in range(10000): # TODO(pcm): increase this to 100000 + for i in range(10000): # TODO(pcm): increase this to 100000. object_id = random_object_id() object_ids.append(object_id) self.client1.create(object_id, 1) @@ -763,13 +802,16 @@ class TestPlasmaManager(unittest.TestCase): print("it took", b, "seconds to put and transfer the objects") + class TestPlasmaManagerRecovery(unittest.TestCase): def setUp(self): # Start a Plasma store. - self.store_name, self.p2 = plasma.start_plasma_store(use_valgrind=USE_VALGRIND) + self.store_name, self.p2 = plasma.start_plasma_store( + use_valgrind=USE_VALGRIND) # Start a Redis server. - self.redis_address = services.address("127.0.0.1", services.start_redis()[0]) + self.redis_address = services.address("127.0.0.1", + services.start_redis()[0]) # Start a PlasmaManagers. manager_name, self.p3, self.port1 = plasma.start_plasma_manager( self.store_name, @@ -791,7 +833,8 @@ class TestPlasmaManagerRecovery(unittest.TestCase): # Kill the Plasma store and Plasma manager processes. if USE_VALGRIND: - time.sleep(1) # give processes opportunity to finish work + # Give processes opportunity to finish work. + time.sleep(1) for process in self.processes_to_kill: process.send_signal(signal.SIGTERM) process.wait() @@ -818,23 +861,26 @@ class TestPlasmaManagerRecovery(unittest.TestCase): self.assertEqual(waiting, []) # Start a second plasma manager attached to the same store. - manager_name, self.p5, self.port2 = plasma.start_plasma_manager(self.store_name, self.redis_address, use_valgrind=USE_VALGRIND) + manager_name, self.p5, self.port2 = plasma.start_plasma_manager( + self.store_name, self.redis_address, use_valgrind=USE_VALGRIND) self.processes_to_kill = [self.p5] + self.processes_to_kill # Check that the second manager knows about existing objects. client2 = plasma.PlasmaClient(self.store_name, manager_name) ready, waiting = [], object_ids while True: - ready, waiting = client2.wait(object_ids, num_returns=num_objects, timeout=0) + ready, waiting = client2.wait(object_ids, num_returns=num_objects, + timeout=0) if len(ready) == len(object_ids): break self.assertEqual(set(ready), set(object_ids)) self.assertEqual(waiting, []) + if __name__ == "__main__": if len(sys.argv) > 1: - # pop the argument so we don't mess with unittest's own argument parser + # Pop the argument so we don't mess with unittest's own argument parser. if sys.argv[-1] == "valgrind": arg = sys.argv.pop() USE_VALGRIND = True diff --git a/python/ray/plasma/utils.py b/python/ray/plasma/utils.py index a6719dde0..e2b63f2de 100644 --- a/python/ray/plasma/utils.py +++ b/python/ray/plasma/utils.py @@ -5,9 +5,11 @@ from __future__ import print_function import numpy as np import random + def random_object_id(): return np.random.bytes(20) + def generate_metadata(length): metadata_buffer = bytearray(length) if length > 0: @@ -17,6 +19,7 @@ def generate_metadata(length): metadata_buffer[random.randint(0, length - 1)] = random.randint(0, 255) return metadata_buffer + def write_to_data_buffer(buff, length): if length > 0: buff[0] = chr(random.randint(0, 255)) @@ -24,7 +27,9 @@ def write_to_data_buffer(buff, length): for _ in range(100): buff[random.randint(0, length - 1)] = chr(random.randint(0, 255)) -def create_object_with_id(client, object_id, data_size, metadata_size, seal=True): + +def create_object_with_id(client, object_id, data_size, metadata_size, + seal=True): metadata = generate_metadata(metadata_size) memory_buffer = client.create(object_id, data_size, metadata) write_to_data_buffer(memory_buffer, data_size) @@ -32,7 +37,9 @@ def create_object_with_id(client, object_id, data_size, metadata_size, seal=True client.seal(object_id) return memory_buffer, metadata + def create_object(client, data_size, metadata_size, seal=True): object_id = random_object_id() - memory_buffer, metadata = create_object_with_id(client, object_id, data_size, metadata_size, seal=seal) + memory_buffer, metadata = create_object_with_id(client, object_id, data_size, + metadata_size, seal=seal) return object_id, memory_buffer, metadata diff --git a/python/ray/serialization.py b/python/ray/serialization.py index 1a707b269..dda62a60e 100644 --- a/python/ray/serialization.py +++ b/python/ray/serialization.py @@ -7,6 +7,7 @@ import numpy as np import ray.numbuf import ray.pickling as pickling + def check_serializable(cls): """Throws an exception if Ray cannot serialize this class efficiently. @@ -21,15 +22,30 @@ def check_serializable(cls): # This case works. return if not hasattr(cls, "__new__"): - raise Exception("The class {} does not have a '__new__' attribute, and is probably an old-style class. We do not support this. Please either make it a new-style class by inheriting from 'object', or use 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls)) + raise Exception("The class {} does not have a '__new__' attribute, and is " + "probably an old-style class. We do not support this. " + "Please either make it a new-style class by inheriting " + "from 'object', or use " + "'ray.register_class(cls, pickle=True)'. However, note " + "that pickle is inefficient.".format(cls)) try: obj = cls.__new__(cls) except: - raise Exception("The class {} has overridden '__new__', so Ray may not be able to serialize it efficiently. Try using 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls)) + raise Exception("The class {} has overridden '__new__', so Ray may not be " + "able to serialize it efficiently. Try using " + "'ray.register_class(cls, pickle=True)'. However, note " + "that pickle is inefficient.".format(cls)) if not hasattr(obj, "__dict__"): - raise Exception("Objects of the class {} do not have a `__dict__` attribute, so Ray cannot serialize it efficiently. Try using 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls)) + raise Exception("Objects of the class {} do not have a `__dict__` " + "attribute, so Ray cannot serialize it efficiently. Try " + "using 'ray.register_class(cls, pickle=True)'. However, " + "note that pickle is inefficient.".format(cls)) if hasattr(obj, "__slots__"): - raise Exception("The class {} uses '__slots__', so Ray may not be able to serialize it efficiently. Try using 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls)) + raise Exception("The class {} uses '__slots__', so Ray may not be able to " + "serialize it efficiently. Try using " + "'ray.register_class(cls, pickle=True)'. However, note " + "that pickle is inefficient.".format(cls)) + # This field keeps track of a whitelisted set of classes that Ray will # serialize. @@ -38,10 +54,12 @@ classes_to_pickle = set() custom_serializers = {} custom_deserializers = {} + def class_identifier(typ): """Return a string that identifies this type.""" return "{}.{}".format(typ.__module__, typ.__name__) + def is_named_tuple(cls): """Return True if cls is a namedtuple and False otherwise.""" b = cls.__bases__ @@ -52,7 +70,9 @@ def is_named_tuple(cls): return False return all(type(n) == str for n in f) -def add_class_to_whitelist(cls, pickle=False, custom_serializer=None, custom_deserializer=None): + +def add_class_to_whitelist(cls, pickle=False, custom_serializer=None, + custom_deserializer=None): """Add cls to the list of classes that we can serialize. Args: @@ -72,13 +92,21 @@ def add_class_to_whitelist(cls, pickle=False, custom_serializer=None, custom_des custom_serializers[class_id] = custom_serializer custom_deserializers[class_id] = custom_deserializer + # Here we define a custom serializer and deserializer for handling numpy # arrays that contain objects. def array_custom_serializer(obj): return obj.tolist(), obj.dtype.str + + def array_custom_deserializer(serialized_obj): return np.array(serialized_obj[0], dtype=np.dtype(serialized_obj[1])) -add_class_to_whitelist(np.ndarray, pickle=False, custom_serializer=array_custom_serializer, custom_deserializer=array_custom_deserializer) + + +add_class_to_whitelist(np.ndarray, pickle=False, + custom_serializer=array_custom_serializer, + custom_deserializer=array_custom_deserializer) + def serialize(obj): """This is the callback that will be used by numbuf. @@ -94,7 +122,9 @@ def serialize(obj): """ class_id = class_identifier(type(obj)) if class_id not in whitelisted_classes: - raise Exception("Ray does not know how to serialize objects of type {}. To fix this, call 'ray.register_class' with this class.".format(type(obj))) + raise Exception("Ray does not know how to serialize objects of type {}. " + "To fix this, call 'ray.register_class' with this class." + .format(type(obj))) if class_id in classes_to_pickle: serialized_obj = {"data": pickling.dumps(obj)} elif class_id in custom_serializers.keys(): @@ -107,10 +137,12 @@ def serialize(obj): elif hasattr(obj, "__dict__"): serialized_obj = obj.__dict__ else: - raise Exception("We do not know how to serialize the object '{}'".format(obj)) + raise Exception("We do not know how to serialize the object '{}'" + .format(obj)) result = dict(serialized_obj, **{"_pytype_": class_id}) return result + def deserialize(serialized_obj): """This is the callback that will be used by numbuf. @@ -139,11 +171,13 @@ def deserialize(serialized_obj): obj.__dict__.update(serialized_obj) return obj + def set_callbacks(): """Register the custom callbacks with numbuf. The serialize callback is used to serialize objects that numbuf does not know - how to serialize (for example custom Python classes). The deserialize callback - is used to serialize objects that were serialized by the serialize callback. + how to serialize (for example custom Python classes). The deserialize + callback is used to serialize objects that were serialized by the serialize + callback. """ ray.numbuf.register_callbacks(serialize, deserialize) diff --git a/python/ray/services.py b/python/ray/services.py index ae988586d..2484c0dcd 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -10,7 +10,6 @@ import random import redis import signal import socket -import string import subprocess import sys import time @@ -60,16 +59,20 @@ ObjectStoreAddress = namedtuple("ObjectStoreAddress", ["name", "manager_name", "manager_port"]) + def address(ip_address, port): return ip_address + ":" + str(port) + def get_ip_address(address): try: ip_address = address.split(":")[0] except: - raise Exception("Unable to parse IP address from address {}".format(address)) + raise Exception("Unable to parse IP address from address " + "{}".format(address)) return ip_address + def get_port(address): try: port = int(address.split(":")[1]) @@ -77,12 +80,15 @@ def get_port(address): raise Exception("Unable to parse port from address {}".format(address)) return port + def new_port(): return random.randint(10000, 65535) + def random_name(): return str(random.randint(0, 99999999)) + def kill_process(p): """Kill a process. @@ -92,11 +98,15 @@ def kill_process(p): Returns: True if the process was killed successfully and false otherwise. """ - if p.poll() is not None: # process has already terminated + if p.poll() is not None: + # The process has already terminated. return True - if RUN_LOCAL_SCHEDULER_PROFILER or RUN_PLASMA_MANAGER_PROFILER or RUN_PLASMA_STORE_PROFILER: - os.kill(p.pid, signal.SIGINT) # Give process signal to write profiler data. - time.sleep(0.1) # Wait for profiling data to be written. + if any([RUN_LOCAL_SCHEDULER_PROFILER, RUN_PLASMA_MANAGER_PROFILER, + RUN_PLASMA_STORE_PROFILER]): + # Give process signal to write profiler data. + os.kill(p.pid, signal.SIGINT) + # Wait for profiling data to be written. + time.sleep(0.1) # Allow the process one second to exit gracefully. p.terminate() @@ -118,6 +128,7 @@ def kill_process(p): # The process was not killed for some reason. return False + def cleanup(): """When running in local mode, shutdown the Ray processes. @@ -138,6 +149,7 @@ def cleanup(): if not successfully_shut_down: print("Ray did not shut down properly.") + def all_processes_alive(exclude=[]): """Check if all of the processes are still alive. @@ -147,10 +159,13 @@ def all_processes_alive(exclude=[]): for process_type, processes in all_processes.items(): # Note that p.poll() returns the exit code that the process exited with, so # an exit code of None indicates that the process is still alive. - if not all([p.poll() is None for p in processes]) and process_type not in exclude: + processes_alive = [p.poll() is None for p in processes] + if (not all(processes_alive) and process_type not in exclude): + print("A process of type {} has dead.".format(process_type)) return False return True + def get_node_ip_address(address="8.8.8.8:53"): """Determine the IP address of the local node. @@ -166,6 +181,7 @@ def get_node_ip_address(address="8.8.8.8:53"): s.connect((ip_address, int(port))) return s.getsockname()[0] + def record_log_files_in_redis(redis_address, node_ip_address, log_files): """Record in Redis that a new log file has been created. @@ -187,6 +203,7 @@ def record_log_files_in_redis(redis_address, node_ip_address, log_files): log_file_list_key = "LOG_FILENAMES:{}".format(node_ip_address) redis_client.rpush(log_file_list_key, log_file.name) + def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5): """Wait for a Redis server to be available. @@ -208,7 +225,8 @@ def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5): while counter < num_retries: try: # Run some random command and see if it worked. - print("Waiting for redis server at {}:{} to respond...".format(redis_ip_address, redis_port)) + print("Waiting for redis server at {}:{} to respond..." + .format(redis_ip_address, redis_port)) redis_client.client_list() except redis.ConnectionError as e: # Wait a little bit. @@ -218,7 +236,10 @@ def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5): else: break if counter == num_retries: - raise Exception("Unable to connect to Redis. If the Redis instance is on a different machine, check that your firewall is configured properly.") + raise Exception("Unable to connect to Redis. If the Redis instance is on " + "a different machine, check that your firewall is " + "configured properly.") + def start_redis(node_ip_address="127.0.0.1", port=None, num_retries=20, stdout_file=None, stderr_file=None, cleanup=True): @@ -239,13 +260,18 @@ def start_redis(node_ip_address="127.0.0.1", port=None, num_retries=20, Returns: A tuple of the port used by Redis and a handle to the process that was - started. If a port is passed in, then the returned port value is the same. + started. If a port is passed in, then the returned port value is the + same. Raises: Exception: An exception is raised if Redis could not be started. """ - redis_filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), "./core/src/common/thirdparty/redis/src/redis-server") - redis_module = os.path.join(os.path.dirname(os.path.abspath(__file__)), "./core/src/common/redis_module/libray_redis_module.so") + redis_filepath = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "./core/src/common/thirdparty/redis/src/redis-server") + redis_module = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "./core/src/common/redis_module/libray_redis_module.so") assert os.path.isfile(redis_filepath) assert os.path.isfile(redis_module) counter = 0 @@ -261,7 +287,7 @@ def start_redis(node_ip_address="127.0.0.1", port=None, num_retries=20, "--port", str(port), "--loglevel", "warning", "--loadmodule", redis_module], - stdout=stdout_file, stderr=stderr_file) + stdout=stdout_file, stderr=stderr_file) time.sleep(0.1) # Check if Redis successfully started (or at least if it the executable did # not exit within 0.1 seconds). @@ -291,6 +317,7 @@ def start_redis(node_ip_address="127.0.0.1", port=None, num_retries=20, [stdout_file, stderr_file]) return port, p + def start_log_monitor(redis_address, node_ip_address, stdout_file=None, stderr_file=None, cleanup=cleanup): """Start a log monitor process. @@ -307,7 +334,9 @@ def start_log_monitor(redis_address, node_ip_address, stdout_file=None, this process will be killed by services.cleanup() when the Python process that imported services exits. """ - log_monitor_filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log_monitor.py") + log_monitor_filepath = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "log_monitor.py") p = subprocess.Popen(["python", log_monitor_filepath, "--redis-address", redis_address, "--node-ip-address", node_ip_address], @@ -317,13 +346,15 @@ def start_log_monitor(redis_address, node_ip_address, stdout_file=None, record_log_files_in_redis(redis_address, node_ip_address, [stdout_file, stderr_file]) + def start_global_scheduler(redis_address, node_ip_address, stdout_file=None, stderr_file=None, cleanup=True): """Start a global scheduler process. Args: redis_address (str): The address of the Redis instance. - node_ip_address: The IP address of the node that this scheduler will run on. + node_ip_address: The IP address of the node that this scheduler will run + on. stdout_file: A file handle opened for writing to redirect stdout to. If no redirection should happen, then this should be None. stderr_file: A file handle opened for writing to redirect stderr to. If no @@ -340,6 +371,7 @@ def start_global_scheduler(redis_address, node_ip_address, stdout_file=None, record_log_files_in_redis(redis_address, node_ip_address, [stdout_file, stderr_file]) + def start_webui(redis_address, node_ip_address, backend_stdout_file=None, backend_stderr_file=None, polymer_stdout_file=None, polymer_stderr_file=None, cleanup=True): @@ -367,8 +399,11 @@ def start_webui(redis_address, node_ip_address, backend_stdout_file=None, Return: True if the web UI was successfully started, otherwise false. """ - webui_backend_filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../webui/backend/ray_ui.py") - webui_directory = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../webui/") + webui_backend_filepath = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "../../webui/backend/ray_ui.py") + webui_directory = os.path.join(os.path.dirname(os.path.abspath(__file__)), + "../../webui/") if sys.version_info >= (3, 0): python_executable = "python" @@ -376,7 +411,8 @@ def start_webui(redis_address, node_ip_address, backend_stdout_file=None, # If the user is using Python 2, it is still possible to run the webserver # separately with Python 3, so try to find a Python 3 executable. try: - python_executable = subprocess.check_output(["which", "python3"]).decode("ascii").strip() + python_executable = subprocess.check_output( + ["which", "python3"]).decode("ascii").strip() except Exception as e: print("Not starting the web UI because the web UI requires Python 3.") return False @@ -394,8 +430,8 @@ def start_webui(redis_address, node_ip_address, backend_stdout_file=None, return False # Try to start polymer. If this fails, it may that port 8080 is already in - # use. It'd be nice to test for this, but doing so by calling "bind" may start - # using the port and prevent polymer from using it. + # use. It'd be nice to test for this, but doing so by calling "bind" may + # start using the port and prevent polymer from using it. try: polymer_process = subprocess.Popen(["polymer", "serve", "--port", "8080"], cwd=webui_directory, @@ -433,6 +469,7 @@ def start_webui(redis_address, node_ip_address, backend_stdout_file=None, return True + def start_local_scheduler(redis_address, node_ip_address, plasma_store_name, @@ -472,8 +509,8 @@ def start_local_scheduler(redis_address, The name of the local scheduler socket. """ if num_cpus is None: - # By default, use the number of hardware execution threads for the number of - # cores. + # By default, use the number of hardware execution threads for the number + # of cores. num_cpus = multiprocessing.cpu_count() if num_gpus is None: # By default, assume this node has no GPUs. @@ -496,6 +533,7 @@ def start_local_scheduler(redis_address, [stdout_file, stderr_file]) return local_scheduler_name + def start_objstore(node_ip_address, redis_address, object_manager_port=None, store_stdout_file=None, store_stderr_file=None, manager_stdout_file=None, manager_stderr_file=None, @@ -511,10 +549,10 @@ def start_objstore(node_ip_address, redis_address, object_manager_port=None, If no redirection should happen, then this should be None. store_stderr_file: A file handle opened for writing to redirect stderr to. If no redirection should happen, then this should be None. - manager_stdout_file: A file handle opened for writing to redirect stdout to. - If no redirection should happen, then this should be None. - manager_stderr_file: A file handle opened for writing to redirect stderr to. - If no redirection should happen, then this should be None. + manager_stdout_file: A file handle opened for writing to redirect stdout + to. If no redirection should happen, then this should be None. + manager_stderr_file: A file handle opened for writing to redirect stderr + to. If no redirection should happen, then this should be None. cleanup (bool): True if using Ray in local mode. If cleanup is true, then this process will be killed by serices.cleanup() when the Python process that imported services exits. @@ -522,8 +560,8 @@ def start_objstore(node_ip_address, redis_address, object_manager_port=None, with. Return: - A tuple of the Plasma store socket name, the Plasma manager socket name, and - the plasma manager port. + A tuple of the Plasma store socket name, the Plasma manager socket name, + and the plasma manager port. """ if objstore_memory is None: # Compute a fraction of the system memory for the Plasma store to use. @@ -559,7 +597,8 @@ def start_objstore(node_ip_address, redis_address, object_manager_port=None, stderr_file=store_stderr_file) # Start the plasma manager. if object_manager_port is not None: - plasma_manager_name, p2, plasma_manager_port = ray.plasma.start_plasma_manager( + (plasma_manager_name, p2, + plasma_manager_port) = ray.plasma.start_plasma_manager( plasma_store_name, redis_address, plasma_manager_port=object_manager_port, @@ -570,7 +609,8 @@ def start_objstore(node_ip_address, redis_address, object_manager_port=None, stderr_file=manager_stderr_file) assert plasma_manager_port == object_manager_port else: - plasma_manager_name, p2, plasma_manager_port = ray.plasma.start_plasma_manager( + (plasma_manager_name, p2, + plasma_manager_port) = ray.plasma.start_plasma_manager( plasma_store_name, redis_address, node_ip_address=node_ip_address, @@ -587,6 +627,7 @@ def start_objstore(node_ip_address, redis_address, object_manager_port=None, return ObjectStoreAddress(plasma_store_name, plasma_manager_name, plasma_manager_port) + def start_worker(node_ip_address, object_store_name, object_store_manager_name, local_scheduler_name, redis_address, worker_path, stdout_file=None, stderr_file=None, cleanup=True): @@ -599,8 +640,8 @@ def start_worker(node_ip_address, object_store_name, object_store_manager_name, object_store_manager_name (str): The name of the object store manager. local_scheduler_name (str): The name of the local scheduler. redis_address (str): The address that the Redis server is listening on. - worker_path (str): The path of the source code which the worker process will - run. + worker_path (str): The path of the source code which the worker process + will run. stdout_file: A file handle opened for writing to redirect stdout to. If no redirection should happen, then this should be None. stderr_file: A file handle opened for writing to redirect stderr to. If no @@ -622,6 +663,7 @@ def start_worker(node_ip_address, object_store_name, object_store_manager_name, record_log_files_in_redis(redis_address, node_ip_address, [stdout_file, stderr_file]) + def start_monitor(redis_address, node_ip_address, stdout_file=None, stderr_file=None, cleanup=True): """Run a process to monitor the other processes. @@ -637,7 +679,8 @@ def start_monitor(redis_address, node_ip_address, stdout_file=None, this process will be killed by services.cleanup() when the Python process that imported services exits. This is True by default. """ - monitor_path= os.path.join(os.path.dirname(os.path.abspath(__file__)), "monitor.py") + monitor_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), + "monitor.py") command = ["python", monitor_path, "--redis-address=" + str(redis_address)] @@ -647,6 +690,7 @@ def start_monitor(redis_address, node_ip_address, stdout_file=None, record_log_files_in_redis(redis_address, node_ip_address, [stdout_file, stderr_file]) + def start_ray_processes(address_info=None, node_ip_address="127.0.0.1", num_workers=0, @@ -685,8 +729,9 @@ def start_ray_processes(address_info=None, start a global scheduler process. include_redis (bool): If include_redis is True, then start a Redis server process. - include_log_monitor (bool): If True, then start a log monitor to monitor the - log files for all processes on this node and push their contents to Redis. + include_log_monitor (bool): If True, then start a log monitor to monitor + the log files for all processes on this node and push their contents to + Redis. include_webui (bool): If True, then attempt to start the web UI. Note that this is only possible with Python 3. start_workers_from_local_scheduler (bool): If this flag is True, then start @@ -713,7 +758,8 @@ def start_ray_processes(address_info=None, address_info["node_ip_address"] = node_ip_address if worker_path is None: - worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "workers/default_worker.py") + worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), + "workers/default_worker.py") # Start Redis if there isn't already an instance running. TODO(rkn): We are # suppressing the output of Redis because on Linux it prints a bunch of @@ -721,7 +767,8 @@ def start_ray_processes(address_info=None, # should address the warnings. redis_address = address_info.get("redis_address") if include_redis: - redis_stdout_file, redis_stderr_file = new_log_files("redis", redirect_output) + redis_stdout_file, redis_stderr_file = new_log_files("redis", + redirect_output) if redis_address is None: # Start a Redis server. The start_redis method will choose a random port. redis_port, _ = start_redis(node_ip_address, @@ -735,7 +782,6 @@ def start_ray_processes(address_info=None, # A Redis address was provided, so start a Redis server with the given # port. TODO(rkn): We should check that the IP address corresponds to the # machine that this method is running on. - redis_ip_address = get_ip_address(redis_address) redis_port = get_port(redis_address) new_redis_port, _ = start_redis(port=int(redis_port), num_retries=1, @@ -744,7 +790,8 @@ def start_ray_processes(address_info=None, cleanup=cleanup) assert redis_port == new_redis_port # Start monitoring the processes. - monitor_stdout_file, monitor_stderr_file = new_log_files("monitor", redirect_output) + monitor_stdout_file, monitor_stderr_file = new_log_files("monitor", + redirect_output) start_monitor(redis_address, node_ip_address, stdout_file=monitor_stdout_file, @@ -755,8 +802,8 @@ def start_ray_processes(address_info=None, # Start the log monitor, if necessary. if include_log_monitor: - log_monitor_stdout_file, log_monitor_stderr_file = new_log_files("log_monitor", - redirect_output=True) + log_monitor_stdout_file, log_monitor_stderr_file = new_log_files( + "log_monitor", redirect_output=True) start_log_monitor(redis_address, node_ip_address, stdout_file=log_monitor_stdout_file, @@ -765,7 +812,8 @@ def start_ray_processes(address_info=None, # Start the global scheduler, if necessary. if include_global_scheduler: - global_scheduler_stdout_file, global_scheduler_stderr_file = new_log_files("global_scheduler", redirect_output) + global_scheduler_stdout_file, global_scheduler_stderr_file = new_log_files( + "global_scheduler", redirect_output) start_global_scheduler(redis_address, node_ip_address, stdout_file=global_scheduler_stdout_file, @@ -781,7 +829,8 @@ def start_ray_processes(address_info=None, local_scheduler_socket_names = address_info["local_scheduler_socket_names"] # Get the ports to use for the object managers if any are provided. - object_manager_ports = address_info["object_manager_ports"] if "object_manager_ports" in address_info else None + object_manager_ports = (address_info["object_manager_ports"] + if "object_manager_ports" in address_info else None) if not isinstance(object_manager_ports, list): object_manager_ports = num_local_schedulers * [object_manager_ports] assert len(object_manager_ports) == num_local_schedulers @@ -789,23 +838,26 @@ def start_ray_processes(address_info=None, # Start any object stores that do not yet exist. for i in range(num_local_schedulers - len(object_store_addresses)): # Start Plasma. - plasma_store_stdout_file, plasma_store_stderr_file = new_log_files("plasma_store_{}".format(i), redirect_output) - plasma_manager_stdout_file, plasma_manager_stderr_file = new_log_files("plasma_manager_{}".format(i), redirect_output) - object_store_address = start_objstore(node_ip_address, - redis_address, - object_manager_port=object_manager_ports[i], - store_stdout_file=plasma_store_stdout_file, - store_stderr_file=plasma_store_stderr_file, - manager_stdout_file=plasma_manager_stdout_file, - manager_stderr_file=plasma_manager_stderr_file, - cleanup=cleanup) + plasma_store_stdout_file, plasma_store_stderr_file = new_log_files( + "plasma_store_{}".format(i), redirect_output) + plasma_manager_stdout_file, plasma_manager_stderr_file = new_log_files( + "plasma_manager_{}".format(i), redirect_output) + object_store_address = start_objstore( + node_ip_address, + redis_address, + object_manager_port=object_manager_ports[i], + store_stdout_file=plasma_store_stdout_file, + store_stderr_file=plasma_store_stderr_file, + manager_stdout_file=plasma_manager_stdout_file, + manager_stderr_file=plasma_manager_stderr_file, + cleanup=cleanup) object_store_addresses.append(object_store_address) time.sleep(0.1) # Determine how many workers to start for each local scheduler. - num_workers_per_local_scheduler = [0] * num_local_schedulers + workers_per_local_scheduler = [0] * num_local_schedulers for i in range(num_workers): - num_workers_per_local_scheduler[i % num_local_schedulers] += 1 + workers_per_local_scheduler[i % num_local_schedulers] += 1 # Start any local schedulers that do not yet exist. for i in range(len(local_scheduler_socket_names), num_local_schedulers): @@ -815,26 +867,28 @@ def start_ray_processes(address_info=None, object_store_address.manager_port) # Determine how many workers this local scheduler should start. if start_workers_from_local_scheduler: - num_local_scheduler_workers = num_workers_per_local_scheduler[i] - num_workers_per_local_scheduler[i] = 0 + num_local_scheduler_workers = workers_per_local_scheduler[i] + workers_per_local_scheduler[i] = 0 else: # If we're starting the workers from Python, the local scheduler should # not start any workers. num_local_scheduler_workers = 0 # Start the local scheduler. - local_scheduler_stdout_file, local_scheduler_stderr_file = new_log_files("local_scheduler_{}".format(i), redirect_output) - local_scheduler_name = start_local_scheduler(redis_address, - node_ip_address, - object_store_address.name, - object_store_address.manager_name, - worker_path, - plasma_address=plasma_address, - stdout_file=local_scheduler_stdout_file, - stderr_file=local_scheduler_stderr_file, - cleanup=cleanup, - num_cpus=num_cpus[i], - num_gpus=num_gpus[i], - num_workers=num_local_scheduler_workers) + local_scheduler_stdout_file, local_scheduler_stderr_file = new_log_files( + "local_scheduler_{}".format(i), redirect_output) + local_scheduler_name = start_local_scheduler( + redis_address, + node_ip_address, + object_store_address.name, + object_store_address.manager_name, + worker_path, + plasma_address=plasma_address, + stdout_file=local_scheduler_stdout_file, + stderr_file=local_scheduler_stderr_file, + cleanup=cleanup, + num_cpus=num_cpus[i], + num_gpus=num_gpus[i], + num_workers=num_local_scheduler_workers) local_scheduler_socket_names.append(local_scheduler_name) time.sleep(0.1) @@ -844,11 +898,12 @@ def start_ray_processes(address_info=None, assert len(local_scheduler_socket_names) == num_local_schedulers # Start any workers that the local scheduler has not already started. - for i, num_local_scheduler_workers in enumerate(num_workers_per_local_scheduler): + for i, num_local_scheduler_workers in enumerate(workers_per_local_scheduler): object_store_address = object_store_addresses[i] local_scheduler_name = local_scheduler_socket_names[i] for j in range(num_local_scheduler_workers): - worker_stdout_file, worker_stderr_file = new_log_files("worker_{}_{}".format(i, j), redirect_output) + worker_stdout_file, worker_stderr_file = new_log_files( + "worker_{}_{}".format(i, j), redirect_output) start_worker(node_ip_address, object_store_address.name, object_store_address.manager_name, @@ -858,17 +913,17 @@ def start_ray_processes(address_info=None, stdout_file=worker_stdout_file, stderr_file=worker_stderr_file, cleanup=cleanup) - num_workers_per_local_scheduler[i] -= 1 + workers_per_local_scheduler[i] -= 1 # Make sure that we've started all the workers. - assert(sum(num_workers_per_local_scheduler) == 0) + assert(sum(workers_per_local_scheduler) == 0) # Try to start the web UI. if include_webui: - backend_stdout_file, backend_stderr_file = new_log_files("webui_backend", - redirect_output=True) - polymer_stdout_file, polymer_stderr_file = new_log_files("webui_polymer", - redirect_output=True) + backend_stdout_file, backend_stderr_file = new_log_files( + "webui_backend", redirect_output=True) + polymer_stdout_file, polymer_stderr_file = new_log_files( + "webui_polymer", redirect_output=True) successfully_started = start_webui(redis_address, node_ip_address, backend_stdout_file=backend_stdout_file, @@ -883,6 +938,7 @@ def start_ray_processes(address_info=None, # Return the addresses of the relevant processes. return address_info + def start_ray_node(node_ip_address, redis_address, object_manager_ports=None, @@ -905,8 +961,8 @@ def start_ray_node(node_ip_address, managers. There should be one per object manager being started on this node (typically just one). num_workers (int): The number of workers to start. - num_local_schedulers (int): The number of local schedulers to start. This is - also the number of plasma stores and plasma managers to start. + num_local_schedulers (int): The number of local schedulers to start. This + is also the number of plasma stores and plasma managers to start. worker_path (str): The path of the source code that will be run by the worker. cleanup (bool): If cleanup is true, then the processes started here will be @@ -919,10 +975,8 @@ def start_ray_node(node_ip_address, A dictionary of the address information for the processes that were started. """ - address_info = { - "redis_address": redis_address, - "object_manager_ports": object_manager_ports, - } + address_info = {"redis_address": redis_address, + "object_manager_ports": object_manager_ports} return start_ray_processes(address_info=address_info, node_ip_address=node_ip_address, num_workers=num_workers, @@ -934,6 +988,7 @@ def start_ray_node(node_ip_address, num_cpus=num_cpus, num_gpus=num_gpus) + def start_ray_head(address_info=None, node_ip_address="127.0.0.1", num_workers=0, @@ -974,33 +1029,37 @@ def start_ray_head(address_info=None, A dictionary of the address information for the processes that were started. """ - return start_ray_processes(address_info=address_info, - node_ip_address=node_ip_address, - num_workers=num_workers, - num_local_schedulers=num_local_schedulers, - worker_path=worker_path, - cleanup=cleanup, - redirect_output=redirect_output, - include_global_scheduler=True, - include_log_monitor=True, - include_redis=True, - include_webui=True, - start_workers_from_local_scheduler=start_workers_from_local_scheduler, - num_cpus=num_cpus, - num_gpus=num_gpus) + return start_ray_processes( + address_info=address_info, + node_ip_address=node_ip_address, + num_workers=num_workers, + num_local_schedulers=num_local_schedulers, + worker_path=worker_path, + cleanup=cleanup, + redirect_output=redirect_output, + include_global_scheduler=True, + include_log_monitor=True, + include_redis=True, + include_webui=True, + start_workers_from_local_scheduler=start_workers_from_local_scheduler, + num_cpus=num_cpus, + num_gpus=num_gpus) + def new_log_files(name, redirect_output): """Generate partially randomized filenames for log files. Args: name (str): descriptive string for this log file. - redirect_output (bool): True if files should be generated for logging stdout - and stderr and false if stdout and stderr should not be redirected. + redirect_output (bool): True if files should be generated for logging + stdout and stderr and false if stdout and stderr should not be + redirected. Returns: - If redirect_output is true, this will return a tuple of two filehandles. The - first is for redirecting stdout and the second is for redirecting stderr. - If redirect_output is false, this will return a tuple of two None objects. + If redirect_output is true, this will return a tuple of two filehandles. + The first is for redirecting stdout and the second is for redirecting + stderr. If redirect_output is false, this will return a tuple of two None + objects. """ if not redirect_output: return None, None diff --git a/python/ray/test/test_functions.py b/python/ray/test/test_functions.py index 3f7dec3ac..70bcd7ef3 100644 --- a/python/ray/test/test_functions.py +++ b/python/ray/test/test_functions.py @@ -8,44 +8,53 @@ import numpy as np # Test simple functionality + @ray.remote(num_return_vals=2) def handle_int(a, b): return a + 1, b + 1 # Test timing + @ray.remote def empty_function(): pass + @ray.remote def trivial_function(): return 1 # Test keyword arguments + @ray.remote def keyword_fct1(a, b="hello"): return "{} {}".format(a, b) + @ray.remote def keyword_fct2(a="hello", b="world"): return "{} {}".format(a, b) + @ray.remote def keyword_fct3(a, b, c="hello", d="world"): return "{} {} {} {}".format(a, b, c, d) # Test variable numbers of arguments + @ray.remote def varargs_fct1(*a): return " ".join(map(str, a)) + @ray.remote def varargs_fct2(a, *b): return " ".join(map(str, b)) + try: @ray.remote def kwargs_throw_exception(**c): @@ -64,24 +73,29 @@ except: # test throwing an exception + @ray.remote def throw_exception_fct1(): raise Exception("Test function 1 intentionally failed.") + @ray.remote def throw_exception_fct2(): raise Exception("Test function 2 intentionally failed.") + @ray.remote(num_return_vals=3) def throw_exception_fct3(x): raise Exception("Test function 3 intentionally failed.") # test Python mode + @ray.remote def python_mode_f(): return np.array([0, 0]) + @ray.remote def python_mode_g(x): x[0] = 1 @@ -89,14 +103,17 @@ def python_mode_g(x): # test no return values + @ray.remote def no_op(): pass + class TestClass(object): def __init__(self): self.a = 5 + @ray.remote def test_unknown_type(): return TestClass() diff --git a/python/ray/worker.py b/python/ray/worker.py index 5dda40441..5a5ea77cb 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -12,10 +12,8 @@ import inspect import json import numpy as np import os -import random import redis import signal -import string import sys import threading import time @@ -60,12 +58,15 @@ PUT_RECONSTRUCTION_ERROR_TYPE = b"put_reconstruction" # This must be kept in sync with the `scheduling_state` enum in common/task.h. TASK_STATUS_RUNNING = 8 + def random_string(): return np.random.bytes(20) + def random_object_id(): return ray.local_scheduler.ObjectID(random_string()) + class FunctionID(object): def __init__(self, function_id): self.function_id = function_id @@ -73,13 +74,16 @@ class FunctionID(object): def id(self): return self.function_id + contained_objectids = [] + + def numbuf_serialize(value): """This serializes a value and tracks the object IDs inside the value. We also define a custom ObjectID serializer which also closes over the global - variable contained_objectids, and whenever the custom serializer is called, it - adds the releevant ObjectID to the list contained_objectids. The list + variable contained_objectids, and whenever the custom serializer is called, + it adds the releevant ObjectID to the list contained_objectids. The list contained_objectids should be reset between calls to numbuf_serialize. Args: @@ -91,6 +95,7 @@ def numbuf_serialize(value): assert len(contained_objectids) == 0, "This should be unreachable." return ray.numbuf.serialize_list([value]) + class RayTaskError(Exception): """An object used internally to represent a task that threw an exception. @@ -113,7 +118,8 @@ class RayTaskError(Exception): def __init__(self, function_name, exception, traceback_str): """Initialize a RayTaskError.""" self.function_name = function_name - if isinstance(exception, RayGetError) or isinstance(exception, RayGetArgumentError): + if isinstance(exception, RayGetError) or isinstance(exception, + RayGetArgumentError): self.exception = exception else: self.exception = None @@ -123,10 +129,15 @@ class RayTaskError(Exception): """Format a RayTaskError as a string.""" if self.traceback_str is None: # This path is taken if getting the task arguments failed. - return "Remote function {}{}{} failed with:\n\n{}".format(colorama.Fore.RED, self.function_name, colorama.Fore.RESET, self.exception) + return ("Remote function {}{}{} failed with:\n\n{}" + .format(colorama.Fore.RED, self.function_name, + colorama.Fore.RESET, self.exception)) else: # This path is taken if the task execution failed. - return "Remote function {}{}{} failed with:\n\n{}".format(colorama.Fore.RED, self.function_name, colorama.Fore.RESET, self.traceback_str) + return ("Remote function {}{}{} failed with:\n\n{}" + .format(colorama.Fore.RED, self.function_name, + colorama.Fore.RESET, self.traceback_str)) + class RayGetError(Exception): """An exception used when get is called on an output of a failed task. @@ -144,7 +155,12 @@ class RayGetError(Exception): def __str__(self): """Format a RayGetError as a string.""" - return "Could not get objectid {}. It was created by remote function {}{}{} which failed with:\n\n{}".format(self.objectid, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error) + return ("Could not get objectid {}. It was created by remote function " + "{}{}{} which failed with:\n\n{}" + .format(self.objectid, colorama.Fore.RED, + self.task_error.function_name, colorama.Fore.RESET, + self.task_error)) + class RayGetArgumentError(Exception): """An exception used when a task's argument was produced by a failed task. @@ -167,18 +183,24 @@ class RayGetArgumentError(Exception): def __str__(self): """Format a RayGetArgumentError as a string.""" - return "Failed to get objectid {} as argument {} for remote function {}{}{}. It was created by remote function {}{}{} which failed with:\n{}".format(self.objectid, self.argument_index, colorama.Fore.RED, self.function_name, colorama.Fore.RESET, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error) + return ("Failed to get objectid {} as argument {} for remote function " + "{}{}{}. It was created by remote function {}{}{} which failed " + "with:\n{}".format(self.objectid, self.argument_index, + colorama.Fore.RED, self.function_name, + colorama.Fore.RESET, colorama.Fore.RED, + self.task_error.function_name, + colorama.Fore.RESET, self.task_error)) class EnvironmentVariable(object): """An Python object that can be shared between tasks. Attributes: - initializer (Callable[[], object]): A function used to create and initialize - the environment variable. + initializer (Callable[[], object]): A function used to create and + initialize the environment variable. reinitializer (Optional[Callable[[object], object]]): An optional function - used to reinitialize the environment variable after it has been used. This - argument can be used as an optimization if there is a fast way to + used to reinitialize the environment variable after it has been used. + This argument can be used as an optimization if there is a fast way to reinitialize the state of the variable other than rerunning the initializer. """ @@ -186,15 +208,20 @@ class EnvironmentVariable(object): def __init__(self, initializer, reinitializer=None): """Initialize an EnvironmentVariable object.""" if not callable(initializer): - raise Exception("When creating an EnvironmentVariable, initializer must be a function.") + raise Exception("When creating an EnvironmentVariable, initializer must " + "be a function.") self.initializer = initializer if reinitializer is None: - # If no reinitializer is passed in, use a wrapped version of the initializer. - reinitializer = lambda value: initializer() + # If no reinitializer is passed in, use a wrapped version of the + # initializer. + def reinitializer(value): + return initializer() if not callable(reinitializer): - raise Exception("When creating an EnvironmentVariable, reinitializer must be a function.") + raise Exception("When creating an EnvironmentVariable, reinitializer " + "must be a function.") self.reinitializer = reinitializer + class RayEnvironmentVariables(object): """An object used to store Python variables that are shared between tasks. @@ -217,9 +244,9 @@ class RayEnvironmentVariables(object): _names (List[str]): A list of the names of all the environment variables. _reinitializers (Dict[str, Callable]): A dictionary mapping the name of the environment variables to the corresponding reinitializer. - _running_remote_function_locally (bool): A flag used to indicate if a remote - function is running locally on the driver so that we can simulate the same - behavior as running a remote function remotely. + _running_remote_function_locally (bool): A flag used to indicate if a + remote function is running locally on the driver so that we can simulate + the same behavior as running a remote function remotely. _environment_variables: A dictionary mapping the name of an environment variable to the value of the environment variable. _local_mode_environment_variables: A copy of _environment_variables used on @@ -236,9 +263,9 @@ class RayEnvironmentVariables(object): object. This list is used to store environment variables that are defined before the driver is connected. Once the driver is connected, these variables will be exported. - _used (List[str]): A list of the names of all the environment variables that - have been accessed within the scope of the current task. This is reset to - the empty list after each task. + _used (List[str]): A list of the names of all the environment variables + that have been accessed within the scope of the current task. This is + reset to the empty list after each task. """ def __init__(self): @@ -250,8 +277,21 @@ class RayEnvironmentVariables(object): self._local_mode_environment_variables = {} self._cached_environment_variables = [] self._used = set() - self._slots = ("_names", "_reinitializers", "_running_remote_function_locally", "_environment_variables", "_local_mode_environment_variables", "_cached_environment_variables", "_used", "_slots", "_create_environment_variable", "_reinitialize", "__getattribute__", "__setattr__", "__delattr__") - # CHECKPOINT: Attributes must not be added after _slots. The above attributes are protected from deletion. + self._slots = ("_names", + "_reinitializers", + "_running_remote_function_locally", + "_environment_variables", + "_local_mode_environment_variables", + "_cached_environment_variables", + "_used", + "_slots", + "_create_environment_variable", + "_reinitialize", + "__getattribute__", + "__setattr__", + "__delattr__") + # CHECKPOINT: Attributes must not be added after _slots. The above + # attributes are protected from deletion. def _create_environment_variable(self, name, environment_variable): """Create an environment variable locally. @@ -265,10 +305,11 @@ class RayEnvironmentVariables(object): self._reinitializers[name] = environment_variable.reinitializer self._environment_variables[name] = environment_variable.initializer() # We create a second copy of the environment variable on the driver to use - # inside of remote functions that run locally. This occurs when we start Ray - # in PYTHON_MODE and when we call a remote function locally. + # inside of remote functions that run locally. This occurs when we start + # Ray in PYTHON_MODE and when we call a remote function locally. if _mode() in [SCRIPT_MODE, SILENT_MODE, PYTHON_MODE]: - self._local_mode_environment_variables[name] = environment_variable.initializer() + self._local_mode_environment_variables[name] = (environment_variable + .initializer()) def _reinitialize(self): """Reinitialize the environment variables that the current task used.""" @@ -282,7 +323,7 @@ class RayEnvironmentVariables(object): self._local_mode_environment_variables[name] = new_value else: self._environment_variables[name] = new_value - self._used.clear() # Reset the _used list. + self._used.clear() # Reset the _used list. def __getattribute__(self, name): """Get an attribute. This handles environment variables as a special case. @@ -312,20 +353,21 @@ class RayEnvironmentVariables(object): def __setattr__(self, name, value): """Set an attribute. This handles environment variables as a special case. - This is used to create environment variables. When it is called, it runs the - function for initializing the variable to create the variable. If this is - called on the driver, then the functions for initializing and reinitializing - the variable are shipped to the workers. + This is used to create environment variables. When it is called, it runs + the function for initializing the variable to create the variable. If this + is called on the driver, then the functions for initializing and + reinitializing the variable are shipped to the workers. If this is called before ray.init has been run, then the environment variable will be cached and it will be created and exported when connect is called. Args: - name (str): The name of the attribute to set. This is either a whitelisted - name or it is treated as the name of an environment variable. - value: If name is a whitelisted name, then value can be any value. If name - is the name of an environment variable, then this is an + name (str): The name of the attribute to set. This is either a + whitelisted name or it is treated as the name of an environment + variable. + value: If name is a whitelisted name, then value can be any value. If + name is the name of an environment variable, then this is an EnvironmentVariable object. """ try: @@ -338,7 +380,8 @@ class RayEnvironmentVariables(object): return object.__setattr__(self, name, value) environment_variable = value if not issubclass(type(environment_variable), EnvironmentVariable): - raise Exception("To set an environment variable, you must pass in an EnvironmentVariable object") + raise Exception("To set an environment variable, you must pass in an " + "EnvironmentVariable object") # If ray.init has not been called, cache the environment variable to export # later. Otherwise, export the environment variable to the workers and # define it locally. @@ -361,7 +404,10 @@ class RayEnvironmentVariables(object): Args: name (str): The name of the attribute to delete. """ - raise Exception("Attempted deletion of attribute {}. Attributes of a RayEnvironmentVariables object may not be deleted.".format(name)) + raise Exception("Attempted deletion of attribute {}. Attributes of a " + "RayEnvironmentVariables object may not be deleted." + .format(name)) + class Worker(object): """A class used to define the control flow of a worker process. @@ -382,8 +428,8 @@ class Worker(object): called connect. The first element is the name of the remote function, and the second element is the serialized remote function. When the worker eventually does call connect, if it is a driver, it will export these - functions to the scheduler. If cached_remote_functions is None, that means - that connect has been called already. + functions to the scheduler. If cached_remote_functions is None, that + means that connect has been called already. cached_functions_to_run (List): A list of functions to run on all of the workers that should be exported as soon as connect is called. """ @@ -409,25 +455,25 @@ class Worker(object): self.cached_functions_to_run = [] self.fetch_and_register = {} self.actors = {} - # Use a defaultdict for the actor counts. If this is accessed with a missing - # key, the default value of 0 is returned, and that key value pair is added - # to the dict. + # Use a defaultdict for the actor counts. If this is accessed with a + # missing key, the default value of 0 is returned, and that key value pair + # is added to the dict. self.actor_counters = collections.defaultdict(lambda: 0) def set_mode(self, mode): """Set the mode of the worker. - The mode SCRIPT_MODE should be used if this Worker is a driver that is being - run as a Python script or interactively in a shell. It will print + The mode SCRIPT_MODE should be used if this Worker is a driver that is + being run as a Python script or interactively in a shell. It will print information about task failures. The mode WORKER_MODE should be used if this Worker is not a driver. It will not print information about tasks. The mode PYTHON_MODE should be used if this Worker is a driver and if you - want to run the driver in a manner equivalent to serial Python for debugging - purposes. It will not send remote function calls to the scheduler and will - insead execute them in a blocking fashion. + want to run the driver in a manner equivalent to serial Python for + debugging purposes. It will not send remote function calls to the scheduler + and will insead execute them in a blocking fashion. The mode SILENT_MODE should be used only during testing. It does not print any information about errors because some of the tests intentionally fail. @@ -450,11 +496,11 @@ class Worker(object): """ # Make sure that the value is not an object ID. if isinstance(value, ray.local_scheduler.ObjectID): - raise Exception("Calling `put` on an ObjectID is not allowed (similarly, " - "returning an ObjectID from a remote function is not " - "allowed). If you really want to do this, you can wrap " - "the ObjectID in a list and call `put` on it (or return " - "it).") + raise Exception("Calling `put` on an ObjectID is not allowed " + "(similarly, returning an ObjectID from a remote " + "function is not allowed). If you really want to do " + "this, you can wrap the ObjectID in a list and call " + "`put` on it (or return it).") # Serialize and put the object in the object store. try: @@ -471,10 +517,11 @@ class Worker(object): contained_objectids = [] def get_object(self, object_ids): - """Get the value or values in the local object store associated with object_ids. + """Get the value or values in the object store associated with object_ids. - Return the values from the local object store for object_ids. This will block - until all the values for object_ids have been written to the local object store. + Return the values from the local object store for object_ids. This will + block until all the values for object_ids have been written to the local + object store. Args: object_ids (List[object_id.ObjectID]): A list of the object IDs whose @@ -507,8 +554,8 @@ class Worker(object): # they were evicted since the last fetch. self.plasma_client.fetch(list(unready_ids.keys())) results = ray.numbuf.retrieve_list(list(unready_ids.keys()), - self.plasma_client.conn, - GET_TIMEOUT_MILLISECONDS) + self.plasma_client.conn, + GET_TIMEOUT_MILLISECONDS) # Remove any entries for objects we received during this iteration so we # don't retrieve the same object twice. for object_id, val in results: @@ -543,7 +590,8 @@ class Worker(object): """ with log_span("ray:submit_task", worker=self): check_main_thread() - actor_id = ray.local_scheduler.ObjectID(NIL_ACTOR_ID) if actor_id is None else actor_id + actor_id = (ray.local_scheduler.ObjectID(NIL_ACTOR_ID) + if actor_id is None else actor_id) # Put large or complex arguments that are passed by value in the object # store first. args_for_local_scheduler = [] @@ -556,7 +604,8 @@ class Worker(object): args_for_local_scheduler.append(put(arg)) # Look up the various function properties. - num_return_vals, num_cpus, num_gpus = self.function_properties[self.task_driver_id.id()][function_id.id()] + num_return_vals, num_cpus, num_gpus = self.function_properties[ + self.task_driver_id.id()][function_id.id()] # Submit the task to local scheduler. task = ray.local_scheduler.Task( @@ -580,9 +629,9 @@ class Worker(object): """Run arbitrary code on all of the workers. This function will first be run on the driver, and then it will be exported - to all of the workers to be run. It will also be run on any new workers that - register later. If ray.init has not been called yet, then cache the function - and export it later. + to all of the workers to be run. It will also be run on any new workers + that register later. If ray.init has not been called yet, then cache the + function and export it later. Args: function (Callable): The function to run on all of the workers. It should @@ -591,9 +640,10 @@ class Worker(object): """ check_main_thread() if self.mode not in [None, SCRIPT_MODE, SILENT_MODE, PYTHON_MODE]: - raise Exception("run_function_on_all_workers can only be called on a driver.") - # If ray.init has not been called yet, then cache the function and export it - # when connect is called. Otherwise, run the function on all workers. + raise Exception("run_function_on_all_workers can only be called on a " + "driver.") + # If ray.init has not been called yet, then cache the function and export + # it when connect is called. Otherwise, run the function on all workers. if self.mode is None: self.cached_functions_to_run.append(function) else: @@ -628,6 +678,7 @@ class Worker(object): "data": data}) self.redis_client.rpush("ErrorKeys", error_key) + global_worker = Worker() """Worker: The global Worker object for this worker process. @@ -645,9 +696,11 @@ used to reinitialize these variables after they are used so that changes to their state made by one task do not affect other tasks. """ + class RayConnectionError(Exception): pass + def check_main_thread(): """Check that we are currently on the main thread. @@ -656,7 +709,10 @@ def check_main_thread(): the main thread. """ if threading.current_thread().getName() != "MainThread": - raise Exception("The Ray methods are not thread safe and must be called from the main thread. This method was called from thread {}.".format(threading.current_thread().getName())) + raise Exception("The Ray methods are not thread safe and must be called " + "from the main thread. This method was called from thread " + "{}.".format(threading.current_thread().getName())) + def check_connected(worker=global_worker): """Check if the worker is connected. @@ -665,7 +721,10 @@ def check_connected(worker=global_worker): Exception: An exception is raised if the worker is not connected. """ if not worker.connected: - raise RayConnectionError("This command cannot be called before Ray has been started. You can start Ray with 'ray.init(num_workers=10)'.") + raise RayConnectionError("This command cannot be called before Ray has " + "been started. You can start Ray with " + "'ray.init(num_workers=10)'.") + def print_failed_task(task_status): """Print information about failed tasks. @@ -679,18 +738,24 @@ def print_failed_task(task_status): Function Name: {} Task ID: {} Error Message: \n{} - """.format(task_status["function_name"], task_status["operationid"], task_status["error_message"])) + """.format(task_status["function_name"], task_status["operationid"], + task_status["error_message"])) + def error_applies_to_driver(error_key, worker=global_worker): """Return True if the error is for this driver and false otherwise.""" # TODO(rkn): Should probably check that this is only called on a driver. # Check that the error key is formatted as in push_error_to_driver. - assert len(error_key) == len(ERROR_KEY_PREFIX) + DRIVER_ID_LENGTH + 1 + ERROR_ID_LENGTH, error_key + assert len(error_key) == (len(ERROR_KEY_PREFIX) + DRIVER_ID_LENGTH + 1 + + ERROR_ID_LENGTH), error_key # If the driver ID in the error message is a sequence of all zeros, then the # message is intended for all drivers. generic_driver_id = DRIVER_ID_LENGTH * b"\x00" - driver_id = error_key[len(ERROR_KEY_PREFIX):(len(ERROR_KEY_PREFIX) + DRIVER_ID_LENGTH)] - return driver_id == worker.task_driver_id.id() or driver_id == generic_driver_id + driver_id = error_key[len(ERROR_KEY_PREFIX):(len(ERROR_KEY_PREFIX) + + DRIVER_ID_LENGTH)] + return (driver_id == worker.task_driver_id.id() or + driver_id == generic_driver_id) + def error_info(worker=global_worker): """Return information about failed tasks.""" @@ -704,18 +769,20 @@ def error_info(worker=global_worker): # If the error is an object hash mismatch, look up the function name for # the nondeterministic task. error_type = error_contents[b"type"] - if (error_type == OBJECT_HASH_MISMATCH_ERROR_TYPE or error_type == - PUT_RECONSTRUCTION_ERROR_TYPE): + if error_type in [OBJECT_HASH_MISMATCH_ERROR_TYPE, + PUT_RECONSTRUCTION_ERROR_TYPE]: function_id = error_contents[b"data"] if function_id == NIL_FUNCTION_ID: function_name = b"Driver" else: - function_name = worker.redis_client.hget("RemoteFunction:{}".format(function_id), "name") + function_name = worker.redis_client.hget( + "RemoteFunction:{}".format(function_id), "name") error_contents[b"data"] = function_name errors.append(error_contents) return errors + def initialize_numbuf(worker=global_worker): """Initialize the serialization library. @@ -723,17 +790,18 @@ def initialize_numbuf(worker=global_worker): serialize several exception classes that we define for error handling. """ ray.serialization.set_callbacks() + # Define a custom serializer and deserializer for handling Object IDs. def objectid_custom_serializer(obj): - class_identifier = serialization.class_identifier(type(obj)) contained_objectids.append(obj) return obj.id() + def objectid_custom_deserializer(serialized_obj): return ray.local_scheduler.ObjectID(serialized_obj) - serialization.add_class_to_whitelist(ray.local_scheduler.ObjectID, - pickle=False, - custom_serializer=objectid_custom_serializer, - custom_deserializer=objectid_custom_deserializer) + serialization.add_class_to_whitelist( + ray.local_scheduler.ObjectID, pickle=False, + custom_serializer=objectid_custom_serializer, + custom_deserializer=objectid_custom_deserializer) if worker.mode in [SCRIPT_MODE, SILENT_MODE]: # These should only be called on the driver because register_class will @@ -742,6 +810,7 @@ def initialize_numbuf(worker=global_worker): register_class(RayGetError) register_class(RayGetArgumentError) + def get_address_info_from_redis_helper(redis_address, node_ip_address): redis_ip_address, redis_port = redis_address.split(":") # For this command to work, some other client (on the same machine as Redis) @@ -774,11 +843,9 @@ def get_address_info_from_redis_helper(redis_address, node_ip_address): port = services.get_port(address) object_store_addresses.append( services.ObjectStoreAddress( - name=manager[b"store_socket_name"].decode("ascii"), - manager_name=manager[b"manager_socket_name"].decode("ascii"), - manager_port=port - ) - ) + name=manager[b"store_socket_name"].decode("ascii"), + manager_name=manager[b"manager_socket_name"].decode("ascii"), + manager_port=port)) scheduler_names = [scheduler[b"local_scheduler_socket_name"].decode("ascii") for scheduler in local_schedulers] client_info = {"node_ip_address": node_ip_address, @@ -788,6 +855,7 @@ def get_address_info_from_redis_helper(redis_address, node_ip_address): } return client_info + def get_address_info_from_redis(redis_address, node_ip_address, num_retries=5): counter = 0 while True: @@ -803,6 +871,7 @@ def get_address_info_from_redis(redis_address, node_ip_address, num_retries=5): time.sleep(1) counter += 1 + def _init(address_info=None, start_ray_local=False, object_id_seed=None, @@ -831,12 +900,12 @@ def _init(address_info=None, Ray cluster. object_id_seed (int): Used to seed the deterministic generation of object IDs. The same value can be used across multiple runs of the same job in - order to generate the object IDs in a consistent manner. However, the same - ID should not be used for different jobs. + order to generate the object IDs in a consistent manner. However, the + same ID should not be used for different jobs. num_workers (int): The number of workers to start. This is only provided if start_ray_local is True. - num_local_schedulers (int): The number of local schedulers to start. This is - only provided if start_ray_local is True. + num_local_schedulers (int): The number of local schedulers to start. This + is only provided if start_ray_local is True. driver_mode (bool): The mode in which to start the driver. This should be one of ray.SCRIPT_MODE, ray.PYTHON_MODE, and ray.SILENT_MODE. redirect_output (bool): True if stdout and stderr for all the processes @@ -858,7 +927,8 @@ def _init(address_info=None, """ check_main_thread() if driver_mode not in [SCRIPT_MODE, PYTHON_MODE, SILENT_MODE]: - raise Exception("Driver_mode must be in [ray.SCRIPT_MODE, ray.PYTHON_MODE, ray.SILENT_MODE].") + raise Exception("Driver_mode must be in [ray.SCRIPT_MODE, " + "ray.PYTHON_MODE, ray.SILENT_MODE].") # Get addresses of existing services. if address_info is None: @@ -873,11 +943,12 @@ def _init(address_info=None, # If starting Ray in PYTHON_MODE, don't start any other processes. pass elif start_ray_local: - # In this case, we launch a scheduler, a new object store, and some workers, - # and we connect to them. We do not launch any processes that are already - # registered in address_info. + # In this case, we launch a scheduler, a new object store, and some + # workers, and we connect to them. We do not launch any processes that are + # already registered in address_info. # Use the address 127.0.0.1 in local mode. - node_ip_address = "127.0.0.1" if node_ip_address is None else node_ip_address + node_ip_address = ("127.0.0.1" if node_ip_address is None + else node_ip_address) # Use 1 worker if num_workers is not provided. num_workers = 10 if num_workers is None else num_workers # Use 1 local scheduler if num_local_schedulers is not provided. If @@ -891,23 +962,28 @@ def _init(address_info=None, num_local_schedulers = 1 # Start the scheduler, object store, and some workers. These will be killed # by the call to cleanup(), which happens when the Python script exits. - address_info = services.start_ray_head(address_info=address_info, - node_ip_address=node_ip_address, - num_workers=num_workers, - num_local_schedulers=num_local_schedulers, - redirect_output=redirect_output, - start_workers_from_local_scheduler=start_workers_from_local_scheduler, - num_cpus=num_cpus, - num_gpus=num_gpus) + address_info = services.start_ray_head( + address_info=address_info, + node_ip_address=node_ip_address, + num_workers=num_workers, + num_local_schedulers=num_local_schedulers, + redirect_output=redirect_output, + start_workers_from_local_scheduler=start_workers_from_local_scheduler, + num_cpus=num_cpus, + num_gpus=num_gpus) else: if redis_address is None: - raise Exception("If start_ray_local=False, then redis_address must be provided.") + raise Exception("If start_ray_local=False, then redis_address must be " + "provided.") if num_workers is not None: - raise Exception("If start_ray_local=False, then num_workers must not be provided.") + raise Exception("If start_ray_local=False, then num_workers must not be " + "provided.") if num_local_schedulers is not None: - raise Exception("If start_ray_local=False, then num_local_schedulers must not be provided.") + raise Exception("If start_ray_local=False, then num_local_schedulers " + "must not be provided.") if num_cpus is not None or num_gpus is not None: - raise Exception("If start_ray_local=False, then num_cpus and num_gpus must not be provided.") + raise Exception("If start_ray_local=False, then num_cpus and num_gpus " + "must not be provided.") # Get the node IP address if one is not provided. if node_ip_address is None: node_ip_address = services.get_node_ip_address(redis_address) @@ -925,12 +1001,15 @@ def _init(address_info=None, "node_ip_address": node_ip_address, "redis_address": address_info["redis_address"], "store_socket_name": address_info["object_store_addresses"][0].name, - "manager_socket_name": address_info["object_store_addresses"][0].manager_name, - "local_scheduler_socket_name": address_info["local_scheduler_socket_names"][0], - } - connect(driver_address_info, object_id_seed=object_id_seed, mode=driver_mode, worker=global_worker, actor_id=NIL_ACTOR_ID) + "manager_socket_name": (address_info["object_store_addresses"][0] + .manager_name), + "local_scheduler_socket_name": (address_info + ["local_scheduler_socket_names"][0])} + connect(driver_address_info, object_id_seed=object_id_seed, mode=driver_mode, + worker=global_worker, actor_id=NIL_ACTOR_ID) return address_info + def init(redis_address=None, node_ip_address=None, object_id_seed=None, num_workers=None, driver_mode=SCRIPT_MODE, redirect_output=False, num_cpus=None, num_gpus=None): @@ -948,8 +1027,8 @@ def init(redis_address=None, node_ip_address=None, object_id_seed=None, workers. It will also kill these processes when Python exits. object_id_seed (int): Used to seed the deterministic generation of object IDs. The same value can be used across multiple runs of the same job in - order to generate the object IDs in a consistent manner. However, the same - ID should not be used for different jobs. + order to generate the object IDs in a consistent manner. However, the + same ID should not be used for different jobs. num_workers (int): The number of workers to start. This is only provided if redis_address is not provided. driver_mode (bool): The mode in which to start the driver. This should be @@ -968,15 +1047,14 @@ def init(redis_address=None, node_ip_address=None, object_id_seed=None, Exception: An exception is raised if an inappropriate combination of arguments is passed in. """ - info = { - "node_ip_address": node_ip_address, - "redis_address": redis_address, - } + info = {"node_ip_address": node_ip_address, + "redis_address": redis_address} return _init(address_info=info, start_ray_local=(redis_address is None), num_workers=num_workers, driver_mode=driver_mode, redirect_output=redirect_output, num_cpus=num_cpus, num_gpus=num_gpus) + def cleanup(worker=global_worker): """Disconnect the worker, and terminate any processes started in init. @@ -1008,11 +1086,14 @@ def cleanup(worker=global_worker): worker.set_mode(None) + atexit.register(cleanup) # Define a custom excepthook so that if the driver exits with an exception, we # can push that exception to Redis. normal_excepthook = sys.excepthook + + def custom_excepthook(type, value, tb): # If this is a driver, push the exception to redis. if global_worker.mode in [SCRIPT_MODE, SILENT_MODE]: @@ -1021,8 +1102,11 @@ def custom_excepthook(type, value, tb): {"exception": error_message}) # Call the normal excepthook. normal_excepthook(type, value, tb) + + sys.excepthook = custom_excepthook + def print_error_messages(worker): """Print error messages in the background on the driver. @@ -1055,7 +1139,8 @@ If this driver is hanging, start a new one with error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1) for error_key in error_keys: if error_applies_to_driver(error_key, worker=worker): - error_message = worker.redis_client.hget(error_key, "message").decode("ascii") + error_message = worker.redis_client.hget(error_key, + "message").decode("ascii") print(error_message) print(helpful_message) num_errors_received += 1 @@ -1063,9 +1148,11 @@ If this driver is hanging, start a new one with try: for msg in worker.error_message_pubsub_client.listen(): with worker.lock: - for error_key in worker.redis_client.lrange("ErrorKeys", num_errors_received, -1): + for error_key in worker.redis_client.lrange("ErrorKeys", + num_errors_received, -1): if error_applies_to_driver(error_key, worker=worker): - error_message = worker.redis_client.hget(error_key, "message").decode("ascii") + error_message = worker.redis_client.hget(error_key, + "message").decode("ascii") print(error_message) print(helpful_message) num_errors_received += 1 @@ -1074,17 +1161,19 @@ If this driver is hanging, start a new one with # we catch here. pass + def fetch_and_register_remote_function(key, worker=global_worker): """Import a remote function.""" - driver_id, function_id_str, function_name, serialized_function, num_return_vals, module, num_cpus, num_gpus = \ - worker.redis_client.hmget(key, ["driver_id", - "function_id", - "name", - "function", - "num_return_vals", - "module", - "num_cpus", - "num_gpus"]) + (driver_id, function_id_str, function_name, serialized_function, + num_return_vals, module, num_cpus, num_gpus) = worker.redis_client.hmget( + key, ["driver_id", + "function_id", + "name", + "function", + "num_return_vals", + "module", + "num_cpus", + "num_gpus"]) function_id = ray.local_scheduler.ObjectID(function_id_str) function_name = function_name.decode("ascii") num_return_vals = int(num_return_vals) @@ -1097,8 +1186,11 @@ def fetch_and_register_remote_function(key, worker=global_worker): def f(): raise Exception("This function was not imported properly.") remote_f_placeholder = remote(function_id=function_id)(lambda *xs: f()) - worker.functions[driver_id][function_id.id()] = (function_name, remote_f_placeholder) - worker.function_properties[driver_id][function_id.id()] = (num_return_vals, num_cpus, num_gpus) + worker.functions[driver_id][function_id.id()] = (function_name, + remote_f_placeholder) + worker.function_properties[driver_id][function_id.id()] = (num_return_vals, + num_cpus, + num_gpus) try: function = pickling.loads(serialized_function) @@ -1114,18 +1206,24 @@ def fetch_and_register_remote_function(key, worker=global_worker): else: # TODO(rkn): Why is the below line necessary? function.__module__ = module - worker.functions[driver_id][function_id.id()] = (function_name, remote(function_id=function_id)(function)) + worker.functions[driver_id][function_id.id()] = ( + function_name, remote(function_id=function_id)(function)) # Add the function to the function table. - worker.redis_client.rpush("FunctionTable:{}".format(function_id.id()), worker.worker_id) + worker.redis_client.rpush("FunctionTable:{}".format(function_id.id()), + worker.worker_id) + def fetch_and_register_environment_variable(key, worker=global_worker): """Import an environment variable.""" - driver_id, environment_variable_name, serialized_initializer, serialized_reinitializer = worker.redis_client.hmget(key, ["driver_id", "name", "initializer", "reinitializer"]) + (driver_id, environment_variable_name, serialized_initializer, + serialized_reinitializer) = worker.redis_client.hmget( + key, ["driver_id", "name", "initializer", "reinitializer"]) environment_variable_name = environment_variable_name.decode("ascii") try: initializer = pickling.loads(serialized_initializer) reinitializer = pickling.loads(serialized_reinitializer) - env.__setattr__(environment_variable_name, EnvironmentVariable(initializer, reinitializer)) + env.__setattr__(environment_variable_name, + EnvironmentVariable(initializer, reinitializer)) except: # If an exception was thrown when the environment variable was imported, we # record the traceback and notify the scheduler of the failure. @@ -1135,9 +1233,11 @@ def fetch_and_register_environment_variable(key, worker=global_worker): traceback_str, data={"name": environment_variable_name}) + def fetch_and_execute_function_to_run(key, worker=global_worker): """Run on arbitrary function on the worker.""" - driver_id, serialized_function = worker.redis_client.hmget(key, ["driver_id", "function"]) + driver_id, serialized_function = worker.redis_client.hmget( + key, ["driver_id", "function"]) # Get the number of workers on this node that have already started executing # this remote function, and increment that value. Subtract 1 so the counter # starts at 0. @@ -1152,17 +1252,18 @@ def fetch_and_execute_function_to_run(key, worker=global_worker): # traceback and notify the scheduler of the failure. traceback_str = traceback.format_exc() # Log the error message. - name = function.__name__ if "function" in locals() and hasattr(function, "__name__") else "" + name = function.__name__ if ("function" in locals() and + hasattr(function, "__name__")) else "" worker.push_error_to_driver(driver_id, "function_to_run", traceback_str, data={"name": name}) + def import_thread(worker): worker.import_pubsub_client = worker.redis_client.pubsub() - # Exports that are published after the call to import_pubsub_client.psubscribe - # and before the call to import_pubsub_client.listen will still be processed - # in the loop. + # Exports that are published after the call to + # import_pubsub_client.psubscribe and before the call to + # import_pubsub_client.listen will still be processed in the loop. worker.import_pubsub_client.psubscribe("__keyspace@0__:Exports") - worker_info_key = "WorkerInfo:{}".format(worker.worker_id) # Keep track of the number of imports that we've imported. num_imported = 0 @@ -1214,7 +1315,9 @@ def import_thread(worker): raise Exception("This code should be unreachable.") num_imported += 1 -def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, actor_id=NIL_ACTOR_ID): + +def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, + actor_id=NIL_ACTOR_ID): """Connect this worker to the local scheduler, to Plasma, and to Redis. Args: @@ -1235,12 +1338,13 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, a worker.actor_id = actor_id worker.connected = True worker.set_mode(mode) - # The worker.events field is used to aggregate logging information and display - # it in the web UI. Note that Python lists protected by the GIL, which is - # important because we will append to this field from multiple threads. + # The worker.events field is used to aggregate logging information and + # display it in the web UI. Note that Python lists protected by the GIL, + # which is important because we will append to this field from multiple + # threads. worker.events = [] - # If running Ray in PYTHON_MODE, there is no need to create call create_worker - # or to start the worker service. + # If running Ray in PYTHON_MODE, there is no need to create call + # create_worker or to start the worker service. if mode == PYTHON_MODE: return # Set the node IP address. @@ -1248,7 +1352,8 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, a worker.redis_address = info["redis_address"] # Create a Redis client. redis_ip_address, redis_port = info["redis_address"].split(":") - worker.redis_client = redis.StrictRedis(host=redis_ip_address, port=int(redis_port)) + worker.redis_client = redis.StrictRedis(host=redis_ip_address, + port=int(redis_port)) worker.lock = threading.Lock() # Register the worker with Redis. @@ -1256,33 +1361,35 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, a # The concept of a driver is the same as the concept of a "job". Register # the driver/job with Redis here. import __main__ as main - driver_info = {"node_ip_address": worker.node_ip_address, - "driver_id": worker.worker_id, - "start_time": time.time(), - "plasma_store_socket": info["store_socket_name"], - "plasma_manager_socket": info["manager_socket_name"], - "local_scheduler_socket": info["local_scheduler_socket_name"]} - driver_info["name"] = main.__file__ if hasattr(main, "__file__") else "INTERACTIVE MODE" + driver_info = { + "node_ip_address": worker.node_ip_address, + "driver_id": worker.worker_id, + "start_time": time.time(), + "plasma_store_socket": info["store_socket_name"], + "plasma_manager_socket": info["manager_socket_name"], + "local_scheduler_socket": info["local_scheduler_socket_name"]} + driver_info["name"] = (main.__file__ if hasattr(main, "__file__") + else "INTERACTIVE MODE") worker.redis_client.hmset(b"Drivers:" + worker.worker_id, driver_info) is_worker = False elif mode == WORKER_MODE: # Register the worker with Redis. - worker.redis_client.hmset(b"Workers:" + worker.worker_id, - {"node_ip_address": worker.node_ip_address, - "plasma_store_socket": info["store_socket_name"], - "plasma_manager_socket": info["manager_socket_name"], - "local_scheduler_socket": info["local_scheduler_socket_name"]}) + worker.redis_client.hmset( + b"Workers:" + worker.worker_id, + {"node_ip_address": worker.node_ip_address, + "plasma_store_socket": info["store_socket_name"], + "plasma_manager_socket": info["manager_socket_name"], + "local_scheduler_socket": info["local_scheduler_socket_name"]}) is_worker = True else: raise Exception("This code should be unreachable.") # Create an object store client. - worker.plasma_client = ray.plasma.PlasmaClient(info["store_socket_name"], info["manager_socket_name"]) + worker.plasma_client = ray.plasma.PlasmaClient(info["store_socket_name"], + info["manager_socket_name"]) # Create the local scheduler client. worker.local_scheduler_client = ray.local_scheduler.LocalSchedulerClient( - info["local_scheduler_socket_name"], - worker.actor_id, - is_worker) + info["local_scheduler_socket_name"], worker.actor_id, is_worker) # If this is a driver, set the current task ID, the task driver ID, and set # the task index to 0. @@ -1315,15 +1422,15 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, a # an object that is later evicted, we should notify the user that we're # unable to reconstruct the object, since we cannot rerun the driver. driver_task = ray.local_scheduler.Task( - worker.task_driver_id, - ray.local_scheduler.ObjectID(NIL_FUNCTION_ID), - [], - 0, - worker.current_task_id, - worker.task_index, - ray.local_scheduler.ObjectID(NIL_ACTOR_ID), - worker.actor_counters[actor_id], - [0, 0]) + worker.task_driver_id, + ray.local_scheduler.ObjectID(NIL_FUNCTION_ID), + [], + 0, + worker.current_task_id, + worker.task_index, + ray.local_scheduler.ObjectID(NIL_ACTOR_ID), + worker.actor_counters[actor_id], + [0, 0]) worker.redis_client.execute_command( "RAY.TASK_TABLE_ADD", driver_task.task_id().id(), @@ -1343,10 +1450,10 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, a # If this is a driver running in SCRIPT_MODE, start a thread to print error # messages asynchronously in the background. Ideally the scheduler would push - # messages to the driver's worker service, but we ran into bugs when trying to - # properly shutdown the driver's worker service, so we are temporarily using - # this implementation which constantly queries the scheduler for new error - # messages. + # messages to the driver's worker service, but we ran into bugs when trying + # to properly shutdown the driver's worker service, so we are temporarily + # using this implementation which constantly queries the scheduler for new + # error messages. if mode == SCRIPT_MODE: t = threading.Thread(target=print_error_messages, args=(worker,)) # Making the thread a daemon causes it to exit when the main thread exits. @@ -1362,18 +1469,20 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, a # the same. script_directory = os.path.abspath(os.path.dirname(sys.argv[0])) current_directory = os.path.abspath(os.path.curdir) - worker.run_function_on_all_workers(lambda worker_info: sys.path.insert(1, script_directory)) - worker.run_function_on_all_workers(lambda worker_info: sys.path.insert(1, current_directory)) + worker.run_function_on_all_workers( + lambda worker_info: sys.path.insert(1, script_directory)) + worker.run_function_on_all_workers( + lambda worker_info: sys.path.insert(1, current_directory)) # TODO(rkn): Here we first export functions to run, then environment # variables, then remote functions. The order matters. For example, one of # the functions to run may set the Python path, which is needed to import a # module used to define an environment variable, which in turn is used - # inside a remote function. We may want to change the order to simply be the - # order in which the exports were defined on the driver. In addition, we - # will need to retain the ability to decide what the first few exports are - # (mostly to set the Python path). Additionally, note that the first exports - # to be defined on the driver will be the ones defined in separate modules - # that are imported by the driver. + # inside a remote function. We may want to change the order to simply be + # the order in which the exports were defined on the driver. In addition, + # we will need to retain the ability to decide what the first few exports + # are (mostly to set the Python path). Additionally, note that the first + # exports to be defined on the driver will be the ones defined in separate + # modules that are imported by the driver. # Export cached functions_to_run. for function in worker.cached_functions_to_run: worker.run_function_on_all_workers(function) @@ -1381,12 +1490,15 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, a for name, environment_variable in env._cached_environment_variables: env.__setattr__(name, environment_variable) # Export cached remote functions to the workers. - for function_id, func_name, func, num_return_vals, num_cpus, num_gpus in worker.cached_remote_functions: - export_remote_function(function_id, func_name, func, num_return_vals, num_cpus, num_gpus, worker) + for info in worker.cached_remote_functions: + function_id, func_name, func, num_return_vals, num_cpus, num_gpus = info + export_remote_function(function_id, func_name, func, num_return_vals, + num_cpus, num_gpus, worker) worker.cached_functions_to_run = None worker.cached_remote_functions = None env._cached_environment_variables = None + def disconnect(worker=global_worker): """Disconnect this worker from the scheduler and object store.""" # Reset the list of cached remote functions so that if more remote functions @@ -1397,12 +1509,13 @@ def disconnect(worker=global_worker): worker.cached_remote_functions = [] env._cached_environment_variables = [] + def register_class(cls, pickle=False, worker=global_worker): """Enable workers to serialize or deserialize objects of a particular class. This method runs the register_class function defined below on every worker, - which will enable numbuf to properly serialize and deserialize objects of this - class. + which will enable numbuf to properly serialize and deserialize objects of + this class. Args: cls (type): The class that numbuf should serialize. @@ -1422,10 +1535,12 @@ def register_class(cls, pickle=False, worker=global_worker): # Raise an exception if cls cannot be serialized efficiently by Ray. if not pickle: serialization.check_serializable(cls) + def register_class_for_serialization(worker_info): serialization.add_class_to_whitelist(cls, pickle=pickle) worker.run_function_on_all_workers(register_class_for_serialization) + class RayLogSpan(object): """An object used to enable logging a span of events with a with statement. @@ -1458,12 +1573,15 @@ class RayLogSpan(object): kind=LOG_SPAN_END, worker=self.worker) + def log_span(event_type, contents=None, worker=global_worker): return RayLogSpan(event_type, contents=contents, worker=worker) + def log_event(event_type, contents=None, worker=global_worker): log(event_type, kind=LOG_POINT, contents=contents, worker=worker) + def log(event_type, kind, contents=None, worker=global_worker): """Log an event to the global state store. @@ -1487,21 +1605,24 @@ def log(event_type, kind, contents=None, worker=global_worker): contents = {str(k): str(v) for k, v in contents.items()} worker.events.append((time.time(), event_type, kind, contents)) + def flush_log(worker=global_worker): """Send the logged worker events to the global state store.""" - event_log_key = b"event_log:" + worker.worker_id + b":" + worker.current_task_id.id() + event_log_key = (b"event_log:" + worker.worker_id + b":" + + worker.current_task_id.id()) event_log_value = json.dumps(worker.events) worker.local_scheduler_client.log_event(event_log_key, event_log_value) worker.events = [] + def get(object_ids, worker=global_worker): """Get a remote object or a list of remote objects from the object store. - This method blocks until the object corresponding to the object ID is available in - the local object store. If this object is not in the local object store, it - will be shipped from an object store that has it (once the object has been - created). If object_ids is a list, then the objects corresponding to each object - in the list will be returned. + This method blocks until the object corresponding to the object ID is + available in the local object store. If this object is not in the local + object store, it will be shipped from an object store that has it (once the + object has been created). If object_ids is a list, then the objects + corresponding to each object in the list will be returned. Args: object_ids: Object ID of the object to get or a list of object IDs to get. @@ -1514,7 +1635,8 @@ def get(object_ids, worker=global_worker): check_main_thread() if worker.mode == PYTHON_MODE: - # In PYTHON_MODE, ray.get is the identity operation (the input will actually be a value not an objectid) + # In PYTHON_MODE, ray.get is the identity operation (the input will + # actually be a value not an objectid). return object_ids if isinstance(object_ids, list): values = worker.get_object(object_ids) @@ -1525,11 +1647,12 @@ def get(object_ids, worker=global_worker): else: value = worker.get_object([object_ids])[0] if isinstance(value, RayTaskError): - # If the result is a RayTaskError, then the task that created this object - # failed, and we should propagate the error message here. + # If the result is a RayTaskError, then the task that created this + # object failed, and we should propagate the error message here. raise RayGetError(object_ids, value) return value + def put(value, worker=global_worker): """Store an object in the object store. @@ -1544,21 +1667,22 @@ def put(value, worker=global_worker): check_main_thread() if worker.mode == PYTHON_MODE: - # In PYTHON_MODE, ray.put is the identity operation + # In PYTHON_MODE, ray.put is the identity operation. return value object_id = worker.local_scheduler_client.compute_put_id( - worker.current_task_id, worker.put_index) + worker.current_task_id, worker.put_index) worker.put_object(object_id, value) worker.put_index += 1 return object_id + def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): """Return a list of IDs that are ready and a list of IDs that are not ready. If timeout is set, the function returns either when the requested number of - IDs are ready or when the timeout is reached, whichever occurs first. If it is - not set, the function simply waits until that number of objects is ready and - returns that exact number of objectids. + IDs are ready or when the timeout is reached, whichever occurs first. If it + is not set, the function simply waits until that number of objects is ready + and returns that exact number of objectids. This method returns two lists. The first list consists of object IDs that correspond to objects that are stored in the object store. The second list @@ -1579,17 +1703,21 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): check_main_thread() object_id_strs = [object_id.id() for object_id in object_ids] timeout = timeout if timeout is not None else 2 ** 30 - ready_ids, remaining_ids = worker.plasma_client.wait(object_id_strs, timeout, num_returns) - ready_ids = [ray.local_scheduler.ObjectID(object_id) for object_id in ready_ids] - remaining_ids = [ray.local_scheduler.ObjectID(object_id) for object_id in remaining_ids] + ready_ids, remaining_ids = worker.plasma_client.wait(object_id_strs, + timeout, num_returns) + ready_ids = [ray.local_scheduler.ObjectID(object_id) + for object_id in ready_ids] + remaining_ids = [ray.local_scheduler.ObjectID(object_id) + for object_id in remaining_ids] return ready_ids, remaining_ids + def wait_for_function(function_id, driver_id, timeout=5, worker=global_worker): """Wait until the function to be executed is present on this worker. - This method will simply loop until the import thread has imported the relevant - function. If we spend too long in this loop, that may indicate a problem - somewhere and we will push an error message to the user. + This method will simply loop until the import thread has imported the + relevant function. If we spend too long in this loop, that may indicate a + problem somewhere and we will push an error message to the user. If this worker is an actor, then this will wait until the actor has been defined. @@ -1606,18 +1734,23 @@ def wait_for_function(function_id, driver_id, timeout=5, worker=global_worker): num_warnings_sent = 0 while True: with worker.lock: - if worker.actor_id == NIL_ACTOR_ID and function_id.id() in worker.functions[driver_id]: + if worker.actor_id == NIL_ACTOR_ID and (function_id.id() in + worker.functions[driver_id]): break - elif worker.actor_id != NIL_ACTOR_ID and worker.actor_id in worker.actors: + elif worker.actor_id != NIL_ACTOR_ID and (worker.actor_id in + worker.actors): break if time.time() - start_time > timeout * (num_warnings_sent + 1): - warning_message = "This worker was asked to execute a function that it does not have registered. You may have to restart Ray." + warning_message = ("This worker was asked to execute a function that " + "it does not have registered. You may have to " + "restart Ray.") if not warning_sent: worker.push_error_to_driver(driver_id, "wait_for_function", warning_message) warning_sent = True time.sleep(0.001) + def format_error_message(exception_message, task_exception=False): """Improve the formatting of an exception thrown by a remote function. @@ -1639,14 +1772,16 @@ def format_error_message(exception_message, task_exception=False): lines = lines[0:1] + lines[5:] return "\n".join(lines) + def main_loop(worker=global_worker): """The main loop a worker runs to receive and execute tasks. This method is an infinite loop. It waits to receive commands from the scheduler. A command may consist of a task to execute, a remote function to - import, an environment variable to import, or an order to terminate the worker - process. The worker executes the command, notifies the scheduler of any errors - that occurred while executing the command, and waits for the next command. + import, an environment variable to import, or an order to terminate the + worker process. The worker executes the command, notifies the scheduler of + any errors that occurred while executing the command, and waits for the next + command. """ def exit(signum, frame): @@ -1655,7 +1790,7 @@ def main_loop(worker=global_worker): signal.signal(signal.SIGTERM, exit) - def process_task(task): # wrapping these lines in a function should cause the local variables to go out of scope more quickly, which is useful for inspecting reference counts + def process_task(task): """Execute a task assigned to this worker. This method deserializes a task from the scheduler, and attempts to execute @@ -1678,7 +1813,9 @@ def main_loop(worker=global_worker): function_id = task.function_id() args = task.arguments() return_object_ids = task.returns() - function_name, function_executor = worker.functions[worker.task_driver_id.id()][function_id.id()] + function_name, function_executor = (worker.functions + [worker.task_driver_id.id()] + [function_id.id()]) # Get task arguments from the object store. with log_span("ray:task:get_arguments", worker=worker): @@ -1689,7 +1826,8 @@ def main_loop(worker=global_worker): if task.actor_id().id() == NIL_ACTOR_ID: outputs = function_executor.executor(arguments) else: - outputs = function_executor(worker.actors[task.actor_id().id()], *arguments) + outputs = function_executor( + worker.actors[task.actor_id().id()], *arguments) # Store the outputs in the local object store. with log_span("ray:task:store_outputs", worker=worker): @@ -1705,7 +1843,8 @@ def main_loop(worker=global_worker): if "arguments" in locals() and "outputs" not in locals(): if task.actor_id().id() == NIL_ACTOR_ID: # The error occurred during the task execution. - traceback_str = format_error_message(traceback.format_exc(), task_exception=True) + traceback_str = format_error_message(traceback.format_exc(), + task_exception=True) else: # The error occurred during the execution of an actor task. traceback_str = format_error_message(traceback.format_exc()) @@ -1725,8 +1864,10 @@ def main_loop(worker=global_worker): "function_name": function_name}) try: # Reinitialize the values of environment variables that were used in the - # task above so that changes made to their state do not affect other tasks. - with log_span("ray:task:reinitialize_environment_variables", worker=worker): + # task above so that changes made to their state do not affect other + # tasks. + with log_span("ray:task:reinitialize_environment_variables", + worker=worker): env._reinitialize() except Exception as e: # The attempt to reinitialize the environment variables threw an @@ -1752,14 +1893,15 @@ def main_loop(worker=global_worker): # Execute the task. # TODO(rkn): Consider acquiring this lock with a timeout and pushing a - # warning to the user if we are waiting too long to acquire the lock because - # that may indicate that the system is hanging, and it'd be good to know - # where the system is hanging. + # warning to the user if we are waiting too long to acquire the lock + # because that may indicate that the system is hanging, and it'd be good to + # know where the system is hanging. log(event_type="ray:acquire_lock", kind=LOG_SPAN_START, worker=worker) with worker.lock: log(event_type="ray:acquire_lock", kind=LOG_SPAN_END, worker=worker) - function_name, _ = worker.functions[task.driver_id().id()][function_id.id()] + function_name, _ = (worker.functions[task.driver_id().id()] + [function_id.id()]) contents = {"function_name": function_name, "task_id": task.task_id().hex()} with log_span("ray:task", contents=contents, worker=worker): @@ -1768,6 +1910,7 @@ def main_loop(worker=global_worker): # Push all of the log events to the global state store. flush_log() + def _submit_task(function_id, func_name, args, worker=global_worker): """This is a wrapper around worker.submit_task. @@ -1778,6 +1921,7 @@ def _submit_task(function_id, func_name, args, worker=global_worker): """ return worker.submit_task(function_id, func_name, args) + def _mode(worker=global_worker): """This is a wrapper around worker.mode. @@ -1788,6 +1932,7 @@ def _mode(worker=global_worker): """ return worker.mode + def _env(): """Return the env object. @@ -1796,7 +1941,9 @@ def _env(): """ return env -def _export_environment_variable(name, environment_variable, worker=global_worker): + +def _export_environment_variable(name, environment_variable, + worker=global_worker): """Export an environment variable to the workers. This is only called by a driver. @@ -1808,20 +1955,26 @@ def _export_environment_variable(name, environment_variable, worker=global_worke """ check_main_thread() if _mode(worker) not in [SCRIPT_MODE, SILENT_MODE]: - raise Exception("_export_environment_variable can only be called on a driver.") + raise Exception("_export_environment_variable can only be called on a " + "driver.") environment_variable_id = name key = "EnvironmentVariables:{}".format(environment_variable_id) - worker.redis_client.hmset(key, {"driver_id": worker.task_driver_id.id(), - "name": name, - "initializer": pickling.dumps(environment_variable.initializer), - "reinitializer": pickling.dumps(environment_variable.reinitializer)}) + worker.redis_client.hmset(key, { + "driver_id": worker.task_driver_id.id(), + "name": name, + "initializer": pickling.dumps(environment_variable.initializer), + "reinitializer": pickling.dumps(environment_variable.reinitializer)}) worker.redis_client.rpush("Exports", key) -def export_remote_function(function_id, func_name, func, num_return_vals, num_cpus, num_gpus, worker=global_worker): + +def export_remote_function(function_id, func_name, func, num_return_vals, + num_cpus, num_gpus, worker=global_worker): check_main_thread() if _mode(worker) not in [SCRIPT_MODE, SILENT_MODE]: raise Exception("export_remote_function can only be called on a driver.") - worker.function_properties[worker.task_driver_id.id()][function_id.id()] = (num_return_vals, num_cpus, num_gpus) + + worker.function_properties[worker.task_driver_id.id()][function_id.id()] = ( + num_return_vals, num_cpus, num_gpus) key = "RemoteFunction:{}".format(function_id.id()) pickled_func = pickling.dumps(func) worker.redis_client.hmset(key, {"driver_id": worker.task_driver_id.id(), @@ -1834,18 +1987,20 @@ def export_remote_function(function_id, func_name, func, num_return_vals, num_cp "num_gpus": num_gpus}) worker.redis_client.rpush("Exports", key) + def remote(*args, **kwargs): """This decorator is used to create remote functions. Args: - num_return_vals (int): The number of object IDs that a call to this function - should return. + num_return_vals (int): The number of object IDs that a call to this + function should return. num_cpus (int): The number of CPUs needed to execute this function. This should only be passed in when defining the remote function on the driver. num_gpus (int): The number of GPUs needed to execute this function. This should only be passed in when defining the remote function on the driver. """ worker = global_worker + def make_remote_decorator(num_return_vals, num_cpus, num_gpus, func_id=None): def remote_decorator(func): func_name = "{}.{}".format(func.__module__, func.__name__) @@ -1868,13 +2023,16 @@ def remote(*args, **kwargs): check_connected() check_main_thread() args = list(args) - args.extend([kwargs[keyword] if keyword in kwargs else default for keyword, default in keyword_defaults[len(args):]]) # fill in the remaining arguments + # Fill in the remaining arguments. + args.extend([kwargs[keyword] if keyword in kwargs else default + for keyword, default in keyword_defaults[len(args):]]) if any([arg is funcsigs._empty for arg in args]): - raise Exception("Not enough arguments were provided to {}.".format(func_name)) + raise Exception("Not enough arguments were provided to {}." + .format(func_name)) if _mode() == PYTHON_MODE: - # In PYTHON_MODE, remote calls simply execute the function. We copy the - # arguments to prevent the function call from mutating them and to match - # the usual behavior of immutable remote objects. + # In PYTHON_MODE, remote calls simply execute the function. We copy + # the arguments to prevent the function call from mutating them and + # to match the usual behavior of immutable remote objects. try: _env()._running_remote_function_locally = True result = func(*copy.deepcopy(args)) @@ -1887,15 +2045,17 @@ def remote(*args, **kwargs): return objectids[0] elif len(objectids) > 1: return objectids + def func_executor(arguments): """This gets run when the remote function is executed.""" - start_time = time.time() result = func(*arguments) - end_time = time.time() return result + def func_invoker(*args, **kwargs): - """This is returned by the decorator and used to invoke the function.""" - raise Exception("Remote functions cannot be called directly. Instead of running '{}()', try '{}.remote()'.".format(func_name, func_name)) + """This is used to invoke the function.""" + raise Exception("Remote functions cannot be called directly. Instead " + "of running '{}()', try '{}.remote()'." + .format(func_name, func_name)) func_invoker.remote = func_call func_invoker.executor = func_executor func_invoker.is_remote = True @@ -1906,41 +2066,38 @@ def remote(*args, **kwargs): else: func_invoker.func_doc = func.func_doc - sig_params = [(k, v) for k, v in funcsigs.signature(func).parameters.items()] + sig_params = [(k, v) for k, v + in funcsigs.signature(func).parameters.items()] keyword_defaults = [(k, v.default) for k, v in sig_params] - has_vararg_param = any([v.kind == v.VAR_POSITIONAL for k, v in sig_params]) + has_vararg_param = any([v.kind == v.VAR_POSITIONAL + for k, v in sig_params]) func_invoker.has_vararg_param = has_vararg_param has_kwargs_param = any([v.kind == v.VAR_KEYWORD for k, v in sig_params]) - check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, func_name) + check_signature_supported(has_kwargs_param, has_vararg_param, + keyword_defaults, func_name) # Everything ready - export the function - if worker.mode in [None, SCRIPT_MODE, SILENT_MODE]: - func_name_global_valid = func.__name__ in func.__globals__ - func_name_global_value = func.__globals__.get(func.__name__) - # Set the function globally to make it refer to itself - func.__globals__[func.__name__] = func_invoker # Allow the function to reference itself as a global variable - try: - to_export = pickling.dumps((func, num_return_vals, func.__module__)) - finally: - # Undo our changes - if func_name_global_valid: func.__globals__[func.__name__] = func_name_global_value - else: del func.__globals__[func.__name__] if worker.mode in [SCRIPT_MODE, SILENT_MODE]: - export_remote_function(function_id, func_name, func, num_return_vals, num_cpus, num_gpus) + export_remote_function(function_id, func_name, func, num_return_vals, + num_cpus, num_gpus) elif worker.mode is None: - worker.cached_remote_functions.append((function_id, func_name, func, num_return_vals, num_cpus, num_gpus)) + worker.cached_remote_functions.append((function_id, func_name, func, + num_return_vals, num_cpus, + num_gpus)) return func_invoker return remote_decorator - num_return_vals = kwargs["num_return_vals"] if "num_return_vals" in kwargs.keys() else 1 + num_return_vals = (kwargs["num_return_vals"] if "num_return_vals" + in kwargs.keys() else 1) num_cpus = kwargs["num_cpus"] if "num_cpus" in kwargs.keys() else 1 num_gpus = kwargs["num_gpus"] if "num_gpus" in kwargs.keys() else 0 if _mode() == WORKER_MODE: if "function_id" in kwargs: function_id = kwargs["function_id"] - return make_remote_decorator(num_return_vals, num_cpus, num_gpus, function_id) + return make_remote_decorator(num_return_vals, num_cpus, num_gpus, + function_id) if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): # This is the case where the decorator is just @ray.remote. @@ -1956,10 +2113,12 @@ def remote(*args, **kwargs): assert len(args) == 0 and ("num_return_vals" in kwargs or "num_cpus" in kwargs or "num_gpus" in kwargs), error_string - assert not "function_id" in kwargs + assert "function_id" not in kwargs return make_remote_decorator(num_return_vals, num_cpus, num_gpus) -def check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, name): + +def check_signature_supported(has_kwargs_param, has_vararg_param, + keyword_defaults, name): """Check if we support the signature of this function. We currently do not allow remote functions to have **kwargs. We also do not @@ -1977,14 +2136,20 @@ def check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaul Raises: Exception: An exception is raised if the signature is not supported. """ - # check if the user specified kwargs + # Check if the user specified kwargs. if has_kwargs_param: - raise "Function {} has a **kwargs argument, which is currently not supported.".format(name) - # check if the user specified a variable number of arguments and any keyword arguments - if has_vararg_param and any([d != funcsigs._empty for _, d in keyword_defaults]): - raise "Function {} has a *args argument as well as a keyword argument, which is currently not supported.".format(name) + raise ("Function {} has a **kwargs argument, which is currently not " + "supported.".format(name)) + # Check if the user specified a variable number of arguments and any keyword + # arguments. + if has_vararg_param and any([d != funcsigs._empty + for _, d in keyword_defaults]): + raise ("Function {} has a *args argument as well as a keyword argument, " + "which is currently not supported.".format(name)) -def get_arguments_for_execution(function_name, serialized_args, worker=global_worker): + +def get_arguments_for_execution(function_name, serialized_args, + worker=global_worker): """Retrieve the arguments for the remote function. This retrieves the values for the arguments to the remote function that were @@ -2022,13 +2187,15 @@ def get_arguments_for_execution(function_name, serialized_args, worker=global_wo arguments.append(argument) return arguments + def store_outputs_in_objstore(objectids, outputs, worker=global_worker): """Store the outputs of a remote function in the local object store. This stores the values that were returned by a remote function in the local object store. If any of the return values are object IDs, then these object - IDs are aliased with the object IDs that the scheduler assigned for the return - values. This is called by the worker that executes the remote function. + IDs are aliased with the object IDs that the scheduler assigned for the + return values. This is called by the worker that executes the remote + function. Note: The arguments objectids and outputs should have the same length. diff --git a/python/ray/workers/default_worker.py b/python/ray/workers/default_worker.py index c6457357c..009ae6138 100644 --- a/python/ray/workers/default_worker.py +++ b/python/ray/workers/default_worker.py @@ -3,25 +3,33 @@ from __future__ import division from __future__ import print_function import argparse +import binascii import numpy as np import redis import traceback -import sys -import binascii import ray -parser = argparse.ArgumentParser(description="Parse addresses for the worker to connect to.") -parser.add_argument("--node-ip-address", required=True, type=str, help="the ip address of the worker's node") -parser.add_argument("--redis-address", required=True, type=str, help="the address to use for Redis") -parser.add_argument("--object-store-name", required=True, type=str, help="the object store's name") -parser.add_argument("--object-store-manager-name", required=True, type=str, help="the object store manager's name") -parser.add_argument("--local-scheduler-name", required=True, type=str, help="the local scheduler's name") -parser.add_argument("--actor-id", required=False, type=str, help="the actor ID of this worker") +parser = argparse.ArgumentParser(description=("Parse addresses for the worker " + "to connect to.")) +parser.add_argument("--node-ip-address", required=True, type=str, + help="the ip address of the worker's node") +parser.add_argument("--redis-address", required=True, type=str, + help="the address to use for Redis") +parser.add_argument("--object-store-name", required=True, type=str, + help="the object store's name") +parser.add_argument("--object-store-manager-name", required=True, type=str, + help="the object store manager's name") +parser.add_argument("--local-scheduler-name", required=True, type=str, + help="the local scheduler's name") +parser.add_argument("--actor-id", required=False, type=str, + help="the actor ID of this worker") + def random_string(): return np.random.bytes(20) + if __name__ == "__main__": args = parser.parse_args() info = {"node_ip_address": args.node_ip_address, @@ -30,7 +38,10 @@ if __name__ == "__main__": "manager_socket_name": args.object_store_manager_name, "local_scheduler_socket_name": args.local_scheduler_name} - actor_id = binascii.unhexlify(args.actor_id) if not args.actor_id is None else ray.worker.NIL_ACTOR_ID + if args.actor_id is not None: + actor_id = binascii.unhexlify(args.actor_id) + else: + actor_id = ray.worker.NIL_ACTOR_ID ray.worker.connect(info, mode=ray.WORKER_MODE, actor_id=actor_id) @@ -43,24 +54,27 @@ being caught in "python/ray/workers/default_worker.py". while True: try: # This call to main_loop should never return if things are working. Most - # exceptions that are thrown (e.g., inside the execution of a task) should - # be caught and handled inside of the call to main_loop. If an exception - # is thrown here, then that means that there is some error that we didn't - # anticipate. + # exceptions that are thrown (e.g., inside the execution of a task) + # should be caught and handled inside of the call to main_loop. If an + # exception is thrown here, then that means that there is some error that + # we didn't anticipate. ray.worker.main_loop() except Exception as e: traceback_str = traceback.format_exc() + error_explanation DRIVER_ID_LENGTH = 20 - # We use a driver ID of all zeros to push an error message to all drivers. + # We use a driver ID of all zeros to push an error message to all + # drivers. driver_id = DRIVER_ID_LENGTH * b"\x00" error_key = b"Error:" + driver_id + b":" + random_string() redis_ip_address, redis_port = args.redis_address.split(":") # For this command to work, some other client (on the same machine as # Redis) must have run "CONFIG SET protected-mode no". - redis_client = redis.StrictRedis(host=redis_ip_address, port=int(redis_port)) + redis_client = redis.StrictRedis(host=redis_ip_address, + port=int(redis_port)) redis_client.hmset(error_key, {"type": "worker_crash", "message": traceback_str, - "note": "This error is unexpected and should not have happened."}) + "note": ("This error is unexpected and " + "should not have happened.")}) redis_client.rpush("ErrorKeys", error_key) # TODO(rkn): Note that if the worker was in the middle of executing a # task, the any worker or driver that is blocking in a get call and diff --git a/python/setup.py b/python/setup.py index c943ac04d..47840c441 100644 --- a/python/setup.py +++ b/python/setup.py @@ -2,12 +2,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os import subprocess from setuptools import setup, find_packages import setuptools.command.install as _install + class install(_install.install): def run(self): subprocess.check_call(["../build.sh"]) @@ -16,19 +16,24 @@ class install(_install.install): # setuptools. So, calling do_egg_install() manually here. self.do_egg_install() + +package_data = { + "ray": ["core/src/common/thirdparty/redis/src/redis-server", + "core/src/common/redis_module/libray_redis_module.so", + "core/src/plasma/plasma_store", + "core/src/plasma/plasma_manager", + "core/src/plasma/libplasma.so", + "core/src/local_scheduler/local_scheduler", + "core/src/local_scheduler/liblocal_scheduler_library.so", + "core/src/numbuf/libarrow.so", + "core/src/numbuf/libnumbuf.so", + "core/src/global_scheduler/global_scheduler"] +} + setup(name="ray", version="0.0.1", packages=find_packages(), - package_data={"ray": ["core/src/common/thirdparty/redis/src/redis-server", - "core/src/common/redis_module/libray_redis_module.so", - "core/src/plasma/plasma_store", - "core/src/plasma/plasma_manager", - "core/src/plasma/libplasma.so", - "core/src/local_scheduler/local_scheduler", - "core/src/local_scheduler/liblocal_scheduler_library.so", - "core/src/numbuf/libarrow.so", - "core/src/numbuf/libnumbuf.so", - "core/src/global_scheduler/global_scheduler"]}, + package_data=package_data, cmdclass={"install": install}, install_requires=["numpy", "funcsigs", diff --git a/scripts/start_ray.py b/scripts/start_ray.py index feb61a567..284179ca6 100644 --- a/scripts/start_ray.py +++ b/scripts/start_ray.py @@ -7,15 +7,25 @@ import redis import ray.services as services -parser = argparse.ArgumentParser(description="Parse addresses for the worker to connect to.") -parser.add_argument("--node-ip-address", required=False, type=str, help="the IP address of the worker's node") -parser.add_argument("--redis-address", required=False, type=str, help="the address to use for connecting to Redis") -parser.add_argument("--redis-port", required=False, type=str, help="the port to use for starting Redis") -parser.add_argument("--object-manager-port", required=False, type=int, help="the port to use for starting the object manager") -parser.add_argument("--num-workers", default=10, required=False, type=int, help="the number of workers to start on this node") -parser.add_argument("--num-cpus", required=False, type=int, help="the number of CPUs on this node") -parser.add_argument("--num-gpus", required=False, type=int, help="the number of GPUs on this node") -parser.add_argument("--head", action="store_true", help="provide this argument for the head node") +parser = argparse.ArgumentParser( + description="Parse addresses for the worker to connect to.") +parser.add_argument("--node-ip-address", required=False, type=str, + help="the IP address of the worker's node") +parser.add_argument("--redis-address", required=False, type=str, + help="the address to use for connecting to Redis") +parser.add_argument("--redis-port", required=False, type=str, + help="the port to use for starting Redis") +parser.add_argument("--object-manager-port", required=False, type=int, + help="the port to use for starting the object manager") +parser.add_argument("--num-workers", default=10, required=False, type=int, + help="the number of workers to start on this node") +parser.add_argument("--num-cpus", required=False, type=int, + help="the number of CPUs on this node") +parser.add_argument("--num-gpus", required=False, type=int, + help="the number of GPUs on this node") +parser.add_argument("--head", action="store_true", + help="provide this argument for the head node") + def check_no_existing_redis_clients(node_ip_address, redis_address): redis_ip_address, redis_port = redis_address.split(":") @@ -39,7 +49,9 @@ def check_no_existing_redis_clients(node_ip_address, redis_address): continue if info[b"node_ip_address"].decode("ascii") == node_ip_address: - raise Exception("This Redis instance is already connected to clients with this IP address.") + raise Exception("This Redis instance is already connected to clients " + "with this IP address.") + if __name__ == "__main__": args = parser.parse_args() @@ -52,7 +64,8 @@ if __name__ == "__main__": if args.head: # Start Ray on the head node. if args.redis_address is not None: - raise Exception("If --head is passed in, a Redis server will be started, so a Redis address should not be provided.") + raise Exception("If --head is passed in, a Redis server will be " + "started, so a Redis address should not be provided.") # Get the node IP address if one is not provided. if args.node_ip_address is None: @@ -82,25 +95,27 @@ if __name__ == "__main__": print(address_info) print("\nStarted Ray with {} workers on this node. A different number of " "workers can be set with the --num-workers flag (but you have to " - "first terminate the existing cluster). You can add additional nodes " - "to the cluster by calling\n\n" + "first terminate the existing cluster). You can add additional " + "nodes to the cluster by calling\n\n" " ./scripts/start_ray.sh --redis-address {}\n\n" "from the node you wish to add. You can connect a driver to the " "cluster from Python by running\n\n" " import ray\n" " ray.init(redis_address=\"{}\")\n\n" - "If you have trouble connecting from a different machine, check that " - "your firewall is configured properly. If you wish to terminate the " - "processes that have been started, run\n\n" + "If you have trouble connecting from a different machine, check " + "that your firewall is configured properly. If you wish to " + "terminate the processes that have been started, run\n\n" " ./scripts/stop_ray.sh".format(args.num_workers, address_info["redis_address"], address_info["redis_address"])) else: # Start Ray on a non-head node. if args.redis_port is not None: - raise Exception("If --head is not passed in, --redis-port is not allowed") + raise Exception("If --head is not passed in, --redis-port is not " + "allowed") if args.redis_address is None: - raise Exception("If --head is not passed in, --redis-address must be provided.") + raise Exception("If --head is not passed in, --redis-address must be " + "provided.") redis_ip_address, redis_port = args.redis_address.split(":") # Wait for the Redis server to be started. And throw an exception if we # can't connect to it. @@ -115,14 +130,15 @@ if __name__ == "__main__": # connected with this Redis instance. This raises an exception if the Redis # server already has clients on this node. check_no_existing_redis_clients(node_ip_address, args.redis_address) - address_info = services.start_ray_node(node_ip_address=node_ip_address, - redis_address=args.redis_address, - object_manager_ports=[args.object_manager_port], - num_workers=args.num_workers, - cleanup=False, - redirect_output=True, - num_cpus=args.num_cpus, - num_gpus=args.num_gpus) + address_info = services.start_ray_node( + node_ip_address=node_ip_address, + redis_address=args.redis_address, + object_manager_ports=[args.object_manager_port], + num_workers=args.num_workers, + cleanup=False, + redirect_output=True, + num_cpus=args.num_cpus, + num_gpus=args.num_gpus) print(address_info) print("\nStarted {} workers on this node. A different number of workers " "can be set with the --num-workers flag (but you have to first " diff --git a/src/numbuf/numbuf/__init__.py b/src/numbuf/numbuf/__init__.py deleted file mode 100644 index a692edb79..000000000 --- a/src/numbuf/numbuf/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# See https://github.com/ray-project/ray/issues/131. -helpful_message = """ - -If you are using Anaconda, try fixing this problem by running: - - conda install libgcc -""" - -try: - from .libnumbuf import * -except ImportError as e: - if hasattr(e, "msg") and isinstance(e.msg, str) and "GLIBCXX" in e.msg: - # This code path should be taken with Python 3. - e.msg += helpful_message - elif hasattr(e, "message") and isinstance(e.message, str) and "GLIBCXX" in e.message: - # This code path should be taken with Python 2. - if hasattr(e, "args") and isinstance(e.args, tuple) and len(e.args) == 1 and isinstance(e.args[0], str): - e.args = (e.args[0] + helpful_message,) - else: - if not hasattr(e, "args"): - e.args = () - elif not isinstance(e.args, tuple): - e.args = (e.args,) - e.args += (helpful_message,) - raise diff --git a/src/numbuf/python/test/runtest.py b/src/numbuf/python/test/runtest.py index f0c4daca6..1932bb58b 100644 --- a/src/numbuf/python/test/runtest.py +++ b/src/numbuf/python/test/runtest.py @@ -9,19 +9,21 @@ from numpy.testing import assert_equal import os import sys -TEST_OBJECTS = [{(1,2) : 1}, {() : 2}, [1, "hello", 3.0], 42, 43, "hello world", +TEST_OBJECTS = [{(1, 2): 1}, {(): 2}, [1, "hello", 3.0], 42, 43, + "hello world", u"x", u"\u262F", 42.0, 1 << 62, (1.0, "hi"), None, (None, None), ("hello", None), True, False, (True, False), "hello", {True: "hello", False: "world"}, - {"hello" : "world", 1: 42, 1.0: 45}, {}, + {"hello": "world", 1: 42, 2.5: 45}, {}, np.int8(3), np.int32(4), np.int64(5), np.uint8(3), np.uint32(4), np.uint64(5), np.float32(1.0), np.float64(1.0)] if sys.version_info < (3, 0): - TEST_OBJECTS += [long(42), long(1 << 62)] + TEST_OBJECTS += [long(42), long(1 << 62)] # noqa: F821 + class SerializationTests(unittest.TestCase): @@ -47,14 +49,16 @@ class SerializationTests(unittest.TestCase): self.roundTripTest([{"hello": [1, 2, 3]}]) self.roundTripTest([{"hello": [1, [2, 3]]}]) self.roundTripTest([{"hello": (None, 2, [3, 4])}]) - self.roundTripTest([{"hello": (None, 2, [3, 4], np.array([1.0, 2.0, 3.0]))}]) + self.roundTripTest( + [{"hello": (None, 2, [3, 4], np.array([1.0, 2.0, 3.0]))}]) def numpyTest(self, t): a = np.random.randint(0, 10, size=(100, 100)).astype(t) self.roundTripTest([a]) def testArrays(self): - for t in ["int8", "uint8", "int16", "uint16", "int32", "uint32", "float32", "float64"]: + for t in ["int8", "uint8", "int16", "uint16", "int32", "uint32", "float32", + "float64"]: self.numpyTest(t) def testRay(self): @@ -165,5 +169,6 @@ class SerializationTests(unittest.TestCase): print("Not running testArrowLimits on Travis because of the test's " "memory requirements.") + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/plasma/setup.py b/src/plasma/setup.py index 45627c03b..1df0d4b00 100644 --- a/src/plasma/setup.py +++ b/src/plasma/setup.py @@ -7,11 +7,13 @@ import setuptools.command.install as _install import subprocess + class install(_install.install): def run(self): subprocess.check_call(["make"]) subprocess.check_call(["cp", "build/plasma_store", "plasma/plasma_store"]) - subprocess.check_call(["cp", "build/plasma_manager", "plasma/plasma_manager"]) + subprocess.check_call(["cp", "build/plasma_manager", + "plasma/plasma_manager"]) subprocess.check_call(["cmake", ".."], cwd="./build") subprocess.check_call(["make", "install"], cwd="./build") # Calling _install.install.run(self) does not fetch required packages and @@ -19,14 +21,14 @@ class install(_install.install): # setuptools. So, calling do_egg_install() manually here. self.do_egg_install() + setup(name="Plasma", version="0.0.1", description="Plasma client for Python", packages=find_packages(), package_data={"plasma": ["plasma_store", "plasma_manager", - "libplasma.so"], - }, + "libplasma.so"]}, cmdclass={"install": install}, include_package_data=True, zip_safe=False) diff --git a/test/actor_test.py b/test/actor_test.py index 154d2adbd..066b07db4 100644 --- a/test/actor_test.py +++ b/test/actor_test.py @@ -2,11 +2,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import unittest import numpy as np -import time +import unittest + import ray + class ActorAPI(unittest.TestCase): def testKeywordArgs(self): @@ -18,6 +19,7 @@ class ActorAPI(unittest.TestCase): self.arg0 = arg0 self.arg1 = arg1 self.arg2 = arg2 + def get_values(self, arg0, arg1=2, arg2="b"): return self.arg0 + arg0, self.arg1 + arg1, self.arg2 + arg2 @@ -53,6 +55,7 @@ class ActorAPI(unittest.TestCase): self.arg0 = arg0 self.arg1 = arg1 self.args = args + def get_values(self, arg0, arg1=2, *args): return self.arg0 + arg0, self.arg1 + arg1, self.args, args @@ -63,10 +66,12 @@ class ActorAPI(unittest.TestCase): self.assertEqual(ray.get(actor.get_values(2, 3)), (3, 5, (), ())) actor = Actor(1, 2, "c") - self.assertEqual(ray.get(actor.get_values(2, 3, "d")), (3, 5, ("c",), ("d",))) + self.assertEqual(ray.get(actor.get_values(2, 3, "d")), + (3, 5, ("c",), ("d",))) actor = Actor(1, 2, "a", "b", "c", "d") - self.assertEqual(ray.get(actor.get_values(2, 3, 1, 2, 3, 4)), (3, 5, ("a", "b", "c", "d"), (1, 2, 3, 4))) + self.assertEqual(ray.get(actor.get_values(2, 3, 1, 2, 3, 4)), + (3, 5, ("a", "b", "c", "d"), (1, 2, 3, 4))) ray.worker.cleanup() @@ -77,6 +82,7 @@ class ActorAPI(unittest.TestCase): class Actor(object): def __init__(self): pass + def get_values(self): pass @@ -112,8 +118,10 @@ class ActorAPI(unittest.TestCase): def __init__(self, f2): self.f1 = Foo(1) self.f2 = f2 + def get_values1(self): return self.f1, self.f2 + def get_values2(self, f3): return self.f1, self.f2, f3 @@ -144,38 +152,39 @@ class ActorAPI(unittest.TestCase): # This is an invalid way of using the actor decorator. with self.assertRaises(Exception): - @ray.actor(invalid_kwarg=0) + @ray.actor(invalid_kwarg=0) # noqa: F811 class Actor(object): def __init__(self): pass # This is an invalid way of using the actor decorator. with self.assertRaises(Exception): - @ray.actor(num_cpus=0, invalid_kwarg=0) + @ray.actor(num_cpus=0, invalid_kwarg=0) # noqa: F811 class Actor(object): def __init__(self): pass # This is a valid way of using the decorator. - @ray.actor(num_cpus=1) + @ray.actor(num_cpus=1) # noqa: F811 class Actor(object): def __init__(self): pass # This is a valid way of using the decorator. - @ray.actor(num_gpus=1) + @ray.actor(num_gpus=1) # noqa: F811 class Actor(object): def __init__(self): pass # This is a valid way of using the decorator. - @ray.actor(num_cpus=1, num_gpus=1) + @ray.actor(num_cpus=1, num_gpus=1) # noqa: F811 class Actor(object): def __init__(self): pass ray.worker.cleanup() + class ActorMethods(unittest.TestCase): def testDefineActor(self): @@ -185,6 +194,7 @@ class ActorMethods(unittest.TestCase): class Test(object): def __init__(self, x): self.x = x + def f(self, y): return self.x + y @@ -200,8 +210,10 @@ class ActorMethods(unittest.TestCase): class Counter(object): def __init__(self): self.value = 0 + def increase(self): self.value += 1 + def value(self): return self.value @@ -224,9 +236,11 @@ class ActorMethods(unittest.TestCase): class Counter(object): def __init__(self, value): self.value = value + def increase(self): self.value += 1 return self.value + def reset(self): self.value = 0 @@ -240,7 +254,9 @@ class ActorMethods(unittest.TestCase): results += [actors[i].increase() for _ in range(num_increases)] result_values = ray.get(results) for i in range(num_actors): - self.assertEqual(result_values[(num_increases * i):(num_increases * (i + 1))], list(range(i + 1, num_increases + i + 1))) + self.assertEqual( + result_values[(num_increases * i):(num_increases * (i + 1))], + list(range(i + 1, num_increases + i + 1))) # Reset the actor values. [actor.reset() for actor in actors] @@ -251,10 +267,12 @@ class ActorMethods(unittest.TestCase): results += [actor.increase() for actor in actors] result_values = ray.get(results) for j in range(num_increases): - self.assertEqual(result_values[(num_actors * j):(num_actors * (j + 1))], num_actors * [j + 1]) + self.assertEqual(result_values[(num_actors * j):(num_actors * (j + 1))], + num_actors * [j + 1]) ray.worker.cleanup() + class ActorNesting(unittest.TestCase): def testRemoteFunctionWithinActor(self): @@ -302,7 +320,8 @@ class ActorNesting(unittest.TestCase): self.assertEqual(ray.get(ray.get(actor.f())), list(range(1, 6))) self.assertEqual(ray.get(actor.g()), list(range(1, 6))) - self.assertEqual(ray.get(actor.h([f.remote(i) for i in range(5)])), list(range(1, 6))) + self.assertEqual(ray.get(actor.h([f.remote(i) for i in range(5)])), + list(range(1, 6))) ray.worker.cleanup() @@ -320,6 +339,7 @@ class ActorNesting(unittest.TestCase): class Actor2(object): def __init__(self, x): self.x = x + def get_value(self): return self.x self.actor2 = Actor2(z) @@ -370,13 +390,15 @@ class ActorNesting(unittest.TestCase): class Actor1(object): def __init__(self, x): self.x = x + def get_value(self): return self.x actor = Actor1(x) return ray.get([actor.get_value() for _ in range(n)]) self.assertEqual(ray.get(f.remote(3, 1)), [3]) - self.assertEqual(ray.get([f.remote(i, 20) for i in range(10)]), [20 * [i] for i in range(10)]) + self.assertEqual(ray.get([f.remote(i, 20) for i in range(10)]), + [20 * [i] for i in range(10)]) ray.worker.cleanup() @@ -421,8 +443,10 @@ class ActorNesting(unittest.TestCase): def __init__(self): # This should use the last version of f. self.x = ray.get(f.remote()) + def get_val(self): return self.x + actor = Actor() return ray.get(actor.get_val()) @@ -430,6 +454,7 @@ class ActorNesting(unittest.TestCase): ray.worker.cleanup() + class ActorInheritance(unittest.TestCase): def testInheritActorFromClass(self): @@ -440,8 +465,10 @@ class ActorInheritance(unittest.TestCase): class Foo(object): def __init__(self, x): self.x = x + def f(self): return self.x + def g(self, y): return self.x + y @@ -449,6 +476,7 @@ class ActorInheritance(unittest.TestCase): class Actor(Foo): def __init__(self, x): Foo.__init__(self, x) + def get_value(self): return self.f() @@ -458,6 +486,7 @@ class ActorInheritance(unittest.TestCase): ray.worker.cleanup() + class ActorSchedulingProperties(unittest.TestCase): def testRemoteFunctionsNotScheduledOnActors(self): @@ -469,7 +498,7 @@ class ActorSchedulingProperties(unittest.TestCase): def __init__(self): pass - actor = Actor() + Actor() @ray.remote def f(): @@ -477,22 +506,26 @@ class ActorSchedulingProperties(unittest.TestCase): # Make sure that f cannot be scheduled on the worker created for the actor. # The wait call should time out. - ready_ids, remaining_ids = ray.wait([f.remote() for _ in range(10)], timeout=3000) + ready_ids, remaining_ids = ray.wait([f.remote() for _ in range(10)], + timeout=3000) self.assertEqual(ready_ids, []) self.assertEqual(len(remaining_ids), 10) ray.worker.cleanup() + class ActorsOnMultipleNodes(unittest.TestCase): def testActorLoadBalancing(self): num_local_schedulers = 3 - ray.worker._init(start_ray_local=True, num_workers=0, num_local_schedulers=num_local_schedulers) + ray.worker._init(start_ray_local=True, num_workers=0, + num_local_schedulers=num_local_schedulers) @ray.actor class Actor1(object): def __init__(self): pass + def get_location(self): return ray.worker.global_worker.plasma_client.store_socket_name @@ -509,7 +542,8 @@ class ActorsOnMultipleNodes(unittest.TestCase): names = set(locations) counts = [locations.count(name) for name in names] print("Counts are {}.".format(counts)) - if len(names) == num_local_schedulers and all([count >= minimum_count for count in counts]): + if len(names) == num_local_schedulers and all([count >= minimum_count + for count in counts]): break attempts += 1 self.assertLess(attempts, num_attempts) @@ -523,26 +557,32 @@ class ActorsOnMultipleNodes(unittest.TestCase): ray.worker.cleanup() + class ActorsWithGPUs(unittest.TestCase): def testActorGPUs(self): num_local_schedulers = 3 num_gpus_per_scheduler = 4 - ray.worker._init(start_ray_local=True, num_workers=0, - num_local_schedulers=num_local_schedulers, - num_gpus=(num_local_schedulers * [num_gpus_per_scheduler])) + ray.worker._init( + start_ray_local=True, num_workers=0, + num_local_schedulers=num_local_schedulers, + num_gpus=(num_local_schedulers * [num_gpus_per_scheduler])) @ray.actor(num_gpus=1) class Actor1(object): def __init__(self): self.gpu_ids = ray.get_gpu_ids() + def get_location_and_ids(self): - return ray.worker.global_worker.plasma_client.store_socket_name, tuple(self.gpu_ids) + return (ray.worker.global_worker.plasma_client.store_socket_name, + tuple(self.gpu_ids)) # Create one actor per GPU. - actors = [Actor1() for _ in range(num_local_schedulers * num_gpus_per_scheduler)] + actors = [Actor1() for _ + in range(num_local_schedulers * num_gpus_per_scheduler)] # Make sure that no two actors are assigned to the same GPU. - locations_and_ids = ray.get([actor.get_location_and_ids() for actor in actors]) + locations_and_ids = ray.get([actor.get_location_and_ids() + for actor in actors]) node_names = set([location for location, gpu_id in locations_and_ids]) self.assertEqual(len(node_names), num_local_schedulers) location_actor_combinations = [] @@ -553,28 +593,32 @@ class ActorsWithGPUs(unittest.TestCase): # Creating a new actor should fail because all of the GPUs are being used. with self.assertRaises(Exception): - a = Actor1() + Actor1() ray.worker.cleanup() def testActorMultipleGPUs(self): num_local_schedulers = 3 num_gpus_per_scheduler = 5 - ray.worker._init(start_ray_local=True, num_workers=0, - num_local_schedulers=num_local_schedulers, - num_gpus=(num_local_schedulers * [num_gpus_per_scheduler])) + ray.worker._init( + start_ray_local=True, num_workers=0, + num_local_schedulers=num_local_schedulers, + num_gpus=(num_local_schedulers * [num_gpus_per_scheduler])) @ray.actor(num_gpus=2) class Actor1(object): def __init__(self): self.gpu_ids = ray.get_gpu_ids() + def get_location_and_ids(self): - return ray.worker.global_worker.plasma_client.store_socket_name, tuple(self.gpu_ids) + return (ray.worker.global_worker.plasma_client.store_socket_name, + tuple(self.gpu_ids)) # Create some actors. actors = [Actor1() for _ in range(num_local_schedulers * 2)] # Make sure that no two actors are assigned to the same GPU. - locations_and_ids = ray.get([actor.get_location_and_ids() for actor in actors]) + locations_and_ids = ray.get([actor.get_location_and_ids() + for actor in actors]) node_names = set([location for location, gpu_id in locations_and_ids]) self.assertEqual(len(node_names), num_local_schedulers) location_actor_combinations = [] @@ -585,20 +629,23 @@ class ActorsWithGPUs(unittest.TestCase): # Creating a new actor should fail because all of the GPUs are being used. with self.assertRaises(Exception): - a = Actor1() + Actor1() # We should be able to create more actors that use only a single GPU. @ray.actor(num_gpus=1) class Actor2(object): def __init__(self): self.gpu_ids = ray.get_gpu_ids() + def get_location_and_ids(self): - return ray.worker.global_worker.plasma_client.store_socket_name, tuple(self.gpu_ids) + return (ray.worker.global_worker.plasma_client.store_socket_name, + tuple(self.gpu_ids)) # Create some actors. actors = [Actor2() for _ in range(num_local_schedulers)] # Make sure that no two actors are assigned to the same GPU. - locations_and_ids = ray.get([actor.get_location_and_ids() for actor in actors]) + locations_and_ids = ray.get([actor.get_location_and_ids() + for actor in actors]) node_names = set([location for location, gpu_id in locations_and_ids]) self.assertEqual(len(node_names), num_local_schedulers) location_actor_combinations = [] @@ -608,13 +655,13 @@ class ActorsWithGPUs(unittest.TestCase): # Creating a new actor should fail because all of the GPUs are being used. with self.assertRaises(Exception): - a = Actor2() + Actor2() ray.worker.cleanup() def testActorDifferentNumbersOfGPUs(self): - # Test that we can create actors on two nodes that have different numbers of - # GPUs. + # Test that we can create actors on two nodes that have different numbers + # of GPUs. ray.worker._init(start_ray_local=True, num_workers=0, num_local_schedulers=3, num_gpus=[0, 5, 10]) @@ -622,32 +669,38 @@ class ActorsWithGPUs(unittest.TestCase): class Actor1(object): def __init__(self): self.gpu_ids = ray.get_gpu_ids() + def get_location_and_ids(self): - return ray.worker.global_worker.plasma_client.store_socket_name, tuple(self.gpu_ids) + return (ray.worker.global_worker.plasma_client.store_socket_name, + tuple(self.gpu_ids)) # Create some actors. actors = [Actor1() for _ in range(0 + 5 + 10)] # Make sure that no two actors are assigned to the same GPU. - locations_and_ids = ray.get([actor.get_location_and_ids() for actor in actors]) + locations_and_ids = ray.get([actor.get_location_and_ids() + for actor in actors]) node_names = set([location for location, gpu_id in locations_and_ids]) self.assertEqual(len(node_names), 2) for node_name in node_names: - node_gpu_ids = [gpu_id for location, gpu_id in locations_and_ids if location == node_name] + node_gpu_ids = [gpu_id for location, gpu_id in locations_and_ids + if location == node_name] self.assertIn(len(node_gpu_ids), [5, 10]) - self.assertEqual(set(node_gpu_ids), set([(i,) for i in range(len(node_gpu_ids))])) + self.assertEqual(set(node_gpu_ids), + set([(i,) for i in range(len(node_gpu_ids))])) # Creating a new actor should fail because all of the GPUs are being used. with self.assertRaises(Exception): - a = Actor1() + Actor1() ray.worker.cleanup() def testActorMultipleGPUsFromMultipleTasks(self): num_local_schedulers = 10 num_gpus_per_scheduler = 10 - ray.worker._init(start_ray_local=True, num_workers=0, - num_local_schedulers=num_local_schedulers, - num_gpus=(num_local_schedulers * [num_gpus_per_scheduler])) + ray.worker._init( + start_ray_local=True, num_workers=0, + num_local_schedulers=num_local_schedulers, + num_gpus=(num_local_schedulers * [num_gpus_per_scheduler])) @ray.remote def create_actors(n): @@ -655,20 +708,25 @@ class ActorsWithGPUs(unittest.TestCase): class Actor(object): def __init__(self): self.gpu_ids = ray.get_gpu_ids() + def get_location_and_ids(self): - return ray.worker.global_worker.plasma_client.store_socket_name, tuple(self.gpu_ids) + return (ray.worker.global_worker.plasma_client.store_socket_name, + tuple(self.gpu_ids)) # Create n actors. for _ in range(n): Actor() - ray.get([create_actors.remote(num_gpus_per_scheduler) for _ in range(num_local_schedulers)]) + ray.get([create_actors.remote(num_gpus_per_scheduler) + for _ in range(num_local_schedulers)]) @ray.actor(num_gpus=1) class Actor(object): def __init__(self): self.gpu_ids = ray.get_gpu_ids() + def get_location_and_ids(self): - return ray.worker.global_worker.plasma_client.store_socket_name, tuple(self.gpu_ids) + return (ray.worker.global_worker.plasma_client.store_socket_name, + tuple(self.gpu_ids)) # All the GPUs should be used up now. with self.assertRaises(Exception): @@ -676,5 +734,6 @@ class ActorsWithGPUs(unittest.TestCase): ray.worker.cleanup() + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/test/array_test.py b/test/array_test.py index 3f0ed0297..bd5366744 100644 --- a/test/array_test.py +++ b/test/array_test.py @@ -5,20 +5,21 @@ from __future__ import print_function import unittest import ray import numpy as np -import time from numpy.testing import assert_equal, assert_almost_equal import sys -if sys.version_info >= (3, 0): - from importlib import reload - import ray.experimental.array.remote as ra import ray.experimental.array.distributed as da +if sys.version_info >= (3, 0): + from importlib import reload + + class RemoteArrayTest(unittest.TestCase): def testMethods(self): - for module in [ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg]: + for module in [ra.core, ra.random, ra.linalg, da.core, da.random, + da.linalg]: reload(module) ray.init(num_workers=1) @@ -49,24 +50,30 @@ class RemoteArrayTest(unittest.TestCase): ray.worker.cleanup() + class DistributedArrayTest(unittest.TestCase): def testAssemble(self): - for module in [ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg]: + for module in [ra.core, ra.random, ra.linalg, da.core, da.random, + da.linalg]: reload(module) ray.init(num_workers=1) a = ra.ones.remote([da.BLOCK_SIZE, da.BLOCK_SIZE]) b = ra.zeros.remote([da.BLOCK_SIZE, da.BLOCK_SIZE]) x = da.DistArray([2 * da.BLOCK_SIZE, da.BLOCK_SIZE], np.array([[a], [b]])) - assert_equal(x.assemble(), np.vstack([np.ones([da.BLOCK_SIZE, da.BLOCK_SIZE]), np.zeros([da.BLOCK_SIZE, da.BLOCK_SIZE])])) + assert_equal(x.assemble(), + np.vstack([np.ones([da.BLOCK_SIZE, da.BLOCK_SIZE]), + np.zeros([da.BLOCK_SIZE, da.BLOCK_SIZE])])) ray.worker.cleanup() def testMethods(self): - for module in [ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg]: + for module in [ra.core, ra.random, ra.linalg, da.core, da.random, + da.linalg]: reload(module) - ray.worker._init(start_ray_local=True, num_workers=10, num_local_schedulers=2, num_cpus=[10, 10]) + ray.worker._init(start_ray_local=True, num_workers=10, + num_local_schedulers=2, num_cpus=[10, 10]) x = da.zeros.remote([9, 25, 51], "float") assert_equal(ray.get(da.assemble.remote(x)), np.zeros([9, 25, 51])) @@ -76,18 +83,21 @@ class DistributedArrayTest(unittest.TestCase): x = da.random.normal.remote([11, 25, 49]) y = da.copy.remote(x) - assert_equal(ray.get(da.assemble.remote(x)), ray.get(da.assemble.remote(y))) + assert_equal(ray.get(da.assemble.remote(x)), + ray.get(da.assemble.remote(y))) x = da.eye.remote(25, dtype_name="float") assert_equal(ray.get(da.assemble.remote(x)), np.eye(25)) x = da.random.normal.remote([25, 49]) y = da.triu.remote(x) - assert_equal(ray.get(da.assemble.remote(y)), np.triu(ray.get(da.assemble.remote(x)))) + assert_equal(ray.get(da.assemble.remote(y)), + np.triu(ray.get(da.assemble.remote(x)))) x = da.random.normal.remote([25, 49]) y = da.tril.remote(x) - assert_equal(ray.get(da.assemble.remote(y)), np.tril(ray.get(da.assemble.remote(x)))) + assert_equal(ray.get(da.assemble.remote(y)), + np.tril(ray.get(da.assemble.remote(x)))) x = da.random.normal.remote([25, 49]) y = da.random.normal.remote([49, 18]) @@ -102,29 +112,37 @@ class DistributedArrayTest(unittest.TestCase): x = da.random.normal.remote([23, 42]) y = da.random.normal.remote([23, 42]) z = da.add.remote(x, y) - assert_almost_equal(ray.get(da.assemble.remote(z)), ray.get(da.assemble.remote(x)) + ray.get(da.assemble.remote(y))) + assert_almost_equal(ray.get(da.assemble.remote(z)), + ray.get(da.assemble.remote(x)) + + ray.get(da.assemble.remote(y))) # test subtract x = da.random.normal.remote([33, 40]) y = da.random.normal.remote([33, 40]) z = da.subtract.remote(x, y) - assert_almost_equal(ray.get(da.assemble.remote(z)), ray.get(da.assemble.remote(x)) - ray.get(da.assemble.remote(y))) + assert_almost_equal(ray.get(da.assemble.remote(z)), + ray.get(da.assemble.remote(x)) - + ray.get(da.assemble.remote(y))) # test transpose x = da.random.normal.remote([234, 432]) y = da.transpose.remote(x) - assert_equal(ray.get(da.assemble.remote(x)).T, ray.get(da.assemble.remote(y))) + assert_equal(ray.get(da.assemble.remote(x)).T, + ray.get(da.assemble.remote(y))) # test numpy_to_dist x = da.random.normal.remote([23, 45]) y = da.assemble.remote(x) z = da.numpy_to_dist.remote(y) w = da.assemble.remote(z) - assert_equal(ray.get(da.assemble.remote(x)), ray.get(da.assemble.remote(z))) + assert_equal(ray.get(da.assemble.remote(x)), + ray.get(da.assemble.remote(z))) assert_equal(ray.get(y), ray.get(w)) # test da.tsqr - for shape in [[123, da.BLOCK_SIZE], [7, da.BLOCK_SIZE], [da.BLOCK_SIZE, da.BLOCK_SIZE], [da.BLOCK_SIZE, 7], [10 * da.BLOCK_SIZE, da.BLOCK_SIZE]]: + for shape in [[123, da.BLOCK_SIZE], [7, da.BLOCK_SIZE], + [da.BLOCK_SIZE, da.BLOCK_SIZE], [da.BLOCK_SIZE, 7], + [10 * da.BLOCK_SIZE, da.BLOCK_SIZE]]: x = da.random.normal.remote(shape) K = min(shape) q, r = da.linalg.tsqr.remote(x) @@ -138,23 +156,26 @@ class DistributedArrayTest(unittest.TestCase): # test da.linalg.modified_lu def test_modified_lu(d1, d2): - print("testing dist_modified_lu with d1 = " + str(d1) + ", d2 = " + str(d2)) + print("testing dist_modified_lu with d1 = " + str(d1) + + ", d2 = " + str(d2)) assert d1 >= d2 - k = min(d1, d2) m = ra.random.normal.remote([d1, d2]) q, r = ra.linalg.qr.remote(m) l, u, s = da.linalg.modified_lu.remote(da.numpy_to_dist.remote(q)) q_val = ray.get(q) - r_val = ray.get(r) + ray.get(r) l_val = ray.get(da.assemble.remote(l)) u_val = ray.get(u) s_val = ray.get(s) s_mat = np.zeros((d1, d2)) for i in range(len(s_val)): s_mat[i, i] = s_val[i] - assert_almost_equal(q_val - s_mat, np.dot(l_val, u_val)) # check that q - s = l * u - assert_equal(np.triu(u_val), u_val) # check that u is upper triangular - assert_equal(np.tril(l_val), l_val) # check that l is lower triangular + # Check that q - s = l * u. + assert_almost_equal(q_val - s_mat, np.dot(l_val, u_val)) + # Check that u is upper triangular. + assert_equal(np.triu(u_val), u_val) + # Check that l is lower triangular. + assert_equal(np.tril(l_val), l_val) for d1, d2 in [(100, 100), (99, 98), (7, 5), (7, 7), (20, 7), (20, 10)]: test_modified_lu(d1, d2) @@ -172,10 +193,14 @@ class DistributedArrayTest(unittest.TestCase): tall_eye = np.zeros((d1, min(d1, d2))) np.fill_diagonal(tall_eye, 1) q = tall_eye - np.dot(y_val, np.dot(t_val, y_top_val.T)) - assert_almost_equal(np.dot(q.T, q), np.eye(min(d1, d2))) # check that q.T * q = I - assert_almost_equal(np.dot(q, r_val), a_val) # check that a = (I - y * t * y_top.T) * r + # Check that q.T * q = I. + assert_almost_equal(np.dot(q.T, q), np.eye(min(d1, d2))) + # Check that a = (I - y * t * y_top.T) * r. + assert_almost_equal(np.dot(q, r_val), a_val) - for d1, d2 in [(123, da.BLOCK_SIZE), (7, da.BLOCK_SIZE), (da.BLOCK_SIZE, da.BLOCK_SIZE), (da.BLOCK_SIZE, 7), (10 * da.BLOCK_SIZE, da.BLOCK_SIZE)]: + for d1, d2 in [(123, da.BLOCK_SIZE), (7, da.BLOCK_SIZE), + (da.BLOCK_SIZE, da.BLOCK_SIZE), (da.BLOCK_SIZE, 7), + (10 * da.BLOCK_SIZE, da.BLOCK_SIZE)]: test_dist_tsqr_hr(d1, d2) def test_dist_qr(d1, d2): @@ -192,7 +217,9 @@ class DistributedArrayTest(unittest.TestCase): assert_equal(r_val, np.triu(r_val)) assert_almost_equal(a_val, np.dot(q_val, r_val)) - for d1, d2 in [(123, da.BLOCK_SIZE), (7, da.BLOCK_SIZE), (da.BLOCK_SIZE, da.BLOCK_SIZE), (da.BLOCK_SIZE, 7), (13, 21), (34, 35), (8, 7)]: + for d1, d2 in [(123, da.BLOCK_SIZE), (7, da.BLOCK_SIZE), + (da.BLOCK_SIZE, da.BLOCK_SIZE), (da.BLOCK_SIZE, 7), + (13, 21), (34, 35), (8, 7)]: test_dist_qr(d1, d2) test_dist_qr(d2, d1) for _ in range(20): @@ -202,5 +229,6 @@ class DistributedArrayTest(unittest.TestCase): ray.worker.cleanup() + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/test/component_failures_test.py b/test/component_failures_test.py index 49fc7d8ef..f52264dcf 100644 --- a/test/component_failures_test.py +++ b/test/component_failures_test.py @@ -3,10 +3,10 @@ from __future__ import division from __future__ import print_function import ray -import sys import time import unittest + class ComponentFailureTest(unittest.TestCase): def tearDown(self): @@ -16,6 +16,7 @@ class ComponentFailureTest(unittest.TestCase): # store and manager will not die. def testDyingWorkerGet(self): obj_id = 20 * b"a" + @ray.remote def f(): ray.worker.global_worker.plasma_client.get(obj_id) @@ -40,12 +41,14 @@ class ComponentFailureTest(unittest.TestCase): time.sleep(0.1) # Make sure that nothing has died. - self.assertTrue(ray.services.all_processes_alive(exclude=[ray.services.PROCESS_TYPE_WORKER])) + self.assertTrue(ray.services.all_processes_alive( + exclude=[ray.services.PROCESS_TYPE_WORKER])) - # This test checks that when a worker dies in the middle of a wait, the plasma - # store and manager will not die. + # This test checks that when a worker dies in the middle of a wait, the + # plasma store and manager will not die. def testDyingWorkerWait(self): obj_id = 20 * b"a" + @ray.remote def f(): ray.worker.global_worker.plasma_client.wait([obj_id]) @@ -70,7 +73,8 @@ class ComponentFailureTest(unittest.TestCase): time.sleep(0.1) # Make sure that nothing has died. - self.assertTrue(ray.services.all_processes_alive(exclude=[ray.services.PROCESS_TYPE_WORKER])) + self.assertTrue(ray.services.all_processes_alive( + exclude=[ray.services.PROCESS_TYPE_WORKER])) def _testWorkerFailed(self, num_local_schedulers): @ray.remote @@ -86,7 +90,8 @@ class ComponentFailureTest(unittest.TestCase): num_cpus=[num_initial_workers] * num_local_schedulers) # Submit more tasks than there are workers so that all workers and cores # are utilized. - object_ids = [f.remote(i) for i in range(num_initial_workers * num_local_schedulers)] + object_ids = [f.remote(i) for i + in range(num_initial_workers * num_local_schedulers)] object_ids += [f.remote(object_id) for object_id in object_ids] # Allow the tasks some time to begin executing. time.sleep(0.1) @@ -94,7 +99,8 @@ class ComponentFailureTest(unittest.TestCase): for worker in ray.services.all_processes[ray.services.PROCESS_TYPE_WORKER]: worker.terminate() time.sleep(0.1) - # Make sure that we can still get the objects after the executing tasks died. + # Make sure that we can still get the objects after the executing tasks + # died. ray.get(object_ids) def testWorkerFailed(self): @@ -104,8 +110,7 @@ class ComponentFailureTest(unittest.TestCase): self._testWorkerFailed(4) def _testComponentFailed(self, component_type): - """Kill a component on all worker nodes and check that workload succeeds. - """ + """Kill a component on all worker nodes and check workload succeeds.""" @ray.remote def f(x, j): time.sleep(0.2) @@ -114,14 +119,16 @@ class ComponentFailureTest(unittest.TestCase): # Start with 4 workers and 4 cores. num_local_schedulers = 4 num_workers_per_scheduler = 8 - address_info = ray.worker._init(num_workers=num_local_schedulers * num_workers_per_scheduler, - num_local_schedulers=num_local_schedulers, - start_ray_local=True, - num_cpus=[num_workers_per_scheduler] * num_local_schedulers) + ray.worker._init( + num_workers=num_local_schedulers * num_workers_per_scheduler, + num_local_schedulers=num_local_schedulers, + start_ray_local=True, + num_cpus=[num_workers_per_scheduler] * num_local_schedulers) - # Submit more tasks than there are workers so that all workers and cores are - # utilized. - object_ids = [f.remote(i, 0) for i in range(num_workers_per_scheduler * num_local_schedulers)] + # Submit more tasks than there are workers so that all workers and cores + # are utilized. + object_ids = [f.remote(i, 0) for i + in range(num_workers_per_scheduler * num_local_schedulers)] object_ids += [f.remote(object_id, 1) for object_id in object_ids] object_ids += [f.remote(object_id, 2) for object_id in object_ids] @@ -140,7 +147,8 @@ class ComponentFailureTest(unittest.TestCase): # Make sure that we can still get the objects after the executing tasks # died. results = ray.get(object_ids) - expected_results = 4 * list(range(num_workers_per_scheduler * num_local_schedulers)) + expected_results = 4 * list(range( + num_workers_per_scheduler * num_local_schedulers)) self.assertEqual(results, expected_results) def check_components_alive(self, component_type, check_component_alive): @@ -161,7 +169,8 @@ class ComponentFailureTest(unittest.TestCase): # nodes. self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_STORE, True) self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_MANAGER, True) - self.check_components_alive(ray.services.PROCESS_TYPE_LOCAL_SCHEDULER, False) + self.check_components_alive(ray.services.PROCESS_TYPE_LOCAL_SCHEDULER, + False) def testPlasmaManagerFailed(self): # Kill all plasma managers on worker nodes. @@ -170,8 +179,10 @@ class ComponentFailureTest(unittest.TestCase): # The plasma stores should still be alive (but unreachable) on the worker # nodes. self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_STORE, True) - self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_MANAGER, False) - self.check_components_alive(ray.services.PROCESS_TYPE_LOCAL_SCHEDULER, False) + self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_MANAGER, + False) + self.check_components_alive(ray.services.PROCESS_TYPE_LOCAL_SCHEDULER, + False) def testPlasmaStoreFailed(self): # Kill all plasma stores on worker nodes. @@ -179,17 +190,19 @@ class ComponentFailureTest(unittest.TestCase): # No processes should be left alive on the worker nodes. self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_STORE, False) - self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_MANAGER, False) - self.check_components_alive(ray.services.PROCESS_TYPE_LOCAL_SCHEDULER, False) + self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_MANAGER, + False) + self.check_components_alive(ray.services.PROCESS_TYPE_LOCAL_SCHEDULER, + False) def testDriverLivesSequential(self): ray.worker.init() + all_processes = ray.services.all_processes processes = [ - ray.services.all_processes[ray.services.PROCESS_TYPE_PLASMA_STORE][0], - ray.services.all_processes[ray.services.PROCESS_TYPE_PLASMA_MANAGER][0], - ray.services.all_processes[ray.services.PROCESS_TYPE_LOCAL_SCHEDULER][0], - ray.services.all_processes[ray.services.PROCESS_TYPE_GLOBAL_SCHEDULER][0], - ] + all_processes[ray.services.PROCESS_TYPE_PLASMA_STORE][0], + all_processes[ray.services.PROCESS_TYPE_PLASMA_MANAGER][0], + all_processes[ray.services.PROCESS_TYPE_LOCAL_SCHEDULER][0], + all_processes[ray.services.PROCESS_TYPE_GLOBAL_SCHEDULER][0]] # Kill all the components sequentially. for process in processes: @@ -202,12 +215,12 @@ class ComponentFailureTest(unittest.TestCase): def testDriverLivesParallel(self): ray.worker.init() + all_processes = ray.services.all_processes processes = [ - ray.services.all_processes[ray.services.PROCESS_TYPE_PLASMA_STORE][0], - ray.services.all_processes[ray.services.PROCESS_TYPE_PLASMA_MANAGER][0], - ray.services.all_processes[ray.services.PROCESS_TYPE_LOCAL_SCHEDULER][0], - ray.services.all_processes[ray.services.PROCESS_TYPE_GLOBAL_SCHEDULER][0], - ] + all_processes[ray.services.PROCESS_TYPE_PLASMA_STORE][0], + all_processes[ray.services.PROCESS_TYPE_PLASMA_MANAGER][0], + all_processes[ray.services.PROCESS_TYPE_LOCAL_SCHEDULER][0], + all_processes[ray.services.PROCESS_TYPE_GLOBAL_SCHEDULER][0]] # Kill all the components in parallel. for process in processes: @@ -222,5 +235,6 @@ class ComponentFailureTest(unittest.TestCase): # If the driver can reach the tearDown method, then it is still alive. + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/test/failure_test.py b/test/failure_test.py index e504072ca..d1fdb2554 100644 --- a/test/failure_test.py +++ b/test/failure_test.py @@ -9,14 +9,16 @@ import tempfile import time import unittest +import ray.test.test_functions as test_functions + if sys.version_info >= (3, 0): from importlib import reload -import ray.test.test_functions as test_functions def relevant_errors(error_type): return [info for info in ray.error_info() if info[b"type"] == error_type] + def wait_for_errors(error_type, num_errors, timeout=10): start_time = time.time() while time.time() - start_time < timeout: @@ -25,6 +27,7 @@ def wait_for_errors(error_type, num_errors, timeout=10): time.sleep(0.1) print("Timing out of wait.") + class FailureTest(unittest.TestCase): def testUnknownSerialization(self): reload(test_functions) @@ -32,32 +35,35 @@ class FailureTest(unittest.TestCase): test_functions.test_unknown_type.remote() wait_for_errors(b"task", 1) - error_info = ray.error_info() self.assertEqual(len(relevant_errors(b"task")), 1) ray.worker.cleanup() + class TaskSerializationTest(unittest.TestCase): def testReturnAndPassUnknownType(self): ray.init(num_workers=1, driver_mode=ray.SILENT_MODE) class Foo(object): pass + # Check that returning an unknown type from a remote function raises an # exception. @ray.remote def f(): return Foo() - self.assertRaises(Exception, lambda : ray.get(f.remote())) + self.assertRaises(Exception, lambda: ray.get(f.remote())) + # Check that passing an unknown type into a remote function raises an # exception. @ray.remote def g(x): return 1 - self.assertRaises(Exception, lambda : g.remote(Foo())) + self.assertRaises(Exception, lambda: g.remote(Foo())) ray.worker.cleanup() + class TaskStatusTest(unittest.TestCase): def testFailedTask(self): reload(test_functions) @@ -66,10 +72,10 @@ class TaskStatusTest(unittest.TestCase): test_functions.throw_exception_fct1.remote() test_functions.throw_exception_fct1.remote() wait_for_errors(b"task", 2) - result = ray.error_info() self.assertEqual(len(relevant_errors(b"task")), 2) for task in relevant_errors(b"task"): - self.assertIn(b"Test function 1 intentionally failed.", task.get(b"message")) + self.assertIn(b"Test function 1 intentionally failed.", + task.get(b"message")) x = test_functions.throw_exception_fct2.remote() try: @@ -77,7 +83,8 @@ class TaskStatusTest(unittest.TestCase): except Exception as e: self.assertIn("Test function 2 intentionally failed.", str(e)) else: - self.assertTrue(False) # ray.get should throw an exception + # ray.get should throw an exception. + self.assertTrue(False) x, y, z = test_functions.throw_exception_fct3.remote(1.0) for ref in [x, y, z]: @@ -86,7 +93,8 @@ class TaskStatusTest(unittest.TestCase): except Exception as e: self.assertIn("Test function 3 intentionally failed.", str(e)) else: - self.assertTrue(False) # ray.get should throw an exception + # ray.get should throw an exception. + self.assertTrue(False) ray.worker.cleanup() @@ -108,8 +116,8 @@ def temporary_helper_function(): sys.path.append(directory) module = __import__(module_name) - # Define a function that closes over this temporary module. This should fail - # when it is unpickled. + # Define a function that closes over this temporary module. This should + # fail when it is unpickled. @ray.remote def g(): return module.temporary_python_file() @@ -121,7 +129,7 @@ def temporary_helper_function(): # Check that if we try to call the function it throws an exception and does # not hang. for _ in range(10): - self.assertRaises(Exception, lambda : ray.get(g.remote())) + self.assertRaises(Exception, lambda: ray.get(g.remote())) f.close() @@ -150,16 +158,19 @@ def temporary_helper_function(): def initializer(): return 0 + def reinitializer(foo): raise Exception("The reinitializer failed.") ray.env.foo = ray.EnvironmentVariable(initializer, reinitializer) + @ray.remote def use_foo(): ray.env.foo use_foo.remote() wait_for_errors(b"reinitialize_environment_variable", 1) # Check that the error message is in the task info. - self.assertIn(b"The reinitializer failed.", ray.error_info()[0][b"message"]) + self.assertIn(b"The reinitializer failed.", + ray.error_info()[0][b"message"]) ray.worker.cleanup() @@ -202,6 +213,7 @@ def temporary_helper_function(): class Foo(object): def __init__(self): self.x = module.temporary_python_file() + def get_val(self): return 1 @@ -217,7 +229,8 @@ def temporary_helper_function(): # Wait for the error from when the __init__ tries to run. wait_for_errors(b"task", 1) - self.assertIn(b"failed to be imported, and so cannot execute this method", ray.error_info()[1][b"message"]) + self.assertIn(b"failed to be imported, and so cannot execute this method", + ray.error_info()[1][b"message"]) # Check that if we try to get the function it throws an exception and does # not hang. @@ -226,7 +239,8 @@ def temporary_helper_function(): # Wait for the error from when the call to get_val. wait_for_errors(b"task", 2) - self.assertIn(b"failed to be imported, and so cannot execute this method", ray.error_info()[2][b"message"]) + self.assertIn(b"failed to be imported, and so cannot execute this method", + ray.error_info()[2][b"message"]) f.close() @@ -234,6 +248,7 @@ def temporary_helper_function(): sys.path.pop(-1) ray.worker.cleanup() + class ActorTest(unittest.TestCase): def testFailedActorInit(self): @@ -241,12 +256,15 @@ class ActorTest(unittest.TestCase): error_message1 = "actor constructor failed" error_message2 = "actor method failed" + @ray.actor class FailedActor(object): def __init__(self): raise Exception(error_message1) + def get_val(self): return 1 + def fail_method(self): raise Exception(error_message2) @@ -255,13 +273,15 @@ class ActorTest(unittest.TestCase): # Make sure that we get errors from a failed constructor. wait_for_errors(b"task", 1) self.assertEqual(len(ray.error_info()), 1) - self.assertIn(error_message1, ray.error_info()[0][b"message"].decode("ascii")) + self.assertIn(error_message1, + ray.error_info()[0][b"message"].decode("ascii")) # Make sure that we get errors from a failed method. a.fail_method() wait_for_errors(b"task", 2) self.assertEqual(len(ray.error_info()), 2) - self.assertIn(error_message2, ray.error_info()[1][b"message"].decode("ascii")) + self.assertIn(error_message2, + ray.error_info()[1][b"message"].decode("ascii")) ray.worker.cleanup() @@ -272,6 +292,7 @@ class ActorTest(unittest.TestCase): class Actor(object): def __init__(self, missing_variable_name): pass + def get_val(self, x): pass @@ -284,18 +305,22 @@ class ActorTest(unittest.TestCase): wait_for_errors(b"task", 1) self.assertEqual(len(ray.error_info()), 1) if sys.version_info >= (3, 0): - self.assertIn("missing 1 required", ray.error_info()[0][b"message"].decode("ascii")) + self.assertIn("missing 1 required", + ray.error_info()[0][b"message"].decode("ascii")) else: - self.assertIn("takes exactly 2 arguments", ray.error_info()[0][b"message"].decode("ascii")) + self.assertIn("takes exactly 2 arguments", + ray.error_info()[0][b"message"].decode("ascii")) # Create an actor with too many arguments. a = Actor(1, 2) wait_for_errors(b"task", 2) self.assertEqual(len(ray.error_info()), 2) if sys.version_info >= (3, 0): - self.assertIn("but 3 were given", ray.error_info()[1][b"message"].decode("ascii")) + self.assertIn("but 3 were given", + ray.error_info()[1][b"message"].decode("ascii")) else: - self.assertIn("takes exactly 2 arguments", ray.error_info()[1][b"message"].decode("ascii")) + self.assertIn("takes exactly 2 arguments", + ray.error_info()[1][b"message"].decode("ascii")) # Create an actor the correct number of arguments. a = Actor(1) @@ -305,23 +330,28 @@ class ActorTest(unittest.TestCase): wait_for_errors(b"task", 3) self.assertEqual(len(ray.error_info()), 3) if sys.version_info >= (3, 0): - self.assertIn("missing 1 required", ray.error_info()[2][b"message"].decode("ascii")) + self.assertIn("missing 1 required", + ray.error_info()[2][b"message"].decode("ascii")) else: - self.assertIn("takes exactly 2 arguments", ray.error_info()[2][b"message"].decode("ascii")) + self.assertIn("takes exactly 2 arguments", + ray.error_info()[2][b"message"].decode("ascii")) # Call a method with too many arguments. a.get_val(1, 2) wait_for_errors(b"task", 4) self.assertEqual(len(ray.error_info()), 4) if sys.version_info >= (3, 0): - self.assertIn("but 3 were given", ray.error_info()[3][b"message"].decode("ascii")) + self.assertIn("but 3 were given", + ray.error_info()[3][b"message"].decode("ascii")) else: - self.assertIn("takes exactly 2 arguments", ray.error_info()[3][b"message"].decode("ascii")) + self.assertIn("takes exactly 2 arguments", + ray.error_info()[3][b"message"].decode("ascii")) # Call a method that doesn't exist. with self.assertRaises(AttributeError): a.nonexistent_method() ray.worker.cleanup() + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/test/jenkins_tests/multi_node_docker_test.py b/test/jenkins_tests/multi_node_docker_test.py index 1e406f213..2a26df6c4 100644 --- a/test/jenkins_tests/multi_node_docker_test.py +++ b/test/jenkins_tests/multi_node_docker_test.py @@ -7,7 +7,7 @@ import os import re import subprocess import sys -import time + def wait_for_output(proc): """This is a convenience method to parse a process's stdout and stderr. @@ -19,10 +19,13 @@ def wait_for_output(proc): A tuple of the stdout and stderr of the process as strings. """ stdout_data, stderr_data = proc.communicate() - stdout_data = stdout_data.decode("ascii") if stdout_data is not None else None - stderr_data = stderr_data.decode("ascii") if stderr_data is not None else None + stdout_data = (stdout_data.decode("ascii") if stdout_data is not None + else None) + stderr_data = (stderr_data.decode("ascii") if stderr_data is not None + else None) return stdout_data, stderr_data + class DockerRunner(object): """This class manages the logistics of running multiple nodes in Docker. @@ -34,8 +37,8 @@ class DockerRunner(object): head_container_id: The ID of the docker container that runs the head node. worker_container_ids: A list of the docker container IDs of the Ray worker nodes. - head_container_ip: The IP address of the docker container that runs the head - node. + head_container_ip: The IP address of the docker container that runs the + head node. """ def __init__(self): """Initialize the DockerRunner.""" @@ -47,8 +50,8 @@ class DockerRunner(object): """Parse the docker container ID from stdout_data. Args: - stdout_data: This should be a string with the standard output of a call to - a docker command. + stdout_data: This should be a string with the standard output of a call + to a docker command. Returns: The container ID of the docker container. @@ -70,7 +73,8 @@ class DockerRunner(object): The IP address of the container. """ proc = subprocess.Popen(["docker", "inspect", - "--format={{.NetworkSettings.Networks.bridge.IPAddress}}", + "--format={{.NetworkSettings.Networks.bridge" + ".IPAddress}}", container_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout_data, _ = wait_for_output(proc) @@ -86,9 +90,10 @@ class DockerRunner(object): """Start the Ray head node inside a docker container.""" mem_arg = ["--memory=" + mem_size] if mem_size else [] shm_arg = ["--shm-size=" + shm_size] if shm_size else [] - volume_arg = ["-v", - "{}:{}".format(os.path.dirname(os.path.realpath(__file__)), - "/ray/test/jenkins_tests")] if development_mode else [] + volume_arg = (["-v", + "{}:{}".format(os.path.dirname(os.path.realpath(__file__)), + "/ray/test/jenkins_tests")] + if development_mode else []) proc = subprocess.Popen(["docker", "run", "-d"] + mem_arg + shm_arg + volume_arg + [docker_image, "/ray/scripts/start_ray.sh", @@ -113,7 +118,8 @@ class DockerRunner(object): proc = subprocess.Popen(["docker", "run", "-d"] + mem_arg + shm_arg + ["--shm-size=" + shm_size, docker_image, "/ray/scripts/start_ray.sh", - "--redis-address={:s}:6379".format(self.head_container_ip)], + "--redis-address={:s}:6379".format( + self.head_container_ip)], stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout_data, _ = wait_for_output(proc) container_id = self._get_container_id(stdout_data) @@ -136,10 +142,10 @@ class DockerRunner(object): mem_size: The amount of memory to start each docker container with. This will be passed into `docker run` as the --memory flag. If this is None, then no --memory flag will be used. - shm_size: The amount of shared memory to start each docker container with. - This will be passed into `docker run` as the `--shm-size` flag. - num_nodes: The number of nodes to use in the cluster (this counts the head - node as well). + shm_size: The amount of shared memory to start each docker container + with. This will be passed into `docker run` as the `--shm-size` flag. + num_nodes: The number of nodes to use in the cluster (this counts the + head node as well). development_mode: True if you want to mount the local copy of test/jenkins_test on the head node so we can avoid rebuilding docker images during development. @@ -163,7 +169,7 @@ class DockerRunner(object): stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout_data, _ = wait_for_output(proc) removed_container_id = self._get_container_id(stdout_data) - if not container_id == stopped_container_id: + if not container_id == removed_container_id: raise Exception("Failed to remove container {}.".format(container_id)) print("stop_node", {"container_id": container_id, @@ -202,8 +208,10 @@ class DockerRunner(object): print(stderr_data) return {"success": proc.returncode == 0, "return_code": proc.returncode} + if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Run multinode tests in Docker.") + parser = argparse.ArgumentParser( + description="Run multinode tests in Docker.") parser.add_argument("--docker-image", default="ray-project/deploy", help="docker image") parser.add_argument("--mem-size", help="memory size") diff --git a/test/jenkins_tests/multi_node_tests/test_0.py b/test/jenkins_tests/multi_node_tests/test_0.py index 30e6597ee..6b4d52890 100644 --- a/test/jenkins_tests/multi_node_tests/test_0.py +++ b/test/jenkins_tests/multi_node_tests/test_0.py @@ -3,11 +3,13 @@ import time import ray + @ray.remote def f(): time.sleep(0.1) return ray.services.get_node_ip_address() + if __name__ == "__main__": ray.init(redis_address=os.environ["RAY_REDIS_ADDRESS"]) # Check that tasks are scheduled on all nodes. diff --git a/test/microbenchmarks.py b/test/microbenchmarks.py index 85ec735e4..bb2537e16 100644 --- a/test/microbenchmarks.py +++ b/test/microbenchmarks.py @@ -9,10 +9,11 @@ import sys import time import numpy as np +import ray.test.test_functions as test_functions + if sys.version_info >= (3, 0): from importlib import reload -import ray.test.test_functions as test_functions class MicroBenchmarkTest(unittest.TestCase): @@ -20,7 +21,7 @@ class MicroBenchmarkTest(unittest.TestCase): reload(test_functions) ray.init(num_workers=3) - # measure the time required to submit a remote task to the scheduler + # Measure the time required to submit a remote task to the scheduler. elapsed_times = [] for _ in range(1000): start_time = time.time() @@ -34,9 +35,10 @@ class MicroBenchmarkTest(unittest.TestCase): print(" 90th percentile: {}".format(elapsed_times[900])) print(" 99th percentile: {}".format(elapsed_times[990])) print(" worst: {}".format(elapsed_times[999])) - # average_elapsed_time should be about 0.00038 + # average_elapsed_time should be about 0.00038. - # measure the time required to submit a remote task to the scheduler (where the remote task returns one value) + # Measure the time required to submit a remote task to the scheduler + # (where the remote task returns one value). elapsed_times = [] for _ in range(1000): start_time = time.time() @@ -50,9 +52,10 @@ class MicroBenchmarkTest(unittest.TestCase): print(" 90th percentile: {}".format(elapsed_times[900])) print(" 99th percentile: {}".format(elapsed_times[990])) print(" worst: {}".format(elapsed_times[999])) - # average_elapsed_time should be about 0.001 + # average_elapsed_time should be about 0.001. - # measure the time required to submit a remote task to the scheduler and get the result + # Measure the time required to submit a remote task to the scheduler and + # get the result. elapsed_times = [] for _ in range(1000): start_time = time.time() @@ -62,14 +65,15 @@ class MicroBenchmarkTest(unittest.TestCase): elapsed_times.append(end_time - start_time) elapsed_times = np.sort(elapsed_times) average_elapsed_time = sum(elapsed_times) / 1000 - print("Time required to submit a trivial function call and get the result:") + print("Time required to submit a trivial function call and get the " + "result:") print(" Average: {}".format(average_elapsed_time)) print(" 90th percentile: {}".format(elapsed_times[900])) print(" 99th percentile: {}".format(elapsed_times[990])) print(" worst: {}".format(elapsed_times[999])) - # average_elapsed_time should be about 0.0013 + # average_elapsed_time should be about 0.0013. - # measure the time required to do do a put + # Measure the time required to do do a put. elapsed_times = [] for _ in range(1000): start_time = time.time() @@ -83,7 +87,7 @@ class MicroBenchmarkTest(unittest.TestCase): print(" 90th percentile: {}".format(elapsed_times[900])) print(" 99th percentile: {}".format(elapsed_times[990])) print(" worst: {}".format(elapsed_times[999])) - # average_elapsed_time should be about 0.00087 + # average_elapsed_time should be about 0.00087. ray.worker.cleanup() @@ -105,11 +109,14 @@ class MicroBenchmarkTest(unittest.TestCase): if d > 1.5 * b: if os.getenv("TRAVIS") is None: - raise Exception("The caching test was too slow. d = {}, b = {}".format(d, b)) + raise Exception("The caching test was too slow. " + "d = {}, b = {}".format(d, b)) else: - print("WARNING: The caching test was too slow. d = {}, b = {}".format(d, b)) + print("WARNING: The caching test was too slow. " + "d = {}, b = {}".format(d, b)) ray.worker.cleanup() + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/test/multi_node_test.py b/test/multi_node_test.py index 1dabff337..307ad6932 100644 --- a/test/multi_node_test.py +++ b/test/multi_node_test.py @@ -2,17 +2,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np import os import unittest import ray import subprocess -import sys import tempfile import time -start_ray_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../scripts/start_ray.sh") -stop_ray_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../scripts/stop_ray.sh") +start_ray_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), + "../scripts/start_ray.sh") +stop_ray_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), + "../scripts/stop_ray.sh") + class MultiNodeTest(unittest.TestCase): @@ -21,7 +22,8 @@ class MultiNodeTest(unittest.TestCase): out = subprocess.check_output([start_ray_script, "--head"]).decode("ascii") # Get the redis address from the output. redis_substring_prefix = "redis_address=\"" - redis_address_location = out.find(redis_substring_prefix) + len(redis_substring_prefix) + redis_address_location = (out.find(redis_substring_prefix) + + len(redis_substring_prefix)) redis_address = out[redis_address_location:] self.redis_address = redis_address.split("\"")[0] @@ -54,7 +56,8 @@ class MultiNodeTest(unittest.TestCase): # Make sure we got the error. self.assertEqual(len(ray.error_info()), 1) - self.assertIn(error_string1, ray.error_info()[0][b"message"].decode("ascii")) + self.assertIn(error_string1, + ray.error_info()[0][b"message"].decode("ascii")) # Start another driver and make sure that it does not receive this error. # Make the other driver throw an error, and make sure it receives that @@ -98,7 +101,8 @@ print("success") # Make sure that the other error message doesn't show up for this driver. self.assertEqual(len(ray.error_info()), 1) - self.assertIn(error_string1, ray.error_info()[0][b"message"].decode("ascii")) + self.assertIn(error_string1, + ray.error_info()[0][b"message"].decode("ascii")) ray.worker.cleanup() @@ -149,6 +153,7 @@ print("success") ray.worker.cleanup() + class StartRayScriptTest(unittest.TestCase): def testCallingStartRayHead(self): @@ -157,11 +162,12 @@ class StartRayScriptTest(unittest.TestCase): # the non-head node code path. # Test starting Ray with no arguments. - out = subprocess.check_output([start_ray_script, "--head"]).decode("ascii") + subprocess.check_output([start_ray_script, "--head"]).decode("ascii") subprocess.Popen([stop_ray_script]).wait() # Test starting Ray with a number of workers specified. - subprocess.check_output([start_ray_script, "--head", "--num-workers", "20"]) + subprocess.check_output([start_ray_script, "--head", "--num-workers", + "20"]) subprocess.Popen([stop_ray_script]).wait() # Test starting Ray with a redis port specified. @@ -204,5 +210,6 @@ class StartRayScriptTest(unittest.TestCase): "--redis-address", "127.0.0.1:6379"]) subprocess.Popen([stop_ray_script]).wait() + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/test/runtest.py b/test/runtest.py index 413558791..d26e219be 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -12,24 +12,33 @@ import string import sys from collections import namedtuple +import ray.test.test_functions as test_functions + if sys.version_info >= (3, 0): from importlib import reload -import ray.test.test_functions as test_functions -import ray.experimental.array.remote as ra -import ray.experimental.array.distributed as da def assert_equal(obj1, obj2): - if type(obj1).__module__ == np.__name__ or type(obj2).__module__ == np.__name__: - if (hasattr(obj1, "shape") and obj1.shape == ()) or (hasattr(obj2, "shape") and obj2.shape == ()): + module_numpy = (type(obj1).__module__ == np.__name__ or + type(obj2).__module__ == np.__name__) + if module_numpy: + empty_shape = ((hasattr(obj1, "shape") and obj1.shape == ()) or + (hasattr(obj2, "shape") and obj2.shape == ())) + if empty_shape: # This is a special case because currently np.testing.assert_equal fails # because we do not properly handle different numerical types. - assert obj1 == obj2, "Objects {} and {} are different.".format(obj1, obj2) + assert obj1 == obj2, ("Objects {} and {} are " + "different.".format(obj1, obj2)) else: np.testing.assert_equal(obj1, obj2) elif hasattr(obj1, "__dict__") and hasattr(obj2, "__dict__"): special_keys = ["_pytype_"] - assert set(list(obj1.__dict__.keys()) + special_keys) == set(list(obj2.__dict__.keys()) + special_keys), "Objects {} and {} are different.".format(obj1, obj2) + assert (set(list(obj1.__dict__.keys()) + special_keys) == + set(list(obj2.__dict__.keys()) + special_keys)), ("Objects {} and " + "{} are " + "different." + .format(obj1, + obj2)) for key in obj1.__dict__.keys(): if key not in special_keys: assert_equal(obj1.__dict__[key], obj2.__dict__[key]) @@ -38,24 +47,29 @@ def assert_equal(obj1, obj2): for key in obj1.keys(): assert_equal(obj1[key], obj2[key]) elif type(obj1) is list or type(obj2) is list: - assert len(obj1) == len(obj2), "Objects {} and {} are lists with different lengths.".format(obj1, obj2) + assert len(obj1) == len(obj2), ("Objects {} and {} are lists with " + "different lengths.".format(obj1, obj2)) for i in range(len(obj1)): assert_equal(obj1[i], obj2[i]) elif type(obj1) is tuple or type(obj2) is tuple: - assert len(obj1) == len(obj2), "Objects {} and {} are tuples with different lengths.".format(obj1, obj2) + assert len(obj1) == len(obj2), ("Objects {} and {} are tuples with " + "different lengths.".format(obj1, obj2)) for i in range(len(obj1)): assert_equal(obj1[i], obj2[i]) - elif ray.serialization.is_named_tuple(type(obj1)) or ray.serialization.is_named_tuple(type(obj2)): - assert len(obj1) == len(obj2), "Objects {} and {} are named tuples with different lengths.".format(obj1, obj2) + elif (ray.serialization.is_named_tuple(type(obj1)) or + ray.serialization.is_named_tuple(type(obj2))): + assert len(obj1) == len(obj2), ("Objects {} and {} are named tuples with " + "different lengths.".format(obj1, obj2)) for i in range(len(obj1)): assert_equal(obj1[i], obj2[i]) else: assert obj1 == obj2, "Objects {} and {} are different.".format(obj1, obj2) + if sys.version_info >= (3, 0): long_extras = [0, np.array([["hi", u"hi"], [1.3, 1]])] else: - long_extras = [long(0), np.array([["hi", u"hi"], [1.3, long(1)]])] + long_extras = [long(0), np.array([["hi", u"hi"], [1.3, long(1)]])] # noqa: E501,F821 PRIMITIVE_OBJECTS = [0, 0.0, 0.9, 1 << 62, "a", string.printable, "\u262F", u"hello world", u"\xff\xfe\x9c\x001\x000\x00", None, True, @@ -65,45 +79,55 @@ PRIMITIVE_OBJECTS = [0, 0.0, 0.9, 1 << 62, "a", string.printable, "\u262F", np.random.normal(size=[100, 100]), np.array(["hi", 3]), np.array(["hi", 3], dtype=object)] + long_extras -COMPLEX_OBJECTS = [[[[[[[[[[[[[]]]]]]]]]]]], - {"obj{}".format(i): np.random.normal(size=[100, 100]) for i in range(10)}, - #{(): {(): {(): {(): {(): {(): {(): {(): {(): {(): {(): {(): {}}}}}}}}}}}}}, - ((((((((((),),),),),),),),),), - {"a": {"b": {"c": {"d": {}}}}} - ] +COMPLEX_OBJECTS = [ + [[[[[[[[[[[[]]]]]]]]]]]], + {"obj{}".format(i): np.random.normal(size=[100, 100]) for i in range(10)}, + # {(): {(): {(): {(): {(): {(): {(): {(): {(): {(): { + # (): {(): {}}}}}}}}}}}}}, + ((((((((((),),),),),),),),),), + {"a": {"b": {"c": {"d": {}}}}}] + class Foo(object): def __init__(self): pass + class Bar(object): def __init__(self): for i, val in enumerate(PRIMITIVE_OBJECTS + COMPLEX_OBJECTS): setattr(self, "field{}".format(i), val) + class Baz(object): def __init__(self): self.foo = Foo() self.bar = Bar() + def method(self, arg): pass + class Qux(object): def __init__(self): self.objs = [Foo(), Bar(), Baz()] + class SubQux(Qux): def __init__(self): Qux.__init__(self) + class CustomError(Exception): pass + Point = namedtuple("Point", ["x", "y"]) -NamedTupleExample = namedtuple("Example", "field1, field2, field3, field4, field5") +NamedTupleExample = namedtuple("Example", + "field1, field2, field3, field4, field5") CUSTOM_OBJECTS = [Exception("Test object."), CustomError(), Point(11, y=22), - Foo(), Bar(), Baz(), # Qux(), SubQux(), + Foo(), Bar(), Baz(), # Qux(), SubQux(), NamedTupleExample(1, 1.0, "hi", np.zeros([3, 5]), [1, 2, 3])] BASE_OBJECTS = PRIMITIVE_OBJECTS + COMPLEX_OBJECTS + CUSTOM_OBJECTS @@ -112,8 +136,9 @@ LIST_OBJECTS = [[obj] for obj in BASE_OBJECTS] TUPLE_OBJECTS = [(obj,) for obj in BASE_OBJECTS] # The check that type(obj).__module__ != "numpy" should be unnecessary, but # otherwise this seems to fail on Mac OS X on Travis. -DICT_OBJECTS = ([{obj: obj} for obj in PRIMITIVE_OBJECTS if obj.__hash__ is not None and type(obj).__module__ != "numpy"] + -# DICT_OBJECTS = ([{obj: obj} for obj in BASE_OBJECTS if obj.__hash__ is not None] + +DICT_OBJECTS = ([{obj: obj} for obj in PRIMITIVE_OBJECTS + if (obj.__hash__ is not None and + type(obj).__module__ != "numpy")] + [{0: obj} for obj in BASE_OBJECTS]) RAY_TEST_OBJECTS = BASE_OBJECTS + LIST_OBJECTS + TUPLE_OBJECTS + DICT_OBJECTS @@ -124,7 +149,10 @@ try: cloudpickle.dumps(Point) except AttributeError: cloudpickle_command = "pip install --upgrade cloudpickle" - raise Exception("You have an older version of cloudpickle that is not able to serialize namedtuples. Try running \n\n{}\n\n".format(cloudpickle_command)) + raise Exception("You have an older version of cloudpickle that is not able " + "to serialize namedtuples. Try running " + "\n\n{}\n\n".format(cloudpickle_command)) + class SerializationTest(unittest.TestCase): @@ -155,7 +183,7 @@ class SerializationTest(unittest.TestCase): # Check that exceptions are thrown when we serialize the recursive objects. for obj in recursive_objects: - self.assertRaises(Exception, lambda : ray.put(obj)) + self.assertRaises(Exception, lambda: ray.put(obj)) ray.worker.cleanup() @@ -181,6 +209,7 @@ class SerializationTest(unittest.TestCase): ray.worker.cleanup() + class WorkerTest(unittest.TestCase): def testPythonWorkers(self): @@ -228,6 +257,7 @@ class WorkerTest(unittest.TestCase): ray.worker.cleanup() + class APITest(unittest.TestCase): def testRegisterClass(self): @@ -237,10 +267,10 @@ class APITest(unittest.TestCase): # throws an exception. class TempClass(object): pass - self.assertRaises(Exception, lambda : ray.put(Foo)) + self.assertRaises(Exception, lambda: ray.put(Foo)) # Check that registering a class that Ray cannot serialize efficiently # raises an exception. - self.assertRaises(Exception, lambda : ray.register_class(type(True))) + self.assertRaises(Exception, lambda: ray.register_class(type(True))) # Check that registering the same class with pickle works. ray.register_class(type(float), pickle=True) self.assertEqual(ray.get(ray.put(float)), float) @@ -328,7 +358,9 @@ class APITest(unittest.TestCase): print("Still using old definition of f, trying again.") # Test that we can close over plain old data. - data = [np.zeros([3, 5]), (1, 2, "a"), [0.0, 1.0, 1 << 62], 1 << 60, {"a": np.zeros(3)}] + data = [np.zeros([3, 5]), (1, 2, "a"), [0.0, 1.0, 1 << 62], 1 << 60, + {"a": np.zeros(3)}] + @ray.remote def g(): return data @@ -339,18 +371,22 @@ class APITest(unittest.TestCase): def h(): return np.zeros([3, 5]) assert_equal(ray.get(h.remote()), np.zeros([3, 5])) + @ray.remote def j(): return time.time() ray.get(j.remote()) - # Test that we can define remote functions that call other remote functions. + # Test that we can define remote functions that call other remote + # functions. @ray.remote def k(x): return x + 1 + @ray.remote def l(x): return ray.get(k.remote(x)) + @ray.remote def m(x): return ray.get(l.remote(x)) @@ -398,7 +434,7 @@ class APITest(unittest.TestCase): # Verify that calling wait with duplicate object IDs throws an exception. x = ray.put(1) - self.assertRaises(Exception, lambda : ray.wait([x, x])) + self.assertRaises(Exception, lambda: ray.wait([x, x])) ray.worker.cleanup() @@ -435,11 +471,14 @@ class APITest(unittest.TestCase): ray.worker.cleanup() def testCachingEnvironmentVariables(self): - # Test that we can define environment variables before the driver is connected. + # Test that we can define environment variables before the driver is + # connected. def foo_initializer(): return 1 + def bar_initializer(): return [] + def bar_reinitializer(bar): return [] ray.env.foo = ray.EnvironmentVariable(foo_initializer) @@ -448,6 +487,7 @@ class APITest(unittest.TestCase): @ray.remote def use_foo(): return ray.env.foo + @ray.remote def use_bar(): ray.env.bar.append(1) @@ -463,16 +503,20 @@ class APITest(unittest.TestCase): ray.worker.cleanup() def testCachingFunctionsToRun(self): - # Test that we export functions to run on all workers before the driver is connected. + # Test that we export functions to run on all workers before the driver is + # connected. def f(worker_info): sys.path.append(1) ray.worker.global_worker.run_function_on_all_workers(f) + def f(worker_info): sys.path.append(2) ray.worker.global_worker.run_function_on_all_workers(f) + def g(worker_info): sys.path.append(3) ray.worker.global_worker.run_function_on_all_workers(g) + def f(worker_info): sys.path.append(4) ray.worker.global_worker.run_function_on_all_workers(f) @@ -505,13 +549,16 @@ class APITest(unittest.TestCase): def f(worker_info): sys.path.append("fake_directory") ray.worker.global_worker.run_function_on_all_workers(f) + @ray.remote def get_path1(): return sys.path self.assertEqual("fake_directory", ray.get(get_path1.remote())[-1]) + def f(worker_info): sys.path.pop(-1) ray.worker.global_worker.run_function_on_all_workers(f) + # Create a second remote function to guarantee that when we call # get_path2.remote(), the second function to run will have been run on the # worker. @@ -528,6 +575,7 @@ class APITest(unittest.TestCase): def f(worker_info): sys.path.append(worker_info) ray.worker.global_worker.run_function_on_all_workers(f) + @ray.remote def get_path(): time.sleep(1) @@ -542,6 +590,7 @@ class APITest(unittest.TestCase): counters = [worker_info["counter"] for worker_info in worker_infos] # We use range(11) because the driver also runs the function. self.assertEqual(set(counters), set(range(11))) + # Clean up the worker paths. def f(worker_info): sys.path.pop(-1) @@ -555,7 +604,8 @@ class APITest(unittest.TestCase): def events(): # This is a hack for getting the event log. It is not part of the API. keys = ray.worker.global_worker.redis_client.keys("event_log:*") - return [ray.worker.global_worker.redis_client.lrange(key, 0, -1) for key in keys] + return [ray.worker.global_worker.redis_client.lrange(key, 0, -1) + for key in keys] def wait_for_num_events(num_events, timeout=10): start_time = time.time() @@ -604,25 +654,28 @@ class APITest(unittest.TestCase): # accidentally call an older version. ray.init(num_workers=2) - num_remote_functions = 100 num_calls = 200 @ray.remote def f(): return 1 results1 = [f.remote() for _ in range(num_calls)] + @ray.remote def f(): return 2 results2 = [f.remote() for _ in range(num_calls)] + @ray.remote def f(): return 3 results3 = [f.remote() for _ in range(num_calls)] + @ray.remote def f(): return 4 results4 = [f.remote() for _ in range(num_calls)] + @ray.remote def f(): return 5 @@ -637,16 +690,20 @@ class APITest(unittest.TestCase): @ray.remote def g(): return 1 - @ray.remote + + @ray.remote # noqa: F811 def g(): return 2 - @ray.remote + + @ray.remote # noqa: F811 def g(): return 3 - @ray.remote + + @ray.remote # noqa: F811 def g(): return 4 - @ray.remote + + @ray.remote # noqa: F811 def g(): return 5 @@ -668,6 +725,7 @@ class APITest(unittest.TestCase): ray.worker.cleanup() + class PythonModeTest(unittest.TestCase): def testPythonMode(self): @@ -678,17 +736,21 @@ class PythonModeTest(unittest.TestCase): def f(): return np.ones([3, 4, 5]) xref = f.remote() - assert_equal(xref, np.ones([3, 4, 5])) # remote functions should return by value - assert_equal(xref, ray.get(xref)) # ray.get should be the identity + # Remote functions should return by value. + assert_equal(xref, np.ones([3, 4, 5])) + # Check that ray.get is the identity. + assert_equal(xref, ray.get(xref)) y = np.random.normal(size=[11, 12]) - assert_equal(y, ray.put(y)) # ray.put should be the identity + # Check that ray.put is the identity. + assert_equal(y, ray.put(y)) - # make sure objects are immutable, this example is why we need to copy + # Make sure objects are immutable, this example is why we need to copy # arguments before passing them into remote functions in python mode aref = test_functions.python_mode_f.remote() assert_equal(aref, np.array([0, 0])) bref = test_functions.python_mode_g.remote(aref) - assert_equal(aref, np.array([0, 0])) # python_mode_g should not mutate aref + # Make sure python_mode_g does not mutate aref. + assert_equal(aref, np.array([0, 0])) assert_equal(bref, np.array([1, 0])) ray.worker.cleanup() @@ -699,6 +761,7 @@ class PythonModeTest(unittest.TestCase): def l_init(): return [] + def l_reinit(l): return [] ray.env.l = ray.EnvironmentVariable(l_init, l_reinit) @@ -717,7 +780,8 @@ class PythonModeTest(unittest.TestCase): assert_equal(ray.get(use_l.remote()), [1]) assert_equal(ray.get(use_l.remote()), [1]) - # Make sure the local copy of the environment variable has not been mutated. + # Make sure the local copy of the environment variable has not been + # mutated. assert_equal(l, []) l = ray.env.l assert_equal(l, []) @@ -730,6 +794,7 @@ class PythonModeTest(unittest.TestCase): ray.worker.cleanup() + class EnvironmentVariablesTest(unittest.TestCase): def testEnvironmentVariables(self): @@ -739,6 +804,7 @@ class EnvironmentVariablesTest(unittest.TestCase): def foo_initializer(): return 1 + def foo_reinitializer(foo): return foo @@ -752,7 +818,8 @@ class EnvironmentVariablesTest(unittest.TestCase): self.assertEqual(ray.get(use_foo.remote()), 1) self.assertEqual(ray.get(use_foo.remote()), 1) - # Test that we can add a variable to the key-value store, mutate it, and reset it. + # Test that we can add a variable to the key-value store, mutate it, and + # reset it. def bar_initializer(): return [1, 2, 3] @@ -771,6 +838,7 @@ class EnvironmentVariablesTest(unittest.TestCase): def baz_initializer(): return np.zeros([4]) + def baz_reinitializer(baz): for i in range(len(baz)): baz[i] = 0 @@ -794,6 +862,7 @@ class EnvironmentVariablesTest(unittest.TestCase): def qux_initializer(): return 0 + def qux_reinitializer(x): return x + 1 @@ -815,6 +884,7 @@ class EnvironmentVariablesTest(unittest.TestCase): def foo_initializer(): return [] + def foo_reinitializer(foo): return [] @@ -846,6 +916,7 @@ class EnvironmentVariablesTest(unittest.TestCase): ray.worker.cleanup() + class UtilsTest(unittest.TestCase): def testCopyingDirectory(self): @@ -894,6 +965,7 @@ class UtilsTest(unittest.TestCase): ray.worker.cleanup() + class ResourcesTest(unittest.TestCase): def testResourceConstraints(self): @@ -901,13 +973,16 @@ class ResourcesTest(unittest.TestCase): ray.init(num_workers=num_workers, num_cpus=10, num_gpus=2) # Attempt to wait for all of the workers to start up. - ray.worker.global_worker.run_function_on_all_workers(lambda worker_info: sys.path.append(worker_info["counter"])) + ray.worker.global_worker.run_function_on_all_workers( + lambda worker_info: sys.path.append(worker_info["counter"])) + @ray.remote(num_cpus=0) def get_worker_id(): time.sleep(1) return sys.path[-1] while True: - if len(set(ray.get([get_worker_id.remote() for _ in range(num_workers)]))) == num_workers: + if len(set(ray.get([get_worker_id.remote() + for _ in range(num_workers)]))) == num_workers: break time_buffer = 0.3 @@ -974,13 +1049,16 @@ class ResourcesTest(unittest.TestCase): ray.init(num_workers=num_workers, num_cpus=10, num_gpus=10) # Attempt to wait for all of the workers to start up. - ray.worker.global_worker.run_function_on_all_workers(lambda worker_info: sys.path.append(worker_info["counter"])) + ray.worker.global_worker.run_function_on_all_workers( + lambda worker_info: sys.path.append(worker_info["counter"])) + @ray.remote(num_cpus=0) def get_worker_id(): time.sleep(1) return sys.path[-1] while True: - if len(set(ray.get([get_worker_id.remote() for _ in range(num_workers)]))) == num_workers: + if len(set(ray.get([get_worker_id.remote() + for _ in range(num_workers)]))) == num_workers: break @ray.remote(num_cpus=1, num_gpus=9) @@ -1021,8 +1099,8 @@ class ResourcesTest(unittest.TestCase): def testMultipleLocalSchedulers(self): # This test will define a bunch of tasks that can only be assigned to - # specific local schedulers, and we will check that they are assigned to the - # correct local schedulers. + # specific local schedulers, and we will check that they are assigned to + # the correct local schedulers. address_info = ray.worker._init(start_ray_local=True, num_local_schedulers=3, num_cpus=[100, 5, 10], @@ -1088,7 +1166,8 @@ class ResourcesTest(unittest.TestCase): results.append(run_on_0_2.remote()) return names, results - store_names = [object_store_address.name for object_store_address in address_info["object_store_addresses"]] + store_names = [object_store_address.name for object_store_address + in address_info["object_store_addresses"]] def validate_names_and_results(names, results): for name, result in zip(names, ray.get(results)): @@ -1099,7 +1178,8 @@ class ResourcesTest(unittest.TestCase): elif name == "run_on_2": self.assertIn(result, [store_names[2]]) elif name == "run_on_0_1_2": - self.assertIn(result, [store_names[0], store_names[1], store_names[2]]) + self.assertIn(result, [store_names[0], store_names[1], + store_names[2]]) elif name == "run_on_1_2": self.assertIn(result, [store_names[1], store_names[2]]) elif name == "run_on_0_2": @@ -1128,6 +1208,7 @@ class ResourcesTest(unittest.TestCase): ray.worker.cleanup() + class WorkerPoolTests(unittest.TestCase): def tearDown(self): @@ -1177,6 +1258,7 @@ class WorkerPoolTests(unittest.TestCase): ray.worker.cleanup() + class SchedulingAlgorithm(unittest.TestCase): def attempt_to_load_balance(self, remote_function, args, total_tasks, @@ -1184,21 +1266,24 @@ class SchedulingAlgorithm(unittest.TestCase): num_attempts=20): attempts = 0 while attempts < num_attempts: - locations = ray.get([remote_function.remote(*args) for _ in range(total_tasks)]) + locations = ray.get([remote_function.remote(*args) + for _ in range(total_tasks)]) names = set(locations) counts = [locations.count(name) for name in names] print("Counts are {}.".format(counts)) - if len(names) == num_local_schedulers and all([count >= minimum_count for count in counts]): + if len(names) == num_local_schedulers and all([count >= minimum_count + for count in counts]): break attempts += 1 self.assertLess(attempts, num_attempts) def testLoadBalancing(self): - # This test ensures that tasks are being assigned to all local schedulers in - # a roughly equal manner. + # This test ensures that tasks are being assigned to all local schedulers + # in a roughly equal manner. num_workers = 21 num_local_schedulers = 3 - ray.worker._init(start_ray_local=True, num_workers=num_workers, num_local_schedulers=num_local_schedulers) + ray.worker._init(start_ray_local=True, num_workers=num_workers, + num_local_schedulers=num_local_schedulers) @ray.remote def f(): @@ -1211,11 +1296,12 @@ class SchedulingAlgorithm(unittest.TestCase): ray.worker.cleanup() def testLoadBalancingWithDependencies(self): - # This test ensures that tasks are being assigned to all local schedulers in - # a roughly equal manner even when the tasks have dependencies. + # This test ensures that tasks are being assigned to all local schedulers + # in a roughly equal manner even when the tasks have dependencies. num_workers = 3 num_local_schedulers = 3 - ray.worker._init(start_ray_local=True, num_workers=num_workers, num_local_schedulers=num_local_schedulers) + ray.worker._init(start_ray_local=True, num_workers=num_workers, + num_local_schedulers=num_local_schedulers) @ray.remote def f(x): @@ -1229,5 +1315,6 @@ class SchedulingAlgorithm(unittest.TestCase): ray.worker.cleanup() + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/test/stress_tests.py b/test/stress_tests.py index 04bbe5f40..48de88a24 100644 --- a/test/stress_tests.py +++ b/test/stress_tests.py @@ -11,6 +11,7 @@ import redis # Import flatbuffer bindings. from ray.core.generated.TaskReply import TaskReply + class TaskTests(unittest.TestCase): def testSubmittingTasks(self): @@ -93,7 +94,7 @@ class TaskTests(unittest.TestCase): def f(): return 1 - n = 10 ** 4 # TODO(pcm): replace by 10 ** 5 once this is faster + n = 10 ** 4 # TODO(pcm): replace by 10 ** 5 once this is faster. l = ray.get([f.remote() for _ in range(n)]) self.assertEqual(l, n * [1]) @@ -123,12 +124,14 @@ class TaskTests(unittest.TestCase): time.sleep(x) for i in range(1, 5): - x_ids = [g.remote(np.random.uniform(0, i)) for _ in range(2 * num_workers)] + x_ids = [g.remote(np.random.uniform(0, i)) + for _ in range(2 * num_workers)] ray.wait(x_ids, num_returns=len(x_ids)) self.assertTrue(ray.services.all_processes_alive()) ray.worker.cleanup() + class ReconstructionTests(unittest.TestCase): num_local_schedulers = 1 @@ -144,14 +147,10 @@ class ReconstructionTests(unittest.TestCase): plasma_addresses = [] objstore_memory = (self.plasma_store_memory // self.num_local_schedulers) for i in range(self.num_local_schedulers): - plasma_addresses.append( - ray.services.start_objstore(node_ip_address, redis_address, - objstore_memory=objstore_memory) - ) - address_info = { - "redis_address": redis_address, - "object_store_addresses": plasma_addresses, - } + plasma_addresses.append(ray.services.start_objstore( + node_ip_address, redis_address, objstore_memory=objstore_memory)) + address_info = {"redis_address": redis_address, + "object_store_addresses": plasma_addresses} # Start the rest of the services in the Ray cluster. ray.worker._init(address_info=address_info, start_ray_local=True, @@ -180,7 +179,8 @@ class ReconstructionTests(unittest.TestCase): # total number of local schedulers to account for NIL_LOCAL_SCHEDULER_ID. # This is the local scheduler ID associated with the driver task, since it # is not scheduled by a particular local scheduler. - self.assertEqual(len(set(local_scheduler_ids)), self.num_local_schedulers + 1) + self.assertEqual(len(set(local_scheduler_ids)), + self.num_local_schedulers + 1) # Clean up the Ray cluster. ray.worker.cleanup() @@ -218,7 +218,7 @@ class ReconstructionTests(unittest.TestCase): num_chunks = 4 * self.num_local_schedulers chunk = num_objects // num_chunks for i in range(num_chunks): - values = ray.get(args[i * chunk : (i + 1) * chunk]) + values = ray.get(args[i * chunk:(i + 1) * chunk]) del values def testRecursive(self): @@ -261,14 +261,14 @@ class ReconstructionTests(unittest.TestCase): self.assertEqual(value[0], i) # Get 10 values randomly. for _ in range(10): - i = np.random.randint(num_objects) + i = np.random.randint(num_objects) value = ray.get(args[i]) self.assertEqual(value[0], i) # Get values sequentially, in chunks. num_chunks = 4 * self.num_local_schedulers chunk = num_objects // num_chunks for i in range(num_chunks): - values = ray.get(args[i * chunk : (i + 1) * chunk]) + values = ray.get(args[i * chunk:(i + 1) * chunk]) del values def testMultipleRecursive(self): @@ -316,7 +316,7 @@ class ReconstructionTests(unittest.TestCase): self.assertEqual(value[0], i) # Get 10 values randomly. for _ in range(10): - i = np.random.randint(num_objects) + i = np.random.randint(num_objects) value = ray.get(args[i]) self.assertEqual(value[0], i) @@ -391,7 +391,8 @@ class ReconstructionTests(unittest.TestCase): return len(errors) >= min_errors errors = self.wait_for_errors(error_check) # Make sure all the errors have the correct type. - self.assertTrue(all(error[b"type"] == b"object_hash_mismatch" for error in errors)) + self.assertTrue(all(error[b"type"] == b"object_hash_mismatch" + for error in errors)) # Make sure all the errors have the correct function name. self.assertTrue(all(error[b"data"] == b"__main__.foo" for error in errors)) @@ -462,20 +463,26 @@ class ReconstructionTests(unittest.TestCase): self.assertEqual(value[0], i) put_arg_task.remote(size) + def error_check(errors): return len(errors) > 1 errors = self.wait_for_errors(error_check) # Make sure all the errors have the correct type. - self.assertTrue(all(error[b"type"] == b"put_reconstruction" for error in errors)) - self.assertTrue(all(error[b"data"] == b"__main__.put_arg_task" for error in errors)) + self.assertTrue(all(error[b"type"] == b"put_reconstruction" + for error in errors)) + self.assertTrue(all(error[b"data"] == b"__main__.put_arg_task" + for error in errors)) put_task.remote(size) + def error_check(errors): return any(error[b"data"] == b"__main__.put_task" for error in errors) errors = self.wait_for_errors(error_check) # Make sure all the errors have the correct type. - self.assertTrue(all(error[b"type"] == b"put_reconstruction" for error in errors)) - self.assertTrue(any(error[b"data"] == b"__main__.put_task" for error in errors)) + self.assertTrue(all(error[b"type"] == b"put_reconstruction" + for error in errors)) + self.assertTrue(any(error[b"data"] == b"__main__.put_task" + for error in errors)) def testDriverPutErrors(self): # Define the size of one task's return argument so that the combined sum of @@ -511,11 +518,14 @@ class ReconstructionTests(unittest.TestCase): # were evicted and whose originating tasks are still running, this # for-loop should hang on its first iteration and push an error to the # driver. - ray.worker.global_worker.local_scheduler_client.reconstruct_object(args[0].id()) + ray.worker.global_worker.local_scheduler_client.reconstruct_object( + args[0].id()) + def error_check(errors): return len(errors) > 1 errors = self.wait_for_errors(error_check) - self.assertTrue(all(error[b"type"] == b"put_reconstruction" for error in errors)) + self.assertTrue(all(error[b"type"] == b"put_reconstruction" + for error in errors)) self.assertTrue(all(error[b"data"] == b"Driver" for error in errors)) @@ -526,26 +536,27 @@ class ReconstructionTestsMultinode(ReconstructionTests): num_local_schedulers = 4 # NOTE(swang): This test tries to launch 1000 workers and breaks. -#class WorkerPoolTests(unittest.TestCase): +# class WorkerPoolTests(unittest.TestCase): # -# def tearDown(self): -# ray.worker.cleanup() +# def tearDown(self): +# ray.worker.cleanup() # -# def testBlockingTasks(self): -# @ray.remote -# def f(i, j): -# return (i, j) +# def testBlockingTasks(self): +# @ray.remote +# def f(i, j): +# return (i, j) # -# @ray.remote -# def g(i): -# # Each instance of g submits and blocks on the result of another remote -# # task. -# object_ids = [f.remote(i, j) for j in range(10)] -# return ray.get(object_ids) +# @ray.remote +# def g(i): +# # Each instance of g submits and blocks on the result of another remote +# # task. +# object_ids = [f.remote(i, j) for j in range(10)] +# return ray.get(object_ids) # -# ray.init(num_workers=1) -# ray.get([g.remote(i) for i in range(1000)]) -# ray.worker.cleanup() +# ray.init(num_workers=1) +# ray.get([g.remote(i) for i in range(1000)]) +# ray.worker.cleanup() + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/test/tensorflow_test.py b/test/tensorflow_test.py index f5c0b05e4..3cf101247 100644 --- a/test/tensorflow_test.py +++ b/test/tensorflow_test.py @@ -2,11 +2,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import unittest -import uuid -import tensorflow as tf -import ray from numpy.testing import assert_almost_equal +import tensorflow as tf +import unittest + +import ray + def make_linear_network(w_name=None, b_name=None): # Define the inputs. @@ -17,7 +18,9 @@ def make_linear_network(w_name=None, b_name=None): b = tf.Variable(tf.zeros([1]), name=b_name) y = w * x_data + b # Return the loss and weight initializer. - return tf.reduce_mean(tf.square(y - y_data)), tf.global_variables_initializer(), x_data, y_data + return (tf.reduce_mean(tf.square(y - y_data)), + tf.global_variables_initializer(), x_data, y_data) + class NetActor(object): @@ -40,6 +43,7 @@ class NetActor(object): def get_weights(self): return self.values[0].get_weights() + class TrainActor(object): def __init__(self): @@ -57,11 +61,13 @@ class TrainActor(object): def training_step(self, weights): _, variables, _, sess, grads, _, placeholders = self.values variables.set_weights(weights) - return sess.run([grad[0] for grad in grads], feed_dict=dict(zip(placeholders, [[1]*100, [2]*100]))) + return sess.run([grad[0] for grad in grads], + feed_dict=dict(zip(placeholders, [[1] * 100, [2] * 100]))) def get_weights(self): return self.values[1].get_weights() + class TensorFlowTest(unittest.TestCase): def testTensorFlowVariables(self): @@ -113,9 +119,6 @@ class TensorFlowTest(unittest.TestCase): net1 = NetActor() net2 = NetActor() - net_vars1, init1, sess1 = net1.values - net_vars2, init2, sess2 = net2.values - # This is checking that the variable names of the two nets are the same, # i.e. that the names in the weight dictionaries are the same net1.values[0].set_weights(net2.values[0].get_weights()) @@ -125,7 +128,8 @@ class TensorFlowTest(unittest.TestCase): # Test that different networks on the same worker are independent and # we can get/set their weights without any interaction. def testNetworksIndependent(self): - # Note we use only one worker to ensure that all of the remote functions run on the same worker. + # Note we use only one worker to ensure that all of the remote functions + # run on the same worker. ray.init(num_workers=1) net1 = NetActor() net2 = NetActor() @@ -151,15 +155,15 @@ class TensorFlowTest(unittest.TestCase): ray.worker.cleanup() - # This test creates an additional network on the driver so that the tensorflow - # variables on the driver and the worker differ. + # This test creates an additional network on the driver so that the + # tensorflow variables on the driver and the worker differ. def testNetworkDriverWorkerIndependent(self): ray.init(num_workers=1) # Create a network on the driver locally. sess1 = tf.Session() loss1, init1, _, _ = make_linear_network() - net_vars1 = ray.experimental.TensorFlowVariables(loss1, sess1) + ray.experimental.TensorFlowVariables(loss1, sess1) sess1.run(init1) net2 = ray.actor(NetActor)() @@ -194,39 +198,28 @@ class TensorFlowTest(unittest.TestCase): ray.worker.cleanup() - def testRemoteTrainingLoss(self): ray.init(num_workers=2) net = ray.actor(TrainActor)() loss, variables, _, sess, grads, train, placeholders = TrainActor().values - before_acc = sess.run(loss, feed_dict=dict(zip(placeholders, [[2]*100, [4]*100]))) + before_acc = sess.run(loss, feed_dict=dict(zip(placeholders, + [[2] * 100, [4] * 100]))) for _ in range(3): - gradients_list = ray.get([net.training_step(variables.get_weights()) for _ in range(2)]) - mean_grads = [sum([gradients[i] for gradients in gradients_list]) / len(gradients_list) for i in range(len(gradients_list[0]))] - feed_dict = {grad[0]: mean_grad for (grad, mean_grad) in zip(grads, mean_grads)} + gradients_list = ray.get([net.training_step(variables.get_weights()) + for _ in range(2)]) + mean_grads = [sum([gradients[i] for gradients in gradients_list]) / + len(gradients_list) for i in range(len(gradients_list[0]))] + feed_dict = {grad[0]: mean_grad for (grad, mean_grad) + in zip(grads, mean_grads)} sess.run(train, feed_dict=feed_dict) - after_acc = sess.run(loss, feed_dict=dict(zip(placeholders, [[2]*100, [4]*100]))) + after_acc = sess.run(loss, feed_dict=dict(zip(placeholders, + [[2] * 100, [4] * 100]))) self.assertTrue(before_acc < after_acc) ray.worker.cleanup() - def testVariablesControlDependencies(self): - ray.init(num_workers=1) - - # Creates a network and appends a momentum optimizer. - sess = tf.Session() - loss, init, _, _ = make_linear_network() - minimizer = tf.train.MomentumOptimizer(0.9, 0.9).minimize(loss) - net_vars = ray.experimental.TensorFlowVariables(minimizer, sess) - sess.run(init) - - # Tests if all variables are properly retrieved, 2 variables and 2 momentum - # variables. - self.assertEqual(len(net_vars.variables.items()), 4) - - ray.worker.cleanup() if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/webui/backend/ray_ui.py b/webui/backend/ray_ui.py index 3b7538ab5..ba62b8577 100644 --- a/webui/backend/ray_ui.py +++ b/webui/backend/ray_ui.py @@ -6,17 +6,17 @@ import collections import datetime import json import numpy as np -import os -import redis -import sys import time import websockets # Import flatbuffer bindings. -from ray.core.generated.LocalSchedulerInfoMessage import LocalSchedulerInfoMessage +from ray.core.generated.LocalSchedulerInfoMessage import \ + LocalSchedulerInfoMessage -parser = argparse.ArgumentParser(description="parse information for the web ui") -parser.add_argument("--redis-address", required=True, type=str, help="the address to use for redis") +parser = argparse.ArgumentParser( + description="parse information for the web ui") +parser.add_argument("--redis-address", required=True, type=str, + help="the address to use for redis") loop = asyncio.get_event_loop() @@ -25,27 +25,36 @@ IDENTIFIER_LENGTH = 20 # This prefix must match the value defined in ray_redis_module.cc. DB_CLIENT_PREFIX = b"CL:" + def hex_identifier(identifier): return binascii.hexlify(identifier).decode() + def identifier(hex_identifier): return binascii.unhexlify(hex_identifier) + def key_to_hex_identifier(key): - return hex_identifier(key[(key.index(b":") + 1):(key.index(b":") + IDENTIFIER_LENGTH + 1)]) + return hex_identifier( + key[(key.index(b":") + 1):(key.index(b":") + IDENTIFIER_LENGTH + 1)]) + def timestamp_to_date_string(timestamp): """Convert a time stamp returned by time.time() to a formatted string.""" - return datetime.datetime.fromtimestamp(timestamp).strftime("%Y/%m/%d %H:%M:%S") + return (datetime.datetime.fromtimestamp(timestamp) + .strftime("%Y/%m/%d %H:%M:%S")) + def key_to_hex_identifiers(key): - # Extract worker_id and task_id from key of the form prefix:worker_id:task_id. + # Extract worker_id and task_id from key of the form + # prefix:worker_id:task_id. offset = key.index(b":") + 1 worker_id = hex_identifier(key[offset:(offset + IDENTIFIER_LENGTH)]) offset += IDENTIFIER_LENGTH + 1 task_id = hex_identifier(key[offset:(offset + IDENTIFIER_LENGTH)]) return worker_id, task_id + async def hgetall_as_dict(redis_conn, key): fields = await redis_conn.execute("hgetall", key) return {fields[2 * i]: fields[2 * i + 1] for i in range(len(fields) // 2)} @@ -55,6 +64,7 @@ async def hgetall_as_dict(redis_conn, key): local_schedulers = {} errors = [] + def duration_to_string(duration): """Format a duration in seconds as a string. @@ -79,8 +89,10 @@ def duration_to_string(duration): duration_str = "{} microseconds".format(int(duration * 1000000)) return duration_str + async def handle_get_statistics(websocket, redis_conn): - cluster_start_time = float(await redis_conn.execute("get", "redis_start_time")) + cluster_start_time = float(await redis_conn.execute("get", + "redis_start_time")) start_date = timestamp_to_date_string(cluster_start_time) uptime = duration_to_string(time.time() - cluster_start_time) @@ -90,7 +102,9 @@ async def handle_get_statistics(websocket, redis_conn): for client_key in client_keys: client_fields = await hgetall_as_dict(redis_conn, client_key) clients.append(client_fields) - ip_addresses = list(set([client[b"node_ip_address"].decode("ascii") for client in clients if client[b"client_type"] == b"local_scheduler"])) + ip_addresses = list(set([client[b"node_ip_address"].decode("ascii") + for client in clients + if client[b"client_type"] == b"local_scheduler"])) num_nodes = len(ip_addresses) reply = {"uptime": uptime, "start_date": start_date, @@ -98,18 +112,22 @@ async def handle_get_statistics(websocket, redis_conn): "addresses": ip_addresses} await websocket.send(json.dumps(reply)) + async def handle_get_drivers(websocket, redis_conn): keys = await redis_conn.execute("keys", "Drivers:*") drivers = [] for key in keys: driver_fields = await hgetall_as_dict(redis_conn, key) - driver_info = {"node ip address": driver_fields[b"node_ip_address"].decode("ascii"), - "name": driver_fields[b"name"].decode("ascii")} + driver_info = { + "node ip address": driver_fields[b"node_ip_address"].decode("ascii"), + "name": driver_fields[b"name"].decode("ascii")} - driver_info["start time"] = timestamp_to_date_string(float(driver_fields[b"start_time"])) + driver_info["start time"] = timestamp_to_date_string( + float(driver_fields[b"start_time"])) if b"end_time" in driver_fields: - duration = float(driver_fields[b"end_time"]) - float(driver_fields[b"start_time"]) + duration = (float(driver_fields[b"end_time"]) - + float(driver_fields[b"start_time"])) else: duration = time.time() - float(driver_fields[b"start_time"]) driver_info["duration"] = duration_to_string(duration) @@ -129,17 +147,20 @@ async def handle_get_drivers(websocket, redis_conn): reply = sorted(drivers, key=(lambda driver: driver["start time"]))[::-1] await websocket.send(json.dumps(reply)) + async def listen_for_errors(redis_ip_address, redis_port): - pubsub_conn = await aioredis.create_connection((redis_ip_address, redis_port), loop=loop) - data_conn = await aioredis.create_connection((redis_ip_address, redis_port), loop=loop) + pubsub_conn = await aioredis.create_connection( + (redis_ip_address, redis_port), loop=loop) + data_conn = await aioredis.create_connection((redis_ip_address, redis_port), + loop=loop) error_pattern = "__keyspace@0__:ErrorKeys" - psub = await pubsub_conn.execute_pubsub("psubscribe", error_pattern) + await pubsub_conn.execute_pubsub("psubscribe", error_pattern) channel = pubsub_conn.pubsub_patterns[error_pattern] print("Listening for error messages...") index = 0 while (await channel.wait_message()): - msg = await channel.get() + await channel.get() info = await data_conn.execute("lrange", "ErrorKeys", index, -1) for error_key in info: @@ -154,6 +175,7 @@ async def listen_for_errors(redis_ip_address, redis_port): "error": result}) index += 1 + async def handle_get_errors(websocket): """Send error messages to the frontend.""" await websocket.send(json.dumps(errors)) @@ -161,6 +183,7 @@ async def handle_get_errors(websocket): node_info = collections.OrderedDict() worker_info = collections.OrderedDict() + async def handle_get_recent_tasks(websocket, redis_conn, num_tasks): # First update the cache of worker information. worker_keys = await redis_conn.execute("keys", "Workers:*") @@ -168,7 +191,8 @@ async def handle_get_recent_tasks(websocket, redis_conn, num_tasks): worker_id = hex_identifier(key[len("Workers:"):]) if worker_id not in worker_info: worker_info[worker_id] = await hgetall_as_dict(redis_conn, key) - node_ip_address = worker_info[worker_id][b"node_ip_address"].decode("ascii") + node_ip_address = (worker_info[worker_id][b"node_ip_address"] + .decode("ascii")) if node_ip_address not in node_info: node_info[node_ip_address] = {"workers": []} node_info[node_ip_address]["workers"].append(worker_id) @@ -183,7 +207,8 @@ async def handle_get_recent_tasks(websocket, redis_conn, num_tasks): for key in keys: content = await redis_conn.execute("lrange", key, "0", "-1") contents.append(json.loads(content[0].decode())) - timestamps += [timestamp for (timestamp, task, kind, info) in contents[-1] if task == "ray:task"] + timestamps += [timestamp for (timestamp, task, kind, info) + in contents[-1] if task == "ray:task"] timestamps.sort() time_cutoff = timestamps[(-2 * num_tasks):][0] @@ -197,36 +222,49 @@ async def handle_get_recent_tasks(websocket, redis_conn, num_tasks): num_tasks = 0 task_data = [{"task_data": [], - "num_workers": len(node_info[node_ip_address]["workers"])} for node_ip_address in node_ip_addresses] + "num_workers": len(node_info[node_ip_address]["workers"])} + for node_ip_address in node_ip_addresses] for i in range(len(keys)): worker_id, task_id = key_to_hex_identifiers(keys[i]) data = contents[i] if worker_id not in worker_ids: # This case should be extremely rare. - raise Exception("A worker ID was not present in the list of worker IDs.") - node_ip_address = worker_info[worker_id][b"node_ip_address"].decode("ascii") + raise Exception("A worker ID was not present in the list of worker " + "IDs.") + node_ip_address = (worker_info[worker_id][b"node_ip_address"] + .decode("ascii")) worker_index = node_info[node_ip_address]["workers"].index(worker_id) node_index = node_ip_addresses.index(node_ip_address) - task_times = [timestamp for (timestamp, task, kind, info) in data if task == "ray:task"] + task_times = [timestamp for (timestamp, task, kind, info) in data + if task == "ray:task"] if task_times[1] <= time_cutoff: continue - task_get_arguments_times = [timestamp for (timestamp, task, kind, info) in data if task == "ray:task:get_arguments"] - task_execute_times = [timestamp for (timestamp, task, kind, info) in data if task == "ray:task:execute"] - task_store_outputs_times = [timestamp for (timestamp, task, kind, info) in data if task == "ray:task:store_outputs"] - task_info = {"task": task_times, - "get_arguments": task_get_arguments_times, - "execute": task_execute_times, - "store_outputs": task_store_outputs_times, - "worker_index": worker_index, - "node_ip_address": node_ip_address, - "task_formatted_time": duration_to_string(task_times[1] - task_times[0]), - "get_arguments_formatted_time": duration_to_string(task_get_arguments_times[1] - task_get_arguments_times[0])} + task_get_arguments_times = [timestamp for (timestamp, task, kind, info) + in data if task == "ray:task:get_arguments"] + task_execute_times = [timestamp for (timestamp, task, kind, info) + in data if task == "ray:task:execute"] + task_store_outputs_times = [timestamp for (timestamp, task, kind, info) + in data if task == "ray:task:store_outputs"] + task_info = { + "task": task_times, + "get_arguments": task_get_arguments_times, + "execute": task_execute_times, + "store_outputs": task_store_outputs_times, + "worker_index": worker_index, + "node_ip_address": node_ip_address, + "task_formatted_time": duration_to_string(task_times[1] - + task_times[0]), + "get_arguments_formatted_time": + duration_to_string(task_get_arguments_times[1] - + task_get_arguments_times[0])} if len(task_execute_times) == 2: - task_info["execute_formatted_time"] = duration_to_string(task_execute_times[1] - task_execute_times[0]) + task_info["execute_formatted_time"] = duration_to_string( + task_execute_times[1] - task_execute_times[0]) if len(task_store_outputs_times) == 2: - task_info["store_outputs_formatted_time"] = duration_to_string(task_store_outputs_times[1] - task_store_outputs_times[0]) + task_info["store_outputs_formatted_time"] = duration_to_string( + task_store_outputs_times[1] - task_store_outputs_times[0]) task_data[node_index]["task_data"].append(task_info) num_tasks += 1 reply = {"min_time": min_time, @@ -235,34 +273,41 @@ async def handle_get_recent_tasks(websocket, redis_conn, num_tasks): "task_data": task_data} await websocket.send(json.dumps(reply)) + async def send_heartbeat_payload(websocket): """Send heartbeat updates to the frontend every half second.""" while True: reply = [] for local_scheduler_id, local_scheduler in local_schedulers.items(): current_time = time.time() - local_scheduler_info = {"local scheduler ID": local_scheduler_id, - "time since heartbeat": duration_to_string(current_time - local_scheduler["last_heartbeat"]), - "time since heartbeat numeric": str(current_time - local_scheduler["last_heartbeat"]), - "node ip address": local_scheduler["node_ip_address"]} + local_scheduler_info = { + "local scheduler ID": local_scheduler_id, + "time since heartbeat": + (duration_to_string(current_time - + local_scheduler["last_heartbeat"])), + "time since heartbeat numeric": + str(current_time - local_scheduler["last_heartbeat"]), + "node ip address": local_scheduler["node_ip_address"]} reply.append(local_scheduler_info) # Send the payload to the frontend. await websocket.send(json.dumps(reply)) # Wait for a little while so as not to overwhelm the frontend. await asyncio.sleep(0.5) + async def send_heartbeats(websocket, redis_conn): # First update the local scheduler info locally. client_keys = await redis_conn.execute("keys", "CL:*") - clients = [] for client_key in client_keys: client_fields = await hgetall_as_dict(redis_conn, client_key) if client_fields[b"client_type"] == b"local_scheduler": local_scheduler_id = hex_identifier(client_fields[b"ray_client_id"]) - local_schedulers[local_scheduler_id] = {"node_ip_address": client_fields[b"node_ip_address"].decode("ascii"), - "local_scheduler_socket_name": client_fields[b"local_scheduler_socket_name"].decode("ascii"), - "aux_address": client_fields[b"aux_address"].decode("ascii"), - "last_heartbeat": -1 * np.inf} + local_schedulers[local_scheduler_id] = { + "node_ip_address": client_fields[b"node_ip_address"].decode("ascii"), + "local_scheduler_socket_name": + client_fields[b"local_scheduler_socket_name"].decode("ascii"), + "aux_address": client_fields[b"aux_address"].decode("ascii"), + "last_heartbeat": -1 * np.inf} # Subscribe to local scheduler heartbeats. await redis_conn.execute_pubsub("subscribe", "local_schedulers") @@ -272,7 +317,8 @@ async def send_heartbeats(websocket, redis_conn): while True: msg = await redis_conn.pubsub_channels["local_schedulers"].get() - heartbeat = LocalSchedulerInfoMessage.GetRootAsLocalSchedulerInfoMessage(msg, 0) + heartbeat = LocalSchedulerInfoMessage.GetRootAsLocalSchedulerInfoMessage( + msg, 0) local_scheduler_id_bytes = heartbeat.DbClientId() local_scheduler_id = hex_identifier(local_scheduler_id_bytes) if local_scheduler_id not in local_schedulers: @@ -281,6 +327,7 @@ async def send_heartbeats(websocket, redis_conn): continue local_schedulers[local_scheduler_id]["last_heartbeat"] = time.time() + async def cache_data_from_redis(redis_ip_address, redis_port): """Open up ports to listen for new updates from Redis.""" # TODO(richard): A lot of code needs to be ported in order to open new @@ -288,6 +335,7 @@ async def cache_data_from_redis(redis_ip_address, redis_port): asyncio.ensure_future(listen_for_errors(redis_ip_address, redis_port)) + async def handle_get_log_files(websocket, redis_conn): reply = {} # First get all keys for the log file lists. @@ -296,9 +344,11 @@ async def handle_get_log_files(websocket, redis_conn): node_ip_address = log_file_list_key.decode("ascii").split(":")[1] reply[node_ip_address] = {} # Get all of the log filenames for this node IP address. - log_filenames = await redis_conn.execute("lrange", log_file_list_key, 0, -1) + log_filenames = await redis_conn.execute("lrange", log_file_list_key, 0, + -1) for log_filename in log_filenames: - log_filename_key = "LOGFILE:{}:{}".format(node_ip_address, log_filename.decode("ascii")) + log_filename_key = "LOGFILE:{}:{}".format(node_ip_address, + log_filename.decode("ascii")) logfile = await redis_conn.execute("lrange", log_filename_key, 0, -1) logfile = [line.decode("ascii") for line in logfile] reply[node_ip_address][log_filename.decode("ascii")] = logfile @@ -306,8 +356,10 @@ async def handle_get_log_files(websocket, redis_conn): # Send the reply back to the front end. await websocket.send(json.dumps(reply)) + async def serve_requests(websocket, path): - redis_conn = await aioredis.create_connection((redis_ip_address, redis_port), loop=loop) + redis_conn = await aioredis.create_connection((redis_ip_address, redis_port), + loop=loop) while True: command = json.loads(await websocket.recv()) print("received command {}".format(command)) @@ -352,10 +404,10 @@ async def serve_requests(websocket, path): "data_size": content[5].decode()}) await websocket.send(json.dumps(result)) elif command["command"] == "get-object-info": - # TODO(pcm): Get the object here (have to connect to ray) and ship content - # and type back to webclient. One challenge here is that the naive - # implementation will block the web ui backend, which is not ok if it is - # serving multiple users. + # TODO(pcm): Get the object here (have to connect to ray) and ship + # content and type back to webclient. One challenge here is that the + # naive implementation will block the web ui backend, which is not ok if + # it is serving multiple users. await websocket.send(json.dumps({"object_id": "none"})) elif command["command"] == "get-tasks": result = [] @@ -372,7 +424,8 @@ async def serve_requests(websocket, path): worker_id, task_id = key_to_hex_identifiers(key) content = await redis_conn.execute("lrange", key, "0", "-1") data = json.loads(content[0].decode()) - begin_and_end_time = [timestamp for (timestamp, task, kind, info) in data if task == "ray:task"] + begin_and_end_time = [timestamp for (timestamp, task, kind, info) + in data if task == "ray:task"] tasks[worker_id].append({"task_id": task_id, "start_task": min(begin_and_end_time), "end_task": max(begin_and_end_time)}) @@ -396,8 +449,8 @@ if __name__ == "__main__": redis_ip_address, redis_port = redis_address[0], int(redis_address[1]) # The port here must match the value used by the frontend to connect over - # websockets. TODO(richard): Automatically increment the port if it is already - # taken. + # websockets. TODO(richard): Automatically increment the port if it is + # already taken. port = 8888 loop.run_until_complete(cache_data_from_redis(redis_ip_address, redis_port))