Throw an exception if a Ray method is called from a thread that isn't the main thread. (#97)

This commit is contained in:
Robert Nishihara
2016-12-10 21:24:50 -08:00
committed by Philipp Moritz
parent 9474d03912
commit 0f7091099d
+22
View File
@@ -466,6 +466,7 @@ class Worker(object):
be object IDs or they can be values. If they are values, they
must be serializable objecs.
"""
check_main_thread()
# Put large or complex arguments that are passed by value in the object
# store first.
args_for_photon = []
@@ -503,6 +504,7 @@ class Worker(object):
not take any arguments. If it returns anything, its return values will
not be used.
"""
check_main_thread()
if self.mode not in [None, SCRIPT_MODE, SILENT_MODE, PYTHON_MODE]:
raise Exception("run_function_on_all_workers can only be called on a driver.")
# If ray.init has not been called yet, then cache the function and export it
@@ -540,6 +542,16 @@ made by one task do not affect other tasks.
class RayConnectionError(Exception):
pass
def check_main_thread():
"""Check that we are currently on the main thread.
Raises:
Exception: An exception is raised if this is called on a thread other than
the main thread.
"""
if threading.current_thread().getName() != "MainThread":
raise Exception("The Ray methods are not thread safe and must be called from the main thread. This method was called from thread {}.".format(threading.current_thread().getName()))
def check_connected(worker=global_worker):
"""Check if the worker is connected.
@@ -565,6 +577,7 @@ def print_failed_task(task_status):
def error_info(worker=global_worker):
"""Return information about failed tasks."""
check_main_thread()
check_connected(worker)
result = {"TaskError": [],
"RemoteFunctionImportError": [],
@@ -627,6 +640,7 @@ def init(start_ray_local=False, num_workers=None, num_local_schedulers=1, driver
Exception: An exception is raised if an inappropriate combination of
arguments is passed in.
"""
check_main_thread()
if driver_mode == PYTHON_MODE:
# If starting Ray in PYTHON_MODE, don't start any other processes.
address_info = {}
@@ -819,6 +833,7 @@ def connect(address_info, mode=WORKER_MODE, worker=global_worker):
mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, PYTHON_MODE,
and SILENT_MODE.
"""
check_main_thread()
worker.worker_id = random_string()
worker.connected = True
worker.set_mode(mode)
@@ -952,6 +967,7 @@ def get(objectid, worker=global_worker):
Returns:
A Python object or a list of Python objects.
"""
check_main_thread()
check_connected(worker)
if worker.mode == PYTHON_MODE:
return objectid # In PYTHON_MODE, ray.get is the identity operation (the input will actually be a value not an objectid)
@@ -977,6 +993,7 @@ def put(value, worker=global_worker):
Returns:
The object ID assigned to this value.
"""
check_main_thread()
check_connected(worker)
if worker.mode == PYTHON_MODE:
return value # In PYTHON_MODE, ray.put is the identity operation
@@ -1006,6 +1023,7 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
Returns:
A list of object IDs that are ready and a list of the remaining object IDs.
"""
check_main_thread()
check_connected(worker)
object_id_strs = [object_id.id() for object_id in object_ids]
timeout = timeout if timeout is not None else 2 ** 30
@@ -1095,6 +1113,7 @@ def main_loop(worker=global_worker):
"message": traceback_str})
worker.redis_client.rpush("ErrorKeys", error_key)
check_main_thread()
while True:
task = worker.photon_client.get_task()
function_id = task.function_id()
@@ -1145,6 +1164,7 @@ def _export_reusable_variable(name, reusable, worker=global_worker):
reusable (Reusable): The reusable object containing code for initializing
and reinitializing the variable.
"""
check_main_thread()
if _mode(worker) not in [SCRIPT_MODE, SILENT_MODE]:
raise Exception("_export_reusable_variable can only be called on a driver.")
reusable_variable_id = name
@@ -1156,6 +1176,7 @@ def _export_reusable_variable(name, reusable, worker=global_worker):
worker.driver_export_counter += 1
def export_remote_function(function_id, func_name, func, num_return_vals, worker=global_worker):
check_main_thread()
if _mode(worker) not in [SCRIPT_MODE, SILENT_MODE]:
raise Exception("export_remote_function can only be called on a driver.")
key = "RemoteFunction:{}".format(function_id.id())
@@ -1188,6 +1209,7 @@ def remote(*args, **kwargs):
def func_call(*args, **kwargs):
"""This gets run immediately when a worker calls a remote function."""
check_main_thread()
check_connected()
args = list(args)
args.extend([kwargs[keyword] if kwargs.has_key(keyword) else default for keyword, default in keyword_defaults[len(args):]]) # fill in the remaining arguments