mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 15:10:19 +08:00
export remote functions and reusable variables that were defined before connect was called (#292)
This commit is contained in:
committed by
Philipp Moritz
parent
8e9f98c5ff
commit
3bae6f136b
@@ -5,6 +5,7 @@ SCRIPT_MODE = 0
|
||||
WORKER_MODE = 1
|
||||
SHELL_MODE = 2
|
||||
PYTHON_MODE = 3
|
||||
SILENT_MODE = 4 # This is only used during testing.
|
||||
|
||||
import ctypes
|
||||
# Windows only
|
||||
|
||||
@@ -83,14 +83,14 @@ def numpy_to_dist(a):
|
||||
result.objrefs[index] = ray.put(a[[slice(l, u) for (l, u) in zip(lower, upper)]])
|
||||
return result
|
||||
|
||||
@ray.remote([List[int], str], [DistArray])
|
||||
@ray.remote([List, str], [DistArray])
|
||||
def zeros(shape, dtype_name="float"):
|
||||
result = DistArray(shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objrefs[index] = ra.zeros(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
|
||||
return result
|
||||
|
||||
@ray.remote([List[int], str], [DistArray])
|
||||
@ray.remote([List, str], [DistArray])
|
||||
def ones(shape, dtype_name="float"):
|
||||
result = DistArray(shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
@@ -171,7 +171,7 @@ def dot(a, b):
|
||||
result.objrefs[i, j] = blockwise_dot(*args)
|
||||
return result
|
||||
|
||||
@ray.remote([DistArray, List[int]], [DistArray])
|
||||
@ray.remote([DistArray, List], [DistArray])
|
||||
def subblocks(a, *ranges):
|
||||
"""
|
||||
This function produces a distributed array from a subset of the blocks in the `a`. The result and `a` will have the same number of dimensions.For example,
|
||||
|
||||
@@ -6,7 +6,7 @@ import ray
|
||||
|
||||
from core import *
|
||||
|
||||
@ray.remote([List[int]], [DistArray])
|
||||
@ray.remote([List], [DistArray])
|
||||
def normal(shape):
|
||||
num_blocks = DistArray.compute_num_blocks(shape)
|
||||
objrefs = np.empty(num_blocks, dtype=object)
|
||||
|
||||
@@ -4,7 +4,7 @@ import ray
|
||||
|
||||
__all__ = ["zeros", "zeros_like", "ones", "eye", "dot", "vstack", "hstack", "subarray", "copy", "tril", "triu", "diag", "transpose", "add", "subtract", "sum", "shape", "sum_list"]
|
||||
|
||||
@ray.remote([List[int], str, str], [np.ndarray])
|
||||
@ray.remote([List, str, str], [np.ndarray])
|
||||
def zeros(shape, dtype_name="float", order="C"):
|
||||
return np.zeros(shape, dtype=np.dtype(dtype_name), order=order)
|
||||
|
||||
@@ -13,7 +13,7 @@ def zeros_like(a, dtype_name="None", order="K", subok=True):
|
||||
dtype_val = None if dtype_name == "None" else np.dtype(dtype_name)
|
||||
return np.zeros_like(a, dtype=dtype_val, order=order, subok=subok)
|
||||
|
||||
@ray.remote([List[int], str, str], [np.ndarray])
|
||||
@ray.remote([List, str, str], [np.ndarray])
|
||||
def ones(shape, dtype_name="float", order="C"):
|
||||
return np.ones(shape, dtype=np.dtype(dtype_name), order=order)
|
||||
|
||||
@@ -35,7 +35,7 @@ def hstack(*xs):
|
||||
return np.hstack(xs)
|
||||
|
||||
# TODO(rkn): instead of this, consider implementing slicing
|
||||
@ray.remote([np.ndarray, List[int], List[int]], [np.ndarray])
|
||||
@ray.remote([np.ndarray, List, List], [np.ndarray])
|
||||
def subarray(a, lower_indices, upper_indices): # TODO(rkn): be consistent about using "index" versus "indices"
|
||||
return a[[slice(l, u) for (l, u) in zip(lower_indices, upper_indices)]]
|
||||
|
||||
@@ -55,7 +55,7 @@ def triu(m, k=0):
|
||||
def diag(v, k=0):
|
||||
return np.diag(v, k=k)
|
||||
|
||||
@ray.remote([np.ndarray, List[int]], [np.ndarray])
|
||||
@ray.remote([np.ndarray, List], [np.ndarray])
|
||||
def transpose(a, axes=[]):
|
||||
axes = None if axes == [] else axes
|
||||
return np.transpose(a, axes=axes)
|
||||
|
||||
@@ -2,6 +2,6 @@ from typing import List
|
||||
import numpy as np
|
||||
import ray
|
||||
|
||||
@ray.remote([List[int]], [np.ndarray])
|
||||
@ray.remote([List], [np.ndarray])
|
||||
def normal(shape):
|
||||
return np.random.normal(size=shape)
|
||||
|
||||
@@ -6,10 +6,10 @@ from ctypes import c_void_p
|
||||
from cloudpickle import pickle, cloudpickle, CloudPickler, load, loads
|
||||
|
||||
try:
|
||||
from ctypes import pythonapi
|
||||
pythonapi.PyCell_Set # Make sure this exists
|
||||
from ctypes import pythonapi
|
||||
pythonapi.PyCell_Set # Make sure this exists
|
||||
except:
|
||||
pythonapi = None
|
||||
pythonapi = None
|
||||
|
||||
def dump(obj, file, protocol=2):
|
||||
return BetterPickler(file, protocol).dump(obj)
|
||||
@@ -69,4 +69,4 @@ class BetterPickler(CloudPickler):
|
||||
self.write(pickle.REDUCE)
|
||||
dispatch = CloudPickler.dispatch.copy()
|
||||
dispatch[(lambda _: lambda: _)(0).__closure__[0].__class__] = save_cell
|
||||
dispatch[typing.GenericMeta] = save_type
|
||||
# dispatch[typing.GenericMeta] = save_type
|
||||
|
||||
@@ -56,8 +56,10 @@ def cleanup():
|
||||
global drivers
|
||||
for driver in drivers:
|
||||
ray.disconnect(driver)
|
||||
driver.set_mode(None)
|
||||
if len(drivers) == 0:
|
||||
ray.disconnect()
|
||||
ray.worker.global_worker.set_mode(None)
|
||||
drivers = []
|
||||
|
||||
global all_processes
|
||||
@@ -191,6 +193,8 @@ def start_ray_local(num_workers=0, worker_path=None, driver_mode=ray.SCRIPT_MODE
|
||||
equivalent to serial Python code. It should be ray.WORKER_MODE to surpress
|
||||
the printing of error messages.
|
||||
"""
|
||||
if worker_path is None:
|
||||
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../scripts/default_worker.py")
|
||||
start_services_local(num_objstores=1, num_workers_per_objstore=num_workers, worker_path=worker_path, driver_mode=driver_mode)
|
||||
|
||||
# This is a helper method which is only used in the tests and should not be
|
||||
|
||||
+54
-23
@@ -147,12 +147,13 @@ class RayReusables(object):
|
||||
|
||||
Attributes:
|
||||
_names (List[str]): A list of the names of all the reusable variables.
|
||||
_initializers (dict[str, [Callable[[], object]])]: A dictionary mapping the
|
||||
names of the reusable variables to the code for initializing them.
|
||||
_reinitializers (Dict[str, Callable[[object], object]]): A dictionary
|
||||
mapping the names of the reusable variables to the code for reinitializing
|
||||
them. For reusable variables for which reinitializer code is not provided,
|
||||
the reinitializer here essentially wraps the initializer.
|
||||
_reusables (Dict[str, Reusable]): A dictionary mapping the name of the
|
||||
reusable variables to the corresponding Reusable object.
|
||||
_cached_reusables (List[Tuple[str, Reusable]]): A list of pairs. The first
|
||||
element of each pair is the name of a reusable variable, and the second
|
||||
element is the Reusable object. This list is used to store reusable
|
||||
variables that are defined before the driver is connected. Once the driver
|
||||
is connected, these variables will be exported.
|
||||
_used (List[str]): A list of the names of all the reusable variables that
|
||||
have been accessed within the scope of the current task. This is reset to
|
||||
the empty list after each task.
|
||||
@@ -162,9 +163,10 @@ class RayReusables(object):
|
||||
"""Initialize a RayReusables object."""
|
||||
self._names = set()
|
||||
self._reusables = {}
|
||||
self._cached_reusables = []
|
||||
self._used = set()
|
||||
self._slots = ("_names", "_reusables", "_used", "_slots", "_reinitialize", "__getattribute__", "__setattr__", "__delattr__")
|
||||
# CHECKPOINT: Any attributes assigned before _here_ will be protected from rewrite or deletion
|
||||
self._slots = ("_names", "_reusables", "_cached_reusables", "_used", "_slots", "_reinitialize", "__getattribute__", "__setattr__", "__delattr__")
|
||||
# CHECKPOINT: Attributes must not be added after _slots. The above attributes are protected from deletion.
|
||||
|
||||
def _reinitialize(self):
|
||||
"""Reinitialize the reusable variables that the current task used."""
|
||||
@@ -214,14 +216,16 @@ class RayReusables(object):
|
||||
if slots == ():
|
||||
return object.__setattr__(self, name, value)
|
||||
if name in slots:
|
||||
raise AttributeError("Illegal assignment to {} object attribute {}".format(self.__class__.__name__, name))
|
||||
return object.__setattr__(self, name, value)
|
||||
reusable = value
|
||||
if not issubclass(type(reusable), Reusable):
|
||||
raise Exception("To set a reusable variable, you must pass in a Reusable object")
|
||||
self._names.add(name)
|
||||
self._reusables[name] = reusable
|
||||
if _mode() in [ray.SHELL_MODE, ray.SCRIPT_MODE]:
|
||||
if _mode() in [ray.SHELL_MODE, ray.SCRIPT_MODE, ray.SILENT_MODE]:
|
||||
_export_reusable_variable(name, reusable)
|
||||
elif _mode() is None:
|
||||
self._cached_reusables.append((name, reusable))
|
||||
object.__setattr__(self, name, reusable.initializer())
|
||||
|
||||
def __delattr__(self, name):
|
||||
@@ -244,6 +248,13 @@ class Worker(object):
|
||||
function to the remote function itself. This is the set of remote
|
||||
functions that can be executed by this worker.
|
||||
handle (worker capsule): A Python object wrapping a C++ Worker object.
|
||||
mode: The mode of the worker. One of ray.SCRIPT_MODE, ray.SHELL_MODE,
|
||||
ray.PYTHON_MODE, ray.SILENT_MODE, and ray.WORKER_MODE.
|
||||
cached_remote_functions (List[str]): A list of serialized remote functions
|
||||
that were defined before the worker called connect. When the worker
|
||||
eventually does call connect, if it is a driver, it will export these
|
||||
functions to the scheduler. If cached_remote_functions is None, that means
|
||||
that connect has been called.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -251,6 +262,7 @@ class Worker(object):
|
||||
self.functions = {}
|
||||
self.handle = None
|
||||
self.mode = None
|
||||
self.cached_remote_functions = []
|
||||
|
||||
def set_mode(self, mode):
|
||||
"""Set the mode of the worker.
|
||||
@@ -270,9 +282,13 @@ class Worker(object):
|
||||
debugging purposes. It will not send remote function calls to the scheduler
|
||||
and will insead execute them in a blocking fashion.
|
||||
|
||||
The mode ray.SILENT_MODE should be used only during testing. It does not
|
||||
print any information about errors because some of the tests intentionally
|
||||
fail.
|
||||
|
||||
args:
|
||||
mode: One of ray.SCRIPT_MODE, ray.WORKER_MODE, ray.SHELL_MODE, and
|
||||
ray.PYTHON_MODE.
|
||||
mode: One of ray.SCRIPT_MODE, ray.WORKER_MODE, ray.SHELL_MODE,
|
||||
ray.PYTHON_MODE, and ray.SILENT_MODE.
|
||||
"""
|
||||
self.mode = mode
|
||||
colorama.init()
|
||||
@@ -381,6 +397,7 @@ class Worker(object):
|
||||
Args:
|
||||
function (Callable): The remote function that this worker can execute.
|
||||
"""
|
||||
logging.info("Registering function {}.".format(function.func_name))
|
||||
ray.lib.register_function(self.handle, function.func_name, len(function.return_types))
|
||||
self.functions[function.func_name] = function
|
||||
|
||||
@@ -399,7 +416,7 @@ class Worker(object):
|
||||
"""
|
||||
task_capsule = serialization.serialize_task(self.handle, func_name, args)
|
||||
objrefs = ray.lib.submit_task(self.handle, task_capsule)
|
||||
if self.mode == ray.SHELL_MODE or self.mode == ray.SCRIPT_MODE:
|
||||
if self.mode in [ray.SHELL_MODE, ray.SCRIPT_MODE]:
|
||||
print_task_info(ray.lib.task_info(self.handle), self.mode)
|
||||
return objrefs
|
||||
|
||||
@@ -530,7 +547,7 @@ def connect(scheduler_address, objstore_address, worker_address, is_driver=False
|
||||
be chosen arbitrarily.
|
||||
is_driver (bool): True if this worker is a driver and false otherwise.
|
||||
mode: The mode of the worker. One of ray.SCRIPT_MODE, ray.WORKER_MODE,
|
||||
ray.SHELL_MODE, and ray.PYTHON_MODE.
|
||||
ray.SHELL_MODE, ray.PYTHON_MODE, and ray.SILENT_MODE.
|
||||
"""
|
||||
if hasattr(worker, "handle"):
|
||||
del worker.handle
|
||||
@@ -542,11 +559,23 @@ def connect(scheduler_address, objstore_address, worker_address, is_driver=False
|
||||
FORMAT = "%(asctime)-15s %(message)s"
|
||||
logging.basicConfig(level=logging.DEBUG, format=FORMAT, filename=ray.config.get_log_file_path("-".join(["worker", worker_address]) + ".log"))
|
||||
ray.lib.set_log_config(ray.config.get_log_file_path("-".join(["worker", worker_address, "c++"]) + ".log"))
|
||||
if mode in [ray.SHELL_MODE, ray.SCRIPT_MODE, ray.SILENT_MODE]:
|
||||
for function_to_export in worker.cached_remote_functions:
|
||||
ray.lib.export_function(worker.handle, function_to_export)
|
||||
for name, reusable_variable in reusables._cached_reusables:
|
||||
_export_reusable_variable(name, reusable_variable)
|
||||
worker.cached_remote_functions = None
|
||||
reusables._cached_reusables = None
|
||||
|
||||
def disconnect(worker=global_worker):
|
||||
"""Disconnect this worker from the scheduler and object store."""
|
||||
if worker.handle is not None:
|
||||
ray.lib.disconnect(worker.handle)
|
||||
# Reset the list of cached remote functions so that if more remote functions
|
||||
# are defined and then connect is called again, the remote functions will be
|
||||
# exported. This is mostly relevant for the tests.
|
||||
worker.cached_remote_functions = []
|
||||
reusables._cached_reusables = []
|
||||
|
||||
def get(objref, worker=global_worker):
|
||||
"""Get a remote object from an object store.
|
||||
@@ -565,7 +594,7 @@ def get(objref, worker=global_worker):
|
||||
if worker.mode == ray.PYTHON_MODE:
|
||||
return objref # In ray.PYTHON_MODE, ray.get is the identity operation (the input will actually be a value not an objref)
|
||||
ray.lib.request_object(worker.handle, objref)
|
||||
if worker.mode == ray.SHELL_MODE or worker.mode == ray.SCRIPT_MODE:
|
||||
if worker.mode in [ray.SHELL_MODE, ray.SCRIPT_MODE]:
|
||||
print_task_info(ray.lib.task_info(worker.handle), worker.mode)
|
||||
value = worker.get_object(objref)
|
||||
if isinstance(value, RayFailedObject):
|
||||
@@ -585,7 +614,7 @@ def put(value, worker=global_worker):
|
||||
return value # In ray.PYTHON_MODE, ray.put is the identity operation
|
||||
objref = ray.lib.get_objref(worker.handle)
|
||||
worker.put_object(objref, value)
|
||||
if worker.mode == ray.SHELL_MODE or worker.mode == ray.SCRIPT_MODE:
|
||||
if worker.mode in [ray.SHELL_MODE, ray.SCRIPT_MODE]:
|
||||
print_task_info(ray.lib.task_info(worker.handle), worker.mode)
|
||||
return objref
|
||||
|
||||
@@ -692,8 +721,9 @@ def main_loop(worker=global_worker):
|
||||
# We use this as a mechanism to allow the scheduler to kill workers.
|
||||
break
|
||||
elif command == "function":
|
||||
(function, arg_types, return_types) = pickling.loads(command_args)
|
||||
if function.__module__ is None: function.__module__ = "__main__"
|
||||
(function, arg_types, return_types, module) = pickling.loads(command_args)
|
||||
# TODO(rkn): Why is the below line necessary?
|
||||
function.__module__ = module
|
||||
worker.register_function(remote(arg_types, return_types, worker)(function))
|
||||
elif command == "reusable_variable":
|
||||
name, initializer_str, reinitializer_str = command_args
|
||||
@@ -736,7 +766,7 @@ def _export_reusable_variable(name, reusable, worker=global_worker):
|
||||
reusable (Reusable): The reusable object containing code for initializing
|
||||
and reinitializing the variable.
|
||||
"""
|
||||
if _mode(worker) not in [ray.SHELL_MODE, ray.SCRIPT_MODE]:
|
||||
if _mode(worker) not in [ray.SHELL_MODE, ray.SCRIPT_MODE, ray.SILENT_MODE]:
|
||||
raise Exception("_export_reusable_variable can only be called on a driver.")
|
||||
ray.lib.export_reusable_variable(worker.handle, name, pickling.dumps(reusable.initializer), pickling.dumps(reusable.reinitializer))
|
||||
|
||||
@@ -787,20 +817,21 @@ def remote(arg_types, return_types, worker=global_worker):
|
||||
check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, func_name)
|
||||
|
||||
# Everything ready - export the function
|
||||
to_export = None
|
||||
if worker.mode in [ray.SHELL_MODE, ray.SCRIPT_MODE]:
|
||||
if worker.mode in [None, ray.SHELL_MODE, ray.SCRIPT_MODE, ray.SILENT_MODE]:
|
||||
func_name_global_valid = func.__name__ in func.__globals__
|
||||
func_name_global_value = func.__globals__.get(func.__name__)
|
||||
# Set the function globally to make it refer to itself
|
||||
func.__globals__[func.__name__] = func_call # Allow the function to reference itself as a global variable
|
||||
try:
|
||||
to_export = pickling.dumps((func, arg_types, return_types))
|
||||
to_export = pickling.dumps((func, arg_types, return_types, func.__module__))
|
||||
finally:
|
||||
# Undo our changes
|
||||
if func_name_global_valid: func.__globals__[func.__name__] = func_name_global_value
|
||||
else: del func.__globals__[func.__name__]
|
||||
if to_export:
|
||||
if worker.mode in [ray.SHELL_MODE, ray.SCRIPT_MODE, ray.SILENT_MODE]:
|
||||
ray.lib.export_function(worker.handle, to_export)
|
||||
elif worker.mode is None:
|
||||
worker.cached_remote_functions.append(to_export)
|
||||
return func_call
|
||||
return remote_decorator
|
||||
|
||||
|
||||
Reference in New Issue
Block a user