From 3c32343c6399399841c23806c315b0ef37efeaba Mon Sep 17 00:00:00 2001 From: Ion Date: Mon, 11 Feb 2019 20:14:48 +0200 Subject: [PATCH] Ray signal (#3624) --- .travis.yml | 1 + python/ray/experimental/signal.py | 155 ++++++++++++++++ python/ray/worker.py | 4 + test/test_signal.py | 281 ++++++++++++++++++++++++++++++ 4 files changed, 441 insertions(+) create mode 100644 python/ray/experimental/signal.py create mode 100644 test/test_signal.py diff --git a/.travis.yml b/.travis.yml index 1c0e2eef5..25812baf2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -193,6 +193,7 @@ script: - python -m pytest -v --durations=10 test/cython_test.py - python -m pytest -v --durations=10 test/credis_test.py - python -m pytest -v --durations=10 test/node_manager_test.py + - python -m pytest -v --durations=10 test/test_signal.py # TODO(yuhguo): object_manager_test.py requires a lot of CPU/memory, and # better be put in Jenkins. However, it fails frequently in Jenkins, but # works well in Travis. We should consider moving it back to Jenkins once diff --git a/python/ray/experimental/signal.py b/python/ray/experimental/signal.py new file mode 100644 index 000000000..e70eb7a00 --- /dev/null +++ b/python/ray/experimental/signal.py @@ -0,0 +1,155 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import defaultdict + +import ray +import ray.cloudpickle as cloudpickle + + +class Signal(object): + """Base class for Ray signals.""" + pass + + +class ErrorSignal(Signal): + """Signal raised if an exception happens in a task or actor method.""" + + def __init__(self, error): + self.error = error + + +def _get_task_id(source): + """Return the task id associated to the generic source of the signal. + + Args: + source: source of the signal, it can be either an object id, task id, + or actor handle. + + Returns: + - If source is an object id, return id of task which creted object. + - If source is an actor handle, return id of actor's task creator. + - If source is a task id, return same task id. + """ + if type(source) is ray.actor.ActorHandle: + return ray._raylet.compute_task_id( + source._ray_actor_creation_dummy_object_id) + else: + if type(source) is ray.TaskID: + return source + else: + return ray._raylet.compute_task_id(source) + + +def send(signal): + """Send signal. + + The signal has a unique identifier that is computed from (1) the id + of the actor or task sending this signal (i.e., the actor or task calling + this function), and (2) an index that is incremented every time this + source sends a signal. This index starts from 1. + + Args: + signal: Signal to be sent. + """ + if hasattr(ray.worker.global_worker, "actor_creation_task_id"): + global_worker = ray.worker.global_worker + source_key = global_worker.actor_creation_task_id.hex() + else: + # No actors; this function must have been called from a task + source_key = ray.worker.global_worker.current_task_id.hex() + + encoded_signal = ray.utils.binary_to_hex(cloudpickle.dumps(signal)) + ray.worker.global_worker.redis_client.execute_command( + "XADD " + source_key + " * signal " + encoded_signal) + + +def receive(sources, timeout=10**12): + """Get all outstanding signals from sources. + + A source can be either (1) an object id returned by the task (we want + to receive signals from), or (2) an actor handle. + + For each source S, this function returns all signals associated to S + since the last receive() or forget() were invoked on S. If this is the + first call on S, this function returns all past signals generated by S + so far. + + Args: + sources: list of sources from which caller waits for signals. + A source is either an object id identifying the task returning + the object, or an actor handle. + timeout: Time (in seconds) this function waits to get a signal from + a source in sources. If none, return when timeout expires. + + Returns: + The list of signals generated by each source in sources. + This list contain pairs (source, signal). There can be + more than a signal associated with the same source. + """ + if not hasattr(ray.worker.global_worker, "signal_counters"): + ray.worker.global_worker.signal_counters = defaultdict(lambda: b"0") + + signal_counters = ray.worker.global_worker.signal_counters + + # Construct the redis query. + query = "XREAD BLOCK " + # Multiply by 1000x since timeout is in sec and redis expects ms. + query += str(1000 * timeout) + query += " STREAMS " + query += " ".join([_get_task_id(source).hex() for source in sources]) + query += " " + query += " ".join([ + ray.utils.decode(signal_counters[_get_task_id(source)]) + for source in sources + ]) + + answers = ray.worker.global_worker.redis_client.execute_command(query) + if not answers: + return [] + # There will be one answer per source. If there is no signal for a given + # source, redis will return an empty list for that source. + assert len(answers) == len(sources) + + results = [] + # Decoding is a little bit involved. Iterate through all the sources: + for i, answer in enumerate(answers): + # Make sure the answer corresponds to the source + assert ray.utils.decode(answer[0]) == _get_task_id(sources[i]).hex() + # The list of results for that source is stored in answer[1] + for r in answer[1]: + # Now it gets tricky: r[0] is the redis internal sequence id + signal_counters[_get_task_id(sources[i])] = r[0] + # r[1] contains a list with elements (key, value), in our case + # we only have one key "signal" and the value is the signal. + signal = cloudpickle.loads(ray.utils.hex_to_binary(r[1][1])) + results.append((sources[i], signal)) + + return results + + +def forget(sources): + """Ignore all previous signals associated with each source S in sources. + + The index of the next expected signal from S is set to the index of + the last signal that S sent plus 1. This means that the next receive() + on S will only get the signals generated after this function was invoked. + + Args: + sources: list of sources whose past signals are forgotten. + """ + # Just read all signals sent by all sources so far. + # This will results in ignoring these signals. + receive(sources, timeout=0) + + +def reset(): + """ + Reset the worker state associated with any signals that this worker + has received so far. + + If the worker calls receive() on a source next, it will get all the + signals generated by that source starting with index = 1. + """ + ray.worker.global_worker.signal_counters = defaultdict(lambda: b"0") diff --git a/python/ray/worker.py b/python/ray/worker.py index f66616b38..aef0b7ae9 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -23,6 +23,7 @@ import traceback import pyarrow import pyarrow.plasma as plasma import ray.cloudpickle as pickle +import ray.experimental.signal as ray_signal import ray.experimental.state as state import ray.gcs_utils import ray.memory_monitor as memory_monitor @@ -879,6 +880,8 @@ class Worker(object): # Mark the actor init as failed if not self.actor_id.is_nil() and function_name == "__init__": self.mark_actor_init_failed(error) + # Send signal with the error. + ray_signal.send(ray_signal.ErrorSignal(error)) def _wait_for_and_process_task(self, task): """Wait for a task to be ready and process the task. @@ -895,6 +898,7 @@ class Worker(object): if not task.actor_creation_id().is_nil(): assert self.actor_id.is_nil() self.actor_id = task.actor_creation_id() + self.actor_creation_task_id = task.task_id() self.function_actor_manager.load_actor(driver_id, function_descriptor) diff --git a/test/test_signal.py b/test/test_signal.py new file mode 100644 index 000000000..e86812c8e --- /dev/null +++ b/test/test_signal.py @@ -0,0 +1,281 @@ +import pytest + +import ray +import ray.experimental.signal as signal + + +class UserSignal(signal.Signal): + def __init__(self, value): + self.value = value + + +@pytest.fixture +def ray_start(): + # Start the Ray processes. + ray.init(num_cpus=4) + yield None + # The code after the yield will run as teardown code. + ray.shutdown() + + +def receive_all_signals(sources, timeout): + # Get all signals from sources, until there is no signal for a time + # period of timeout. + + results = [] + while True: + r = signal.receive(sources, timeout=timeout) + if len(r) == 0: + return results + else: + results.extend(r) + + +def test_task_to_driver(ray_start): + # Send a signal from a task to the driver. + + @ray.remote + def task_send_signal(value): + signal.send(UserSignal(value)) + return + + signal_value = "simple signal" + object_id = task_send_signal.remote(signal_value) + result_list = signal.receive([object_id], timeout=10) + print(result_list[0][1]) + assert len(result_list) == 1 + + +def test_send_signal_from_actor_to_driver(ray_start): + # Send several signals from an actor, and receive them in the driver. + + @ray.remote + class ActorSendSignal(object): + def __init__(self): + pass + + def send_signal(self, value): + signal.send(UserSignal(value)) + + a = ActorSendSignal.remote() + signal_value = "simple signal" + count = 6 + for i in range(count): + ray.get(a.send_signal.remote(signal_value + str(i))) + + result_list = receive_all_signals([a], timeout=5) + assert len(result_list) == count + for i in range(count): + assert signal_value + str(i) == result_list[i][1].value + + +def test_send_signals_from_actor_to_driver(ray_start): + # Send "count" signal at intervals from an actor and get + # these signals in the driver. + + @ray.remote + class ActorSendSignals(object): + def __init__(self): + pass + + def send_signals(self, value, count): + for i in range(count): + signal.send(UserSignal(value + str(i))) + + a = ActorSendSignals.remote() + signal_value = "simple signal" + count = 20 + a.send_signals.remote(signal_value, count) + received_count = 0 + while True: + result_list = signal.receive([a], timeout=5) + received_count += len(result_list) + if (received_count == count): + break + assert True + + +def test_task_crash(ray_start): + # Get an error when ray.get() is called on the return of a failed task. + + @ray.remote + def crashing_function(): + raise Exception("exception message") + + object_id = crashing_function.remote() + try: + ray.get(object_id) + except Exception as e: + assert type(e) == ray.worker.RayTaskError + finally: + result_list = signal.receive([object_id], timeout=5) + assert len(result_list) == 1 + assert type(result_list[0][1]) == signal.ErrorSignal + + +def test_task_crash_without_get(ray_start): + # Get an error when task failed. + + @ray.remote + def crashing_function(): + raise Exception("exception message") + + object_id = crashing_function.remote() + result_list = signal.receive([object_id], timeout=5) + assert len(result_list) == 1 + assert type(result_list[0][1]) == signal.ErrorSignal + + +def test_actor_crash(ray_start): + # Get an error when ray.get() is called on a return parameter + # of a method that failed. + + @ray.remote + class Actor(object): + def __init__(self): + pass + + def crash(self): + raise Exception("exception message") + + a = Actor.remote() + try: + ray.get(a.crash.remote()) + except Exception as e: + assert type(e) == ray.worker.RayTaskError + finally: + result_list = signal.receive([a], timeout=5) + assert len(result_list) == 1 + assert type(result_list[0][1]) == signal.ErrorSignal + + +def test_actor_crash_init(ray_start): + # Get an error when an actor's __init__ failed. + + @ray.remote + class ActorCrashInit(object): + def __init__(self): + raise Exception("exception message") + + def m(self): + return 1 + + # Do not catch the exception in the __init__. + a = ActorCrashInit.remote() + result_list = signal.receive([a], timeout=5) + assert len(result_list) == 1 + assert type(result_list[0][1]) == signal.ErrorSignal + + +def test_actor_crash_init2(ray_start): + # Get errors when (1) __init__ fails, and (2) subsequently when + # ray.get() is called on the return parameter of another method + # of the actor. + + @ray.remote + class ActorCrashInit(object): + def __init__(self): + raise Exception("exception message") + + def method(self): + return 1 + + a = ActorCrashInit.remote() + try: + ray.get(a.method.remote()) + except Exception as e: + assert type(e) == ray.worker.RayTaskError + finally: + result_list = receive_all_signals([a], timeout=5) + assert len(result_list) == 2 + assert type(result_list[0][1]) == signal.ErrorSignal + + +def test_actor_crash_init3(ray_start): + # Get errors when (1) __init__ fails, and (2) subsequently when + # another method of the actor is invoked. + + @ray.remote + class ActorCrashInit(object): + def __init__(self): + raise Exception("exception message") + + def method(self): + return 1 + + a = ActorCrashInit.remote() + a.method.remote() + result_list = signal.receive([a], timeout=10) + assert len(result_list) == 1 + assert type(result_list[0][1]) == signal.ErrorSignal + + +def test_send_signals_from_actor_to_actor(ray_start): + # Send "count" signal at intervals of 100ms from two actors and get + # these signals in another actor. + + @ray.remote + class ActorSendSignals(object): + def __init__(self): + pass + + def send_signals(self, value, count): + for i in range(count): + signal.send(UserSignal(value + str(i))) + + @ray.remote + class ActorGetSignalsAll(object): + def __init__(self): + self.received_signals = [] + + def register_handle(self, handle): + self.this_actor = handle + + def get_signals(self, source_ids, count): + new_signals = receive_all_signals(source_ids, timeout=5) + for s in new_signals: + self.received_signals.append(s) + if len(self.received_signals) < count: + self.this_actor.get_signals.remote(source_ids, count) + else: + return + + def get_count(self): + return len(self.received_signals) + + a1 = ActorSendSignals.remote() + a2 = ActorSendSignals.remote() + signal_value = "simple signal" + count = 20 + ray.get(a1.send_signals.remote(signal_value, count)) + ray.get(a2.send_signals.remote(signal_value, count)) + + b = ActorGetSignalsAll.remote() + ray.get(b.register_handle.remote(b)) + b.get_signals.remote([a1, a2], count) + received_count = ray.get(b.get_count.remote()) + assert received_count == 2 * count + + +def test_forget(ray_start): + # Send "count" signals on behalf of an actor, then ignore all these + # signals, and then send anther "count" signals on behalf of the same + # actor. Then show that the driver only gets the last "count" signals. + + @ray.remote + class ActorSendSignals(object): + def __init__(self): + pass + + def send_signals(self, value, count): + for i in range(count): + signal.send(UserSignal(value + str(i))) + + a = ActorSendSignals.remote() + signal_value = "simple signal" + count = 5 + ray.get(a.send_signals.remote(signal_value, count)) + signal.forget([a]) + ray.get(a.send_signals.remote(signal_value, count)) + result_list = receive_all_signals([a], timeout=5) + assert len(result_list) == count