export remote functions and reusable variables that were defined before connect was called (#292)

This commit is contained in:
Robert Nishihara
2016-07-26 11:40:09 -07:00
committed by Philipp Moritz
parent 8e9f98c5ff
commit 3bae6f136b
15 changed files with 167 additions and 141 deletions
+1
View File
@@ -5,6 +5,7 @@ SCRIPT_MODE = 0
WORKER_MODE = 1
SHELL_MODE = 2
PYTHON_MODE = 3
SILENT_MODE = 4 # This is only used during testing.
import ctypes
# Windows only
+3 -3
View File
@@ -83,14 +83,14 @@ def numpy_to_dist(a):
result.objrefs[index] = ray.put(a[[slice(l, u) for (l, u) in zip(lower, upper)]])
return result
@ray.remote([List[int], str], [DistArray])
@ray.remote([List, str], [DistArray])
def zeros(shape, dtype_name="float"):
result = DistArray(shape)
for index in np.ndindex(*result.num_blocks):
result.objrefs[index] = ra.zeros(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
return result
@ray.remote([List[int], str], [DistArray])
@ray.remote([List, str], [DistArray])
def ones(shape, dtype_name="float"):
result = DistArray(shape)
for index in np.ndindex(*result.num_blocks):
@@ -171,7 +171,7 @@ def dot(a, b):
result.objrefs[i, j] = blockwise_dot(*args)
return result
@ray.remote([DistArray, List[int]], [DistArray])
@ray.remote([DistArray, List], [DistArray])
def subblocks(a, *ranges):
"""
This function produces a distributed array from a subset of the blocks in the `a`. The result and `a` will have the same number of dimensions.For example,
+1 -1
View File
@@ -6,7 +6,7 @@ import ray
from core import *
@ray.remote([List[int]], [DistArray])
@ray.remote([List], [DistArray])
def normal(shape):
num_blocks = DistArray.compute_num_blocks(shape)
objrefs = np.empty(num_blocks, dtype=object)
+4 -4
View File
@@ -4,7 +4,7 @@ import ray
__all__ = ["zeros", "zeros_like", "ones", "eye", "dot", "vstack", "hstack", "subarray", "copy", "tril", "triu", "diag", "transpose", "add", "subtract", "sum", "shape", "sum_list"]
@ray.remote([List[int], str, str], [np.ndarray])
@ray.remote([List, str, str], [np.ndarray])
def zeros(shape, dtype_name="float", order="C"):
return np.zeros(shape, dtype=np.dtype(dtype_name), order=order)
@@ -13,7 +13,7 @@ def zeros_like(a, dtype_name="None", order="K", subok=True):
dtype_val = None if dtype_name == "None" else np.dtype(dtype_name)
return np.zeros_like(a, dtype=dtype_val, order=order, subok=subok)
@ray.remote([List[int], str, str], [np.ndarray])
@ray.remote([List, str, str], [np.ndarray])
def ones(shape, dtype_name="float", order="C"):
return np.ones(shape, dtype=np.dtype(dtype_name), order=order)
@@ -35,7 +35,7 @@ def hstack(*xs):
return np.hstack(xs)
# TODO(rkn): instead of this, consider implementing slicing
@ray.remote([np.ndarray, List[int], List[int]], [np.ndarray])
@ray.remote([np.ndarray, List, List], [np.ndarray])
def subarray(a, lower_indices, upper_indices): # TODO(rkn): be consistent about using "index" versus "indices"
return a[[slice(l, u) for (l, u) in zip(lower_indices, upper_indices)]]
@@ -55,7 +55,7 @@ def triu(m, k=0):
def diag(v, k=0):
return np.diag(v, k=k)
@ray.remote([np.ndarray, List[int]], [np.ndarray])
@ray.remote([np.ndarray, List], [np.ndarray])
def transpose(a, axes=[]):
axes = None if axes == [] else axes
return np.transpose(a, axes=axes)
+1 -1
View File
@@ -2,6 +2,6 @@ from typing import List
import numpy as np
import ray
@ray.remote([List[int]], [np.ndarray])
@ray.remote([List], [np.ndarray])
def normal(shape):
return np.random.normal(size=shape)
+4 -4
View File
@@ -6,10 +6,10 @@ from ctypes import c_void_p
from cloudpickle import pickle, cloudpickle, CloudPickler, load, loads
try:
from ctypes import pythonapi
pythonapi.PyCell_Set # Make sure this exists
from ctypes import pythonapi
pythonapi.PyCell_Set # Make sure this exists
except:
pythonapi = None
pythonapi = None
def dump(obj, file, protocol=2):
return BetterPickler(file, protocol).dump(obj)
@@ -69,4 +69,4 @@ class BetterPickler(CloudPickler):
self.write(pickle.REDUCE)
dispatch = CloudPickler.dispatch.copy()
dispatch[(lambda _: lambda: _)(0).__closure__[0].__class__] = save_cell
dispatch[typing.GenericMeta] = save_type
# dispatch[typing.GenericMeta] = save_type
+4
View File
@@ -56,8 +56,10 @@ def cleanup():
global drivers
for driver in drivers:
ray.disconnect(driver)
driver.set_mode(None)
if len(drivers) == 0:
ray.disconnect()
ray.worker.global_worker.set_mode(None)
drivers = []
global all_processes
@@ -191,6 +193,8 @@ def start_ray_local(num_workers=0, worker_path=None, driver_mode=ray.SCRIPT_MODE
equivalent to serial Python code. It should be ray.WORKER_MODE to surpress
the printing of error messages.
"""
if worker_path is None:
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../scripts/default_worker.py")
start_services_local(num_objstores=1, num_workers_per_objstore=num_workers, worker_path=worker_path, driver_mode=driver_mode)
# This is a helper method which is only used in the tests and should not be
+54 -23
View File
@@ -147,12 +147,13 @@ class RayReusables(object):
Attributes:
_names (List[str]): A list of the names of all the reusable variables.
_initializers (dict[str, [Callable[[], object]])]: A dictionary mapping the
names of the reusable variables to the code for initializing them.
_reinitializers (Dict[str, Callable[[object], object]]): A dictionary
mapping the names of the reusable variables to the code for reinitializing
them. For reusable variables for which reinitializer code is not provided,
the reinitializer here essentially wraps the initializer.
_reusables (Dict[str, Reusable]): A dictionary mapping the name of the
reusable variables to the corresponding Reusable object.
_cached_reusables (List[Tuple[str, Reusable]]): A list of pairs. The first
element of each pair is the name of a reusable variable, and the second
element is the Reusable object. This list is used to store reusable
variables that are defined before the driver is connected. Once the driver
is connected, these variables will be exported.
_used (List[str]): A list of the names of all the reusable variables that
have been accessed within the scope of the current task. This is reset to
the empty list after each task.
@@ -162,9 +163,10 @@ class RayReusables(object):
"""Initialize a RayReusables object."""
self._names = set()
self._reusables = {}
self._cached_reusables = []
self._used = set()
self._slots = ("_names", "_reusables", "_used", "_slots", "_reinitialize", "__getattribute__", "__setattr__", "__delattr__")
# CHECKPOINT: Any attributes assigned before _here_ will be protected from rewrite or deletion
self._slots = ("_names", "_reusables", "_cached_reusables", "_used", "_slots", "_reinitialize", "__getattribute__", "__setattr__", "__delattr__")
# CHECKPOINT: Attributes must not be added after _slots. The above attributes are protected from deletion.
def _reinitialize(self):
"""Reinitialize the reusable variables that the current task used."""
@@ -214,14 +216,16 @@ class RayReusables(object):
if slots == ():
return object.__setattr__(self, name, value)
if name in slots:
raise AttributeError("Illegal assignment to {} object attribute {}".format(self.__class__.__name__, name))
return object.__setattr__(self, name, value)
reusable = value
if not issubclass(type(reusable), Reusable):
raise Exception("To set a reusable variable, you must pass in a Reusable object")
self._names.add(name)
self._reusables[name] = reusable
if _mode() in [ray.SHELL_MODE, ray.SCRIPT_MODE]:
if _mode() in [ray.SHELL_MODE, ray.SCRIPT_MODE, ray.SILENT_MODE]:
_export_reusable_variable(name, reusable)
elif _mode() is None:
self._cached_reusables.append((name, reusable))
object.__setattr__(self, name, reusable.initializer())
def __delattr__(self, name):
@@ -244,6 +248,13 @@ class Worker(object):
function to the remote function itself. This is the set of remote
functions that can be executed by this worker.
handle (worker capsule): A Python object wrapping a C++ Worker object.
mode: The mode of the worker. One of ray.SCRIPT_MODE, ray.SHELL_MODE,
ray.PYTHON_MODE, ray.SILENT_MODE, and ray.WORKER_MODE.
cached_remote_functions (List[str]): A list of serialized remote functions
that were defined before the worker called connect. When the worker
eventually does call connect, if it is a driver, it will export these
functions to the scheduler. If cached_remote_functions is None, that means
that connect has been called.
"""
def __init__(self):
@@ -251,6 +262,7 @@ class Worker(object):
self.functions = {}
self.handle = None
self.mode = None
self.cached_remote_functions = []
def set_mode(self, mode):
"""Set the mode of the worker.
@@ -270,9 +282,13 @@ class Worker(object):
debugging purposes. It will not send remote function calls to the scheduler
and will insead execute them in a blocking fashion.
The mode ray.SILENT_MODE should be used only during testing. It does not
print any information about errors because some of the tests intentionally
fail.
args:
mode: One of ray.SCRIPT_MODE, ray.WORKER_MODE, ray.SHELL_MODE, and
ray.PYTHON_MODE.
mode: One of ray.SCRIPT_MODE, ray.WORKER_MODE, ray.SHELL_MODE,
ray.PYTHON_MODE, and ray.SILENT_MODE.
"""
self.mode = mode
colorama.init()
@@ -381,6 +397,7 @@ class Worker(object):
Args:
function (Callable): The remote function that this worker can execute.
"""
logging.info("Registering function {}.".format(function.func_name))
ray.lib.register_function(self.handle, function.func_name, len(function.return_types))
self.functions[function.func_name] = function
@@ -399,7 +416,7 @@ class Worker(object):
"""
task_capsule = serialization.serialize_task(self.handle, func_name, args)
objrefs = ray.lib.submit_task(self.handle, task_capsule)
if self.mode == ray.SHELL_MODE or self.mode == ray.SCRIPT_MODE:
if self.mode in [ray.SHELL_MODE, ray.SCRIPT_MODE]:
print_task_info(ray.lib.task_info(self.handle), self.mode)
return objrefs
@@ -530,7 +547,7 @@ def connect(scheduler_address, objstore_address, worker_address, is_driver=False
be chosen arbitrarily.
is_driver (bool): True if this worker is a driver and false otherwise.
mode: The mode of the worker. One of ray.SCRIPT_MODE, ray.WORKER_MODE,
ray.SHELL_MODE, and ray.PYTHON_MODE.
ray.SHELL_MODE, ray.PYTHON_MODE, and ray.SILENT_MODE.
"""
if hasattr(worker, "handle"):
del worker.handle
@@ -542,11 +559,23 @@ def connect(scheduler_address, objstore_address, worker_address, is_driver=False
FORMAT = "%(asctime)-15s %(message)s"
logging.basicConfig(level=logging.DEBUG, format=FORMAT, filename=ray.config.get_log_file_path("-".join(["worker", worker_address]) + ".log"))
ray.lib.set_log_config(ray.config.get_log_file_path("-".join(["worker", worker_address, "c++"]) + ".log"))
if mode in [ray.SHELL_MODE, ray.SCRIPT_MODE, ray.SILENT_MODE]:
for function_to_export in worker.cached_remote_functions:
ray.lib.export_function(worker.handle, function_to_export)
for name, reusable_variable in reusables._cached_reusables:
_export_reusable_variable(name, reusable_variable)
worker.cached_remote_functions = None
reusables._cached_reusables = None
def disconnect(worker=global_worker):
"""Disconnect this worker from the scheduler and object store."""
if worker.handle is not None:
ray.lib.disconnect(worker.handle)
# Reset the list of cached remote functions so that if more remote functions
# are defined and then connect is called again, the remote functions will be
# exported. This is mostly relevant for the tests.
worker.cached_remote_functions = []
reusables._cached_reusables = []
def get(objref, worker=global_worker):
"""Get a remote object from an object store.
@@ -565,7 +594,7 @@ def get(objref, worker=global_worker):
if worker.mode == ray.PYTHON_MODE:
return objref # In ray.PYTHON_MODE, ray.get is the identity operation (the input will actually be a value not an objref)
ray.lib.request_object(worker.handle, objref)
if worker.mode == ray.SHELL_MODE or worker.mode == ray.SCRIPT_MODE:
if worker.mode in [ray.SHELL_MODE, ray.SCRIPT_MODE]:
print_task_info(ray.lib.task_info(worker.handle), worker.mode)
value = worker.get_object(objref)
if isinstance(value, RayFailedObject):
@@ -585,7 +614,7 @@ def put(value, worker=global_worker):
return value # In ray.PYTHON_MODE, ray.put is the identity operation
objref = ray.lib.get_objref(worker.handle)
worker.put_object(objref, value)
if worker.mode == ray.SHELL_MODE or worker.mode == ray.SCRIPT_MODE:
if worker.mode in [ray.SHELL_MODE, ray.SCRIPT_MODE]:
print_task_info(ray.lib.task_info(worker.handle), worker.mode)
return objref
@@ -692,8 +721,9 @@ def main_loop(worker=global_worker):
# We use this as a mechanism to allow the scheduler to kill workers.
break
elif command == "function":
(function, arg_types, return_types) = pickling.loads(command_args)
if function.__module__ is None: function.__module__ = "__main__"
(function, arg_types, return_types, module) = pickling.loads(command_args)
# TODO(rkn): Why is the below line necessary?
function.__module__ = module
worker.register_function(remote(arg_types, return_types, worker)(function))
elif command == "reusable_variable":
name, initializer_str, reinitializer_str = command_args
@@ -736,7 +766,7 @@ def _export_reusable_variable(name, reusable, worker=global_worker):
reusable (Reusable): The reusable object containing code for initializing
and reinitializing the variable.
"""
if _mode(worker) not in [ray.SHELL_MODE, ray.SCRIPT_MODE]:
if _mode(worker) not in [ray.SHELL_MODE, ray.SCRIPT_MODE, ray.SILENT_MODE]:
raise Exception("_export_reusable_variable can only be called on a driver.")
ray.lib.export_reusable_variable(worker.handle, name, pickling.dumps(reusable.initializer), pickling.dumps(reusable.reinitializer))
@@ -787,20 +817,21 @@ def remote(arg_types, return_types, worker=global_worker):
check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, func_name)
# Everything ready - export the function
to_export = None
if worker.mode in [ray.SHELL_MODE, ray.SCRIPT_MODE]:
if worker.mode in [None, ray.SHELL_MODE, ray.SCRIPT_MODE, ray.SILENT_MODE]:
func_name_global_valid = func.__name__ in func.__globals__
func_name_global_value = func.__globals__.get(func.__name__)
# Set the function globally to make it refer to itself
func.__globals__[func.__name__] = func_call # Allow the function to reference itself as a global variable
try:
to_export = pickling.dumps((func, arg_types, return_types))
to_export = pickling.dumps((func, arg_types, return_types, func.__module__))
finally:
# Undo our changes
if func_name_global_valid: func.__globals__[func.__name__] = func_name_global_value
else: del func.__globals__[func.__name__]
if to_export:
if worker.mode in [ray.SHELL_MODE, ray.SCRIPT_MODE, ray.SILENT_MODE]:
ray.lib.export_function(worker.handle, to_export)
elif worker.mode is None:
worker.cached_remote_functions.append(to_export)
return func_call
return remote_decorator