mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 03:50:55 +08:00
261 lines
6.7 KiB
Python
261 lines
6.7 KiB
Python
# coding: utf-8
|
|
import asyncio
|
|
import sys
|
|
import threading
|
|
|
|
import pytest
|
|
|
|
import ray
|
|
from ray.test_utils import SignalActor, wait_for_condition
|
|
|
|
|
|
def test_asyncio_actor(ray_start_regular_shared):
|
|
@ray.remote
|
|
class AsyncBatcher:
|
|
def __init__(self):
|
|
self.batch = []
|
|
self.event = asyncio.Event()
|
|
|
|
async def add(self, x):
|
|
self.batch.append(x)
|
|
if len(self.batch) >= 3:
|
|
self.event.set()
|
|
else:
|
|
await self.event.wait()
|
|
return sorted(self.batch)
|
|
|
|
a = AsyncBatcher.remote()
|
|
x1 = a.add.remote(1)
|
|
x2 = a.add.remote(2)
|
|
x3 = a.add.remote(3)
|
|
r1 = ray.get(x1)
|
|
r2 = ray.get(x2)
|
|
r3 = ray.get(x3)
|
|
assert r1 == [1, 2, 3]
|
|
assert r1 == r2 == r3
|
|
|
|
|
|
def test_asyncio_actor_same_thread(ray_start_regular_shared):
|
|
@ray.remote
|
|
class Actor:
|
|
def sync_thread_id(self):
|
|
return threading.current_thread().ident
|
|
|
|
async def async_thread_id(self):
|
|
return threading.current_thread().ident
|
|
|
|
a = Actor.remote()
|
|
sync_id, async_id = ray.get(
|
|
[a.sync_thread_id.remote(),
|
|
a.async_thread_id.remote()])
|
|
assert sync_id == async_id
|
|
|
|
|
|
def test_asyncio_actor_concurrency(ray_start_regular_shared):
|
|
@ray.remote
|
|
class RecordOrder:
|
|
def __init__(self):
|
|
self.history = []
|
|
|
|
async def do_work(self):
|
|
self.history.append("STARTED")
|
|
# Force a context switch
|
|
await asyncio.sleep(0)
|
|
self.history.append("ENDED")
|
|
|
|
def get_history(self):
|
|
return self.history
|
|
|
|
num_calls = 10
|
|
|
|
a = RecordOrder.options(max_concurrency=1).remote()
|
|
ray.get([a.do_work.remote() for _ in range(num_calls)])
|
|
history = ray.get(a.get_history.remote())
|
|
|
|
# We only care about ordered start-end-start-end sequence because
|
|
# coroutines may be executed out of enqueued order.
|
|
answer = []
|
|
for _ in range(num_calls):
|
|
for status in ["STARTED", "ENDED"]:
|
|
answer.append(status)
|
|
|
|
assert history == answer
|
|
|
|
|
|
def test_asyncio_actor_high_concurrency(ray_start_regular_shared):
|
|
# This tests actor can handle concurrency above recursionlimit.
|
|
|
|
@ray.remote
|
|
class AsyncConcurrencyBatcher:
|
|
def __init__(self, batch_size):
|
|
self.batch = []
|
|
self.event = asyncio.Event()
|
|
self.batch_size = batch_size
|
|
|
|
async def add(self, x):
|
|
self.batch.append(x)
|
|
if len(self.batch) >= self.batch_size:
|
|
self.event.set()
|
|
else:
|
|
await self.event.wait()
|
|
return sorted(self.batch)
|
|
|
|
batch_size = sys.getrecursionlimit() * 4
|
|
actor = AsyncConcurrencyBatcher.options(max_concurrency=batch_size *
|
|
2).remote(batch_size)
|
|
result = ray.get([actor.add.remote(i) for i in range(batch_size)])
|
|
assert result[0] == list(range(batch_size))
|
|
assert result[-1] == list(range(batch_size))
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_asyncio_get(ray_start_regular_shared, event_loop):
|
|
loop = event_loop
|
|
asyncio.set_event_loop(loop)
|
|
loop.set_debug(True)
|
|
|
|
# Test Async Plasma
|
|
@ray.remote
|
|
def task():
|
|
return 1
|
|
|
|
assert await task.remote().as_future() == 1
|
|
|
|
@ray.remote
|
|
def task_throws():
|
|
1 / 0
|
|
|
|
with pytest.raises(ray.exceptions.RayTaskError):
|
|
await task_throws.remote().as_future()
|
|
|
|
# Test actor calls.
|
|
str_len = 200 * 1024
|
|
|
|
@ray.remote
|
|
class Actor:
|
|
def echo(self, i):
|
|
return i
|
|
|
|
def big_object(self):
|
|
# 100Kb is the limit for direct call
|
|
return "a" * (str_len)
|
|
|
|
def throw_error(self):
|
|
1 / 0
|
|
|
|
actor = Actor.remote()
|
|
|
|
actor_call_future = actor.echo.remote(2).as_future()
|
|
assert await actor_call_future == 2
|
|
|
|
promoted_to_plasma_future = actor.big_object.remote().as_future()
|
|
assert await promoted_to_plasma_future == "a" * str_len
|
|
|
|
with pytest.raises(ray.exceptions.RayTaskError):
|
|
await actor.throw_error.remote().as_future()
|
|
|
|
ray.kill(actor)
|
|
with pytest.raises(ray.exceptions.RayActorError):
|
|
await actor.echo.remote(1)
|
|
|
|
|
|
def test_asyncio_actor_async_get(ray_start_regular_shared):
|
|
@ray.remote
|
|
def remote_task():
|
|
return 1
|
|
|
|
@ray.remote
|
|
class AsyncGetter:
|
|
async def get(self):
|
|
return await remote_task.remote()
|
|
|
|
async def plasma_get(self, plasma_object):
|
|
return await plasma_object[0]
|
|
|
|
plasma_object = ray.put(2)
|
|
getter = AsyncGetter.remote()
|
|
assert ray.get(getter.get.remote()) == 1
|
|
assert ray.get(getter.plasma_get.remote([plasma_object])) == 2
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_asyncio_double_await(ray_start_regular_shared):
|
|
# This is a regression test for
|
|
# https://github.com/ray-project/ray/issues/8841
|
|
|
|
signal = SignalActor.remote()
|
|
waiting = signal.wait.remote()
|
|
|
|
future = waiting.as_future()
|
|
with pytest.raises(asyncio.TimeoutError):
|
|
await asyncio.wait_for(future, timeout=0.1)
|
|
assert future.cancelled()
|
|
|
|
# We are explicitly waiting multiple times here to test asyncio state
|
|
# override.
|
|
await signal.send.remote()
|
|
await waiting
|
|
await waiting
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_asyncio_exit_actor(ray_start_regular_shared):
|
|
# https://github.com/ray-project/ray/issues/12649
|
|
# The test should just hang without the fix.
|
|
|
|
@ray.remote
|
|
class Actor:
|
|
async def exit(self):
|
|
ray.actor.exit_actor()
|
|
|
|
async def ping(self):
|
|
return "pong"
|
|
|
|
async def loop_forever(self):
|
|
while True:
|
|
await asyncio.sleep(5)
|
|
|
|
a = Actor.options(max_task_retries=0).remote()
|
|
a.loop_forever.remote()
|
|
# Make sure exit_actor exits immediately, not once all tasks completed.
|
|
ray.get(a.exit.remote())
|
|
|
|
with pytest.raises(ray.exceptions.RayActorError):
|
|
ray.get(a.ping.remote())
|
|
|
|
|
|
def test_async_callback(ray_start_regular_shared):
|
|
global_set = set()
|
|
|
|
ref = ray.put(None)
|
|
ref._on_completed(lambda _: global_set.add("completed-1"))
|
|
wait_for_condition(lambda: "completed-1" in global_set)
|
|
|
|
signal = SignalActor.remote()
|
|
|
|
@ray.remote
|
|
def wait():
|
|
ray.get(signal.wait.remote())
|
|
|
|
ref = wait.remote()
|
|
ref._on_completed(lambda _: global_set.add("completed-2"))
|
|
assert "completed-2" not in global_set
|
|
signal.send.remote()
|
|
wait_for_condition(lambda: "completed-2" in global_set)
|
|
|
|
|
|
def test_async_function_errored(ray_start_regular_shared):
|
|
@ray.remote
|
|
async def f():
|
|
pass
|
|
|
|
ref = f.remote()
|
|
|
|
with pytest.raises(ValueError):
|
|
ray.get(ref)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
sys.exit(pytest.main(["-v", __file__]))
|