implement key value store for sharing reusable variables

This commit is contained in:
Robert Nishihara
2016-07-21 00:16:19 -07:00
parent baa4b7cae3
commit 03f1830cd0
10 changed files with 347 additions and 8 deletions
+1
View File
@@ -18,5 +18,6 @@ import config
import libraylib as lib
import serialization
from worker import scheduler_info, visualize_computation_graph, task_info, register_module, connect, disconnect, get, put, remote, kill_workers, restart_workers_local
from worker import Reusable, reusables
from libraylib import ObjRef
import internal
+6
View File
@@ -1,5 +1,11 @@
import cloudpickle
def serialize(function):
return cloudpickle.dumps(function)
def deserialize(serialized_function):
return cloudpickle.loads(serialized_function)
def dumps(func, arg_types, return_types):
return cloudpickle.dumps((func, arg_types, return_types))
+165 -6
View File
@@ -102,6 +102,135 @@ class RayDealloc(object):
"""Deallocate the relevant segment to avoid a memory leak."""
ray.lib.unmap_object(self.handle, self.segmentid)
class Reusable(object):
"""An Python object that can be shared between tasks.
Attributes:
initializer (Callable[[], object]): A function used to create and initialize
the reusable variable.
reinitializer (Optional[Callable[[object], object]]): An optional function
used to reinitialize the reusable variable after it has been used. This
argument can be used as an optimization if there is a fast way to
reinitialize the state of the variable other than rerunning the
initializer.
"""
def __init__(self, initializer, reinitializer=None):
"""Initialize a Reusable object."""
if not isinstance(initializer, typing.Callable):
raise Exception("When creating a RayReusable, initializer must be a function.")
self.initializer = initializer
if reinitializer is None:
# If no reinitializer is passed in, use a wrapped version of the initializer.
reinitializer = lambda value: initializer()
if not isinstance(reinitializer, typing.Callable):
raise Exception("When creating a RayReusable, reinitializer must be a function.")
self.reinitializer = reinitializer
class RayReusables(object):
"""An object used to store Python variables that are shared between tasks.
Each worker process will have a single RayReusables object. This class serves
two purposes. First, some objects are not serializable, and so the code that
creates those objects must be run on the worker that uses them. This class is
responsible for running the code that creates those objects. Second, some of
these objects are expensive to create, and so they should be shared between
tasks. However, if a task mutates a variable that is shared between tasks,
then the behavior of the overall program may be nondeterministic (it could
depend on scheduling decisions). To fix this, if a task uses a one of these
shared objects, then that shared object will be reinitialized after the task
finishes. Since the initialization may be expensive, the user can pass in
custom reinitialization code that resets the state of the shared variable to
the way it was after initialization. If the reinitialization code does not do
this, then the behavior of the overall program is undefined.
Attributes:
_names (List[str]): A list of the names of all the reusable variables.
_initializers (dict[str, [Callable[[], object]])]: A dictionary mapping the
names of the reusable variables to the code for initializing them.
_reinitializers (Dict[str, Callable[[object], object]]): A dictionary
mapping the names of the reusable variables to the code for reinitializing
them. For reusable variables for which reinitializer code is not provided,
the reinitializer here essentially wraps the initializer.
_used (List[str]): A list of the names of all the reusable variables that
have been accessed within the scope of the current task. This is reset to
the empty list after each task.
"""
def __init__(self):
"""Initialize a RayReusables object."""
self._names = set()
self._reusables = {}
self._used = set()
self._slots = ("_names", "_reusables", "_used", "_slots", "_reinitialize", "__getattribute__", "__setattr__", "__delattr__")
# CHECKPOINT: Any attributes assigned before _here_ will be protected from rewrite or deletion
def _reinitialize(self):
"""Reinitialize the reusable variables that the current task used."""
for name in self._used:
current_value = getattr(self, name)
new_value = self._reusables[name].reinitializer(current_value)
object.__setattr__(self, name, new_value)
self._used.clear() # Reset the _used list.
def __getattribute__(self, name):
"""Get an attribute. This handles reusable variables as a special case.
When __getattribute__ is called with the name of a reusable variable, that
name is added to the list of variables that were used in the current task.
Args:
name (str): The name of the attribute to get.
"""
if name == "_slots":
return object.__getattribute__(self, name)
if name in self._slots:
return object.__getattribute__(self, name)
if name in self._names and name not in self._used:
self._used.add(name)
return object.__getattribute__(self, name)
def __setattr__(self, name, value):
"""Set an attribute. This handles reusable variables as a special case.
This is used to create reusable variables. When it is called, it runs the
function for initializing the variable to create the variable. If this is
called on the driver, then the functions for initializing and reinitializing
the variable are shipped to the workers.
Args:
name (str): The name of the attribute to set. This is either a whitelisted
name or it is treated as the name of a reusable variable.
value: If name is a whitelisted name, then value can be any value. If name
is the name of a reusable variable, then this is either the serialized
initializer code or it is a tuple of the serialized initializer and
reinitializer code.
"""
try:
slots = self._slots
except AttributeError:
slots = ()
if slots == ():
return object.__setattr__(self, name, value)
if name in slots:
raise AttributeError("Illegal assignment to {} object attribute {}".format(self.__class__.__name__, name))
reusable = value
if not issubclass(type(reusable), Reusable):
raise Exception("To set a reusable variable, you must pass in a Reusable object")
self._names.add(name)
self._reusables[name] = reusable
if _mode() in [ray.SHELL_MODE, ray.SCRIPT_MODE]:
_export_reusable_variable(name, reusable)
object.__setattr__(self, name, reusable.initializer())
def __delattr__(self, name):
"""We do not allow attributes of RayReusables to be deleted.
Args:
name (str): The name of the attribute to delete.
"""
raise Exception("Attempted deletion of attribute {}. Attributes of a RayReusable object may not be deleted.".format(name))
class Worker(object):
"""A class used to define the control flow of a worker process.
@@ -114,7 +243,6 @@ class Worker(object):
function to the remote function itself. This is the set of remote
functions that can be executed by this worker.
handle (worker capsule): A Python object wrapping a C++ Worker object.
"""
def __init__(self):
@@ -250,6 +378,16 @@ We use a global Worker object to ensure that there is a single worker object
per worker process.
"""
reusables = RayReusables()
"""RayReusables: The reusable variables that are shared between tasks.
Each worker process has its own RayReusables object, and these objects should be
the same in all workers. This is used for storing variables that are not
serializable but must be used by remote tasks. In addition, it is used to
reinitialize these variables after they are used so that changes to their state
made by one task do not affect other tasks.
"""
def print_failed_task(task_status):
"""Print information about failed tasks.
@@ -511,13 +649,16 @@ def main_loop(worker=global_worker):
else:
store_outputs_in_objstore(return_objrefs, outputs, worker) # store output in local object store
ray.lib.notify_task_completed(worker.handle, True, "") # notify the scheduler that the task completed successfully
finally:
# Reinitialize the values of reusable variables that were used in the task
# above so that changes made to their state do not affect other tasks.
ray.reusables._reinitialize()
while True:
(task, function) = ray.lib.wait_for_next_message(worker.handle)
(task, function, reusable_variable) = ray.lib.wait_for_next_message(worker.handle)
try:
# Currently the schedule does not ask the worker to execute a task and
# import a function at the same time.
assert task is None or function is None
if task is None and function is None:
# Only one of task, function, and reusable_variable should be not None.
assert sum([obj is not None for obj in [task, function, reusable_variable]]) <= 1
if task is None and function is None and reusable_variable is None:
# We use this as a mechanism to allow the scheduler to kill workers. When
# the scheduler wants to kill a worker, it gives the worker a null task,
# causing the worker program to exit the main loop here.
@@ -526,12 +667,18 @@ def main_loop(worker=global_worker):
(function, arg_types, return_types) = pickling.loads(function)
if function.__module__ is None: function.__module__ = "__main__"
worker.register_function(remote(arg_types, return_types, worker)(function))
if reusable_variable is not None:
name, initializer_str, reinitializer_str = reusable_variable
initializer = pickling.deserialize(initializer_str)
reinitializer = pickling.deserialize(reinitializer_str)
reusables.__setattr__(name, Reusable(initializer, reinitializer))
if task is not None:
process_task(task)
finally:
# Allow releasing the variables BEFORE we wait for the next message or exit the block
del task
del function
del reusable_variable
def _submit_task(func_name, args, worker=global_worker):
"""This is a wrapper around worker.submit_task.
@@ -553,6 +700,18 @@ def _mode(worker=global_worker):
"""
return worker.mode
def _export_reusable_variable(name, reusable, worker=global_worker):
"""Export a reusable variable to the workers. This is only called by a driver.
Args:
name (str): The name of the variable to export.
reusable (Reusable): The reusable object containing code for initializing
and reinitializing the variable.
"""
if _mode(worker) not in [ray.SHELL_MODE, ray.SCRIPT_MODE]:
raise Exception("_export_reusable_variable can only be called on a driver.")
ray.lib.export_reusable_variable(worker.handle, name, pickling.serialize(reusable.initializer), pickling.serialize(reusable.reinitializer))
def remote(arg_types, return_types, worker=global_worker):
"""This decorator is used to create remote functions.