mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 23:09:51 +08:00
Throw an exception if a Ray method is called from a thread that isn't the main thread. (#97)
This commit is contained in:
committed by
Philipp Moritz
parent
9474d03912
commit
0f7091099d
@@ -466,6 +466,7 @@ class Worker(object):
|
||||
be object IDs or they can be values. If they are values, they
|
||||
must be serializable objecs.
|
||||
"""
|
||||
check_main_thread()
|
||||
# Put large or complex arguments that are passed by value in the object
|
||||
# store first.
|
||||
args_for_photon = []
|
||||
@@ -503,6 +504,7 @@ class Worker(object):
|
||||
not take any arguments. If it returns anything, its return values will
|
||||
not be used.
|
||||
"""
|
||||
check_main_thread()
|
||||
if self.mode not in [None, SCRIPT_MODE, SILENT_MODE, PYTHON_MODE]:
|
||||
raise Exception("run_function_on_all_workers can only be called on a driver.")
|
||||
# If ray.init has not been called yet, then cache the function and export it
|
||||
@@ -540,6 +542,16 @@ made by one task do not affect other tasks.
|
||||
class RayConnectionError(Exception):
|
||||
pass
|
||||
|
||||
def check_main_thread():
|
||||
"""Check that we are currently on the main thread.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if this is called on a thread other than
|
||||
the main thread.
|
||||
"""
|
||||
if threading.current_thread().getName() != "MainThread":
|
||||
raise Exception("The Ray methods are not thread safe and must be called from the main thread. This method was called from thread {}.".format(threading.current_thread().getName()))
|
||||
|
||||
def check_connected(worker=global_worker):
|
||||
"""Check if the worker is connected.
|
||||
|
||||
@@ -565,6 +577,7 @@ def print_failed_task(task_status):
|
||||
|
||||
def error_info(worker=global_worker):
|
||||
"""Return information about failed tasks."""
|
||||
check_main_thread()
|
||||
check_connected(worker)
|
||||
result = {"TaskError": [],
|
||||
"RemoteFunctionImportError": [],
|
||||
@@ -627,6 +640,7 @@ def init(start_ray_local=False, num_workers=None, num_local_schedulers=1, driver
|
||||
Exception: An exception is raised if an inappropriate combination of
|
||||
arguments is passed in.
|
||||
"""
|
||||
check_main_thread()
|
||||
if driver_mode == PYTHON_MODE:
|
||||
# If starting Ray in PYTHON_MODE, don't start any other processes.
|
||||
address_info = {}
|
||||
@@ -819,6 +833,7 @@ def connect(address_info, mode=WORKER_MODE, worker=global_worker):
|
||||
mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, PYTHON_MODE,
|
||||
and SILENT_MODE.
|
||||
"""
|
||||
check_main_thread()
|
||||
worker.worker_id = random_string()
|
||||
worker.connected = True
|
||||
worker.set_mode(mode)
|
||||
@@ -952,6 +967,7 @@ def get(objectid, worker=global_worker):
|
||||
Returns:
|
||||
A Python object or a list of Python objects.
|
||||
"""
|
||||
check_main_thread()
|
||||
check_connected(worker)
|
||||
if worker.mode == PYTHON_MODE:
|
||||
return objectid # In PYTHON_MODE, ray.get is the identity operation (the input will actually be a value not an objectid)
|
||||
@@ -977,6 +993,7 @@ def put(value, worker=global_worker):
|
||||
Returns:
|
||||
The object ID assigned to this value.
|
||||
"""
|
||||
check_main_thread()
|
||||
check_connected(worker)
|
||||
if worker.mode == PYTHON_MODE:
|
||||
return value # In PYTHON_MODE, ray.put is the identity operation
|
||||
@@ -1006,6 +1023,7 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
|
||||
Returns:
|
||||
A list of object IDs that are ready and a list of the remaining object IDs.
|
||||
"""
|
||||
check_main_thread()
|
||||
check_connected(worker)
|
||||
object_id_strs = [object_id.id() for object_id in object_ids]
|
||||
timeout = timeout if timeout is not None else 2 ** 30
|
||||
@@ -1095,6 +1113,7 @@ def main_loop(worker=global_worker):
|
||||
"message": traceback_str})
|
||||
worker.redis_client.rpush("ErrorKeys", error_key)
|
||||
|
||||
check_main_thread()
|
||||
while True:
|
||||
task = worker.photon_client.get_task()
|
||||
function_id = task.function_id()
|
||||
@@ -1145,6 +1164,7 @@ def _export_reusable_variable(name, reusable, worker=global_worker):
|
||||
reusable (Reusable): The reusable object containing code for initializing
|
||||
and reinitializing the variable.
|
||||
"""
|
||||
check_main_thread()
|
||||
if _mode(worker) not in [SCRIPT_MODE, SILENT_MODE]:
|
||||
raise Exception("_export_reusable_variable can only be called on a driver.")
|
||||
reusable_variable_id = name
|
||||
@@ -1156,6 +1176,7 @@ def _export_reusable_variable(name, reusable, worker=global_worker):
|
||||
worker.driver_export_counter += 1
|
||||
|
||||
def export_remote_function(function_id, func_name, func, num_return_vals, worker=global_worker):
|
||||
check_main_thread()
|
||||
if _mode(worker) not in [SCRIPT_MODE, SILENT_MODE]:
|
||||
raise Exception("export_remote_function can only be called on a driver.")
|
||||
key = "RemoteFunction:{}".format(function_id.id())
|
||||
@@ -1188,6 +1209,7 @@ def remote(*args, **kwargs):
|
||||
|
||||
def func_call(*args, **kwargs):
|
||||
"""This gets run immediately when a worker calls a remote function."""
|
||||
check_main_thread()
|
||||
check_connected()
|
||||
args = list(args)
|
||||
args.extend([kwargs[keyword] if kwargs.has_key(keyword) else default for keyword, default in keyword_defaults[len(args):]]) # fill in the remaining arguments
|
||||
|
||||
Reference in New Issue
Block a user