Files
ray/python/ray/tests/test_async.py
T
2020-07-10 17:49:04 +08:00

125 lines
3.1 KiB
Python

import asyncio
import sys
import time
import numpy as np
import pytest
import ray
@pytest.fixture
def init():
ray.init(num_cpus=4)
asyncio.get_event_loop().set_debug(False)
yield
ray.shutdown()
def gen_tasks(time_scale=0.1):
@ray.remote
def f(n):
time.sleep(n * time_scale)
return n, np.zeros(1024 * 1024, dtype=np.uint8)
return [f.remote(i) for i in range(5)]
def test_simple(init):
@ray.remote
def f():
time.sleep(1)
return np.zeros(1024 * 1024, dtype=np.uint8)
future = f.remote().as_future()
result = asyncio.get_event_loop().run_until_complete(future)
assert isinstance(result, np.ndarray)
def test_gather(init):
loop = asyncio.get_event_loop()
tasks = gen_tasks()
futures = [obj_ref.as_future() for obj_ref in tasks]
results = loop.run_until_complete(asyncio.gather(*futures))
assert all(a[0] == b[0] for a, b in zip(results, ray.get(tasks)))
def test_wait(init):
loop = asyncio.get_event_loop()
tasks = gen_tasks()
futures = [obj_ref.as_future() for obj_ref in tasks]
results, _ = loop.run_until_complete(asyncio.wait(futures))
assert set(results) == set(futures)
def test_wait_timeout(init):
loop = asyncio.get_event_loop()
tasks = gen_tasks(10)
futures = [obj_ref.as_future() for obj_ref in tasks]
fut = asyncio.wait(futures, timeout=5)
results, _ = loop.run_until_complete(fut)
assert list(results)[0] == futures[0]
def test_gather_mixup(init):
loop = asyncio.get_event_loop()
@ray.remote
def f(n):
time.sleep(n * 0.1)
return n, np.zeros(1024 * 1024, dtype=np.uint8)
async def g(n):
await asyncio.sleep(n * 0.1)
return n, np.zeros(1024 * 1024, dtype=np.uint8)
tasks = [f.remote(1).as_future(), g(2), f.remote(3).as_future(), g(4)]
results = loop.run_until_complete(asyncio.gather(*tasks))
assert [result[0] for result in results] == [1, 2, 3, 4]
def test_wait_mixup(init):
loop = asyncio.get_event_loop()
@ray.remote
def f(n):
time.sleep(n)
return n, np.zeros(1024 * 1024, dtype=np.uint8)
def g(n):
async def _g(_n):
await asyncio.sleep(_n)
return _n
return asyncio.ensure_future(_g(n))
tasks = [f.remote(0.1).as_future(), g(7), f.remote(5).as_future(), g(2)]
ready, _ = loop.run_until_complete(asyncio.wait(tasks, timeout=4))
assert set(ready) == {tasks[0], tasks[-1]}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"ray_start_regular_shared", [{
"object_store_memory": 100 * 1024 * 1024,
}],
indirect=True)
async def test_garbage_collection(ray_start_regular_shared):
# This is a regression test for
# https://github.com/ray-project/ray/issues/9134
@ray.remote
def f():
return np.zeros(40 * 1024 * 1024, dtype=np.uint8)
for _ in range(10):
await f.remote()
for _ in range(10):
put_id = ray.put(np.zeros(40 * 1024 * 1024, dtype=np.uint8))
await put_id
if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))