Stop vendoring pyarrow (#7233)

This commit is contained in:
Simon Mo
2020-02-19 19:01:26 -08:00
committed by GitHub
parent 48c06f5042
commit b804d40c04
20 changed files with 178 additions and 387 deletions
+4 -69
View File
@@ -1,84 +1,22 @@
# Note: asyncio is only compatible with Python 3
import asyncio
import functools
import threading
import pyarrow.plasma as plasma
import ray
from ray.experimental.async_plasma import PlasmaProtocol, PlasmaEventHandler
from ray.experimental.async_plasma import PlasmaEventHandler
from ray.services import logger
handler = None
transport = None
protocol = None
class _ThreadSafeProxy:
"""This class is used to create a thread-safe proxy for a given object.
Every method call will be guarded with a lock.
Attributes:
orig_obj (object): the original object.
lock (threading.Lock): the lock object.
_wrapper_cache (dict): a cache from original object's methods to
the proxy methods.
"""
def __init__(self, orig_obj, lock):
self.orig_obj = orig_obj
self.lock = lock
self._wrapper_cache = {}
def __getattr__(self, attr):
orig_attr = getattr(self.orig_obj, attr)
if not callable(orig_attr):
# If the original attr is a field, just return it.
return orig_attr
else:
# If the orginal attr is a method,
# return a wrapper that guards the original method with a lock.
wrapper = self._wrapper_cache.get(attr)
if wrapper is None:
@functools.wraps(orig_attr)
def _wrapper(*args, **kwargs):
with self.lock:
return orig_attr(*args, **kwargs)
self._wrapper_cache[attr] = _wrapper
wrapper = _wrapper
return wrapper
def thread_safe_client(client, lock=None):
"""Create a thread-safe proxy which locks every method call
for the given client.
Args:
client: the client object to be guarded.
lock: the lock object that will be used to lock client's methods.
If None, a new lock will be used.
Returns:
A thread-safe proxy for the given client.
"""
if lock is None:
lock = threading.Lock()
return _ThreadSafeProxy(client, lock)
async def _async_init():
global handler, transport, protocol
global handler
if handler is None:
worker = ray.worker.global_worker
plasma_client = thread_safe_client(
plasma.connect(worker.node.plasma_store_socket_name, 300))
loop = asyncio.get_event_loop()
plasma_client.subscribe()
rsock = plasma_client.get_notification_socket()
handler = PlasmaEventHandler(loop, worker)
transport, protocol = await loop.create_connection(
lambda: PlasmaProtocol(plasma_client, handler), sock=rsock)
worker.core_worker.subscribe_to_plasma(handler)
logger.debug("AsyncPlasma Connection Created!")
@@ -126,10 +64,7 @@ def shutdown():
Cancels all related tasks and all the socket transportation.
"""
global handler, transport, protocol
global handler
if handler is not None:
handler.close()
transport.close()
handler = None
transport = None
protocol = None
+38 -203
View File
@@ -1,179 +1,13 @@
import asyncio
import ctypes
import sys
import pyarrow.plasma as plasma
import ray
from ray.services import logger
INT64_SIZE = ctypes.sizeof(ctypes.c_int64)
def _release_waiter(waiter, *_):
if not waiter.done():
waiter.set_result(None)
class PlasmaProtocol(asyncio.Protocol):
"""Protocol control for the asyncio connection."""
def __init__(self, plasma_client, plasma_event_handler):
self.plasma_client = plasma_client
self.plasma_event_handler = plasma_event_handler
self.transport = None
self._buffer = b""
def connection_made(self, transport):
self.transport = transport
def data_received(self, data):
self._buffer += data
messages = []
i = 0
while i + INT64_SIZE <= len(self._buffer):
msg_len = int.from_bytes(self._buffer[i:i + INT64_SIZE],
sys.byteorder)
if i + INT64_SIZE + msg_len > len(self._buffer):
break
i += INT64_SIZE
segment = self._buffer[i:i + msg_len]
i += msg_len
(object_ids, object_sizes,
metadata_sizes) = self.plasma_client.decode_notifications(segment)
assert len(object_ids) == len(object_sizes) == len(metadata_sizes)
for j in range(len(object_ids)):
messages.append((object_ids[j], object_sizes[j],
metadata_sizes[j]))
self._buffer = self._buffer[i:]
self.plasma_event_handler.process_notifications(messages)
def connection_lost(self, exc):
# The socket has been closed
logger.debug("PlasmaProtocol - connection lost.")
def eof_received(self):
logger.debug("PlasmaProtocol - EOF received.")
self.transport.close()
from collections import defaultdict
class PlasmaObjectFuture(asyncio.Future):
"""This class manages the lifecycle of a Future contains an object_id.
Note:
This Future is an item in an linked list.
Attributes:
object_id: The object_id this Future contains.
"""
def __init__(self, loop, object_id):
super().__init__(loop=loop)
self.object_id = object_id
self.prev = None
self.next = None
@property
def ray_object_id(self):
return ray.ObjectID(self.object_id.binary())
def __repr__(self):
return super().__repr__() + "{object_id=%s}" % self.object_id
class PlasmaObjectLinkedList(asyncio.Future):
"""This class is a doubly-linked list.
It holds a ObjectID and maintains futures assigned to the ObjectID.
Args:
loop: an event loop.
plain_object_id (plasma.ObjectID):
The plasma ObjectID this class holds.
"""
def __init__(self, loop, plain_object_id):
super().__init__(loop=loop)
assert isinstance(plain_object_id, plasma.ObjectID)
self.object_id = plain_object_id
self.head = None
self.tail = None
def append(self, future):
"""Append an object to the linked list.
Args:
future (PlasmaObjectFuture): A PlasmaObjectFuture instance.
"""
future.prev = self.tail
if self.tail is None:
assert self.head is None
self.head = future
else:
self.tail.next = future
self.tail = future
# Once done, it will be removed from the list.
future.add_done_callback(self.remove)
def remove(self, future):
"""Remove an object from the linked list.
Args:
future (PlasmaObjectFuture): A PlasmaObjectFuture instance.
"""
if self._loop.get_debug():
logger.debug("Removing %s from the linked list.", future)
if future.prev is None:
assert future is self.head
self.head = future.next
if self.head is None:
self.tail = None
if not self.cancelled():
self.set_result(None)
else:
self.head.prev = None
elif future.next is None:
assert future is self.tail
self.tail = future.prev
if self.tail is None:
self.head = None
if not self.cancelled():
self.set_result(None)
else:
self.tail.prev = None
def cancel(self, *args, **kwargs):
"""Manually cancel all tasks assigned to this event loop."""
# Because remove all futures will trigger `set_result`,
# we cancel itself first.
super().cancel()
for future in self.traverse():
# All cancelled futures should have callbacks to removed itself
# from this linked list. However, these callbacks are scheduled in
# an event loop, so we could still find them in our list.
if not future.cancelled():
future.cancel()
def set_result(self, result):
"""Complete all tasks. """
for future in self.traverse():
# All cancelled futures should have callbacks to removed itself
# from this linked list. However, these callbacks are scheduled in
# an event loop, so we could still find them in our list.
future.set_result(result)
if not self.done():
super().set_result(result)
def traverse(self):
"""Traverse this linked list.
Yields:
PlasmaObjectFuture: PlasmaObjectFuture instances.
"""
current = self.head
while current is not None:
yield current
current = current.next
"""This class is a wrapper for a Future on Plasma."""
pass
class PlasmaEventHandler:
@@ -183,30 +17,46 @@ class PlasmaEventHandler:
super().__init__()
self._loop = loop
self._worker = worker
self._waiting_dict = {}
self._waiting_dict = defaultdict(list)
def process_notifications(self, messages):
"""Process notifications."""
for object_id, object_size, metadata_size in messages:
if object_size > 0 and object_id in self._waiting_dict:
linked_list = self._waiting_dict[object_id]
self._complete_future(linked_list)
self._complete_future(object_id)
def close(self):
"""Clean up this handler."""
for linked_list in self._waiting_dict.values():
linked_list.cancel()
# All cancelled linked lists should have callbacks to removed itself
# from the waiting dict. However, these callbacks are scheduled in
# an event loop, so we don't check them now.
for futures in self._waiting_dict.values():
for fut in futures:
fut.cancel()
def _unregister_callback(self, fut):
del self._waiting_dict[fut.object_id]
def _complete_future(self, ray_object_id):
# TODO(ilr): Consider race condition between popping from the
# waiting_dict and as_future appending to the waiting_dict's list.
logger.debug(
"Completing plasma futures for object id {}".format(ray_object_id))
def _complete_future(self, fut):
obj = self._worker.get_objects([ray.ObjectID(
fut.object_id.binary())])[0]
fut.set_result(obj)
obj = self._worker.get_objects([ray_object_id])[0]
futures = self._waiting_dict.pop(ray_object_id)
for fut in futures:
loop = fut._loop
def complete_closure():
try:
fut.set_result(obj)
except asyncio.InvalidStateError:
# Avoid issues where process_notifications
# and check_ready both get executed
logger.debug("Failed to set result for future {}."
"Most likely already set.".format(fut))
loop.call_soon_threadsafe(complete_closure)
def check_immediately(self, object_id):
ready, _ = ray.wait([object_id], timeout=0)
if ready:
self._complete_future(object_id)
def as_future(self, object_id, check_ready=True):
"""Turn an object_id into a Future object.
@@ -219,25 +69,10 @@ class PlasmaEventHandler:
PlasmaObjectFuture: A future object that waits the object_id.
"""
if not isinstance(object_id, ray.ObjectID):
raise TypeError("Input should be an ObjectID.")
raise TypeError("Input should be a Ray ObjectID.")
plain_object_id = plasma.ObjectID(object_id.binary())
fut = PlasmaObjectFuture(loop=self._loop, object_id=plain_object_id)
future = PlasmaObjectFuture(loop=self._loop)
self._waiting_dict[object_id].append(future)
self.check_immediately(object_id)
if check_ready:
ready, _ = ray.wait([object_id], timeout=0)
if ready:
if self._loop.get_debug():
logger.debug("%s has been ready.", plain_object_id)
self._complete_future(fut)
return fut
if plain_object_id not in self._waiting_dict:
linked_list = PlasmaObjectLinkedList(self._loop, plain_object_id)
linked_list.add_done_callback(self._unregister_callback)
self._waiting_dict[plain_object_id] = linked_list
self._waiting_dict[plain_object_id].append(fut)
if self._loop.get_debug():
logger.debug("%s added to the waiting list.", fut)
return fut
return future
@@ -1,5 +1,6 @@
import asyncio
import time
import os
import pytest
@@ -9,6 +10,7 @@ from ray.experimental import async_api
@pytest.fixture
def init():
os.environ["RAY_FORCE_DIRECT"] = "0"
ray.init(num_cpus=4)
async_api.init()
asyncio.get_event_loop().set_debug(False)