[ray_client] actors v0 (#12388)

This commit is contained in:
Barak Michener
2020-12-01 13:12:08 -08:00
committed by GitHub
parent 0e892908f7
commit 6412dfaf38
15 changed files with 421 additions and 210 deletions
+9 -5
View File
@@ -12,6 +12,11 @@ from ray.util.placement_group import (
from ray import ActorClassID, Language
from ray._raylet import PythonFunctionDescriptor
from ray import cross_language
from ray.util.inspect import (
is_function_or_method,
is_class_method,
is_static_method,
)
logger = logging.getLogger(__name__)
@@ -195,7 +200,7 @@ class ActorClassMethodMetadata(object):
self = cls.__new__(cls)
actor_methods = inspect.getmembers(modified_class,
ray.utils.is_function_or_method)
is_function_or_method)
self.methods = dict(actor_methods)
# Extract the signatures of each of the methods. This will be used
@@ -208,9 +213,8 @@ class ActorClassMethodMetadata(object):
# Whether or not this method requires binding of its first
# argument. For class and static methods, we do not want to bind
# the first argument, but we do for instance methods
is_bound = (ray.utils.is_class_method(method)
or ray.utils.is_static_method(modified_class,
method_name))
is_bound = (is_class_method(method)
or is_static_method(modified_class, method_name))
# Print a warning message if the method signature is not
# supported. We don't raise an exception because if the actor
@@ -956,7 +960,7 @@ def modify_class(cls):
Class.__module__ = cls.__module__
Class.__name__ = cls.__name__
if not ray.utils.is_function_or_method(getattr(Class, "__init__", None)):
if not is_function_or_method(getattr(Class, "__init__", None)):
# Add __init__ if it does not exist.
# Actor creation will be executed with __init__ together.
@@ -1,6 +1,7 @@
from ray.experimental.client.api import ClientAPI
from ray.experimental.client.api import APIImpl
from typing import Optional, List, Tuple
from contextlib import contextmanager
import logging
@@ -14,6 +15,16 @@ logger = logging.getLogger(__name__)
_client_api: Optional[APIImpl] = None
@contextmanager
def stash_api_for_tests(in_test: bool):
api = None
if in_test:
api = stash_api()
yield api
if in_test:
restore_api(api)
def stash_api() -> Optional[APIImpl]:
global _client_api
a = _client_api
+3 -3
View File
@@ -31,7 +31,7 @@ class APIImpl(ABC):
pass
@abstractmethod
def call_remote(self, f, *args, **kwargs):
def call_remote(self, f, kind, *args, **kwargs):
pass
@abstractmethod
@@ -55,8 +55,8 @@ class ClientAPI(APIImpl):
def remote(self, *args, **kwargs):
return self.worker.remote(*args, **kwargs)
def call_remote(self, f, *args, **kwargs):
return self.worker.call_remote(f, *args, **kwargs)
def call_remote(self, f, kind, *args, **kwargs):
return self.worker.call_remote(f, kind, *args, **kwargs)
def close(self, *args, **kwargs):
return self.worker.close()
@@ -1,8 +1,30 @@
from ray.experimental.client import ray
from typing import Tuple
ray.connect("localhost:50051")
@ray.remote
class HelloActor:
def __init__(self):
self.count = 0
def say_hello(self, whom: str) -> Tuple[str, int]:
self.count += 1
return ("Hello " + whom, self.count)
actor = HelloActor.remote()
s, count = ray.get(actor.say_hello.remote("you"))
print(s, count)
assert s == "Hello you"
assert count == 1
s, count = ray.get(actor.say_hello.remote("world"))
print(s, count)
assert s == "Hello world"
assert count == 2
@ray.remote
def plus2(x):
return x + 2
+66 -3
View File
@@ -4,17 +4,28 @@ from typing import Any
from ray import cloudpickle
class ClientObjectRef:
class ClientBaseRef:
def __init__(self, id):
self.id = id
def __repr__(self):
return "ClientObjectRef(%s)" % self.id.hex()
return "%s(%s)" % (
type(self).__name__,
self.id.hex(),
)
def __eq__(self, other):
return self.id == other.id
class ClientObjectRef(ClientBaseRef):
pass
class ClientActorRef(ClientBaseRef):
pass
class ClientRemoteFunc:
def __init__(self, f):
self._func = f
@@ -27,12 +38,64 @@ class ClientRemoteFunc:
"Use {self._name}.remote method instead")
def remote(self, *args, **kwargs):
return ray.call_remote(self, *args, **kwargs)
return ray.call_remote(self, ray_client_pb2.ClientTask.FUNCTION, *args,
**kwargs)
def __repr__(self):
return "ClientRemoteFunc(%s, %s)" % (self._name, self.id)
class ClientActorClass:
def __init__(self, actor_cls):
self.actor_cls = actor_cls
self._name = actor_cls.__name__
def __call__(self, *args, **kwargs):
raise TypeError(f"Remote actor cannot be instantiated directly. "
"Use {self._name}.remote() instead")
def remote(self, *args, **kwargs):
# Actually instantiate the actor
ref = ray.call_remote(self, ray_client_pb2.ClientTask.ACTOR, *args,
**kwargs)
return ClientActorHandle(ref, self)
def __repr__(self):
return "ClientRemoteActor(%s, %s)" % (self._name, self.id)
def __getattr__(self, key):
raise NotImplementedError("static methods")
class ClientActorHandle:
def __init__(self, actor_id: ClientActorRef,
actor_class: ClientActorClass):
self.actor_id = actor_id
self.actor_class = actor_class
def __getattr__(self, key):
return ClientRemoteMethod(self, key)
class ClientRemoteMethod:
def __init__(self, actor_handle: ClientActorHandle, method_name: str):
self.actor_handle = actor_handle
self.method_name = method_name
self._name = "%s.%s" % (self.actor_handle.actor_class._name,
self.method_name)
def __call__(self, *args, **kwargs):
raise TypeError(f"Remote method cannot be called directly. "
"Use {self._name}.remote() instead")
def remote(self, *args, **kwargs):
return ray.call_remote(self, ray_client_pb2.ClientTask.METHOD, *args,
**kwargs)
def __repr__(self):
return "ClientRemoteMethod(%s, %s)" % (self._name, self.actor_id)
def convert_from_arg(pb) -> Any:
if pb.local == ray_client_pb2.Arg.Locality.REFERENCE:
return ClientObjectRef(pb.reference_id)
@@ -26,7 +26,7 @@ class CoreRayAPI(APIImpl):
def remote(self, *args, **kwargs):
return ray.remote(*args, **kwargs)
def call_remote(self, f: ClientRemoteFunc, *args, **kwargs):
def call_remote(self, f: ClientRemoteFunc, kind: int, *args, **kwargs):
if f._raylet_remote_func is None:
f._raylet_remote_func = ray.remote(f._func)
return f._raylet_remote_func.remote(*args, **kwargs)
+56 -12
View File
@@ -6,7 +6,8 @@ import ray
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
import time
import ray.experimental.client as client_init
import inspect
from ray.experimental.client import stash_api_for_tests
from ray.experimental.client.common import convert_from_arg
from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.common import ClientRemoteFunc
@@ -18,6 +19,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
def __init__(self, test_mode=False):
self.object_refs = {}
self.function_refs = {}
self.actor_refs = {}
self.registered_actor_classes = {}
self._test_mode = test_mode
def GetObject(self, request, context=None):
@@ -67,25 +70,66 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
ready_object_ids=ready_object_ids,
remaining_object_ids=remaining_object_ids)
def Schedule(self, task, context=None):
logger.info("schedule: %s" % task)
def Schedule(self, task, context=None) -> ray_client_pb2.ClientTaskTicket:
logger.info("schedule: %s %s" %
(task.name,
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type)))
if task.type == ray_client_pb2.ClientTask.FUNCTION:
return self._schedule_function(task, context)
elif task.type == ray_client_pb2.ClientTask.ACTOR:
return self._schedule_actor(task, context)
elif task.type == ray_client_pb2.ClientTask.METHOD:
return self._schedule_method(task, context)
else:
raise NotImplementedError(
"Unimplemented Schedule task type: %s" %
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type))
def _schedule_method(self, task: ray_client_pb2.ClientTask,
context=None) -> ray_client_pb2.ClientTaskTicket:
actor_handle = self.actor_refs.get(task.payload_id)
if actor_handle is None:
raise Exception(
"Can't run an actor the server doesn't have a handle for")
arglist = _convert_args(task.args)
with stash_api_for_tests(self._test_mode):
output = getattr(actor_handle, task.name).remote(*arglist)
self.object_refs[output.binary()] = output
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
def _schedule_actor(self, task: ray_client_pb2.ClientTask,
context=None) -> ray_client_pb2.ClientTaskTicket:
with stash_api_for_tests(self._test_mode):
if task.payload_id not in self.registered_actor_classes:
actor_class_ref = self.object_refs[task.payload_id]
actor_class = ray.get(actor_class_ref)
if not inspect.isclass(actor_class):
raise Exception("Attempting to schedule actor that "
"isn't a ClientActorClass.")
reg_class = ray.remote(actor_class)
self.registered_actor_classes[task.payload_id] = reg_class
remote_class = self.registered_actor_classes[task.payload_id]
arglist = _convert_args(task.args)
actor = remote_class.remote(*arglist)
actor_ref = actor._actor_id
self.actor_refs[actor_ref.binary()] = actor
return ray_client_pb2.ClientTaskTicket(return_id=actor_ref.binary())
def _schedule_function(self, task: ray_client_pb2.ClientTask,
context=None) -> ray_client_pb2.ClientTaskTicket:
if task.payload_id not in self.function_refs:
funcref = self.object_refs[task.payload_id]
func = ray.get(funcref)
if not isinstance(func, ClientRemoteFunc):
raise Exception("Attempting to schedule something that "
"isn't a ClientRemoteFunc")
raise Exception("Attempting to schedule function that "
"isn't a ClientRemoteFunc.")
self.function_refs[task.payload_id] = func
remote_func = self.function_refs[task.payload_id]
arglist = _convert_args(task.args)
# Prepare call if we're in a test
api = None
if self._test_mode:
api = client_init.stash_api()
output = remote_func.remote(*arglist)
if self._test_mode:
client_init.restore_api(api)
self.object_refs[output.binary()] = output
with stash_api_for_tests(self._test_mode):
output = remote_func.remote(*arglist)
self.object_refs[output.binary()] = output
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
+58 -11
View File
@@ -2,15 +2,21 @@
It implements the Ray API functions that are forwarded through grpc calls
to the server.
"""
from typing import List, Tuple
import inspect
from typing import List
from typing import Tuple
import ray.cloudpickle as cloudpickle
from ray.util.inspect import is_cython
import grpc
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
from ray.experimental.client.common import convert_to_arg
from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.common import ClientActorRef
from ray.experimental.client.common import ClientActorClass
from ray.experimental.client.common import ClientRemoteMethod
from ray.experimental.client.common import ClientRemoteFunc
@@ -87,7 +93,7 @@ class Worker:
*,
num_returns: int = 1,
timeout: float = None
) -> (List[ClientObjectRef], List[ClientObjectRef]):
) -> Tuple[List[ClientObjectRef], List[ClientObjectRef]]:
assert isinstance(object_refs, list)
for ref in object_refs:
assert isinstance(ref, ClientObjectRef)
@@ -112,21 +118,62 @@ class Worker:
return (client_ready_object_ids, client_remaining_object_ids)
def remote(self, func):
return ClientRemoteFunc(func)
def remote(self, function_or_class, *args, **kwargs):
# TODO(barakmich): Arguments to ray.remote
# get captured here.
if (inspect.isfunction(function_or_class)
or is_cython(function_or_class)):
return ClientRemoteFunc(function_or_class)
elif inspect.isclass(function_or_class):
return ClientActorClass(function_or_class)
else:
raise TypeError("The @ray.remote decorator must be applied to "
"either a function or to a class.")
def call_remote(self, func, *args, **kwargs):
if not isinstance(func, ClientRemoteFunc):
raise TypeError("Client not passing a ClientRemoteFunc stub")
func_ref = self._put(func)
def call_remote(self, instance, kind, *args, **kwargs):
ticket = None
if kind == ray_client_pb2.ClientTask.FUNCTION:
ticket = self._put_and_schedule(instance, kind, *args, **kwargs)
elif kind == ray_client_pb2.ClientTask.ACTOR:
ticket = self._put_and_schedule(instance, kind, *args, **kwargs)
return ClientActorRef(ticket.return_id)
elif kind == ray_client_pb2.ClientTask.METHOD:
ticket = self._call_method(instance, *args, **kwargs)
if ticket is None:
raise Exception(
"Couldn't call_remote on %s for type %s" % (instance, kind))
return ClientObjectRef(ticket.return_id)
def _call_method(self, instance: ClientRemoteMethod, *args, **kwargs):
if not isinstance(instance, ClientRemoteMethod):
raise TypeError("Client not passing a ClientRemoteMethod stub")
task = ray_client_pb2.ClientTask()
task.name = func._name
task.payload_id = func_ref.id
task.type = ray_client_pb2.ClientTask.METHOD
task.name = instance.method_name
task.payload_id = instance.actor_handle.actor_id.id
for arg in args:
pb_arg = convert_to_arg(arg)
task.args.append(pb_arg)
ticket = self.server.Schedule(task, metadata=self.metadata)
return ClientObjectRef(ticket.return_id)
return ticket
def _put_and_schedule(self, item, task_type, *args, **kwargs):
if isinstance(item, ClientRemoteFunc):
ref = self._put(item)
elif isinstance(item, ClientActorClass):
ref = self._put(item.actor_cls)
else:
raise TypeError("Client not passing a ClientRemoteFunc stub")
task = ray_client_pb2.ClientTask()
task.type = task_type
task.name = item._name
task.payload_id = ref.id
for arg in args:
pb_arg = convert_to_arg(arg)
task.args.append(pb_arg)
ticket = self.server.Schedule(task, metadata=self.metadata)
return ticket
def close(self):
self.channel.close()
+5 -3
View File
@@ -19,15 +19,17 @@ from ray import ray_constants
from ray import cloudpickle as pickle
from ray._raylet import PythonFunctionDescriptor
from ray.utils import (
is_function_or_method,
is_class_method,
is_static_method,
check_oversized_pickle,
decode,
ensure_str,
format_error_message,
push_error_to_driver,
)
from ray.util.inspect import (
is_function_or_method,
is_class_method,
is_static_method,
)
FunctionExecutionInfo = namedtuple("FunctionExecutionInfo",
["function", "function_name", "max_calls"])
+1 -1
View File
@@ -2,7 +2,7 @@ import inspect
from inspect import Parameter
import logging
from ray.utils import is_cython
from ray.util.inspect import is_cython
# Logger for this module. It should be configured at the entry point
# into the program using Ray. Ray provides a default configuration at
+129 -118
View File
@@ -1,162 +1,173 @@
import pytest
from contextlib import contextmanager
import ray.experimental.client.server.server as ray_client_server
from ray.experimental.client import ray
from ray.experimental.client.common import ClientObjectRef
def test_real_ray_fallback(ray_start_regular_shared):
@contextmanager
def ray_start_client_server():
server = ray_client_server.serve("localhost:50051", test_mode=True)
ray.connect("localhost:50051")
@ray.remote
def get_nodes_real():
import ray as real_ray
return real_ray.nodes()
nodes = ray.get(get_nodes_real.remote())
assert len(nodes) == 1, nodes
@ray.remote
def get_nodes():
return ray.nodes() # Can access the full Ray API in remote methods.
nodes = ray.get(get_nodes.remote())
assert len(nodes) == 1, nodes
with pytest.raises(NotImplementedError):
print(ray.nodes())
yield ray
ray.disconnect()
server.stop(0)
def test_real_ray_fallback(ray_start_regular_shared):
with ray_start_client_server() as ray:
@ray.remote
def get_nodes_real():
import ray as real_ray
return real_ray.nodes()
nodes = ray.get(get_nodes_real.remote())
assert len(nodes) == 1, nodes
@ray.remote
def get_nodes():
# Can access the full Ray API in remote methods.
return ray.nodes()
nodes = ray.get(get_nodes.remote())
assert len(nodes) == 1, nodes
with pytest.raises(NotImplementedError):
print(ray.nodes())
def test_nested_function(ray_start_regular_shared):
server = ray_client_server.serve("localhost:50051", test_mode=True)
ray.connect("localhost:50051")
with ray_start_client_server() as ray:
@ray.remote
def g():
@ray.remote
def f():
return "OK"
def g():
@ray.remote
def f():
return "OK"
return ray.get(f.remote())
return ray.get(f.remote())
assert ray.get(g.remote()) == "OK"
ray.disconnect()
server.stop(0)
assert ray.get(g.remote()) == "OK"
def test_put_get(ray_start_regular_shared):
server = ray_client_server.serve("localhost:50051", test_mode=True)
ray.connect("localhost:50051")
with ray_start_client_server() as ray:
objectref = ray.put("hello world")
print(objectref)
objectref = ray.put("hello world")
print(objectref)
retval = ray.get(objectref)
assert retval == "hello world"
ray.disconnect()
server.stop(0)
retval = ray.get(objectref)
assert retval == "hello world"
def test_wait(ray_start_regular_shared):
server = ray_client_server.serve("localhost:50051", test_mode=True)
ray.connect("localhost:50051")
with ray_start_client_server() as ray:
objectref = ray.put("hello world")
ready, remaining = ray.wait([objectref])
assert remaining == []
retval = ray.get(ready[0])
assert retval == "hello world"
objectref = ray.put("hello world")
ready, remaining = ray.wait([objectref])
assert remaining == []
retval = ray.get(ready[0])
assert retval == "hello world"
objectref2 = ray.put(5)
ready, remaining = ray.wait([objectref, objectref2])
assert (ready, remaining) == ([objectref], [objectref2]) or \
(ready, remaining) == ([objectref2], [objectref])
ready_retval = ray.get(ready[0])
remaining_retval = ray.get(remaining[0])
assert (ready_retval, remaining_retval) == ("hello world", 5) \
or (ready_retval, remaining_retval) == (5, "hello world")
objectref2 = ray.put(5)
ready, remaining = ray.wait([objectref, objectref2])
assert (ready, remaining) == ([objectref], [objectref2]) or \
(ready, remaining) == ([objectref2], [objectref])
ready_retval = ray.get(ready[0])
remaining_retval = ray.get(remaining[0])
assert (ready_retval, remaining_retval) == ("hello world", 5) \
or (ready_retval, remaining_retval) == (5, "hello world")
with pytest.raises(Exception):
# Reference not in the object store.
ray.wait([ClientObjectRef("blabla")])
with pytest.raises(AssertionError):
ray.wait("blabla")
with pytest.raises(AssertionError):
ray.wait(ClientObjectRef("blabla"))
with pytest.raises(AssertionError):
ray.wait(["blabla"])
ray.disconnect()
server.stop(0)
with pytest.raises(Exception):
# Reference not in the object store.
ray.wait([ClientObjectRef("blabla")])
with pytest.raises(AssertionError):
ray.wait("blabla")
with pytest.raises(AssertionError):
ray.wait(ClientObjectRef("blabla"))
with pytest.raises(AssertionError):
ray.wait(["blabla"])
def test_remote_functions(ray_start_regular_shared):
server = ray_client_server.serve("localhost:50051", test_mode=True)
ray.connect("localhost:50051")
with ray_start_client_server() as ray:
@ray.remote
def plus2(x):
return x + 2
@ray.remote
def plus2(x):
return x + 2
@ray.remote
def fact(x):
print(x, type(fact))
if x <= 0:
return 1
# This hits the "nested tasks" issue
# https://github.com/ray-project/ray/issues/3644
# So we're on the right track!
return ray.get(fact.remote(x - 1)) * x
@ray.remote
def fact(x):
print(x, type(fact))
if x <= 0:
return 1
# This hits the "nested tasks" issue
# https://github.com/ray-project/ray/issues/3644
# So we're on the right track!
return ray.get(fact.remote(x - 1)) * x
ref2 = plus2.remote(234)
# `236`
assert ray.get(ref2) == 236
ref2 = plus2.remote(234)
# `236`
assert ray.get(ref2) == 236
ref3 = fact.remote(20)
# `2432902008176640000`
assert ray.get(ref3) == 2_432_902_008_176_640_000
ref3 = fact.remote(20)
# `2432902008176640000`
assert ray.get(ref3) == 2_432_902_008_176_640_000
# Reuse the cached ClientRemoteFunc object
ref4 = fact.remote(5)
assert ray.get(ref4) == 120
# Reuse the cached ClientRemoteFunc object
ref4 = fact.remote(5)
assert ray.get(ref4) == 120
# Test ray.wait()
ref5 = fact.remote(10)
# should return ref2, ref3, ref4
res = ray.wait([ref5, ref2, ref3, ref4], num_returns=3)
assert [ref2, ref3, ref4] == res[0]
assert [ref5] == res[1]
assert ray.get(res[0]) == [236, 2_432_902_008_176_640_000, 120]
# should return ref2, ref3, ref4, ref5
res = ray.wait([ref2, ref3, ref4, ref5], num_returns=4)
assert [ref2, ref3, ref4, ref5] == res[0]
assert [] == res[1]
assert ray.get(res[0]) == [236, 2_432_902_008_176_640_000, 120, 3628800]
ray.disconnect()
server.stop(0)
# Test ray.wait()
ref5 = fact.remote(10)
# should return ref2, ref3, ref4
res = ray.wait([ref5, ref2, ref3, ref4], num_returns=3)
assert [ref2, ref3, ref4] == res[0]
assert [ref5] == res[1]
assert ray.get(res[0]) == [236, 2_432_902_008_176_640_000, 120]
# should return ref2, ref3, ref4, ref5
res = ray.wait([ref2, ref3, ref4, ref5], num_returns=4)
assert [ref2, ref3, ref4, ref5] == res[0]
assert [] == res[1]
all_vals = ray.get(res[0])
assert all_vals == [236, 2_432_902_008_176_640_000, 120, 3628800]
def test_function_calling_function(ray_start_regular_shared):
server = ray_client_server.serve("localhost:50051", test_mode=True)
ray.connect("localhost:50051")
with ray_start_client_server() as ray:
@ray.remote
def g():
return "OK"
@ray.remote
def g():
return "OK"
@ray.remote
def f():
print(f, f._name, g._name, g)
return ray.get(g.remote())
@ray.remote
def f():
print(f, f._name, g._name, g)
return ray.get(g.remote())
print(f, type(f))
assert ray.get(f.remote()) == "OK"
ray.disconnect()
server.stop(0)
print(f, type(f))
assert ray.get(f.remote()) == "OK"
def test_basic_actor(ray_start_regular_shared):
with ray_start_client_server() as ray:
@ray.remote
class HelloActor:
def __init__(self):
self.count = 0
def say_hello(self, whom):
self.count += 1
return ("Hello " + whom, self.count)
actor = HelloActor.remote()
s, count = ray.get(actor.say_hello.remote("you"))
assert s == "Hello you"
assert count == 1
s, count = ray.get(actor.say_hello.remote("world"))
assert s == "Hello world"
assert count == 2
if __name__ == "__main__":
+48
View File
@@ -0,0 +1,48 @@
import inspect
def is_cython(obj):
"""Check if an object is a Cython function or method"""
# TODO(suo): We could split these into two functions, one for Cython
# functions and another for Cython methods.
# TODO(suo): There doesn't appear to be a Cython function 'type' we can
# check against via isinstance. Please correct me if I'm wrong.
def check_cython(x):
return type(x).__name__ == "cython_function_or_method"
# Check if function or method, respectively
return check_cython(obj) or \
(hasattr(obj, "__func__") and check_cython(obj.__func__))
def is_function_or_method(obj):
"""Check if an object is a function or method.
Args:
obj: The Python object in question.
Returns:
True if the object is an function or method.
"""
return inspect.isfunction(obj) or inspect.ismethod(obj) or is_cython(obj)
def is_class_method(f):
"""Returns whether the given method is a class_method."""
return hasattr(f, "__self__") and f.__self__ is not None
def is_static_method(cls, f_name):
"""Returns whether the class has a static method with the given name.
Args:
cls: The Python class (i.e. object of type `type`) to
search for the method in.
f_name: The name of the method to look up in this class
and check whether or not it is static.
"""
for cls in inspect.getmro(cls):
if f_name in cls.__dict__:
return isinstance(cls.__dict__[f_name], staticmethod)
return False
-48
View File
@@ -1,7 +1,6 @@
import binascii
import errno
import hashlib
import inspect
import logging
import multiprocessing
import numpy as np
@@ -129,53 +128,6 @@ def push_error_to_driver_through_redis(redis_client,
pubsub_msg.SerializeToString())
def is_cython(obj):
"""Check if an object is a Cython function or method"""
# TODO(suo): We could split these into two functions, one for Cython
# functions and another for Cython methods.
# TODO(suo): There doesn't appear to be a Cython function 'type' we can
# check against via isinstance. Please correct me if I'm wrong.
def check_cython(x):
return type(x).__name__ == "cython_function_or_method"
# Check if function or method, respectively
return check_cython(obj) or \
(hasattr(obj, "__func__") and check_cython(obj.__func__))
def is_function_or_method(obj):
"""Check if an object is a function or method.
Args:
obj: The Python object in question.
Returns:
True if the object is an function or method.
"""
return inspect.isfunction(obj) or inspect.ismethod(obj) or is_cython(obj)
def is_class_method(f):
"""Returns whether the given method is a class_method."""
return hasattr(f, "__self__") and f.__self__ is not None
def is_static_method(cls, f_name):
"""Returns whether the class has a static method with the given name.
Args:
cls: The Python class (i.e. object of type `type`) to
search for the method in.
f_name: The name of the method to look up in this class
and check whether or not it is static.
"""
for cls in inspect.getmro(cls):
if f_name in cls.__dict__:
return isinstance(cls.__dict__[f_name], staticmethod)
return False
def random_string():
"""Generate a random string to use as an ID.
+2 -1
View File
@@ -48,7 +48,8 @@ from ray.exceptions import (
)
from ray.function_manager import FunctionActorManager
from ray.ray_logging import setup_logger
from ray.utils import _random_string, check_oversized_pickle, is_cython
from ray.utils import _random_string, check_oversized_pickle
from ray.util.inspect import is_cython
SCRIPT_MODE = 0
WORKER_MODE = 1