diff --git a/python/ray/async_compat.py b/python/ray/async_compat.py index ff1cd595a..1428e765c 100644 --- a/python/ray/async_compat.py +++ b/python/ray/async_compat.py @@ -75,6 +75,11 @@ def get_async(object_id): # Result from direct call. assert isinstance(result, AsyncGetResponse), result if result.plasma_fallback_id is None: + # If this future has result set already, we just need to + # skip the set result/exception procedure. + if user_future.done(): + return + if isinstance(result.result, ray.exceptions.RayTaskError): ray.worker.last_task_error_raise_time = time.time() user_future.set_exception( diff --git a/python/ray/tests/test_asyncio.py b/python/ray/tests/test_asyncio.py index 657a68664..da3dac49d 100644 --- a/python/ray/tests/test_asyncio.py +++ b/python/ray/tests/test_asyncio.py @@ -5,6 +5,7 @@ import pytest import sys import ray +from ray.test_utils import SignalActor def test_asyncio_actor(ray_start_regular_shared): @@ -177,6 +178,26 @@ def test_asyncio_actor_async_get(ray_start_regular_shared): assert ray.get(getter.plasma_get.remote([plasma_object])) == 2 +@pytest.mark.asyncio +async def test_asyncio_double_await(ray_start_regular_shared): + # This is a regression test for + # https://github.com/ray-project/ray/issues/8841 + + signal = SignalActor.remote() + waiting = signal.wait.remote() + + future = waiting.as_future() + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(future, timeout=0.1) + assert future.cancelled() + + # We are explicitly waiting multiple times here to test asyncio state + # override. + await signal.send.remote() + await waiting + await waiting + + if __name__ == "__main__": import pytest sys.exit(pytest.main(["-v", __file__]))