From c4e273920f517b18c99fbabca49135dd6e30e683 Mon Sep 17 00:00:00 2001 From: Barak Michener Date: Tue, 22 Dec 2020 22:51:45 -0800 Subject: [PATCH] [ray_client]: Insert decorators into the real ray module to allow for client mode (#13031) --- .travis.yml | 2 +- ci/travis/ci.sh | 7 + python/ray/_private/client_mode_hook.py | 47 ++++ python/ray/_raylet.pyx | 7 + python/ray/experimental/client/__init__.py | 200 +++++++---------- python/ray/experimental/client/api.py | 212 +++++++----------- .../ray/experimental/client/client_pickler.py | 13 +- python/ray/experimental/client/common.py | 21 +- python/ray/experimental/client/dataclient.py | 3 +- .../experimental/client/examples/run_tune.py | 7 + python/ray/experimental/client/logsclient.py | 14 +- .../experimental/client/ray_client_helpers.py | 7 +- .../client/server/core_ray_api.py | 81 ------- .../experimental/client/server/logservicer.py | 6 +- .../ray/experimental/client/server/server.py | 132 ++++++----- .../client/server/server_pickler.py | 17 +- python/ray/experimental/client/worker.py | 42 +--- python/ray/state.py | 4 + python/ray/test_utils.py | 2 +- python/ray/tests/BUILD | 2 +- python/ray/tests/conftest.py | 24 +- python/ray/tests/test_actor.py | 9 +- python/ray/tests/test_advanced.py | 2 + python/ray/tests/test_basic.py | 5 +- python/ray/tests/test_experimental_client.py | 30 ++- .../test_experimental_client_metadata.py | 3 +- .../test_experimental_client_references.py | 6 +- python/ray/worker.py | 11 + 28 files changed, 419 insertions(+), 497 deletions(-) create mode 100644 python/ray/_private/client_mode_hook.py create mode 100644 python/ray/experimental/client/examples/run_tune.py delete mode 100644 python/ray/experimental/client/server/core_ray_api.py diff --git a/.travis.yml b/.travis.yml index dc133de49..8173ec1ac 100644 --- a/.travis.yml +++ b/.travis.yml @@ -46,7 +46,7 @@ matrix: script: # bazel python tests for medium size tests. Used for parallelization. - if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,medium_size_python_tests_a_to_j python/ray/tests/...; fi - - if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,client_tests --test_env=RAY_TEST_CLIENT_MODE=1 python/ray/tests/...; fi + - if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,client_tests --test_env=RAY_CLIENT_MODE=1 python/ray/tests/...; fi - os: linux env: diff --git a/ci/travis/ci.sh b/ci/travis/ci.sh index 843515400..e4d9741cd 100755 --- a/ci/travis/ci.sh +++ b/ci/travis/ci.sh @@ -262,6 +262,11 @@ _bazel_build_before_install() { bazel build "${target}" } + +_bazel_build_protobuf() { + bazel build "//:install_py_proto" +} + install_ray() { # TODO(mehrdadn): This function should be unified with the one in python/build-wheel-windows.sh. ( @@ -457,6 +462,8 @@ init() { build() { if [ "${LINT-}" != 1 ]; then _bazel_build_before_install + else + _bazel_build_protobuf fi if ! need_wheels; then diff --git a/python/ray/_private/client_mode_hook.py b/python/ray/_private/client_mode_hook.py new file mode 100644 index 000000000..4fbc568c8 --- /dev/null +++ b/python/ray/_private/client_mode_hook.py @@ -0,0 +1,47 @@ +import os +from contextlib import contextmanager + +client_mode_enabled = os.environ.get("RAY_CLIENT_MODE", "0") == "1" + +_client_hook_enabled = True + + +def _enable_client_hook(val: bool): + global _client_hook_enabled + _client_hook_enabled = val + + +def _disable_client_hook(): + global _client_hook_enabled + out = _client_hook_enabled + _client_hook_enabled = False + return out + + +def _explicitly_enable_client_mode(): + global client_mode_enabled + client_mode_enabled = True + + +@contextmanager +def disable_client_hook(): + val = _disable_client_hook() + try: + yield None + finally: + _enable_client_hook(val) + + +def client_mode_hook(func): + """ + Decorator for ray module methods to delegate to ray client + """ + from ray.experimental.client import ray + + def wrapper(*args, **kwargs): + global _client_hook_enabled + if client_mode_enabled and _client_hook_enabled: + return getattr(ray, func.__name__)(*args, **kwargs) + return func(*args, **kwargs) + + return wrapper diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 356222bb9..1360c96ce 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -107,6 +107,10 @@ from ray.exceptions import ( TaskCancelledError ) from ray.utils import decode +from ray._private.client_mode_hook import ( + _enable_client_hook, + _disable_client_hook, +) import msgpack cimport cpython @@ -558,6 +562,7 @@ cdef CRayStatus task_execution_handler( with gil: try: + client_was_enabled = _disable_client_hook() try: # The call to execute_task should never raise an exception. If # it does, that indicates that there was an internal error. @@ -582,6 +587,8 @@ cdef CRayStatus task_execution_handler( else: logger.exception("SystemExit was raised from the worker") return CRayStatus.UnexpectedSystemExit() + finally: + _enable_client_hook(client_was_enabled) return CRayStatus.OK() diff --git a/python/ray/experimental/client/__init__.py b/python/ray/experimental/client/__init__.py index ed1983528..674dfa7f7 100644 --- a/python/ray/experimental/client/__init__.py +++ b/python/ray/experimental/client/__init__.py @@ -1,148 +1,104 @@ -from ray.experimental.client.api import ClientAPI -from ray.experimental.client.api import APIImpl -from typing import Optional, List, Tuple -from contextlib import contextmanager +from typing import List, Tuple import logging -import os logger = logging.getLogger(__name__) -# About these global variables: Ray 1.0 uses exported module functions to -# provide its API, and we need to match that. However, we want different -# behaviors depending on where, exactly, in the client stack this is running. -# -# The reason for these differences depends on what's being pickled and passed -# to functions, or functions inside functions. So there are three cases to care -# about -# -# (Python Client)-->(Python ClientServer)-->(Internal Raylet Process) -# -# * _client_api should be set if we're inside the client -# * _server_api should be set if we're inside the clientserver -# * Both will be set if we're running both (as in a test) -# * Neither should be set if we're inside the raylet (but we still need to shim -# from the client API surface to the Ray API) -# -# The job of RayAPIStub (below) delegates to the appropriate one of these -# depending on what's set or not. Then, all users importing the ray object -# from this package get the stub which routes them to the appropriate APIImpl. -_client_api: Optional[APIImpl] = None -_server_api: Optional[APIImpl] = None - -# The reason for _is_server is a hack around the above comment while running -# tests. If we have both a client and a server trying to control these static -# variables then we need a way to decide which to use. In this case, both -# _client_api and _server_api are set. -# This boolean flips between the two -_is_server: bool = False - - -@contextmanager -def stash_api_for_tests(in_test: bool): - global _is_server - is_server = _is_server - if in_test: - _is_server = True - try: - yield _server_api - finally: - if in_test: - _is_server = is_server - - -def _set_client_api(val: Optional[APIImpl]): - global _client_api - global _is_server - if _client_api is not None: - raise Exception("Trying to set more than one client API") - _client_api = val - _is_server = False - - -def _set_server_api(val: Optional[APIImpl]): - global _server_api - global _is_server - if _server_api is not None: - raise Exception("Trying to set more than one server API") - _server_api = val - _is_server = True - - -def reset_api(): - global _client_api - global _server_api - global _is_server - _client_api = None - _server_api = None - _is_server = False - - -def _get_client_api() -> APIImpl: - global _client_api - return _client_api - - -def _get_server_instance(): - """Used inside tests to inspect the running server. - """ - global _server_api - if _server_api is not None: - return _server_api.server - class RayAPIStub: + """This class stands in as the replacement API for the `import ray` module. + + Much like the ray module, this mostly delegates the work to the + _client_worker. As parts of the ray API are covered, they are piped through + here or on the client worker API. + """ + + def __init__(self): + from ray.experimental.client.api import ClientAPI + self.api = ClientAPI() + self.client_worker = None + self._server = None + self._connected_with_init = False + self._inside_client_test = False + def connect(self, conn_str: str, secure: bool = False, - metadata: List[Tuple[str, str]] = None, - stub=None) -> None: + metadata: List[Tuple[str, str]] = None) -> None: + """Connect the Ray Client to a server. + + Args: + conn_str: Connection string, in the form "[host]:port" + secure: Whether to use a TLS secured gRPC channel + metadata: gRPC metadata to send on connect + """ + # Delay imports until connect to avoid circular imports. from ray.experimental.client.worker import Worker - _client_worker = Worker(conn_str, secure=secure, metadata=metadata) - _set_client_api(ClientAPI(_client_worker)) + import ray._private.client_mode_hook + if self.client_worker is not None: + if self._connected_with_init: + return + raise Exception( + "ray.connect() called, but ray client is already connected") + if not self._inside_client_test: + # If we're calling a client connect specifically and we're not + # currently in client mode, ensure we are. + ray._private.client_mode_hook._explicitly_enable_client_mode() + self.client_worker = Worker(conn_str, secure=secure, metadata=metadata) + self.api.worker = self.client_worker def disconnect(self): - global _client_api - if _client_api is not None: - _client_api.close() - _client_api = None + """Disconnect the Ray Client. + """ + if self.client_worker is not None: + self.client_worker.close() + self.client_worker = None + + # remote can be called outside of a connection, which is why it + # exists on the same API layer as connect() itself. + def remote(self, *args, **kwargs): + """remote is the hook stub passed on to replace `ray.remote`. + + This sets up remote functions or actors, as the decorator, + but does not execute them. + + Args: + args: opaque arguments + kwargs: opaque keyword arguments + """ + return self.api.remote(*args, **kwargs) def __getattr__(self, key: str): - global _get_client_api - api = _get_client_api() - return getattr(api, key) + if not self.is_connected(): + raise Exception("Ray Client is not connected. " + "Please connect by calling `ray.connect`.") + return getattr(self.api, key) def is_connected(self) -> bool: - global _client_api - return _client_api is not None + return self.api is not None def init(self, *args, **kwargs): - if _is_client_test_env(): - global _test_server - import ray.experimental.client.server.server as ray_client_server - _test_server, address_info = ray_client_server.init_and_serve( - "localhost:50051", test_mode=True, *args, **kwargs) - self.connect("localhost:50051") - return address_info - else: - raise NotImplementedError( - "Please call ray.connect() in client mode") + if self._server is not None: + raise Exception("Trying to start two instances of ray via client") + import ray.experimental.client.server.server as ray_client_server + self._server, address_info = ray_client_server.init_and_serve( + "localhost:50051", *args, **kwargs) + self.connect("localhost:50051") + self._connected_with_init = True + return address_info + + def shutdown(self, _exiting_interpreter=False): + self.disconnect() + import ray.experimental.client.server.server as ray_client_server + if self._server is None: + return + ray_client_server.shutdown_with_server(self._server, + _exiting_interpreter) + self._server = None ray = RayAPIStub() -_test_server = None - - -def _stop_test_server(*args): - global _test_server - _test_server.stop(*args) - - -def _is_client_test_env() -> bool: - return os.environ.get("RAY_TEST_CLIENT_MODE") == "1" - - # Someday we might add methods in this module so that someone who # tries to `import ray_client as ray` -- as a module, instead of # `from ray_client import ray` -- as the API stub diff --git a/python/ray/experimental/client/api.py b/python/ray/experimental/client/api.py index 93da6382f..58680bf9f 100644 --- a/python/ray/experimental/client/api.py +++ b/python/ray/experimental/client/api.py @@ -1,74 +1,51 @@ -# This file defines an interface and client-side API stub -# for referring either to the core Ray API or the same interface -# from the Ray client. -# -# In tandem with __init__.py, we want to expose an API that's -# close to `python/ray/__init__.py` but with more than one implementation. -# The stubs in __init__ should call into a well-defined interface. -# Only the core Ray API implementation should actually `import ray` -# (and thus import all the raylet worker C bindings and such). -# But to make sure that we're matching these calls, we define this API. - -from abc import ABC -from abc import abstractmethod -from typing import TYPE_CHECKING, Any, Union, Optional -import ray.core.generated.ray_client_pb2 as ray_client_pb2 +"""This file defines the interface between the ray client worker +and the overall ray module API. +""" +from typing import TYPE_CHECKING if TYPE_CHECKING: - from ray.experimental.client.common import ClientActorHandle from ray.experimental.client.common import ClientStub + from ray.experimental.client.common import ClientActorHandle from ray.experimental.client.common import ClientObjectRef - from ray._raylet import ObjectRef - - # Use the imports for type checking. This is a python 3.6 limitation. - # See https://www.python.org/dev/peps/pep-0563/ - PutType = Union[ClientObjectRef, ObjectRef] -class APIImpl(ABC): - """ - APIImpl is the interface to implement for whichever version of the core - Ray API that needs abstracting when run in client mode. +class ClientAPI: + """The Client-side methods corresponding to the ray API. Delegates + to the Client Worker that contains the connection to the ClientServer. """ - @abstractmethod - def get(self, vals, *, timeout: Optional[float] = None) -> Any: - """ - get is the hook stub passed on to replace `ray.get` + def __init__(self, worker=None): + self.worker = worker + + def get(self, vals, *, timeout=None): + """get is the hook stub passed on to replace `ray.get` Args: vals: [Client]ObjectRef or list of these refs to retrieve. timeout: Optional timeout in milliseconds """ - pass + return self.worker.get(vals, timeout=timeout) - @abstractmethod - def put(self, vals: Any, *args, - **kwargs) -> Union["ClientObjectRef", "ObjectRef"]: - """ - put is the hook stub passed on to replace `ray.put` + def put(self, *args, **kwargs): + """put is the hook stub passed on to replace `ray.put` Args: vals: The value or list of values to `put`. args: opaque arguments kwargs: opaque keyword arguments """ - pass + return self.worker.put(*args, **kwargs) - @abstractmethod def wait(self, *args, **kwargs): - """ - wait is the hook stub passed on to replace `ray.wait` + """wait is the hook stub passed on to replace `ray.wait` Args: args: opaque arguments kwargs: opaque keyword arguments """ - pass + return self.worker.wait(*args, **kwargs) - @abstractmethod def remote(self, *args, **kwargs): - """ - remote is the hook stub passed on to replace `ray.remote`. + """remote is the hook stub passed on to replace `ray.remote`. This sets up remote functions or actors, as the decorator, but does not execute them. @@ -77,12 +54,24 @@ class APIImpl(ABC): args: opaque arguments kwargs: opaque keyword arguments """ - pass + # Delayed import to avoid a cyclic import + from ray.experimental.client.common import remote_decorator + if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): + # This is the case where the decorator is just @ray.remote. + return remote_decorator(options=None)(args[0]) + error_string = ("The @ray.remote decorator must be applied either " + "with no arguments and no parentheses, for example " + "'@ray.remote', or it must be applied using some of " + "the arguments 'num_returns', 'num_cpus', 'num_gpus', " + "'memory', 'object_store_memory', 'resources', " + "'max_calls', or 'max_restarts', like " + "'@ray.remote(num_returns=2, " + "resources={\"CustomResource\": 1})'.") + assert len(args) == 0 and len(kwargs) > 0, error_string + return remote_decorator(options=kwargs) - @abstractmethod def call_remote(self, instance: "ClientStub", *args, **kwargs): - """ - call_remote is called by stub objects to execute them remotely. + """call_remote is called by stub objects to execute them remotely. This is used by stub objects in situations where they're called with .remote, eg, `f.remote()` or `actor_cls.remote()`. @@ -95,31 +84,57 @@ class APIImpl(ABC): args: opaque arguments kwargs: opaque keyword arguments """ - pass + return self.worker.call_remote(instance, *args, **kwargs) - @abstractmethod - def close(self) -> None: + def call_release(self, id: bytes) -> None: + """Attempts to release an object reference. + + When client references are destructed, they release their reference, + which can opportunistically send a notification through the datachannel + to release the reference being held for that object on the server. + + Args: + id: The id of the reference to release on the server side. """ - close cleans up an API connection by closing any channels or + return self.worker.call_release(id) + + def call_retain(self, id: bytes) -> None: + """Attempts to retain a client object reference. + + Increments the reference count on the client side, to prevent + the client worker from attempting to release the server reference. + + Args: + id: The id of the reference to retain on the client side. + """ + return self.worker.call_retain(id) + + def close(self) -> None: + """close cleans up an API connection by closing any channels or shutting down any servers gracefully. """ - pass + return self.worker.close() - @abstractmethod - def kill(self, actor, *, no_restart=True): + def get_actor(self, name: str) -> "ClientActorHandle": + """Returns a handle to an actor by name. + + Args: + name: The name passed to this actor by + Actor.options(name="name").remote() """ - kill forcibly stops an actor running in the cluster + return self.worker.get_actor(name) + + def kill(self, actor: "ClientActorHandle", *, no_restart=True): + """kill forcibly stops an actor running in the cluster Args: no_restart: Whether this actor should be restarted if it's a restartable actor. """ - pass + return self.worker.terminate_actor(actor, no_restart) - @abstractmethod - def cancel(self, obj, *, force=False, recursive=True): - """ - Cancels a task on the cluster. + def cancel(self, obj: "ClientObjectRef", *, force=False, recursive=True): + """Cancels a task on the cluster. If the specified task is pending execution, it will not be executed. If the task is currently executing, the behavior depends on the ``force`` @@ -136,80 +151,11 @@ class APIImpl(ABC): recursive (boolean): Whether to try to cancel tasks submitted by the task specified. """ - pass - - @abstractmethod - def call_release(self, id: bytes) -> None: - """ - Attempts to release an object reference. - - When client references are destructed, they release their reference, - which can opportunistically send a notification through the datachannel - to release the reference being held for that object on the server. - - Args: - id: The id of the reference to release on the server side. - """ - - @abstractmethod - def call_retain(self, id: bytes) -> None: - """ - Attempts to retain a client object reference. - - Increments the reference count on the client side, to prevent - the client worker from attempting to release the server reference. - - Args: - id: The id of the reference to retain on the client side. - """ - - -class ClientAPI(APIImpl): - """ - The Client-side methods corresponding to the ray API. Delegates - to the Client Worker that contains the connection to the ClientServer. - """ - - def __init__(self, worker): - self.worker = worker - - def get(self, vals, *, timeout=None): - return self.worker.get(vals, timeout=timeout) - - def put(self, *args, **kwargs): - return self.worker.put(*args, **kwargs) - - def wait(self, *args, **kwargs): - return self.worker.wait(*args, **kwargs) - - def remote(self, *args, **kwargs): - return self.worker.remote(*args, **kwargs) - - def call_remote(self, instance: "ClientStub", *args, **kwargs): - return self.worker.call_remote(instance, *args, **kwargs) - - def call_release(self, id: bytes) -> None: - return self.worker.call_release(id) - - def call_retain(self, id: bytes) -> None: - return self.worker.call_retain(id) - - def close(self) -> None: - return self.worker.close() - - def get_actor(self, name: str) -> "ClientActorHandle": - return self.worker.get_actor(name) - - def kill(self, actor: "ClientActorHandle", *, no_restart=True): - return self.worker.terminate_actor(actor, no_restart) - - def cancel(self, obj: "ClientObjectRef", *, force=False, recursive=True): return self.worker.terminate_task(obj, force, recursive) # Various metadata methods for the client that are defined in the protocol. def is_initialized(self) -> bool: - """ True if our client is connected, and if the server is initialized. - + """True if our client is connected, and if the server is initialized. Returns: A boolean determining if the client is connected and server initialized. @@ -222,6 +168,8 @@ class ClientAPI(APIImpl): Returns: Information about the Ray clients in the cluster. """ + # This should be imported here, otherwise, it will error doc build. + import ray.core.generated.ray_client_pb2 as ray_client_pb2 return self.worker.get_cluster_info( ray_client_pb2.ClusterInfoType.NODES) @@ -235,6 +183,8 @@ class ClientAPI(APIImpl): A dictionary mapping resource name to the total quantity of that resource in the cluster. """ + # This should be imported here, otherwise, it will error doc build. + import ray.core.generated.ray_client_pb2 as ray_client_pb2 return self.worker.get_cluster_info( ray_client_pb2.ClusterInfoType.CLUSTER_RESOURCES) @@ -250,6 +200,8 @@ class ClientAPI(APIImpl): A dictionary mapping resource name to the total quantity of that resource in the cluster. """ + # This should be imported here, otherwise, it will error doc build. + import ray.core.generated.ray_client_pb2 as ray_client_pb2 return self.worker.get_cluster_info( ray_client_pb2.ClusterInfoType.AVAILABLE_RESOURCES) diff --git a/python/ray/experimental/client/client_pickler.py b/python/ray/experimental/client/client_pickler.py index 7ba83b3ac..863884687 100644 --- a/python/ray/experimental/client/client_pickler.py +++ b/python/ray/experimental/client/client_pickler.py @@ -1,5 +1,4 @@ -""" -Implements the client side of the client/server pickling protocol. +"""Implements the client side of the client/server pickling protocol. All ray client client/server data transfer happens through this pickling protocol. The model is as follows: @@ -41,6 +40,7 @@ from ray.experimental.client.common import ClientRemoteMethod from ray.experimental.client.common import OptionWrapper from ray.experimental.client.common import SelfReferenceSentinel import ray.core.generated.ray_client_pb2 as ray_client_pb2 +from ray._private.client_mode_hook import disable_client_hook if sys.version_info < (3, 8): try: @@ -155,10 +155,11 @@ class ServerUnpickler(pickle.Unpickler): def dumps_from_client(obj: Any, client_id: str, protocol=None) -> bytes: - with io.BytesIO() as file: - cp = ClientPickler(client_id, file, protocol=protocol) - cp.dump(obj) - return file.getvalue() + with disable_client_hook(): + with io.BytesIO() as file: + cp = ClientPickler(client_id, file, protocol=protocol) + cp.dump(obj) + return file.getvalue() def loads_from_server(data: bytes, diff --git a/python/ray/experimental/client/common.py b/python/ray/experimental/client/common.py index f68b26e2c..18708f279 100644 --- a/python/ray/experimental/client/common.py +++ b/python/ray/experimental/client/common.py @@ -2,6 +2,8 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2 from ray.experimental.client import ray from ray.experimental.client.options import validate_options +import inspect +from ray.util.inspect import is_cython import json import threading from typing import Any @@ -52,8 +54,7 @@ class ClientStub: class ClientRemoteFunc(ClientStub): - """ - A stub created on the Ray Client to represent a remote + """A stub created on the Ray Client to represent a remote function that can be exectued on the cluster. This class is allowed to be passed around between remote functions. @@ -112,7 +113,7 @@ class ClientRemoteFunc(ClientStub): class ClientActorClass(ClientStub): - """ A stub created on the Ray Client to represent an actor class. + """A stub created on the Ray Client to represent an actor class. It is wrapped by ray.remote and can be executed on the cluster. @@ -294,3 +295,17 @@ class DataEncodingSentinel: class SelfReferenceSentinel(DataEncodingSentinel): pass + + +def remote_decorator(options: Optional[Dict[str, Any]]): + def decorator(function_or_class) -> ClientStub: + if (inspect.isfunction(function_or_class) + or is_cython(function_or_class)): + return ClientRemoteFunc(function_or_class, options=options) + elif inspect.isclass(function_or_class): + return ClientActorClass(function_or_class, options=options) + else: + raise TypeError("The @ray.remote decorator must be applied to " + "either a function or to a class.") + + return decorator diff --git a/python/ray/experimental/client/dataclient.py b/python/ray/experimental/client/dataclient.py index b0dda0a1b..add66c82c 100644 --- a/python/ray/experimental/client/dataclient.py +++ b/python/ray/experimental/client/dataclient.py @@ -1,5 +1,4 @@ -""" -This file implements a threaded stream controller to abstract a data stream +"""This file implements a threaded stream controller to abstract a data stream back to the ray clientserver. """ import logging diff --git a/python/ray/experimental/client/examples/run_tune.py b/python/ray/experimental/client/examples/run_tune.py new file mode 100644 index 000000000..9e0592c1e --- /dev/null +++ b/python/ray/experimental/client/examples/run_tune.py @@ -0,0 +1,7 @@ +from ray.experimental.client import ray + +from ray.tune import tune + +ray.connect("localhost:50051") + +tune.run("PG", config={"env": "CartPole-v0"}) diff --git a/python/ray/experimental/client/logsclient.py b/python/ray/experimental/client/logsclient.py index f26417e7e..acf2619c9 100644 --- a/python/ray/experimental/client/logsclient.py +++ b/python/ray/experimental/client/logsclient.py @@ -1,5 +1,4 @@ -""" -This file implements a threaded stream controller to return logs back from +"""This file implements a threaded stream controller to return logs back from the ray clientserver. """ import sys @@ -12,6 +11,10 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc logger = logging.getLogger(__name__) +# TODO(barakmich): Running a logger in a logger causes loopback. +# The client logger need its own root -- possibly this one. +# For the moment, let's just not propogate beyond this point. +logger.propagate = False class LogstreamClient: @@ -45,8 +48,7 @@ class LogstreamClient: raise e def log(self, level: int, msg: str): - """ - Log the message from the log stream. + """Log the message from the log stream. By default, calls logger.log but this can be overridden. Args: @@ -56,8 +58,7 @@ class LogstreamClient: logger.log(level=level, msg=msg) def stdstream(self, level: int, msg: str): - """ - Log the stdout/stderr entry from the log stream. + """Log the stdout/stderr entry from the log stream. By default, calls print but this can be overridden. Args: @@ -68,6 +69,7 @@ class LogstreamClient: print(msg, file=print_file) def set_logstream_level(self, level: int): + logger.setLevel(level) req = ray_client_pb2.LogSettingsRequest() req.enabled = True req.loglevel = level diff --git a/python/ray/experimental/client/ray_client_helpers.py b/python/ray/experimental/client/ray_client_helpers.py index ab9d7408a..975918cef 100644 --- a/python/ray/experimental/client/ray_client_helpers.py +++ b/python/ray/experimental/client/ray_client_helpers.py @@ -1,16 +1,17 @@ from contextlib import contextmanager import ray.experimental.client.server.server as ray_client_server -from ray.experimental.client import ray, reset_api +from ray.experimental.client import ray @contextmanager def ray_start_client_server(): - server = ray_client_server.serve("localhost:50051", test_mode=True) + ray._inside_client_test = True + server = ray_client_server.serve("localhost:50051") ray.connect("localhost:50051") try: yield ray finally: + ray._inside_client_test = False ray.disconnect() server.stop(0) - reset_api() diff --git a/python/ray/experimental/client/server/core_ray_api.py b/python/ray/experimental/client/server/core_ray_api.py deleted file mode 100644 index 0762cd0b1..000000000 --- a/python/ray/experimental/client/server/core_ray_api.py +++ /dev/null @@ -1,81 +0,0 @@ -# Along with `api.py` this is the stub that interfaces with -# the real (C-binding, raylet) ray core. -# -# Ideally, the first import line is the only time we actually -# import ray in this library (excluding the main function for the server) -# -# While the stub is trivial, it allows us to check that the calls we're -# making into the core-ray module are contained and well-defined. - -from typing import Any -from typing import Optional -from typing import Union - -import logging -import ray - -from ray.experimental.client.api import APIImpl -from ray.experimental.client.common import ClientObjectRef -from ray.experimental.client.common import ClientStub - -logger = logging.getLogger(__name__) - - -class CoreRayAPI(APIImpl): - """ - Implements the equivalent client-side Ray API by simply passing along to - the Core Ray API. Primarily used inside of Ray Workers as a trampoline back - to core ray when passed client stubs. - """ - - def get(self, vals, *, timeout: Optional[float] = None) -> Any: - return ray.get(vals, timeout=timeout) - - def put(self, vals: Any, *args, - **kwargs) -> Union[ClientObjectRef, ray._raylet.ObjectRef]: - return ray.put(vals, *args, **kwargs) - - def wait(self, *args, **kwargs): - return ray.wait(*args, **kwargs) - - def remote(self, *args, **kwargs): - return ray.remote(*args, **kwargs) - - def call_remote(self, instance: ClientStub, *args, **kwargs): - raise NotImplementedError( - "Should not attempt execution of a client stub inside the raylet") - - def close(self) -> None: - return None - - def kill(self, actor, *, no_restart=True): - return ray.kill(actor, no_restart=no_restart) - - def cancel(self, obj, *, force=False, recursive=True): - return ray.cancel(obj, force=force, recursive=recursive) - - def is_initialized(self) -> bool: - return ray.is_initialized() - - def call_release(self, id: bytes) -> None: - return None - - def call_retain(self, id: bytes) -> None: - return None - - # Allow for generic fallback to ray.* in remote methods. This allows calls - # like ray.nodes() to be run in remote functions even though the client - # doesn't currently support them. - def __getattr__(self, key: str): - return getattr(ray, key) - - -class RayServerAPI(CoreRayAPI): - """ - Ray Client server-side API shim. By default, simply calls the default Core - Ray API calls, but also accepts scheduling calls from functions running - inside of other remote functions that need to create more work. - """ - - def __init__(self, server_instance): - self.server = server_instance diff --git a/python/ray/experimental/client/server/logservicer.py b/python/ray/experimental/client/server/logservicer.py index 9b2fa24bf..25e4ccbd5 100644 --- a/python/ray/experimental/client/server/logservicer.py +++ b/python/ray/experimental/client/server/logservicer.py @@ -1,5 +1,4 @@ -""" -This file responds to log stream requests and forwards logs +"""This file responds to log stream requests and forwards logs with its handler. """ import io @@ -70,6 +69,9 @@ def log_status_change_thread(log_queue, request_iterator): std_handler.register_global() root_logger.addHandler(current_handler) root_logger.setLevel(req.loglevel) + except grpc.RpcError as e: + logger.debug(f"closing log thread " + f"grpc error reading request_iterator: {e}") finally: if current_handler is not None: root_logger.setLevel(default_level) diff --git a/python/ray/experimental/client/server/server.py b/python/ray/experimental/client/server/server.py index 7cc286de8..c1b7d6be8 100644 --- a/python/ray/experimental/client/server/server.py +++ b/python/ray/experimental/client/server/server.py @@ -17,27 +17,25 @@ import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc import time import inspect import json -from ray.experimental.client import stash_api_for_tests, _set_server_api from ray.experimental.client.server.server_pickler import convert_from_arg from ray.experimental.client.server.server_pickler import dumps_from_server from ray.experimental.client.server.server_pickler import loads_from_client -from ray.experimental.client.server.core_ray_api import RayServerAPI from ray.experimental.client.server.dataservicer import DataServicer from ray.experimental.client.server.logservicer import LogstreamServicer from ray.experimental.client.server.server_stubs import current_remote +from ray._private.client_mode_hook import disable_client_hook logger = logging.getLogger(__name__) class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): - def __init__(self, test_mode=False): + def __init__(self): self.object_refs: Dict[str, Dict[bytes, ray.ObjectRef]] = defaultdict( dict) self.function_refs = {} self.actor_refs: Dict[bytes, ray.ActorHandle] = {} self.actor_owners: Dict[str, Set[bytes]] = defaultdict(set) self.registered_actor_classes = {} - self._test_mode = test_mode self._current_function_stub = None def ClusterInfo(self, request, @@ -45,7 +43,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): resp = ray_client_pb2.ClusterInfoResponse() resp.type = request.type if request.type == ray_client_pb2.ClusterInfoType.CLUSTER_RESOURCES: - resources = ray.cluster_resources() + with disable_client_hook(): + resources = ray.cluster_resources() # Normalize resources into floats # (the function may return values that are ints) float_resources = {k: float(v) for k, v in resources.items()} @@ -54,7 +53,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): table=float_resources)) elif request.type == \ ray_client_pb2.ClusterInfoType.AVAILABLE_RESOURCES: - resources = ray.available_resources() + with disable_client_hook(): + resources = ray.available_resources() # Normalize resources into floats # (the function may return values that are ints) float_resources = {k: float(v) for k, v in resources.items()} @@ -62,7 +62,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): ray_client_pb2.ClusterInfoResponse.ResourceTable( table=float_resources)) else: - resp.json = self._return_debug_cluster_info(request, context) + with disable_client_hook(): + resp.json = self._return_debug_cluster_info(request, context) return resp def _return_debug_cluster_info(self, request, context=None) -> str: @@ -118,16 +119,18 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): try: object_ref = \ self.object_refs[req.client_id][req.task_object.id] - ray.cancel( - object_ref, - force=req.task_object.force, - recursive=req.task_object.recursive) + with disable_client_hook(): + ray.cancel( + object_ref, + force=req.task_object.force, + recursive=req.task_object.recursive) except Exception as e: return_exception_in_context(e, context) elif req.WhichOneof("terminate_type") == "actor": try: actor_ref = self.actor_refs[req.actor.id] - ray.kill(actor_ref, no_restart=req.actor.no_restart) + with disable_client_hook(): + ray.kill(actor_ref, no_restart=req.actor.no_restart) except Exception as e: return_exception_in_context(e, context) else: @@ -145,7 +148,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): objectref = self.object_refs[client_id][request.id] logger.debug("get: %s" % objectref) try: - item = ray.get(objectref, timeout=request.timeout) + with disable_client_hook(): + item = ray.get(objectref, timeout=request.timeout) except Exception as e: return ray_client_pb2.GetResponse( valid=False, error=cloudpickle.dumps(e)) @@ -171,7 +175,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): context: gRPC context. """ obj = loads_from_client(request.data, self) - objectref = ray.put(obj) + with disable_client_hook(): + objectref = ray.put(obj) self.object_refs[client_id][objectref.binary()] = objectref logger.debug("put: %s" % objectref) return ray_client_pb2.PutResponse(id=objectref.binary()) @@ -187,11 +192,12 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): num_returns = request.num_returns timeout = request.timeout try: - ready_object_refs, remaining_object_refs = ray.wait( - object_refs, - num_returns=num_returns, - timeout=timeout if timeout != -1 else None, - ) + with disable_client_hook(): + ready_object_refs, remaining_object_refs = ray.wait( + object_refs, + num_returns=num_returns, + timeout=timeout if timeout != -1 else None, + ) except Exception as e: # TODO(ameer): improve exception messages. logger.error(f"Exception {e}") @@ -215,8 +221,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): "schedule: %s %s" % (task.name, ray_client_pb2.ClientTask.RemoteExecType.Name( task.type))) - with stash_api_for_tests(self._test_mode): - try: + try: + with disable_client_hook(): if task.type == ray_client_pb2.ClientTask.FUNCTION: result = self._schedule_function(task, context) elif task.type == ray_client_pb2.ClientTask.ACTOR: @@ -232,11 +238,11 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): task.type)) result.valid = True return result - except Exception as e: - logger.error(f"Caught schedule exception {e}") - raise e - return ray_client_pb2.ClientTaskTicket( - valid=False, error=cloudpickle.dumps(e)) + except Exception as e: + logger.error(f"Caught schedule exception {e}") + raise e + return ray_client_pb2.ClientTaskTicket( + valid=False, error=cloudpickle.dumps(e)) def _schedule_method(self, task: ray_client_pb2.ClientTask, context=None) -> ray_client_pb2.ClientTaskTicket: @@ -307,31 +313,33 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): def lookup_or_register_func( self, id: bytes, client_id: str, options: Optional[Dict]) -> ray.remote_function.RemoteFunction: - if id not in self.function_refs: - funcref = self.object_refs[client_id][id] - func = ray.get(funcref) - if not inspect.isfunction(func): - raise Exception("Attempting to register function that " - "isn't a function.") - if options is None or len(options) == 0: - self.function_refs[id] = ray.remote(func) - else: - self.function_refs[id] = ray.remote(**options)(func) + with disable_client_hook(): + if id not in self.function_refs: + funcref = self.object_refs[client_id][id] + func = ray.get(funcref) + if not inspect.isfunction(func): + raise Exception("Attempting to register function that " + "isn't a function.") + if options is None or len(options) == 0: + self.function_refs[id] = ray.remote(func) + else: + self.function_refs[id] = ray.remote(**options)(func) return self.function_refs[id] def lookup_or_register_actor(self, id: bytes, client_id: str, options: Optional[Dict]): - if id not in self.registered_actor_classes: - actor_class_ref = self.object_refs[client_id][id] - actor_class = ray.get(actor_class_ref) - if not inspect.isclass(actor_class): - raise Exception("Attempting to schedule actor that " - "isn't a class.") - if options is None or len(options) == 0: - reg_class = ray.remote(actor_class) - else: - reg_class = ray.remote(**options)(actor_class) - self.registered_actor_classes[id] = reg_class + with disable_client_hook(): + if id not in self.registered_actor_classes: + actor_class_ref = self.object_refs[client_id][id] + actor_class = ray.get(actor_class_ref) + if not inspect.isclass(actor_class): + raise Exception("Attempting to schedule actor that " + "isn't a class.") + if options is None or len(options) == 0: + reg_class = ray.remote(actor_class) + else: + reg_class = ray.remote(**options)(actor_class) + self.registered_actor_classes[id] = reg_class return self.registered_actor_classes[id] @@ -369,12 +377,22 @@ def decode_options( return opts -def serve(connection_str, test_mode=False): +_current_servicer: Optional[RayletServicer] = None + + +# Used by tests to peek inside the servicer +def _get_current_servicer(): + global _current_servicer + return _current_servicer + + +def serve(connection_str): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) - task_servicer = RayletServicer(test_mode=test_mode) + task_servicer = RayletServicer() data_servicer = DataServicer(task_servicer) logs_servicer = LogstreamServicer() - _set_server_api(RayServerAPI(task_servicer)) + global _current_servicer + _current_servicer = task_servicer ray_client_pb2_grpc.add_RayletDriverServicer_to_server( task_servicer, server) ray_client_pb2_grpc.add_RayletDataStreamerServicer_to_server( @@ -386,12 +404,20 @@ def serve(connection_str, test_mode=False): return server -def init_and_serve(connection_str, test_mode=False, *args, **kwargs): - info = ray.init(*args, **kwargs) - server = serve(connection_str, test_mode) +def init_and_serve(connection_str, *args, **kwargs): + with disable_client_hook(): + # Disable client mode inside the worker's environment + info = ray.init(*args, **kwargs) + server = serve(connection_str) return (server, info) +def shutdown_with_server(server, _exiting_interpreter=False): + server.stop(1) + with disable_client_hook(): + ray.shutdown(_exiting_interpreter) + + if __name__ == "__main__": logging.basicConfig(level="INFO") # TODO(barakmich): Perhaps wrap ray init diff --git a/python/ray/experimental/client/server/server_pickler.py b/python/ray/experimental/client/server/server_pickler.py index 10da70cc1..4f25d728f 100644 --- a/python/ray/experimental/client/server/server_pickler.py +++ b/python/ray/experimental/client/server/server_pickler.py @@ -1,5 +1,4 @@ -""" -Implements the client side of the client/server pickling protocol. +"""Implements the client side of the client/server pickling protocol. These picklers are aware of the server internals and can find the references held for the client within the server. @@ -20,6 +19,7 @@ import ray from typing import Any from typing import TYPE_CHECKING +from ray._private.client_mode_hook import disable_client_hook from ray.experimental.client.client_pickler import PickleStub from ray.experimental.client.server.server_stubs import ( ServerSelfReferenceSentinel) @@ -121,12 +121,13 @@ def loads_from_client(data: bytes, fix_imports=True, encoding="ASCII", errors="strict") -> Any: - if isinstance(data, str): - raise TypeError("Can't load pickle from unicode string") - file = io.BytesIO(data) - return ClientUnpickler( - server_instance, file, fix_imports=fix_imports, - encoding=encoding).load() + with disable_client_hook(): + if isinstance(data, str): + raise TypeError("Can't load pickle from unicode string") + file = io.BytesIO(data) + return ClientUnpickler( + server_instance, file, fix_imports=fix_imports, + encoding=encoding).load() def convert_from_arg(pb: "ray_client_pb2.Arg", diff --git a/python/ray/experimental/client/worker.py b/python/ray/experimental/client/worker.py index 8ed41bff4..b9124a9a7 100644 --- a/python/ray/experimental/client/worker.py +++ b/python/ray/experimental/client/worker.py @@ -3,7 +3,6 @@ It implements the Ray API functions that are forwarded through grpc calls to the server. """ import base64 -import inspect import json import logging import uuid @@ -14,7 +13,6 @@ from typing import List from typing import Tuple from typing import Optional -from ray.util.inspect import is_cython import grpc import ray.cloudpickle as cloudpickle @@ -23,12 +21,9 @@ import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc from ray.experimental.client.client_pickler import convert_to_arg from ray.experimental.client.client_pickler import dumps_from_client from ray.experimental.client.client_pickler import loads_from_server -from ray.experimental.client.common import ClientActorClass from ray.experimental.client.common import ClientActorHandle from ray.experimental.client.common import ClientActorRef from ray.experimental.client.common import ClientObjectRef -from ray.experimental.client.common import ClientRemoteFunc -from ray.experimental.client.common import ClientStub from ray.experimental.client.dataclient import DataClient from ray.experimental.client.logsclient import LogstreamClient @@ -61,6 +56,7 @@ class Worker: self.log_client = LogstreamClient(self.channel) self.log_client.set_logstream_level(logging.INFO) + self.closed = False def get(self, vals, *, timeout: Optional[float] = None) -> Any: to_get = [] @@ -153,21 +149,6 @@ class Worker: return (client_ready_object_ids, client_remaining_object_ids) - def remote(self, *args, **kwargs): - if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): - # This is the case where the decorator is just @ray.remote. - return remote_decorator(options=None)(args[0]) - error_string = ("The @ray.remote decorator must be applied either " - "with no arguments and no parentheses, for example " - "'@ray.remote', or it must be applied using some of " - "the arguments 'num_returns', 'num_cpus', 'num_gpus', " - "'memory', 'object_store_memory', 'resources', " - "'max_calls', or 'max_restarts', like " - "'@ray.remote(num_returns=2, " - "resources={\"CustomResource\": 1})'.") - assert len(args) == 0 and len(kwargs) > 0, error_string - return remote_decorator(options=kwargs) - def call_remote(self, instance, *args, **kwargs) -> List[bytes]: task = instance._prepare_client_task() for arg in args: @@ -190,6 +171,8 @@ class Worker: return ticket.return_ids def call_release(self, id: bytes) -> None: + if self.closed: + return self.reference_count[id] -= 1 if self.reference_count[id] == 0: self._release_server(id) @@ -212,6 +195,7 @@ class Worker: self.channel.close() self.channel = None self.server = None + self.closed = True def get_actor(self, name: str) -> ClientActorHandle: task = ray_client_pb2.ClientTask() @@ -258,7 +242,9 @@ class Worker: req.type = type resp = self.server.ClusterInfo(req) if resp.WhichOneof("response_type") == "resource_table": - return resp.resource_table.table + # translate from a proto map to a python dict + output_dict = {k: v for k, v in resp.resource_table.table.items()} + return output_dict return json.loads(resp.json) def is_initialized(self) -> bool: @@ -268,20 +254,6 @@ class Worker: return False -def remote_decorator(options: Optional[Dict[str, Any]]): - def decorator(function_or_class) -> ClientStub: - if (inspect.isfunction(function_or_class) - or is_cython(function_or_class)): - return ClientRemoteFunc(function_or_class, options=options) - elif inspect.isclass(function_or_class): - return ClientActorClass(function_or_class, options=options) - else: - raise TypeError("The @ray.remote decorator must be applied to " - "either a function or to a class.") - - return decorator - - def make_client_id() -> str: id = uuid.uuid4() return id.hex diff --git a/python/ray/state.py b/python/ray/state.py index 6d9df7870..aa3488e20 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -9,6 +9,7 @@ import ray from ray import gcs_utils from google.protobuf.json_format import MessageToDict from ray._private import services +from ray._private.client_mode_hook import client_mode_hook from ray.utils import (decode, binary_to_hex, hex_to_binary) from ray._raylet import GlobalStateAccessor @@ -851,6 +852,7 @@ def jobs(): return state.job_table() +@client_mode_hook def nodes(): """Get a list of the nodes in the cluster (for debugging only). @@ -964,6 +966,7 @@ def object_transfer_timeline(filename=None): return state.chrome_tracing_object_transfer_dump(filename=filename) +@client_mode_hook def cluster_resources(): """Get the current total cluster resources. @@ -977,6 +980,7 @@ def cluster_resources(): return state.cluster_resources() +@client_mode_hook def available_resources(): """Get the current available cluster resources. diff --git a/python/ray/test_utils.py b/python/ray/test_utils.py index a479903ff..4185d3f0c 100644 --- a/python/ray/test_utils.py +++ b/python/ray/test_utils.py @@ -446,4 +446,4 @@ def new_scheduler_enabled(): def client_test_enabled() -> bool: - return os.environ.get("RAY_TEST_CLIENT_MODE") == "1" + return os.environ.get("RAY_CLIENT_MODE") == "1" diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 7e552e616..903377ec8 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -167,7 +167,7 @@ py_test_module_list( name_suffix = "_client_mode", # TODO(barakmich): py_test will support env in Bazel 4.0.0... # Until then, we can use tags. - #env = {"RAY_TEST_CLIENT_MODE": "true"}, + #env = {"RAY_CLIENT_MODE": "1"}, tags = ["exclusive", "client_tests"], deps = ["//:ray_lib"], ) diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index 05cd9d8ca..4fdfe68c6 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -9,18 +9,12 @@ import subprocess import ray from ray.cluster_utils import Cluster from ray.test_utils import init_error_pubsub -from ray.test_utils import client_test_enabled -import ray.experimental.client as ray_client @pytest.fixture def shutdown_only(): yield None # The code after the yield will run as teardown code. - if client_test_enabled(): - ray_client.ray.disconnect() - ray_client._stop_test_server(1) - ray_client.reset_api() ray.shutdown() @@ -49,17 +43,10 @@ def _ray_start(**kwargs): init_kwargs = get_default_fixture_ray_kwargs() init_kwargs.update(kwargs) # Start the Ray processes. - if client_test_enabled(): - address_info = ray_client.ray.init(**init_kwargs) - else: - address_info = ray.init(**init_kwargs) + address_info = ray.init(**init_kwargs) yield address_info # The code after the yield will run as teardown code. - if client_test_enabled(): - ray_client.ray.disconnect() - ray_client._stop_test_server(1) - ray_client.reset_api() ray.shutdown() @@ -144,16 +131,9 @@ def _ray_start_cluster(**kwargs): # We assume driver will connect to the head (first node), # so ray init will be invoked if do_init is true if len(remote_nodes) == 1 and do_init: - if client_test_enabled(): - ray_client.ray.init(address=cluster.address) - else: - ray.init(address=cluster.address) + ray.init(address=cluster.address) yield cluster # The code after the yield will run as teardown code. - if client_test_enabled(): - ray_client.ray.disconnect() - ray_client._stop_test_server(1) - ray_client.reset_api() ray.shutdown() cluster.shutdown() diff --git a/python/ray/tests/test_actor.py b/python/ray/tests/test_actor.py index 3ba2ed7eb..4db4bdd4b 100644 --- a/python/ray/tests/test_actor.py +++ b/python/ray/tests/test_actor.py @@ -16,18 +16,12 @@ from ray.test_utils import wait_for_condition from ray.test_utils import wait_for_pid_to_exit from ray.tests.client_test_utils import create_remote_signal_actor -if client_test_enabled(): - from ray.experimental.client import ray -else: - import ray +import ray # NOTE: We have to import setproctitle after ray because we bundle setproctitle # with ray. import setproctitle # noqa -@pytest.mark.skipif( - client_test_enabled(), - reason="defining early, no ray package injection yet") def test_caching_actors(shutdown_only): # Test defining actors before ray.init() has been called. @@ -705,7 +699,6 @@ def test_options_num_returns(ray_start_regular_shared): assert ray.get([obj1, obj2]) == [1, 2] -@pytest.mark.skipif(client_test_enabled(), reason="remote args") def test_options_name(ray_start_regular_shared): @ray.remote class Foo: diff --git a/python/ray/tests/test_advanced.py b/python/ray/tests/test_advanced.py index ea2a6c693..50c27b07a 100644 --- a/python/ray/tests/test_advanced.py +++ b/python/ray/tests/test_advanced.py @@ -354,6 +354,8 @@ def test_illegal_api_calls(ray_start_regular): ray.get(3) +@pytest.mark.skipif( + client_test_enabled(), reason="grpc interaction with releasing resources") def test_multithreading(ray_start_2_cpus): # This test requires at least 2 CPUs to finish since the worker does not # release resources when joining the threads. diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 38330645b..7d0e7ae83 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -15,10 +15,7 @@ from ray.test_utils import ( wait_for_pid_to_exit, ) -if client_test_enabled(): - from ray.experimental.client import ray -else: - import ray +import ray logger = logging.getLogger(__name__) diff --git a/python/ray/tests/test_experimental_client.py b/python/ray/tests/test_experimental_client.py index 131954ede..c01030e58 100644 --- a/python/ray/tests/test_experimental_client.py +++ b/python/ray/tests/test_experimental_client.py @@ -160,8 +160,7 @@ def test_basic_actor(ray_start_regular_shared): def test_pass_handles(ray_start_regular_shared): - """ - Test that passing client handles to actors and functions to remote actors + """Test that passing client handles to actors and functions to remote actors in functions (on the server or raylet side) works transparently to the caller. """ @@ -264,9 +263,32 @@ def test_stdout_log_stream(ray_start_regular_shared): assert all((msg.find("Hello world") for msg in log_msgs)) -def test_basic_named_actor(ray_start_regular_shared): +def test_create_remote_before_start(ray_start_regular_shared): + """Creates remote objects (as though in a library) before + starting the client. """ - Test that ray.get_actor() can create and return a detached actor. + from ray.experimental.client import ray + + @ray.remote + class Returner: + def doit(self): + return "foo" + + @ray.remote + def f(x): + return x + 20 + + # Prints in verbose tests + print("Created remote functions") + + with ray_start_client_server() as ray: + assert ray.get(f.remote(3)) == 23 + a = Returner.remote() + assert ray.get(a.doit.remote()) == "foo" + + +def test_basic_named_actor(ray_start_regular_shared): + """Test that ray.get_actor() can create and return a detached actor. """ with ray_start_client_server() as ray: diff --git a/python/ray/tests/test_experimental_client_metadata.py b/python/ray/tests/test_experimental_client_metadata.py index f5a65cd66..a35f01649 100644 --- a/python/ray/tests/test_experimental_client_metadata.py +++ b/python/ray/tests/test_experimental_client_metadata.py @@ -2,8 +2,7 @@ from ray.experimental.client.ray_client_helpers import ray_start_client_server def test_get_ray_metadata(ray_start_regular_shared): - """ - Test the ClusterInfo client data pathway and API surface + """Test the ClusterInfo client data pathway and API surface """ with ray_start_client_server() as ray: ip_address = ray_start_regular_shared["node_ip_address"] diff --git a/python/ray/tests/test_experimental_client_references.py b/python/ray/tests/test_experimental_client_references.py index 4875d1ae0..7e5b4d184 100644 --- a/python/ray/tests/test_experimental_client_references.py +++ b/python/ray/tests/test_experimental_client_references.py @@ -2,11 +2,11 @@ from ray.experimental.client.ray_client_helpers import ray_start_client_server from ray.test_utils import wait_for_condition import ray as real_ray from ray.core.generated.gcs_pb2 import ActorTableData -from ray.experimental.client import _get_server_instance +from ray.experimental.client.server.server import _get_current_servicer def server_object_ref_count(n): - server = _get_server_instance() + server = _get_current_servicer() assert server is not None def test_cond(): @@ -20,7 +20,7 @@ def server_object_ref_count(n): def server_actor_ref_count(n): - server = _get_server_instance() + server = _get_current_servicer() assert server is not None def test_cond(): diff --git a/python/ray/worker.py b/python/ray/worker.py index a3d07e5ee..888cf680b 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -51,6 +51,7 @@ from ray.ray_logging import setup_logger from ray.ray_logging import global_worker_stdstream_dispatcher from ray.utils import _random_string, check_oversized_pickle from ray.util.inspect import is_cython +from ray._private.client_mode_hook import client_mode_hook SCRIPT_MODE = 0 WORKER_MODE = 1 @@ -469,6 +470,7 @@ _global_node = None """ray.node.Node: The global node object that is created by ray.init().""" +@client_mode_hook def init( address=None, *, @@ -781,6 +783,7 @@ def init( _post_init_hooks = [] +@client_mode_hook def shutdown(_exiting_interpreter=False): """Disconnect the worker, and terminate processes started by ray.init(). @@ -1044,6 +1047,7 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): worker.error_message_pubsub_client.close() +@client_mode_hook def is_initialized(): """Check if ray.init has been called yet. @@ -1322,6 +1326,7 @@ def show_in_dashboard(message, key="", dtype="text"): blocking_get_inside_async_warned = False +@client_mode_hook def get(object_refs, *, timeout=None): """Get a remote object or a list of remote objects from the object store. @@ -1400,6 +1405,7 @@ def get(object_refs, *, timeout=None): return values +@client_mode_hook def put(value): """Store an object in the object store. @@ -1428,6 +1434,7 @@ def put(value): blocking_wait_inside_async_warned = False +@client_mode_hook def wait(object_refs, *, num_returns=1, timeout=None, fetch_local=True): """Return a list of IDs that are ready and a list of IDs that are not. @@ -1528,6 +1535,7 @@ def wait(object_refs, *, num_returns=1, timeout=None, fetch_local=True): return ready_ids, remaining_ids +@client_mode_hook def get_actor(name): """Get a handle to a detached actor. @@ -1548,6 +1556,7 @@ def get_actor(name): return handle +@client_mode_hook def kill(actor, *, no_restart=True): """Kill an actor forcefully. @@ -1575,6 +1584,7 @@ def kill(actor, *, no_restart=True): worker.core_worker.kill_actor(actor._ray_actor_id, no_restart) +@client_mode_hook def cancel(object_ref, *, force=False, recursive=True): """Cancels a task according to the following conditions. @@ -1691,6 +1701,7 @@ def make_decorator(num_returns=None, return decorator +@client_mode_hook def remote(*args, **kwargs): """Defines a remote function or an actor class.