mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 07:38:49 +08:00
Experimental asyncio support (#2015)
* Init commit for async plasma client * Create an eventloop model for ray/plasma * Implement a poll-like selector base on `ray.wait`. Huge improvements. * Allow choosing workers & selectors * remove original design * initial implementation of epoll-like selector for plasma * Add a param for `worker` used in `PlasmaSelectorEventLoop` * Allow accepting a `Future` which returns object_id * Do not need `io.py` anymore * Create a basic testing model * fix: `ray.wait` returns tuple of lists * fix a few bugs * improving performance & bug fixing * add test * several improvements & fixing * fix relative import * [async] change code format, remove old files * [async] Create context wrapper for the eventloop * [async] fix: context should return a value * [async] Implement futures grouping * [async] Fix bugs & replace old functions * [async] Fix bugs found in tests * [async] Implement `PlasmaEpoll` * [async] Make test faster, add tests for epoll * [async] Fix code format * [async] Add comments for main code. * [async] Fix import path. * [async] Fix test. * [async] Compatibility. * [async] less verbose to not annoy the CI. * [async] Add test for new API * [async] Allow showing debug info in some of the test. * [async] Fix test. * [async] Proper shutdown. * [async] Lint~ * [async] Move files to experimental and create API * [async] Use async/await syntax * [async] Fix names & styles * [async] comments * [async] bug fixing & use pytest * [async] bug fixing & change tests * [async] use logger * [async] add tests * [async] lint * [async] type checking * [async] add more tests * [async] fix bugs on waiting a future while timeout. Add more docs. * [async] Formal docs. * [async] Add typing info since these codes are compatible with py3.5+. * [async] Documents. * [async] Lint. * [async] Fix deprecated call. * [async] Fix deprecated call. * [async] Implement a more reasonable way for dealing with pending inputs. * [async] Fix docs * [async] Lint * [async] Fix bug: Type for time * [async] Set our eventloop as the default eventloop so that we can get it through `asyncio.get_event_loop()`. * [async] Update test & docs. * [async] Lint. * [async] Temporarily print more debug info. * [async] Use `Poll` as a default option. * [async] Limit resources. * new async implementation for Ray * implement linked list * bug fix * update * support seamless async operations * update * update API * fix tests * lint * bug fix * refactor names * improve doc * properly shutdown async_api * doc * Change the table on the index page. * Adjust table size. * Only keeps `as_future`. * change how we init connection * init connection in `ray.worker.connect` * doc * fix * Move initialization code into the module. * Fix docs & code * Update pyarrow version. * lint * Restore index.rst * Add known issues. * Apply suggestions from code review Co-Authored-By: suquark <suquark@gmail.com> * rename * Update async_api.rst * Update async_api.py * Update async_api.rst * Update async_api.py * Update worker.py * Update async_api.rst * fix tests * lint * lint * replace the magic number
This commit is contained in:
@@ -0,0 +1,62 @@
|
||||
# Note: asyncio is only compatible with Python 3
|
||||
|
||||
import asyncio
|
||||
import ray
|
||||
from ray.experimental.async_plasma import PlasmaProtocol, PlasmaEventHandler
|
||||
|
||||
handler = None
|
||||
transport = None
|
||||
protocol = None
|
||||
|
||||
|
||||
async def _async_init():
|
||||
global handler, transport, protocol
|
||||
if handler is None:
|
||||
worker = ray.worker.global_worker
|
||||
loop = asyncio.get_event_loop()
|
||||
worker.plasma_client.subscribe()
|
||||
rsock = worker.plasma_client.get_notification_socket()
|
||||
handler = PlasmaEventHandler(loop, worker)
|
||||
transport, protocol = await loop.create_connection(
|
||||
lambda: PlasmaProtocol(worker.plasma_client, handler), sock=rsock)
|
||||
|
||||
|
||||
def init():
|
||||
"""
|
||||
Initialize synchronously.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
raise Exception("You must initialize the Ray async API by calling "
|
||||
"async_api.init() or async_api.as_future(obj) before "
|
||||
"the event loop starts.")
|
||||
else:
|
||||
asyncio.get_event_loop().run_until_complete(_async_init())
|
||||
|
||||
|
||||
def as_future(object_id):
|
||||
"""Turn an object_id into a Future object.
|
||||
|
||||
Args:
|
||||
object_id: A Ray object_id.
|
||||
|
||||
Returns:
|
||||
PlasmaObjectFuture: A future object that waits the object_id.
|
||||
"""
|
||||
if handler is None:
|
||||
init()
|
||||
return handler.as_future(object_id)
|
||||
|
||||
|
||||
def shutdown():
|
||||
"""Manually shutdown the async API.
|
||||
|
||||
Cancels all related tasks and all the socket transportation.
|
||||
"""
|
||||
global handler, transport, protocol
|
||||
if handler is not None:
|
||||
handler.close()
|
||||
transport.close()
|
||||
handler = None
|
||||
transport = None
|
||||
protocol = None
|
||||
@@ -0,0 +1,237 @@
|
||||
import asyncio
|
||||
import ctypes
|
||||
import sys
|
||||
|
||||
import pyarrow.plasma as plasma
|
||||
|
||||
import ray
|
||||
from ray.services import logger
|
||||
|
||||
INT64_SIZE = ctypes.sizeof(ctypes.c_int64)
|
||||
|
||||
|
||||
def _release_waiter(waiter, *_):
|
||||
if not waiter.done():
|
||||
waiter.set_result(None)
|
||||
|
||||
|
||||
class PlasmaProtocol(asyncio.Protocol):
|
||||
"""Protocol control for the asyncio connection."""
|
||||
|
||||
def __init__(self, plasma_client, plasma_event_handler):
|
||||
self.plasma_client = plasma_client
|
||||
self.plasma_event_handler = plasma_event_handler
|
||||
self.transport = None
|
||||
self._buffer = b""
|
||||
|
||||
def connection_made(self, transport):
|
||||
self.transport = transport
|
||||
|
||||
def data_received(self, data):
|
||||
self._buffer += data
|
||||
messages = []
|
||||
i = 0
|
||||
while i + INT64_SIZE <= len(self._buffer):
|
||||
msg_len = int.from_bytes(self._buffer[i:i + INT64_SIZE],
|
||||
sys.byteorder)
|
||||
if i + INT64_SIZE + msg_len > len(self._buffer):
|
||||
break
|
||||
i += INT64_SIZE
|
||||
segment = self._buffer[i:i + msg_len]
|
||||
i += msg_len
|
||||
messages.append(self.plasma_client.decode_notification(segment))
|
||||
|
||||
self._buffer = self._buffer[i:]
|
||||
self.plasma_event_handler.process_notifications(messages)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
# The socket has been closed
|
||||
logger.debug("PlasmaProtocol - connection lost.")
|
||||
|
||||
def eof_received(self):
|
||||
logger.debug("PlasmaProtocol - EOF received.")
|
||||
self.transport.close()
|
||||
|
||||
|
||||
class PlasmaObjectFuture(asyncio.Future):
|
||||
"""This class manages the lifecycle of a Future contains an object_id.
|
||||
|
||||
Note:
|
||||
This Future is an item in an linked list.
|
||||
|
||||
Attributes:
|
||||
object_id: The object_id this Future contains.
|
||||
"""
|
||||
|
||||
def __init__(self, loop, object_id):
|
||||
super().__init__(loop=loop)
|
||||
self.object_id = object_id
|
||||
self.prev = None
|
||||
self.next = None
|
||||
|
||||
@property
|
||||
def ray_object_id(self):
|
||||
return ray.ObjectID(self.object_id.binary())
|
||||
|
||||
def __repr__(self):
|
||||
return super().__repr__() + "{object_id=%s}" % self.object_id
|
||||
|
||||
|
||||
class PlasmaObjectLinkedList(asyncio.Future):
|
||||
"""This class is a doubly-linked list.
|
||||
It holds a ObjectID and maintains futures assigned to the ObjectID.
|
||||
|
||||
Args:
|
||||
loop: an event loop.
|
||||
plain_object_id (plasma.ObjectID):
|
||||
The plasma ObjectID this class holds.
|
||||
"""
|
||||
|
||||
def __init__(self, loop, plain_object_id):
|
||||
super().__init__(loop=loop)
|
||||
assert isinstance(plain_object_id, plasma.ObjectID)
|
||||
self.object_id = plain_object_id
|
||||
self.head = None
|
||||
self.tail = None
|
||||
|
||||
def append(self, future):
|
||||
"""Append an object to the linked list.
|
||||
|
||||
Args:
|
||||
future (PlasmaObjectFuture): A PlasmaObjectFuture instance.
|
||||
"""
|
||||
future.prev = self.tail
|
||||
if self.tail is None:
|
||||
assert self.head is None
|
||||
self.head = future
|
||||
else:
|
||||
self.tail.next = future
|
||||
self.tail = future
|
||||
# Once done, it will be removed from the list.
|
||||
future.add_done_callback(self.remove)
|
||||
|
||||
def remove(self, future):
|
||||
"""Remove an object from the linked list.
|
||||
|
||||
Args:
|
||||
future (PlasmaObjectFuture): A PlasmaObjectFuture instance.
|
||||
"""
|
||||
if self._loop.get_debug():
|
||||
logger.debug("Removing %s from the linked list.", future)
|
||||
if future.prev is None:
|
||||
assert future is self.head
|
||||
self.head = future.next
|
||||
if self.head is None:
|
||||
self.tail = None
|
||||
if not self.cancelled():
|
||||
self.set_result(None)
|
||||
else:
|
||||
self.head.prev = None
|
||||
elif future.next is None:
|
||||
assert future is self.tail
|
||||
self.tail = future.prev
|
||||
if self.tail is None:
|
||||
self.head = None
|
||||
if not self.cancelled():
|
||||
self.set_result(None)
|
||||
else:
|
||||
self.tail.prev = None
|
||||
|
||||
def cancel(self, *args, **kwargs):
|
||||
"""Manually cancel all tasks assigned to this event loop."""
|
||||
# Because remove all futures will trigger `set_result`,
|
||||
# we cancel itself first.
|
||||
super().cancel()
|
||||
for future in self.traverse():
|
||||
# All cancelled futures should have callbacks to removed itself
|
||||
# from this linked list. However, these callbacks are scheduled in
|
||||
# an event loop, so we could still find them in our list.
|
||||
if not future.cancelled():
|
||||
future.cancel()
|
||||
|
||||
def set_result(self, result):
|
||||
"""Complete all tasks. """
|
||||
for future in self.traverse():
|
||||
# All cancelled futures should have callbacks to removed itself
|
||||
# from this linked list. However, these callbacks are scheduled in
|
||||
# an event loop, so we could still find them in our list.
|
||||
future.set_result(result)
|
||||
if not self.done():
|
||||
super().set_result(result)
|
||||
|
||||
def traverse(self):
|
||||
"""Traverse this linked list.
|
||||
|
||||
Yields:
|
||||
PlasmaObjectFuture: PlasmaObjectFuture instances.
|
||||
"""
|
||||
current = self.head
|
||||
while current is not None:
|
||||
yield current
|
||||
current = current.next
|
||||
|
||||
|
||||
class PlasmaEventHandler:
|
||||
"""This class is an event handler for Plasma."""
|
||||
|
||||
def __init__(self, loop, worker):
|
||||
super().__init__()
|
||||
self._loop = loop
|
||||
self._worker = worker
|
||||
self._waiting_dict = {}
|
||||
|
||||
def process_notifications(self, messages):
|
||||
"""Process notifications."""
|
||||
for object_id, object_size, metadata_size in messages:
|
||||
if object_size > 0 and object_id in self._waiting_dict:
|
||||
linked_list = self._waiting_dict[object_id]
|
||||
self._complete_future(linked_list)
|
||||
|
||||
def close(self):
|
||||
"""Clean up this handler."""
|
||||
for linked_list in self._waiting_dict.values():
|
||||
linked_list.cancel()
|
||||
# All cancelled linked lists should have callbacks to removed itself
|
||||
# from the waiting dict. However, these callbacks are scheduled in
|
||||
# an event loop, so we don't check them now.
|
||||
|
||||
def _unregister_callback(self, fut):
|
||||
del self._waiting_dict[fut.object_id]
|
||||
|
||||
def _complete_future(self, fut):
|
||||
obj = self._worker.retrieve_and_deserialize([fut.object_id], 0)[0]
|
||||
fut.set_result(obj)
|
||||
|
||||
def as_future(self, object_id, check_ready=True):
|
||||
"""Turn an object_id into a Future object.
|
||||
|
||||
Args:
|
||||
object_id: A Ray's object_id.
|
||||
check_ready (bool): If true, check if the object_id is ready.
|
||||
|
||||
Returns:
|
||||
PlasmaObjectFuture: A future object that waits the object_id.
|
||||
"""
|
||||
if not isinstance(object_id, ray.ObjectID):
|
||||
raise TypeError("Input should be an ObjectID.")
|
||||
|
||||
plain_object_id = plasma.ObjectID(object_id.id())
|
||||
fut = PlasmaObjectFuture(loop=self._loop, object_id=plain_object_id)
|
||||
|
||||
if check_ready:
|
||||
ready, _ = ray.wait([object_id], timeout=0)
|
||||
if ready:
|
||||
if self._loop.get_debug():
|
||||
logger.debug("%s has been ready.", plain_object_id)
|
||||
self._complete_future(fut)
|
||||
return fut
|
||||
|
||||
if plain_object_id not in self._waiting_dict:
|
||||
linked_list = PlasmaObjectLinkedList(self._loop, plain_object_id)
|
||||
linked_list.add_done_callback(self._unregister_callback)
|
||||
self._waiting_dict[plain_object_id] = linked_list
|
||||
self._waiting_dict[plain_object_id].append(fut)
|
||||
if self._loop.get_debug():
|
||||
logger.debug("%s added to the waiting list.", fut)
|
||||
|
||||
return fut
|
||||
@@ -0,0 +1,150 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
import ray
|
||||
from ray.experimental import async_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def init():
|
||||
ray.init(num_cpus=4)
|
||||
async_api.init()
|
||||
asyncio.get_event_loop().set_debug(False)
|
||||
yield
|
||||
async_api.shutdown()
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
def gen_tasks(time_scale=0.1):
|
||||
@ray.remote
|
||||
def f(n):
|
||||
time.sleep(n * time_scale)
|
||||
return n
|
||||
|
||||
tasks = [f.remote(i) for i in range(5)]
|
||||
return tasks
|
||||
|
||||
|
||||
def test_simple(init):
|
||||
@ray.remote
|
||||
def f():
|
||||
time.sleep(1)
|
||||
return {"key1": ["value"]}
|
||||
|
||||
future = async_api.as_future(f.remote())
|
||||
result = asyncio.get_event_loop().run_until_complete(future)
|
||||
assert result["key1"] == ["value"]
|
||||
|
||||
|
||||
def test_gather(init):
|
||||
loop = asyncio.get_event_loop()
|
||||
tasks = gen_tasks()
|
||||
futures = [async_api.as_future(obj_id) for obj_id in tasks]
|
||||
results = loop.run_until_complete(asyncio.gather(*futures))
|
||||
assert all(a == b for a, b in zip(results, ray.get(tasks)))
|
||||
|
||||
|
||||
def test_gather_benchmark(init):
|
||||
@ray.remote
|
||||
def f(n):
|
||||
time.sleep(0.001 * n)
|
||||
return 42
|
||||
|
||||
async def test_async():
|
||||
sum_time = 0.
|
||||
for _ in range(50):
|
||||
tasks = [f.remote(n) for n in range(20)]
|
||||
start = time.time()
|
||||
futures = [async_api.as_future(obj_id) for obj_id in tasks]
|
||||
await asyncio.gather(*futures)
|
||||
sum_time += time.time() - start
|
||||
return sum_time
|
||||
|
||||
def baseline():
|
||||
sum_time = 0.
|
||||
for _ in range(50):
|
||||
tasks = [f.remote(n) for n in range(20)]
|
||||
start = time.time()
|
||||
ray.get(tasks)
|
||||
sum_time += time.time() - start
|
||||
return sum_time
|
||||
|
||||
# warm up
|
||||
baseline()
|
||||
# async get
|
||||
sum_time_1 = asyncio.get_event_loop().run_until_complete(test_async())
|
||||
# get
|
||||
sum_time_2 = baseline()
|
||||
|
||||
# Ensure the new implementation is not too slow.
|
||||
assert sum_time_2 * 1.2 > sum_time_1
|
||||
|
||||
|
||||
def test_wait(init):
|
||||
loop = asyncio.get_event_loop()
|
||||
tasks = gen_tasks()
|
||||
futures = [async_api.as_future(obj_id) for obj_id in tasks]
|
||||
results, _ = loop.run_until_complete(asyncio.wait(futures))
|
||||
assert set(results) == set(futures)
|
||||
|
||||
|
||||
def test_wait_timeout(init):
|
||||
loop = asyncio.get_event_loop()
|
||||
tasks = gen_tasks(10)
|
||||
futures = [async_api.as_future(obj_id) for obj_id in tasks]
|
||||
fut = asyncio.wait(futures, timeout=5)
|
||||
results, _ = loop.run_until_complete(fut)
|
||||
assert list(results)[0] == futures[0]
|
||||
|
||||
|
||||
def test_gather_mixup(init):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
@ray.remote
|
||||
def f(n):
|
||||
time.sleep(n * 0.1)
|
||||
return n
|
||||
|
||||
async def g(n):
|
||||
await asyncio.sleep(n * 0.1)
|
||||
return n
|
||||
|
||||
tasks = [
|
||||
async_api.as_future(f.remote(1)),
|
||||
g(2),
|
||||
async_api.as_future(f.remote(3)),
|
||||
g(4)
|
||||
]
|
||||
results = loop.run_until_complete(asyncio.gather(*tasks))
|
||||
assert results == [1, 2, 3, 4]
|
||||
|
||||
|
||||
def test_wait_mixup(init):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
@ray.remote
|
||||
def f(n):
|
||||
time.sleep(n)
|
||||
return n
|
||||
|
||||
def g(n):
|
||||
async def _g(_n):
|
||||
await asyncio.sleep(_n)
|
||||
return _n
|
||||
|
||||
return asyncio.ensure_future(_g(n))
|
||||
|
||||
tasks = [
|
||||
async_api.as_future(f.remote(0.1)),
|
||||
g(7),
|
||||
async_api.as_future(f.remote(5)),
|
||||
g(2)
|
||||
]
|
||||
ready, _ = loop.run_until_complete(asyncio.wait(tasks, timeout=4))
|
||||
assert set(ready) == {tasks[0], tasks[-1]}
|
||||
Reference in New Issue
Block a user