mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 11:26:24 +08:00
Fix asyncio actor race condition (#7335)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
|
||||
@@ -300,12 +301,6 @@ def test_dynamic_res_concurrent_res_increment(ray_start_cluster):
|
||||
num_nodes = 5
|
||||
TIMEOUT_DURATION = 1
|
||||
|
||||
# Create a object ID to have the task wait on
|
||||
WAIT_OBJECT_ID_STR = ("a" * 20).encode("ascii")
|
||||
|
||||
# Create a object ID to signal that the task is running
|
||||
TASK_RUNNING_OBJECT_ID_STR = ("b" * 20).encode("ascii")
|
||||
|
||||
for i in range(num_nodes):
|
||||
cluster.add_node()
|
||||
|
||||
@@ -325,29 +320,42 @@ def test_dynamic_res_concurrent_res_increment(ray_start_cluster):
|
||||
|
||||
# Task to hold the resource till the driver signals to finish
|
||||
@ray.remote
|
||||
def wait_func(running_oid, wait_oid):
|
||||
# Signal that the task is running
|
||||
ray.worker.global_worker.put_object(1, ray.ObjectID(running_oid))
|
||||
# Make the task wait till signalled by driver
|
||||
ray.get(ray.ObjectID(wait_oid))
|
||||
def wait_func(running_signal, finish_signal):
|
||||
# Signal that the task is running.
|
||||
ray.get(running_signal.send.remote())
|
||||
# Wait until signaled by driver.
|
||||
ray.get(finish_signal.wait.remote())
|
||||
|
||||
@ray.remote
|
||||
def test_func():
|
||||
return 1
|
||||
|
||||
@ray.remote(num_cpus=0)
|
||||
class Signal:
|
||||
def __init__(self):
|
||||
self.ready_event = asyncio.Event()
|
||||
|
||||
def send(self):
|
||||
self.ready_event.set()
|
||||
|
||||
async def wait(self):
|
||||
await self.ready_event.wait()
|
||||
|
||||
running_signal = Signal.remote()
|
||||
finish_signal = Signal.remote()
|
||||
|
||||
# Launch the task with resource requirement of 4, thus the new available
|
||||
# capacity becomes 1
|
||||
task = wait_func._remote(
|
||||
args=[TASK_RUNNING_OBJECT_ID_STR, WAIT_OBJECT_ID_STR],
|
||||
resources={res_name: 4})
|
||||
# Wait till wait_func is launched before updating resource
|
||||
ray.get(ray.ObjectID(TASK_RUNNING_OBJECT_ID_STR))
|
||||
args=[running_signal, finish_signal], resources={res_name: 4})
|
||||
# Wait until wait_func is launched before updating resource
|
||||
ray.get(running_signal.wait.remote())
|
||||
|
||||
# Update the resource capacity
|
||||
ray.get(set_res.remote(res_name, updated_capacity, target_node_id))
|
||||
|
||||
# Signal task to complete
|
||||
ray.worker.global_worker.put_object(1, ray.ObjectID(WAIT_OBJECT_ID_STR))
|
||||
ray.get(finish_signal.send.remote())
|
||||
ray.get(task)
|
||||
|
||||
# Check if scheduler state is consistent by launching a task requiring
|
||||
@@ -379,12 +387,6 @@ def test_dynamic_res_concurrent_res_decrement(ray_start_cluster):
|
||||
num_nodes = 5
|
||||
TIMEOUT_DURATION = 1
|
||||
|
||||
# Create a object ID to have the task wait on
|
||||
WAIT_OBJECT_ID_STR = ("a" * 20).encode("ascii")
|
||||
|
||||
# Create a object ID to signal that the task is running
|
||||
TASK_RUNNING_OBJECT_ID_STR = ("b" * 20).encode("ascii")
|
||||
|
||||
for i in range(num_nodes):
|
||||
cluster.add_node()
|
||||
|
||||
@@ -404,29 +406,42 @@ def test_dynamic_res_concurrent_res_decrement(ray_start_cluster):
|
||||
|
||||
# Task to hold the resource till the driver signals to finish
|
||||
@ray.remote
|
||||
def wait_func(running_oid, wait_oid):
|
||||
# Signal that the task is running
|
||||
ray.worker.global_worker.put_object(1, ray.ObjectID(running_oid))
|
||||
# Make the task wait till signalled by driver
|
||||
ray.get(ray.ObjectID(wait_oid))
|
||||
def wait_func(running_signal, finish_signal):
|
||||
# Signal that the task is running.
|
||||
ray.get(running_signal.send.remote())
|
||||
# Wait until signaled by driver.
|
||||
ray.get(finish_signal.wait.remote())
|
||||
|
||||
@ray.remote
|
||||
def test_func():
|
||||
return 1
|
||||
|
||||
@ray.remote(num_cpus=0)
|
||||
class Signal:
|
||||
def __init__(self):
|
||||
self.ready_event = asyncio.Event()
|
||||
|
||||
def send(self):
|
||||
self.ready_event.set()
|
||||
|
||||
async def wait(self):
|
||||
await self.ready_event.wait()
|
||||
|
||||
running_signal = Signal.remote()
|
||||
finish_signal = Signal.remote()
|
||||
|
||||
# Launch the task with resource requirement of 4, thus the new available
|
||||
# capacity becomes 1
|
||||
task = wait_func._remote(
|
||||
args=[TASK_RUNNING_OBJECT_ID_STR, WAIT_OBJECT_ID_STR],
|
||||
resources={res_name: 4})
|
||||
# Wait till wait_func is launched before updating resource
|
||||
ray.get(ray.ObjectID(TASK_RUNNING_OBJECT_ID_STR))
|
||||
args=[running_signal, finish_signal], resources={res_name: 4})
|
||||
# Wait until wait_func is launched before updating resource
|
||||
ray.get(running_signal.wait.remote())
|
||||
|
||||
# Decrease the resource capacity
|
||||
ray.get(set_res.remote(res_name, updated_capacity, target_node_id))
|
||||
|
||||
# Signal task to complete
|
||||
ray.worker.global_worker.put_object(1, ray.ObjectID(WAIT_OBJECT_ID_STR))
|
||||
ray.get(finish_signal.send.remote())
|
||||
ray.get(task)
|
||||
|
||||
# Check if scheduler state is consistent by launching a task requiring
|
||||
@@ -456,12 +471,6 @@ def test_dynamic_res_concurrent_res_delete(ray_start_cluster):
|
||||
num_nodes = 5
|
||||
TIMEOUT_DURATION = 1
|
||||
|
||||
# Create a object ID to have the task wait on
|
||||
WAIT_OBJECT_ID_STR = ("a" * 20).encode("ascii")
|
||||
|
||||
# Create a object ID to signal that the task is running
|
||||
TASK_RUNNING_OBJECT_ID_STR = ("b" * 20).encode("ascii")
|
||||
|
||||
for i in range(num_nodes):
|
||||
cluster.add_node()
|
||||
|
||||
@@ -486,29 +495,42 @@ def test_dynamic_res_concurrent_res_delete(ray_start_cluster):
|
||||
|
||||
# Task to hold the resource till the driver signals to finish
|
||||
@ray.remote
|
||||
def wait_func(running_oid, wait_oid):
|
||||
# Signal that the task is running
|
||||
ray.worker.global_worker.put_object(1, ray.ObjectID(running_oid))
|
||||
# Make the task wait till signalled by driver
|
||||
ray.get(ray.ObjectID(wait_oid))
|
||||
def wait_func(running_signal, finish_signal):
|
||||
# Signal that the task is running.
|
||||
ray.get(running_signal.send.remote())
|
||||
# Wait until signaled by driver.
|
||||
ray.get(finish_signal.wait.remote())
|
||||
|
||||
@ray.remote
|
||||
def test_func():
|
||||
return 1
|
||||
|
||||
@ray.remote(num_cpus=0)
|
||||
class Signal:
|
||||
def __init__(self):
|
||||
self.ready_event = asyncio.Event()
|
||||
|
||||
def send(self):
|
||||
self.ready_event.set()
|
||||
|
||||
async def wait(self):
|
||||
await self.ready_event.wait()
|
||||
|
||||
running_signal = Signal.remote()
|
||||
finish_signal = Signal.remote()
|
||||
|
||||
# Launch the task with resource requirement of 4, thus the new available
|
||||
# capacity becomes 1
|
||||
task = wait_func._remote(
|
||||
args=[TASK_RUNNING_OBJECT_ID_STR, WAIT_OBJECT_ID_STR],
|
||||
resources={res_name: 4})
|
||||
# Wait till wait_func is launched before updating resource
|
||||
ray.get(ray.ObjectID(TASK_RUNNING_OBJECT_ID_STR))
|
||||
args=[running_signal, finish_signal], resources={res_name: 4})
|
||||
# Wait until wait_func is launched before updating resource
|
||||
ray.get(running_signal.wait.remote())
|
||||
|
||||
# Delete the resource
|
||||
ray.get(delete_res.remote(res_name, target_node_id))
|
||||
|
||||
# Signal task to complete
|
||||
ray.worker.global_worker.put_object(1, ray.ObjectID(WAIT_OBJECT_ID_STR))
|
||||
ray.get(finish_signal.send.remote())
|
||||
ray.get(task)
|
||||
|
||||
# Check if scheduler state is consistent by launching a task requiring
|
||||
|
||||
Reference in New Issue
Block a user