mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 03:02:56 +08:00
44 lines
1.0 KiB
Python
44 lines
1.0 KiB
Python
from contextlib import contextmanager
|
|
|
|
import ray as real_ray
|
|
import ray.util.client.server.server as ray_client_server
|
|
from ray.util.client import ray
|
|
|
|
|
|
@contextmanager
|
|
def ray_start_client_server():
|
|
with ray_start_client_server_pair() as pair:
|
|
client, server = pair
|
|
yield client
|
|
|
|
|
|
@contextmanager
|
|
def ray_start_client_server_pair():
|
|
ray._inside_client_test = True
|
|
server = ray_client_server.serve("localhost:50051")
|
|
ray.connect("localhost:50051")
|
|
try:
|
|
yield ray, server
|
|
finally:
|
|
ray._inside_client_test = False
|
|
ray.disconnect()
|
|
server.stop(0)
|
|
|
|
|
|
@contextmanager
|
|
def ray_start_cluster_client_server_pair(address):
|
|
ray._inside_client_test = True
|
|
|
|
def ray_connect_handler():
|
|
real_ray.init(address=address)
|
|
|
|
server = ray_client_server.serve(
|
|
"localhost:50051", ray_connect_handler=ray_connect_handler)
|
|
ray.connect("localhost:50051")
|
|
try:
|
|
yield ray, server
|
|
finally:
|
|
ray._inside_client_test = False
|
|
ray.disconnect()
|
|
server.stop(0)
|