mirror of
https://github.com/wassname/ray.git
synced 2026-07-04 16:31:38 +08:00
implement key value store for sharing reusable variables
This commit is contained in:
@@ -18,5 +18,6 @@ import config
|
||||
import libraylib as lib
|
||||
import serialization
|
||||
from worker import scheduler_info, visualize_computation_graph, task_info, register_module, connect, disconnect, get, put, remote, kill_workers, restart_workers_local
|
||||
from worker import Reusable, reusables
|
||||
from libraylib import ObjRef
|
||||
import internal
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
import cloudpickle
|
||||
|
||||
def serialize(function):
|
||||
return cloudpickle.dumps(function)
|
||||
|
||||
def deserialize(serialized_function):
|
||||
return cloudpickle.loads(serialized_function)
|
||||
|
||||
def dumps(func, arg_types, return_types):
|
||||
return cloudpickle.dumps((func, arg_types, return_types))
|
||||
|
||||
|
||||
+165
-6
@@ -102,6 +102,135 @@ class RayDealloc(object):
|
||||
"""Deallocate the relevant segment to avoid a memory leak."""
|
||||
ray.lib.unmap_object(self.handle, self.segmentid)
|
||||
|
||||
class Reusable(object):
|
||||
"""An Python object that can be shared between tasks.
|
||||
|
||||
Attributes:
|
||||
initializer (Callable[[], object]): A function used to create and initialize
|
||||
the reusable variable.
|
||||
reinitializer (Optional[Callable[[object], object]]): An optional function
|
||||
used to reinitialize the reusable variable after it has been used. This
|
||||
argument can be used as an optimization if there is a fast way to
|
||||
reinitialize the state of the variable other than rerunning the
|
||||
initializer.
|
||||
"""
|
||||
|
||||
def __init__(self, initializer, reinitializer=None):
|
||||
"""Initialize a Reusable object."""
|
||||
if not isinstance(initializer, typing.Callable):
|
||||
raise Exception("When creating a RayReusable, initializer must be a function.")
|
||||
self.initializer = initializer
|
||||
if reinitializer is None:
|
||||
# If no reinitializer is passed in, use a wrapped version of the initializer.
|
||||
reinitializer = lambda value: initializer()
|
||||
if not isinstance(reinitializer, typing.Callable):
|
||||
raise Exception("When creating a RayReusable, reinitializer must be a function.")
|
||||
self.reinitializer = reinitializer
|
||||
|
||||
class RayReusables(object):
|
||||
"""An object used to store Python variables that are shared between tasks.
|
||||
|
||||
Each worker process will have a single RayReusables object. This class serves
|
||||
two purposes. First, some objects are not serializable, and so the code that
|
||||
creates those objects must be run on the worker that uses them. This class is
|
||||
responsible for running the code that creates those objects. Second, some of
|
||||
these objects are expensive to create, and so they should be shared between
|
||||
tasks. However, if a task mutates a variable that is shared between tasks,
|
||||
then the behavior of the overall program may be nondeterministic (it could
|
||||
depend on scheduling decisions). To fix this, if a task uses a one of these
|
||||
shared objects, then that shared object will be reinitialized after the task
|
||||
finishes. Since the initialization may be expensive, the user can pass in
|
||||
custom reinitialization code that resets the state of the shared variable to
|
||||
the way it was after initialization. If the reinitialization code does not do
|
||||
this, then the behavior of the overall program is undefined.
|
||||
|
||||
Attributes:
|
||||
_names (List[str]): A list of the names of all the reusable variables.
|
||||
_initializers (dict[str, [Callable[[], object]])]: A dictionary mapping the
|
||||
names of the reusable variables to the code for initializing them.
|
||||
_reinitializers (Dict[str, Callable[[object], object]]): A dictionary
|
||||
mapping the names of the reusable variables to the code for reinitializing
|
||||
them. For reusable variables for which reinitializer code is not provided,
|
||||
the reinitializer here essentially wraps the initializer.
|
||||
_used (List[str]): A list of the names of all the reusable variables that
|
||||
have been accessed within the scope of the current task. This is reset to
|
||||
the empty list after each task.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize a RayReusables object."""
|
||||
self._names = set()
|
||||
self._reusables = {}
|
||||
self._used = set()
|
||||
self._slots = ("_names", "_reusables", "_used", "_slots", "_reinitialize", "__getattribute__", "__setattr__", "__delattr__")
|
||||
# CHECKPOINT: Any attributes assigned before _here_ will be protected from rewrite or deletion
|
||||
|
||||
def _reinitialize(self):
|
||||
"""Reinitialize the reusable variables that the current task used."""
|
||||
for name in self._used:
|
||||
current_value = getattr(self, name)
|
||||
new_value = self._reusables[name].reinitializer(current_value)
|
||||
object.__setattr__(self, name, new_value)
|
||||
self._used.clear() # Reset the _used list.
|
||||
|
||||
def __getattribute__(self, name):
|
||||
"""Get an attribute. This handles reusable variables as a special case.
|
||||
|
||||
When __getattribute__ is called with the name of a reusable variable, that
|
||||
name is added to the list of variables that were used in the current task.
|
||||
|
||||
Args:
|
||||
name (str): The name of the attribute to get.
|
||||
"""
|
||||
if name == "_slots":
|
||||
return object.__getattribute__(self, name)
|
||||
if name in self._slots:
|
||||
return object.__getattribute__(self, name)
|
||||
if name in self._names and name not in self._used:
|
||||
self._used.add(name)
|
||||
return object.__getattribute__(self, name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
"""Set an attribute. This handles reusable variables as a special case.
|
||||
|
||||
This is used to create reusable variables. When it is called, it runs the
|
||||
function for initializing the variable to create the variable. If this is
|
||||
called on the driver, then the functions for initializing and reinitializing
|
||||
the variable are shipped to the workers.
|
||||
|
||||
Args:
|
||||
name (str): The name of the attribute to set. This is either a whitelisted
|
||||
name or it is treated as the name of a reusable variable.
|
||||
value: If name is a whitelisted name, then value can be any value. If name
|
||||
is the name of a reusable variable, then this is either the serialized
|
||||
initializer code or it is a tuple of the serialized initializer and
|
||||
reinitializer code.
|
||||
"""
|
||||
try:
|
||||
slots = self._slots
|
||||
except AttributeError:
|
||||
slots = ()
|
||||
if slots == ():
|
||||
return object.__setattr__(self, name, value)
|
||||
if name in slots:
|
||||
raise AttributeError("Illegal assignment to {} object attribute {}".format(self.__class__.__name__, name))
|
||||
reusable = value
|
||||
if not issubclass(type(reusable), Reusable):
|
||||
raise Exception("To set a reusable variable, you must pass in a Reusable object")
|
||||
self._names.add(name)
|
||||
self._reusables[name] = reusable
|
||||
if _mode() in [ray.SHELL_MODE, ray.SCRIPT_MODE]:
|
||||
_export_reusable_variable(name, reusable)
|
||||
object.__setattr__(self, name, reusable.initializer())
|
||||
|
||||
def __delattr__(self, name):
|
||||
"""We do not allow attributes of RayReusables to be deleted.
|
||||
|
||||
Args:
|
||||
name (str): The name of the attribute to delete.
|
||||
"""
|
||||
raise Exception("Attempted deletion of attribute {}. Attributes of a RayReusable object may not be deleted.".format(name))
|
||||
|
||||
class Worker(object):
|
||||
"""A class used to define the control flow of a worker process.
|
||||
|
||||
@@ -114,7 +243,6 @@ class Worker(object):
|
||||
function to the remote function itself. This is the set of remote
|
||||
functions that can be executed by this worker.
|
||||
handle (worker capsule): A Python object wrapping a C++ Worker object.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -250,6 +378,16 @@ We use a global Worker object to ensure that there is a single worker object
|
||||
per worker process.
|
||||
"""
|
||||
|
||||
reusables = RayReusables()
|
||||
"""RayReusables: The reusable variables that are shared between tasks.
|
||||
|
||||
Each worker process has its own RayReusables object, and these objects should be
|
||||
the same in all workers. This is used for storing variables that are not
|
||||
serializable but must be used by remote tasks. In addition, it is used to
|
||||
reinitialize these variables after they are used so that changes to their state
|
||||
made by one task do not affect other tasks.
|
||||
"""
|
||||
|
||||
def print_failed_task(task_status):
|
||||
"""Print information about failed tasks.
|
||||
|
||||
@@ -511,13 +649,16 @@ def main_loop(worker=global_worker):
|
||||
else:
|
||||
store_outputs_in_objstore(return_objrefs, outputs, worker) # store output in local object store
|
||||
ray.lib.notify_task_completed(worker.handle, True, "") # notify the scheduler that the task completed successfully
|
||||
finally:
|
||||
# Reinitialize the values of reusable variables that were used in the task
|
||||
# above so that changes made to their state do not affect other tasks.
|
||||
ray.reusables._reinitialize()
|
||||
while True:
|
||||
(task, function) = ray.lib.wait_for_next_message(worker.handle)
|
||||
(task, function, reusable_variable) = ray.lib.wait_for_next_message(worker.handle)
|
||||
try:
|
||||
# Currently the schedule does not ask the worker to execute a task and
|
||||
# import a function at the same time.
|
||||
assert task is None or function is None
|
||||
if task is None and function is None:
|
||||
# Only one of task, function, and reusable_variable should be not None.
|
||||
assert sum([obj is not None for obj in [task, function, reusable_variable]]) <= 1
|
||||
if task is None and function is None and reusable_variable is None:
|
||||
# We use this as a mechanism to allow the scheduler to kill workers. When
|
||||
# the scheduler wants to kill a worker, it gives the worker a null task,
|
||||
# causing the worker program to exit the main loop here.
|
||||
@@ -526,12 +667,18 @@ def main_loop(worker=global_worker):
|
||||
(function, arg_types, return_types) = pickling.loads(function)
|
||||
if function.__module__ is None: function.__module__ = "__main__"
|
||||
worker.register_function(remote(arg_types, return_types, worker)(function))
|
||||
if reusable_variable is not None:
|
||||
name, initializer_str, reinitializer_str = reusable_variable
|
||||
initializer = pickling.deserialize(initializer_str)
|
||||
reinitializer = pickling.deserialize(reinitializer_str)
|
||||
reusables.__setattr__(name, Reusable(initializer, reinitializer))
|
||||
if task is not None:
|
||||
process_task(task)
|
||||
finally:
|
||||
# Allow releasing the variables BEFORE we wait for the next message or exit the block
|
||||
del task
|
||||
del function
|
||||
del reusable_variable
|
||||
|
||||
def _submit_task(func_name, args, worker=global_worker):
|
||||
"""This is a wrapper around worker.submit_task.
|
||||
@@ -553,6 +700,18 @@ def _mode(worker=global_worker):
|
||||
"""
|
||||
return worker.mode
|
||||
|
||||
def _export_reusable_variable(name, reusable, worker=global_worker):
|
||||
"""Export a reusable variable to the workers. This is only called by a driver.
|
||||
|
||||
Args:
|
||||
name (str): The name of the variable to export.
|
||||
reusable (Reusable): The reusable object containing code for initializing
|
||||
and reinitializing the variable.
|
||||
"""
|
||||
if _mode(worker) not in [ray.SHELL_MODE, ray.SCRIPT_MODE]:
|
||||
raise Exception("_export_reusable_variable can only be called on a driver.")
|
||||
ray.lib.export_reusable_variable(worker.handle, name, pickling.serialize(reusable.initializer), pickling.serialize(reusable.reinitializer))
|
||||
|
||||
def remote(arg_types, return_types, worker=global_worker):
|
||||
"""This decorator is used to create remote functions.
|
||||
|
||||
|
||||
@@ -54,6 +54,8 @@ service Scheduler {
|
||||
rpc KillWorkers(KillWorkersRequest) returns (KillWorkersReply);
|
||||
// Exports function to the workers
|
||||
rpc ExportFunction(ExportFunctionRequest) returns (ExportFunctionReply);
|
||||
// Ship an initializer and reinitializer for a reusable variable to the workers
|
||||
rpc ExportReusableVariable(ExportReusableVariableRequest) returns (AckReply);
|
||||
}
|
||||
|
||||
message AckReply {
|
||||
@@ -237,6 +239,12 @@ message ExportFunctionRequest {
|
||||
message ExportFunctionReply {
|
||||
}
|
||||
|
||||
message ExportReusableVariableRequest {
|
||||
string name = 1; // The name of the reusable variable.
|
||||
Function initializer = 2; // A serialized version of the function that initializes the reusable variable.
|
||||
Function reinitializer = 3; // A serialized version of the function that reinitializes the reusable variable.
|
||||
}
|
||||
|
||||
// These messages are for getting information about the object store state
|
||||
|
||||
message ObjStoreInfoRequest {
|
||||
@@ -253,6 +261,7 @@ message ObjStoreInfoReply {
|
||||
service WorkerService {
|
||||
rpc ExecuteTask(ExecuteTaskRequest) returns (ExecuteTaskReply); // Scheduler calls a function from the worker
|
||||
rpc ImportFunction(ImportFunctionRequest) returns (ImportFunctionReply); // Scheduler imports a function into the worker
|
||||
rpc ImportReusableVariable(ImportReusableVariableRequest) returns (AckReply); // Scheduler imports a reusable variable into the worker
|
||||
rpc Die(DieRequest) returns (DieReply); // Kills this worker
|
||||
}
|
||||
|
||||
@@ -270,6 +279,12 @@ message ImportFunctionRequest {
|
||||
message ImportFunctionReply {
|
||||
}
|
||||
|
||||
message ImportReusableVariableRequest {
|
||||
string name = 1; // The name of the reusable variable.
|
||||
Function initializer = 2; // A serialized version of the function that initializes the reusable variable.
|
||||
Function reinitializer = 3; // A serialized version of the function that reinitializes the reusable variable.
|
||||
}
|
||||
|
||||
message DieRequest {
|
||||
}
|
||||
|
||||
|
||||
+28
-1
@@ -714,13 +714,21 @@ static PyObject* wait_for_next_message(PyObject* self, PyObject* args) {
|
||||
Worker* worker;
|
||||
PyObjectToWorker(worker_capsule, &worker);
|
||||
if (std::unique_ptr<WorkerMessage> message = worker->receive_next_message()) {
|
||||
PyObject* variable_info;
|
||||
if (!message->reusable_variable.variable_name.empty()) {
|
||||
variable_info = PyTuple_New(3);
|
||||
PyTuple_SetItem(variable_info, 0, PyString_FromStringAndSize(message->reusable_variable.variable_name.data(), static_cast<ssize_t>(message->reusable_variable.variable_name.size())));
|
||||
PyTuple_SetItem(variable_info, 1, PyString_FromStringAndSize(message->reusable_variable.initializer.data(), static_cast<ssize_t>(message->reusable_variable.initializer.size())));
|
||||
PyTuple_SetItem(variable_info, 2, PyString_FromStringAndSize(message->reusable_variable.reinitializer.data(), static_cast<ssize_t>(message->reusable_variable.reinitializer.size())));
|
||||
}
|
||||
// The tuple constructed below will take ownership of some None objects.
|
||||
// When the tuple goes out of scope, the reference count for None will be
|
||||
// decremented. Therefore, we need to increment the reference count for None
|
||||
// every time we put a None in the tuple.
|
||||
PyObject* t = PyTuple_New(2); // We set the items of the tuple using PyTuple_SetItem, because that transfers ownership to the tuple.
|
||||
PyObject* t = PyTuple_New(3); // We set the items of the tuple using PyTuple_SetItem, because that transfers ownership to the tuple.
|
||||
PyTuple_SetItem(t, 0, message->task.name().empty() ? Py_INCREF(Py_None), Py_None : deserialize_task(worker_capsule, &message->task));
|
||||
PyTuple_SetItem(t, 1, message->function.empty() ? Py_INCREF(Py_None), Py_None : PyString_FromStringAndSize(message->function.data(), static_cast<ssize_t>(message->function.size())));
|
||||
PyTuple_SetItem(t, 2, message->reusable_variable.variable_name.empty() ? Py_INCREF(Py_None), Py_None : variable_info);
|
||||
return t;
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
@@ -740,6 +748,24 @@ static PyObject* export_function(PyObject* self, PyObject* args) {
|
||||
}
|
||||
}
|
||||
|
||||
static PyObject* export_reusable_variable(PyObject* self, PyObject* args) {
|
||||
Worker* worker;
|
||||
const char* name;
|
||||
int name_size;
|
||||
const char* initializer;
|
||||
int initializer_size;
|
||||
const char* reinitializer;
|
||||
int reinitializer_size;
|
||||
if (!PyArg_ParseTuple(args, "O&s#s#s#", &PyObjectToWorker, &worker, &name, &name_size, &initializer, &initializer_size, &reinitializer, &reinitializer_size)) {
|
||||
return NULL;
|
||||
}
|
||||
std::string name_str(name, static_cast<size_t>(name_size));
|
||||
std::string initializer_str(initializer, static_cast<size_t>(initializer_size));
|
||||
std::string reinitializer_str(reinitializer, static_cast<size_t>(reinitializer_size));
|
||||
worker->export_reusable_variable(name_str, initializer_str, reinitializer_str);
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* submit_task(PyObject* self, PyObject* args) {
|
||||
PyObject* worker_capsule;
|
||||
Task* task;
|
||||
@@ -997,6 +1023,7 @@ static PyMethodDef RayLibMethods[] = {
|
||||
{ "scheduler_info", scheduler_info, METH_VARARGS, "get info about scheduler state" },
|
||||
{ "task_info", task_info, METH_VARARGS, "get task statuses" },
|
||||
{ "export_function", export_function, METH_VARARGS, "export function to workers" },
|
||||
{ "export_reusable_variable", export_reusable_variable, METH_VARARGS, "export a reusable variable to the workers" },
|
||||
{ "dump_computation_graph", dump_computation_graph, METH_VARARGS, "dump the current computation graph to a file" },
|
||||
{ "set_log_config", set_log_config, METH_VARARGS, "set filename for raylib logging" },
|
||||
{ "kill_workers", kill_workers, METH_VARARGS, "kills all of the workers" },
|
||||
|
||||
@@ -305,6 +305,22 @@ Status SchedulerService::ExportFunction(ServerContext* context, const ExportFunc
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
Status SchedulerService::ExportReusableVariable(ServerContext* context, const ExportReusableVariableRequest* request, AckReply* reply) {
|
||||
auto workers = workers_.get();
|
||||
for (size_t i = 0; i < workers->size(); ++i) {
|
||||
ClientContext import_context;
|
||||
ImportReusableVariableRequest import_request;
|
||||
import_request.set_name(request->name());
|
||||
import_request.mutable_initializer()->set_implementation(request->initializer().implementation());
|
||||
import_request.mutable_reinitializer()->set_implementation(request->reinitializer().implementation());
|
||||
if ((*workers)[i].current_task != ROOT_OPERATION) {
|
||||
AckReply import_reply;
|
||||
(*workers)[i].worker_stub->ImportReusableVariable(&import_context, import_request, &import_reply);
|
||||
}
|
||||
}
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
void SchedulerService::deliver_object_async_if_necessary(ObjRef canonical_objref, ObjStoreId from, ObjStoreId to) {
|
||||
bool object_present_or_in_transit;
|
||||
{
|
||||
|
||||
@@ -71,6 +71,7 @@ public:
|
||||
Status TaskInfo(ServerContext* context, const TaskInfoRequest* request, TaskInfoReply* reply) override;
|
||||
Status KillWorkers(ServerContext* context, const KillWorkersRequest* request, KillWorkersReply* reply) override;
|
||||
Status ExportFunction(ServerContext* context, const ExportFunctionRequest* request, ExportFunctionReply* reply) override;
|
||||
Status ExportReusableVariable(ServerContext* context, const ExportReusableVariableRequest* request, AckReply* reply) override;
|
||||
|
||||
// This will ask an object store to send an object to another object store if
|
||||
// the object is not already present in that object store and is not already
|
||||
|
||||
@@ -40,6 +40,20 @@ Status WorkerServiceImpl::ImportFunction(ServerContext* context, const ImportFun
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
Status WorkerServiceImpl::ImportReusableVariable(ServerContext* context, const ImportReusableVariableRequest* request, AckReply* reply) {
|
||||
std::unique_ptr<WorkerMessage> message(new WorkerMessage());
|
||||
message->reusable_variable.variable_name = request->name();
|
||||
message->reusable_variable.initializer = request->initializer().implementation();
|
||||
message->reusable_variable.reinitializer = request->reinitializer().implementation();
|
||||
RAY_LOG(RAY_INFO, "importing reusable variable");
|
||||
{
|
||||
WorkerMessage* message_ptr = message.get();
|
||||
RAY_CHECK(send_queue_.send(&message_ptr), "error sending over IPC");
|
||||
}
|
||||
message.release();
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
Status WorkerServiceImpl::Die(ServerContext* context, const DieRequest* request, DieReply* reply) {
|
||||
WorkerMessage* message_ptr = NULL;
|
||||
RAY_CHECK(send_queue_.send(&message_ptr), "error sending over IPC");
|
||||
@@ -389,6 +403,17 @@ bool Worker::export_function(const std::string& function) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void Worker::export_reusable_variable(const std::string& name, const std::string& initializer, const std::string& reinitializer) {
|
||||
RAY_CHECK(connected_, "Attempted to export reusable variable but failed.");
|
||||
ClientContext context;
|
||||
ExportReusableVariableRequest request;
|
||||
request.set_name(name);
|
||||
request.mutable_initializer()->set_implementation(initializer);
|
||||
request.mutable_reinitializer()->set_implementation(reinitializer);
|
||||
AckReply reply;
|
||||
Status status = scheduler_stub_->ExportReusableVariable(&context, request, &reply);
|
||||
}
|
||||
|
||||
// Communication between the WorkerServer and the Worker happens via a message
|
||||
// queue. This is because the Python interpreter needs to be single threaded
|
||||
// (in our case running in the main thread), whereas the WorkerService will
|
||||
|
||||
+11
-1
@@ -23,9 +23,16 @@ using grpc::Channel;
|
||||
using grpc::ClientContext;
|
||||
using grpc::ClientWriter;
|
||||
|
||||
struct ReusableVariable {
|
||||
std::string variable_name;
|
||||
std::string initializer;
|
||||
std::string reinitializer;
|
||||
};
|
||||
|
||||
struct WorkerMessage {
|
||||
Task task;
|
||||
std::string function;
|
||||
std::string function; // Used for importing remote functions.
|
||||
ReusableVariable reusable_variable; // Used for importing reusable variables.
|
||||
};
|
||||
|
||||
class WorkerServiceImpl final : public WorkerService::Service {
|
||||
@@ -34,6 +41,7 @@ public:
|
||||
Status ExecuteTask(ServerContext* context, const ExecuteTaskRequest* request, ExecuteTaskReply* reply) override;
|
||||
Status ImportFunction(ServerContext* context, const ImportFunctionRequest* request, ImportFunctionReply* reply) override;
|
||||
Status Die(ServerContext* context, const DieRequest* request, DieReply* reply) override;
|
||||
Status ImportReusableVariable(ServerContext* context, const ImportReusableVariableRequest* request, AckReply* reply) override;
|
||||
private:
|
||||
std::string worker_address_;
|
||||
MessageQueue<WorkerMessage*> send_queue_;
|
||||
@@ -100,6 +108,8 @@ class Worker {
|
||||
void task_info(ClientContext &context, TaskInfoRequest &request, TaskInfoReply &reply);
|
||||
// export function to workers
|
||||
bool export_function(const std::string& function);
|
||||
// export reusable variable to workers
|
||||
void export_reusable_variable(const std::string& name, const std::string& initializer, const std::string& reinitializer);
|
||||
|
||||
private:
|
||||
bool connected_;
|
||||
|
||||
@@ -507,5 +507,84 @@ class PythonCExtensionTest(unittest.TestCase):
|
||||
|
||||
ray.services.cleanup()
|
||||
|
||||
class ReusablesTest(unittest.TestCase):
|
||||
|
||||
def testReusables(self):
|
||||
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
|
||||
ray.services.start_ray_local(num_workers=1, worker_path=worker_path)
|
||||
|
||||
# Test that we can add a variable to the key-value store.
|
||||
|
||||
def foo_initializer():
|
||||
return 1
|
||||
def foo_reinitializer(foo):
|
||||
return foo
|
||||
|
||||
ray.reusables.foo = ray.Reusable(foo_initializer, foo_reinitializer)
|
||||
self.assertEqual(ray.reusables.foo, 1)
|
||||
|
||||
@ray.remote([], [int])
|
||||
def use_foo():
|
||||
return ray.reusables.foo
|
||||
self.assertEqual(ray.get(use_foo()), 1)
|
||||
self.assertEqual(ray.get(use_foo()), 1)
|
||||
self.assertEqual(ray.get(use_foo()), 1)
|
||||
|
||||
# Test that we can add a variable to the key-value store, mutate it, and reset it.
|
||||
|
||||
def bar_initializer():
|
||||
return [1, 2, 3]
|
||||
|
||||
ray.reusables.bar = ray.Reusable(bar_initializer)
|
||||
|
||||
@ray.remote([], [list])
|
||||
def use_bar():
|
||||
ray.reusables.bar.append(4)
|
||||
return ray.reusables.bar
|
||||
self.assertEqual(ray.get(use_bar()), [1, 2, 3, 4])
|
||||
self.assertEqual(ray.get(use_bar()), [1, 2, 3, 4])
|
||||
self.assertEqual(ray.get(use_bar()), [1, 2, 3, 4])
|
||||
|
||||
# Test that we can use the reinitializer.
|
||||
|
||||
def baz_initializer():
|
||||
return np.zeros([4])
|
||||
def baz_reinitializer(baz):
|
||||
for i in range(len(baz)):
|
||||
baz[i] = 0
|
||||
return baz
|
||||
|
||||
ray.reusables.baz = ray.Reusable(baz_initializer, baz_reinitializer)
|
||||
|
||||
@ray.remote([int], [np.ndarray])
|
||||
def use_baz(i):
|
||||
baz = ray.reusables.baz
|
||||
baz[i] = 1
|
||||
return baz
|
||||
self.assertTrue(np.alltrue(ray.get(use_baz(0)) == np.array([1, 0, 0, 0])))
|
||||
self.assertTrue(np.alltrue(ray.get(use_baz(1)) == np.array([0, 1, 0, 0])))
|
||||
self.assertTrue(np.alltrue(ray.get(use_baz(2)) == np.array([0, 0, 1, 0])))
|
||||
self.assertTrue(np.alltrue(ray.get(use_baz(3)) == np.array([0, 0, 0, 1])))
|
||||
|
||||
# Make sure the reinitializer is actually getting called. Note that this is
|
||||
# not the correct usage of a reinitializer because it does not reset qux to
|
||||
# its original state. This is just for testing.
|
||||
|
||||
def qux_initializer():
|
||||
return 0
|
||||
def qux_reinitializer(x):
|
||||
return x + 1
|
||||
|
||||
ray.reusables.qux = ray.Reusable(qux_initializer, qux_reinitializer)
|
||||
|
||||
@ray.remote([], [int])
|
||||
def use_qux():
|
||||
return ray.reusables.qux
|
||||
self.assertEqual(ray.get(use_qux()), 0)
|
||||
self.assertEqual(ray.get(use_qux()), 1)
|
||||
self.assertEqual(ray.get(use_qux()), 2)
|
||||
|
||||
ray.services.cleanup()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user