mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 11:01:06 +08:00
Leave ray.wait calls open until the task or actor exits (#5234)
* Regression test * Split TaskDependencyManager::SubscribeDependencies into ray.get and ray.wait dependencies - Some initial implementation * unit test * Improve unit tests for TaskDependencyManager * Implement SubscribeWaitDependencies and UnsubscribeWaitDependencies, unit tests passing * Add ray.wait python test for drivers that exit early * Add WorkerID to Worker * Update test to use two nodes * Regression test for ray.wait passes * Extend regression test to include ray.wait from an actor * Fix ClientID and WorkerIDs * lint * lint * Remove unnecessary ray_get argument * fix build
This commit is contained in:
@@ -8,6 +8,10 @@ import random
|
||||
import numpy as np
|
||||
import os
|
||||
import pytest
|
||||
try:
|
||||
import pytest_timeout
|
||||
except ImportError:
|
||||
pytest_timeout = None
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
@@ -2647,3 +2651,78 @@ def test_decorated_method(ray_start_regular):
|
||||
assert isinstance(object_id, ray.ObjectID)
|
||||
assert extra == {"kwarg": 3}
|
||||
assert ray.get(object_id) == 7 # 2 * 3 + 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
pytest_timeout is None,
|
||||
reason="Timeout package not installed; skipping test that may hang.")
|
||||
@pytest.mark.timeout(10)
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_cluster", [{
|
||||
"num_cpus": 1,
|
||||
"num_nodes": 2,
|
||||
}], indirect=True)
|
||||
def test_ray_wait_dead_actor(ray_start_cluster):
|
||||
"""Tests that methods completed by dead actors are returned as ready"""
|
||||
cluster = ray_start_cluster
|
||||
|
||||
@ray.remote(num_cpus=1)
|
||||
class Actor(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def local_plasma(self):
|
||||
return ray.worker.global_worker.plasma_client.store_socket_name
|
||||
|
||||
def ping(self):
|
||||
time.sleep(1)
|
||||
|
||||
# Create some actors and wait for them to initialize.
|
||||
num_nodes = len(cluster.list_all_nodes())
|
||||
actors = [Actor.remote() for _ in range(num_nodes)]
|
||||
ray.get([actor.ping.remote() for actor in actors])
|
||||
|
||||
# Ping the actors and make sure the tasks complete.
|
||||
ping_ids = [actor.ping.remote() for actor in actors]
|
||||
ray.get(ping_ids)
|
||||
# Evict the result from the node that we're about to kill.
|
||||
remote_node = cluster.list_all_nodes()[-1]
|
||||
remote_ping_id = None
|
||||
for i, actor in enumerate(actors):
|
||||
if ray.get(actor.local_plasma.remote()
|
||||
) == remote_node.plasma_store_socket_name:
|
||||
remote_ping_id = ping_ids[i]
|
||||
ray.internal.free([remote_ping_id], local_only=True)
|
||||
cluster.remove_node(remote_node)
|
||||
|
||||
# Repeatedly call ray.wait until the exception for the dead actor is
|
||||
# received.
|
||||
unready = ping_ids[:]
|
||||
while unready:
|
||||
_, unready = ray.wait(unready, timeout=0)
|
||||
time.sleep(1)
|
||||
|
||||
with pytest.raises(ray.exceptions.RayActorError):
|
||||
ray.get(ping_ids)
|
||||
|
||||
# Evict the result from the dead node.
|
||||
ray.internal.free([remote_ping_id], local_only=True)
|
||||
# Create an actor on the local node that will call ray.wait in a loop.
|
||||
head_node_resource = "HEAD_NODE"
|
||||
ray.experimental.set_resource(head_node_resource, 1)
|
||||
|
||||
@ray.remote(num_cpus=0, resources={head_node_resource: 1})
|
||||
class ParentActor(object):
|
||||
def __init__(self, ping_ids):
|
||||
self.unready = ping_ids
|
||||
|
||||
def wait(self):
|
||||
_, self.unready = ray.wait(self.unready, timeout=0)
|
||||
return len(self.unready) == 0
|
||||
|
||||
# Repeatedly call ray.wait through the local actor until the exception for
|
||||
# the dead actor is received.
|
||||
parent_actor = ParentActor.remote(ping_ids)
|
||||
failure_detected = False
|
||||
while not failure_detected:
|
||||
failure_detected = ray.get(parent_actor.wait.remote())
|
||||
|
||||
@@ -411,7 +411,7 @@ def test_driver_exiting_when_worker_blocked(call_ray_start):
|
||||
ray.init(redis_address=redis_address)
|
||||
|
||||
# Define a driver that creates two tasks, one that runs forever and the
|
||||
# other blocked on the first.
|
||||
# other blocked on the first in a `ray.get`.
|
||||
driver_script = """
|
||||
import time
|
||||
import ray
|
||||
@@ -425,6 +425,30 @@ def g():
|
||||
g.remote()
|
||||
time.sleep(1)
|
||||
print("success")
|
||||
""".format(redis_address)
|
||||
|
||||
# Create some drivers and let them exit and make sure everything is
|
||||
# still alive.
|
||||
for _ in range(3):
|
||||
out = run_string_as_driver(driver_script)
|
||||
# Make sure the first driver ran to completion.
|
||||
assert "success" in out
|
||||
|
||||
# Define a driver that creates two tasks, one that runs forever and the
|
||||
# other blocked on the first in a `ray.wait`.
|
||||
driver_script = """
|
||||
import time
|
||||
import ray
|
||||
ray.init(redis_address="{}")
|
||||
@ray.remote
|
||||
def f():
|
||||
time.sleep(10**6)
|
||||
@ray.remote
|
||||
def g():
|
||||
ray.wait([f.remote()])
|
||||
g.remote()
|
||||
time.sleep(1)
|
||||
print("success")
|
||||
""".format(redis_address)
|
||||
|
||||
# Create some drivers and let them exit and make sure everything is
|
||||
@@ -448,6 +472,31 @@ def g(x):
|
||||
g.remote(ray.ObjectID(ray.utils.hex_to_binary("{}")))
|
||||
time.sleep(1)
|
||||
print("success")
|
||||
""".format(redis_address, nonexistent_id_hex)
|
||||
|
||||
# Create some drivers and let them exit and make sure everything is
|
||||
# still alive.
|
||||
for _ in range(3):
|
||||
out = run_string_as_driver(driver_script)
|
||||
# Simulate the nonexistent dependency becoming available.
|
||||
ray.worker.global_worker.put_object(
|
||||
ray.ObjectID(nonexistent_id_bytes), None)
|
||||
# Make sure the first driver ran to completion.
|
||||
assert "success" in out
|
||||
|
||||
nonexistent_id_bytes = _random_string()
|
||||
nonexistent_id_hex = ray.utils.binary_to_hex(nonexistent_id_bytes)
|
||||
# Define a driver that calls `ray.wait` on a nonexistent object.
|
||||
driver_script = """
|
||||
import time
|
||||
import ray
|
||||
ray.init(redis_address="{}")
|
||||
@ray.remote
|
||||
def g():
|
||||
ray.wait(ray.ObjectID(ray.utils.hex_to_binary("{}")))
|
||||
g.remote()
|
||||
time.sleep(1)
|
||||
print("success")
|
||||
""".format(redis_address, nonexistent_id_hex)
|
||||
|
||||
# Create some drivers and let them exit and make sure everything is
|
||||
|
||||
Reference in New Issue
Block a user