Experimental asyncio support (#2015)

* Init commit for async plasma client

* Create an eventloop model for ray/plasma

* Implement a poll-like selector base on `ray.wait`. Huge improvements.

* Allow choosing workers & selectors

* remove original design

* initial implementation of epoll-like selector for plasma

* Add a param for `worker` used in `PlasmaSelectorEventLoop`

* Allow accepting a `Future` which returns object_id

* Do not need `io.py` anymore

* Create a basic testing model

* fix: `ray.wait` returns tuple of lists

* fix a few bugs

* improving performance & bug fixing

* add test

* several improvements & fixing

* fix relative import

* [async] change code format, remove old files

* [async] Create context wrapper for the eventloop

* [async] fix: context should return a value

* [async] Implement futures grouping

* [async] Fix bugs & replace old functions

* [async] Fix bugs found in tests

* [async] Implement `PlasmaEpoll`

* [async] Make test faster, add tests for epoll

* [async] Fix code format

* [async] Add comments for main code.

* [async] Fix import path.

* [async] Fix test.

* [async] Compatibility.

* [async] less verbose to not annoy the CI.

* [async] Add test for new API

* [async] Allow showing debug info in some of the test.

* [async] Fix test.

* [async] Proper shutdown.

* [async] Lint~

* [async] Move files to experimental and create API

* [async] Use async/await syntax

* [async] Fix names & styles

* [async] comments

* [async] bug fixing & use pytest

* [async] bug fixing & change tests

* [async] use logger

* [async] add tests

* [async] lint

* [async] type checking

* [async] add more tests

* [async] fix bugs on waiting a future while timeout. Add more docs.

* [async] Formal docs.

* [async] Add typing info since these codes are compatible with py3.5+.

* [async] Documents.

* [async] Lint.

* [async] Fix deprecated call.

* [async] Fix deprecated call.

* [async] Implement a more reasonable way for dealing with pending inputs.

* [async] Fix docs

* [async] Lint

* [async] Fix bug: Type for time

* [async] Set our eventloop as the default eventloop so that we can get it through `asyncio.get_event_loop()`.

* [async] Update test & docs.

* [async] Lint.

* [async] Temporarily print more debug info.

* [async] Use `Poll` as a default option.

* [async] Limit resources.

* new async implementation for Ray

* implement linked list

* bug fix

* update

* support seamless async operations

* update

* update API

* fix tests

* lint

* bug fix

* refactor names

* improve doc

* properly shutdown async_api

* doc

* Change the table on the index page.

* Adjust table size.

* Only keeps `as_future`.

* change how we init connection

* init connection in `ray.worker.connect`

* doc

* fix

* Move initialization code into the module.

* Fix docs & code

* Update pyarrow version.

* lint

* Restore index.rst

* Add known issues.

* Apply suggestions from code review

Co-Authored-By: suquark <suquark@gmail.com>

* rename

* Update async_api.rst

* Update async_api.py

* Update async_api.rst

* Update async_api.py

* Update worker.py

* Update async_api.rst

* fix tests

* lint

* lint

* replace the magic number
This commit is contained in:
Si-Yuan
2018-12-06 17:39:05 -08:00
committed by Philipp Moritz
parent 970babf31a
commit c2c501bbe6
7 changed files with 542 additions and 2 deletions
+62
View File
@@ -0,0 +1,62 @@
# Note: asyncio is only compatible with Python 3
import asyncio
import ray
from ray.experimental.async_plasma import PlasmaProtocol, PlasmaEventHandler
handler = None
transport = None
protocol = None
async def _async_init():
global handler, transport, protocol
if handler is None:
worker = ray.worker.global_worker
loop = asyncio.get_event_loop()
worker.plasma_client.subscribe()
rsock = worker.plasma_client.get_notification_socket()
handler = PlasmaEventHandler(loop, worker)
transport, protocol = await loop.create_connection(
lambda: PlasmaProtocol(worker.plasma_client, handler), sock=rsock)
def init():
"""
Initialize synchronously.
"""
loop = asyncio.get_event_loop()
if loop.is_running():
raise Exception("You must initialize the Ray async API by calling "
"async_api.init() or async_api.as_future(obj) before "
"the event loop starts.")
else:
asyncio.get_event_loop().run_until_complete(_async_init())
def as_future(object_id):
"""Turn an object_id into a Future object.
Args:
object_id: A Ray object_id.
Returns:
PlasmaObjectFuture: A future object that waits the object_id.
"""
if handler is None:
init()
return handler.as_future(object_id)
def shutdown():
"""Manually shutdown the async API.
Cancels all related tasks and all the socket transportation.
"""
global handler, transport, protocol
if handler is not None:
handler.close()
transport.close()
handler = None
transport = None
protocol = None
+237
View File
@@ -0,0 +1,237 @@
import asyncio
import ctypes
import sys
import pyarrow.plasma as plasma
import ray
from ray.services import logger
INT64_SIZE = ctypes.sizeof(ctypes.c_int64)
def _release_waiter(waiter, *_):
if not waiter.done():
waiter.set_result(None)
class PlasmaProtocol(asyncio.Protocol):
"""Protocol control for the asyncio connection."""
def __init__(self, plasma_client, plasma_event_handler):
self.plasma_client = plasma_client
self.plasma_event_handler = plasma_event_handler
self.transport = None
self._buffer = b""
def connection_made(self, transport):
self.transport = transport
def data_received(self, data):
self._buffer += data
messages = []
i = 0
while i + INT64_SIZE <= len(self._buffer):
msg_len = int.from_bytes(self._buffer[i:i + INT64_SIZE],
sys.byteorder)
if i + INT64_SIZE + msg_len > len(self._buffer):
break
i += INT64_SIZE
segment = self._buffer[i:i + msg_len]
i += msg_len
messages.append(self.plasma_client.decode_notification(segment))
self._buffer = self._buffer[i:]
self.plasma_event_handler.process_notifications(messages)
def connection_lost(self, exc):
# The socket has been closed
logger.debug("PlasmaProtocol - connection lost.")
def eof_received(self):
logger.debug("PlasmaProtocol - EOF received.")
self.transport.close()
class PlasmaObjectFuture(asyncio.Future):
"""This class manages the lifecycle of a Future contains an object_id.
Note:
This Future is an item in an linked list.
Attributes:
object_id: The object_id this Future contains.
"""
def __init__(self, loop, object_id):
super().__init__(loop=loop)
self.object_id = object_id
self.prev = None
self.next = None
@property
def ray_object_id(self):
return ray.ObjectID(self.object_id.binary())
def __repr__(self):
return super().__repr__() + "{object_id=%s}" % self.object_id
class PlasmaObjectLinkedList(asyncio.Future):
"""This class is a doubly-linked list.
It holds a ObjectID and maintains futures assigned to the ObjectID.
Args:
loop: an event loop.
plain_object_id (plasma.ObjectID):
The plasma ObjectID this class holds.
"""
def __init__(self, loop, plain_object_id):
super().__init__(loop=loop)
assert isinstance(plain_object_id, plasma.ObjectID)
self.object_id = plain_object_id
self.head = None
self.tail = None
def append(self, future):
"""Append an object to the linked list.
Args:
future (PlasmaObjectFuture): A PlasmaObjectFuture instance.
"""
future.prev = self.tail
if self.tail is None:
assert self.head is None
self.head = future
else:
self.tail.next = future
self.tail = future
# Once done, it will be removed from the list.
future.add_done_callback(self.remove)
def remove(self, future):
"""Remove an object from the linked list.
Args:
future (PlasmaObjectFuture): A PlasmaObjectFuture instance.
"""
if self._loop.get_debug():
logger.debug("Removing %s from the linked list.", future)
if future.prev is None:
assert future is self.head
self.head = future.next
if self.head is None:
self.tail = None
if not self.cancelled():
self.set_result(None)
else:
self.head.prev = None
elif future.next is None:
assert future is self.tail
self.tail = future.prev
if self.tail is None:
self.head = None
if not self.cancelled():
self.set_result(None)
else:
self.tail.prev = None
def cancel(self, *args, **kwargs):
"""Manually cancel all tasks assigned to this event loop."""
# Because remove all futures will trigger `set_result`,
# we cancel itself first.
super().cancel()
for future in self.traverse():
# All cancelled futures should have callbacks to removed itself
# from this linked list. However, these callbacks are scheduled in
# an event loop, so we could still find them in our list.
if not future.cancelled():
future.cancel()
def set_result(self, result):
"""Complete all tasks. """
for future in self.traverse():
# All cancelled futures should have callbacks to removed itself
# from this linked list. However, these callbacks are scheduled in
# an event loop, so we could still find them in our list.
future.set_result(result)
if not self.done():
super().set_result(result)
def traverse(self):
"""Traverse this linked list.
Yields:
PlasmaObjectFuture: PlasmaObjectFuture instances.
"""
current = self.head
while current is not None:
yield current
current = current.next
class PlasmaEventHandler:
"""This class is an event handler for Plasma."""
def __init__(self, loop, worker):
super().__init__()
self._loop = loop
self._worker = worker
self._waiting_dict = {}
def process_notifications(self, messages):
"""Process notifications."""
for object_id, object_size, metadata_size in messages:
if object_size > 0 and object_id in self._waiting_dict:
linked_list = self._waiting_dict[object_id]
self._complete_future(linked_list)
def close(self):
"""Clean up this handler."""
for linked_list in self._waiting_dict.values():
linked_list.cancel()
# All cancelled linked lists should have callbacks to removed itself
# from the waiting dict. However, these callbacks are scheduled in
# an event loop, so we don't check them now.
def _unregister_callback(self, fut):
del self._waiting_dict[fut.object_id]
def _complete_future(self, fut):
obj = self._worker.retrieve_and_deserialize([fut.object_id], 0)[0]
fut.set_result(obj)
def as_future(self, object_id, check_ready=True):
"""Turn an object_id into a Future object.
Args:
object_id: A Ray's object_id.
check_ready (bool): If true, check if the object_id is ready.
Returns:
PlasmaObjectFuture: A future object that waits the object_id.
"""
if not isinstance(object_id, ray.ObjectID):
raise TypeError("Input should be an ObjectID.")
plain_object_id = plasma.ObjectID(object_id.id())
fut = PlasmaObjectFuture(loop=self._loop, object_id=plain_object_id)
if check_ready:
ready, _ = ray.wait([object_id], timeout=0)
if ready:
if self._loop.get_debug():
logger.debug("%s has been ready.", plain_object_id)
self._complete_future(fut)
return fut
if plain_object_id not in self._waiting_dict:
linked_list = PlasmaObjectLinkedList(self._loop, plain_object_id)
linked_list.add_done_callback(self._unregister_callback)
self._waiting_dict[plain_object_id] = linked_list
self._waiting_dict[plain_object_id].append(fut)
if self._loop.get_debug():
logger.debug("%s added to the waiting list.", fut)
return fut
+150
View File
@@ -0,0 +1,150 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import asyncio
import time
import pytest
import ray
from ray.experimental import async_api
@pytest.fixture
def init():
ray.init(num_cpus=4)
async_api.init()
asyncio.get_event_loop().set_debug(False)
yield
async_api.shutdown()
ray.shutdown()
def gen_tasks(time_scale=0.1):
@ray.remote
def f(n):
time.sleep(n * time_scale)
return n
tasks = [f.remote(i) for i in range(5)]
return tasks
def test_simple(init):
@ray.remote
def f():
time.sleep(1)
return {"key1": ["value"]}
future = async_api.as_future(f.remote())
result = asyncio.get_event_loop().run_until_complete(future)
assert result["key1"] == ["value"]
def test_gather(init):
loop = asyncio.get_event_loop()
tasks = gen_tasks()
futures = [async_api.as_future(obj_id) for obj_id in tasks]
results = loop.run_until_complete(asyncio.gather(*futures))
assert all(a == b for a, b in zip(results, ray.get(tasks)))
def test_gather_benchmark(init):
@ray.remote
def f(n):
time.sleep(0.001 * n)
return 42
async def test_async():
sum_time = 0.
for _ in range(50):
tasks = [f.remote(n) for n in range(20)]
start = time.time()
futures = [async_api.as_future(obj_id) for obj_id in tasks]
await asyncio.gather(*futures)
sum_time += time.time() - start
return sum_time
def baseline():
sum_time = 0.
for _ in range(50):
tasks = [f.remote(n) for n in range(20)]
start = time.time()
ray.get(tasks)
sum_time += time.time() - start
return sum_time
# warm up
baseline()
# async get
sum_time_1 = asyncio.get_event_loop().run_until_complete(test_async())
# get
sum_time_2 = baseline()
# Ensure the new implementation is not too slow.
assert sum_time_2 * 1.2 > sum_time_1
def test_wait(init):
loop = asyncio.get_event_loop()
tasks = gen_tasks()
futures = [async_api.as_future(obj_id) for obj_id in tasks]
results, _ = loop.run_until_complete(asyncio.wait(futures))
assert set(results) == set(futures)
def test_wait_timeout(init):
loop = asyncio.get_event_loop()
tasks = gen_tasks(10)
futures = [async_api.as_future(obj_id) for obj_id in tasks]
fut = asyncio.wait(futures, timeout=5)
results, _ = loop.run_until_complete(fut)
assert list(results)[0] == futures[0]
def test_gather_mixup(init):
loop = asyncio.get_event_loop()
@ray.remote
def f(n):
time.sleep(n * 0.1)
return n
async def g(n):
await asyncio.sleep(n * 0.1)
return n
tasks = [
async_api.as_future(f.remote(1)),
g(2),
async_api.as_future(f.remote(3)),
g(4)
]
results = loop.run_until_complete(asyncio.gather(*tasks))
assert results == [1, 2, 3, 4]
def test_wait_mixup(init):
loop = asyncio.get_event_loop()
@ray.remote
def f(n):
time.sleep(n)
return n
def g(n):
async def _g(_n):
await asyncio.sleep(_n)
return _n
return asyncio.ensure_future(_g(n))
tasks = [
async_api.as_future(f.remote(0.1)),
g(7),
async_api.as_future(f.remote(5)),
g(2)
]
ready, _ = loop.run_until_complete(asyncio.wait(tasks, timeout=4))
assert set(ready) == {tasks[0], tasks[-1]}