diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index edbab95a5..9c184eacd 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -466,6 +466,7 @@ class Worker(object): be object IDs or they can be values. If they are values, they must be serializable objecs. """ + check_main_thread() # Put large or complex arguments that are passed by value in the object # store first. args_for_photon = [] @@ -503,6 +504,7 @@ class Worker(object): not take any arguments. If it returns anything, its return values will not be used. """ + check_main_thread() if self.mode not in [None, SCRIPT_MODE, SILENT_MODE, PYTHON_MODE]: raise Exception("run_function_on_all_workers can only be called on a driver.") # If ray.init has not been called yet, then cache the function and export it @@ -540,6 +542,16 @@ made by one task do not affect other tasks. class RayConnectionError(Exception): pass +def check_main_thread(): + """Check that we are currently on the main thread. + + Raises: + Exception: An exception is raised if this is called on a thread other than + the main thread. + """ + if threading.current_thread().getName() != "MainThread": + raise Exception("The Ray methods are not thread safe and must be called from the main thread. This method was called from thread {}.".format(threading.current_thread().getName())) + def check_connected(worker=global_worker): """Check if the worker is connected. @@ -565,6 +577,7 @@ def print_failed_task(task_status): def error_info(worker=global_worker): """Return information about failed tasks.""" + check_main_thread() check_connected(worker) result = {"TaskError": [], "RemoteFunctionImportError": [], @@ -627,6 +640,7 @@ def init(start_ray_local=False, num_workers=None, num_local_schedulers=1, driver Exception: An exception is raised if an inappropriate combination of arguments is passed in. """ + check_main_thread() if driver_mode == PYTHON_MODE: # If starting Ray in PYTHON_MODE, don't start any other processes. address_info = {} @@ -819,6 +833,7 @@ def connect(address_info, mode=WORKER_MODE, worker=global_worker): mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, and SILENT_MODE. """ + check_main_thread() worker.worker_id = random_string() worker.connected = True worker.set_mode(mode) @@ -952,6 +967,7 @@ def get(objectid, worker=global_worker): Returns: A Python object or a list of Python objects. """ + check_main_thread() check_connected(worker) if worker.mode == PYTHON_MODE: return objectid # In PYTHON_MODE, ray.get is the identity operation (the input will actually be a value not an objectid) @@ -977,6 +993,7 @@ def put(value, worker=global_worker): Returns: The object ID assigned to this value. """ + check_main_thread() check_connected(worker) if worker.mode == PYTHON_MODE: return value # In PYTHON_MODE, ray.put is the identity operation @@ -1006,6 +1023,7 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): Returns: A list of object IDs that are ready and a list of the remaining object IDs. """ + check_main_thread() check_connected(worker) object_id_strs = [object_id.id() for object_id in object_ids] timeout = timeout if timeout is not None else 2 ** 30 @@ -1095,6 +1113,7 @@ def main_loop(worker=global_worker): "message": traceback_str}) worker.redis_client.rpush("ErrorKeys", error_key) + check_main_thread() while True: task = worker.photon_client.get_task() function_id = task.function_id() @@ -1145,6 +1164,7 @@ def _export_reusable_variable(name, reusable, worker=global_worker): reusable (Reusable): The reusable object containing code for initializing and reinitializing the variable. """ + check_main_thread() if _mode(worker) not in [SCRIPT_MODE, SILENT_MODE]: raise Exception("_export_reusable_variable can only be called on a driver.") reusable_variable_id = name @@ -1156,6 +1176,7 @@ def _export_reusable_variable(name, reusable, worker=global_worker): worker.driver_export_counter += 1 def export_remote_function(function_id, func_name, func, num_return_vals, worker=global_worker): + check_main_thread() if _mode(worker) not in [SCRIPT_MODE, SILENT_MODE]: raise Exception("export_remote_function can only be called on a driver.") key = "RemoteFunction:{}".format(function_id.id()) @@ -1188,6 +1209,7 @@ def remote(*args, **kwargs): def func_call(*args, **kwargs): """This gets run immediately when a worker calls a remote function.""" + check_main_thread() check_connected() args = list(args) args.extend([kwargs[keyword] if kwargs.has_key(keyword) else default for keyword, default in keyword_defaults[len(args):]]) # fill in the remaining arguments