diff --git a/python/ray/function_manager.py b/python/ray/function_manager.py index e72fe64f1..3be8fb5b4 100644 --- a/python/ray/function_manager.py +++ b/python/ray/function_manager.py @@ -3,9 +3,11 @@ from __future__ import division from __future__ import print_function import hashlib +import importlib import inspect import json import logging +import six import sys import time import traceback @@ -87,9 +89,9 @@ class FunctionDescriptor(object): return FunctionDescriptor.for_driver_task() elif (len(function_descriptor_list) == 3 or len(function_descriptor_list) == 4): - module_name = function_descriptor_list[0].decode() - class_name = function_descriptor_list[1].decode() - function_name = function_descriptor_list[2].decode() + module_name = six.ensure_str(function_descriptor_list[0]) + class_name = six.ensure_str(function_descriptor_list[1]) + function_name = six.ensure_str(function_descriptor_list[2]) if len(function_descriptor_list) == 4: return cls(module_name, function_name, class_name, function_descriptor_list[3]) @@ -256,6 +258,14 @@ class FunctionDescriptor(object): descriptor_list.append(self._function_source_hash) return descriptor_list + def is_actor_method(self): + """Wether this function descriptor is an actor method. + + Returns: + True if it's an actor method, False if it's a normal function. + """ + return len(self._class_name) > 0 + class FunctionActorManager(object): """A class used to export/load remote functions and actors. @@ -289,13 +299,18 @@ class FunctionActorManager(object): # import thread. It is safe to convert this worker into an actor of # these types. self.imported_actor_classes = set() + self._loaded_actor_classes = {} def increase_task_counter(self, driver_id, function_descriptor): function_id = function_descriptor.function_id + if self._worker.load_code_from_local: + driver_id = ray.DriverID.nil() self._num_task_executions[driver_id][function_id] += 1 def get_task_counter(self, driver_id, function_descriptor): function_id = function_descriptor.function_id + if self._worker.load_code_from_local: + driver_id = ray.DriverID.nil() return self._num_task_executions[driver_id][function_id] def export_cached(self): @@ -336,6 +351,8 @@ class FunctionActorManager(object): Args: remote_function: the RemoteFunction object. """ + if self._worker.load_code_from_local: + return # Work around limitations of Python pickling. function = remote_function._function function_name_global_valid = function.__name__ in function.__globals__ @@ -436,16 +453,24 @@ class FunctionActorManager(object): Returns: A FunctionExecutionInfo object. """ - function_id = function_descriptor.function_id - - # Wait until the function to be executed has actually been - # registered on this worker. We will push warnings to the user if - # we spend too long in this loop. - # The driver function may not be found in sys.path. Try to load - # the function from GCS. - with profiling.profile("wait_for_function"): - self._wait_for_function(function_descriptor, driver_id) + if self._worker.load_code_from_local: + # Load function from local code. + # Currently, we don't support isolating code by drivers, + # thus always set driver ID to NIL here. + driver_id = ray.DriverID.nil() + if not function_descriptor.is_actor_method(): + self._load_function_from_local(driver_id, function_descriptor) + else: + # Load function from GCS. + # Wait until the function to be executed has actually been + # registered on this worker. We will push warnings to the user if + # we spend too long in this loop. + # The driver function may not be found in sys.path. Try to load + # the function from GCS. + with profiling.profile("wait_for_function"): + self._wait_for_function(function_descriptor, driver_id) try: + function_id = function_descriptor.function_id info = self._function_execution_info[driver_id][function_id] except KeyError as e: message = ("Error occurs in get_execution_info: " @@ -454,6 +479,33 @@ class FunctionActorManager(object): raise KeyError(message) return info + def _load_function_from_local(self, driver_id, function_descriptor): + assert not function_descriptor.is_actor_method() + function_id = function_descriptor.function_id + if (driver_id in self._function_execution_info + and function_id in self._function_execution_info[function_id]): + return + module_name, function_name = ( + function_descriptor.module_name, + function_descriptor.function_name, + ) + try: + module = importlib.import_module(module_name) + function = getattr(module, function_name)._function + self._function_execution_info[driver_id][function_id] = ( + FunctionExecutionInfo( + function=function, + function_name=function_name, + max_calls=0, + )) + self._num_task_executions[driver_id][function_id] = 0 + except Exception: + logger.exception( + "Failed to load function %s.".format(function_name)) + raise Exception( + "Function {} failed to be loaded from local code.".format( + function_descriptor)) + def _wait_for_function(self, function_descriptor, driver_id, timeout=10): """Wait until the function to be executed is present on this worker. @@ -513,6 +565,8 @@ class FunctionActorManager(object): self._worker.redis_client.rpush("Exports", key) def export_actor_class(self, Class, actor_method_names): + if self._worker.load_code_from_local: + return function_descriptor = FunctionDescriptor.from_class(Class) # `task_driver_id` shouldn't be NIL, unless: # 1) This worker isn't an actor; @@ -553,7 +607,87 @@ class FunctionActorManager(object): # within tasks. I tried to disable this, but it may be necessary # because of https://github.com/ray-project/ray/issues/1146. - def load_actor(self, driver_id, function_descriptor): + def load_actor_class(self, driver_id, function_descriptor): + """Load the actor class. + + Args: + driver_id: Driver ID of the actor. + function_descriptor: Function descriptor of the actor constructor. + + Returns: + The actor class. + """ + function_id = function_descriptor.function_id + # Check if the actor class already exists in the cache. + actor_class = self._loaded_actor_classes.get(function_id, None) + if actor_class is None: + # Load actor class. + if self._worker.load_code_from_local: + driver_id = ray.DriverID.nil() + # Load actor class from local code. + actor_class = self._load_actor_from_local( + driver_id, function_descriptor) + else: + # Load actor class from GCS. + actor_class = self._load_actor_class_from_gcs( + driver_id, function_descriptor) + # Save the loaded actor class in cache. + self._loaded_actor_classes[function_id] = actor_class + + # Generate execution info for the methods of this actor class. + module_name = function_descriptor.module_name + actor_class_name = function_descriptor.class_name + actor_methods = inspect.getmembers( + actor_class, predicate=is_function_or_method) + for actor_method_name, actor_method in actor_methods: + method_descriptor = FunctionDescriptor( + module_name, actor_method_name, actor_class_name) + method_id = method_descriptor.function_id + executor = self._make_actor_method_executor( + actor_method_name, + actor_method, + actor_imported=True, + ) + self._function_execution_info[driver_id][method_id] = ( + FunctionExecutionInfo( + function=executor, + function_name=actor_method_name, + max_calls=0, + )) + self._num_task_executions[driver_id][method_id] = 0 + self._num_task_executions[driver_id][function_id] = 0 + return actor_class + + def _load_actor_from_local(self, driver_id, function_descriptor): + """Load actor class from local code.""" + module_name, class_name = (function_descriptor.module_name, + function_descriptor.class_name) + try: + module = importlib.import_module(module_name) + return getattr(module, class_name)._modified_class + except Exception: + logger.exception( + "Failed to load actor_class %s.".format(class_name)) + raise Exception( + "Actor {} failed to be imported from local code.".format( + class_name)) + + def _create_fake_actor_class(self, actor_class_name, actor_method_names): + class TemporaryActor(object): + pass + + def temporary_actor_method(*xs): + raise Exception( + "The actor with name {} failed to be imported, " + "and so cannot execute this method.".format(actor_class_name)) + + for method in actor_method_names: + setattr(TemporaryActor, method, temporary_actor_method) + + return TemporaryActor + + def _load_actor_class_from_gcs(self, driver_id, function_descriptor): + """Load actor class from GCS.""" key = (b"ActorClass:" + driver_id.binary() + b":" + function_descriptor.function_id.binary()) # Wait for the actor class key to have been imported by the @@ -562,74 +696,32 @@ class FunctionActorManager(object): # the driver if too much time is spent here. while key not in self.imported_actor_classes: time.sleep(0.001) - with self._worker.lock: - self.fetch_and_register_actor(key) - def fetch_and_register_actor(self, actor_class_key): - """Import an actor. - - This will be called by the worker's import thread when the worker - receives the actor_class export, assuming that the worker is an actor - for that class. - - Args: - actor_class_key: The key in Redis to use to fetch the actor. - """ - actor_id = self._worker.actor_id + # Fetch raw data from GCS. (driver_id_str, class_name, module, pickled_class, actor_method_names) = self._worker.redis_client.hmget( - actor_class_key, [ + key, [ "driver_id", "class_name", "module", "class", "actor_method_names" ]) - class_name = decode(class_name) - module = decode(module) + class_name = six.ensure_str(class_name) + module_name = six.ensure_str(module) driver_id = ray.DriverID(driver_id_str) - actor_method_names = json.loads(decode(actor_method_names)) - - # In Python 2, json loads strings as unicode, so convert them back to - # strings. - if sys.version_info < (3, 0): - actor_method_names = [ - method_name.encode("ascii") - for method_name in actor_method_names - ] - - # Create a temporary actor with some temporary methods so that if - # the actor fails to be unpickled, the temporary actor can be used - # (just to produce error messages and to prevent the driver from - # hanging). - class TemporaryActor(object): - pass - - self._worker.actors[actor_id] = TemporaryActor() - - def temporary_actor_method(*xs): - raise Exception( - "The actor with name {} failed to be imported, " - "and so cannot execute this method".format(class_name)) - - # Register the actor method executors. - for actor_method_name in actor_method_names: - function_descriptor = FunctionDescriptor(module, actor_method_name, - class_name) - function_id = function_descriptor.function_id - temporary_executor = self._make_actor_method_executor( - actor_method_name, - temporary_actor_method, - actor_imported=False) - self._function_execution_info[driver_id][function_id] = ( - FunctionExecutionInfo( - function=temporary_executor, - function_name=actor_method_name, - max_calls=0)) - self._num_task_executions[driver_id][function_id] = 0 + actor_method_names = json.loads(six.ensure_str(actor_method_names)) + actor_class = None try: - unpickled_class = pickle.loads(pickled_class) - self._worker.actor_class = unpickled_class + with self._worker.lock: + actor_class = pickle.loads(pickled_class) except Exception: + logger.exception( + "Failed to load actor class %s.".format(class_name)) + # The actor class failed to be unpickled, create a fake actor + # class instead (just to produce error messages and to prevent + # the driver from hanging). + actor_class = self._create_fake_actor_class( + class_name, actor_method_names) # If an exception was thrown when the actor was imported, we record # the traceback and notify the scheduler of the failure. traceback_str = ray.utils.format_error_message( @@ -638,33 +730,20 @@ class FunctionActorManager(object): push_error_to_driver( self._worker, ray_constants.REGISTER_ACTOR_PUSH_ERROR, "Failed to unpickle actor class '{}' for actor ID {}. " - "Traceback:\n{}".format(class_name, actor_id.hex(), + "Traceback:\n{}".format(class_name, + self._worker.actor_id.hex(), traceback_str), driver_id) # TODO(rkn): In the future, it might make sense to have the worker # exit here. However, currently that would lead to hanging if # someone calls ray.get on a method invoked on the actor. - else: - # TODO(pcm): Why is the below line necessary? - unpickled_class.__module__ = module - self._worker.actors[actor_id] = unpickled_class.__new__( - unpickled_class) - actor_methods = inspect.getmembers( - unpickled_class, predicate=is_function_or_method) - for actor_method_name, actor_method in actor_methods: - function_descriptor = FunctionDescriptor( - module, actor_method_name, class_name) - function_id = function_descriptor.function_id - executor = self._make_actor_method_executor( - actor_method_name, actor_method, actor_imported=True) - self._function_execution_info[driver_id][function_id] = ( - FunctionExecutionInfo( - function=executor, - function_name=actor_method_name, - max_calls=0)) - # We do not set function_properties[driver_id][function_id] - # because we currently do need the actor worker to submit new - # tasks for the actor. + # The below line is necessary. Because in the driver process, + # if the function is defined in the file where the python script + # was started from, its module is `__main__`. + # However in the worker process, the `__main__` module is a + # different module, which is `default_worker.py` + actor_class.__module__ = module_name + return actor_class def _make_actor_method_executor(self, method_name, method, actor_imported): """Make an executor that wraps a user-defined actor method. diff --git a/python/ray/node.py b/python/ray/node.py index c5aa55fef..cee9e6fd0 100644 --- a/python/ray/node.py +++ b/python/ray/node.py @@ -355,6 +355,7 @@ class Node(object): config=self._config, include_java=self._ray_params.include_java, java_worker_options=self._ray_params.java_worker_options, + load_code_from_local=self._ray_params.load_code_from_local, ) assert ray_constants.PROCESS_TYPE_RAYLET not in self.all_processes self.all_processes[ray_constants.PROCESS_TYPE_RAYLET] = [process_info] diff --git a/python/ray/parameter.py b/python/ray/parameter.py index 49d2ace4e..5eab3cf75 100644 --- a/python/ray/parameter.py +++ b/python/ray/parameter.py @@ -73,6 +73,7 @@ class RayParams(object): include_java (bool): If True, the raylet backend can also support Java worker. java_worker_options (str): The command options for Java worker. + load_code_from_local: Whether load code from local file or from GCS. _internal_config (str): JSON configuration for overriding RayConfig defaults. For testing purposes ONLY. """ @@ -110,6 +111,7 @@ class RayParams(object): autoscaling_config=None, include_java=False, java_worker_options=None, + load_code_from_local=False, _internal_config=None): self.object_id_seed = object_id_seed self.redis_address = redis_address @@ -141,6 +143,7 @@ class RayParams(object): self.autoscaling_config = autoscaling_config self.include_java = include_java self.java_worker_options = java_worker_options + self.load_code_from_local = load_code_from_local self._internal_config = _internal_config self._check_usage() diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index 138a933f0..0fa2bf829 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -208,6 +208,11 @@ def cli(logging_level, logging_format): default=None, type=str, help="Do NOT use this. This is for debugging/development purposes ONLY.") +@click.option( + "--load-code-from-local", + is_flag=True, + default=False, + help="Specify whether load code from local file or GCS serialization.") def start(node_ip_address, redis_address, redis_port, num_redis_shards, redis_max_clients, redis_password, redis_shard_ports, object_manager_port, node_manager_port, object_store_memory, @@ -215,7 +220,7 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, plasma_directory, huge_pages, autoscaling_config, no_redirect_worker_output, no_redirect_output, plasma_store_socket_name, raylet_socket_name, temp_dir, include_java, - java_worker_options, internal_config): + java_worker_options, load_code_from_local, internal_config): # Convert hostnames to numerical IP address. if node_ip_address is not None: node_ip_address = services.address_to_ip(node_ip_address) @@ -250,6 +255,7 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, temp_dir=temp_dir, include_java=include_java, java_worker_options=java_worker_options, + load_code_from_local=load_code_from_local, _internal_config=internal_config) if head: diff --git a/python/ray/services.py b/python/ray/services.py index c6debb905..e529420b1 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -973,7 +973,8 @@ def start_raylet(redis_address, stderr_file=None, config=None, include_java=False, - java_worker_options=None): + java_worker_options=None, + load_code_from_local=False): """Start a raylet, which is a combined local scheduler and object manager. Args: @@ -1068,6 +1069,9 @@ def start_raylet(redis_address, if node_manager_port is None: node_manager_port = 0 + if load_code_from_local: + start_worker_command += " --load-code-from-local " + command = [ RAYLET_EXECUTABLE, raylet_name, diff --git a/python/ray/worker.py b/python/ray/worker.py index 00e8cd959..427356043 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -152,6 +152,7 @@ class Worker(object): # A dictionary that maps from driver id to SerializationContext # TODO: clean up the SerializationContext once the job finished. self.serialization_context_map = {} + self.load_code_from_local = False self.function_actor_manager = FunctionActorManager(self) # Identity of the driver that this worker is processing. # It is a DriverID. @@ -915,8 +916,9 @@ class Worker(object): assert self.actor_id.is_nil() self.actor_id = task.actor_creation_id() self.actor_creation_task_id = task.task_id() - self.function_actor_manager.load_actor(driver_id, - function_descriptor) + actor_class = self.function_actor_manager.load_actor_class( + driver_id, function_descriptor) + self.actors[self.actor_id] = actor_class.__new__(actor_class) self.actor_checkpoint_info[self.actor_id] = ActorCheckpointInfo( num_tasks_since_last_checkpoint=0, last_checkpoint_timestamp=int(1000 * time.time()), @@ -1271,6 +1273,7 @@ def init(redis_address=None, plasma_store_socket_name=None, raylet_socket_name=None, temp_dir=None, + load_code_from_local=False, _internal_config=None): """Connect to an existing Ray cluster or start one and connect to it. @@ -1346,6 +1349,8 @@ def init(redis_address=None, used by the raylet process. temp_dir (str): If provided, it will specify the root temporary directory for the Ray process. + load_code_from_local: Whether code should be loaded from a local module + or from the GCS. _internal_config (str): JSON configuration for overriding RayConfig defaults. For testing purposes ONLY. @@ -1427,6 +1432,7 @@ def init(redis_address=None, plasma_store_socket_name=plasma_store_socket_name, raylet_socket_name=raylet_socket_name, temp_dir=temp_dir, + load_code_from_local=load_code_from_local, _internal_config=_internal_config, ) # Start the Ray processes. We set shutdown_at_exit=False because we @@ -1516,7 +1522,8 @@ def init(redis_address=None, mode=driver_mode, log_to_driver=log_to_driver, worker=global_worker, - driver_id=driver_id) + driver_id=driver_id, + load_code_from_local=load_code_from_local) for hook in _post_init_hooks: hook() @@ -1745,7 +1752,8 @@ def connect(info, mode=WORKER_MODE, log_to_driver=False, worker=global_worker, - driver_id=None): + driver_id=None, + load_code_from_local=False): """Connect this worker to the local scheduler, to Plasma, and to Redis. Args: @@ -1802,6 +1810,7 @@ def connect(info, worker.actor_id = ActorID.nil() worker.connected = True worker.set_mode(mode) + worker.load_code_from_local = load_code_from_local # If running Ray in LOCAL_MODE, there is no need to create call # create_worker or to start the worker service. diff --git a/python/ray/workers/default_worker.py b/python/ray/workers/default_worker.py index c73167630..71b3a5f26 100644 --- a/python/ray/workers/default_worker.py +++ b/python/ray/workers/default_worker.py @@ -57,6 +57,11 @@ parser.add_argument( type=str, default=None, help="Specify the path of the temporary directory use by Ray process.") +parser.add_argument( + "--load-code-from-local", + default=False, + action='store_true', + help="True if code is loaded from local files, as opposed to the GCS.") if __name__ == "__main__": args = parser.parse_args() @@ -77,7 +82,8 @@ if __name__ == "__main__": redis_password=args.redis_password, plasma_store_socket_name=args.object_store_name, raylet_socket_name=args.raylet_name, - temp_dir=args.temp_dir) + temp_dir=args.temp_dir, + load_code_from_local=args.load_code_from_local) node = ray.node.Node( ray_params, head=False, shutdown_at_exit=False, connect_only=True) @@ -85,7 +91,10 @@ if __name__ == "__main__": # TODO(suquark): Use "node" as the input of "connect". ray.worker.connect( - info, redis_password=args.redis_password, mode=ray.WORKER_MODE) + info, + redis_password=args.redis_password, + mode=ray.WORKER_MODE, + load_code_from_local=args.load_code_from_local) error_explanation = """ This error is unexpected and should not have happened. Somehow a worker diff --git a/python/setup.py b/python/setup.py index 18af248c9..1bdbb9021 100644 --- a/python/setup.py +++ b/python/setup.py @@ -148,8 +148,7 @@ requires = [ "pytest", "pyyaml", "redis", - # The six module is required by pyarrow. - "six >= 1.0.0", + "six >= 1.12.0", # The typing module is required by modin. "typing", "flatbuffers", diff --git a/test/runtest.py b/test/runtest.py index 6efd6ff08..1ad95233b 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -156,9 +156,9 @@ def test_complex_serialization(ray_start): assert_equal(obj1[i], obj2[i]) elif (ray.serialization.is_named_tuple(type(obj1)) or ray.serialization.is_named_tuple(type(obj2))): - assert len(obj1) == len(obj2), ("Objects {} and {} are named " - "tuples with different lengths." - .format(obj1, obj2)) + assert len(obj1) == len(obj2), ( + "Objects {} and {} are named " + "tuples with different lengths.".format(obj1, obj2)) for i in range(len(obj1)): assert_equal(obj1[i], obj2[i]) else: @@ -2843,3 +2843,58 @@ def test_non_ascii_comment(ray_start): return 1 assert ray.get(f.remote()) == 1 + + +@ray.remote +def echo(x): + return x + + +@ray.remote +class WithConstructor(object): + def __init__(self, data): + self.data = data + + def get_data(self): + return self.data + + +@ray.remote +class WithoutConstructor(object): + def set_data(self, data): + self.data = data + + def get_data(self): + return self.data + + +class BaseClass(object): + def __init__(self, data): + self.data = data + + +@ray.remote +class DerivedClass(BaseClass): + def __init__(self, data): + # Due to different behaviors of super in Python 2 and Python 3, + # we use BaseClass directly here. + BaseClass.__init__(self, data) + + def get_data(self): + return self.data + + +def test_load_code_from_local(shutdown_only): + ray.init(load_code_from_local=True, num_cpus=4) + # Test normal function. + assert ray.get(echo.remote("foo")) == "foo" + # Test actor class with constructor. + actor = WithConstructor.remote(1) + assert ray.get(actor.get_data.remote()) == 1 + # Test actor class without constructor. + actor = WithoutConstructor.remote() + actor.set_data.remote(1) + assert ray.get(actor.get_data.remote()) == 1 + # Test derived actor class. + actor = DerivedClass.remote(1) + assert ray.get(actor.get_data.remote()) == 1