mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:38:18 +08:00
[ray_client]: Insert decorators into the real ray module to allow for client mode (#13031)
This commit is contained in:
@@ -0,0 +1,47 @@
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
|
||||
client_mode_enabled = os.environ.get("RAY_CLIENT_MODE", "0") == "1"
|
||||
|
||||
_client_hook_enabled = True
|
||||
|
||||
|
||||
def _enable_client_hook(val: bool):
|
||||
global _client_hook_enabled
|
||||
_client_hook_enabled = val
|
||||
|
||||
|
||||
def _disable_client_hook():
|
||||
global _client_hook_enabled
|
||||
out = _client_hook_enabled
|
||||
_client_hook_enabled = False
|
||||
return out
|
||||
|
||||
|
||||
def _explicitly_enable_client_mode():
|
||||
global client_mode_enabled
|
||||
client_mode_enabled = True
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_client_hook():
|
||||
val = _disable_client_hook()
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
_enable_client_hook(val)
|
||||
|
||||
|
||||
def client_mode_hook(func):
|
||||
"""
|
||||
Decorator for ray module methods to delegate to ray client
|
||||
"""
|
||||
from ray.experimental.client import ray
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
global _client_hook_enabled
|
||||
if client_mode_enabled and _client_hook_enabled:
|
||||
return getattr(ray, func.__name__)(*args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
@@ -107,6 +107,10 @@ from ray.exceptions import (
|
||||
TaskCancelledError
|
||||
)
|
||||
from ray.utils import decode
|
||||
from ray._private.client_mode_hook import (
|
||||
_enable_client_hook,
|
||||
_disable_client_hook,
|
||||
)
|
||||
import msgpack
|
||||
|
||||
cimport cpython
|
||||
@@ -558,6 +562,7 @@ cdef CRayStatus task_execution_handler(
|
||||
|
||||
with gil:
|
||||
try:
|
||||
client_was_enabled = _disable_client_hook()
|
||||
try:
|
||||
# The call to execute_task should never raise an exception. If
|
||||
# it does, that indicates that there was an internal error.
|
||||
@@ -582,6 +587,8 @@ cdef CRayStatus task_execution_handler(
|
||||
else:
|
||||
logger.exception("SystemExit was raised from the worker")
|
||||
return CRayStatus.UnexpectedSystemExit()
|
||||
finally:
|
||||
_enable_client_hook(client_was_enabled)
|
||||
|
||||
return CRayStatus.OK()
|
||||
|
||||
|
||||
@@ -1,148 +1,104 @@
|
||||
from ray.experimental.client.api import ClientAPI
|
||||
from ray.experimental.client.api import APIImpl
|
||||
from typing import Optional, List, Tuple
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Tuple
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# About these global variables: Ray 1.0 uses exported module functions to
|
||||
# provide its API, and we need to match that. However, we want different
|
||||
# behaviors depending on where, exactly, in the client stack this is running.
|
||||
#
|
||||
# The reason for these differences depends on what's being pickled and passed
|
||||
# to functions, or functions inside functions. So there are three cases to care
|
||||
# about
|
||||
#
|
||||
# (Python Client)-->(Python ClientServer)-->(Internal Raylet Process)
|
||||
#
|
||||
# * _client_api should be set if we're inside the client
|
||||
# * _server_api should be set if we're inside the clientserver
|
||||
# * Both will be set if we're running both (as in a test)
|
||||
# * Neither should be set if we're inside the raylet (but we still need to shim
|
||||
# from the client API surface to the Ray API)
|
||||
#
|
||||
# The job of RayAPIStub (below) delegates to the appropriate one of these
|
||||
# depending on what's set or not. Then, all users importing the ray object
|
||||
# from this package get the stub which routes them to the appropriate APIImpl.
|
||||
_client_api: Optional[APIImpl] = None
|
||||
_server_api: Optional[APIImpl] = None
|
||||
|
||||
# The reason for _is_server is a hack around the above comment while running
|
||||
# tests. If we have both a client and a server trying to control these static
|
||||
# variables then we need a way to decide which to use. In this case, both
|
||||
# _client_api and _server_api are set.
|
||||
# This boolean flips between the two
|
||||
_is_server: bool = False
|
||||
|
||||
|
||||
@contextmanager
|
||||
def stash_api_for_tests(in_test: bool):
|
||||
global _is_server
|
||||
is_server = _is_server
|
||||
if in_test:
|
||||
_is_server = True
|
||||
try:
|
||||
yield _server_api
|
||||
finally:
|
||||
if in_test:
|
||||
_is_server = is_server
|
||||
|
||||
|
||||
def _set_client_api(val: Optional[APIImpl]):
|
||||
global _client_api
|
||||
global _is_server
|
||||
if _client_api is not None:
|
||||
raise Exception("Trying to set more than one client API")
|
||||
_client_api = val
|
||||
_is_server = False
|
||||
|
||||
|
||||
def _set_server_api(val: Optional[APIImpl]):
|
||||
global _server_api
|
||||
global _is_server
|
||||
if _server_api is not None:
|
||||
raise Exception("Trying to set more than one server API")
|
||||
_server_api = val
|
||||
_is_server = True
|
||||
|
||||
|
||||
def reset_api():
|
||||
global _client_api
|
||||
global _server_api
|
||||
global _is_server
|
||||
_client_api = None
|
||||
_server_api = None
|
||||
_is_server = False
|
||||
|
||||
|
||||
def _get_client_api() -> APIImpl:
|
||||
global _client_api
|
||||
return _client_api
|
||||
|
||||
|
||||
def _get_server_instance():
|
||||
"""Used inside tests to inspect the running server.
|
||||
"""
|
||||
global _server_api
|
||||
if _server_api is not None:
|
||||
return _server_api.server
|
||||
|
||||
|
||||
class RayAPIStub:
|
||||
"""This class stands in as the replacement API for the `import ray` module.
|
||||
|
||||
Much like the ray module, this mostly delegates the work to the
|
||||
_client_worker. As parts of the ray API are covered, they are piped through
|
||||
here or on the client worker API.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
from ray.experimental.client.api import ClientAPI
|
||||
self.api = ClientAPI()
|
||||
self.client_worker = None
|
||||
self._server = None
|
||||
self._connected_with_init = False
|
||||
self._inside_client_test = False
|
||||
|
||||
def connect(self,
|
||||
conn_str: str,
|
||||
secure: bool = False,
|
||||
metadata: List[Tuple[str, str]] = None,
|
||||
stub=None) -> None:
|
||||
metadata: List[Tuple[str, str]] = None) -> None:
|
||||
"""Connect the Ray Client to a server.
|
||||
|
||||
Args:
|
||||
conn_str: Connection string, in the form "[host]:port"
|
||||
secure: Whether to use a TLS secured gRPC channel
|
||||
metadata: gRPC metadata to send on connect
|
||||
"""
|
||||
# Delay imports until connect to avoid circular imports.
|
||||
from ray.experimental.client.worker import Worker
|
||||
_client_worker = Worker(conn_str, secure=secure, metadata=metadata)
|
||||
_set_client_api(ClientAPI(_client_worker))
|
||||
import ray._private.client_mode_hook
|
||||
if self.client_worker is not None:
|
||||
if self._connected_with_init:
|
||||
return
|
||||
raise Exception(
|
||||
"ray.connect() called, but ray client is already connected")
|
||||
if not self._inside_client_test:
|
||||
# If we're calling a client connect specifically and we're not
|
||||
# currently in client mode, ensure we are.
|
||||
ray._private.client_mode_hook._explicitly_enable_client_mode()
|
||||
self.client_worker = Worker(conn_str, secure=secure, metadata=metadata)
|
||||
self.api.worker = self.client_worker
|
||||
|
||||
def disconnect(self):
|
||||
global _client_api
|
||||
if _client_api is not None:
|
||||
_client_api.close()
|
||||
_client_api = None
|
||||
"""Disconnect the Ray Client.
|
||||
"""
|
||||
if self.client_worker is not None:
|
||||
self.client_worker.close()
|
||||
self.client_worker = None
|
||||
|
||||
# remote can be called outside of a connection, which is why it
|
||||
# exists on the same API layer as connect() itself.
|
||||
def remote(self, *args, **kwargs):
|
||||
"""remote is the hook stub passed on to replace `ray.remote`.
|
||||
|
||||
This sets up remote functions or actors, as the decorator,
|
||||
but does not execute them.
|
||||
|
||||
Args:
|
||||
args: opaque arguments
|
||||
kwargs: opaque keyword arguments
|
||||
"""
|
||||
return self.api.remote(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
global _get_client_api
|
||||
api = _get_client_api()
|
||||
return getattr(api, key)
|
||||
if not self.is_connected():
|
||||
raise Exception("Ray Client is not connected. "
|
||||
"Please connect by calling `ray.connect`.")
|
||||
return getattr(self.api, key)
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
global _client_api
|
||||
return _client_api is not None
|
||||
return self.api is not None
|
||||
|
||||
def init(self, *args, **kwargs):
|
||||
if _is_client_test_env():
|
||||
global _test_server
|
||||
import ray.experimental.client.server.server as ray_client_server
|
||||
_test_server, address_info = ray_client_server.init_and_serve(
|
||||
"localhost:50051", test_mode=True, *args, **kwargs)
|
||||
self.connect("localhost:50051")
|
||||
return address_info
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Please call ray.connect() in client mode")
|
||||
if self._server is not None:
|
||||
raise Exception("Trying to start two instances of ray via client")
|
||||
import ray.experimental.client.server.server as ray_client_server
|
||||
self._server, address_info = ray_client_server.init_and_serve(
|
||||
"localhost:50051", *args, **kwargs)
|
||||
self.connect("localhost:50051")
|
||||
self._connected_with_init = True
|
||||
return address_info
|
||||
|
||||
def shutdown(self, _exiting_interpreter=False):
|
||||
self.disconnect()
|
||||
import ray.experimental.client.server.server as ray_client_server
|
||||
if self._server is None:
|
||||
return
|
||||
ray_client_server.shutdown_with_server(self._server,
|
||||
_exiting_interpreter)
|
||||
self._server = None
|
||||
|
||||
|
||||
ray = RayAPIStub()
|
||||
|
||||
_test_server = None
|
||||
|
||||
|
||||
def _stop_test_server(*args):
|
||||
global _test_server
|
||||
_test_server.stop(*args)
|
||||
|
||||
|
||||
def _is_client_test_env() -> bool:
|
||||
return os.environ.get("RAY_TEST_CLIENT_MODE") == "1"
|
||||
|
||||
|
||||
# Someday we might add methods in this module so that someone who
|
||||
# tries to `import ray_client as ray` -- as a module, instead of
|
||||
# `from ray_client import ray` -- as the API stub
|
||||
|
||||
@@ -1,74 +1,51 @@
|
||||
# This file defines an interface and client-side API stub
|
||||
# for referring either to the core Ray API or the same interface
|
||||
# from the Ray client.
|
||||
#
|
||||
# In tandem with __init__.py, we want to expose an API that's
|
||||
# close to `python/ray/__init__.py` but with more than one implementation.
|
||||
# The stubs in __init__ should call into a well-defined interface.
|
||||
# Only the core Ray API implementation should actually `import ray`
|
||||
# (and thus import all the raylet worker C bindings and such).
|
||||
# But to make sure that we're matching these calls, we define this API.
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Union, Optional
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
"""This file defines the interface between the ray client worker
|
||||
and the overall ray module API.
|
||||
"""
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from ray.experimental.client.common import ClientActorHandle
|
||||
from ray.experimental.client.common import ClientStub
|
||||
from ray.experimental.client.common import ClientActorHandle
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray._raylet import ObjectRef
|
||||
|
||||
# Use the imports for type checking. This is a python 3.6 limitation.
|
||||
# See https://www.python.org/dev/peps/pep-0563/
|
||||
PutType = Union[ClientObjectRef, ObjectRef]
|
||||
|
||||
|
||||
class APIImpl(ABC):
|
||||
"""
|
||||
APIImpl is the interface to implement for whichever version of the core
|
||||
Ray API that needs abstracting when run in client mode.
|
||||
class ClientAPI:
|
||||
"""The Client-side methods corresponding to the ray API. Delegates
|
||||
to the Client Worker that contains the connection to the ClientServer.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
|
||||
"""
|
||||
get is the hook stub passed on to replace `ray.get`
|
||||
def __init__(self, worker=None):
|
||||
self.worker = worker
|
||||
|
||||
def get(self, vals, *, timeout=None):
|
||||
"""get is the hook stub passed on to replace `ray.get`
|
||||
|
||||
Args:
|
||||
vals: [Client]ObjectRef or list of these refs to retrieve.
|
||||
timeout: Optional timeout in milliseconds
|
||||
"""
|
||||
pass
|
||||
return self.worker.get(vals, timeout=timeout)
|
||||
|
||||
@abstractmethod
|
||||
def put(self, vals: Any, *args,
|
||||
**kwargs) -> Union["ClientObjectRef", "ObjectRef"]:
|
||||
"""
|
||||
put is the hook stub passed on to replace `ray.put`
|
||||
def put(self, *args, **kwargs):
|
||||
"""put is the hook stub passed on to replace `ray.put`
|
||||
|
||||
Args:
|
||||
vals: The value or list of values to `put`.
|
||||
args: opaque arguments
|
||||
kwargs: opaque keyword arguments
|
||||
"""
|
||||
pass
|
||||
return self.worker.put(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def wait(self, *args, **kwargs):
|
||||
"""
|
||||
wait is the hook stub passed on to replace `ray.wait`
|
||||
"""wait is the hook stub passed on to replace `ray.wait`
|
||||
|
||||
Args:
|
||||
args: opaque arguments
|
||||
kwargs: opaque keyword arguments
|
||||
"""
|
||||
pass
|
||||
return self.worker.wait(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def remote(self, *args, **kwargs):
|
||||
"""
|
||||
remote is the hook stub passed on to replace `ray.remote`.
|
||||
"""remote is the hook stub passed on to replace `ray.remote`.
|
||||
|
||||
This sets up remote functions or actors, as the decorator,
|
||||
but does not execute them.
|
||||
@@ -77,12 +54,24 @@ class APIImpl(ABC):
|
||||
args: opaque arguments
|
||||
kwargs: opaque keyword arguments
|
||||
"""
|
||||
pass
|
||||
# Delayed import to avoid a cyclic import
|
||||
from ray.experimental.client.common import remote_decorator
|
||||
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
||||
# This is the case where the decorator is just @ray.remote.
|
||||
return remote_decorator(options=None)(args[0])
|
||||
error_string = ("The @ray.remote decorator must be applied either "
|
||||
"with no arguments and no parentheses, for example "
|
||||
"'@ray.remote', or it must be applied using some of "
|
||||
"the arguments 'num_returns', 'num_cpus', 'num_gpus', "
|
||||
"'memory', 'object_store_memory', 'resources', "
|
||||
"'max_calls', or 'max_restarts', like "
|
||||
"'@ray.remote(num_returns=2, "
|
||||
"resources={\"CustomResource\": 1})'.")
|
||||
assert len(args) == 0 and len(kwargs) > 0, error_string
|
||||
return remote_decorator(options=kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def call_remote(self, instance: "ClientStub", *args, **kwargs):
|
||||
"""
|
||||
call_remote is called by stub objects to execute them remotely.
|
||||
"""call_remote is called by stub objects to execute them remotely.
|
||||
|
||||
This is used by stub objects in situations where they're called
|
||||
with .remote, eg, `f.remote()` or `actor_cls.remote()`.
|
||||
@@ -95,31 +84,57 @@ class APIImpl(ABC):
|
||||
args: opaque arguments
|
||||
kwargs: opaque keyword arguments
|
||||
"""
|
||||
pass
|
||||
return self.worker.call_remote(instance, *args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
def call_release(self, id: bytes) -> None:
|
||||
"""Attempts to release an object reference.
|
||||
|
||||
When client references are destructed, they release their reference,
|
||||
which can opportunistically send a notification through the datachannel
|
||||
to release the reference being held for that object on the server.
|
||||
|
||||
Args:
|
||||
id: The id of the reference to release on the server side.
|
||||
"""
|
||||
close cleans up an API connection by closing any channels or
|
||||
return self.worker.call_release(id)
|
||||
|
||||
def call_retain(self, id: bytes) -> None:
|
||||
"""Attempts to retain a client object reference.
|
||||
|
||||
Increments the reference count on the client side, to prevent
|
||||
the client worker from attempting to release the server reference.
|
||||
|
||||
Args:
|
||||
id: The id of the reference to retain on the client side.
|
||||
"""
|
||||
return self.worker.call_retain(id)
|
||||
|
||||
def close(self) -> None:
|
||||
"""close cleans up an API connection by closing any channels or
|
||||
shutting down any servers gracefully.
|
||||
"""
|
||||
pass
|
||||
return self.worker.close()
|
||||
|
||||
@abstractmethod
|
||||
def kill(self, actor, *, no_restart=True):
|
||||
def get_actor(self, name: str) -> "ClientActorHandle":
|
||||
"""Returns a handle to an actor by name.
|
||||
|
||||
Args:
|
||||
name: The name passed to this actor by
|
||||
Actor.options(name="name").remote()
|
||||
"""
|
||||
kill forcibly stops an actor running in the cluster
|
||||
return self.worker.get_actor(name)
|
||||
|
||||
def kill(self, actor: "ClientActorHandle", *, no_restart=True):
|
||||
"""kill forcibly stops an actor running in the cluster
|
||||
|
||||
Args:
|
||||
no_restart: Whether this actor should be restarted if it's a
|
||||
restartable actor.
|
||||
"""
|
||||
pass
|
||||
return self.worker.terminate_actor(actor, no_restart)
|
||||
|
||||
@abstractmethod
|
||||
def cancel(self, obj, *, force=False, recursive=True):
|
||||
"""
|
||||
Cancels a task on the cluster.
|
||||
def cancel(self, obj: "ClientObjectRef", *, force=False, recursive=True):
|
||||
"""Cancels a task on the cluster.
|
||||
|
||||
If the specified task is pending execution, it will not be executed. If
|
||||
the task is currently executing, the behavior depends on the ``force``
|
||||
@@ -136,80 +151,11 @@ class APIImpl(ABC):
|
||||
recursive (boolean): Whether to try to cancel tasks submitted by
|
||||
the task specified.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def call_release(self, id: bytes) -> None:
|
||||
"""
|
||||
Attempts to release an object reference.
|
||||
|
||||
When client references are destructed, they release their reference,
|
||||
which can opportunistically send a notification through the datachannel
|
||||
to release the reference being held for that object on the server.
|
||||
|
||||
Args:
|
||||
id: The id of the reference to release on the server side.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def call_retain(self, id: bytes) -> None:
|
||||
"""
|
||||
Attempts to retain a client object reference.
|
||||
|
||||
Increments the reference count on the client side, to prevent
|
||||
the client worker from attempting to release the server reference.
|
||||
|
||||
Args:
|
||||
id: The id of the reference to retain on the client side.
|
||||
"""
|
||||
|
||||
|
||||
class ClientAPI(APIImpl):
|
||||
"""
|
||||
The Client-side methods corresponding to the ray API. Delegates
|
||||
to the Client Worker that contains the connection to the ClientServer.
|
||||
"""
|
||||
|
||||
def __init__(self, worker):
|
||||
self.worker = worker
|
||||
|
||||
def get(self, vals, *, timeout=None):
|
||||
return self.worker.get(vals, timeout=timeout)
|
||||
|
||||
def put(self, *args, **kwargs):
|
||||
return self.worker.put(*args, **kwargs)
|
||||
|
||||
def wait(self, *args, **kwargs):
|
||||
return self.worker.wait(*args, **kwargs)
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
return self.worker.remote(*args, **kwargs)
|
||||
|
||||
def call_remote(self, instance: "ClientStub", *args, **kwargs):
|
||||
return self.worker.call_remote(instance, *args, **kwargs)
|
||||
|
||||
def call_release(self, id: bytes) -> None:
|
||||
return self.worker.call_release(id)
|
||||
|
||||
def call_retain(self, id: bytes) -> None:
|
||||
return self.worker.call_retain(id)
|
||||
|
||||
def close(self) -> None:
|
||||
return self.worker.close()
|
||||
|
||||
def get_actor(self, name: str) -> "ClientActorHandle":
|
||||
return self.worker.get_actor(name)
|
||||
|
||||
def kill(self, actor: "ClientActorHandle", *, no_restart=True):
|
||||
return self.worker.terminate_actor(actor, no_restart)
|
||||
|
||||
def cancel(self, obj: "ClientObjectRef", *, force=False, recursive=True):
|
||||
return self.worker.terminate_task(obj, force, recursive)
|
||||
|
||||
# Various metadata methods for the client that are defined in the protocol.
|
||||
def is_initialized(self) -> bool:
|
||||
""" True if our client is connected, and if the server is initialized.
|
||||
|
||||
"""True if our client is connected, and if the server is initialized.
|
||||
Returns:
|
||||
A boolean determining if the client is connected and
|
||||
server initialized.
|
||||
@@ -222,6 +168,8 @@ class ClientAPI(APIImpl):
|
||||
Returns:
|
||||
Information about the Ray clients in the cluster.
|
||||
"""
|
||||
# This should be imported here, otherwise, it will error doc build.
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
return self.worker.get_cluster_info(
|
||||
ray_client_pb2.ClusterInfoType.NODES)
|
||||
|
||||
@@ -235,6 +183,8 @@ class ClientAPI(APIImpl):
|
||||
A dictionary mapping resource name to the total quantity of that
|
||||
resource in the cluster.
|
||||
"""
|
||||
# This should be imported here, otherwise, it will error doc build.
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
return self.worker.get_cluster_info(
|
||||
ray_client_pb2.ClusterInfoType.CLUSTER_RESOURCES)
|
||||
|
||||
@@ -250,6 +200,8 @@ class ClientAPI(APIImpl):
|
||||
A dictionary mapping resource name to the total quantity of that
|
||||
resource in the cluster.
|
||||
"""
|
||||
# This should be imported here, otherwise, it will error doc build.
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
return self.worker.get_cluster_info(
|
||||
ray_client_pb2.ClusterInfoType.AVAILABLE_RESOURCES)
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
Implements the client side of the client/server pickling protocol.
|
||||
"""Implements the client side of the client/server pickling protocol.
|
||||
|
||||
All ray client client/server data transfer happens through this pickling
|
||||
protocol. The model is as follows:
|
||||
@@ -41,6 +40,7 @@ from ray.experimental.client.common import ClientRemoteMethod
|
||||
from ray.experimental.client.common import OptionWrapper
|
||||
from ray.experimental.client.common import SelfReferenceSentinel
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
from ray._private.client_mode_hook import disable_client_hook
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
try:
|
||||
@@ -155,10 +155,11 @@ class ServerUnpickler(pickle.Unpickler):
|
||||
|
||||
|
||||
def dumps_from_client(obj: Any, client_id: str, protocol=None) -> bytes:
|
||||
with io.BytesIO() as file:
|
||||
cp = ClientPickler(client_id, file, protocol=protocol)
|
||||
cp.dump(obj)
|
||||
return file.getvalue()
|
||||
with disable_client_hook():
|
||||
with io.BytesIO() as file:
|
||||
cp = ClientPickler(client_id, file, protocol=protocol)
|
||||
cp.dump(obj)
|
||||
return file.getvalue()
|
||||
|
||||
|
||||
def loads_from_server(data: bytes,
|
||||
|
||||
@@ -2,6 +2,8 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
from ray.experimental.client import ray
|
||||
from ray.experimental.client.options import validate_options
|
||||
|
||||
import inspect
|
||||
from ray.util.inspect import is_cython
|
||||
import json
|
||||
import threading
|
||||
from typing import Any
|
||||
@@ -52,8 +54,7 @@ class ClientStub:
|
||||
|
||||
|
||||
class ClientRemoteFunc(ClientStub):
|
||||
"""
|
||||
A stub created on the Ray Client to represent a remote
|
||||
"""A stub created on the Ray Client to represent a remote
|
||||
function that can be exectued on the cluster.
|
||||
|
||||
This class is allowed to be passed around between remote functions.
|
||||
@@ -112,7 +113,7 @@ class ClientRemoteFunc(ClientStub):
|
||||
|
||||
|
||||
class ClientActorClass(ClientStub):
|
||||
""" A stub created on the Ray Client to represent an actor class.
|
||||
"""A stub created on the Ray Client to represent an actor class.
|
||||
|
||||
It is wrapped by ray.remote and can be executed on the cluster.
|
||||
|
||||
@@ -294,3 +295,17 @@ class DataEncodingSentinel:
|
||||
|
||||
class SelfReferenceSentinel(DataEncodingSentinel):
|
||||
pass
|
||||
|
||||
|
||||
def remote_decorator(options: Optional[Dict[str, Any]]):
|
||||
def decorator(function_or_class) -> ClientStub:
|
||||
if (inspect.isfunction(function_or_class)
|
||||
or is_cython(function_or_class)):
|
||||
return ClientRemoteFunc(function_or_class, options=options)
|
||||
elif inspect.isclass(function_or_class):
|
||||
return ClientActorClass(function_or_class, options=options)
|
||||
else:
|
||||
raise TypeError("The @ray.remote decorator must be applied to "
|
||||
"either a function or to a class.")
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
This file implements a threaded stream controller to abstract a data stream
|
||||
"""This file implements a threaded stream controller to abstract a data stream
|
||||
back to the ray clientserver.
|
||||
"""
|
||||
import logging
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
from ray.experimental.client import ray
|
||||
|
||||
from ray.tune import tune
|
||||
|
||||
ray.connect("localhost:50051")
|
||||
|
||||
tune.run("PG", config={"env": "CartPole-v0"})
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
This file implements a threaded stream controller to return logs back from
|
||||
"""This file implements a threaded stream controller to return logs back from
|
||||
the ray clientserver.
|
||||
"""
|
||||
import sys
|
||||
@@ -12,6 +11,10 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# TODO(barakmich): Running a logger in a logger causes loopback.
|
||||
# The client logger need its own root -- possibly this one.
|
||||
# For the moment, let's just not propogate beyond this point.
|
||||
logger.propagate = False
|
||||
|
||||
|
||||
class LogstreamClient:
|
||||
@@ -45,8 +48,7 @@ class LogstreamClient:
|
||||
raise e
|
||||
|
||||
def log(self, level: int, msg: str):
|
||||
"""
|
||||
Log the message from the log stream.
|
||||
"""Log the message from the log stream.
|
||||
By default, calls logger.log but this can be overridden.
|
||||
|
||||
Args:
|
||||
@@ -56,8 +58,7 @@ class LogstreamClient:
|
||||
logger.log(level=level, msg=msg)
|
||||
|
||||
def stdstream(self, level: int, msg: str):
|
||||
"""
|
||||
Log the stdout/stderr entry from the log stream.
|
||||
"""Log the stdout/stderr entry from the log stream.
|
||||
By default, calls print but this can be overridden.
|
||||
|
||||
Args:
|
||||
@@ -68,6 +69,7 @@ class LogstreamClient:
|
||||
print(msg, file=print_file)
|
||||
|
||||
def set_logstream_level(self, level: int):
|
||||
logger.setLevel(level)
|
||||
req = ray_client_pb2.LogSettingsRequest()
|
||||
req.enabled = True
|
||||
req.loglevel = level
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
from contextlib import contextmanager
|
||||
|
||||
import ray.experimental.client.server.server as ray_client_server
|
||||
from ray.experimental.client import ray, reset_api
|
||||
from ray.experimental.client import ray
|
||||
|
||||
|
||||
@contextmanager
|
||||
def ray_start_client_server():
|
||||
server = ray_client_server.serve("localhost:50051", test_mode=True)
|
||||
ray._inside_client_test = True
|
||||
server = ray_client_server.serve("localhost:50051")
|
||||
ray.connect("localhost:50051")
|
||||
try:
|
||||
yield ray
|
||||
finally:
|
||||
ray._inside_client_test = False
|
||||
ray.disconnect()
|
||||
server.stop(0)
|
||||
reset_api()
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
# Along with `api.py` this is the stub that interfaces with
|
||||
# the real (C-binding, raylet) ray core.
|
||||
#
|
||||
# Ideally, the first import line is the only time we actually
|
||||
# import ray in this library (excluding the main function for the server)
|
||||
#
|
||||
# While the stub is trivial, it allows us to check that the calls we're
|
||||
# making into the core-ray module are contained and well-defined.
|
||||
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
import logging
|
||||
import ray
|
||||
|
||||
from ray.experimental.client.api import APIImpl
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.common import ClientStub
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CoreRayAPI(APIImpl):
|
||||
"""
|
||||
Implements the equivalent client-side Ray API by simply passing along to
|
||||
the Core Ray API. Primarily used inside of Ray Workers as a trampoline back
|
||||
to core ray when passed client stubs.
|
||||
"""
|
||||
|
||||
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
|
||||
return ray.get(vals, timeout=timeout)
|
||||
|
||||
def put(self, vals: Any, *args,
|
||||
**kwargs) -> Union[ClientObjectRef, ray._raylet.ObjectRef]:
|
||||
return ray.put(vals, *args, **kwargs)
|
||||
|
||||
def wait(self, *args, **kwargs):
|
||||
return ray.wait(*args, **kwargs)
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
return ray.remote(*args, **kwargs)
|
||||
|
||||
def call_remote(self, instance: ClientStub, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"Should not attempt execution of a client stub inside the raylet")
|
||||
|
||||
def close(self) -> None:
|
||||
return None
|
||||
|
||||
def kill(self, actor, *, no_restart=True):
|
||||
return ray.kill(actor, no_restart=no_restart)
|
||||
|
||||
def cancel(self, obj, *, force=False, recursive=True):
|
||||
return ray.cancel(obj, force=force, recursive=recursive)
|
||||
|
||||
def is_initialized(self) -> bool:
|
||||
return ray.is_initialized()
|
||||
|
||||
def call_release(self, id: bytes) -> None:
|
||||
return None
|
||||
|
||||
def call_retain(self, id: bytes) -> None:
|
||||
return None
|
||||
|
||||
# Allow for generic fallback to ray.* in remote methods. This allows calls
|
||||
# like ray.nodes() to be run in remote functions even though the client
|
||||
# doesn't currently support them.
|
||||
def __getattr__(self, key: str):
|
||||
return getattr(ray, key)
|
||||
|
||||
|
||||
class RayServerAPI(CoreRayAPI):
|
||||
"""
|
||||
Ray Client server-side API shim. By default, simply calls the default Core
|
||||
Ray API calls, but also accepts scheduling calls from functions running
|
||||
inside of other remote functions that need to create more work.
|
||||
"""
|
||||
|
||||
def __init__(self, server_instance):
|
||||
self.server = server_instance
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
This file responds to log stream requests and forwards logs
|
||||
"""This file responds to log stream requests and forwards logs
|
||||
with its handler.
|
||||
"""
|
||||
import io
|
||||
@@ -70,6 +69,9 @@ def log_status_change_thread(log_queue, request_iterator):
|
||||
std_handler.register_global()
|
||||
root_logger.addHandler(current_handler)
|
||||
root_logger.setLevel(req.loglevel)
|
||||
except grpc.RpcError as e:
|
||||
logger.debug(f"closing log thread "
|
||||
f"grpc error reading request_iterator: {e}")
|
||||
finally:
|
||||
if current_handler is not None:
|
||||
root_logger.setLevel(default_level)
|
||||
|
||||
@@ -17,27 +17,25 @@ import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
import time
|
||||
import inspect
|
||||
import json
|
||||
from ray.experimental.client import stash_api_for_tests, _set_server_api
|
||||
from ray.experimental.client.server.server_pickler import convert_from_arg
|
||||
from ray.experimental.client.server.server_pickler import dumps_from_server
|
||||
from ray.experimental.client.server.server_pickler import loads_from_client
|
||||
from ray.experimental.client.server.core_ray_api import RayServerAPI
|
||||
from ray.experimental.client.server.dataservicer import DataServicer
|
||||
from ray.experimental.client.server.logservicer import LogstreamServicer
|
||||
from ray.experimental.client.server.server_stubs import current_remote
|
||||
from ray._private.client_mode_hook import disable_client_hook
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
def __init__(self, test_mode=False):
|
||||
def __init__(self):
|
||||
self.object_refs: Dict[str, Dict[bytes, ray.ObjectRef]] = defaultdict(
|
||||
dict)
|
||||
self.function_refs = {}
|
||||
self.actor_refs: Dict[bytes, ray.ActorHandle] = {}
|
||||
self.actor_owners: Dict[str, Set[bytes]] = defaultdict(set)
|
||||
self.registered_actor_classes = {}
|
||||
self._test_mode = test_mode
|
||||
self._current_function_stub = None
|
||||
|
||||
def ClusterInfo(self, request,
|
||||
@@ -45,7 +43,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
resp = ray_client_pb2.ClusterInfoResponse()
|
||||
resp.type = request.type
|
||||
if request.type == ray_client_pb2.ClusterInfoType.CLUSTER_RESOURCES:
|
||||
resources = ray.cluster_resources()
|
||||
with disable_client_hook():
|
||||
resources = ray.cluster_resources()
|
||||
# Normalize resources into floats
|
||||
# (the function may return values that are ints)
|
||||
float_resources = {k: float(v) for k, v in resources.items()}
|
||||
@@ -54,7 +53,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
table=float_resources))
|
||||
elif request.type == \
|
||||
ray_client_pb2.ClusterInfoType.AVAILABLE_RESOURCES:
|
||||
resources = ray.available_resources()
|
||||
with disable_client_hook():
|
||||
resources = ray.available_resources()
|
||||
# Normalize resources into floats
|
||||
# (the function may return values that are ints)
|
||||
float_resources = {k: float(v) for k, v in resources.items()}
|
||||
@@ -62,7 +62,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
ray_client_pb2.ClusterInfoResponse.ResourceTable(
|
||||
table=float_resources))
|
||||
else:
|
||||
resp.json = self._return_debug_cluster_info(request, context)
|
||||
with disable_client_hook():
|
||||
resp.json = self._return_debug_cluster_info(request, context)
|
||||
return resp
|
||||
|
||||
def _return_debug_cluster_info(self, request, context=None) -> str:
|
||||
@@ -118,16 +119,18 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
try:
|
||||
object_ref = \
|
||||
self.object_refs[req.client_id][req.task_object.id]
|
||||
ray.cancel(
|
||||
object_ref,
|
||||
force=req.task_object.force,
|
||||
recursive=req.task_object.recursive)
|
||||
with disable_client_hook():
|
||||
ray.cancel(
|
||||
object_ref,
|
||||
force=req.task_object.force,
|
||||
recursive=req.task_object.recursive)
|
||||
except Exception as e:
|
||||
return_exception_in_context(e, context)
|
||||
elif req.WhichOneof("terminate_type") == "actor":
|
||||
try:
|
||||
actor_ref = self.actor_refs[req.actor.id]
|
||||
ray.kill(actor_ref, no_restart=req.actor.no_restart)
|
||||
with disable_client_hook():
|
||||
ray.kill(actor_ref, no_restart=req.actor.no_restart)
|
||||
except Exception as e:
|
||||
return_exception_in_context(e, context)
|
||||
else:
|
||||
@@ -145,7 +148,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
objectref = self.object_refs[client_id][request.id]
|
||||
logger.debug("get: %s" % objectref)
|
||||
try:
|
||||
item = ray.get(objectref, timeout=request.timeout)
|
||||
with disable_client_hook():
|
||||
item = ray.get(objectref, timeout=request.timeout)
|
||||
except Exception as e:
|
||||
return ray_client_pb2.GetResponse(
|
||||
valid=False, error=cloudpickle.dumps(e))
|
||||
@@ -171,7 +175,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
context: gRPC context.
|
||||
"""
|
||||
obj = loads_from_client(request.data, self)
|
||||
objectref = ray.put(obj)
|
||||
with disable_client_hook():
|
||||
objectref = ray.put(obj)
|
||||
self.object_refs[client_id][objectref.binary()] = objectref
|
||||
logger.debug("put: %s" % objectref)
|
||||
return ray_client_pb2.PutResponse(id=objectref.binary())
|
||||
@@ -187,11 +192,12 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
num_returns = request.num_returns
|
||||
timeout = request.timeout
|
||||
try:
|
||||
ready_object_refs, remaining_object_refs = ray.wait(
|
||||
object_refs,
|
||||
num_returns=num_returns,
|
||||
timeout=timeout if timeout != -1 else None,
|
||||
)
|
||||
with disable_client_hook():
|
||||
ready_object_refs, remaining_object_refs = ray.wait(
|
||||
object_refs,
|
||||
num_returns=num_returns,
|
||||
timeout=timeout if timeout != -1 else None,
|
||||
)
|
||||
except Exception as e:
|
||||
# TODO(ameer): improve exception messages.
|
||||
logger.error(f"Exception {e}")
|
||||
@@ -215,8 +221,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
"schedule: %s %s" % (task.name,
|
||||
ray_client_pb2.ClientTask.RemoteExecType.Name(
|
||||
task.type)))
|
||||
with stash_api_for_tests(self._test_mode):
|
||||
try:
|
||||
try:
|
||||
with disable_client_hook():
|
||||
if task.type == ray_client_pb2.ClientTask.FUNCTION:
|
||||
result = self._schedule_function(task, context)
|
||||
elif task.type == ray_client_pb2.ClientTask.ACTOR:
|
||||
@@ -232,11 +238,11 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
task.type))
|
||||
result.valid = True
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Caught schedule exception {e}")
|
||||
raise e
|
||||
return ray_client_pb2.ClientTaskTicket(
|
||||
valid=False, error=cloudpickle.dumps(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Caught schedule exception {e}")
|
||||
raise e
|
||||
return ray_client_pb2.ClientTaskTicket(
|
||||
valid=False, error=cloudpickle.dumps(e))
|
||||
|
||||
def _schedule_method(self, task: ray_client_pb2.ClientTask,
|
||||
context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
@@ -307,31 +313,33 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
def lookup_or_register_func(
|
||||
self, id: bytes, client_id: str,
|
||||
options: Optional[Dict]) -> ray.remote_function.RemoteFunction:
|
||||
if id not in self.function_refs:
|
||||
funcref = self.object_refs[client_id][id]
|
||||
func = ray.get(funcref)
|
||||
if not inspect.isfunction(func):
|
||||
raise Exception("Attempting to register function that "
|
||||
"isn't a function.")
|
||||
if options is None or len(options) == 0:
|
||||
self.function_refs[id] = ray.remote(func)
|
||||
else:
|
||||
self.function_refs[id] = ray.remote(**options)(func)
|
||||
with disable_client_hook():
|
||||
if id not in self.function_refs:
|
||||
funcref = self.object_refs[client_id][id]
|
||||
func = ray.get(funcref)
|
||||
if not inspect.isfunction(func):
|
||||
raise Exception("Attempting to register function that "
|
||||
"isn't a function.")
|
||||
if options is None or len(options) == 0:
|
||||
self.function_refs[id] = ray.remote(func)
|
||||
else:
|
||||
self.function_refs[id] = ray.remote(**options)(func)
|
||||
return self.function_refs[id]
|
||||
|
||||
def lookup_or_register_actor(self, id: bytes, client_id: str,
|
||||
options: Optional[Dict]):
|
||||
if id not in self.registered_actor_classes:
|
||||
actor_class_ref = self.object_refs[client_id][id]
|
||||
actor_class = ray.get(actor_class_ref)
|
||||
if not inspect.isclass(actor_class):
|
||||
raise Exception("Attempting to schedule actor that "
|
||||
"isn't a class.")
|
||||
if options is None or len(options) == 0:
|
||||
reg_class = ray.remote(actor_class)
|
||||
else:
|
||||
reg_class = ray.remote(**options)(actor_class)
|
||||
self.registered_actor_classes[id] = reg_class
|
||||
with disable_client_hook():
|
||||
if id not in self.registered_actor_classes:
|
||||
actor_class_ref = self.object_refs[client_id][id]
|
||||
actor_class = ray.get(actor_class_ref)
|
||||
if not inspect.isclass(actor_class):
|
||||
raise Exception("Attempting to schedule actor that "
|
||||
"isn't a class.")
|
||||
if options is None or len(options) == 0:
|
||||
reg_class = ray.remote(actor_class)
|
||||
else:
|
||||
reg_class = ray.remote(**options)(actor_class)
|
||||
self.registered_actor_classes[id] = reg_class
|
||||
|
||||
return self.registered_actor_classes[id]
|
||||
|
||||
@@ -369,12 +377,22 @@ def decode_options(
|
||||
return opts
|
||||
|
||||
|
||||
def serve(connection_str, test_mode=False):
|
||||
_current_servicer: Optional[RayletServicer] = None
|
||||
|
||||
|
||||
# Used by tests to peek inside the servicer
|
||||
def _get_current_servicer():
|
||||
global _current_servicer
|
||||
return _current_servicer
|
||||
|
||||
|
||||
def serve(connection_str):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
task_servicer = RayletServicer(test_mode=test_mode)
|
||||
task_servicer = RayletServicer()
|
||||
data_servicer = DataServicer(task_servicer)
|
||||
logs_servicer = LogstreamServicer()
|
||||
_set_server_api(RayServerAPI(task_servicer))
|
||||
global _current_servicer
|
||||
_current_servicer = task_servicer
|
||||
ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
|
||||
task_servicer, server)
|
||||
ray_client_pb2_grpc.add_RayletDataStreamerServicer_to_server(
|
||||
@@ -386,12 +404,20 @@ def serve(connection_str, test_mode=False):
|
||||
return server
|
||||
|
||||
|
||||
def init_and_serve(connection_str, test_mode=False, *args, **kwargs):
|
||||
info = ray.init(*args, **kwargs)
|
||||
server = serve(connection_str, test_mode)
|
||||
def init_and_serve(connection_str, *args, **kwargs):
|
||||
with disable_client_hook():
|
||||
# Disable client mode inside the worker's environment
|
||||
info = ray.init(*args, **kwargs)
|
||||
server = serve(connection_str)
|
||||
return (server, info)
|
||||
|
||||
|
||||
def shutdown_with_server(server, _exiting_interpreter=False):
|
||||
server.stop(1)
|
||||
with disable_client_hook():
|
||||
ray.shutdown(_exiting_interpreter)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level="INFO")
|
||||
# TODO(barakmich): Perhaps wrap ray init
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
Implements the client side of the client/server pickling protocol.
|
||||
"""Implements the client side of the client/server pickling protocol.
|
||||
|
||||
These picklers are aware of the server internals and can find the
|
||||
references held for the client within the server.
|
||||
@@ -20,6 +19,7 @@ import ray
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ray._private.client_mode_hook import disable_client_hook
|
||||
from ray.experimental.client.client_pickler import PickleStub
|
||||
from ray.experimental.client.server.server_stubs import (
|
||||
ServerSelfReferenceSentinel)
|
||||
@@ -121,12 +121,13 @@ def loads_from_client(data: bytes,
|
||||
fix_imports=True,
|
||||
encoding="ASCII",
|
||||
errors="strict") -> Any:
|
||||
if isinstance(data, str):
|
||||
raise TypeError("Can't load pickle from unicode string")
|
||||
file = io.BytesIO(data)
|
||||
return ClientUnpickler(
|
||||
server_instance, file, fix_imports=fix_imports,
|
||||
encoding=encoding).load()
|
||||
with disable_client_hook():
|
||||
if isinstance(data, str):
|
||||
raise TypeError("Can't load pickle from unicode string")
|
||||
file = io.BytesIO(data)
|
||||
return ClientUnpickler(
|
||||
server_instance, file, fix_imports=fix_imports,
|
||||
encoding=encoding).load()
|
||||
|
||||
|
||||
def convert_from_arg(pb: "ray_client_pb2.Arg",
|
||||
|
||||
@@ -3,7 +3,6 @@ It implements the Ray API functions that are forwarded through grpc calls
|
||||
to the server.
|
||||
"""
|
||||
import base64
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
@@ -14,7 +13,6 @@ from typing import List
|
||||
from typing import Tuple
|
||||
from typing import Optional
|
||||
|
||||
from ray.util.inspect import is_cython
|
||||
import grpc
|
||||
|
||||
import ray.cloudpickle as cloudpickle
|
||||
@@ -23,12 +21,9 @@ import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
from ray.experimental.client.client_pickler import convert_to_arg
|
||||
from ray.experimental.client.client_pickler import dumps_from_client
|
||||
from ray.experimental.client.client_pickler import loads_from_server
|
||||
from ray.experimental.client.common import ClientActorClass
|
||||
from ray.experimental.client.common import ClientActorHandle
|
||||
from ray.experimental.client.common import ClientActorRef
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.common import ClientRemoteFunc
|
||||
from ray.experimental.client.common import ClientStub
|
||||
from ray.experimental.client.dataclient import DataClient
|
||||
from ray.experimental.client.logsclient import LogstreamClient
|
||||
|
||||
@@ -61,6 +56,7 @@ class Worker:
|
||||
|
||||
self.log_client = LogstreamClient(self.channel)
|
||||
self.log_client.set_logstream_level(logging.INFO)
|
||||
self.closed = False
|
||||
|
||||
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
|
||||
to_get = []
|
||||
@@ -153,21 +149,6 @@ class Worker:
|
||||
|
||||
return (client_ready_object_ids, client_remaining_object_ids)
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
||||
# This is the case where the decorator is just @ray.remote.
|
||||
return remote_decorator(options=None)(args[0])
|
||||
error_string = ("The @ray.remote decorator must be applied either "
|
||||
"with no arguments and no parentheses, for example "
|
||||
"'@ray.remote', or it must be applied using some of "
|
||||
"the arguments 'num_returns', 'num_cpus', 'num_gpus', "
|
||||
"'memory', 'object_store_memory', 'resources', "
|
||||
"'max_calls', or 'max_restarts', like "
|
||||
"'@ray.remote(num_returns=2, "
|
||||
"resources={\"CustomResource\": 1})'.")
|
||||
assert len(args) == 0 and len(kwargs) > 0, error_string
|
||||
return remote_decorator(options=kwargs)
|
||||
|
||||
def call_remote(self, instance, *args, **kwargs) -> List[bytes]:
|
||||
task = instance._prepare_client_task()
|
||||
for arg in args:
|
||||
@@ -190,6 +171,8 @@ class Worker:
|
||||
return ticket.return_ids
|
||||
|
||||
def call_release(self, id: bytes) -> None:
|
||||
if self.closed:
|
||||
return
|
||||
self.reference_count[id] -= 1
|
||||
if self.reference_count[id] == 0:
|
||||
self._release_server(id)
|
||||
@@ -212,6 +195,7 @@ class Worker:
|
||||
self.channel.close()
|
||||
self.channel = None
|
||||
self.server = None
|
||||
self.closed = True
|
||||
|
||||
def get_actor(self, name: str) -> ClientActorHandle:
|
||||
task = ray_client_pb2.ClientTask()
|
||||
@@ -258,7 +242,9 @@ class Worker:
|
||||
req.type = type
|
||||
resp = self.server.ClusterInfo(req)
|
||||
if resp.WhichOneof("response_type") == "resource_table":
|
||||
return resp.resource_table.table
|
||||
# translate from a proto map to a python dict
|
||||
output_dict = {k: v for k, v in resp.resource_table.table.items()}
|
||||
return output_dict
|
||||
return json.loads(resp.json)
|
||||
|
||||
def is_initialized(self) -> bool:
|
||||
@@ -268,20 +254,6 @@ class Worker:
|
||||
return False
|
||||
|
||||
|
||||
def remote_decorator(options: Optional[Dict[str, Any]]):
|
||||
def decorator(function_or_class) -> ClientStub:
|
||||
if (inspect.isfunction(function_or_class)
|
||||
or is_cython(function_or_class)):
|
||||
return ClientRemoteFunc(function_or_class, options=options)
|
||||
elif inspect.isclass(function_or_class):
|
||||
return ClientActorClass(function_or_class, options=options)
|
||||
else:
|
||||
raise TypeError("The @ray.remote decorator must be applied to "
|
||||
"either a function or to a class.")
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def make_client_id() -> str:
|
||||
id = uuid.uuid4()
|
||||
return id.hex
|
||||
|
||||
@@ -9,6 +9,7 @@ import ray
|
||||
from ray import gcs_utils
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
from ray._private import services
|
||||
from ray._private.client_mode_hook import client_mode_hook
|
||||
from ray.utils import (decode, binary_to_hex, hex_to_binary)
|
||||
|
||||
from ray._raylet import GlobalStateAccessor
|
||||
@@ -851,6 +852,7 @@ def jobs():
|
||||
return state.job_table()
|
||||
|
||||
|
||||
@client_mode_hook
|
||||
def nodes():
|
||||
"""Get a list of the nodes in the cluster (for debugging only).
|
||||
|
||||
@@ -964,6 +966,7 @@ def object_transfer_timeline(filename=None):
|
||||
return state.chrome_tracing_object_transfer_dump(filename=filename)
|
||||
|
||||
|
||||
@client_mode_hook
|
||||
def cluster_resources():
|
||||
"""Get the current total cluster resources.
|
||||
|
||||
@@ -977,6 +980,7 @@ def cluster_resources():
|
||||
return state.cluster_resources()
|
||||
|
||||
|
||||
@client_mode_hook
|
||||
def available_resources():
|
||||
"""Get the current available cluster resources.
|
||||
|
||||
|
||||
@@ -446,4 +446,4 @@ def new_scheduler_enabled():
|
||||
|
||||
|
||||
def client_test_enabled() -> bool:
|
||||
return os.environ.get("RAY_TEST_CLIENT_MODE") == "1"
|
||||
return os.environ.get("RAY_CLIENT_MODE") == "1"
|
||||
|
||||
@@ -167,7 +167,7 @@ py_test_module_list(
|
||||
name_suffix = "_client_mode",
|
||||
# TODO(barakmich): py_test will support env in Bazel 4.0.0...
|
||||
# Until then, we can use tags.
|
||||
#env = {"RAY_TEST_CLIENT_MODE": "true"},
|
||||
#env = {"RAY_CLIENT_MODE": "1"},
|
||||
tags = ["exclusive", "client_tests"],
|
||||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
||||
@@ -9,18 +9,12 @@ import subprocess
|
||||
import ray
|
||||
from ray.cluster_utils import Cluster
|
||||
from ray.test_utils import init_error_pubsub
|
||||
from ray.test_utils import client_test_enabled
|
||||
import ray.experimental.client as ray_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def shutdown_only():
|
||||
yield None
|
||||
# The code after the yield will run as teardown code.
|
||||
if client_test_enabled():
|
||||
ray_client.ray.disconnect()
|
||||
ray_client._stop_test_server(1)
|
||||
ray_client.reset_api()
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@@ -49,17 +43,10 @@ def _ray_start(**kwargs):
|
||||
init_kwargs = get_default_fixture_ray_kwargs()
|
||||
init_kwargs.update(kwargs)
|
||||
# Start the Ray processes.
|
||||
if client_test_enabled():
|
||||
address_info = ray_client.ray.init(**init_kwargs)
|
||||
else:
|
||||
address_info = ray.init(**init_kwargs)
|
||||
address_info = ray.init(**init_kwargs)
|
||||
|
||||
yield address_info
|
||||
# The code after the yield will run as teardown code.
|
||||
if client_test_enabled():
|
||||
ray_client.ray.disconnect()
|
||||
ray_client._stop_test_server(1)
|
||||
ray_client.reset_api()
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@@ -144,16 +131,9 @@ def _ray_start_cluster(**kwargs):
|
||||
# We assume driver will connect to the head (first node),
|
||||
# so ray init will be invoked if do_init is true
|
||||
if len(remote_nodes) == 1 and do_init:
|
||||
if client_test_enabled():
|
||||
ray_client.ray.init(address=cluster.address)
|
||||
else:
|
||||
ray.init(address=cluster.address)
|
||||
ray.init(address=cluster.address)
|
||||
yield cluster
|
||||
# The code after the yield will run as teardown code.
|
||||
if client_test_enabled():
|
||||
ray_client.ray.disconnect()
|
||||
ray_client._stop_test_server(1)
|
||||
ray_client.reset_api()
|
||||
ray.shutdown()
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
@@ -16,18 +16,12 @@ from ray.test_utils import wait_for_condition
|
||||
from ray.test_utils import wait_for_pid_to_exit
|
||||
from ray.tests.client_test_utils import create_remote_signal_actor
|
||||
|
||||
if client_test_enabled():
|
||||
from ray.experimental.client import ray
|
||||
else:
|
||||
import ray
|
||||
import ray
|
||||
# NOTE: We have to import setproctitle after ray because we bundle setproctitle
|
||||
# with ray.
|
||||
import setproctitle # noqa
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
client_test_enabled(),
|
||||
reason="defining early, no ray package injection yet")
|
||||
def test_caching_actors(shutdown_only):
|
||||
# Test defining actors before ray.init() has been called.
|
||||
|
||||
@@ -705,7 +699,6 @@ def test_options_num_returns(ray_start_regular_shared):
|
||||
assert ray.get([obj1, obj2]) == [1, 2]
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
|
||||
def test_options_name(ray_start_regular_shared):
|
||||
@ray.remote
|
||||
class Foo:
|
||||
|
||||
@@ -354,6 +354,8 @@ def test_illegal_api_calls(ray_start_regular):
|
||||
ray.get(3)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
client_test_enabled(), reason="grpc interaction with releasing resources")
|
||||
def test_multithreading(ray_start_2_cpus):
|
||||
# This test requires at least 2 CPUs to finish since the worker does not
|
||||
# release resources when joining the threads.
|
||||
|
||||
@@ -15,10 +15,7 @@ from ray.test_utils import (
|
||||
wait_for_pid_to_exit,
|
||||
)
|
||||
|
||||
if client_test_enabled():
|
||||
from ray.experimental.client import ray
|
||||
else:
|
||||
import ray
|
||||
import ray
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -160,8 +160,7 @@ def test_basic_actor(ray_start_regular_shared):
|
||||
|
||||
|
||||
def test_pass_handles(ray_start_regular_shared):
|
||||
"""
|
||||
Test that passing client handles to actors and functions to remote actors
|
||||
"""Test that passing client handles to actors and functions to remote actors
|
||||
in functions (on the server or raylet side) works transparently to the
|
||||
caller.
|
||||
"""
|
||||
@@ -264,9 +263,32 @@ def test_stdout_log_stream(ray_start_regular_shared):
|
||||
assert all((msg.find("Hello world") for msg in log_msgs))
|
||||
|
||||
|
||||
def test_basic_named_actor(ray_start_regular_shared):
|
||||
def test_create_remote_before_start(ray_start_regular_shared):
|
||||
"""Creates remote objects (as though in a library) before
|
||||
starting the client.
|
||||
"""
|
||||
Test that ray.get_actor() can create and return a detached actor.
|
||||
from ray.experimental.client import ray
|
||||
|
||||
@ray.remote
|
||||
class Returner:
|
||||
def doit(self):
|
||||
return "foo"
|
||||
|
||||
@ray.remote
|
||||
def f(x):
|
||||
return x + 20
|
||||
|
||||
# Prints in verbose tests
|
||||
print("Created remote functions")
|
||||
|
||||
with ray_start_client_server() as ray:
|
||||
assert ray.get(f.remote(3)) == 23
|
||||
a = Returner.remote()
|
||||
assert ray.get(a.doit.remote()) == "foo"
|
||||
|
||||
|
||||
def test_basic_named_actor(ray_start_regular_shared):
|
||||
"""Test that ray.get_actor() can create and return a detached actor.
|
||||
"""
|
||||
with ray_start_client_server() as ray:
|
||||
|
||||
|
||||
@@ -2,8 +2,7 @@ from ray.experimental.client.ray_client_helpers import ray_start_client_server
|
||||
|
||||
|
||||
def test_get_ray_metadata(ray_start_regular_shared):
|
||||
"""
|
||||
Test the ClusterInfo client data pathway and API surface
|
||||
"""Test the ClusterInfo client data pathway and API surface
|
||||
"""
|
||||
with ray_start_client_server() as ray:
|
||||
ip_address = ray_start_regular_shared["node_ip_address"]
|
||||
|
||||
@@ -2,11 +2,11 @@ from ray.experimental.client.ray_client_helpers import ray_start_client_server
|
||||
from ray.test_utils import wait_for_condition
|
||||
import ray as real_ray
|
||||
from ray.core.generated.gcs_pb2 import ActorTableData
|
||||
from ray.experimental.client import _get_server_instance
|
||||
from ray.experimental.client.server.server import _get_current_servicer
|
||||
|
||||
|
||||
def server_object_ref_count(n):
|
||||
server = _get_server_instance()
|
||||
server = _get_current_servicer()
|
||||
assert server is not None
|
||||
|
||||
def test_cond():
|
||||
@@ -20,7 +20,7 @@ def server_object_ref_count(n):
|
||||
|
||||
|
||||
def server_actor_ref_count(n):
|
||||
server = _get_server_instance()
|
||||
server = _get_current_servicer()
|
||||
assert server is not None
|
||||
|
||||
def test_cond():
|
||||
|
||||
@@ -51,6 +51,7 @@ from ray.ray_logging import setup_logger
|
||||
from ray.ray_logging import global_worker_stdstream_dispatcher
|
||||
from ray.utils import _random_string, check_oversized_pickle
|
||||
from ray.util.inspect import is_cython
|
||||
from ray._private.client_mode_hook import client_mode_hook
|
||||
|
||||
SCRIPT_MODE = 0
|
||||
WORKER_MODE = 1
|
||||
@@ -469,6 +470,7 @@ _global_node = None
|
||||
"""ray.node.Node: The global node object that is created by ray.init()."""
|
||||
|
||||
|
||||
@client_mode_hook
|
||||
def init(
|
||||
address=None,
|
||||
*,
|
||||
@@ -781,6 +783,7 @@ def init(
|
||||
_post_init_hooks = []
|
||||
|
||||
|
||||
@client_mode_hook
|
||||
def shutdown(_exiting_interpreter=False):
|
||||
"""Disconnect the worker, and terminate processes started by ray.init().
|
||||
|
||||
@@ -1044,6 +1047,7 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
|
||||
worker.error_message_pubsub_client.close()
|
||||
|
||||
|
||||
@client_mode_hook
|
||||
def is_initialized():
|
||||
"""Check if ray.init has been called yet.
|
||||
|
||||
@@ -1322,6 +1326,7 @@ def show_in_dashboard(message, key="", dtype="text"):
|
||||
blocking_get_inside_async_warned = False
|
||||
|
||||
|
||||
@client_mode_hook
|
||||
def get(object_refs, *, timeout=None):
|
||||
"""Get a remote object or a list of remote objects from the object store.
|
||||
|
||||
@@ -1400,6 +1405,7 @@ def get(object_refs, *, timeout=None):
|
||||
return values
|
||||
|
||||
|
||||
@client_mode_hook
|
||||
def put(value):
|
||||
"""Store an object in the object store.
|
||||
|
||||
@@ -1428,6 +1434,7 @@ def put(value):
|
||||
blocking_wait_inside_async_warned = False
|
||||
|
||||
|
||||
@client_mode_hook
|
||||
def wait(object_refs, *, num_returns=1, timeout=None, fetch_local=True):
|
||||
"""Return a list of IDs that are ready and a list of IDs that are not.
|
||||
|
||||
@@ -1528,6 +1535,7 @@ def wait(object_refs, *, num_returns=1, timeout=None, fetch_local=True):
|
||||
return ready_ids, remaining_ids
|
||||
|
||||
|
||||
@client_mode_hook
|
||||
def get_actor(name):
|
||||
"""Get a handle to a detached actor.
|
||||
|
||||
@@ -1548,6 +1556,7 @@ def get_actor(name):
|
||||
return handle
|
||||
|
||||
|
||||
@client_mode_hook
|
||||
def kill(actor, *, no_restart=True):
|
||||
"""Kill an actor forcefully.
|
||||
|
||||
@@ -1575,6 +1584,7 @@ def kill(actor, *, no_restart=True):
|
||||
worker.core_worker.kill_actor(actor._ray_actor_id, no_restart)
|
||||
|
||||
|
||||
@client_mode_hook
|
||||
def cancel(object_ref, *, force=False, recursive=True):
|
||||
"""Cancels a task according to the following conditions.
|
||||
|
||||
@@ -1691,6 +1701,7 @@ def make_decorator(num_returns=None,
|
||||
return decorator
|
||||
|
||||
|
||||
@client_mode_hook
|
||||
def remote(*args, **kwargs):
|
||||
"""Defines a remote function or an actor class.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user