mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:32:11 +08:00
[ray_client] convert things registered for ray into ray_client (#13639)
This commit is contained in:
@@ -2,6 +2,9 @@ import os
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
|
||||
# Attr set on func defs to mark they have been converted to client mode.
|
||||
RAY_CLIENT_MODE_ATTR = "__ray_client_mode_key__"
|
||||
|
||||
client_mode_enabled = os.environ.get("RAY_CLIENT_MODE", "0") == "1"
|
||||
|
||||
_client_hook_enabled = True
|
||||
@@ -34,16 +37,54 @@ def disable_client_hook():
|
||||
|
||||
|
||||
def client_mode_hook(func):
|
||||
"""
|
||||
Decorator for ray module methods to delegate to ray client
|
||||
"""
|
||||
"""Decorator for ray module methods to delegate to ray client"""
|
||||
from ray.util.client import ray
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
global _client_hook_enabled
|
||||
if client_mode_enabled and _client_hook_enabled:
|
||||
if client_mode_should_convert():
|
||||
return getattr(ray, func.__name__)(*args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def client_mode_should_convert():
|
||||
global _client_hook_enabled
|
||||
return client_mode_enabled and _client_hook_enabled
|
||||
|
||||
|
||||
def client_mode_convert_function(func_cls, in_args, in_kwargs, **kwargs):
|
||||
"""Runs a preregistered ray RemoteFunction through the ray client.
|
||||
|
||||
The common case for this is to transparently convert that RemoteFunction
|
||||
to a ClientRemoteFunction. This happens in circumstances where the
|
||||
RemoteFunction is declared early, in a library and only then is Ray used in
|
||||
client mode -- nescessitating a conversion.
|
||||
"""
|
||||
from ray.util.client import ray
|
||||
|
||||
key = getattr(func_cls, RAY_CLIENT_MODE_ATTR, None)
|
||||
if key is None:
|
||||
key = ray._convert_function(func_cls)
|
||||
setattr(func_cls, RAY_CLIENT_MODE_ATTR, key)
|
||||
client_func = ray._get_converted(key)
|
||||
return client_func._remote(in_args, in_kwargs, **kwargs)
|
||||
|
||||
|
||||
def client_mode_convert_actor(actor_cls, in_args, in_kwargs, **kwargs):
|
||||
"""Runs a preregistered actor class on the ray client
|
||||
|
||||
The common case for this decorator is for instantiating an ActorClass
|
||||
transparently as a ClientActorClass. This happens in circumstances where
|
||||
the ActorClass is declared early, in a library and only then is Ray used in
|
||||
client mode -- nescessitating a conversion.
|
||||
"""
|
||||
from ray.util.client import ray
|
||||
|
||||
key = getattr(actor_cls, RAY_CLIENT_MODE_ATTR, None)
|
||||
if key is None:
|
||||
key = ray._convert_actor(actor_cls)
|
||||
setattr(actor_cls, RAY_CLIENT_MODE_ATTR, key)
|
||||
client_actor = ray._get_converted(key)
|
||||
return client_actor._remote(in_args, in_kwargs, **kwargs)
|
||||
|
||||
@@ -13,6 +13,8 @@ from ray.util.placement_group import (
|
||||
from ray import ActorClassID, Language
|
||||
from ray._raylet import PythonFunctionDescriptor
|
||||
from ray._private.client_mode_hook import client_mode_hook
|
||||
from ray._private.client_mode_hook import client_mode_should_convert
|
||||
from ray._private.client_mode_hook import client_mode_convert_actor
|
||||
from ray import cross_language
|
||||
from ray.util.inspect import (
|
||||
is_function_or_method,
|
||||
@@ -553,6 +555,29 @@ class ActorClass:
|
||||
if max_concurrency < 1:
|
||||
raise ValueError("max_concurrency must be >= 1")
|
||||
|
||||
if client_mode_should_convert():
|
||||
return client_mode_convert_actor(
|
||||
self,
|
||||
args,
|
||||
kwargs,
|
||||
num_cpus=num_cpus,
|
||||
num_gpus=num_gpus,
|
||||
memory=memory,
|
||||
object_store_memory=object_store_memory,
|
||||
resources=resources,
|
||||
accelerator_type=accelerator_type,
|
||||
max_concurrency=max_concurrency,
|
||||
max_restarts=max_restarts,
|
||||
max_task_retries=max_task_retries,
|
||||
name=name,
|
||||
lifetime=lifetime,
|
||||
placement_group=placement_group,
|
||||
placement_group_bundle_index=placement_group_bundle_index,
|
||||
placement_group_capture_child_tasks=(
|
||||
placement_group_capture_child_tasks),
|
||||
override_environment_variables=(
|
||||
override_environment_variables))
|
||||
|
||||
worker = ray.worker.global_worker
|
||||
worker.check_connected()
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ from functools import wraps
|
||||
from ray import cloudpickle as pickle
|
||||
from ray._raylet import PythonFunctionDescriptor
|
||||
from ray import cross_language, Language
|
||||
from ray._private.client_mode_hook import client_mode_convert_function
|
||||
from ray._private.client_mode_hook import client_mode_should_convert
|
||||
from ray.util.placement_group import (
|
||||
PlacementGroup,
|
||||
check_placement_group_index,
|
||||
@@ -181,6 +183,26 @@ class RemoteFunction:
|
||||
override_environment_variables=None,
|
||||
name=""):
|
||||
"""Submit the remote function for execution."""
|
||||
if client_mode_should_convert():
|
||||
return client_mode_convert_function(
|
||||
self,
|
||||
args,
|
||||
kwargs,
|
||||
num_returns=num_returns,
|
||||
num_cpus=num_cpus,
|
||||
num_gpus=num_gpus,
|
||||
memory=memory,
|
||||
object_store_memory=object_store_memory,
|
||||
accelerator_type=accelerator_type,
|
||||
resources=resources,
|
||||
max_retries=max_retries,
|
||||
placement_group=placement_group,
|
||||
placement_group_bundle_index=placement_group_bundle_index,
|
||||
placement_group_capture_child_tasks=(
|
||||
placement_group_capture_child_tasks),
|
||||
override_environment_variables=override_environment_variables,
|
||||
name=name)
|
||||
|
||||
worker = ray.worker.global_worker
|
||||
worker.check_connected()
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
import pytest
|
||||
|
||||
import time
|
||||
import random
|
||||
import sys
|
||||
|
||||
import ray.util.client.server.server as ray_client_server
|
||||
@@ -9,6 +10,54 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
|
||||
from ray.util.client import RayAPIStub
|
||||
|
||||
import ray
|
||||
|
||||
|
||||
@ray.remote
|
||||
def hello_world():
|
||||
c1 = complex_task.remote(random.randint(1, 10))
|
||||
c2 = complex_task.remote(random.randint(1, 10))
|
||||
return sum(ray.get([c1, c2]))
|
||||
|
||||
|
||||
@ray.remote
|
||||
def complex_task(value):
|
||||
time.sleep(1)
|
||||
return value * 10
|
||||
|
||||
|
||||
@ray.remote
|
||||
class C:
|
||||
def __init__(self, x):
|
||||
self.val = x
|
||||
|
||||
def double(self):
|
||||
self.val += self.val
|
||||
|
||||
def get(self):
|
||||
return self.val
|
||||
|
||||
|
||||
def test_basic_preregister():
|
||||
from ray.util.client import ray
|
||||
server, _ = ray_client_server.init_and_serve("localhost:50051")
|
||||
try:
|
||||
ray.connect("localhost:50051")
|
||||
val = ray.get(hello_world.remote())
|
||||
print(val)
|
||||
assert val >= 20
|
||||
assert val <= 200
|
||||
c = C.remote(3)
|
||||
x = c.double.remote()
|
||||
y = c.double.remote()
|
||||
ray.wait([x, y])
|
||||
val = ray.get(c.get.remote())
|
||||
assert val == 12
|
||||
finally:
|
||||
ray.disconnect()
|
||||
ray_client_server.shutdown_with_server(server)
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
def test_num_clients():
|
||||
# Tests num clients reporting; useful if you want to build an app that
|
||||
|
||||
@@ -4,6 +4,8 @@ and the overall ray module API.
|
||||
from ray.util.client.runtime_context import ClientWorkerPropertyAPI
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from ray.actor import ActorClass
|
||||
from ray.remote_function import RemoteFunction
|
||||
from ray.util.client.common import ClientStub
|
||||
from ray.util.client.common import ClientActorHandle
|
||||
from ray.util.client.common import ClientObjectRef
|
||||
@@ -265,6 +267,18 @@ class ClientAPI:
|
||||
"""Hook for internal_kv._internal_kv_list."""
|
||||
return self.worker.internal_kv_list(as_bytes(prefix))
|
||||
|
||||
def _convert_actor(self, actor: "ActorClass") -> str:
|
||||
"""Register a ClientActorClass for the ActorClass and return a UUID"""
|
||||
return self.worker._convert_actor(actor)
|
||||
|
||||
def _convert_function(self, func: "RemoteFunction") -> str:
|
||||
"""Register a ClientRemoteFunc for the ActorClass and return a UUID"""
|
||||
return self.worker._convert_function(func)
|
||||
|
||||
def _get_converted(self, key: str) -> "ClientStub":
|
||||
"""Given a UUID, return the converted object"""
|
||||
return self.worker._get_converted(key)
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
if not key.startswith("_"):
|
||||
raise NotImplementedError(
|
||||
|
||||
@@ -82,7 +82,11 @@ class ClientRemoteFunc(ClientStub):
|
||||
def options(self, **kwargs):
|
||||
return OptionWrapper(self, kwargs)
|
||||
|
||||
def _remote(self, args=[], kwargs={}, **option_args):
|
||||
def _remote(self, args=None, kwargs=None, **option_args):
|
||||
if args is None:
|
||||
args = []
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
return self.options(**option_args).remote(*args, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
@@ -150,7 +154,11 @@ class ClientActorClass(ClientStub):
|
||||
def options(self, **kwargs):
|
||||
return ActorOptionWrapper(self, kwargs)
|
||||
|
||||
def _remote(self, args=[], kwargs={}, **option_args):
|
||||
def _remote(self, args=None, kwargs=None, **option_args):
|
||||
if args is None:
|
||||
args = []
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
return self.options(**option_args).remote(*args, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
@@ -230,7 +238,11 @@ class ClientRemoteMethod(ClientStub):
|
||||
def options(self, **kwargs):
|
||||
return OptionWrapper(self, kwargs)
|
||||
|
||||
def _remote(self, args=[], kwargs={}, **option_args):
|
||||
def _remote(self, args=None, kwargs=None, **option_args):
|
||||
if args is None:
|
||||
args = []
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
return self.options(**option_args).remote(*args, **kwargs)
|
||||
|
||||
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
|
||||
|
||||
@@ -46,9 +46,10 @@ def validate_options(
|
||||
raise TypeError(f"Invalid option passed to remote(): {k}")
|
||||
validator = options[k]
|
||||
if len(validator) != 0:
|
||||
if not isinstance(v, validator[0]):
|
||||
raise ValueError(validator[2])
|
||||
if not validator[1](v):
|
||||
raise ValueError(validator[2])
|
||||
if v is not None:
|
||||
if not isinstance(v, validator[0]):
|
||||
raise ValueError(validator[2])
|
||||
if not validator[1](v):
|
||||
raise ValueError(validator[2])
|
||||
out[k] = v
|
||||
return out
|
||||
|
||||
@@ -13,6 +13,7 @@ from typing import Dict
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import grpc
|
||||
|
||||
@@ -22,12 +23,19 @@ import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
from ray.util.client.client_pickler import convert_to_arg
|
||||
from ray.util.client.client_pickler import dumps_from_client
|
||||
from ray.util.client.client_pickler import loads_from_server
|
||||
from ray.util.client.common import ClientStub
|
||||
from ray.util.client.common import ClientActorHandle
|
||||
from ray.util.client.common import ClientActorClass
|
||||
from ray.util.client.common import ClientRemoteFunc
|
||||
from ray.util.client.common import ClientActorRef
|
||||
from ray.util.client.common import ClientObjectRef
|
||||
from ray.util.client.dataclient import DataClient
|
||||
from ray.util.client.logsclient import LogstreamClient
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.actor import ActorClass
|
||||
from ray.remote_function import RemoteFunction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
INITIAL_TIMEOUT_SEC = 5
|
||||
@@ -62,6 +70,7 @@ class Worker:
|
||||
self.channel = None
|
||||
self._conn_state = grpc.ChannelConnectivity.IDLE
|
||||
self._client_id = make_client_id()
|
||||
self._converted: Dict[str, ClientStub] = {}
|
||||
if secure:
|
||||
credentials = grpc.ssl_channel_credentials()
|
||||
self.channel = grpc.secure_channel(conn_str, credentials)
|
||||
@@ -371,6 +380,47 @@ class Worker:
|
||||
def is_connected(self) -> bool:
|
||||
return self._conn_state == grpc.ChannelConnectivity.READY
|
||||
|
||||
def _convert_actor(self, actor: "ActorClass") -> str:
|
||||
"""Register a ClientActorClass for the ActorClass and return a UUID"""
|
||||
key = uuid.uuid4().hex
|
||||
md = actor.__ray_metadata__
|
||||
cls = md.modified_class
|
||||
self._converted[key] = ClientActorClass(
|
||||
cls,
|
||||
options={
|
||||
"max_restarts": md.max_restarts,
|
||||
"max_task_retries": md.max_task_retries,
|
||||
"num_cpus": md.num_cpus,
|
||||
"num_gpus": md.num_gpus,
|
||||
"memory": md.memory,
|
||||
"object_store_memory": md.object_store_memory,
|
||||
"resources": md.resources,
|
||||
"accelerator_type": md.accelerator_type,
|
||||
})
|
||||
return key
|
||||
|
||||
def _convert_function(self, func: "RemoteFunction") -> str:
|
||||
"""Register a ClientRemoteFunc for the ActorClass and return a UUID"""
|
||||
key = uuid.uuid4().hex
|
||||
f = func._function
|
||||
self._converted[key] = ClientRemoteFunc(
|
||||
f,
|
||||
options={
|
||||
"num_cpus": func._num_cpus,
|
||||
"num_gpus": func._num_gpus,
|
||||
"max_calls": func._max_calls,
|
||||
"max_retries": func._max_retries,
|
||||
"resources": func._resources,
|
||||
"accelerator_type": func._accelerator_type,
|
||||
"num_returns": func._num_returns,
|
||||
"memory": func._memory
|
||||
})
|
||||
return key
|
||||
|
||||
def _get_converted(self, key: str) -> "ClientStub":
|
||||
"""Given a UUID, return the converted object"""
|
||||
return self._converted[key]
|
||||
|
||||
|
||||
def make_client_id() -> str:
|
||||
id = uuid.uuid4()
|
||||
|
||||
@@ -1768,7 +1768,6 @@ def make_decorator(num_returns=None,
|
||||
return decorator
|
||||
|
||||
|
||||
@client_mode_hook
|
||||
def remote(*args, **kwargs):
|
||||
"""Defines a remote function or an actor class.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user