diff --git a/python/ray/actor.py b/python/ray/actor.py index e6fda03e5..1b2458e8c 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -285,15 +285,18 @@ def actor(*args, **kwargs): "use @ray.remote.") -def make_actor(Class, num_cpus, num_gpus): +def make_actor(cls, num_cpus, num_gpus): # Modify the class to have an additional method that will be used for # terminating the worker. - class Class(Class): + class Class(cls): def __ray_terminate__(self): ray.worker.global_worker.local_scheduler_client.disconnect() import os os._exit(0) + Class.__module__ = cls.__module__ + Class.__name__ = cls.__name__ + class_id = random_actor_class_id() # The list exported will have length 0 if the class has not been exported # yet, and length one if it has. This is just implementing a bool, but we diff --git a/test/actor_test.py b/test/actor_test.py index c250652f9..0c522116a 100644 --- a/test/actor_test.py +++ b/test/actor_test.py @@ -234,6 +234,25 @@ class ActorAPI(unittest.TestCase): ray.worker.cleanup() + def testActorClassName(self): + ray.init(num_workers=0) + + @ray.remote + class Foo(object): + def __init__(self): + pass + + Foo.remote() + + r = ray.worker.global_worker.redis_client + actor_keys = r.keys("ActorClass*") + self.assertEqual(len(actor_keys), 1) + actor_class_info = r.hgetall(actor_keys[0]) + self.assertEqual(actor_class_info[b"class_name"], b"Foo") + self.assertEqual(actor_class_info[b"module"], b"__main__") + + ray.worker.cleanup() + class ActorMethods(unittest.TestCase):