Simplify imports and exports and provide driver isolation for remote functions. (#288)

* Remove import counter and export counter.

* Provide isolation between drivers for remote functions.

* Add test for driver function isolation.

* Hash source code into function ID to reduce likelihood of collisions.

* Fix failure test example.

* Replace assertTrue with assertIn to improve failure messages in tests.

* Fix failure test.
This commit is contained in:
Robert Nishihara
2017-02-16 11:30:35 -08:00
committed by Philipp Moritz
parent 883f945db4
commit 88a5b4e77b
6 changed files with 252 additions and 181 deletions
+21 -23
View File
@@ -27,18 +27,19 @@ def get_actor_method_function_id(attr):
Returns:
Function ID corresponding to the method.
"""
function_id = hashlib.sha1()
function_id.update(attr.encode("ascii"))
return photon.ObjectID(function_id.digest())
function_id_hash = hashlib.sha1()
function_id_hash.update(attr.encode("ascii"))
function_id = function_id_hash.digest()
assert len(function_id) == 20
return photon.ObjectID(function_id)
def fetch_and_register_actor(key, worker):
"""Import an actor."""
driver_id, actor_id_str, actor_name, module, pickled_class, class_export_counter = \
worker.redis_client.hmget(key, ["driver_id", "actor_id", "name", "module", "class", "class_export_counter"])
driver_id, actor_id_str, actor_name, module, pickled_class = \
worker.redis_client.hmget(key, ["driver_id", "actor_id", "name", "module", "class"])
actor_id = photon.ObjectID(actor_id_str)
actor_name = actor_name.decode("ascii")
module = module.decode("ascii")
class_export_counter = int(class_export_counter)
try:
unpickled_class = pickling.loads(pickled_class)
except:
@@ -49,16 +50,17 @@ def fetch_and_register_actor(key, worker):
worker.actors[actor_id_str] = unpickled_class.__new__(unpickled_class)
for (k, v) in inspect.getmembers(unpickled_class, predicate=(lambda x: inspect.isfunction(x) or inspect.ismethod(x))):
function_id = get_actor_method_function_id(k).id()
worker.function_names[function_id] = k
worker.functions[function_id] = v
worker.functions[driver_id][function_id] = (k, v)
# We do not set worker.function_properties[driver_id][function_id] because
# we currently do need the actor worker to submit new tasks for the actor.
def export_actor(actor_id, Class, worker):
def export_actor(actor_id, Class, actor_method_names, worker):
"""Export an actor to redis.
Args:
actor_id: The ID of the actor.
Class: Name of the class to be exported as an actor.
worker: The worker class
actor_method_names (list): A list of the names of this actor's methods.
"""
ray.worker.check_main_thread()
if worker.mode is None:
@@ -66,28 +68,25 @@ def export_actor(actor_id, Class, worker):
key = "Actor:{}".format(actor_id.id())
pickled_class = pickling.dumps(Class)
# For now, all actor methods have 1 return value and require 0 CPUs and GPUs.
driver_id = worker.task_driver_id.id()
for actor_method_name in actor_method_names:
function_id = get_actor_method_function_id(actor_method_name).id()
worker.function_properties[driver_id][function_id] = (1, 0, 0)
# Select a local scheduler for the actor.
local_schedulers = state.get_local_schedulers()
local_scheduler_id = random.choice(local_schedulers)
worker.redis_client.publish("actor_notifications", actor_id.id() + local_scheduler_id)
# The export counter is computed differently depending on whether we are
# currently in a driver or a worker.
if worker.mode in [ray.SCRIPT_MODE, ray.SILENT_MODE]:
export_counter = worker.driver_export_counter
elif worker.mode == ray.WORKER_MODE:
# We don't actually need export counters for actors.
export_counter = 0
d = {"driver_id": worker.task_driver_id.id(),
d = {"driver_id": driver_id,
"actor_id": actor_id.id(),
"name": Class.__name__,
"module": Class.__module__,
"class": pickled_class,
"class_export_counter": export_counter}
"class": pickled_class}
worker.redis_client.hmset(key, d)
worker.redis_client.rpush("Exports", key)
worker.driver_export_counter += 1
def actor(Class):
# The function actor_method_call gets called if somebody tries to call a
@@ -105,7 +104,6 @@ def actor(Class):
num_cpus = 0
num_gpus = 0
object_ids = ray.worker.global_worker.submit_task(function_id, "", args,
num_cpus, num_gpus,
actor_id=actor_id)
if len(object_ids) == 1:
return object_ids[0]
@@ -116,7 +114,7 @@ def actor(Class):
def __init__(self, *args, **kwargs):
self._ray_actor_id = random_actor_id()
self._ray_actor_methods = {k: v for (k, v) in inspect.getmembers(Class, predicate=(lambda x: inspect.isfunction(x) or inspect.ismethod(x)))}
export_actor(self._ray_actor_id, Class, ray.worker.global_worker)
export_actor(self._ray_actor_id, Class, self._ray_actor_methods, ray.worker.global_worker)
# Call __init__ as a remote function.
if "__init__" in self._ray_actor_methods.keys():
actor_method_call(self._ray_actor_id, "__init__", *args, **kwargs)
+79 -127
View File
@@ -2,22 +2,23 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import hashlib
import os
import sys
import time
import traceback
import copy
import collections
import funcsigs
import numpy as np
import colorama
import atexit
import collections
import colorama
import copy
import funcsigs
import hashlib
import inspect
import json
import numpy as np
import os
import random
import redis
import threading
import string
import sys
import threading
import time
import traceback
# Ray modules
import ray.pickling as pickling
@@ -373,61 +374,27 @@ class Worker(object):
that connect has been called already.
cached_functions_to_run (List): A list of functions to run on all of the
workers that should be exported as soon as connect is called.
driver_export_counter (int): The number of exports that the driver has
exported. This is only used on the driver.
worker_import_counter (int): The number of exports that the worker has
imported so far. This is only used on the workers.
"""
def __init__(self):
"""Initialize a Worker object."""
self.functions = {}
# Use a defaultdict for the number of return values. If this is accessed
# with a missing key, the default value of 1 is returned, and that key value
# pair is added to the dict.
self.num_return_vals = collections.defaultdict(lambda: 1)
self.function_names = {}
self.function_export_counters = {}
# The functions field is a dictionary that maps a driver ID to a dictionary
# of functions that have been registered for that driver (this inner
# dictionary maps function IDs to a tuple of the function name and the
# function itself). This should only be used on workers that execute remote
# functions.
self.functions = collections.defaultdict(lambda: {})
# The function_properties field is a dictionary that maps a driver ID to a
# dictionary of functions that have been registered for that driver (this
# inner dictionary maps function IDs to a tuple of the number of values
# returned by that function, the number of CPUs required by that function,
# and the number of GPUs required by that function). This is used when
# submitting a function (which can be done both on workers and on drivers).
self.function_properties = collections.defaultdict(lambda: {})
self.connected = False
self.mode = None
self.cached_remote_functions = []
self.cached_functions_to_run = []
# The driver_export_counter and worker_import_counter are used to make sure
# that no task executes before everything it needs is present. For example,
# if we define a remote function f, a worker cannot execute a task for f
# until the worker has imported the function f.
# - When a remote function, a reusable variable, or a function to run is
# exported, the driver_export_counter is incremented. These exports must
# take place from the driver.
# - When an actor is created, the driver_export_counter is NOT
# incremented. Note that an actor can be created from a driver or from
# any worker.
# - When a worker imports a remote function, a reusable variable, or a
# function to run, its worker_import_counter is incremented.
# - Notably, when an actor is imported, its worker_import_counter is NOT
# incremented.
# - Whenever a remote function is DEFINED on the driver, it records the
# value of the driver_export_counter and a worker will not execute that
# remote function until it has imported that many exports (excluding
# actors).
# - When an actor is defined.
# a) If the actor is created on a driver, it records the
# driver_export_counter.
# b) If the actor is created inside a task on a regular worker, it
# records the driver_export_counter associated with the function in
# task creating the actor.
# c) If the actor is created inside a task on an actor worker, it
# records
# The worker that ultimately runs the actor will not execute any tasks
# until it has imported that many imports.
#
# TODO(rkn): These counters must be tracked separately for each driver.
# TODO(rkn): Maybe none of these counters are necessary? When executing a
# regular task, workers can just wait until the function ID is present. When
# executing an actor task, the actor worker can just wait until the actor
# has been defined.
self.driver_export_counter = 0
self.worker_import_counter = 0
self.fetch_and_register = {}
self.actors = {}
# Use a defaultdict for the actor counts. If this is accessed with a missing
@@ -526,7 +493,7 @@ class Worker(object):
assert final_results[i][0] == object_ids[i].id()
return [result[1][0] for result in final_results]
def submit_task(self, function_id, func_name, args, num_cpus, num_gpus, actor_id=photon.ObjectID(NIL_ACTOR_ID)):
def submit_task(self, function_id, func_name, args, actor_id=photon.ObjectID(NIL_ACTOR_ID)):
"""Submit a remote task to the scheduler.
Tell the scheduler to schedule the execution of the function with name
@@ -538,8 +505,6 @@ class Worker(object):
args (List[Any]): The arguments to pass into the function. Arguments can
be object IDs or they can be values. If they are values, they
must be serializable objecs.
num_cpus (int): The number of cpu cores this task requires to run.
num_gpus (int): The number of gpus this task requires to run.
"""
with log_span("ray:submit_task", worker=self):
check_main_thread()
@@ -554,11 +519,14 @@ class Worker(object):
else:
args_for_photon.append(put(arg))
# Look up the various function properties.
num_return_vals, num_cpus, num_gpus = self.function_properties[self.task_driver_id.id()][function_id.id()]
# Submit the task to Photon.
task = photon.Task(self.task_driver_id,
photon.ObjectID(function_id.id()),
args_for_photon,
self.num_return_vals[function_id.id()],
num_return_vals,
self.current_task_id,
self.task_index,
actor_id, self.actor_counters[actor_id],
@@ -604,7 +572,6 @@ class Worker(object):
"function_id": function_to_run_id,
"function": pickling.dumps(function)})
self.redis_client.rpush("Exports", key)
self.driver_export_counter += 1
def push_error_to_driver(self, driver_id, error_type, message, data=None):
"""Push an error message to the driver to be printed in the background.
@@ -963,8 +930,6 @@ def cleanup(worker=global_worker):
disconnect(worker)
worker.set_mode(None)
worker.driver_export_counter = 0
worker.worker_import_counter = 0
if hasattr(worker, "plasma_client"):
worker.plasma_client.shutdown()
services.cleanup()
@@ -1037,14 +1002,13 @@ If this driver is hanging, start a new one with
def fetch_and_register_remote_function(key, worker=global_worker):
"""Import a remote function."""
driver_id, function_id_str, function_name, serialized_function, num_return_vals, module, function_export_counter, num_cpus, num_gpus = \
driver_id, function_id_str, function_name, serialized_function, num_return_vals, module, num_cpus, num_gpus = \
worker.redis_client.hmget(key, ["driver_id",
"function_id",
"name",
"function",
"num_return_vals",
"module",
"function_export_counter",
"num_cpus",
"num_gpus"])
function_id = photon.ObjectID(function_id_str)
@@ -1053,19 +1017,14 @@ def fetch_and_register_remote_function(key, worker=global_worker):
num_cpus = int(num_cpus)
num_gpus = int(num_gpus)
module = module.decode("ascii")
function_export_counter = int(function_export_counter)
worker.function_names[function_id.id()] = function_name
worker.num_return_vals[function_id.id()] = num_return_vals
worker.function_export_counters[function_id.id()] = function_export_counter
# This is a placeholder in case the function can't be unpickled. This will be
# overwritten if the function is unpickled successfully.
# overwritten if the function is successfully registered.
def f():
raise Exception("This function was not imported properly.")
worker.functions[function_id.id()] = remote(num_return_vals=num_return_vals,
function_id=function_id,
num_cpus=num_cpus,
num_gpus=num_gpus)(lambda *xs: f())
remote_f_placeholder = remote(function_id=function_id)(lambda *xs: f())
worker.functions[driver_id][function_id.id()] = (function_name, remote_f_placeholder)
worker.function_properties[driver_id][function_id.id()] = (num_return_vals, num_cpus, num_gpus)
try:
function = pickling.loads(serialized_function)
@@ -1081,10 +1040,7 @@ def fetch_and_register_remote_function(key, worker=global_worker):
else:
# TODO(rkn): Why is the below line necessary?
function.__module__ = module
worker.functions[function_id.id()] = remote(num_return_vals=num_return_vals,
function_id=function_id,
num_cpus=num_cpus,
num_gpus=num_gpus)(function)
worker.functions[driver_id][function_id.id()] = (function_name, remote(function_id=function_id)(function))
# Add the function to the function table.
worker.redis_client.rpush("FunctionTable:{}".format(function_id.id()), worker.worker_id)
@@ -1133,10 +1089,7 @@ def import_thread(worker):
# in the loop.
worker.import_pubsub_client.psubscribe("__keyspace@0__:Exports")
worker_info_key = "WorkerInfo:{}".format(worker.worker_id)
worker.redis_client.hset(worker_info_key, "export_counter", 0)
worker.worker_import_counter = 0
# The number of imports is similar to the worker_import_counter except that it
# also counts actors.
# Keep track of the number of imports that we've imported.
num_imported = 0
# Get the exports that occurred before the call to psubscribe.
@@ -1157,10 +1110,6 @@ def import_thread(worker):
worker.fetch_and_register["Actor"](key, worker)
else:
raise Exception("This code should be unreachable.")
# Actors do not contribute to the import counter.
if not key.startswith(b"Actor"):
worker.redis_client.hincrby(worker_info_key, "export_counter", 1)
worker.worker_import_counter += 1
num_imported += 1
for msg in worker.import_pubsub_client.listen():
@@ -1189,10 +1138,6 @@ def import_thread(worker):
worker.fetch_and_register["Actor"](key, worker)
else:
raise Exception("This code should be unreachable.")
# Actors do not contribute to the import counter.
if not key.startswith(b"Actor"):
worker.redis_client.hincrby(worker_info_key, "export_counter", 1)
worker.worker_import_counter += 1
num_imported += 1
def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, actor_id=NIL_ACTOR_ID):
@@ -1522,13 +1467,12 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
remaining_ids = [photon.ObjectID(object_id) for object_id in remaining_ids]
return ready_ids, remaining_ids
def wait_for_valid_import_counter(function_id, driver_id, timeout=5, worker=global_worker):
"""Wait until this worker has imported enough to execute the function.
def wait_for_function(function_id, driver_id, timeout=5, worker=global_worker):
"""Wait until the function to be executed is present on this worker.
This method will simply loop until the import thread has imported enough of
the exports to execute the function. If we spend too long in this loop, that
may indicate a problem somewhere and we will push an error message to the
user.
This method will simply loop until the import thread has imported the relevant
function. If we spend too long in this loop, that may indicate a problem
somewhere and we will push an error message to the user.
If this worker is an actor, then this will wait until the actor has been
defined.
@@ -1545,17 +1489,14 @@ def wait_for_valid_import_counter(function_id, driver_id, timeout=5, worker=glob
num_warnings_sent = 0
while True:
with worker.lock:
if worker.actor_id == NIL_ACTOR_ID and function_id.id() in worker.functions and (worker.function_export_counters[function_id.id()] <= worker.worker_import_counter):
if worker.actor_id == NIL_ACTOR_ID and function_id.id() in worker.functions[driver_id]:
break
elif worker.actor_id != NIL_ACTOR_ID and worker.actor_id in worker.actors:
break
if time.time() - start_time > timeout * (num_warnings_sent + 1):
if function_id.id() not in worker.functions:
warning_message = "This worker was asked to execute a function that it does not have registered. You may have to restart Ray."
else:
warning_message = "This worker's import counter is too small."
warning_message = "This worker was asked to execute a function that it does not have registered. You may have to restart Ray."
if not warning_sent:
worker.push_error_to_driver(driver_id, "import_counter",
worker.push_error_to_driver(driver_id, "wait_for_function",
warning_message)
warning_sent = True
time.sleep(0.001)
@@ -1614,18 +1555,18 @@ def main_loop(worker=global_worker):
function_id = task.function_id()
args = task.arguments()
return_object_ids = task.returns()
function_name = worker.function_names[function_id.id()]
function_name, function_executor = worker.functions[worker.task_driver_id.id()][function_id.id()]
# Get task arguments from the object store.
with log_span("ray:task:get_arguments", worker=worker):
arguments = get_arguments_for_execution(worker.functions[function_id.id()], args, worker)
arguments = get_arguments_for_execution(function_name, args, worker)
# Execute the task.
with log_span("ray:task:execute", worker=worker):
if task.actor_id().id() == NIL_ACTOR_ID:
outputs = worker.functions[task.function_id().id()].executor(arguments)
outputs = function_executor.executor(arguments)
else:
outputs = worker.functions[task.function_id().id()](worker.actors[task.actor_id().id()], *arguments)
outputs = function_executor(worker.actors[task.actor_id().id()], *arguments)
# Store the outputs in the local object store.
with log_span("ray:task:store_outputs", worker=worker):
@@ -1680,11 +1621,11 @@ def main_loop(worker=global_worker):
task = worker.photon_client.get_task()
function_id = task.function_id()
# Check that the number of imports we have is at least as great as the
# export counter for the task. If not, wait until we have imported enough.
# We will push warnings to the user if we spend too long in this loop.
with log_span("ray:wait_for_import_counter", worker=worker):
wait_for_valid_import_counter(function_id, task.driver_id().id(), worker=worker)
# Wait until the function to be executed has actually been registered on
# this worker. We will push warnings to the user if we spend too long in
# this loop.
with log_span("ray:wait_for_function", worker=worker):
wait_for_function(function_id, task.driver_id().id(), worker=worker)
# Execute the task.
# TODO(rkn): Consider acquiring this lock with a timeout and pushing a
@@ -1695,7 +1636,8 @@ def main_loop(worker=global_worker):
with worker.lock:
log(event_type="ray:acquire_lock", kind=LOG_SPAN_END, worker=worker)
contents = {"function_name": worker.function_names[function_id.id()],
function_name, _ = worker.functions[task.driver_id().id()][function_id.id()]
contents = {"function_name": function_name,
"task_id": task.task_id().hex()}
with log_span("ray:task", contents=contents, worker=worker):
process_task(task)
@@ -1703,7 +1645,7 @@ def main_loop(worker=global_worker):
# Push all of the log events to the global state store.
flush_log()
def _submit_task(function_id, func_name, args, num_cpus, num_gpus, worker=global_worker):
def _submit_task(function_id, func_name, args, worker=global_worker):
"""This is a wrapper around worker.submit_task.
We use this wrapper so that in the remote decorator, we can call _submit_task
@@ -1711,7 +1653,7 @@ def _submit_task(function_id, func_name, args, num_cpus, num_gpus, worker=global
serialize remote functions, we don't attempt to serialize the worker object,
which cannot be serialized.
"""
return worker.submit_task(function_id, func_name, args, num_cpus, num_gpus)
return worker.submit_task(function_id, func_name, args)
def _mode(worker=global_worker):
"""This is a wrapper around worker.mode.
@@ -1751,14 +1693,13 @@ def _export_environment_variable(name, environment_variable, worker=global_worke
"initializer": pickling.dumps(environment_variable.initializer),
"reinitializer": pickling.dumps(environment_variable.reinitializer)})
worker.redis_client.rpush("Exports", key)
worker.driver_export_counter += 1
def export_remote_function(function_id, func_name, func, num_return_vals, num_cpus, num_gpus, worker=global_worker):
check_main_thread()
if _mode(worker) not in [SCRIPT_MODE, SILENT_MODE]:
raise Exception("export_remote_function can only be called on a driver.")
worker.function_properties[worker.task_driver_id.id()][function_id.id()] = (num_return_vals, num_cpus, num_gpus)
key = "RemoteFunction:{}".format(function_id.id())
worker.num_return_vals[function_id.id()] = num_return_vals
pickled_func = pickling.dumps(func)
worker.redis_client.hmset(key, {"driver_id": worker.task_driver_id.id(),
"function_id": function_id.id(),
@@ -1766,11 +1707,9 @@ def export_remote_function(function_id, func_name, func, num_return_vals, num_cp
"module": func.__module__,
"function": pickled_func,
"num_return_vals": num_return_vals,
"function_export_counter": worker.driver_export_counter,
"num_cpus": num_cpus,
"num_gpus": num_gpus})
worker.redis_client.rpush("Exports", key)
worker.driver_export_counter += 1
def remote(*args, **kwargs):
"""This decorator is used to create remote functions.
@@ -1778,13 +1717,26 @@ def remote(*args, **kwargs):
Args:
num_return_vals (int): The number of object IDs that a call to this function
should return.
num_cpus (int): The number of CPUs needed to execute this function. This
should only be passed in when defining the remote function on the driver.
num_gpus (int): The number of GPUs needed to execute this function. This
should only be passed in when defining the remote function on the driver.
"""
worker = global_worker
def make_remote_decorator(num_return_vals, num_cpus, num_gpus, func_id=None):
def remote_decorator(func):
func_name = "{}.{}".format(func.__module__, func.__name__)
if func_id is None:
function_id = FunctionID((hashlib.sha256(func_name.encode("ascii")).digest())[:20])
# Compute the function ID as a hash of the function name as well as the
# source code. We could in principle hash in the values in the closure
# of the function, but that is likely to introduce non-determinism in
# the computation of the function ID.
function_id_hash = hashlib.sha1()
function_id_hash.update(func_name.encode("ascii"))
function_id_hash.update(inspect.getsource(func).encode("ascii"))
function_id = function_id_hash.digest()
assert len(function_id) == 20
function_id = FunctionID(function_id)
else:
function_id = func_id
@@ -1807,7 +1759,7 @@ def remote(*args, **kwargs):
_env()._reinitialize()
_env()._running_remote_function_locally = False
return result
objectids = _submit_task(function_id, func_name, args, num_cpus, num_gpus)
objectids = _submit_task(function_id, func_name, args)
if len(objectids) == 1:
return objectids[0]
elif len(objectids) > 1:
@@ -1909,7 +1861,7 @@ def check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaul
if has_vararg_param and any([d != funcsigs._empty for _, d in keyword_defaults]):
raise "Function {} has a *args argument as well as a keyword argument, which is currently not supported.".format(name)
def get_arguments_for_execution(function, serialized_args, worker=global_worker):
def get_arguments_for_execution(function_name, serialized_args, worker=global_worker):
"""Retrieve the arguments for the remote function.
This retrieves the values for the arguments to the remote function that were
@@ -1917,8 +1869,8 @@ def get_arguments_for_execution(function, serialized_args, worker=global_worker)
This is called by the worker that is executing the remote function.
Args:
function (Callable): The remote function whose arguments are being
retrieved.
function_name (str): The name of the remote function whose arguments are
being retrieved.
serialized_args (List): The arguments to the function. These are either
strings representing serialized objects passed by value or they are
ObjectIDs.
@@ -1939,7 +1891,7 @@ def get_arguments_for_execution(function, serialized_args, worker=global_worker)
if isinstance(argument, RayTaskError):
# If the result is a RayTaskError, then the task that created this
# object failed, and we should propagate the error message here.
raise RayGetArgumentError(function.__name__, i, arg, argument)
raise RayGetArgumentError(function_name, i, arg, argument)
else:
# pass the argument by value
argument = arg
+1 -1
View File
@@ -709,7 +709,7 @@ void queue_task_locally(local_scheduler_state *state,
/**
* Give a task directly to another local scheduler. This is currently only used
* for assigning actor tasks to the local scheduer responsible for that actor.
* for assigning actor tasks to the local scheduler responsible for that actor.
*
* @param state The scheduler state.
* @param algorithm_state The scheduling algorithm state.
+38 -25
View File
@@ -2,10 +2,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import os
import ray
import sys
import tempfile
import time
import unittest
if sys.version_info >= (3, 0):
from importlib import reload
@@ -67,13 +69,13 @@ class TaskStatusTest(unittest.TestCase):
result = ray.error_info()
self.assertEqual(len(relevant_errors(b"task")), 2)
for task in relevant_errors(b"task"):
self.assertTrue(b"Test function 1 intentionally failed." in task.get(b"message"))
self.assertIn(b"Test function 1 intentionally failed.", task.get(b"message"))
x = test_functions.throw_exception_fct2.remote()
try:
ray.get(x)
except Exception as e:
self.assertTrue("Test function 2 intentionally failed." in str(e))
self.assertIn("Test function 2 intentionally failed.", str(e))
else:
self.assertTrue(False) # ray.get should throw an exception
@@ -82,7 +84,7 @@ class TaskStatusTest(unittest.TestCase):
try:
ray.get(ref)
except Exception as e:
self.assertTrue("Test function 3 intentionally failed." in str(e))
self.assertIn("Test function 3 intentionally failed.", str(e))
else:
self.assertTrue(False) # ray.get should throw an exception
@@ -91,29 +93,40 @@ class TaskStatusTest(unittest.TestCase):
def testFailImportingRemoteFunction(self):
ray.init(num_workers=2, driver_mode=ray.SILENT_MODE)
# This example is somewhat contrived. It should be successfully pickled, and
# then it should throw an exception when it is unpickled. This may depend a
# bit on the specifics of our pickler.
def reducer(*args):
raise Exception("There is a problem here.")
class Foo(object):
def __init__(self):
self.__name__ = "Foo_object"
self.func_doc = ""
self.__globals__ = {}
def __reduce__(self):
return reducer, ()
def __call__(self):
return
f = ray.remote(Foo())
# Create the contents of a temporary Python file.
temporary_python_file = """
def temporary_helper_function():
return 1
"""
f = tempfile.NamedTemporaryFile(suffix=".py")
f.write(temporary_python_file.encode("ascii"))
f.flush()
directory = os.path.dirname(f.name)
# Get the module name and strip ".py" from the end.
module_name = os.path.basename(f.name)[:-3]
sys.path.append(directory)
module = __import__(module_name)
# Define a function that closes over this temporary module. This should fail
# when it is unpickled.
@ray.remote
def g():
return module.temporary_python_file()
wait_for_errors(b"register_remote_function", 2)
self.assertTrue(b"There is a problem here." in ray.error_info()[0][b"message"])
self.assertIn(b"No module named", ray.error_info()[0][b"message"])
self.assertIn(b"No module named", ray.error_info()[1][b"message"])
# Check that if we try to call the function it throws an exception and does
# not hang.
for _ in range(10):
self.assertRaises(Exception, lambda : ray.get(f.remote()))
self.assertRaises(Exception, lambda : ray.get(g.remote()))
f.close()
# Clean up the junk we added to sys.path.
sys.path.pop(-1)
ray.worker.cleanup()
def testFailImportingEnvironmentVariable(self):
@@ -128,7 +141,7 @@ class TaskStatusTest(unittest.TestCase):
ray.env.foo = ray.EnvironmentVariable(initializer)
wait_for_errors(b"register_environment_variable", 2)
# Check that the error message is in the task info.
self.assertTrue(b"The initializer failed." in ray.error_info()[0][b"message"])
self.assertIn(b"The initializer failed.", ray.error_info()[0][b"message"])
ray.worker.cleanup()
@@ -146,7 +159,7 @@ class TaskStatusTest(unittest.TestCase):
use_foo.remote()
wait_for_errors(b"reinitialize_environment_variable", 1)
# Check that the error message is in the task info.
self.assertTrue(b"The reinitializer failed." in ray.error_info()[0][b"message"])
self.assertIn(b"The reinitializer failed.", ray.error_info()[0][b"message"])
ray.worker.cleanup()
@@ -160,8 +173,8 @@ class TaskStatusTest(unittest.TestCase):
wait_for_errors(b"function_to_run", 2)
# Check that the error message is in the task info.
self.assertEqual(len(ray.error_info()), 2)
self.assertTrue(b"Function to run failed." in ray.error_info()[0][b"message"])
self.assertTrue(b"Function to run failed." in ray.error_info()[1][b"message"])
self.assertIn(b"Function to run failed.", ray.error_info()[0][b"message"])
self.assertIn(b"Function to run failed.", ray.error_info()[1][b"message"])
ray.worker.cleanup()
+57 -5
View File
@@ -16,16 +16,22 @@ stop_ray_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../s
class MultiNodeTest(unittest.TestCase):
def testErrorIsolation(self):
def setUp(self):
# Start the Ray processes on this machine.
out = subprocess.check_output([start_ray_script, "--head"]).decode("ascii")
# Get the redis address from the output.
redis_substring_prefix = "redis_address=\""
redis_address_location = out.find(redis_substring_prefix) + len(redis_substring_prefix)
redis_address = out[redis_address_location:]
redis_address = redis_address.split("\"")[0]
self.redis_address = redis_address.split("\"")[0]
def tearDown(self):
# Kill the Ray cluster.
subprocess.Popen([stop_ray_script]).wait()
def testErrorIsolation(self):
# Connect a driver to the Ray cluster.
ray.init(redis_address=redis_address, driver_mode=ray.SILENT_MODE)
ray.init(redis_address=self.redis_address, driver_mode=ray.SILENT_MODE)
# There shouldn't be any errors yet.
self.assertEqual(len(ray.error_info()), 0)
@@ -79,7 +85,7 @@ assert len(ray.error_info()) == 1
assert "{}" in ray.error_info()[0][b"message"].decode("ascii")
print("success")
""".format(redis_address, error_string2, error_string2)
""".format(self.redis_address, error_string2, error_string2)
# Save the driver script as a file so we can call it using subprocess.
with tempfile.NamedTemporaryFile() as f:
@@ -95,7 +101,53 @@ print("success")
self.assertIn(error_string1, ray.error_info()[0][b"message"].decode("ascii"))
ray.worker.cleanup()
subprocess.Popen([stop_ray_script]).wait()
def testRemoteFunctionIsolation(self):
# This test will run multiple remote functions with the same names in two
# different drivers.
# Connect a driver to the Ray cluster.
ray.init(redis_address=self.redis_address, driver_mode=ray.SILENT_MODE)
# Start another driver and make sure that it can define and call its own
# commands with the same names.
driver_script = """
import ray
import time
ray.init(redis_address="{}")
@ray.remote
def f():
return 3
@ray.remote
def g(x, y):
return 4
for _ in range(10000):
result = ray.get([f.remote(), g.remote(0, 0)])
assert result == [3, 4]
print("success")
""".format(self.redis_address)
# Save the driver script as a file so we can call it using subprocess.
with tempfile.NamedTemporaryFile() as f:
f.write(driver_script.encode("ascii"))
f.flush()
out = subprocess.check_output(["python", f.name]).decode("ascii")
@ray.remote
def f():
return 1
@ray.remote
def g(x):
return 2
for _ in range(10000):
result = ray.get([f.remote(), g.remote(0)])
self.assertEqual(result, [1, 2])
# Make sure the other driver succeeded.
self.assertIn("success", out)
ray.worker.cleanup()
class StartRayScriptTest(unittest.TestCase):
+56
View File
@@ -599,6 +599,62 @@ class APITest(unittest.TestCase):
ray.worker.cleanup()
def testIdenticalFunctionNames(self):
# Define a bunch of remote functions and make sure that we don't
# accidentally call an older version.
ray.init(num_workers=2)
num_remote_functions = 100
num_calls = 200
@ray.remote
def f():
return 1
results1 = [f.remote() for _ in range(num_calls)]
@ray.remote
def f():
return 2
results2 = [f.remote() for _ in range(num_calls)]
@ray.remote
def f():
return 3
results3 = [f.remote() for _ in range(num_calls)]
@ray.remote
def f():
return 4
results4 = [f.remote() for _ in range(num_calls)]
@ray.remote
def f():
return 5
results5 = [f.remote() for _ in range(num_calls)]
self.assertEqual(ray.get(results1), num_calls * [1])
self.assertEqual(ray.get(results2), num_calls * [2])
self.assertEqual(ray.get(results3), num_calls * [3])
self.assertEqual(ray.get(results4), num_calls * [4])
self.assertEqual(ray.get(results5), num_calls * [5])
@ray.remote
def g():
return 1
@ray.remote
def g():
return 2
@ray.remote
def g():
return 3
@ray.remote
def g():
return 4
@ray.remote
def g():
return 5
result_values = ray.get([g.remote() for _ in range(num_calls)])
self.assertEqual(result_values, num_calls * [5])
ray.worker.cleanup()
class PythonModeTest(unittest.TestCase):
def testPythonMode(self):