From 03f1830cd07fd07deaf8ee70563148542d94f7fa Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Thu, 21 Jul 2016 00:16:19 -0700 Subject: [PATCH] implement key value store for sharing reusable variables --- lib/python/ray/__init__.py | 1 + lib/python/ray/pickling.py | 6 ++ lib/python/ray/worker.py | 171 +++++++++++++++++++++++++++++++++++-- protos/ray.proto | 15 ++++ src/raylib.cc | 29 ++++++- src/scheduler.cc | 16 ++++ src/scheduler.h | 1 + src/worker.cc | 25 ++++++ src/worker.h | 12 ++- test/runtest.py | 79 +++++++++++++++++ 10 files changed, 347 insertions(+), 8 deletions(-) diff --git a/lib/python/ray/__init__.py b/lib/python/ray/__init__.py index 824d13831..8a561565b 100644 --- a/lib/python/ray/__init__.py +++ b/lib/python/ray/__init__.py @@ -18,5 +18,6 @@ import config import libraylib as lib import serialization from worker import scheduler_info, visualize_computation_graph, task_info, register_module, connect, disconnect, get, put, remote, kill_workers, restart_workers_local +from worker import Reusable, reusables from libraylib import ObjRef import internal diff --git a/lib/python/ray/pickling.py b/lib/python/ray/pickling.py index d9e34b0ff..1d3c27afd 100644 --- a/lib/python/ray/pickling.py +++ b/lib/python/ray/pickling.py @@ -1,5 +1,11 @@ import cloudpickle +def serialize(function): + return cloudpickle.dumps(function) + +def deserialize(serialized_function): + return cloudpickle.loads(serialized_function) + def dumps(func, arg_types, return_types): return cloudpickle.dumps((func, arg_types, return_types)) diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index 97b02339b..d5e676e0f 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -102,6 +102,135 @@ class RayDealloc(object): """Deallocate the relevant segment to avoid a memory leak.""" ray.lib.unmap_object(self.handle, self.segmentid) +class Reusable(object): + """An Python object that can be shared between tasks. + + Attributes: + initializer (Callable[[], object]): A function used to create and initialize + the reusable variable. + reinitializer (Optional[Callable[[object], object]]): An optional function + used to reinitialize the reusable variable after it has been used. This + argument can be used as an optimization if there is a fast way to + reinitialize the state of the variable other than rerunning the + initializer. + """ + + def __init__(self, initializer, reinitializer=None): + """Initialize a Reusable object.""" + if not isinstance(initializer, typing.Callable): + raise Exception("When creating a RayReusable, initializer must be a function.") + self.initializer = initializer + if reinitializer is None: + # If no reinitializer is passed in, use a wrapped version of the initializer. + reinitializer = lambda value: initializer() + if not isinstance(reinitializer, typing.Callable): + raise Exception("When creating a RayReusable, reinitializer must be a function.") + self.reinitializer = reinitializer + +class RayReusables(object): + """An object used to store Python variables that are shared between tasks. + + Each worker process will have a single RayReusables object. This class serves + two purposes. First, some objects are not serializable, and so the code that + creates those objects must be run on the worker that uses them. This class is + responsible for running the code that creates those objects. Second, some of + these objects are expensive to create, and so they should be shared between + tasks. However, if a task mutates a variable that is shared between tasks, + then the behavior of the overall program may be nondeterministic (it could + depend on scheduling decisions). To fix this, if a task uses a one of these + shared objects, then that shared object will be reinitialized after the task + finishes. Since the initialization may be expensive, the user can pass in + custom reinitialization code that resets the state of the shared variable to + the way it was after initialization. If the reinitialization code does not do + this, then the behavior of the overall program is undefined. + + Attributes: + _names (List[str]): A list of the names of all the reusable variables. + _initializers (dict[str, [Callable[[], object]])]: A dictionary mapping the + names of the reusable variables to the code for initializing them. + _reinitializers (Dict[str, Callable[[object], object]]): A dictionary + mapping the names of the reusable variables to the code for reinitializing + them. For reusable variables for which reinitializer code is not provided, + the reinitializer here essentially wraps the initializer. + _used (List[str]): A list of the names of all the reusable variables that + have been accessed within the scope of the current task. This is reset to + the empty list after each task. + """ + + def __init__(self): + """Initialize a RayReusables object.""" + self._names = set() + self._reusables = {} + self._used = set() + self._slots = ("_names", "_reusables", "_used", "_slots", "_reinitialize", "__getattribute__", "__setattr__", "__delattr__") + # CHECKPOINT: Any attributes assigned before _here_ will be protected from rewrite or deletion + + def _reinitialize(self): + """Reinitialize the reusable variables that the current task used.""" + for name in self._used: + current_value = getattr(self, name) + new_value = self._reusables[name].reinitializer(current_value) + object.__setattr__(self, name, new_value) + self._used.clear() # Reset the _used list. + + def __getattribute__(self, name): + """Get an attribute. This handles reusable variables as a special case. + + When __getattribute__ is called with the name of a reusable variable, that + name is added to the list of variables that were used in the current task. + + Args: + name (str): The name of the attribute to get. + """ + if name == "_slots": + return object.__getattribute__(self, name) + if name in self._slots: + return object.__getattribute__(self, name) + if name in self._names and name not in self._used: + self._used.add(name) + return object.__getattribute__(self, name) + + def __setattr__(self, name, value): + """Set an attribute. This handles reusable variables as a special case. + + This is used to create reusable variables. When it is called, it runs the + function for initializing the variable to create the variable. If this is + called on the driver, then the functions for initializing and reinitializing + the variable are shipped to the workers. + + Args: + name (str): The name of the attribute to set. This is either a whitelisted + name or it is treated as the name of a reusable variable. + value: If name is a whitelisted name, then value can be any value. If name + is the name of a reusable variable, then this is either the serialized + initializer code or it is a tuple of the serialized initializer and + reinitializer code. + """ + try: + slots = self._slots + except AttributeError: + slots = () + if slots == (): + return object.__setattr__(self, name, value) + if name in slots: + raise AttributeError("Illegal assignment to {} object attribute {}".format(self.__class__.__name__, name)) + reusable = value + if not issubclass(type(reusable), Reusable): + raise Exception("To set a reusable variable, you must pass in a Reusable object") + self._names.add(name) + self._reusables[name] = reusable + if _mode() in [ray.SHELL_MODE, ray.SCRIPT_MODE]: + _export_reusable_variable(name, reusable) + object.__setattr__(self, name, reusable.initializer()) + + def __delattr__(self, name): + """We do not allow attributes of RayReusables to be deleted. + + Args: + name (str): The name of the attribute to delete. + """ + raise Exception("Attempted deletion of attribute {}. Attributes of a RayReusable object may not be deleted.".format(name)) + class Worker(object): """A class used to define the control flow of a worker process. @@ -114,7 +243,6 @@ class Worker(object): function to the remote function itself. This is the set of remote functions that can be executed by this worker. handle (worker capsule): A Python object wrapping a C++ Worker object. - """ def __init__(self): @@ -250,6 +378,16 @@ We use a global Worker object to ensure that there is a single worker object per worker process. """ +reusables = RayReusables() +"""RayReusables: The reusable variables that are shared between tasks. + +Each worker process has its own RayReusables object, and these objects should be +the same in all workers. This is used for storing variables that are not +serializable but must be used by remote tasks. In addition, it is used to +reinitialize these variables after they are used so that changes to their state +made by one task do not affect other tasks. +""" + def print_failed_task(task_status): """Print information about failed tasks. @@ -511,13 +649,16 @@ def main_loop(worker=global_worker): else: store_outputs_in_objstore(return_objrefs, outputs, worker) # store output in local object store ray.lib.notify_task_completed(worker.handle, True, "") # notify the scheduler that the task completed successfully + finally: + # Reinitialize the values of reusable variables that were used in the task + # above so that changes made to their state do not affect other tasks. + ray.reusables._reinitialize() while True: - (task, function) = ray.lib.wait_for_next_message(worker.handle) + (task, function, reusable_variable) = ray.lib.wait_for_next_message(worker.handle) try: - # Currently the schedule does not ask the worker to execute a task and - # import a function at the same time. - assert task is None or function is None - if task is None and function is None: + # Only one of task, function, and reusable_variable should be not None. + assert sum([obj is not None for obj in [task, function, reusable_variable]]) <= 1 + if task is None and function is None and reusable_variable is None: # We use this as a mechanism to allow the scheduler to kill workers. When # the scheduler wants to kill a worker, it gives the worker a null task, # causing the worker program to exit the main loop here. @@ -526,12 +667,18 @@ def main_loop(worker=global_worker): (function, arg_types, return_types) = pickling.loads(function) if function.__module__ is None: function.__module__ = "__main__" worker.register_function(remote(arg_types, return_types, worker)(function)) + if reusable_variable is not None: + name, initializer_str, reinitializer_str = reusable_variable + initializer = pickling.deserialize(initializer_str) + reinitializer = pickling.deserialize(reinitializer_str) + reusables.__setattr__(name, Reusable(initializer, reinitializer)) if task is not None: process_task(task) finally: # Allow releasing the variables BEFORE we wait for the next message or exit the block del task del function + del reusable_variable def _submit_task(func_name, args, worker=global_worker): """This is a wrapper around worker.submit_task. @@ -553,6 +700,18 @@ def _mode(worker=global_worker): """ return worker.mode +def _export_reusable_variable(name, reusable, worker=global_worker): + """Export a reusable variable to the workers. This is only called by a driver. + + Args: + name (str): The name of the variable to export. + reusable (Reusable): The reusable object containing code for initializing + and reinitializing the variable. + """ + if _mode(worker) not in [ray.SHELL_MODE, ray.SCRIPT_MODE]: + raise Exception("_export_reusable_variable can only be called on a driver.") + ray.lib.export_reusable_variable(worker.handle, name, pickling.serialize(reusable.initializer), pickling.serialize(reusable.reinitializer)) + def remote(arg_types, return_types, worker=global_worker): """This decorator is used to create remote functions. diff --git a/protos/ray.proto b/protos/ray.proto index ae8818af3..6ef79ce24 100644 --- a/protos/ray.proto +++ b/protos/ray.proto @@ -54,6 +54,8 @@ service Scheduler { rpc KillWorkers(KillWorkersRequest) returns (KillWorkersReply); // Exports function to the workers rpc ExportFunction(ExportFunctionRequest) returns (ExportFunctionReply); + // Ship an initializer and reinitializer for a reusable variable to the workers + rpc ExportReusableVariable(ExportReusableVariableRequest) returns (AckReply); } message AckReply { @@ -237,6 +239,12 @@ message ExportFunctionRequest { message ExportFunctionReply { } +message ExportReusableVariableRequest { + string name = 1; // The name of the reusable variable. + Function initializer = 2; // A serialized version of the function that initializes the reusable variable. + Function reinitializer = 3; // A serialized version of the function that reinitializes the reusable variable. +} + // These messages are for getting information about the object store state message ObjStoreInfoRequest { @@ -253,6 +261,7 @@ message ObjStoreInfoReply { service WorkerService { rpc ExecuteTask(ExecuteTaskRequest) returns (ExecuteTaskReply); // Scheduler calls a function from the worker rpc ImportFunction(ImportFunctionRequest) returns (ImportFunctionReply); // Scheduler imports a function into the worker + rpc ImportReusableVariable(ImportReusableVariableRequest) returns (AckReply); // Scheduler imports a reusable variable into the worker rpc Die(DieRequest) returns (DieReply); // Kills this worker } @@ -270,6 +279,12 @@ message ImportFunctionRequest { message ImportFunctionReply { } +message ImportReusableVariableRequest { + string name = 1; // The name of the reusable variable. + Function initializer = 2; // A serialized version of the function that initializes the reusable variable. + Function reinitializer = 3; // A serialized version of the function that reinitializes the reusable variable. +} + message DieRequest { } diff --git a/src/raylib.cc b/src/raylib.cc index 1cbf3a36a..f5f4e1545 100644 --- a/src/raylib.cc +++ b/src/raylib.cc @@ -714,13 +714,21 @@ static PyObject* wait_for_next_message(PyObject* self, PyObject* args) { Worker* worker; PyObjectToWorker(worker_capsule, &worker); if (std::unique_ptr message = worker->receive_next_message()) { + PyObject* variable_info; + if (!message->reusable_variable.variable_name.empty()) { + variable_info = PyTuple_New(3); + PyTuple_SetItem(variable_info, 0, PyString_FromStringAndSize(message->reusable_variable.variable_name.data(), static_cast(message->reusable_variable.variable_name.size()))); + PyTuple_SetItem(variable_info, 1, PyString_FromStringAndSize(message->reusable_variable.initializer.data(), static_cast(message->reusable_variable.initializer.size()))); + PyTuple_SetItem(variable_info, 2, PyString_FromStringAndSize(message->reusable_variable.reinitializer.data(), static_cast(message->reusable_variable.reinitializer.size()))); + } // The tuple constructed below will take ownership of some None objects. // When the tuple goes out of scope, the reference count for None will be // decremented. Therefore, we need to increment the reference count for None // every time we put a None in the tuple. - PyObject* t = PyTuple_New(2); // We set the items of the tuple using PyTuple_SetItem, because that transfers ownership to the tuple. + PyObject* t = PyTuple_New(3); // We set the items of the tuple using PyTuple_SetItem, because that transfers ownership to the tuple. PyTuple_SetItem(t, 0, message->task.name().empty() ? Py_INCREF(Py_None), Py_None : deserialize_task(worker_capsule, &message->task)); PyTuple_SetItem(t, 1, message->function.empty() ? Py_INCREF(Py_None), Py_None : PyString_FromStringAndSize(message->function.data(), static_cast(message->function.size()))); + PyTuple_SetItem(t, 2, message->reusable_variable.variable_name.empty() ? Py_INCREF(Py_None), Py_None : variable_info); return t; } Py_RETURN_NONE; @@ -740,6 +748,24 @@ static PyObject* export_function(PyObject* self, PyObject* args) { } } +static PyObject* export_reusable_variable(PyObject* self, PyObject* args) { + Worker* worker; + const char* name; + int name_size; + const char* initializer; + int initializer_size; + const char* reinitializer; + int reinitializer_size; + if (!PyArg_ParseTuple(args, "O&s#s#s#", &PyObjectToWorker, &worker, &name, &name_size, &initializer, &initializer_size, &reinitializer, &reinitializer_size)) { + return NULL; + } + std::string name_str(name, static_cast(name_size)); + std::string initializer_str(initializer, static_cast(initializer_size)); + std::string reinitializer_str(reinitializer, static_cast(reinitializer_size)); + worker->export_reusable_variable(name_str, initializer_str, reinitializer_str); + Py_RETURN_NONE; +} + static PyObject* submit_task(PyObject* self, PyObject* args) { PyObject* worker_capsule; Task* task; @@ -997,6 +1023,7 @@ static PyMethodDef RayLibMethods[] = { { "scheduler_info", scheduler_info, METH_VARARGS, "get info about scheduler state" }, { "task_info", task_info, METH_VARARGS, "get task statuses" }, { "export_function", export_function, METH_VARARGS, "export function to workers" }, + { "export_reusable_variable", export_reusable_variable, METH_VARARGS, "export a reusable variable to the workers" }, { "dump_computation_graph", dump_computation_graph, METH_VARARGS, "dump the current computation graph to a file" }, { "set_log_config", set_log_config, METH_VARARGS, "set filename for raylib logging" }, { "kill_workers", kill_workers, METH_VARARGS, "kills all of the workers" }, diff --git a/src/scheduler.cc b/src/scheduler.cc index 9ff67eec7..5e64cbe71 100644 --- a/src/scheduler.cc +++ b/src/scheduler.cc @@ -305,6 +305,22 @@ Status SchedulerService::ExportFunction(ServerContext* context, const ExportFunc return Status::OK; } +Status SchedulerService::ExportReusableVariable(ServerContext* context, const ExportReusableVariableRequest* request, AckReply* reply) { + auto workers = workers_.get(); + for (size_t i = 0; i < workers->size(); ++i) { + ClientContext import_context; + ImportReusableVariableRequest import_request; + import_request.set_name(request->name()); + import_request.mutable_initializer()->set_implementation(request->initializer().implementation()); + import_request.mutable_reinitializer()->set_implementation(request->reinitializer().implementation()); + if ((*workers)[i].current_task != ROOT_OPERATION) { + AckReply import_reply; + (*workers)[i].worker_stub->ImportReusableVariable(&import_context, import_request, &import_reply); + } + } + return Status::OK; +} + void SchedulerService::deliver_object_async_if_necessary(ObjRef canonical_objref, ObjStoreId from, ObjStoreId to) { bool object_present_or_in_transit; { diff --git a/src/scheduler.h b/src/scheduler.h index 3f3ebe1d1..7f2a11d7e 100644 --- a/src/scheduler.h +++ b/src/scheduler.h @@ -71,6 +71,7 @@ public: Status TaskInfo(ServerContext* context, const TaskInfoRequest* request, TaskInfoReply* reply) override; Status KillWorkers(ServerContext* context, const KillWorkersRequest* request, KillWorkersReply* reply) override; Status ExportFunction(ServerContext* context, const ExportFunctionRequest* request, ExportFunctionReply* reply) override; + Status ExportReusableVariable(ServerContext* context, const ExportReusableVariableRequest* request, AckReply* reply) override; // This will ask an object store to send an object to another object store if // the object is not already present in that object store and is not already diff --git a/src/worker.cc b/src/worker.cc index e2b8bed88..a5a2fea40 100644 --- a/src/worker.cc +++ b/src/worker.cc @@ -40,6 +40,20 @@ Status WorkerServiceImpl::ImportFunction(ServerContext* context, const ImportFun return Status::OK; } +Status WorkerServiceImpl::ImportReusableVariable(ServerContext* context, const ImportReusableVariableRequest* request, AckReply* reply) { + std::unique_ptr message(new WorkerMessage()); + message->reusable_variable.variable_name = request->name(); + message->reusable_variable.initializer = request->initializer().implementation(); + message->reusable_variable.reinitializer = request->reinitializer().implementation(); + RAY_LOG(RAY_INFO, "importing reusable variable"); + { + WorkerMessage* message_ptr = message.get(); + RAY_CHECK(send_queue_.send(&message_ptr), "error sending over IPC"); + } + message.release(); + return Status::OK; +} + Status WorkerServiceImpl::Die(ServerContext* context, const DieRequest* request, DieReply* reply) { WorkerMessage* message_ptr = NULL; RAY_CHECK(send_queue_.send(&message_ptr), "error sending over IPC"); @@ -389,6 +403,17 @@ bool Worker::export_function(const std::string& function) { return true; } +void Worker::export_reusable_variable(const std::string& name, const std::string& initializer, const std::string& reinitializer) { + RAY_CHECK(connected_, "Attempted to export reusable variable but failed."); + ClientContext context; + ExportReusableVariableRequest request; + request.set_name(name); + request.mutable_initializer()->set_implementation(initializer); + request.mutable_reinitializer()->set_implementation(reinitializer); + AckReply reply; + Status status = scheduler_stub_->ExportReusableVariable(&context, request, &reply); +} + // Communication between the WorkerServer and the Worker happens via a message // queue. This is because the Python interpreter needs to be single threaded // (in our case running in the main thread), whereas the WorkerService will diff --git a/src/worker.h b/src/worker.h index 14d8141a4..0f5a43c2e 100644 --- a/src/worker.h +++ b/src/worker.h @@ -23,9 +23,16 @@ using grpc::Channel; using grpc::ClientContext; using grpc::ClientWriter; +struct ReusableVariable { + std::string variable_name; + std::string initializer; + std::string reinitializer; +}; + struct WorkerMessage { Task task; - std::string function; + std::string function; // Used for importing remote functions. + ReusableVariable reusable_variable; // Used for importing reusable variables. }; class WorkerServiceImpl final : public WorkerService::Service { @@ -34,6 +41,7 @@ public: Status ExecuteTask(ServerContext* context, const ExecuteTaskRequest* request, ExecuteTaskReply* reply) override; Status ImportFunction(ServerContext* context, const ImportFunctionRequest* request, ImportFunctionReply* reply) override; Status Die(ServerContext* context, const DieRequest* request, DieReply* reply) override; + Status ImportReusableVariable(ServerContext* context, const ImportReusableVariableRequest* request, AckReply* reply) override; private: std::string worker_address_; MessageQueue send_queue_; @@ -100,6 +108,8 @@ class Worker { void task_info(ClientContext &context, TaskInfoRequest &request, TaskInfoReply &reply); // export function to workers bool export_function(const std::string& function); + // export reusable variable to workers + void export_reusable_variable(const std::string& name, const std::string& initializer, const std::string& reinitializer); private: bool connected_; diff --git a/test/runtest.py b/test/runtest.py index f558fd330..7a78b1419 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -507,5 +507,84 @@ class PythonCExtensionTest(unittest.TestCase): ray.services.cleanup() +class ReusablesTest(unittest.TestCase): + + def testReusables(self): + worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py") + ray.services.start_ray_local(num_workers=1, worker_path=worker_path) + + # Test that we can add a variable to the key-value store. + + def foo_initializer(): + return 1 + def foo_reinitializer(foo): + return foo + + ray.reusables.foo = ray.Reusable(foo_initializer, foo_reinitializer) + self.assertEqual(ray.reusables.foo, 1) + + @ray.remote([], [int]) + def use_foo(): + return ray.reusables.foo + self.assertEqual(ray.get(use_foo()), 1) + self.assertEqual(ray.get(use_foo()), 1) + self.assertEqual(ray.get(use_foo()), 1) + + # Test that we can add a variable to the key-value store, mutate it, and reset it. + + def bar_initializer(): + return [1, 2, 3] + + ray.reusables.bar = ray.Reusable(bar_initializer) + + @ray.remote([], [list]) + def use_bar(): + ray.reusables.bar.append(4) + return ray.reusables.bar + self.assertEqual(ray.get(use_bar()), [1, 2, 3, 4]) + self.assertEqual(ray.get(use_bar()), [1, 2, 3, 4]) + self.assertEqual(ray.get(use_bar()), [1, 2, 3, 4]) + + # Test that we can use the reinitializer. + + def baz_initializer(): + return np.zeros([4]) + def baz_reinitializer(baz): + for i in range(len(baz)): + baz[i] = 0 + return baz + + ray.reusables.baz = ray.Reusable(baz_initializer, baz_reinitializer) + + @ray.remote([int], [np.ndarray]) + def use_baz(i): + baz = ray.reusables.baz + baz[i] = 1 + return baz + self.assertTrue(np.alltrue(ray.get(use_baz(0)) == np.array([1, 0, 0, 0]))) + self.assertTrue(np.alltrue(ray.get(use_baz(1)) == np.array([0, 1, 0, 0]))) + self.assertTrue(np.alltrue(ray.get(use_baz(2)) == np.array([0, 0, 1, 0]))) + self.assertTrue(np.alltrue(ray.get(use_baz(3)) == np.array([0, 0, 0, 1]))) + + # Make sure the reinitializer is actually getting called. Note that this is + # not the correct usage of a reinitializer because it does not reset qux to + # its original state. This is just for testing. + + def qux_initializer(): + return 0 + def qux_reinitializer(x): + return x + 1 + + ray.reusables.qux = ray.Reusable(qux_initializer, qux_reinitializer) + + @ray.remote([], [int]) + def use_qux(): + return ray.reusables.qux + self.assertEqual(ray.get(use_qux()), 0) + self.assertEqual(ray.get(use_qux()), 1) + self.assertEqual(ray.get(use_qux()), 2) + + ray.services.cleanup() + if __name__ == "__main__": unittest.main()