Files
ray/python/ray/experimental/signal.py
T
2019-02-11 10:14:48 -08:00

156 lines
5.5 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import defaultdict
import ray
import ray.cloudpickle as cloudpickle
class Signal(object):
"""Base class for Ray signals."""
pass
class ErrorSignal(Signal):
"""Signal raised if an exception happens in a task or actor method."""
def __init__(self, error):
self.error = error
def _get_task_id(source):
"""Return the task id associated to the generic source of the signal.
Args:
source: source of the signal, it can be either an object id, task id,
or actor handle.
Returns:
- If source is an object id, return id of task which creted object.
- If source is an actor handle, return id of actor's task creator.
- If source is a task id, return same task id.
"""
if type(source) is ray.actor.ActorHandle:
return ray._raylet.compute_task_id(
source._ray_actor_creation_dummy_object_id)
else:
if type(source) is ray.TaskID:
return source
else:
return ray._raylet.compute_task_id(source)
def send(signal):
"""Send signal.
The signal has a unique identifier that is computed from (1) the id
of the actor or task sending this signal (i.e., the actor or task calling
this function), and (2) an index that is incremented every time this
source sends a signal. This index starts from 1.
Args:
signal: Signal to be sent.
"""
if hasattr(ray.worker.global_worker, "actor_creation_task_id"):
global_worker = ray.worker.global_worker
source_key = global_worker.actor_creation_task_id.hex()
else:
# No actors; this function must have been called from a task
source_key = ray.worker.global_worker.current_task_id.hex()
encoded_signal = ray.utils.binary_to_hex(cloudpickle.dumps(signal))
ray.worker.global_worker.redis_client.execute_command(
"XADD " + source_key + " * signal " + encoded_signal)
def receive(sources, timeout=10**12):
"""Get all outstanding signals from sources.
A source can be either (1) an object id returned by the task (we want
to receive signals from), or (2) an actor handle.
For each source S, this function returns all signals associated to S
since the last receive() or forget() were invoked on S. If this is the
first call on S, this function returns all past signals generated by S
so far.
Args:
sources: list of sources from which caller waits for signals.
A source is either an object id identifying the task returning
the object, or an actor handle.
timeout: Time (in seconds) this function waits to get a signal from
a source in sources. If none, return when timeout expires.
Returns:
The list of signals generated by each source in sources.
This list contain pairs (source, signal). There can be
more than a signal associated with the same source.
"""
if not hasattr(ray.worker.global_worker, "signal_counters"):
ray.worker.global_worker.signal_counters = defaultdict(lambda: b"0")
signal_counters = ray.worker.global_worker.signal_counters
# Construct the redis query.
query = "XREAD BLOCK "
# Multiply by 1000x since timeout is in sec and redis expects ms.
query += str(1000 * timeout)
query += " STREAMS "
query += " ".join([_get_task_id(source).hex() for source in sources])
query += " "
query += " ".join([
ray.utils.decode(signal_counters[_get_task_id(source)])
for source in sources
])
answers = ray.worker.global_worker.redis_client.execute_command(query)
if not answers:
return []
# There will be one answer per source. If there is no signal for a given
# source, redis will return an empty list for that source.
assert len(answers) == len(sources)
results = []
# Decoding is a little bit involved. Iterate through all the sources:
for i, answer in enumerate(answers):
# Make sure the answer corresponds to the source
assert ray.utils.decode(answer[0]) == _get_task_id(sources[i]).hex()
# The list of results for that source is stored in answer[1]
for r in answer[1]:
# Now it gets tricky: r[0] is the redis internal sequence id
signal_counters[_get_task_id(sources[i])] = r[0]
# r[1] contains a list with elements (key, value), in our case
# we only have one key "signal" and the value is the signal.
signal = cloudpickle.loads(ray.utils.hex_to_binary(r[1][1]))
results.append((sources[i], signal))
return results
def forget(sources):
"""Ignore all previous signals associated with each source S in sources.
The index of the next expected signal from S is set to the index of
the last signal that S sent plus 1. This means that the next receive()
on S will only get the signals generated after this function was invoked.
Args:
sources: list of sources whose past signals are forgotten.
"""
# Just read all signals sent by all sources so far.
# This will results in ignoring these signals.
receive(sources, timeout=0)
def reset():
"""
Reset the worker state associated with any signals that this worker
has received so far.
If the worker calls receive() on a source next, it will get all the
signals generated by that source starting with index = 1.
"""
ray.worker.global_worker.signal_counters = defaultdict(lambda: b"0")