diff --git a/python/ray/actor.py b/python/ray/actor.py index 180b903ec..c012b58e5 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -17,6 +17,12 @@ from ray.utils import _random_string, is_cython, push_error_to_driver DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS = 1 +def is_classmethod(f): + """Returns whether the given method is a classmethod.""" + + return hasattr(f, "__self__") and f.__self__ is not None + + def compute_actor_handle_id(actor_handle_id, num_forks): """Deterministically compute an actor handle ID. @@ -242,7 +248,10 @@ def make_actor_method_executor(worker, method_name, method, actor_imported): # Execute the assigned method and save a checkpoint if necessary. try: - method_returns = method(actor, *args) + if is_classmethod(method): + method_returns = method(*args) + else: + method_returns = method(actor, *args) except Exception: # Save the checkpoint before allowing the method exception to be # thrown. @@ -500,7 +509,7 @@ class ActorClass(object): # don't support, there may not be much the user can do about it. signature.check_signature_supported(method, warn=True) self._method_signatures[method_name] = signature.extract_signature( - method, ignore_first=True) + method, ignore_first=not is_classmethod(method)) # Set the default number of return values for this method. if hasattr(method, "__ray_num_return_vals__"): diff --git a/test/actor_test.py b/test/actor_test.py index 4509748cf..91b11d2e3 100644 --- a/test/actor_test.py +++ b/test/actor_test.py @@ -420,6 +420,32 @@ class ActorMethods(unittest.TestCase): c2.increase.remote() self.assertEqual(ray.get(c2.value.remote()), 2) + def testActorClassMethods(self): + ray.init() + + class Foo(object): + x = 2 + + @classmethod + def as_remote(cls): + return ray.remote(cls) + + @classmethod + def f(cls): + return cls.x + + @classmethod + def g(cls, y): + return cls.x + y + + def echo(self, value): + return value + + a = Foo.as_remote().remote() + self.assertEqual(ray.get(a.echo.remote(2)), 2) + self.assertEqual(ray.get(a.f.remote()), 2) + self.assertEqual(ray.get(a.g.remote(2)), 4) + def testMultipleActors(self): # Create a bunch of actors and call a bunch of methods on all of them. ray.init(num_workers=0)