Merge pull request #12 from amplab/fix

make ObjStore.GetObj wait until the object is present in the object s…
This commit is contained in:
Philipp Moritz
2016-03-15 00:16:14 -07:00
9 changed files with 106 additions and 54 deletions
+3 -5
View File
@@ -10,7 +10,7 @@ typedef size_t ObjStoreId;
class FnInfo {
size_t num_return_vals_;
std::vector<WorkerId> workers_;
std::vector<WorkerId> workers_; // `workers_` is a sorted vector
public:
void set_num_return_vals(size_t num) {
num_return_vals_ = num;
@@ -19,14 +19,12 @@ public:
return num_return_vals_;
}
void add_worker(WorkerId workerid) {
workers_.push_back(workerid);
// insert `workerid` into `workers_` so that `workers_` stays sorted
workers_.insert(std::lower_bound(workers_.begin(), workers_.end(), workerid), workerid);
}
size_t num_workers() const {
return workers_.size();
}
ObjRef worker(size_t i) const {
return workers_[i];
}
const std::vector<WorkerId>& workers() const {
return workers_;
}
+1 -3
View File
@@ -131,7 +131,6 @@ def get_arguments_for_execution(function, args, worker=global_worker):
"""
for (i, arg) in enumerate(args):
print "Pulling argument {} for function {}.".format(i, function.__name__)
if i < len(function.arg_types) - 1:
expected_type = function.arg_types[i]
elif i == len(function.arg_types) - 1 and function.arg_types[-1] is not None:
@@ -141,10 +140,9 @@ def get_arguments_for_execution(function, args, worker=global_worker):
else:
assert False, "This code should be unreachable."
argument = worker.get_object(arg) if type(arg) == orchpy.lib.ObjRef else arg
if type(arg) == orchpy.lib.ObjRef:
# get the object from the local object store
# TODO(rkn): Do we know that it is already there? Maybe we should call pull(arg, worker).
print "Getting argument {} for function {}.".format(i, function.__name__)
argument = worker.get_object(arg)
else:
# pass the argument by value
+12 -2
View File
@@ -1,4 +1,6 @@
#include "objstore.h"
#include <thread>
#include <chrono>
const size_t ObjStoreClient::CHUNK_SIZE = 8 * 1024;
@@ -84,13 +86,21 @@ Status ObjStoreService::GetObj(ServerContext* context, const GetObjRequest* requ
// TODO(pcm): There is one remaining case where this can fail, i.e. if an object is
// to be delivered from another store but hasn't yet arrived
ObjRef objref = request->objref();
memory_lock_.lock();
while (true) {
// if the object has not been sent to the objstore, this has the potential to lead to an infinite loop
std::this_thread::sleep_for(std::chrono::milliseconds(10));
ORCH_LOG(ORCH_DEBUG, "looping in objstore " << objstoreid_ << " waiting for objref " << objref);
std::lock_guard<std::mutex> memory_lock(memory_lock_);
if (memory_.find(objref) != memory_.end()) {
break;
}
}
std::lock_guard<std::mutex> memory_lock(memory_lock_);
shared_object& object = memory_[objref];
reply->set_bucket(object.name);
auto handle = object.memory->get_handle_from_address(object.ptr.data);
reply->set_handle(handle);
reply->set_size(object.ptr.len);
memory_lock_.unlock();
return Status::OK;
}
+1 -1
View File
@@ -393,7 +393,7 @@ PyObject* pull_object(PyObject* self, PyObject* args) {
if (!PyArg_ParseTuple(args, "O&O&", &PyObjectToWorker, &worker, &PyObjectToObjRef, &objref)) {
return NULL;
}
slice s = worker->get_object(objref);
slice s = worker->pull_object(objref);
Obj* obj = new Obj(); // TODO: Make sure this will get deleted
obj->ParseFromString(std::string(s.data, s.len));
return PyCapsule_New(static_cast<void*>(obj), "obj", NULL);
+63 -25
View File
@@ -16,9 +16,9 @@ Status SchedulerService::RemoteCall(ServerContext* context, const RemoteCallRequ
task->add_result(result);
}
tasks_lock_.lock();
tasks_.emplace_back(std::move(task));
tasks_lock_.unlock();
task_queue_lock_.lock();
task_queue_.emplace_back(std::move(task));
task_queue_lock_.unlock();
schedule();
return Status::OK;
@@ -33,18 +33,25 @@ Status SchedulerService::PushObj(ServerContext* context, const PushObjRequest* r
}
Status SchedulerService::PullObj(ServerContext* context, const PullObjRequest* request, AckReply* reply) {
std::lock_guard<std::mutex> objtable_lock(objtable_lock_);
objtable_lock_.lock();
size_t size = objtable_.size();
objtable_lock_.unlock();
ObjRef objref = request->objref();
if (objref >= objtable_.size() || objtable_[objref].size() == 0) {
if (objref >= size) {
ORCH_LOG(ORCH_FATAL, "internal error: no object with objref " << objref << " exists");
}
ObjStoreId objstoreid = pick_objstore(objref);
deliver_object(objref, objstoreid, get_store(request->workerid()));
pull_queue_lock_.lock();
pull_queue_.push_back(std::make_pair(request->workerid(), objref));
pull_queue_lock_.unlock();
schedule();
return Status::OK;
}
Status SchedulerService::RegisterObjStore(ServerContext* context, const RegisterObjStoreRequest* request, RegisterObjStoreReply* reply) {
std::lock_guard<std::mutex> lock();
std::lock_guard<std::mutex> objstore_lock(objstores_lock_);
ObjStoreId objstoreid = objstores_.size();
auto channel = grpc::CreateChannel(request->address(), grpc::InsecureChannelCredentials());
objstores_.push_back(ObjStoreHandle());
@@ -93,6 +100,9 @@ Status SchedulerService::SchedulerDebugInfo(ServerContext* context, const Schedu
}
void SchedulerService::deliver_object(ObjRef objref, ObjStoreId from, ObjStoreId to) {
if (from == to) {
ORCH_LOG(ORCH_FATAL, "attempting to deliver objref " << objref << " from objstore " << from << " to itself.");
}
ClientContext context;
AckReply reply;
DeliverObjRequest request;
@@ -104,21 +114,49 @@ void SchedulerService::deliver_object(ObjRef objref, ObjStoreId from, ObjStoreId
void SchedulerService::schedule() {
// TODO: don't recheck if nothing changed
std::lock_guard<std::mutex> avail_workers_lock(avail_workers_lock_);
std::lock_guard<std::mutex> fntable_lock(fntable_lock_);
std::lock_guard<std::mutex> tasks_lock(tasks_lock_);
for (int i = 0; i < avail_workers_.size(); ++i) {
WorkerId workerid = avail_workers_[i];
for (auto it = tasks_.begin(); it != tasks_.end(); ++it) {
const Call& task = *(*it);
auto& workers = fntable_[task.name()].workers();
if (std::binary_search(workers.begin(), workers.end(), workerid) && can_run(task)) {
submit_task(std::move(*it), workerid);
tasks_.erase(it);
std::swap(avail_workers_[i], avail_workers_[avail_workers_.size() - 1]);
avail_workers_.pop_back();
{
std::lock_guard<std::mutex> objtable_lock(objtable_lock_);
std::lock_guard<std::mutex> pull_queue_lock(pull_queue_lock_);
// Complete all pull tasks that can be completed.
for (int i = 0; i < pull_queue_.size(); ++i) {
const std::pair<WorkerId, ObjRef>& pull = pull_queue_[i];
WorkerId workerid = pull.first;
ObjRef objref = pull.second;
if (objtable_[objref].size() > 0) {
if (!std::binary_search(objtable_[objref].begin(), objtable_[objref].end(), get_store(workerid))) {
// The worker's local object store does not already contain objref, so ship
// it there from an object store that does have it.
ObjStoreId objstoreid = pick_objstore(objref);
deliver_object(objref, objstoreid, get_store(workerid));
}
// Remove the pull task from the queue
std::swap(pull_queue_[i], pull_queue_[pull_queue_.size() - 1]);
pull_queue_.pop_back();
i -= 1;
break;
}
}
}
{
std::lock_guard<std::mutex> fntable_lock(fntable_lock_);
std::lock_guard<std::mutex> avail_workers_lock(avail_workers_lock_);
std::lock_guard<std::mutex> task_queue_lock(task_queue_lock_);
for (int i = 0; i < avail_workers_.size(); ++i) {
// Submit all tasks whose arguments are ready.
WorkerId workerid = avail_workers_[i];
for (auto it = task_queue_.begin(); it != task_queue_.end(); ++it) {
// The use of erase(it) below invalidates the iterator, but we
// immediately break out of the inner loop, so the iterator is not used
// after the erase
const Call& task = *(*it);
auto& workers = fntable_[task.name()].workers();
if (std::binary_search(workers.begin(), workers.end(), workerid) && can_run(task)) {
submit_task(std::move(*it), workerid);
task_queue_.erase(it);
std::swap(avail_workers_[i], avail_workers_[avail_workers_.size() - 1]);
avail_workers_.pop_back();
i -= 1;
break;
}
}
}
}
@@ -231,12 +269,12 @@ void SchedulerService::debug_info(const SchedulerDebugInfoRequest& request, Sche
}
}
fntable_lock_.unlock();
tasks_lock_.lock();
for (const auto& entry : tasks_) {
task_queue_lock_.lock();
for (const auto& entry : task_queue_) {
Call* call = reply->add_task();
call->CopyFrom(*entry);
}
tasks_lock_.unlock();
task_queue_lock_.unlock();
avail_workers_lock_.lock();
for (const WorkerId& entry : avail_workers_) {
reply->add_avail_worker(entry);
+5 -2
View File
@@ -89,8 +89,11 @@ private:
FnTable fntable_;
std::mutex fntable_lock_;
// List of pending tasks.
std::deque<std::unique_ptr<Call> > tasks_;
std::mutex tasks_lock_;
std::deque<std::unique_ptr<Call> > task_queue_;
std::mutex task_queue_lock_;
// List of pending pull calls.
std::vector<std::pair<WorkerId, ObjRef> > pull_queue_;
std::mutex pull_queue_lock_;
};
#endif
+19 -5
View File
@@ -1,6 +1,9 @@
# include "worker.h"
Status WorkerServiceImpl::InvokeCall(ServerContext* context, const InvokeCallRequest* request, InvokeCallReply* reply) {
// TODO(rkn): This method opens a message_queue, which may consume a
// filehandle. This should be changed to only open a queue once in the
// constructor.
call_ = request->call(); // Copy call
ORCH_LOG(ORCH_INFO, "invoked task " << request->call().name());
try {
@@ -13,10 +16,24 @@ Status WorkerServiceImpl::InvokeCall(ServerContext* context, const InvokeCallReq
std::cout << ex.what() << std::endl;
// TODO: return Status;
}
message_queue::remove(worker_address_.c_str());
return Status::OK;
}
Worker::Worker(const std::string& worker_address, std::shared_ptr<Channel> scheduler_channel, std::shared_ptr<Channel> objstore_channel)
: worker_address_(worker_address),
scheduler_stub_(Scheduler::NewStub(scheduler_channel)),
objstore_stub_(ObjStore::NewStub(objstore_channel)) {
try {
// This creates the receive message queue.
const char* message_queue_name = worker_address_.c_str();
message_queue::remove(message_queue_name);
receive_queue_ = std::unique_ptr<message_queue>(new message_queue(create_only, message_queue_name, 1, sizeof(Call*)));
}
catch(interprocess_exception &ex) {
std::cout << ex.what() << std::endl;
}
}
RemoteCallReply Worker::remote_call(RemoteCallRequest* request) {
RemoteCallReply reply;
ClientContext context;
@@ -108,18 +125,15 @@ void Worker::register_function(const std::string& name, size_t num_return_vals)
Call* Worker::receive_next_task() {
const char* message_queue_name = worker_address_.c_str();
try {
message_queue::remove(message_queue_name);
message_queue mq(create_only, message_queue_name, 1, sizeof(Call*));
unsigned int priority;
message_queue::size_type recvd_size;
Call* call;
while (true) {
mq.receive(&call, sizeof(Call*), recvd_size, priority);
receive_queue_->receive(&call, sizeof(Call*), recvd_size, priority);
return call;
}
}
catch(interprocess_exception &ex){
message_queue::remove(message_queue_name);
std::cout << ex.what() << std::endl;
}
}
+2 -5
View File
@@ -37,11 +37,7 @@ private:
class Worker {
public:
Worker(const std::string& worker_address, std::shared_ptr<Channel> scheduler_channel, std::shared_ptr<Channel> objstore_channel)
: worker_address_(worker_address),
scheduler_stub_(Scheduler::NewStub(scheduler_channel)),
objstore_stub_(ObjStore::NewStub(objstore_channel))
{}
Worker(const std::string& worker_address, std::shared_ptr<Channel> scheduler_channel, std::shared_ptr<Channel> objstore_channel);
// submit a remote call to the scheduler
RemoteCallReply remote_call(RemoteCallRequest* request);
@@ -71,6 +67,7 @@ class Worker {
std::unique_ptr<ObjStore::Stub> objstore_stub_;
std::thread worker_server_thread_;
std::thread other_thread_;
std::unique_ptr<message_queue> receive_queue_;
managed_shared_memory segment_;
WorkerId workerid_;
std::string worker_address_;
-6
View File
@@ -73,30 +73,24 @@ class ArraysSingleTest(unittest.TestCase):
# test eye
ref = single.eye(3)
time.sleep(0.2)
val = orchpy.pull(ref)
self.assertTrue(np.alltrue(val == np.eye(3)))
# test zeros
ref = single.zeros([3, 4, 5])
time.sleep(0.2)
val = orchpy.pull(ref)
self.assertTrue(np.alltrue(val == np.zeros([3, 4, 5])))
# test qr - pass by value
val_a = np.random.normal(size=[10, 13])
time.sleep(0.2)
ref_q, ref_r = single.linalg.qr(val_a)
time.sleep(0.2)
val_q = orchpy.pull(ref_q)
val_r = orchpy.pull(ref_r)
self.assertTrue(np.allclose(np.dot(val_q, val_r), val_a))
# test qr - pass by objref
a = single.random.normal([10, 13])
time.sleep(0.2) # TODO(rkn): fails without this sleep
ref_q, ref_r = single.linalg.qr(a)
time.sleep(0.2)
val_a = orchpy.pull(a)
val_q = orchpy.pull(ref_q)
val_r = orchpy.pull(ref_r)