mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 01:46:10 +08:00
implement key value store for sharing reusable variables
This commit is contained in:
@@ -18,5 +18,6 @@ import config
|
||||
import libraylib as lib
|
||||
import serialization
|
||||
from worker import scheduler_info, visualize_computation_graph, task_info, register_module, connect, disconnect, get, put, remote, kill_workers, restart_workers_local
|
||||
from worker import Reusable, reusables
|
||||
from libraylib import ObjRef
|
||||
import internal
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
import cloudpickle
|
||||
|
||||
def serialize(function):
|
||||
return cloudpickle.dumps(function)
|
||||
|
||||
def deserialize(serialized_function):
|
||||
return cloudpickle.loads(serialized_function)
|
||||
|
||||
def dumps(func, arg_types, return_types):
|
||||
return cloudpickle.dumps((func, arg_types, return_types))
|
||||
|
||||
|
||||
+165
-6
@@ -102,6 +102,135 @@ class RayDealloc(object):
|
||||
"""Deallocate the relevant segment to avoid a memory leak."""
|
||||
ray.lib.unmap_object(self.handle, self.segmentid)
|
||||
|
||||
class Reusable(object):
|
||||
"""An Python object that can be shared between tasks.
|
||||
|
||||
Attributes:
|
||||
initializer (Callable[[], object]): A function used to create and initialize
|
||||
the reusable variable.
|
||||
reinitializer (Optional[Callable[[object], object]]): An optional function
|
||||
used to reinitialize the reusable variable after it has been used. This
|
||||
argument can be used as an optimization if there is a fast way to
|
||||
reinitialize the state of the variable other than rerunning the
|
||||
initializer.
|
||||
"""
|
||||
|
||||
def __init__(self, initializer, reinitializer=None):
|
||||
"""Initialize a Reusable object."""
|
||||
if not isinstance(initializer, typing.Callable):
|
||||
raise Exception("When creating a RayReusable, initializer must be a function.")
|
||||
self.initializer = initializer
|
||||
if reinitializer is None:
|
||||
# If no reinitializer is passed in, use a wrapped version of the initializer.
|
||||
reinitializer = lambda value: initializer()
|
||||
if not isinstance(reinitializer, typing.Callable):
|
||||
raise Exception("When creating a RayReusable, reinitializer must be a function.")
|
||||
self.reinitializer = reinitializer
|
||||
|
||||
class RayReusables(object):
|
||||
"""An object used to store Python variables that are shared between tasks.
|
||||
|
||||
Each worker process will have a single RayReusables object. This class serves
|
||||
two purposes. First, some objects are not serializable, and so the code that
|
||||
creates those objects must be run on the worker that uses them. This class is
|
||||
responsible for running the code that creates those objects. Second, some of
|
||||
these objects are expensive to create, and so they should be shared between
|
||||
tasks. However, if a task mutates a variable that is shared between tasks,
|
||||
then the behavior of the overall program may be nondeterministic (it could
|
||||
depend on scheduling decisions). To fix this, if a task uses a one of these
|
||||
shared objects, then that shared object will be reinitialized after the task
|
||||
finishes. Since the initialization may be expensive, the user can pass in
|
||||
custom reinitialization code that resets the state of the shared variable to
|
||||
the way it was after initialization. If the reinitialization code does not do
|
||||
this, then the behavior of the overall program is undefined.
|
||||
|
||||
Attributes:
|
||||
_names (List[str]): A list of the names of all the reusable variables.
|
||||
_initializers (dict[str, [Callable[[], object]])]: A dictionary mapping the
|
||||
names of the reusable variables to the code for initializing them.
|
||||
_reinitializers (Dict[str, Callable[[object], object]]): A dictionary
|
||||
mapping the names of the reusable variables to the code for reinitializing
|
||||
them. For reusable variables for which reinitializer code is not provided,
|
||||
the reinitializer here essentially wraps the initializer.
|
||||
_used (List[str]): A list of the names of all the reusable variables that
|
||||
have been accessed within the scope of the current task. This is reset to
|
||||
the empty list after each task.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize a RayReusables object."""
|
||||
self._names = set()
|
||||
self._reusables = {}
|
||||
self._used = set()
|
||||
self._slots = ("_names", "_reusables", "_used", "_slots", "_reinitialize", "__getattribute__", "__setattr__", "__delattr__")
|
||||
# CHECKPOINT: Any attributes assigned before _here_ will be protected from rewrite or deletion
|
||||
|
||||
def _reinitialize(self):
|
||||
"""Reinitialize the reusable variables that the current task used."""
|
||||
for name in self._used:
|
||||
current_value = getattr(self, name)
|
||||
new_value = self._reusables[name].reinitializer(current_value)
|
||||
object.__setattr__(self, name, new_value)
|
||||
self._used.clear() # Reset the _used list.
|
||||
|
||||
def __getattribute__(self, name):
|
||||
"""Get an attribute. This handles reusable variables as a special case.
|
||||
|
||||
When __getattribute__ is called with the name of a reusable variable, that
|
||||
name is added to the list of variables that were used in the current task.
|
||||
|
||||
Args:
|
||||
name (str): The name of the attribute to get.
|
||||
"""
|
||||
if name == "_slots":
|
||||
return object.__getattribute__(self, name)
|
||||
if name in self._slots:
|
||||
return object.__getattribute__(self, name)
|
||||
if name in self._names and name not in self._used:
|
||||
self._used.add(name)
|
||||
return object.__getattribute__(self, name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
"""Set an attribute. This handles reusable variables as a special case.
|
||||
|
||||
This is used to create reusable variables. When it is called, it runs the
|
||||
function for initializing the variable to create the variable. If this is
|
||||
called on the driver, then the functions for initializing and reinitializing
|
||||
the variable are shipped to the workers.
|
||||
|
||||
Args:
|
||||
name (str): The name of the attribute to set. This is either a whitelisted
|
||||
name or it is treated as the name of a reusable variable.
|
||||
value: If name is a whitelisted name, then value can be any value. If name
|
||||
is the name of a reusable variable, then this is either the serialized
|
||||
initializer code or it is a tuple of the serialized initializer and
|
||||
reinitializer code.
|
||||
"""
|
||||
try:
|
||||
slots = self._slots
|
||||
except AttributeError:
|
||||
slots = ()
|
||||
if slots == ():
|
||||
return object.__setattr__(self, name, value)
|
||||
if name in slots:
|
||||
raise AttributeError("Illegal assignment to {} object attribute {}".format(self.__class__.__name__, name))
|
||||
reusable = value
|
||||
if not issubclass(type(reusable), Reusable):
|
||||
raise Exception("To set a reusable variable, you must pass in a Reusable object")
|
||||
self._names.add(name)
|
||||
self._reusables[name] = reusable
|
||||
if _mode() in [ray.SHELL_MODE, ray.SCRIPT_MODE]:
|
||||
_export_reusable_variable(name, reusable)
|
||||
object.__setattr__(self, name, reusable.initializer())
|
||||
|
||||
def __delattr__(self, name):
|
||||
"""We do not allow attributes of RayReusables to be deleted.
|
||||
|
||||
Args:
|
||||
name (str): The name of the attribute to delete.
|
||||
"""
|
||||
raise Exception("Attempted deletion of attribute {}. Attributes of a RayReusable object may not be deleted.".format(name))
|
||||
|
||||
class Worker(object):
|
||||
"""A class used to define the control flow of a worker process.
|
||||
|
||||
@@ -114,7 +243,6 @@ class Worker(object):
|
||||
function to the remote function itself. This is the set of remote
|
||||
functions that can be executed by this worker.
|
||||
handle (worker capsule): A Python object wrapping a C++ Worker object.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -250,6 +378,16 @@ We use a global Worker object to ensure that there is a single worker object
|
||||
per worker process.
|
||||
"""
|
||||
|
||||
reusables = RayReusables()
|
||||
"""RayReusables: The reusable variables that are shared between tasks.
|
||||
|
||||
Each worker process has its own RayReusables object, and these objects should be
|
||||
the same in all workers. This is used for storing variables that are not
|
||||
serializable but must be used by remote tasks. In addition, it is used to
|
||||
reinitialize these variables after they are used so that changes to their state
|
||||
made by one task do not affect other tasks.
|
||||
"""
|
||||
|
||||
def print_failed_task(task_status):
|
||||
"""Print information about failed tasks.
|
||||
|
||||
@@ -511,13 +649,16 @@ def main_loop(worker=global_worker):
|
||||
else:
|
||||
store_outputs_in_objstore(return_objrefs, outputs, worker) # store output in local object store
|
||||
ray.lib.notify_task_completed(worker.handle, True, "") # notify the scheduler that the task completed successfully
|
||||
finally:
|
||||
# Reinitialize the values of reusable variables that were used in the task
|
||||
# above so that changes made to their state do not affect other tasks.
|
||||
ray.reusables._reinitialize()
|
||||
while True:
|
||||
(task, function) = ray.lib.wait_for_next_message(worker.handle)
|
||||
(task, function, reusable_variable) = ray.lib.wait_for_next_message(worker.handle)
|
||||
try:
|
||||
# Currently the schedule does not ask the worker to execute a task and
|
||||
# import a function at the same time.
|
||||
assert task is None or function is None
|
||||
if task is None and function is None:
|
||||
# Only one of task, function, and reusable_variable should be not None.
|
||||
assert sum([obj is not None for obj in [task, function, reusable_variable]]) <= 1
|
||||
if task is None and function is None and reusable_variable is None:
|
||||
# We use this as a mechanism to allow the scheduler to kill workers. When
|
||||
# the scheduler wants to kill a worker, it gives the worker a null task,
|
||||
# causing the worker program to exit the main loop here.
|
||||
@@ -526,12 +667,18 @@ def main_loop(worker=global_worker):
|
||||
(function, arg_types, return_types) = pickling.loads(function)
|
||||
if function.__module__ is None: function.__module__ = "__main__"
|
||||
worker.register_function(remote(arg_types, return_types, worker)(function))
|
||||
if reusable_variable is not None:
|
||||
name, initializer_str, reinitializer_str = reusable_variable
|
||||
initializer = pickling.deserialize(initializer_str)
|
||||
reinitializer = pickling.deserialize(reinitializer_str)
|
||||
reusables.__setattr__(name, Reusable(initializer, reinitializer))
|
||||
if task is not None:
|
||||
process_task(task)
|
||||
finally:
|
||||
# Allow releasing the variables BEFORE we wait for the next message or exit the block
|
||||
del task
|
||||
del function
|
||||
del reusable_variable
|
||||
|
||||
def _submit_task(func_name, args, worker=global_worker):
|
||||
"""This is a wrapper around worker.submit_task.
|
||||
@@ -553,6 +700,18 @@ def _mode(worker=global_worker):
|
||||
"""
|
||||
return worker.mode
|
||||
|
||||
def _export_reusable_variable(name, reusable, worker=global_worker):
|
||||
"""Export a reusable variable to the workers. This is only called by a driver.
|
||||
|
||||
Args:
|
||||
name (str): The name of the variable to export.
|
||||
reusable (Reusable): The reusable object containing code for initializing
|
||||
and reinitializing the variable.
|
||||
"""
|
||||
if _mode(worker) not in [ray.SHELL_MODE, ray.SCRIPT_MODE]:
|
||||
raise Exception("_export_reusable_variable can only be called on a driver.")
|
||||
ray.lib.export_reusable_variable(worker.handle, name, pickling.serialize(reusable.initializer), pickling.serialize(reusable.reinitializer))
|
||||
|
||||
def remote(arg_types, return_types, worker=global_worker):
|
||||
"""This decorator is used to create remote functions.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user