diff --git a/python/ray/function_manager.py b/python/ray/function_manager.py index 3be8fb5b4..814fe4cd9 100644 --- a/python/ray/function_manager.py +++ b/python/ray/function_manager.py @@ -664,7 +664,11 @@ class FunctionActorManager(object): function_descriptor.class_name) try: module = importlib.import_module(module_name) - return getattr(module, class_name)._modified_class + actor_class = getattr(module, class_name) + if isinstance(actor_class, ray.actor.ActorClass): + return actor_class._modified_class + else: + return actor_class except Exception: logger.exception( "Failed to load actor_class %s.".format(class_name)) diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 9791643ac..2ce07b305 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -2822,6 +2822,9 @@ class BaseClass(object): def __init__(self, data): self.data = data + def get_data(self): + return self.data + @ray.remote class DerivedClass(BaseClass): @@ -2830,14 +2833,12 @@ class DerivedClass(BaseClass): # we use BaseClass directly here. BaseClass.__init__(self, data) - def get_data(self): - return self.data - def test_load_code_from_local(shutdown_only): ray.init(load_code_from_local=True, num_cpus=4) + message = "foo" # Test normal function. - assert ray.get(echo.remote("foo")) == "foo" + assert ray.get(echo.remote(message)) == message # Test actor class with constructor. actor = WithConstructor.remote(1) assert ray.get(actor.get_data.remote()) == 1 @@ -2848,3 +2849,7 @@ def test_load_code_from_local(shutdown_only): # Test derived actor class. actor = DerivedClass.remote(1) assert ray.get(actor.get_data.remote()) == 1 + # Test using ray.remote decorator on raw classes. + base_actor_class = ray.remote(num_cpus=1)(BaseClass) + base_actor = base_actor_class.remote(message) + assert ray.get(base_actor.get_data.remote()) == message