Fix support for actor classmethods (#2146)

This commit is contained in:
Eric Liang
2018-05-28 17:43:23 -07:00
committed by Robert Nishihara
parent eb1d7ac4bc
commit bc2a83e698
2 changed files with 37 additions and 2 deletions
+11 -2
View File
@@ -17,6 +17,12 @@ from ray.utils import _random_string, is_cython, push_error_to_driver
DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS = 1
def is_classmethod(f):
"""Returns whether the given method is a classmethod."""
return hasattr(f, "__self__") and f.__self__ is not None
def compute_actor_handle_id(actor_handle_id, num_forks):
"""Deterministically compute an actor handle ID.
@@ -242,7 +248,10 @@ def make_actor_method_executor(worker, method_name, method, actor_imported):
# Execute the assigned method and save a checkpoint if necessary.
try:
method_returns = method(actor, *args)
if is_classmethod(method):
method_returns = method(*args)
else:
method_returns = method(actor, *args)
except Exception:
# Save the checkpoint before allowing the method exception to be
# thrown.
@@ -500,7 +509,7 @@ class ActorClass(object):
# don't support, there may not be much the user can do about it.
signature.check_signature_supported(method, warn=True)
self._method_signatures[method_name] = signature.extract_signature(
method, ignore_first=True)
method, ignore_first=not is_classmethod(method))
# Set the default number of return values for this method.
if hasattr(method, "__ray_num_return_vals__"):
+26
View File
@@ -420,6 +420,32 @@ class ActorMethods(unittest.TestCase):
c2.increase.remote()
self.assertEqual(ray.get(c2.value.remote()), 2)
def testActorClassMethods(self):
ray.init()
class Foo(object):
x = 2
@classmethod
def as_remote(cls):
return ray.remote(cls)
@classmethod
def f(cls):
return cls.x
@classmethod
def g(cls, y):
return cls.x + y
def echo(self, value):
return value
a = Foo.as_remote().remote()
self.assertEqual(ray.get(a.echo.remote(2)), 2)
self.assertEqual(ray.get(a.f.remote()), 2)
self.assertEqual(ray.get(a.g.remote(2)), 4)
def testMultipleActors(self):
# Create a bunch of actors and call a bunch of methods on all of them.
ray.init(num_workers=0)