mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 15:56:55 +08:00
[ray_client] actors v0 (#12388)
This commit is contained in:
+9
-5
@@ -12,6 +12,11 @@ from ray.util.placement_group import (
|
||||
from ray import ActorClassID, Language
|
||||
from ray._raylet import PythonFunctionDescriptor
|
||||
from ray import cross_language
|
||||
from ray.util.inspect import (
|
||||
is_function_or_method,
|
||||
is_class_method,
|
||||
is_static_method,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -195,7 +200,7 @@ class ActorClassMethodMetadata(object):
|
||||
self = cls.__new__(cls)
|
||||
|
||||
actor_methods = inspect.getmembers(modified_class,
|
||||
ray.utils.is_function_or_method)
|
||||
is_function_or_method)
|
||||
self.methods = dict(actor_methods)
|
||||
|
||||
# Extract the signatures of each of the methods. This will be used
|
||||
@@ -208,9 +213,8 @@ class ActorClassMethodMetadata(object):
|
||||
# Whether or not this method requires binding of its first
|
||||
# argument. For class and static methods, we do not want to bind
|
||||
# the first argument, but we do for instance methods
|
||||
is_bound = (ray.utils.is_class_method(method)
|
||||
or ray.utils.is_static_method(modified_class,
|
||||
method_name))
|
||||
is_bound = (is_class_method(method)
|
||||
or is_static_method(modified_class, method_name))
|
||||
|
||||
# Print a warning message if the method signature is not
|
||||
# supported. We don't raise an exception because if the actor
|
||||
@@ -956,7 +960,7 @@ def modify_class(cls):
|
||||
Class.__module__ = cls.__module__
|
||||
Class.__name__ = cls.__name__
|
||||
|
||||
if not ray.utils.is_function_or_method(getattr(Class, "__init__", None)):
|
||||
if not is_function_or_method(getattr(Class, "__init__", None)):
|
||||
# Add __init__ if it does not exist.
|
||||
# Actor creation will be executed with __init__ together.
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from ray.experimental.client.api import ClientAPI
|
||||
from ray.experimental.client.api import APIImpl
|
||||
from typing import Optional, List, Tuple
|
||||
from contextlib import contextmanager
|
||||
|
||||
import logging
|
||||
|
||||
@@ -14,6 +15,16 @@ logger = logging.getLogger(__name__)
|
||||
_client_api: Optional[APIImpl] = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def stash_api_for_tests(in_test: bool):
|
||||
api = None
|
||||
if in_test:
|
||||
api = stash_api()
|
||||
yield api
|
||||
if in_test:
|
||||
restore_api(api)
|
||||
|
||||
|
||||
def stash_api() -> Optional[APIImpl]:
|
||||
global _client_api
|
||||
a = _client_api
|
||||
|
||||
@@ -31,7 +31,7 @@ class APIImpl(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def call_remote(self, f, *args, **kwargs):
|
||||
def call_remote(self, f, kind, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -55,8 +55,8 @@ class ClientAPI(APIImpl):
|
||||
def remote(self, *args, **kwargs):
|
||||
return self.worker.remote(*args, **kwargs)
|
||||
|
||||
def call_remote(self, f, *args, **kwargs):
|
||||
return self.worker.call_remote(f, *args, **kwargs)
|
||||
def call_remote(self, f, kind, *args, **kwargs):
|
||||
return self.worker.call_remote(f, kind, *args, **kwargs)
|
||||
|
||||
def close(self, *args, **kwargs):
|
||||
return self.worker.close()
|
||||
|
||||
@@ -1,8 +1,30 @@
|
||||
from ray.experimental.client import ray
|
||||
from typing import Tuple
|
||||
|
||||
ray.connect("localhost:50051")
|
||||
|
||||
|
||||
@ray.remote
|
||||
class HelloActor:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
|
||||
def say_hello(self, whom: str) -> Tuple[str, int]:
|
||||
self.count += 1
|
||||
return ("Hello " + whom, self.count)
|
||||
|
||||
|
||||
actor = HelloActor.remote()
|
||||
s, count = ray.get(actor.say_hello.remote("you"))
|
||||
print(s, count)
|
||||
assert s == "Hello you"
|
||||
assert count == 1
|
||||
s, count = ray.get(actor.say_hello.remote("world"))
|
||||
print(s, count)
|
||||
assert s == "Hello world"
|
||||
assert count == 2
|
||||
|
||||
|
||||
@ray.remote
|
||||
def plus2(x):
|
||||
return x + 2
|
||||
|
||||
@@ -4,17 +4,28 @@ from typing import Any
|
||||
from ray import cloudpickle
|
||||
|
||||
|
||||
class ClientObjectRef:
|
||||
class ClientBaseRef:
|
||||
def __init__(self, id):
|
||||
self.id = id
|
||||
|
||||
def __repr__(self):
|
||||
return "ClientObjectRef(%s)" % self.id.hex()
|
||||
return "%s(%s)" % (
|
||||
type(self).__name__,
|
||||
self.id.hex(),
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.id == other.id
|
||||
|
||||
|
||||
class ClientObjectRef(ClientBaseRef):
|
||||
pass
|
||||
|
||||
|
||||
class ClientActorRef(ClientBaseRef):
|
||||
pass
|
||||
|
||||
|
||||
class ClientRemoteFunc:
|
||||
def __init__(self, f):
|
||||
self._func = f
|
||||
@@ -27,12 +38,64 @@ class ClientRemoteFunc:
|
||||
"Use {self._name}.remote method instead")
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
return ray.call_remote(self, *args, **kwargs)
|
||||
return ray.call_remote(self, ray_client_pb2.ClientTask.FUNCTION, *args,
|
||||
**kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return "ClientRemoteFunc(%s, %s)" % (self._name, self.id)
|
||||
|
||||
|
||||
class ClientActorClass:
|
||||
def __init__(self, actor_cls):
|
||||
self.actor_cls = actor_cls
|
||||
self._name = actor_cls.__name__
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise TypeError(f"Remote actor cannot be instantiated directly. "
|
||||
"Use {self._name}.remote() instead")
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
# Actually instantiate the actor
|
||||
ref = ray.call_remote(self, ray_client_pb2.ClientTask.ACTOR, *args,
|
||||
**kwargs)
|
||||
return ClientActorHandle(ref, self)
|
||||
|
||||
def __repr__(self):
|
||||
return "ClientRemoteActor(%s, %s)" % (self._name, self.id)
|
||||
|
||||
def __getattr__(self, key):
|
||||
raise NotImplementedError("static methods")
|
||||
|
||||
|
||||
class ClientActorHandle:
|
||||
def __init__(self, actor_id: ClientActorRef,
|
||||
actor_class: ClientActorClass):
|
||||
self.actor_id = actor_id
|
||||
self.actor_class = actor_class
|
||||
|
||||
def __getattr__(self, key):
|
||||
return ClientRemoteMethod(self, key)
|
||||
|
||||
|
||||
class ClientRemoteMethod:
|
||||
def __init__(self, actor_handle: ClientActorHandle, method_name: str):
|
||||
self.actor_handle = actor_handle
|
||||
self.method_name = method_name
|
||||
self._name = "%s.%s" % (self.actor_handle.actor_class._name,
|
||||
self.method_name)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise TypeError(f"Remote method cannot be called directly. "
|
||||
"Use {self._name}.remote() instead")
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
return ray.call_remote(self, ray_client_pb2.ClientTask.METHOD, *args,
|
||||
**kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return "ClientRemoteMethod(%s, %s)" % (self._name, self.actor_id)
|
||||
|
||||
|
||||
def convert_from_arg(pb) -> Any:
|
||||
if pb.local == ray_client_pb2.Arg.Locality.REFERENCE:
|
||||
return ClientObjectRef(pb.reference_id)
|
||||
|
||||
@@ -26,7 +26,7 @@ class CoreRayAPI(APIImpl):
|
||||
def remote(self, *args, **kwargs):
|
||||
return ray.remote(*args, **kwargs)
|
||||
|
||||
def call_remote(self, f: ClientRemoteFunc, *args, **kwargs):
|
||||
def call_remote(self, f: ClientRemoteFunc, kind: int, *args, **kwargs):
|
||||
if f._raylet_remote_func is None:
|
||||
f._raylet_remote_func = ray.remote(f._func)
|
||||
return f._raylet_remote_func.remote(*args, **kwargs)
|
||||
|
||||
@@ -6,7 +6,8 @@ import ray
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
import time
|
||||
import ray.experimental.client as client_init
|
||||
import inspect
|
||||
from ray.experimental.client import stash_api_for_tests
|
||||
from ray.experimental.client.common import convert_from_arg
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.common import ClientRemoteFunc
|
||||
@@ -18,6 +19,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
def __init__(self, test_mode=False):
|
||||
self.object_refs = {}
|
||||
self.function_refs = {}
|
||||
self.actor_refs = {}
|
||||
self.registered_actor_classes = {}
|
||||
self._test_mode = test_mode
|
||||
|
||||
def GetObject(self, request, context=None):
|
||||
@@ -67,25 +70,66 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
ready_object_ids=ready_object_ids,
|
||||
remaining_object_ids=remaining_object_ids)
|
||||
|
||||
def Schedule(self, task, context=None):
|
||||
logger.info("schedule: %s" % task)
|
||||
def Schedule(self, task, context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
logger.info("schedule: %s %s" %
|
||||
(task.name,
|
||||
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type)))
|
||||
if task.type == ray_client_pb2.ClientTask.FUNCTION:
|
||||
return self._schedule_function(task, context)
|
||||
elif task.type == ray_client_pb2.ClientTask.ACTOR:
|
||||
return self._schedule_actor(task, context)
|
||||
elif task.type == ray_client_pb2.ClientTask.METHOD:
|
||||
return self._schedule_method(task, context)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Unimplemented Schedule task type: %s" %
|
||||
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type))
|
||||
|
||||
def _schedule_method(self, task: ray_client_pb2.ClientTask,
|
||||
context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
actor_handle = self.actor_refs.get(task.payload_id)
|
||||
if actor_handle is None:
|
||||
raise Exception(
|
||||
"Can't run an actor the server doesn't have a handle for")
|
||||
arglist = _convert_args(task.args)
|
||||
with stash_api_for_tests(self._test_mode):
|
||||
output = getattr(actor_handle, task.name).remote(*arglist)
|
||||
self.object_refs[output.binary()] = output
|
||||
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
|
||||
|
||||
def _schedule_actor(self, task: ray_client_pb2.ClientTask,
|
||||
context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
with stash_api_for_tests(self._test_mode):
|
||||
if task.payload_id not in self.registered_actor_classes:
|
||||
actor_class_ref = self.object_refs[task.payload_id]
|
||||
actor_class = ray.get(actor_class_ref)
|
||||
if not inspect.isclass(actor_class):
|
||||
raise Exception("Attempting to schedule actor that "
|
||||
"isn't a ClientActorClass.")
|
||||
reg_class = ray.remote(actor_class)
|
||||
self.registered_actor_classes[task.payload_id] = reg_class
|
||||
remote_class = self.registered_actor_classes[task.payload_id]
|
||||
arglist = _convert_args(task.args)
|
||||
actor = remote_class.remote(*arglist)
|
||||
actor_ref = actor._actor_id
|
||||
self.actor_refs[actor_ref.binary()] = actor
|
||||
return ray_client_pb2.ClientTaskTicket(return_id=actor_ref.binary())
|
||||
|
||||
def _schedule_function(self, task: ray_client_pb2.ClientTask,
|
||||
context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
if task.payload_id not in self.function_refs:
|
||||
funcref = self.object_refs[task.payload_id]
|
||||
func = ray.get(funcref)
|
||||
if not isinstance(func, ClientRemoteFunc):
|
||||
raise Exception("Attempting to schedule something that "
|
||||
"isn't a ClientRemoteFunc")
|
||||
raise Exception("Attempting to schedule function that "
|
||||
"isn't a ClientRemoteFunc.")
|
||||
self.function_refs[task.payload_id] = func
|
||||
remote_func = self.function_refs[task.payload_id]
|
||||
arglist = _convert_args(task.args)
|
||||
# Prepare call if we're in a test
|
||||
api = None
|
||||
if self._test_mode:
|
||||
api = client_init.stash_api()
|
||||
output = remote_func.remote(*arglist)
|
||||
if self._test_mode:
|
||||
client_init.restore_api(api)
|
||||
self.object_refs[output.binary()] = output
|
||||
with stash_api_for_tests(self._test_mode):
|
||||
output = remote_func.remote(*arglist)
|
||||
self.object_refs[output.binary()] = output
|
||||
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
|
||||
|
||||
|
||||
|
||||
@@ -2,15 +2,21 @@
|
||||
It implements the Ray API functions that are forwarded through grpc calls
|
||||
to the server.
|
||||
"""
|
||||
from typing import List, Tuple
|
||||
import inspect
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import ray.cloudpickle as cloudpickle
|
||||
from ray.util.inspect import is_cython
|
||||
import grpc
|
||||
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
from ray.experimental.client.common import convert_to_arg
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.common import ClientActorRef
|
||||
from ray.experimental.client.common import ClientActorClass
|
||||
from ray.experimental.client.common import ClientRemoteMethod
|
||||
from ray.experimental.client.common import ClientRemoteFunc
|
||||
|
||||
|
||||
@@ -87,7 +93,7 @@ class Worker:
|
||||
*,
|
||||
num_returns: int = 1,
|
||||
timeout: float = None
|
||||
) -> (List[ClientObjectRef], List[ClientObjectRef]):
|
||||
) -> Tuple[List[ClientObjectRef], List[ClientObjectRef]]:
|
||||
assert isinstance(object_refs, list)
|
||||
for ref in object_refs:
|
||||
assert isinstance(ref, ClientObjectRef)
|
||||
@@ -112,21 +118,62 @@ class Worker:
|
||||
|
||||
return (client_ready_object_ids, client_remaining_object_ids)
|
||||
|
||||
def remote(self, func):
|
||||
return ClientRemoteFunc(func)
|
||||
def remote(self, function_or_class, *args, **kwargs):
|
||||
# TODO(barakmich): Arguments to ray.remote
|
||||
# get captured here.
|
||||
if (inspect.isfunction(function_or_class)
|
||||
or is_cython(function_or_class)):
|
||||
return ClientRemoteFunc(function_or_class)
|
||||
elif inspect.isclass(function_or_class):
|
||||
return ClientActorClass(function_or_class)
|
||||
else:
|
||||
raise TypeError("The @ray.remote decorator must be applied to "
|
||||
"either a function or to a class.")
|
||||
|
||||
def call_remote(self, func, *args, **kwargs):
|
||||
if not isinstance(func, ClientRemoteFunc):
|
||||
raise TypeError("Client not passing a ClientRemoteFunc stub")
|
||||
func_ref = self._put(func)
|
||||
def call_remote(self, instance, kind, *args, **kwargs):
|
||||
ticket = None
|
||||
if kind == ray_client_pb2.ClientTask.FUNCTION:
|
||||
ticket = self._put_and_schedule(instance, kind, *args, **kwargs)
|
||||
elif kind == ray_client_pb2.ClientTask.ACTOR:
|
||||
ticket = self._put_and_schedule(instance, kind, *args, **kwargs)
|
||||
return ClientActorRef(ticket.return_id)
|
||||
elif kind == ray_client_pb2.ClientTask.METHOD:
|
||||
ticket = self._call_method(instance, *args, **kwargs)
|
||||
|
||||
if ticket is None:
|
||||
raise Exception(
|
||||
"Couldn't call_remote on %s for type %s" % (instance, kind))
|
||||
return ClientObjectRef(ticket.return_id)
|
||||
|
||||
def _call_method(self, instance: ClientRemoteMethod, *args, **kwargs):
|
||||
if not isinstance(instance, ClientRemoteMethod):
|
||||
raise TypeError("Client not passing a ClientRemoteMethod stub")
|
||||
task = ray_client_pb2.ClientTask()
|
||||
task.name = func._name
|
||||
task.payload_id = func_ref.id
|
||||
task.type = ray_client_pb2.ClientTask.METHOD
|
||||
task.name = instance.method_name
|
||||
task.payload_id = instance.actor_handle.actor_id.id
|
||||
for arg in args:
|
||||
pb_arg = convert_to_arg(arg)
|
||||
task.args.append(pb_arg)
|
||||
ticket = self.server.Schedule(task, metadata=self.metadata)
|
||||
return ClientObjectRef(ticket.return_id)
|
||||
return ticket
|
||||
|
||||
def _put_and_schedule(self, item, task_type, *args, **kwargs):
|
||||
if isinstance(item, ClientRemoteFunc):
|
||||
ref = self._put(item)
|
||||
elif isinstance(item, ClientActorClass):
|
||||
ref = self._put(item.actor_cls)
|
||||
else:
|
||||
raise TypeError("Client not passing a ClientRemoteFunc stub")
|
||||
task = ray_client_pb2.ClientTask()
|
||||
task.type = task_type
|
||||
task.name = item._name
|
||||
task.payload_id = ref.id
|
||||
for arg in args:
|
||||
pb_arg = convert_to_arg(arg)
|
||||
task.args.append(pb_arg)
|
||||
ticket = self.server.Schedule(task, metadata=self.metadata)
|
||||
return ticket
|
||||
|
||||
def close(self):
|
||||
self.channel.close()
|
||||
|
||||
@@ -19,15 +19,17 @@ from ray import ray_constants
|
||||
from ray import cloudpickle as pickle
|
||||
from ray._raylet import PythonFunctionDescriptor
|
||||
from ray.utils import (
|
||||
is_function_or_method,
|
||||
is_class_method,
|
||||
is_static_method,
|
||||
check_oversized_pickle,
|
||||
decode,
|
||||
ensure_str,
|
||||
format_error_message,
|
||||
push_error_to_driver,
|
||||
)
|
||||
from ray.util.inspect import (
|
||||
is_function_or_method,
|
||||
is_class_method,
|
||||
is_static_method,
|
||||
)
|
||||
|
||||
FunctionExecutionInfo = namedtuple("FunctionExecutionInfo",
|
||||
["function", "function_name", "max_calls"])
|
||||
|
||||
@@ -2,7 +2,7 @@ import inspect
|
||||
from inspect import Parameter
|
||||
import logging
|
||||
|
||||
from ray.utils import is_cython
|
||||
from ray.util.inspect import is_cython
|
||||
|
||||
# Logger for this module. It should be configured at the entry point
|
||||
# into the program using Ray. Ray provides a default configuration at
|
||||
|
||||
@@ -1,162 +1,173 @@
|
||||
import pytest
|
||||
from contextlib import contextmanager
|
||||
|
||||
import ray.experimental.client.server.server as ray_client_server
|
||||
from ray.experimental.client import ray
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
|
||||
|
||||
def test_real_ray_fallback(ray_start_regular_shared):
|
||||
@contextmanager
|
||||
def ray_start_client_server():
|
||||
server = ray_client_server.serve("localhost:50051", test_mode=True)
|
||||
ray.connect("localhost:50051")
|
||||
|
||||
@ray.remote
|
||||
def get_nodes_real():
|
||||
import ray as real_ray
|
||||
return real_ray.nodes()
|
||||
|
||||
nodes = ray.get(get_nodes_real.remote())
|
||||
assert len(nodes) == 1, nodes
|
||||
|
||||
@ray.remote
|
||||
def get_nodes():
|
||||
return ray.nodes() # Can access the full Ray API in remote methods.
|
||||
|
||||
nodes = ray.get(get_nodes.remote())
|
||||
assert len(nodes) == 1, nodes
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
print(ray.nodes())
|
||||
|
||||
yield ray
|
||||
ray.disconnect()
|
||||
server.stop(0)
|
||||
|
||||
|
||||
def test_real_ray_fallback(ray_start_regular_shared):
|
||||
with ray_start_client_server() as ray:
|
||||
|
||||
@ray.remote
|
||||
def get_nodes_real():
|
||||
import ray as real_ray
|
||||
return real_ray.nodes()
|
||||
|
||||
nodes = ray.get(get_nodes_real.remote())
|
||||
assert len(nodes) == 1, nodes
|
||||
|
||||
@ray.remote
|
||||
def get_nodes():
|
||||
# Can access the full Ray API in remote methods.
|
||||
return ray.nodes()
|
||||
|
||||
nodes = ray.get(get_nodes.remote())
|
||||
assert len(nodes) == 1, nodes
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
print(ray.nodes())
|
||||
|
||||
|
||||
def test_nested_function(ray_start_regular_shared):
|
||||
server = ray_client_server.serve("localhost:50051", test_mode=True)
|
||||
ray.connect("localhost:50051")
|
||||
with ray_start_client_server() as ray:
|
||||
|
||||
@ray.remote
|
||||
def g():
|
||||
@ray.remote
|
||||
def f():
|
||||
return "OK"
|
||||
def g():
|
||||
@ray.remote
|
||||
def f():
|
||||
return "OK"
|
||||
|
||||
return ray.get(f.remote())
|
||||
return ray.get(f.remote())
|
||||
|
||||
assert ray.get(g.remote()) == "OK"
|
||||
ray.disconnect()
|
||||
server.stop(0)
|
||||
assert ray.get(g.remote()) == "OK"
|
||||
|
||||
|
||||
def test_put_get(ray_start_regular_shared):
|
||||
server = ray_client_server.serve("localhost:50051", test_mode=True)
|
||||
ray.connect("localhost:50051")
|
||||
with ray_start_client_server() as ray:
|
||||
objectref = ray.put("hello world")
|
||||
print(objectref)
|
||||
|
||||
objectref = ray.put("hello world")
|
||||
print(objectref)
|
||||
|
||||
retval = ray.get(objectref)
|
||||
assert retval == "hello world"
|
||||
ray.disconnect()
|
||||
server.stop(0)
|
||||
retval = ray.get(objectref)
|
||||
assert retval == "hello world"
|
||||
|
||||
|
||||
def test_wait(ray_start_regular_shared):
|
||||
server = ray_client_server.serve("localhost:50051", test_mode=True)
|
||||
ray.connect("localhost:50051")
|
||||
with ray_start_client_server() as ray:
|
||||
objectref = ray.put("hello world")
|
||||
ready, remaining = ray.wait([objectref])
|
||||
assert remaining == []
|
||||
retval = ray.get(ready[0])
|
||||
assert retval == "hello world"
|
||||
|
||||
objectref = ray.put("hello world")
|
||||
ready, remaining = ray.wait([objectref])
|
||||
assert remaining == []
|
||||
retval = ray.get(ready[0])
|
||||
assert retval == "hello world"
|
||||
objectref2 = ray.put(5)
|
||||
ready, remaining = ray.wait([objectref, objectref2])
|
||||
assert (ready, remaining) == ([objectref], [objectref2]) or \
|
||||
(ready, remaining) == ([objectref2], [objectref])
|
||||
ready_retval = ray.get(ready[0])
|
||||
remaining_retval = ray.get(remaining[0])
|
||||
assert (ready_retval, remaining_retval) == ("hello world", 5) \
|
||||
or (ready_retval, remaining_retval) == (5, "hello world")
|
||||
|
||||
objectref2 = ray.put(5)
|
||||
ready, remaining = ray.wait([objectref, objectref2])
|
||||
assert (ready, remaining) == ([objectref], [objectref2]) or \
|
||||
(ready, remaining) == ([objectref2], [objectref])
|
||||
ready_retval = ray.get(ready[0])
|
||||
remaining_retval = ray.get(remaining[0])
|
||||
assert (ready_retval, remaining_retval) == ("hello world", 5) \
|
||||
or (ready_retval, remaining_retval) == (5, "hello world")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# Reference not in the object store.
|
||||
ray.wait([ClientObjectRef("blabla")])
|
||||
with pytest.raises(AssertionError):
|
||||
ray.wait("blabla")
|
||||
with pytest.raises(AssertionError):
|
||||
ray.wait(ClientObjectRef("blabla"))
|
||||
with pytest.raises(AssertionError):
|
||||
ray.wait(["blabla"])
|
||||
|
||||
ray.disconnect()
|
||||
server.stop(0)
|
||||
with pytest.raises(Exception):
|
||||
# Reference not in the object store.
|
||||
ray.wait([ClientObjectRef("blabla")])
|
||||
with pytest.raises(AssertionError):
|
||||
ray.wait("blabla")
|
||||
with pytest.raises(AssertionError):
|
||||
ray.wait(ClientObjectRef("blabla"))
|
||||
with pytest.raises(AssertionError):
|
||||
ray.wait(["blabla"])
|
||||
|
||||
|
||||
def test_remote_functions(ray_start_regular_shared):
|
||||
server = ray_client_server.serve("localhost:50051", test_mode=True)
|
||||
ray.connect("localhost:50051")
|
||||
with ray_start_client_server() as ray:
|
||||
|
||||
@ray.remote
|
||||
def plus2(x):
|
||||
return x + 2
|
||||
@ray.remote
|
||||
def plus2(x):
|
||||
return x + 2
|
||||
|
||||
@ray.remote
|
||||
def fact(x):
|
||||
print(x, type(fact))
|
||||
if x <= 0:
|
||||
return 1
|
||||
# This hits the "nested tasks" issue
|
||||
# https://github.com/ray-project/ray/issues/3644
|
||||
# So we're on the right track!
|
||||
return ray.get(fact.remote(x - 1)) * x
|
||||
@ray.remote
|
||||
def fact(x):
|
||||
print(x, type(fact))
|
||||
if x <= 0:
|
||||
return 1
|
||||
# This hits the "nested tasks" issue
|
||||
# https://github.com/ray-project/ray/issues/3644
|
||||
# So we're on the right track!
|
||||
return ray.get(fact.remote(x - 1)) * x
|
||||
|
||||
ref2 = plus2.remote(234)
|
||||
# `236`
|
||||
assert ray.get(ref2) == 236
|
||||
ref2 = plus2.remote(234)
|
||||
# `236`
|
||||
assert ray.get(ref2) == 236
|
||||
|
||||
ref3 = fact.remote(20)
|
||||
# `2432902008176640000`
|
||||
assert ray.get(ref3) == 2_432_902_008_176_640_000
|
||||
ref3 = fact.remote(20)
|
||||
# `2432902008176640000`
|
||||
assert ray.get(ref3) == 2_432_902_008_176_640_000
|
||||
|
||||
# Reuse the cached ClientRemoteFunc object
|
||||
ref4 = fact.remote(5)
|
||||
assert ray.get(ref4) == 120
|
||||
# Reuse the cached ClientRemoteFunc object
|
||||
ref4 = fact.remote(5)
|
||||
assert ray.get(ref4) == 120
|
||||
|
||||
# Test ray.wait()
|
||||
ref5 = fact.remote(10)
|
||||
# should return ref2, ref3, ref4
|
||||
res = ray.wait([ref5, ref2, ref3, ref4], num_returns=3)
|
||||
assert [ref2, ref3, ref4] == res[0]
|
||||
assert [ref5] == res[1]
|
||||
assert ray.get(res[0]) == [236, 2_432_902_008_176_640_000, 120]
|
||||
# should return ref2, ref3, ref4, ref5
|
||||
res = ray.wait([ref2, ref3, ref4, ref5], num_returns=4)
|
||||
assert [ref2, ref3, ref4, ref5] == res[0]
|
||||
assert [] == res[1]
|
||||
assert ray.get(res[0]) == [236, 2_432_902_008_176_640_000, 120, 3628800]
|
||||
|
||||
ray.disconnect()
|
||||
server.stop(0)
|
||||
# Test ray.wait()
|
||||
ref5 = fact.remote(10)
|
||||
# should return ref2, ref3, ref4
|
||||
res = ray.wait([ref5, ref2, ref3, ref4], num_returns=3)
|
||||
assert [ref2, ref3, ref4] == res[0]
|
||||
assert [ref5] == res[1]
|
||||
assert ray.get(res[0]) == [236, 2_432_902_008_176_640_000, 120]
|
||||
# should return ref2, ref3, ref4, ref5
|
||||
res = ray.wait([ref2, ref3, ref4, ref5], num_returns=4)
|
||||
assert [ref2, ref3, ref4, ref5] == res[0]
|
||||
assert [] == res[1]
|
||||
all_vals = ray.get(res[0])
|
||||
assert all_vals == [236, 2_432_902_008_176_640_000, 120, 3628800]
|
||||
|
||||
|
||||
def test_function_calling_function(ray_start_regular_shared):
|
||||
server = ray_client_server.serve("localhost:50051", test_mode=True)
|
||||
ray.connect("localhost:50051")
|
||||
with ray_start_client_server() as ray:
|
||||
|
||||
@ray.remote
|
||||
def g():
|
||||
return "OK"
|
||||
@ray.remote
|
||||
def g():
|
||||
return "OK"
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
print(f, f._name, g._name, g)
|
||||
return ray.get(g.remote())
|
||||
@ray.remote
|
||||
def f():
|
||||
print(f, f._name, g._name, g)
|
||||
return ray.get(g.remote())
|
||||
|
||||
print(f, type(f))
|
||||
assert ray.get(f.remote()) == "OK"
|
||||
ray.disconnect()
|
||||
server.stop(0)
|
||||
print(f, type(f))
|
||||
assert ray.get(f.remote()) == "OK"
|
||||
|
||||
|
||||
def test_basic_actor(ray_start_regular_shared):
|
||||
with ray_start_client_server() as ray:
|
||||
|
||||
@ray.remote
|
||||
class HelloActor:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
|
||||
def say_hello(self, whom):
|
||||
self.count += 1
|
||||
return ("Hello " + whom, self.count)
|
||||
|
||||
actor = HelloActor.remote()
|
||||
s, count = ray.get(actor.say_hello.remote("you"))
|
||||
assert s == "Hello you"
|
||||
assert count == 1
|
||||
s, count = ray.get(actor.say_hello.remote("world"))
|
||||
assert s == "Hello world"
|
||||
assert count == 2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
import inspect
|
||||
|
||||
|
||||
def is_cython(obj):
|
||||
"""Check if an object is a Cython function or method"""
|
||||
|
||||
# TODO(suo): We could split these into two functions, one for Cython
|
||||
# functions and another for Cython methods.
|
||||
# TODO(suo): There doesn't appear to be a Cython function 'type' we can
|
||||
# check against via isinstance. Please correct me if I'm wrong.
|
||||
def check_cython(x):
|
||||
return type(x).__name__ == "cython_function_or_method"
|
||||
|
||||
# Check if function or method, respectively
|
||||
return check_cython(obj) or \
|
||||
(hasattr(obj, "__func__") and check_cython(obj.__func__))
|
||||
|
||||
|
||||
def is_function_or_method(obj):
|
||||
"""Check if an object is a function or method.
|
||||
|
||||
Args:
|
||||
obj: The Python object in question.
|
||||
|
||||
Returns:
|
||||
True if the object is an function or method.
|
||||
"""
|
||||
return inspect.isfunction(obj) or inspect.ismethod(obj) or is_cython(obj)
|
||||
|
||||
|
||||
def is_class_method(f):
|
||||
"""Returns whether the given method is a class_method."""
|
||||
return hasattr(f, "__self__") and f.__self__ is not None
|
||||
|
||||
|
||||
def is_static_method(cls, f_name):
|
||||
"""Returns whether the class has a static method with the given name.
|
||||
|
||||
Args:
|
||||
cls: The Python class (i.e. object of type `type`) to
|
||||
search for the method in.
|
||||
f_name: The name of the method to look up in this class
|
||||
and check whether or not it is static.
|
||||
"""
|
||||
for cls in inspect.getmro(cls):
|
||||
if f_name in cls.__dict__:
|
||||
return isinstance(cls.__dict__[f_name], staticmethod)
|
||||
return False
|
||||
@@ -1,7 +1,6 @@
|
||||
import binascii
|
||||
import errno
|
||||
import hashlib
|
||||
import inspect
|
||||
import logging
|
||||
import multiprocessing
|
||||
import numpy as np
|
||||
@@ -129,53 +128,6 @@ def push_error_to_driver_through_redis(redis_client,
|
||||
pubsub_msg.SerializeToString())
|
||||
|
||||
|
||||
def is_cython(obj):
|
||||
"""Check if an object is a Cython function or method"""
|
||||
|
||||
# TODO(suo): We could split these into two functions, one for Cython
|
||||
# functions and another for Cython methods.
|
||||
# TODO(suo): There doesn't appear to be a Cython function 'type' we can
|
||||
# check against via isinstance. Please correct me if I'm wrong.
|
||||
def check_cython(x):
|
||||
return type(x).__name__ == "cython_function_or_method"
|
||||
|
||||
# Check if function or method, respectively
|
||||
return check_cython(obj) or \
|
||||
(hasattr(obj, "__func__") and check_cython(obj.__func__))
|
||||
|
||||
|
||||
def is_function_or_method(obj):
|
||||
"""Check if an object is a function or method.
|
||||
|
||||
Args:
|
||||
obj: The Python object in question.
|
||||
|
||||
Returns:
|
||||
True if the object is an function or method.
|
||||
"""
|
||||
return inspect.isfunction(obj) or inspect.ismethod(obj) or is_cython(obj)
|
||||
|
||||
|
||||
def is_class_method(f):
|
||||
"""Returns whether the given method is a class_method."""
|
||||
return hasattr(f, "__self__") and f.__self__ is not None
|
||||
|
||||
|
||||
def is_static_method(cls, f_name):
|
||||
"""Returns whether the class has a static method with the given name.
|
||||
|
||||
Args:
|
||||
cls: The Python class (i.e. object of type `type`) to
|
||||
search for the method in.
|
||||
f_name: The name of the method to look up in this class
|
||||
and check whether or not it is static.
|
||||
"""
|
||||
for cls in inspect.getmro(cls):
|
||||
if f_name in cls.__dict__:
|
||||
return isinstance(cls.__dict__[f_name], staticmethod)
|
||||
return False
|
||||
|
||||
|
||||
def random_string():
|
||||
"""Generate a random string to use as an ID.
|
||||
|
||||
|
||||
@@ -48,7 +48,8 @@ from ray.exceptions import (
|
||||
)
|
||||
from ray.function_manager import FunctionActorManager
|
||||
from ray.ray_logging import setup_logger
|
||||
from ray.utils import _random_string, check_oversized_pickle, is_cython
|
||||
from ray.utils import _random_string, check_oversized_pickle
|
||||
from ray.util.inspect import is_cython
|
||||
|
||||
SCRIPT_MODE = 0
|
||||
WORKER_MODE = 1
|
||||
|
||||
Reference in New Issue
Block a user