diff --git a/python/ray/actor.py b/python/ray/actor.py index 93ee1cec4..690f8ce1c 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -594,7 +594,6 @@ class ActorClass(object): A handle to the newly created actor. """ worker = ray.worker.get_global_worker() - ray.worker.check_main_thread() if worker.mode is None: raise Exception("Actors cannot be created before ray.init() " "has been called.") @@ -773,7 +772,6 @@ class ActorHandle(object): worker = ray.worker.get_global_worker() worker.check_connected() - ray.worker.check_main_thread() function_signature = self._ray_method_signatures[method_name] if args is None: @@ -929,7 +927,6 @@ class ActorHandle(object): """ worker = ray.worker.get_global_worker() worker.check_connected() - ray.worker.check_main_thread() if state["ray_forking"]: actor_handle_id = compute_actor_handle_id( diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index 64514696a..a3d8c450e 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -114,7 +114,6 @@ class RemoteFunction(object): """An experimental alternate way to submit remote functions.""" worker = ray.worker.get_global_worker() worker.check_connected() - ray.worker.check_main_thread() kwargs = {} if kwargs is None else kwargs args = ray.signature.extend_args(self._function_signature, args, kwargs) diff --git a/python/ray/utils.py b/python/ray/utils.py index 82d475b94..d94b631ae 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -3,10 +3,12 @@ from __future__ import division from __future__ import print_function import binascii +import functools import hashlib import numpy as np import os import sys +import threading import time import uuid @@ -295,3 +297,57 @@ def check_oversized_pickle(pickled, name, obj_type, worker): ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR, warning_message, driver_id=worker.task_driver_id.id()) + + +class _ThreadSafeProxy(object): + """This class is used to create a thread-safe proxy for a given object. + Every method call will be guarded with a lock. + + Attributes: + orig_obj (object): the original object. + lock (threading.Lock): the lock object. + _wrapper_cache (dict): a cache from original object's methods to + the proxy methods. + """ + + def __init__(self, orig_obj, lock): + self.orig_obj = orig_obj + self.lock = lock + self._wrapper_cache = {} + + def __getattr__(self, attr): + orig_attr = getattr(self.orig_obj, attr) + if not callable(orig_attr): + # If the original attr is a field, just return it. + return orig_attr + else: + # If the orginal attr is a method, + # return a wrapper that guards the original method with a lock. + wrapper = self._wrapper_cache.get(attr) + if wrapper is None: + + @functools.wraps(orig_attr) + def _wrapper(*args, **kwargs): + with self.lock: + return orig_attr(*args, **kwargs) + + self._wrapper_cache[attr] = _wrapper + wrapper = _wrapper + return wrapper + + +def thread_safe_client(client, lock=None): + """Create a thread-safe proxy which locks every method call + for the given client. + + Args: + client: the client object to be guarded. + lock: the lock object that will be used to lock client's methods. + If None, a new lock will be used. + + Returns: + A thread-safe proxy for the given client. + """ + if lock is None: + lock = threading.Lock() + return _ThreadSafeProxy(client, lock) diff --git a/python/ray/worker.py b/python/ray/worker.py index 60273190c..e40ec7f91 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -37,6 +37,7 @@ from ray.utils import ( check_oversized_pickle, is_cython, random_string, + thread_safe_client, ) SCRIPT_MODE = 0 @@ -200,6 +201,13 @@ class Worker(object): cached_functions_to_run (List): A list of functions to run on all of the workers that should be exported as soon as connect is called. profiler: the profiler used to aggregate profiling information. + state_lock (Lock): + Used to lock worker's non-thread-safe internal states: + 1) task_index increment: make sure we generate unique task ids; + 2) Object reconstruction: because the node manager will + recycle/return the worker's resources before/after reconstruction, + it's unsafe for multiple threads to call object + reconstruction simultaneously. """ def __init__(self): @@ -236,6 +244,7 @@ class Worker(object): # CUDA_VISIBLE_DEVICES environment variable. self.original_gpu_ids = ray.utils.get_cuda_visible_devices() self.profiler = profiling.Profiler(self) + self.state_lock = threading.Lock() def check_connected(self): """Check if the worker is connected. @@ -365,7 +374,7 @@ class Worker(object): # Serialize and put the object in the object store. try: self.store_and_register(object_id, value) - except pyarrow.PlasmaObjectExists as e: + except pyarrow.PlasmaObjectExists: # The object already exists in the object store, so there is no # need to add it again. TODO(rkn): We need to compare the hashes # and make sure that the objects are in fact the same. We also @@ -393,7 +402,7 @@ class Worker(object): i + ray._config.worker_get_request_size())], timeout, self.serialization_context) return results - except pyarrow.lib.ArrowInvalid as e: + except pyarrow.lib.ArrowInvalid: # TODO(ekl): the local scheduler could include relevant # metadata in the task kill case for a better error message invalid_error = RayTaskError( @@ -401,7 +410,7 @@ class Worker(object): "Invalid return value: likely worker died or was killed " "while executing the task.") return [invalid_error] * len(object_ids) - except pyarrow.DeserializationCallbackError as e: + except pyarrow.DeserializationCallbackError: # Wait a little bit for the import thread to import the class. # If we currently have the worker lock, we need to release it # so that the import thread can acquire it. @@ -466,52 +475,59 @@ class Worker(object): for (i, val) in enumerate(final_results) if val is plasma.ObjectNotAvailable } - was_blocked = (len(unready_ids) > 0) - # Try reconstructing any objects we haven't gotten yet. Try to get them - # until at least get_timeout_milliseconds milliseconds passes, then - # repeat. - while len(unready_ids) > 0: - for unready_id in unready_ids: - if not self.use_raylet: - self.local_scheduler_client.reconstruct_objects( - [ray.ObjectID(unready_id)], False) - # Do another fetch for objects that aren't available locally yet, - # in case they were evicted since the last fetch. We divide the - # fetch into smaller fetches so as to not block the manager for a - # prolonged period of time in a single call. - object_ids_to_fetch = list( - map(plasma.ObjectID, unready_ids.keys())) - ray_object_ids_to_fetch = list( - map(ray.ObjectID, unready_ids.keys())) - for i in range(0, len(object_ids_to_fetch), - ray._config.worker_fetch_request_size()): - if not self.use_raylet: - self.plasma_client.fetch(object_ids_to_fetch[i:( - i + ray._config.worker_fetch_request_size())]) - else: - self.local_scheduler_client.reconstruct_objects( - ray_object_ids_to_fetch[i:( - i + ray._config.worker_fetch_request_size())], - False) - results = self.retrieve_and_deserialize( - object_ids_to_fetch, - max([ - ray._config.get_timeout_milliseconds(), - int(0.01 * len(unready_ids)) - ])) - # Remove any entries for objects we received during this iteration - # so we don't retrieve the same object twice. - for i, val in enumerate(results): - if val is not plasma.ObjectNotAvailable: - object_id = object_ids_to_fetch[i].binary() - index = unready_ids[object_id] - final_results[index] = val - unready_ids.pop(object_id) - # If there were objects that we weren't able to get locally, let the - # local scheduler know that we're now unblocked. - if was_blocked: - self.local_scheduler_client.notify_unblocked() + if len(unready_ids) > 0: + with self.state_lock: + # Try reconstructing any objects we haven't gotten yet. Try to + # get them until at least get_timeout_milliseconds + # milliseconds passes, then repeat. + while len(unready_ids) > 0: + for unready_id in unready_ids: + if not self.use_raylet: + self.local_scheduler_client.reconstruct_objects( + [ray.ObjectID(unready_id)], False) + # Do another fetch for objects that aren't available + # locally yet, in case they were evicted since the last + # fetch. We divide the fetch into smaller fetches so as + # to not block the manager for a prolonged period of time + # in a single call. + object_ids_to_fetch = [ + plasma.ObjectID(unready_id) + for unready_id in unready_ids.keys() + ] + ray_object_ids_to_fetch = [ + ray.ObjectID(unready_id) + for unready_id in unready_ids.keys() + ] + fetch_request_size = ( + ray._config.worker_fetch_request_size()) + for i in range(0, len(object_ids_to_fetch), + fetch_request_size): + if not self.use_raylet: + self.plasma_client.fetch(object_ids_to_fetch[i:( + i + fetch_request_size)]) + else: + self.local_scheduler_client.reconstruct_objects( + ray_object_ids_to_fetch[i:( + i + fetch_request_size)], False) + results = self.retrieve_and_deserialize( + object_ids_to_fetch, + max([ + ray._config.get_timeout_milliseconds(), + int(0.01 * len(unready_ids)) + ])) + # Remove any entries for objects we received during this + # iteration so we don't retrieve the same object twice. + for i, val in enumerate(results): + if val is not plasma.ObjectNotAvailable: + object_id = object_ids_to_fetch[i].binary() + index = unready_ids[object_id] + final_results[index] = val + unready_ids.pop(object_id) + + # If there were objects that we weren't able to get locally, + # let the local scheduler know that we're now unblocked. + self.local_scheduler_client.notify_unblocked() assert len(final_results) == len(object_ids) return final_results @@ -563,7 +579,6 @@ class Worker(object): The return object IDs for this task. """ with profiling.profile("submit_task", worker=self): - check_main_thread() if actor_id is None: assert actor_handle_id is None actor_id = ray.ObjectID(NIL_ACTOR_ID) @@ -607,17 +622,19 @@ class Worker(object): raise ValueError( "Resource quantities must all be whole numbers.") + with self.state_lock: + # Increment the worker's task index to track how many tasks + # have been submitted by the current task so far. + task_index = self.task_index + self.task_index += 1 # Submit the task to local scheduler. task = ray.local_scheduler.Task( driver_id, ray.ObjectID( function_id.id()), args_for_local_scheduler, - num_return_vals, self.current_task_id, self.task_index, + num_return_vals, self.current_task_id, task_index, actor_creation_id, actor_creation_dummy_object_id, actor_id, actor_handle_id, actor_counter, is_actor_checkpoint_method, execution_dependencies, resources, self.use_raylet) - # Increment the worker's task index to track how many tasks have - # been submitted by the current task so far. - self.task_index += 1 self.local_scheduler_client.submit(task) return task.returns() @@ -635,7 +652,6 @@ class Worker(object): decorated_function: The decorated function (this is used to enable the remote function to recursively call itself). """ - check_main_thread() if self.mode not in [SCRIPT_MODE, SILENT_MODE]: raise Exception("export_remote_function can only be called on a " "driver.") @@ -687,7 +703,6 @@ class Worker(object): should not take any arguments. If it returns anything, its return values will not be used. """ - check_main_thread() # If ray.init has not been called yet, then cache the function and # export it when connect is called. Otherwise, run the function on all # workers. @@ -1041,7 +1056,6 @@ class Worker(object): signal.signal(signal.SIGTERM, exit) - check_main_thread() while True: task = self._get_next_task_from_local_scheduler() self._wait_for_and_process_task(task) @@ -1143,20 +1157,6 @@ class RayConnectionError(Exception): pass -def check_main_thread(): - """Check that we are currently on the main thread. - - Raises: - Exception: An exception is raised if this is called on a thread other - than the main thread. - """ - if threading.current_thread().getName() != "MainThread": - raise Exception("The Ray methods are not thread safe and must be " - "called from the main thread. This method was called " - "from thread {}." - .format(threading.current_thread().getName())) - - def print_failed_task(task_status): """Print information about failed tasks. @@ -1191,12 +1191,9 @@ def error_applies_to_driver(error_key, worker=global_worker): def error_info(worker=global_worker): """Return information about failed tasks.""" worker.check_connected() - check_main_thread() - if worker.use_raylet: return (global_state.error_messages(job_id=worker.task_driver_id) + global_state.error_messages(job_id=ray_constants.NIL_JOB_ID)) - error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1) errors = [] for error_key in error_keys: @@ -1388,7 +1385,7 @@ def get_address_info_from_redis(redis_address, try: return get_address_info_from_redis_helper( redis_address, node_ip_address, use_raylet=use_raylet) - except Exception as e: + except Exception: if counter == num_retries: raise # Some of the information may not be in Redis yet, so wait a little @@ -1521,7 +1518,6 @@ def _init(address_info=None, Exception: An exception is raised if an inappropriate combination of arguments is passed in. """ - check_main_thread() if driver_mode not in [SCRIPT_MODE, LOCAL_MODE, SILENT_MODE]: raise Exception("Driver_mode must be in [ray.SCRIPT_MODE, " "ray.LOCAL_MODE, ray.SILENT_MODE].") @@ -1988,7 +1984,6 @@ def connect(info, LOCAL_MODE, and SILENT_MODE. use_raylet: True if the new raylet code path should be used. """ - check_main_thread() # Do some basic checking to make sure we didn't call ray.init twice. error_message = "Perhaps you called ray.init twice by accident?" assert not worker.connected, error_message @@ -2021,8 +2016,8 @@ def connect(info, # Create a Redis client. redis_ip_address, redis_port = info["redis_address"].split(":") - worker.redis_client = redis.StrictRedis( - host=redis_ip_address, port=int(redis_port)) + worker.redis_client = thread_safe_client( + redis.StrictRedis(host=redis_ip_address, port=int(redis_port))) # For driver's check that the version information matches the version # information that the Ray cluster was started with. @@ -2102,11 +2097,12 @@ def connect(info, # Create an object store client. if not worker.use_raylet: - worker.plasma_client = plasma.connect(info["store_socket_name"], - info["manager_socket_name"], 64) + worker.plasma_client = thread_safe_client( + plasma.connect(info["store_socket_name"], + info["manager_socket_name"], 64)) else: - worker.plasma_client = plasma.connect(info["store_socket_name"], "", - 64) + worker.plasma_client = thread_safe_client( + plasma.connect(info["store_socket_name"], "", 64)) if not worker.use_raylet: local_scheduler_socket = info["local_scheduler_socket_name"] @@ -2348,7 +2344,7 @@ def register_custom_serializer(cls, # worker. However, determinism is not guaranteed, and the result # may be different on different workers. class_id = _try_to_compute_deterministic_class_id(cls) - except Exception as e: + except Exception: raise serialization.CloudPickleError("Failed to pickle class " "'{}'".format(cls)) else: @@ -2399,8 +2395,6 @@ def get(object_ids, worker=global_worker): """ worker.check_connected() with profiling.profile("ray.get", worker=worker): - check_main_thread() - if worker.mode == LOCAL_MODE: # In LOCAL_MODE, ray.get is the identity operation (the input will # actually be a value not an objectid). @@ -2432,8 +2426,6 @@ def put(value, worker=global_worker): """ worker.check_connected() with profiling.profile("ray.put", worker=worker): - check_main_thread() - if worker.mode == LOCAL_MODE: # In LOCAL_MODE, ray.put is the identity operation. return value @@ -2491,8 +2483,6 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): worker.check_connected() with profiling.profile("ray.wait", worker=worker): - check_main_thread() - # When Ray is run in LOCAL_MODE, all functions are run immediately, # so all objects in object_id are ready. if worker.mode == LOCAL_MODE: diff --git a/test/runtest.py b/test/runtest.py index 27e017189..d55d9c91d 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -4,6 +4,7 @@ import os import re import string import sys +import threading import time import unittest from collections import defaultdict, namedtuple, OrderedDict @@ -1144,6 +1145,37 @@ class APITest(unittest.TestCase): with self.assertRaises(Exception): ray.get(3) + def testMultithreading(self): + self.init_ray(driver_mode=ray.SILENT_MODE) + + @ray.remote + def f(): + pass + + def g(n): + for _ in range(1000 // n): + ray.get([f.remote() for _ in range(n)]) + res = [ray.put(i) for i in range(1000 // n)] + ray.wait(res, len(res)) + + def test_multi_threading(): + threads = [ + threading.Thread(target=g, args=(n, )) + for n in [1, 5, 10, 100, 1000] + ] + + [thread.start() for thread in threads] + [thread.join() for thread in threads] + + @ray.remote + def test_multi_threading_in_worker(): + test_multi_threading() + + # test multi-threading in the driver + test_multi_threading() + # test multi-threading in the worker + ray.get(test_multi_threading_in_worker.remote()) + @unittest.skipIf( os.environ.get('RAY_USE_NEW_GCS', False),