diff --git a/python/ray/actor.py b/python/ray/actor.py index d3fa34ff8..b8981ca3d 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -12,6 +12,11 @@ from ray.util.placement_group import ( from ray import ActorClassID, Language from ray._raylet import PythonFunctionDescriptor from ray import cross_language +from ray.util.inspect import ( + is_function_or_method, + is_class_method, + is_static_method, +) logger = logging.getLogger(__name__) @@ -195,7 +200,7 @@ class ActorClassMethodMetadata(object): self = cls.__new__(cls) actor_methods = inspect.getmembers(modified_class, - ray.utils.is_function_or_method) + is_function_or_method) self.methods = dict(actor_methods) # Extract the signatures of each of the methods. This will be used @@ -208,9 +213,8 @@ class ActorClassMethodMetadata(object): # Whether or not this method requires binding of its first # argument. For class and static methods, we do not want to bind # the first argument, but we do for instance methods - is_bound = (ray.utils.is_class_method(method) - or ray.utils.is_static_method(modified_class, - method_name)) + is_bound = (is_class_method(method) + or is_static_method(modified_class, method_name)) # Print a warning message if the method signature is not # supported. We don't raise an exception because if the actor @@ -956,7 +960,7 @@ def modify_class(cls): Class.__module__ = cls.__module__ Class.__name__ = cls.__name__ - if not ray.utils.is_function_or_method(getattr(Class, "__init__", None)): + if not is_function_or_method(getattr(Class, "__init__", None)): # Add __init__ if it does not exist. # Actor creation will be executed with __init__ together. diff --git a/python/ray/experimental/client/__init__.py b/python/ray/experimental/client/__init__.py index 8d1267d24..36d19ba56 100644 --- a/python/ray/experimental/client/__init__.py +++ b/python/ray/experimental/client/__init__.py @@ -1,6 +1,7 @@ from ray.experimental.client.api import ClientAPI from ray.experimental.client.api import APIImpl from typing import Optional, List, Tuple +from contextlib import contextmanager import logging @@ -14,6 +15,16 @@ logger = logging.getLogger(__name__) _client_api: Optional[APIImpl] = None +@contextmanager +def stash_api_for_tests(in_test: bool): + api = None + if in_test: + api = stash_api() + yield api + if in_test: + restore_api(api) + + def stash_api() -> Optional[APIImpl]: global _client_api a = _client_api diff --git a/python/ray/experimental/client/api.py b/python/ray/experimental/client/api.py index a91111bde..17d0d6a97 100644 --- a/python/ray/experimental/client/api.py +++ b/python/ray/experimental/client/api.py @@ -31,7 +31,7 @@ class APIImpl(ABC): pass @abstractmethod - def call_remote(self, f, *args, **kwargs): + def call_remote(self, f, kind, *args, **kwargs): pass @abstractmethod @@ -55,8 +55,8 @@ class ClientAPI(APIImpl): def remote(self, *args, **kwargs): return self.worker.remote(*args, **kwargs) - def call_remote(self, f, *args, **kwargs): - return self.worker.call_remote(f, *args, **kwargs) + def call_remote(self, f, kind, *args, **kwargs): + return self.worker.call_remote(f, kind, *args, **kwargs) def close(self, *args, **kwargs): return self.worker.close() diff --git a/python/ray/experimental/client/client_app.py b/python/ray/experimental/client/client_app.py index fe30e17ff..41942341c 100644 --- a/python/ray/experimental/client/client_app.py +++ b/python/ray/experimental/client/client_app.py @@ -1,8 +1,30 @@ from ray.experimental.client import ray +from typing import Tuple ray.connect("localhost:50051") +@ray.remote +class HelloActor: + def __init__(self): + self.count = 0 + + def say_hello(self, whom: str) -> Tuple[str, int]: + self.count += 1 + return ("Hello " + whom, self.count) + + +actor = HelloActor.remote() +s, count = ray.get(actor.say_hello.remote("you")) +print(s, count) +assert s == "Hello you" +assert count == 1 +s, count = ray.get(actor.say_hello.remote("world")) +print(s, count) +assert s == "Hello world" +assert count == 2 + + @ray.remote def plus2(x): return x + 2 diff --git a/python/ray/experimental/client/common.py b/python/ray/experimental/client/common.py index d44df6413..d2ec7e041 100644 --- a/python/ray/experimental/client/common.py +++ b/python/ray/experimental/client/common.py @@ -4,17 +4,28 @@ from typing import Any from ray import cloudpickle -class ClientObjectRef: +class ClientBaseRef: def __init__(self, id): self.id = id def __repr__(self): - return "ClientObjectRef(%s)" % self.id.hex() + return "%s(%s)" % ( + type(self).__name__, + self.id.hex(), + ) def __eq__(self, other): return self.id == other.id +class ClientObjectRef(ClientBaseRef): + pass + + +class ClientActorRef(ClientBaseRef): + pass + + class ClientRemoteFunc: def __init__(self, f): self._func = f @@ -27,12 +38,64 @@ class ClientRemoteFunc: "Use {self._name}.remote method instead") def remote(self, *args, **kwargs): - return ray.call_remote(self, *args, **kwargs) + return ray.call_remote(self, ray_client_pb2.ClientTask.FUNCTION, *args, + **kwargs) def __repr__(self): return "ClientRemoteFunc(%s, %s)" % (self._name, self.id) +class ClientActorClass: + def __init__(self, actor_cls): + self.actor_cls = actor_cls + self._name = actor_cls.__name__ + + def __call__(self, *args, **kwargs): + raise TypeError(f"Remote actor cannot be instantiated directly. " + "Use {self._name}.remote() instead") + + def remote(self, *args, **kwargs): + # Actually instantiate the actor + ref = ray.call_remote(self, ray_client_pb2.ClientTask.ACTOR, *args, + **kwargs) + return ClientActorHandle(ref, self) + + def __repr__(self): + return "ClientRemoteActor(%s, %s)" % (self._name, self.id) + + def __getattr__(self, key): + raise NotImplementedError("static methods") + + +class ClientActorHandle: + def __init__(self, actor_id: ClientActorRef, + actor_class: ClientActorClass): + self.actor_id = actor_id + self.actor_class = actor_class + + def __getattr__(self, key): + return ClientRemoteMethod(self, key) + + +class ClientRemoteMethod: + def __init__(self, actor_handle: ClientActorHandle, method_name: str): + self.actor_handle = actor_handle + self.method_name = method_name + self._name = "%s.%s" % (self.actor_handle.actor_class._name, + self.method_name) + + def __call__(self, *args, **kwargs): + raise TypeError(f"Remote method cannot be called directly. " + "Use {self._name}.remote() instead") + + def remote(self, *args, **kwargs): + return ray.call_remote(self, ray_client_pb2.ClientTask.METHOD, *args, + **kwargs) + + def __repr__(self): + return "ClientRemoteMethod(%s, %s)" % (self._name, self.actor_id) + + def convert_from_arg(pb) -> Any: if pb.local == ray_client_pb2.Arg.Locality.REFERENCE: return ClientObjectRef(pb.reference_id) diff --git a/python/ray/experimental/client/server/core_ray_api.py b/python/ray/experimental/client/server/core_ray_api.py index 32d3b8ccf..3ebb36c32 100644 --- a/python/ray/experimental/client/server/core_ray_api.py +++ b/python/ray/experimental/client/server/core_ray_api.py @@ -26,7 +26,7 @@ class CoreRayAPI(APIImpl): def remote(self, *args, **kwargs): return ray.remote(*args, **kwargs) - def call_remote(self, f: ClientRemoteFunc, *args, **kwargs): + def call_remote(self, f: ClientRemoteFunc, kind: int, *args, **kwargs): if f._raylet_remote_func is None: f._raylet_remote_func = ray.remote(f._func) return f._raylet_remote_func.remote(*args, **kwargs) diff --git a/python/ray/experimental/client/server/server.py b/python/ray/experimental/client/server/server.py index b0b68221e..e42ea8db4 100644 --- a/python/ray/experimental/client/server/server.py +++ b/python/ray/experimental/client/server/server.py @@ -6,7 +6,8 @@ import ray import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc import time -import ray.experimental.client as client_init +import inspect +from ray.experimental.client import stash_api_for_tests from ray.experimental.client.common import convert_from_arg from ray.experimental.client.common import ClientObjectRef from ray.experimental.client.common import ClientRemoteFunc @@ -18,6 +19,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): def __init__(self, test_mode=False): self.object_refs = {} self.function_refs = {} + self.actor_refs = {} + self.registered_actor_classes = {} self._test_mode = test_mode def GetObject(self, request, context=None): @@ -67,25 +70,66 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): ready_object_ids=ready_object_ids, remaining_object_ids=remaining_object_ids) - def Schedule(self, task, context=None): - logger.info("schedule: %s" % task) + def Schedule(self, task, context=None) -> ray_client_pb2.ClientTaskTicket: + logger.info("schedule: %s %s" % + (task.name, + ray_client_pb2.ClientTask.RemoteExecType.Name(task.type))) + if task.type == ray_client_pb2.ClientTask.FUNCTION: + return self._schedule_function(task, context) + elif task.type == ray_client_pb2.ClientTask.ACTOR: + return self._schedule_actor(task, context) + elif task.type == ray_client_pb2.ClientTask.METHOD: + return self._schedule_method(task, context) + else: + raise NotImplementedError( + "Unimplemented Schedule task type: %s" % + ray_client_pb2.ClientTask.RemoteExecType.Name(task.type)) + + def _schedule_method(self, task: ray_client_pb2.ClientTask, + context=None) -> ray_client_pb2.ClientTaskTicket: + actor_handle = self.actor_refs.get(task.payload_id) + if actor_handle is None: + raise Exception( + "Can't run an actor the server doesn't have a handle for") + arglist = _convert_args(task.args) + with stash_api_for_tests(self._test_mode): + output = getattr(actor_handle, task.name).remote(*arglist) + self.object_refs[output.binary()] = output + return ray_client_pb2.ClientTaskTicket(return_id=output.binary()) + + def _schedule_actor(self, task: ray_client_pb2.ClientTask, + context=None) -> ray_client_pb2.ClientTaskTicket: + with stash_api_for_tests(self._test_mode): + if task.payload_id not in self.registered_actor_classes: + actor_class_ref = self.object_refs[task.payload_id] + actor_class = ray.get(actor_class_ref) + if not inspect.isclass(actor_class): + raise Exception("Attempting to schedule actor that " + "isn't a ClientActorClass.") + reg_class = ray.remote(actor_class) + self.registered_actor_classes[task.payload_id] = reg_class + remote_class = self.registered_actor_classes[task.payload_id] + arglist = _convert_args(task.args) + actor = remote_class.remote(*arglist) + actor_ref = actor._actor_id + self.actor_refs[actor_ref.binary()] = actor + return ray_client_pb2.ClientTaskTicket(return_id=actor_ref.binary()) + + def _schedule_function(self, task: ray_client_pb2.ClientTask, + context=None) -> ray_client_pb2.ClientTaskTicket: if task.payload_id not in self.function_refs: funcref = self.object_refs[task.payload_id] func = ray.get(funcref) if not isinstance(func, ClientRemoteFunc): - raise Exception("Attempting to schedule something that " - "isn't a ClientRemoteFunc") + raise Exception("Attempting to schedule function that " + "isn't a ClientRemoteFunc.") self.function_refs[task.payload_id] = func remote_func = self.function_refs[task.payload_id] arglist = _convert_args(task.args) # Prepare call if we're in a test - api = None - if self._test_mode: - api = client_init.stash_api() - output = remote_func.remote(*arglist) - if self._test_mode: - client_init.restore_api(api) - self.object_refs[output.binary()] = output + with stash_api_for_tests(self._test_mode): + output = remote_func.remote(*arglist) + self.object_refs[output.binary()] = output return ray_client_pb2.ClientTaskTicket(return_id=output.binary()) diff --git a/python/ray/experimental/client/worker.py b/python/ray/experimental/client/worker.py index f63171bdc..8c01bea34 100644 --- a/python/ray/experimental/client/worker.py +++ b/python/ray/experimental/client/worker.py @@ -2,15 +2,21 @@ It implements the Ray API functions that are forwarded through grpc calls to the server. """ -from typing import List, Tuple +import inspect +from typing import List +from typing import Tuple import ray.cloudpickle as cloudpickle +from ray.util.inspect import is_cython import grpc import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc from ray.experimental.client.common import convert_to_arg from ray.experimental.client.common import ClientObjectRef +from ray.experimental.client.common import ClientActorRef +from ray.experimental.client.common import ClientActorClass +from ray.experimental.client.common import ClientRemoteMethod from ray.experimental.client.common import ClientRemoteFunc @@ -87,7 +93,7 @@ class Worker: *, num_returns: int = 1, timeout: float = None - ) -> (List[ClientObjectRef], List[ClientObjectRef]): + ) -> Tuple[List[ClientObjectRef], List[ClientObjectRef]]: assert isinstance(object_refs, list) for ref in object_refs: assert isinstance(ref, ClientObjectRef) @@ -112,21 +118,62 @@ class Worker: return (client_ready_object_ids, client_remaining_object_ids) - def remote(self, func): - return ClientRemoteFunc(func) + def remote(self, function_or_class, *args, **kwargs): + # TODO(barakmich): Arguments to ray.remote + # get captured here. + if (inspect.isfunction(function_or_class) + or is_cython(function_or_class)): + return ClientRemoteFunc(function_or_class) + elif inspect.isclass(function_or_class): + return ClientActorClass(function_or_class) + else: + raise TypeError("The @ray.remote decorator must be applied to " + "either a function or to a class.") - def call_remote(self, func, *args, **kwargs): - if not isinstance(func, ClientRemoteFunc): - raise TypeError("Client not passing a ClientRemoteFunc stub") - func_ref = self._put(func) + def call_remote(self, instance, kind, *args, **kwargs): + ticket = None + if kind == ray_client_pb2.ClientTask.FUNCTION: + ticket = self._put_and_schedule(instance, kind, *args, **kwargs) + elif kind == ray_client_pb2.ClientTask.ACTOR: + ticket = self._put_and_schedule(instance, kind, *args, **kwargs) + return ClientActorRef(ticket.return_id) + elif kind == ray_client_pb2.ClientTask.METHOD: + ticket = self._call_method(instance, *args, **kwargs) + + if ticket is None: + raise Exception( + "Couldn't call_remote on %s for type %s" % (instance, kind)) + return ClientObjectRef(ticket.return_id) + + def _call_method(self, instance: ClientRemoteMethod, *args, **kwargs): + if not isinstance(instance, ClientRemoteMethod): + raise TypeError("Client not passing a ClientRemoteMethod stub") task = ray_client_pb2.ClientTask() - task.name = func._name - task.payload_id = func_ref.id + task.type = ray_client_pb2.ClientTask.METHOD + task.name = instance.method_name + task.payload_id = instance.actor_handle.actor_id.id for arg in args: pb_arg = convert_to_arg(arg) task.args.append(pb_arg) ticket = self.server.Schedule(task, metadata=self.metadata) - return ClientObjectRef(ticket.return_id) + return ticket + + def _put_and_schedule(self, item, task_type, *args, **kwargs): + if isinstance(item, ClientRemoteFunc): + ref = self._put(item) + elif isinstance(item, ClientActorClass): + ref = self._put(item.actor_cls) + else: + raise TypeError("Client not passing a ClientRemoteFunc stub") + task = ray_client_pb2.ClientTask() + task.type = task_type + task.name = item._name + task.payload_id = ref.id + for arg in args: + pb_arg = convert_to_arg(arg) + task.args.append(pb_arg) + ticket = self.server.Schedule(task, metadata=self.metadata) + return ticket def close(self): self.channel.close() diff --git a/python/ray/function_manager.py b/python/ray/function_manager.py index b4ae0b104..39b983d47 100644 --- a/python/ray/function_manager.py +++ b/python/ray/function_manager.py @@ -19,15 +19,17 @@ from ray import ray_constants from ray import cloudpickle as pickle from ray._raylet import PythonFunctionDescriptor from ray.utils import ( - is_function_or_method, - is_class_method, - is_static_method, check_oversized_pickle, decode, ensure_str, format_error_message, push_error_to_driver, ) +from ray.util.inspect import ( + is_function_or_method, + is_class_method, + is_static_method, +) FunctionExecutionInfo = namedtuple("FunctionExecutionInfo", ["function", "function_name", "max_calls"]) diff --git a/python/ray/signature.py b/python/ray/signature.py index ae92ec61e..2aaf4d023 100644 --- a/python/ray/signature.py +++ b/python/ray/signature.py @@ -2,7 +2,7 @@ import inspect from inspect import Parameter import logging -from ray.utils import is_cython +from ray.util.inspect import is_cython # Logger for this module. It should be configured at the entry point # into the program using Ray. Ray provides a default configuration at diff --git a/python/ray/tests/test_experimental_client.py b/python/ray/tests/test_experimental_client.py index 03e375eee..8fc07590e 100644 --- a/python/ray/tests/test_experimental_client.py +++ b/python/ray/tests/test_experimental_client.py @@ -1,162 +1,173 @@ import pytest +from contextlib import contextmanager + import ray.experimental.client.server.server as ray_client_server from ray.experimental.client import ray from ray.experimental.client.common import ClientObjectRef -def test_real_ray_fallback(ray_start_regular_shared): +@contextmanager +def ray_start_client_server(): server = ray_client_server.serve("localhost:50051", test_mode=True) ray.connect("localhost:50051") - - @ray.remote - def get_nodes_real(): - import ray as real_ray - return real_ray.nodes() - - nodes = ray.get(get_nodes_real.remote()) - assert len(nodes) == 1, nodes - - @ray.remote - def get_nodes(): - return ray.nodes() # Can access the full Ray API in remote methods. - - nodes = ray.get(get_nodes.remote()) - assert len(nodes) == 1, nodes - - with pytest.raises(NotImplementedError): - print(ray.nodes()) - + yield ray ray.disconnect() server.stop(0) +def test_real_ray_fallback(ray_start_regular_shared): + with ray_start_client_server() as ray: + + @ray.remote + def get_nodes_real(): + import ray as real_ray + return real_ray.nodes() + + nodes = ray.get(get_nodes_real.remote()) + assert len(nodes) == 1, nodes + + @ray.remote + def get_nodes(): + # Can access the full Ray API in remote methods. + return ray.nodes() + + nodes = ray.get(get_nodes.remote()) + assert len(nodes) == 1, nodes + + with pytest.raises(NotImplementedError): + print(ray.nodes()) + + def test_nested_function(ray_start_regular_shared): - server = ray_client_server.serve("localhost:50051", test_mode=True) - ray.connect("localhost:50051") + with ray_start_client_server() as ray: - @ray.remote - def g(): @ray.remote - def f(): - return "OK" + def g(): + @ray.remote + def f(): + return "OK" - return ray.get(f.remote()) + return ray.get(f.remote()) - assert ray.get(g.remote()) == "OK" - ray.disconnect() - server.stop(0) + assert ray.get(g.remote()) == "OK" def test_put_get(ray_start_regular_shared): - server = ray_client_server.serve("localhost:50051", test_mode=True) - ray.connect("localhost:50051") + with ray_start_client_server() as ray: + objectref = ray.put("hello world") + print(objectref) - objectref = ray.put("hello world") - print(objectref) - - retval = ray.get(objectref) - assert retval == "hello world" - ray.disconnect() - server.stop(0) + retval = ray.get(objectref) + assert retval == "hello world" def test_wait(ray_start_regular_shared): - server = ray_client_server.serve("localhost:50051", test_mode=True) - ray.connect("localhost:50051") + with ray_start_client_server() as ray: + objectref = ray.put("hello world") + ready, remaining = ray.wait([objectref]) + assert remaining == [] + retval = ray.get(ready[0]) + assert retval == "hello world" - objectref = ray.put("hello world") - ready, remaining = ray.wait([objectref]) - assert remaining == [] - retval = ray.get(ready[0]) - assert retval == "hello world" + objectref2 = ray.put(5) + ready, remaining = ray.wait([objectref, objectref2]) + assert (ready, remaining) == ([objectref], [objectref2]) or \ + (ready, remaining) == ([objectref2], [objectref]) + ready_retval = ray.get(ready[0]) + remaining_retval = ray.get(remaining[0]) + assert (ready_retval, remaining_retval) == ("hello world", 5) \ + or (ready_retval, remaining_retval) == (5, "hello world") - objectref2 = ray.put(5) - ready, remaining = ray.wait([objectref, objectref2]) - assert (ready, remaining) == ([objectref], [objectref2]) or \ - (ready, remaining) == ([objectref2], [objectref]) - ready_retval = ray.get(ready[0]) - remaining_retval = ray.get(remaining[0]) - assert (ready_retval, remaining_retval) == ("hello world", 5) \ - or (ready_retval, remaining_retval) == (5, "hello world") - - with pytest.raises(Exception): - # Reference not in the object store. - ray.wait([ClientObjectRef("blabla")]) - with pytest.raises(AssertionError): - ray.wait("blabla") - with pytest.raises(AssertionError): - ray.wait(ClientObjectRef("blabla")) - with pytest.raises(AssertionError): - ray.wait(["blabla"]) - - ray.disconnect() - server.stop(0) + with pytest.raises(Exception): + # Reference not in the object store. + ray.wait([ClientObjectRef("blabla")]) + with pytest.raises(AssertionError): + ray.wait("blabla") + with pytest.raises(AssertionError): + ray.wait(ClientObjectRef("blabla")) + with pytest.raises(AssertionError): + ray.wait(["blabla"]) def test_remote_functions(ray_start_regular_shared): - server = ray_client_server.serve("localhost:50051", test_mode=True) - ray.connect("localhost:50051") + with ray_start_client_server() as ray: - @ray.remote - def plus2(x): - return x + 2 + @ray.remote + def plus2(x): + return x + 2 - @ray.remote - def fact(x): - print(x, type(fact)) - if x <= 0: - return 1 - # This hits the "nested tasks" issue - # https://github.com/ray-project/ray/issues/3644 - # So we're on the right track! - return ray.get(fact.remote(x - 1)) * x + @ray.remote + def fact(x): + print(x, type(fact)) + if x <= 0: + return 1 + # This hits the "nested tasks" issue + # https://github.com/ray-project/ray/issues/3644 + # So we're on the right track! + return ray.get(fact.remote(x - 1)) * x - ref2 = plus2.remote(234) - # `236` - assert ray.get(ref2) == 236 + ref2 = plus2.remote(234) + # `236` + assert ray.get(ref2) == 236 - ref3 = fact.remote(20) - # `2432902008176640000` - assert ray.get(ref3) == 2_432_902_008_176_640_000 + ref3 = fact.remote(20) + # `2432902008176640000` + assert ray.get(ref3) == 2_432_902_008_176_640_000 - # Reuse the cached ClientRemoteFunc object - ref4 = fact.remote(5) - assert ray.get(ref4) == 120 + # Reuse the cached ClientRemoteFunc object + ref4 = fact.remote(5) + assert ray.get(ref4) == 120 - # Test ray.wait() - ref5 = fact.remote(10) - # should return ref2, ref3, ref4 - res = ray.wait([ref5, ref2, ref3, ref4], num_returns=3) - assert [ref2, ref3, ref4] == res[0] - assert [ref5] == res[1] - assert ray.get(res[0]) == [236, 2_432_902_008_176_640_000, 120] - # should return ref2, ref3, ref4, ref5 - res = ray.wait([ref2, ref3, ref4, ref5], num_returns=4) - assert [ref2, ref3, ref4, ref5] == res[0] - assert [] == res[1] - assert ray.get(res[0]) == [236, 2_432_902_008_176_640_000, 120, 3628800] - - ray.disconnect() - server.stop(0) + # Test ray.wait() + ref5 = fact.remote(10) + # should return ref2, ref3, ref4 + res = ray.wait([ref5, ref2, ref3, ref4], num_returns=3) + assert [ref2, ref3, ref4] == res[0] + assert [ref5] == res[1] + assert ray.get(res[0]) == [236, 2_432_902_008_176_640_000, 120] + # should return ref2, ref3, ref4, ref5 + res = ray.wait([ref2, ref3, ref4, ref5], num_returns=4) + assert [ref2, ref3, ref4, ref5] == res[0] + assert [] == res[1] + all_vals = ray.get(res[0]) + assert all_vals == [236, 2_432_902_008_176_640_000, 120, 3628800] def test_function_calling_function(ray_start_regular_shared): - server = ray_client_server.serve("localhost:50051", test_mode=True) - ray.connect("localhost:50051") + with ray_start_client_server() as ray: - @ray.remote - def g(): - return "OK" + @ray.remote + def g(): + return "OK" - @ray.remote - def f(): - print(f, f._name, g._name, g) - return ray.get(g.remote()) + @ray.remote + def f(): + print(f, f._name, g._name, g) + return ray.get(g.remote()) - print(f, type(f)) - assert ray.get(f.remote()) == "OK" - ray.disconnect() - server.stop(0) + print(f, type(f)) + assert ray.get(f.remote()) == "OK" + + +def test_basic_actor(ray_start_regular_shared): + with ray_start_client_server() as ray: + + @ray.remote + class HelloActor: + def __init__(self): + self.count = 0 + + def say_hello(self, whom): + self.count += 1 + return ("Hello " + whom, self.count) + + actor = HelloActor.remote() + s, count = ray.get(actor.say_hello.remote("you")) + assert s == "Hello you" + assert count == 1 + s, count = ray.get(actor.say_hello.remote("world")) + assert s == "Hello world" + assert count == 2 if __name__ == "__main__": diff --git a/python/ray/util/inspect.py b/python/ray/util/inspect.py new file mode 100644 index 000000000..c2e82f654 --- /dev/null +++ b/python/ray/util/inspect.py @@ -0,0 +1,48 @@ +import inspect + + +def is_cython(obj): + """Check if an object is a Cython function or method""" + + # TODO(suo): We could split these into two functions, one for Cython + # functions and another for Cython methods. + # TODO(suo): There doesn't appear to be a Cython function 'type' we can + # check against via isinstance. Please correct me if I'm wrong. + def check_cython(x): + return type(x).__name__ == "cython_function_or_method" + + # Check if function or method, respectively + return check_cython(obj) or \ + (hasattr(obj, "__func__") and check_cython(obj.__func__)) + + +def is_function_or_method(obj): + """Check if an object is a function or method. + + Args: + obj: The Python object in question. + + Returns: + True if the object is an function or method. + """ + return inspect.isfunction(obj) or inspect.ismethod(obj) or is_cython(obj) + + +def is_class_method(f): + """Returns whether the given method is a class_method.""" + return hasattr(f, "__self__") and f.__self__ is not None + + +def is_static_method(cls, f_name): + """Returns whether the class has a static method with the given name. + + Args: + cls: The Python class (i.e. object of type `type`) to + search for the method in. + f_name: The name of the method to look up in this class + and check whether or not it is static. + """ + for cls in inspect.getmro(cls): + if f_name in cls.__dict__: + return isinstance(cls.__dict__[f_name], staticmethod) + return False diff --git a/python/ray/utils.py b/python/ray/utils.py index fb8a1964c..6659d7eb9 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -1,7 +1,6 @@ import binascii import errno import hashlib -import inspect import logging import multiprocessing import numpy as np @@ -129,53 +128,6 @@ def push_error_to_driver_through_redis(redis_client, pubsub_msg.SerializeToString()) -def is_cython(obj): - """Check if an object is a Cython function or method""" - - # TODO(suo): We could split these into two functions, one for Cython - # functions and another for Cython methods. - # TODO(suo): There doesn't appear to be a Cython function 'type' we can - # check against via isinstance. Please correct me if I'm wrong. - def check_cython(x): - return type(x).__name__ == "cython_function_or_method" - - # Check if function or method, respectively - return check_cython(obj) or \ - (hasattr(obj, "__func__") and check_cython(obj.__func__)) - - -def is_function_or_method(obj): - """Check if an object is a function or method. - - Args: - obj: The Python object in question. - - Returns: - True if the object is an function or method. - """ - return inspect.isfunction(obj) or inspect.ismethod(obj) or is_cython(obj) - - -def is_class_method(f): - """Returns whether the given method is a class_method.""" - return hasattr(f, "__self__") and f.__self__ is not None - - -def is_static_method(cls, f_name): - """Returns whether the class has a static method with the given name. - - Args: - cls: The Python class (i.e. object of type `type`) to - search for the method in. - f_name: The name of the method to look up in this class - and check whether or not it is static. - """ - for cls in inspect.getmro(cls): - if f_name in cls.__dict__: - return isinstance(cls.__dict__[f_name], staticmethod) - return False - - def random_string(): """Generate a random string to use as an ID. diff --git a/python/ray/worker.py b/python/ray/worker.py index 9960f811a..d5093e360 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -48,7 +48,8 @@ from ray.exceptions import ( ) from ray.function_manager import FunctionActorManager from ray.ray_logging import setup_logger -from ray.utils import _random_string, check_oversized_pickle, is_cython +from ray.utils import _random_string, check_oversized_pickle +from ray.util.inspect import is_cython SCRIPT_MODE = 0 WORKER_MODE = 1 diff --git a/src/ray/protobuf/ray_client.proto b/src/ray/protobuf/ray_client.proto index c4fd65555..fd8fe5345 100644 --- a/src/ray/protobuf/ray_client.proto +++ b/src/ray/protobuf/ray_client.proto @@ -30,10 +30,16 @@ message Arg { } message ClientTask { - // Optionally Provided Task Name - string name = 1; - bytes payload_id = 2; - repeated Arg args = 3; + enum RemoteExecType { + FUNCTION = 0; + ACTOR = 1; + METHOD = 2; + STATIC_METHOD = 3; + } + RemoteExecType type = 1; + string name = 2; + bytes payload_id = 3; + repeated Arg args = 4; } message ClientTaskTicket {