Files
ray/python/ray/experimental/test/async_test.py
T
2020-02-19 19:01:26 -08:00

113 lines
2.5 KiB
Python

import asyncio
import time
import os
import pytest
import ray
from ray.experimental import async_api
@pytest.fixture
def init():
os.environ["RAY_FORCE_DIRECT"] = "0"
ray.init(num_cpus=4)
async_api.init()
asyncio.get_event_loop().set_debug(False)
yield
async_api.shutdown()
ray.shutdown()
def gen_tasks(time_scale=0.1):
@ray.remote
def f(n):
time.sleep(n * time_scale)
return n
tasks = [f.remote(i) for i in range(5)]
return tasks
def test_simple(init):
@ray.remote
def f():
time.sleep(1)
return {"key1": ["value"]}
future = async_api.as_future(f.remote())
result = asyncio.get_event_loop().run_until_complete(future)
assert result["key1"] == ["value"]
def test_gather(init):
loop = asyncio.get_event_loop()
tasks = gen_tasks()
futures = [async_api.as_future(obj_id) for obj_id in tasks]
results = loop.run_until_complete(asyncio.gather(*futures))
assert all(a == b for a, b in zip(results, ray.get(tasks)))
def test_wait(init):
loop = asyncio.get_event_loop()
tasks = gen_tasks()
futures = [async_api.as_future(obj_id) for obj_id in tasks]
results, _ = loop.run_until_complete(asyncio.wait(futures))
assert set(results) == set(futures)
def test_wait_timeout(init):
loop = asyncio.get_event_loop()
tasks = gen_tasks(10)
futures = [async_api.as_future(obj_id) for obj_id in tasks]
fut = asyncio.wait(futures, timeout=5)
results, _ = loop.run_until_complete(fut)
assert list(results)[0] == futures[0]
def test_gather_mixup(init):
loop = asyncio.get_event_loop()
@ray.remote
def f(n):
time.sleep(n * 0.1)
return n
async def g(n):
await asyncio.sleep(n * 0.1)
return n
tasks = [
async_api.as_future(f.remote(1)),
g(2),
async_api.as_future(f.remote(3)),
g(4)
]
results = loop.run_until_complete(asyncio.gather(*tasks))
assert results == [1, 2, 3, 4]
def test_wait_mixup(init):
loop = asyncio.get_event_loop()
@ray.remote
def f(n):
time.sleep(n)
return n
def g(n):
async def _g(_n):
await asyncio.sleep(_n)
return _n
return asyncio.ensure_future(_g(n))
tasks = [
async_api.as_future(f.remote(0.1)),
g(7),
async_api.as_future(f.remote(5)),
g(2)
]
ready, _ = loop.run_until_complete(asyncio.wait(tasks, timeout=4))
assert set(ready) == {tasks[0], tasks[-1]}