diff --git a/python/ray/_private/client_mode_hook.py b/python/ray/_private/client_mode_hook.py index 3ceef7316..74682f1cf 100644 --- a/python/ray/_private/client_mode_hook.py +++ b/python/ray/_private/client_mode_hook.py @@ -2,6 +2,9 @@ import os from contextlib import contextmanager from functools import wraps +# Attr set on func defs to mark they have been converted to client mode. +RAY_CLIENT_MODE_ATTR = "__ray_client_mode_key__" + client_mode_enabled = os.environ.get("RAY_CLIENT_MODE", "0") == "1" _client_hook_enabled = True @@ -34,16 +37,54 @@ def disable_client_hook(): def client_mode_hook(func): - """ - Decorator for ray module methods to delegate to ray client - """ + """Decorator for ray module methods to delegate to ray client""" from ray.util.client import ray @wraps(func) def wrapper(*args, **kwargs): - global _client_hook_enabled - if client_mode_enabled and _client_hook_enabled: + if client_mode_should_convert(): return getattr(ray, func.__name__)(*args, **kwargs) return func(*args, **kwargs) return wrapper + + +def client_mode_should_convert(): + global _client_hook_enabled + return client_mode_enabled and _client_hook_enabled + + +def client_mode_convert_function(func_cls, in_args, in_kwargs, **kwargs): + """Runs a preregistered ray RemoteFunction through the ray client. + + The common case for this is to transparently convert that RemoteFunction + to a ClientRemoteFunction. This happens in circumstances where the + RemoteFunction is declared early, in a library and only then is Ray used in + client mode -- nescessitating a conversion. + """ + from ray.util.client import ray + + key = getattr(func_cls, RAY_CLIENT_MODE_ATTR, None) + if key is None: + key = ray._convert_function(func_cls) + setattr(func_cls, RAY_CLIENT_MODE_ATTR, key) + client_func = ray._get_converted(key) + return client_func._remote(in_args, in_kwargs, **kwargs) + + +def client_mode_convert_actor(actor_cls, in_args, in_kwargs, **kwargs): + """Runs a preregistered actor class on the ray client + + The common case for this decorator is for instantiating an ActorClass + transparently as a ClientActorClass. This happens in circumstances where + the ActorClass is declared early, in a library and only then is Ray used in + client mode -- nescessitating a conversion. + """ + from ray.util.client import ray + + key = getattr(actor_cls, RAY_CLIENT_MODE_ATTR, None) + if key is None: + key = ray._convert_actor(actor_cls) + setattr(actor_cls, RAY_CLIENT_MODE_ATTR, key) + client_actor = ray._get_converted(key) + return client_actor._remote(in_args, in_kwargs, **kwargs) diff --git a/python/ray/actor.py b/python/ray/actor.py index 7ff9f1f33..b24c04a10 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -13,6 +13,8 @@ from ray.util.placement_group import ( from ray import ActorClassID, Language from ray._raylet import PythonFunctionDescriptor from ray._private.client_mode_hook import client_mode_hook +from ray._private.client_mode_hook import client_mode_should_convert +from ray._private.client_mode_hook import client_mode_convert_actor from ray import cross_language from ray.util.inspect import ( is_function_or_method, @@ -553,6 +555,29 @@ class ActorClass: if max_concurrency < 1: raise ValueError("max_concurrency must be >= 1") + if client_mode_should_convert(): + return client_mode_convert_actor( + self, + args, + kwargs, + num_cpus=num_cpus, + num_gpus=num_gpus, + memory=memory, + object_store_memory=object_store_memory, + resources=resources, + accelerator_type=accelerator_type, + max_concurrency=max_concurrency, + max_restarts=max_restarts, + max_task_retries=max_task_retries, + name=name, + lifetime=lifetime, + placement_group=placement_group, + placement_group_bundle_index=placement_group_bundle_index, + placement_group_capture_child_tasks=( + placement_group_capture_child_tasks), + override_environment_variables=( + override_environment_variables)) + worker = ray.worker.global_worker worker.check_connected() diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index e717e2d28..3b8b42062 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -4,6 +4,8 @@ from functools import wraps from ray import cloudpickle as pickle from ray._raylet import PythonFunctionDescriptor from ray import cross_language, Language +from ray._private.client_mode_hook import client_mode_convert_function +from ray._private.client_mode_hook import client_mode_should_convert from ray.util.placement_group import ( PlacementGroup, check_placement_group_index, @@ -181,6 +183,26 @@ class RemoteFunction: override_environment_variables=None, name=""): """Submit the remote function for execution.""" + if client_mode_should_convert(): + return client_mode_convert_function( + self, + args, + kwargs, + num_returns=num_returns, + num_cpus=num_cpus, + num_gpus=num_gpus, + memory=memory, + object_store_memory=object_store_memory, + accelerator_type=accelerator_type, + resources=resources, + max_retries=max_retries, + placement_group=placement_group, + placement_group_bundle_index=placement_group_bundle_index, + placement_group_capture_child_tasks=( + placement_group_capture_child_tasks), + override_environment_variables=override_environment_variables, + name=name) + worker = ray.worker.global_worker worker.check_connected() diff --git a/python/ray/tests/test_client_init.py b/python/ray/tests/test_client_init.py index 5e43ac631..9528f1d20 100644 --- a/python/ray/tests/test_client_init.py +++ b/python/ray/tests/test_client_init.py @@ -2,6 +2,7 @@ import pytest import time +import random import sys import ray.util.client.server.server as ray_client_server @@ -9,6 +10,54 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2 from ray.util.client import RayAPIStub +import ray + + +@ray.remote +def hello_world(): + c1 = complex_task.remote(random.randint(1, 10)) + c2 = complex_task.remote(random.randint(1, 10)) + return sum(ray.get([c1, c2])) + + +@ray.remote +def complex_task(value): + time.sleep(1) + return value * 10 + + +@ray.remote +class C: + def __init__(self, x): + self.val = x + + def double(self): + self.val += self.val + + def get(self): + return self.val + + +def test_basic_preregister(): + from ray.util.client import ray + server, _ = ray_client_server.init_and_serve("localhost:50051") + try: + ray.connect("localhost:50051") + val = ray.get(hello_world.remote()) + print(val) + assert val >= 20 + assert val <= 200 + c = C.remote(3) + x = c.double.remote() + y = c.double.remote() + ray.wait([x, y]) + val = ray.get(c.get.remote()) + assert val == 12 + finally: + ray.disconnect() + ray_client_server.shutdown_with_server(server) + time.sleep(2) + def test_num_clients(): # Tests num clients reporting; useful if you want to build an app that diff --git a/python/ray/util/client/api.py b/python/ray/util/client/api.py index 7d8576d1f..5b1ae881e 100644 --- a/python/ray/util/client/api.py +++ b/python/ray/util/client/api.py @@ -4,6 +4,8 @@ and the overall ray module API. from ray.util.client.runtime_context import ClientWorkerPropertyAPI from typing import TYPE_CHECKING if TYPE_CHECKING: + from ray.actor import ActorClass + from ray.remote_function import RemoteFunction from ray.util.client.common import ClientStub from ray.util.client.common import ClientActorHandle from ray.util.client.common import ClientObjectRef @@ -265,6 +267,18 @@ class ClientAPI: """Hook for internal_kv._internal_kv_list.""" return self.worker.internal_kv_list(as_bytes(prefix)) + def _convert_actor(self, actor: "ActorClass") -> str: + """Register a ClientActorClass for the ActorClass and return a UUID""" + return self.worker._convert_actor(actor) + + def _convert_function(self, func: "RemoteFunction") -> str: + """Register a ClientRemoteFunc for the ActorClass and return a UUID""" + return self.worker._convert_function(func) + + def _get_converted(self, key: str) -> "ClientStub": + """Given a UUID, return the converted object""" + return self.worker._get_converted(key) + def __getattr__(self, key: str): if not key.startswith("_"): raise NotImplementedError( diff --git a/python/ray/util/client/common.py b/python/ray/util/client/common.py index 2bcd14f3f..8eac0983a 100644 --- a/python/ray/util/client/common.py +++ b/python/ray/util/client/common.py @@ -82,7 +82,11 @@ class ClientRemoteFunc(ClientStub): def options(self, **kwargs): return OptionWrapper(self, kwargs) - def _remote(self, args=[], kwargs={}, **option_args): + def _remote(self, args=None, kwargs=None, **option_args): + if args is None: + args = [] + if kwargs is None: + kwargs = {} return self.options(**option_args).remote(*args, **kwargs) def __repr__(self): @@ -150,7 +154,11 @@ class ClientActorClass(ClientStub): def options(self, **kwargs): return ActorOptionWrapper(self, kwargs) - def _remote(self, args=[], kwargs={}, **option_args): + def _remote(self, args=None, kwargs=None, **option_args): + if args is None: + args = [] + if kwargs is None: + kwargs = {} return self.options(**option_args).remote(*args, **kwargs) def __repr__(self): @@ -230,7 +238,11 @@ class ClientRemoteMethod(ClientStub): def options(self, **kwargs): return OptionWrapper(self, kwargs) - def _remote(self, args=[], kwargs={}, **option_args): + def _remote(self, args=None, kwargs=None, **option_args): + if args is None: + args = [] + if kwargs is None: + kwargs = {} return self.options(**option_args).remote(*args, **kwargs) def _prepare_client_task(self) -> ray_client_pb2.ClientTask: diff --git a/python/ray/util/client/options.py b/python/ray/util/client/options.py index 79727b126..b2f1dae81 100644 --- a/python/ray/util/client/options.py +++ b/python/ray/util/client/options.py @@ -46,9 +46,10 @@ def validate_options( raise TypeError(f"Invalid option passed to remote(): {k}") validator = options[k] if len(validator) != 0: - if not isinstance(v, validator[0]): - raise ValueError(validator[2]) - if not validator[1](v): - raise ValueError(validator[2]) + if v is not None: + if not isinstance(v, validator[0]): + raise ValueError(validator[2]) + if not validator[1](v): + raise ValueError(validator[2]) out[k] = v return out diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index 535ec5ab7..3f04c80a4 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -13,6 +13,7 @@ from typing import Dict from typing import List from typing import Tuple from typing import Optional +from typing import TYPE_CHECKING import grpc @@ -22,12 +23,19 @@ import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc from ray.util.client.client_pickler import convert_to_arg from ray.util.client.client_pickler import dumps_from_client from ray.util.client.client_pickler import loads_from_server +from ray.util.client.common import ClientStub from ray.util.client.common import ClientActorHandle +from ray.util.client.common import ClientActorClass +from ray.util.client.common import ClientRemoteFunc from ray.util.client.common import ClientActorRef from ray.util.client.common import ClientObjectRef from ray.util.client.dataclient import DataClient from ray.util.client.logsclient import LogstreamClient +if TYPE_CHECKING: + from ray.actor import ActorClass + from ray.remote_function import RemoteFunction + logger = logging.getLogger(__name__) INITIAL_TIMEOUT_SEC = 5 @@ -62,6 +70,7 @@ class Worker: self.channel = None self._conn_state = grpc.ChannelConnectivity.IDLE self._client_id = make_client_id() + self._converted: Dict[str, ClientStub] = {} if secure: credentials = grpc.ssl_channel_credentials() self.channel = grpc.secure_channel(conn_str, credentials) @@ -371,6 +380,47 @@ class Worker: def is_connected(self) -> bool: return self._conn_state == grpc.ChannelConnectivity.READY + def _convert_actor(self, actor: "ActorClass") -> str: + """Register a ClientActorClass for the ActorClass and return a UUID""" + key = uuid.uuid4().hex + md = actor.__ray_metadata__ + cls = md.modified_class + self._converted[key] = ClientActorClass( + cls, + options={ + "max_restarts": md.max_restarts, + "max_task_retries": md.max_task_retries, + "num_cpus": md.num_cpus, + "num_gpus": md.num_gpus, + "memory": md.memory, + "object_store_memory": md.object_store_memory, + "resources": md.resources, + "accelerator_type": md.accelerator_type, + }) + return key + + def _convert_function(self, func: "RemoteFunction") -> str: + """Register a ClientRemoteFunc for the ActorClass and return a UUID""" + key = uuid.uuid4().hex + f = func._function + self._converted[key] = ClientRemoteFunc( + f, + options={ + "num_cpus": func._num_cpus, + "num_gpus": func._num_gpus, + "max_calls": func._max_calls, + "max_retries": func._max_retries, + "resources": func._resources, + "accelerator_type": func._accelerator_type, + "num_returns": func._num_returns, + "memory": func._memory + }) + return key + + def _get_converted(self, key: str) -> "ClientStub": + """Given a UUID, return the converted object""" + return self._converted[key] + def make_client_id() -> str: id = uuid.uuid4() diff --git a/python/ray/worker.py b/python/ray/worker.py index 337b4ffc9..00d99930c 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1768,7 +1768,6 @@ def make_decorator(num_returns=None, return decorator -@client_mode_hook def remote(*args, **kwargs): """Defines a remote function or an actor class.