mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:22:39 +08:00
[core] Switch Async Callback to C++ [WIP] (#9228)
Co-authored-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
+20
-50
@@ -85,7 +85,7 @@ from ray.includes.global_state_accessor cimport CGlobalStateAccessor
|
||||
|
||||
import ray
|
||||
from ray.async_compat import (
|
||||
sync_to_async, AsyncGetResponse, get_new_event_loop)
|
||||
sync_to_async, get_new_event_loop)
|
||||
import ray.memory_monitor as memory_monitor
|
||||
import ray.ray_constants as ray_constants
|
||||
from ray import profiling
|
||||
@@ -559,19 +559,6 @@ cdef CRayStatus task_execution_handler(
|
||||
|
||||
return CRayStatus.OK()
|
||||
|
||||
cdef void async_plasma_callback(CObjectID object_id,
|
||||
int64_t data_size,
|
||||
int64_t metadata_size) with gil:
|
||||
core_worker = ray.worker.global_worker.core_worker
|
||||
event_handler = core_worker.get_plasma_event_handler()
|
||||
if event_handler is not None:
|
||||
obj_id = ObjectID(object_id.Binary())
|
||||
if data_size > 0 and obj_id:
|
||||
# This must be asynchronous to allow objects to avoid blocking
|
||||
# the IO thread.
|
||||
event_handler._loop.call_soon_threadsafe(
|
||||
event_handler._complete_future, obj_id)
|
||||
|
||||
cdef c_bool kill_main_task() nogil:
|
||||
with gil:
|
||||
if setproctitle.getproctitle() != "ray::IDLE":
|
||||
@@ -731,15 +718,6 @@ cdef class CoreWorker:
|
||||
def set_actor_title(self, title):
|
||||
CCoreWorkerProcess.GetCoreWorker().SetActorTitle(title)
|
||||
|
||||
def set_plasma_added_callback(self, plasma_event_handler):
|
||||
self.plasma_event_handler = plasma_event_handler
|
||||
CCoreWorkerProcess.GetCoreWorker().SetPlasmaAddedCallback(
|
||||
async_plasma_callback)
|
||||
|
||||
def subscribe_to_plasma_object(self, ObjectID object_id):
|
||||
CCoreWorkerProcess.GetCoreWorker().SubscribeToPlasmaAdd(
|
||||
object_id.native())
|
||||
|
||||
def get_plasma_event_handler(self):
|
||||
return self.plasma_event_handler
|
||||
|
||||
@@ -1198,10 +1176,6 @@ cdef class CoreWorker:
|
||||
if self.async_event_loop is None:
|
||||
self.async_event_loop = get_new_event_loop()
|
||||
asyncio.set_event_loop(self.async_event_loop)
|
||||
# Initialize the async plasma connection.
|
||||
# Delayed import due to async_api depends on _raylet.
|
||||
from ray.experimental.async_api import init as plasma_async_init
|
||||
plasma_async_init()
|
||||
|
||||
if self.async_thread is None:
|
||||
self.async_thread = threading.Thread(
|
||||
@@ -1263,12 +1237,12 @@ cdef class CoreWorker:
|
||||
|
||||
return ref_counts
|
||||
|
||||
def in_memory_store_get_async(self, ObjectID object_id, future):
|
||||
def get_async(self, ObjectID object_id, future):
|
||||
cpython.Py_INCREF(future)
|
||||
CCoreWorkerProcess.GetCoreWorker().GetAsync(
|
||||
object_id.native(),
|
||||
async_set_result_callback,
|
||||
async_retry_with_plasma_callback,
|
||||
<void*>future)
|
||||
object_id.native(),
|
||||
async_set_result,
|
||||
<void*>future)
|
||||
|
||||
def push_error(self, JobID job_id, error_type, error_message,
|
||||
double timestamp):
|
||||
@@ -1302,12 +1276,11 @@ cdef class CoreWorker:
|
||||
resource_name.encode("ascii"), capacity,
|
||||
CClientID.FromBinary(client_id.binary()))
|
||||
|
||||
cdef void async_set_result_callback(shared_ptr[CRayObject] obj,
|
||||
CObjectID object_id,
|
||||
void *future) with gil:
|
||||
cdef void async_set_result(shared_ptr[CRayObject] obj,
|
||||
CObjectID object_id,
|
||||
void *future) with gil:
|
||||
cdef:
|
||||
c_vector[shared_ptr[CRayObject]] objects_to_deserialize
|
||||
|
||||
py_future = <object>(future)
|
||||
loop = py_future._loop
|
||||
|
||||
@@ -1317,18 +1290,15 @@ cdef void async_set_result_callback(shared_ptr[CRayObject] obj,
|
||||
data_metadata_pairs = RayObjectsToDataMetadataPairs(
|
||||
objects_to_deserialize)
|
||||
ids_to_deserialize = [ObjectID(object_id.Binary())]
|
||||
objects = ray.worker.global_worker.deserialize_objects(
|
||||
data_metadata_pairs, ids_to_deserialize)
|
||||
loop.call_soon_threadsafe(lambda: py_future.set_result(
|
||||
AsyncGetResponse(
|
||||
plasma_fallback_id=None, result=objects[0])))
|
||||
result = ray.worker.global_worker.deserialize_objects(
|
||||
data_metadata_pairs, ids_to_deserialize)[0]
|
||||
|
||||
cdef void async_retry_with_plasma_callback(shared_ptr[CRayObject] obj,
|
||||
CObjectID object_id,
|
||||
void *future) with gil:
|
||||
py_future = <object>(future)
|
||||
loop = py_future._loop
|
||||
loop.call_soon_threadsafe(lambda: py_future.set_result(
|
||||
AsyncGetResponse(
|
||||
plasma_fallback_id=ObjectID(object_id.Binary()),
|
||||
result=None)))
|
||||
def set_future():
|
||||
if isinstance(result, RayTaskError):
|
||||
ray.worker.last_task_error_raise_time = time.time()
|
||||
py_future.set_exception(result.as_instanceof_cause())
|
||||
else:
|
||||
py_future.set_result(result)
|
||||
cpython.Py_DECREF(py_future)
|
||||
|
||||
loop.call_soon_threadsafe(set_future)
|
||||
|
||||
@@ -3,8 +3,6 @@ This file should only be imported from Python 3.
|
||||
It will raise SyntaxError when importing from Python 2.
|
||||
"""
|
||||
import asyncio
|
||||
from collections import namedtuple
|
||||
import time
|
||||
import inspect
|
||||
|
||||
try:
|
||||
@@ -35,84 +33,13 @@ def sync_to_async(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
# Class encapsulate the get result from direct actor.
|
||||
# Case 1: plasma_fallback_id=None, result=<Object>
|
||||
# Case 2: plasma_fallback_id=ObjectID, result=None
|
||||
AsyncGetResponse = namedtuple("AsyncGetResponse",
|
||||
["plasma_fallback_id", "result"])
|
||||
|
||||
|
||||
def get_async(object_id):
|
||||
"""Asyncio compatible version of ray.get"""
|
||||
# Delayed import because raylet import this file and
|
||||
# it creates circular imports.
|
||||
from ray.experimental.async_api import init as async_api_init, as_future
|
||||
from ray.experimental.async_plasma import PlasmaObjectFuture
|
||||
|
||||
assert isinstance(object_id, ray.ObjectID), "Batched get is not supported."
|
||||
|
||||
# Setup
|
||||
async_api_init()
|
||||
"""C++ Asyncio version of ray.get"""
|
||||
loop = asyncio.get_event_loop()
|
||||
core_worker = ray.worker.global_worker.core_worker
|
||||
|
||||
# Here's the callback used to implement async get logic.
|
||||
# What we want:
|
||||
# - If direct call, first try to get it from in memory store.
|
||||
# If the object if promoted to plasma, retry it from plasma API.
|
||||
# - If not direct call, directly use plasma API to get it.
|
||||
user_future = loop.create_future()
|
||||
|
||||
# We have three future objects here.
|
||||
# user_future is directly returned to the user from this function.
|
||||
# and it will be eventually fulfilled by the final result.
|
||||
# inner_future is the first attempt to retrieve the object. It can be
|
||||
# fulfilled by either core_worker.get_async or plasma_api.as_future.
|
||||
# When inner_future completes, done_callback will be invoked. This
|
||||
# callback set the final object in user_future if the object hasn't
|
||||
# been promoted by plasma, otherwise it will retry from plasma.
|
||||
# retry_plasma_future is only created when we are getting objects that's
|
||||
# promoted to plasma. It will also invoke the done_callback when it's
|
||||
# fulfilled.
|
||||
|
||||
def done_callback(future):
|
||||
result = future.result()
|
||||
# Result from async plasma, transparently pass it to user future
|
||||
if isinstance(future, PlasmaObjectFuture):
|
||||
if isinstance(result, ray.exceptions.RayTaskError):
|
||||
ray.worker.last_task_error_raise_time = time.time()
|
||||
user_future.set_exception(result.as_instanceof_cause())
|
||||
else:
|
||||
user_future.set_result(result)
|
||||
else:
|
||||
# Result from direct call.
|
||||
assert isinstance(result, AsyncGetResponse), result
|
||||
if result.plasma_fallback_id is None:
|
||||
# If this future has result set already, we just need to
|
||||
# skip the set result/exception procedure.
|
||||
if user_future.done():
|
||||
return
|
||||
|
||||
if isinstance(result.result, ray.exceptions.RayTaskError):
|
||||
ray.worker.last_task_error_raise_time = time.time()
|
||||
user_future.set_exception(
|
||||
result.result.as_instanceof_cause())
|
||||
else:
|
||||
user_future.set_result(result.result)
|
||||
else:
|
||||
# Schedule plasma to async get, use the the same callback.
|
||||
retry_plasma_future = as_future(result.plasma_fallback_id)
|
||||
retry_plasma_future.add_done_callback(done_callback)
|
||||
# A hack to keep reference to the future so it doesn't get GC.
|
||||
user_future.retry_plasma_future = retry_plasma_future
|
||||
|
||||
inner_future = loop.create_future()
|
||||
# We must add the done_callback before sending to in_memory_store_get
|
||||
inner_future.add_done_callback(done_callback)
|
||||
core_worker.in_memory_store_get_async(object_id, inner_future)
|
||||
# A hack to keep reference to inner_future so it doesn't get GC.
|
||||
user_future.inner_future = inner_future
|
||||
future = loop.create_future()
|
||||
core_worker.get_async(object_id, future)
|
||||
# A hack to keep a reference to the object ID for ref counting.
|
||||
user_future.object_id = object_id
|
||||
|
||||
return user_future
|
||||
future.object_id = object_id
|
||||
return future
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
import asyncio
|
||||
|
||||
import ray
|
||||
from ray.experimental.async_plasma import PlasmaEventHandler
|
||||
from ray.services import logger
|
||||
|
||||
handler = None
|
||||
|
||||
|
||||
def init():
|
||||
"""Initialize plasma event handlers for asyncio support."""
|
||||
assert ray.is_initialized(), "Please call ray.init before async_api.init"
|
||||
|
||||
global handler
|
||||
if handler is None:
|
||||
worker = ray.worker.global_worker
|
||||
loop = asyncio.get_event_loop()
|
||||
handler = PlasmaEventHandler(loop, worker)
|
||||
worker.core_worker.set_plasma_added_callback(handler)
|
||||
logger.debug("AsyncPlasma Connection Created!")
|
||||
|
||||
|
||||
def as_future(object_id):
|
||||
"""Turn an object_id into a Future object.
|
||||
|
||||
Args:
|
||||
object_id: A Ray object_id.
|
||||
|
||||
Returns:
|
||||
PlasmaObjectFuture: A future object that waits the object_id.
|
||||
"""
|
||||
if handler is None:
|
||||
init()
|
||||
return handler.as_future(object_id)
|
||||
|
||||
|
||||
def shutdown():
|
||||
"""Manually shutdown the async API.
|
||||
|
||||
Cancels all related tasks and all the socket transportation.
|
||||
"""
|
||||
global handler
|
||||
if handler is not None:
|
||||
handler.close()
|
||||
handler = None
|
||||
@@ -1,70 +0,0 @@
|
||||
import asyncio
|
||||
|
||||
import ray
|
||||
from ray.services import logger
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class PlasmaObjectFuture(asyncio.Future):
|
||||
"""This class is a wrapper for a Future on Plasma."""
|
||||
pass
|
||||
|
||||
|
||||
class PlasmaEventHandler:
|
||||
"""This class is an event handler for Plasma."""
|
||||
|
||||
def __init__(self, loop, worker):
|
||||
super().__init__()
|
||||
self._loop = loop
|
||||
self._worker = worker
|
||||
self._waiting_dict = defaultdict(list)
|
||||
|
||||
def _complete_future(self, ray_object_id):
|
||||
# TODO(ilr): Consider race condition between popping from the
|
||||
# waiting_dict and as_future appending to the waiting_dict's list.
|
||||
logger.debug(
|
||||
"Completing plasma futures for object id {}".format(ray_object_id))
|
||||
if ray_object_id not in self._waiting_dict:
|
||||
return
|
||||
obj = self._worker.get_objects([ray_object_id], timeout=0)[0]
|
||||
futures = self._waiting_dict.pop(ray_object_id)
|
||||
for fut in futures:
|
||||
try:
|
||||
fut.set_result(obj)
|
||||
except asyncio.InvalidStateError:
|
||||
# Avoid issues where process_notifications
|
||||
# and check_immediately both get executed
|
||||
logger.debug("Failed to set result for future {}."
|
||||
"Most likely already set.".format(fut))
|
||||
|
||||
def close(self):
|
||||
"""Clean up this handler."""
|
||||
for futures in self._waiting_dict.values():
|
||||
for fut in futures:
|
||||
fut.cancel()
|
||||
|
||||
def check_immediately(self, object_id):
|
||||
ready, _ = ray.wait([object_id], timeout=0)
|
||||
if ready:
|
||||
self._complete_future(object_id)
|
||||
|
||||
def as_future(self, object_id, check_ready=True):
|
||||
"""Turn an object_id into a Future object.
|
||||
|
||||
Args:
|
||||
object_id: A Ray's object_id.
|
||||
check_ready (bool): If true, check if the object_id is ready.
|
||||
|
||||
Returns:
|
||||
PlasmaObjectFuture: A future object that waits the object_id.
|
||||
"""
|
||||
if not isinstance(object_id, ray.ObjectID):
|
||||
raise TypeError("Input should be a Ray ObjectID.")
|
||||
|
||||
future = PlasmaObjectFuture(loop=self._loop)
|
||||
self._waiting_dict[object_id].append(future)
|
||||
if not self.check_immediately(object_id) and len(
|
||||
self._waiting_dict[object_id]) == 1:
|
||||
# Only subscribe once
|
||||
self._worker.core_worker.subscribe_to_plasma_object(object_id)
|
||||
return future
|
||||
@@ -172,7 +172,6 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||
|
||||
void GetAsync(const CObjectID &object_id,
|
||||
ray_callback_function success_callback,
|
||||
ray_callback_function fallback_callback,
|
||||
void* python_future)
|
||||
|
||||
CRayStatus PushError(const CJobID &job_id, const c_string &type,
|
||||
@@ -185,10 +184,6 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||
const double capacity,
|
||||
const CClientID &client_Id)
|
||||
|
||||
void SetPlasmaAddedCallback(plasma_callback_function callback)
|
||||
|
||||
void SubscribeToPlasmaAdd(const CObjectID &object_id)
|
||||
|
||||
cdef cppclass CCoreWorkerOptions "ray::CoreWorkerOptions":
|
||||
CWorkerType worker_type
|
||||
CLanguage language
|
||||
|
||||
@@ -472,3 +472,11 @@ py_test(
|
||||
tags = ["exclusive"],
|
||||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_async",
|
||||
size = "medium",
|
||||
srcs = SRCS + ["test_async.py"],
|
||||
tags = ["exclusive"],
|
||||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
||||
@@ -1,20 +1,19 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import time
|
||||
import pytest
|
||||
|
||||
import numpy as np
|
||||
|
||||
import pytest
|
||||
|
||||
import ray
|
||||
from ray.experimental import async_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def init():
|
||||
ray.init(num_cpus=4)
|
||||
async_api.init()
|
||||
asyncio.get_event_loop().set_debug(False)
|
||||
yield
|
||||
async_api.shutdown()
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@@ -33,7 +32,7 @@ def test_simple(init):
|
||||
time.sleep(1)
|
||||
return np.zeros(1024 * 1024, dtype=np.uint8)
|
||||
|
||||
future = async_api.as_future(f.remote())
|
||||
future = f.remote().as_future()
|
||||
result = asyncio.get_event_loop().run_until_complete(future)
|
||||
assert isinstance(result, np.ndarray)
|
||||
|
||||
@@ -41,7 +40,7 @@ def test_simple(init):
|
||||
def test_gather(init):
|
||||
loop = asyncio.get_event_loop()
|
||||
tasks = gen_tasks()
|
||||
futures = [async_api.as_future(obj_id) for obj_id in tasks]
|
||||
futures = [obj_id.as_future() for obj_id in tasks]
|
||||
results = loop.run_until_complete(asyncio.gather(*futures))
|
||||
assert all(a[0] == b[0] for a, b in zip(results, ray.get(tasks)))
|
||||
|
||||
@@ -49,7 +48,7 @@ def test_gather(init):
|
||||
def test_wait(init):
|
||||
loop = asyncio.get_event_loop()
|
||||
tasks = gen_tasks()
|
||||
futures = [async_api.as_future(obj_id) for obj_id in tasks]
|
||||
futures = [obj_id.as_future() for obj_id in tasks]
|
||||
results, _ = loop.run_until_complete(asyncio.wait(futures))
|
||||
assert set(results) == set(futures)
|
||||
|
||||
@@ -57,7 +56,7 @@ def test_wait(init):
|
||||
def test_wait_timeout(init):
|
||||
loop = asyncio.get_event_loop()
|
||||
tasks = gen_tasks(10)
|
||||
futures = [async_api.as_future(obj_id) for obj_id in tasks]
|
||||
futures = [obj_id.as_future() for obj_id in tasks]
|
||||
fut = asyncio.wait(futures, timeout=5)
|
||||
results, _ = loop.run_until_complete(fut)
|
||||
assert list(results)[0] == futures[0]
|
||||
@@ -75,12 +74,7 @@ def test_gather_mixup(init):
|
||||
await asyncio.sleep(n * 0.1)
|
||||
return n, np.zeros(1024 * 1024, dtype=np.uint8)
|
||||
|
||||
tasks = [
|
||||
async_api.as_future(f.remote(1)),
|
||||
g(2),
|
||||
async_api.as_future(f.remote(3)),
|
||||
g(4)
|
||||
]
|
||||
tasks = [f.remote(1).as_future(), g(2), f.remote(3).as_future(), g(4)]
|
||||
results = loop.run_until_complete(asyncio.gather(*tasks))
|
||||
assert [result[0] for result in results] == [1, 2, 3, 4]
|
||||
|
||||
@@ -100,11 +94,31 @@ def test_wait_mixup(init):
|
||||
|
||||
return asyncio.ensure_future(_g(n))
|
||||
|
||||
tasks = [
|
||||
async_api.as_future(f.remote(0.1)),
|
||||
g(7),
|
||||
async_api.as_future(f.remote(5)),
|
||||
g(2)
|
||||
]
|
||||
tasks = [f.remote(0.1).as_future(), g(7), f.remote(5).as_future(), g(2)]
|
||||
ready, _ = loop.run_until_complete(asyncio.wait(tasks, timeout=4))
|
||||
assert set(ready) == {tasks[0], tasks[-1]}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_regular_shared", [{
|
||||
"object_store_memory": 100 * 1024 * 1024,
|
||||
}],
|
||||
indirect=True)
|
||||
async def test_garbage_collection(ray_start_regular_shared):
|
||||
# This is a regression test for
|
||||
# https://github.com/ray-project/ray/issues/9134
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
return np.zeros(40 * 1024 * 1024, dtype=np.uint8)
|
||||
|
||||
for _ in range(10):
|
||||
await f.remote()
|
||||
for _ in range(10):
|
||||
put_id = ray.put(np.zeros(40 * 1024 * 1024, dtype=np.uint8))
|
||||
await put_id
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
@@ -1,8 +1,9 @@
|
||||
# coding: utf-8
|
||||
import asyncio
|
||||
import threading
|
||||
import pytest
|
||||
import sys
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
|
||||
import ray
|
||||
from ray.test_utils import SignalActor
|
||||
@@ -113,10 +114,6 @@ async def test_asyncio_get(ray_start_regular_shared, event_loop):
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.set_debug(True)
|
||||
|
||||
# This is needed for async plasma
|
||||
from ray.experimental.async_api import init
|
||||
init()
|
||||
|
||||
# Test Async Plasma
|
||||
@ray.remote
|
||||
def task():
|
||||
|
||||
Reference in New Issue
Block a user