mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 10:19:18 +08:00
Ray signal (#3624)
This commit is contained in:
@@ -0,0 +1,155 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
import ray
|
||||
import ray.cloudpickle as cloudpickle
|
||||
|
||||
|
||||
class Signal(object):
|
||||
"""Base class for Ray signals."""
|
||||
pass
|
||||
|
||||
|
||||
class ErrorSignal(Signal):
|
||||
"""Signal raised if an exception happens in a task or actor method."""
|
||||
|
||||
def __init__(self, error):
|
||||
self.error = error
|
||||
|
||||
|
||||
def _get_task_id(source):
|
||||
"""Return the task id associated to the generic source of the signal.
|
||||
|
||||
Args:
|
||||
source: source of the signal, it can be either an object id, task id,
|
||||
or actor handle.
|
||||
|
||||
Returns:
|
||||
- If source is an object id, return id of task which creted object.
|
||||
- If source is an actor handle, return id of actor's task creator.
|
||||
- If source is a task id, return same task id.
|
||||
"""
|
||||
if type(source) is ray.actor.ActorHandle:
|
||||
return ray._raylet.compute_task_id(
|
||||
source._ray_actor_creation_dummy_object_id)
|
||||
else:
|
||||
if type(source) is ray.TaskID:
|
||||
return source
|
||||
else:
|
||||
return ray._raylet.compute_task_id(source)
|
||||
|
||||
|
||||
def send(signal):
|
||||
"""Send signal.
|
||||
|
||||
The signal has a unique identifier that is computed from (1) the id
|
||||
of the actor or task sending this signal (i.e., the actor or task calling
|
||||
this function), and (2) an index that is incremented every time this
|
||||
source sends a signal. This index starts from 1.
|
||||
|
||||
Args:
|
||||
signal: Signal to be sent.
|
||||
"""
|
||||
if hasattr(ray.worker.global_worker, "actor_creation_task_id"):
|
||||
global_worker = ray.worker.global_worker
|
||||
source_key = global_worker.actor_creation_task_id.hex()
|
||||
else:
|
||||
# No actors; this function must have been called from a task
|
||||
source_key = ray.worker.global_worker.current_task_id.hex()
|
||||
|
||||
encoded_signal = ray.utils.binary_to_hex(cloudpickle.dumps(signal))
|
||||
ray.worker.global_worker.redis_client.execute_command(
|
||||
"XADD " + source_key + " * signal " + encoded_signal)
|
||||
|
||||
|
||||
def receive(sources, timeout=10**12):
|
||||
"""Get all outstanding signals from sources.
|
||||
|
||||
A source can be either (1) an object id returned by the task (we want
|
||||
to receive signals from), or (2) an actor handle.
|
||||
|
||||
For each source S, this function returns all signals associated to S
|
||||
since the last receive() or forget() were invoked on S. If this is the
|
||||
first call on S, this function returns all past signals generated by S
|
||||
so far.
|
||||
|
||||
Args:
|
||||
sources: list of sources from which caller waits for signals.
|
||||
A source is either an object id identifying the task returning
|
||||
the object, or an actor handle.
|
||||
timeout: Time (in seconds) this function waits to get a signal from
|
||||
a source in sources. If none, return when timeout expires.
|
||||
|
||||
Returns:
|
||||
The list of signals generated by each source in sources.
|
||||
This list contain pairs (source, signal). There can be
|
||||
more than a signal associated with the same source.
|
||||
"""
|
||||
if not hasattr(ray.worker.global_worker, "signal_counters"):
|
||||
ray.worker.global_worker.signal_counters = defaultdict(lambda: b"0")
|
||||
|
||||
signal_counters = ray.worker.global_worker.signal_counters
|
||||
|
||||
# Construct the redis query.
|
||||
query = "XREAD BLOCK "
|
||||
# Multiply by 1000x since timeout is in sec and redis expects ms.
|
||||
query += str(1000 * timeout)
|
||||
query += " STREAMS "
|
||||
query += " ".join([_get_task_id(source).hex() for source in sources])
|
||||
query += " "
|
||||
query += " ".join([
|
||||
ray.utils.decode(signal_counters[_get_task_id(source)])
|
||||
for source in sources
|
||||
])
|
||||
|
||||
answers = ray.worker.global_worker.redis_client.execute_command(query)
|
||||
if not answers:
|
||||
return []
|
||||
# There will be one answer per source. If there is no signal for a given
|
||||
# source, redis will return an empty list for that source.
|
||||
assert len(answers) == len(sources)
|
||||
|
||||
results = []
|
||||
# Decoding is a little bit involved. Iterate through all the sources:
|
||||
for i, answer in enumerate(answers):
|
||||
# Make sure the answer corresponds to the source
|
||||
assert ray.utils.decode(answer[0]) == _get_task_id(sources[i]).hex()
|
||||
# The list of results for that source is stored in answer[1]
|
||||
for r in answer[1]:
|
||||
# Now it gets tricky: r[0] is the redis internal sequence id
|
||||
signal_counters[_get_task_id(sources[i])] = r[0]
|
||||
# r[1] contains a list with elements (key, value), in our case
|
||||
# we only have one key "signal" and the value is the signal.
|
||||
signal = cloudpickle.loads(ray.utils.hex_to_binary(r[1][1]))
|
||||
results.append((sources[i], signal))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def forget(sources):
|
||||
"""Ignore all previous signals associated with each source S in sources.
|
||||
|
||||
The index of the next expected signal from S is set to the index of
|
||||
the last signal that S sent plus 1. This means that the next receive()
|
||||
on S will only get the signals generated after this function was invoked.
|
||||
|
||||
Args:
|
||||
sources: list of sources whose past signals are forgotten.
|
||||
"""
|
||||
# Just read all signals sent by all sources so far.
|
||||
# This will results in ignoring these signals.
|
||||
receive(sources, timeout=0)
|
||||
|
||||
|
||||
def reset():
|
||||
"""
|
||||
Reset the worker state associated with any signals that this worker
|
||||
has received so far.
|
||||
|
||||
If the worker calls receive() on a source next, it will get all the
|
||||
signals generated by that source starting with index = 1.
|
||||
"""
|
||||
ray.worker.global_worker.signal_counters = defaultdict(lambda: b"0")
|
||||
@@ -23,6 +23,7 @@ import traceback
|
||||
import pyarrow
|
||||
import pyarrow.plasma as plasma
|
||||
import ray.cloudpickle as pickle
|
||||
import ray.experimental.signal as ray_signal
|
||||
import ray.experimental.state as state
|
||||
import ray.gcs_utils
|
||||
import ray.memory_monitor as memory_monitor
|
||||
@@ -879,6 +880,8 @@ class Worker(object):
|
||||
# Mark the actor init as failed
|
||||
if not self.actor_id.is_nil() and function_name == "__init__":
|
||||
self.mark_actor_init_failed(error)
|
||||
# Send signal with the error.
|
||||
ray_signal.send(ray_signal.ErrorSignal(error))
|
||||
|
||||
def _wait_for_and_process_task(self, task):
|
||||
"""Wait for a task to be ready and process the task.
|
||||
@@ -895,6 +898,7 @@ class Worker(object):
|
||||
if not task.actor_creation_id().is_nil():
|
||||
assert self.actor_id.is_nil()
|
||||
self.actor_id = task.actor_creation_id()
|
||||
self.actor_creation_task_id = task.task_id()
|
||||
self.function_actor_manager.load_actor(driver_id,
|
||||
function_descriptor)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user