diff --git a/include/orchestra/orchestra.h b/include/orchestra/orchestra.h index f9bd9f8c6..65688d93e 100644 --- a/include/orchestra/orchestra.h +++ b/include/orchestra/orchestra.h @@ -10,7 +10,7 @@ typedef size_t ObjStoreId; class FnInfo { size_t num_return_vals_; - std::vector workers_; + std::vector workers_; // `workers_` is a sorted vector public: void set_num_return_vals(size_t num) { num_return_vals_ = num; @@ -19,14 +19,12 @@ public: return num_return_vals_; } void add_worker(WorkerId workerid) { - workers_.push_back(workerid); + // insert `workerid` into `workers_` so that `workers_` stays sorted + workers_.insert(std::lower_bound(workers_.begin(), workers_.end(), workerid), workerid); } size_t num_workers() const { return workers_.size(); } - ObjRef worker(size_t i) const { - return workers_[i]; - } const std::vector& workers() const { return workers_; } diff --git a/lib/orchpy/orchpy/worker.py b/lib/orchpy/orchpy/worker.py index aca90688b..87c3e4479 100644 --- a/lib/orchpy/orchpy/worker.py +++ b/lib/orchpy/orchpy/worker.py @@ -131,7 +131,6 @@ def get_arguments_for_execution(function, args, worker=global_worker): """ for (i, arg) in enumerate(args): - print "Pulling argument {} for function {}.".format(i, function.__name__) if i < len(function.arg_types) - 1: expected_type = function.arg_types[i] elif i == len(function.arg_types) - 1 and function.arg_types[-1] is not None: @@ -141,10 +140,9 @@ def get_arguments_for_execution(function, args, worker=global_worker): else: assert False, "This code should be unreachable." - argument = worker.get_object(arg) if type(arg) == orchpy.lib.ObjRef else arg if type(arg) == orchpy.lib.ObjRef: # get the object from the local object store - # TODO(rkn): Do we know that it is already there? Maybe we should call pull(arg, worker). + print "Getting argument {} for function {}.".format(i, function.__name__) argument = worker.get_object(arg) else: # pass the argument by value diff --git a/src/objstore.cc b/src/objstore.cc index b617eb3ac..0991bf43a 100644 --- a/src/objstore.cc +++ b/src/objstore.cc @@ -1,4 +1,6 @@ #include "objstore.h" +#include +#include const size_t ObjStoreClient::CHUNK_SIZE = 8 * 1024; @@ -84,13 +86,21 @@ Status ObjStoreService::GetObj(ServerContext* context, const GetObjRequest* requ // TODO(pcm): There is one remaining case where this can fail, i.e. if an object is // to be delivered from another store but hasn't yet arrived ObjRef objref = request->objref(); - memory_lock_.lock(); + while (true) { + // if the object has not been sent to the objstore, this has the potential to lead to an infinite loop + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + ORCH_LOG(ORCH_DEBUG, "looping in objstore " << objstoreid_ << " waiting for objref " << objref); + std::lock_guard memory_lock(memory_lock_); + if (memory_.find(objref) != memory_.end()) { + break; + } + } + std::lock_guard memory_lock(memory_lock_); shared_object& object = memory_[objref]; reply->set_bucket(object.name); auto handle = object.memory->get_handle_from_address(object.ptr.data); reply->set_handle(handle); reply->set_size(object.ptr.len); - memory_lock_.unlock(); return Status::OK; } diff --git a/src/orchpylib.cc b/src/orchpylib.cc index 04d0ae6a3..eb8312222 100644 --- a/src/orchpylib.cc +++ b/src/orchpylib.cc @@ -393,7 +393,7 @@ PyObject* pull_object(PyObject* self, PyObject* args) { if (!PyArg_ParseTuple(args, "O&O&", &PyObjectToWorker, &worker, &PyObjectToObjRef, &objref)) { return NULL; } - slice s = worker->get_object(objref); + slice s = worker->pull_object(objref); Obj* obj = new Obj(); // TODO: Make sure this will get deleted obj->ParseFromString(std::string(s.data, s.len)); return PyCapsule_New(static_cast(obj), "obj", NULL); diff --git a/src/scheduler.cc b/src/scheduler.cc index 672d57581..e8c1e8503 100644 --- a/src/scheduler.cc +++ b/src/scheduler.cc @@ -16,9 +16,9 @@ Status SchedulerService::RemoteCall(ServerContext* context, const RemoteCallRequ task->add_result(result); } - tasks_lock_.lock(); - tasks_.emplace_back(std::move(task)); - tasks_lock_.unlock(); + task_queue_lock_.lock(); + task_queue_.emplace_back(std::move(task)); + task_queue_lock_.unlock(); schedule(); return Status::OK; @@ -33,18 +33,25 @@ Status SchedulerService::PushObj(ServerContext* context, const PushObjRequest* r } Status SchedulerService::PullObj(ServerContext* context, const PullObjRequest* request, AckReply* reply) { - std::lock_guard objtable_lock(objtable_lock_); + objtable_lock_.lock(); + size_t size = objtable_.size(); + objtable_lock_.unlock(); + ObjRef objref = request->objref(); - if (objref >= objtable_.size() || objtable_[objref].size() == 0) { + if (objref >= size) { ORCH_LOG(ORCH_FATAL, "internal error: no object with objref " << objref << " exists"); } - ObjStoreId objstoreid = pick_objstore(objref); - deliver_object(objref, objstoreid, get_store(request->workerid())); + + pull_queue_lock_.lock(); + pull_queue_.push_back(std::make_pair(request->workerid(), objref)); + pull_queue_lock_.unlock(); + + schedule(); return Status::OK; } Status SchedulerService::RegisterObjStore(ServerContext* context, const RegisterObjStoreRequest* request, RegisterObjStoreReply* reply) { - std::lock_guard lock(); + std::lock_guard objstore_lock(objstores_lock_); ObjStoreId objstoreid = objstores_.size(); auto channel = grpc::CreateChannel(request->address(), grpc::InsecureChannelCredentials()); objstores_.push_back(ObjStoreHandle()); @@ -93,6 +100,9 @@ Status SchedulerService::SchedulerDebugInfo(ServerContext* context, const Schedu } void SchedulerService::deliver_object(ObjRef objref, ObjStoreId from, ObjStoreId to) { + if (from == to) { + ORCH_LOG(ORCH_FATAL, "attempting to deliver objref " << objref << " from objstore " << from << " to itself."); + } ClientContext context; AckReply reply; DeliverObjRequest request; @@ -104,21 +114,49 @@ void SchedulerService::deliver_object(ObjRef objref, ObjStoreId from, ObjStoreId void SchedulerService::schedule() { // TODO: don't recheck if nothing changed - std::lock_guard avail_workers_lock(avail_workers_lock_); - std::lock_guard fntable_lock(fntable_lock_); - std::lock_guard tasks_lock(tasks_lock_); - for (int i = 0; i < avail_workers_.size(); ++i) { - WorkerId workerid = avail_workers_[i]; - for (auto it = tasks_.begin(); it != tasks_.end(); ++it) { - const Call& task = *(*it); - auto& workers = fntable_[task.name()].workers(); - if (std::binary_search(workers.begin(), workers.end(), workerid) && can_run(task)) { - submit_task(std::move(*it), workerid); - tasks_.erase(it); - std::swap(avail_workers_[i], avail_workers_[avail_workers_.size() - 1]); - avail_workers_.pop_back(); + { + std::lock_guard objtable_lock(objtable_lock_); + std::lock_guard pull_queue_lock(pull_queue_lock_); + // Complete all pull tasks that can be completed. + for (int i = 0; i < pull_queue_.size(); ++i) { + const std::pair& pull = pull_queue_[i]; + WorkerId workerid = pull.first; + ObjRef objref = pull.second; + if (objtable_[objref].size() > 0) { + if (!std::binary_search(objtable_[objref].begin(), objtable_[objref].end(), get_store(workerid))) { + // The worker's local object store does not already contain objref, so ship + // it there from an object store that does have it. + ObjStoreId objstoreid = pick_objstore(objref); + deliver_object(objref, objstoreid, get_store(workerid)); + } + // Remove the pull task from the queue + std::swap(pull_queue_[i], pull_queue_[pull_queue_.size() - 1]); + pull_queue_.pop_back(); i -= 1; - break; + } + } + } + { + std::lock_guard fntable_lock(fntable_lock_); + std::lock_guard avail_workers_lock(avail_workers_lock_); + std::lock_guard task_queue_lock(task_queue_lock_); + for (int i = 0; i < avail_workers_.size(); ++i) { + // Submit all tasks whose arguments are ready. + WorkerId workerid = avail_workers_[i]; + for (auto it = task_queue_.begin(); it != task_queue_.end(); ++it) { + // The use of erase(it) below invalidates the iterator, but we + // immediately break out of the inner loop, so the iterator is not used + // after the erase + const Call& task = *(*it); + auto& workers = fntable_[task.name()].workers(); + if (std::binary_search(workers.begin(), workers.end(), workerid) && can_run(task)) { + submit_task(std::move(*it), workerid); + task_queue_.erase(it); + std::swap(avail_workers_[i], avail_workers_[avail_workers_.size() - 1]); + avail_workers_.pop_back(); + i -= 1; + break; + } } } } @@ -231,12 +269,12 @@ void SchedulerService::debug_info(const SchedulerDebugInfoRequest& request, Sche } } fntable_lock_.unlock(); - tasks_lock_.lock(); - for (const auto& entry : tasks_) { + task_queue_lock_.lock(); + for (const auto& entry : task_queue_) { Call* call = reply->add_task(); call->CopyFrom(*entry); } - tasks_lock_.unlock(); + task_queue_lock_.unlock(); avail_workers_lock_.lock(); for (const WorkerId& entry : avail_workers_) { reply->add_avail_worker(entry); diff --git a/src/scheduler.h b/src/scheduler.h index 782f5bebe..afe6ccb0c 100644 --- a/src/scheduler.h +++ b/src/scheduler.h @@ -89,8 +89,11 @@ private: FnTable fntable_; std::mutex fntable_lock_; // List of pending tasks. - std::deque > tasks_; - std::mutex tasks_lock_; + std::deque > task_queue_; + std::mutex task_queue_lock_; + // List of pending pull calls. + std::vector > pull_queue_; + std::mutex pull_queue_lock_; }; #endif diff --git a/src/worker.cc b/src/worker.cc index 194f625bc..387165fff 100644 --- a/src/worker.cc +++ b/src/worker.cc @@ -1,6 +1,9 @@ # include "worker.h" Status WorkerServiceImpl::InvokeCall(ServerContext* context, const InvokeCallRequest* request, InvokeCallReply* reply) { + // TODO(rkn): This method opens a message_queue, which may consume a + // filehandle. This should be changed to only open a queue once in the + // constructor. call_ = request->call(); // Copy call ORCH_LOG(ORCH_INFO, "invoked task " << request->call().name()); try { @@ -13,10 +16,24 @@ Status WorkerServiceImpl::InvokeCall(ServerContext* context, const InvokeCallReq std::cout << ex.what() << std::endl; // TODO: return Status; } - message_queue::remove(worker_address_.c_str()); return Status::OK; } +Worker::Worker(const std::string& worker_address, std::shared_ptr scheduler_channel, std::shared_ptr objstore_channel) + : worker_address_(worker_address), + scheduler_stub_(Scheduler::NewStub(scheduler_channel)), + objstore_stub_(ObjStore::NewStub(objstore_channel)) { + try { + // This creates the receive message queue. + const char* message_queue_name = worker_address_.c_str(); + message_queue::remove(message_queue_name); + receive_queue_ = std::unique_ptr(new message_queue(create_only, message_queue_name, 1, sizeof(Call*))); + } + catch(interprocess_exception &ex) { + std::cout << ex.what() << std::endl; + } +} + RemoteCallReply Worker::remote_call(RemoteCallRequest* request) { RemoteCallReply reply; ClientContext context; @@ -108,18 +125,15 @@ void Worker::register_function(const std::string& name, size_t num_return_vals) Call* Worker::receive_next_task() { const char* message_queue_name = worker_address_.c_str(); try { - message_queue::remove(message_queue_name); - message_queue mq(create_only, message_queue_name, 1, sizeof(Call*)); unsigned int priority; message_queue::size_type recvd_size; Call* call; while (true) { - mq.receive(&call, sizeof(Call*), recvd_size, priority); + receive_queue_->receive(&call, sizeof(Call*), recvd_size, priority); return call; } } catch(interprocess_exception &ex){ - message_queue::remove(message_queue_name); std::cout << ex.what() << std::endl; } } diff --git a/src/worker.h b/src/worker.h index 0ef15487d..50b0d8a31 100644 --- a/src/worker.h +++ b/src/worker.h @@ -37,11 +37,7 @@ private: class Worker { public: - Worker(const std::string& worker_address, std::shared_ptr scheduler_channel, std::shared_ptr objstore_channel) - : worker_address_(worker_address), - scheduler_stub_(Scheduler::NewStub(scheduler_channel)), - objstore_stub_(ObjStore::NewStub(objstore_channel)) - {} + Worker(const std::string& worker_address, std::shared_ptr scheduler_channel, std::shared_ptr objstore_channel); // submit a remote call to the scheduler RemoteCallReply remote_call(RemoteCallRequest* request); @@ -71,6 +67,7 @@ class Worker { std::unique_ptr objstore_stub_; std::thread worker_server_thread_; std::thread other_thread_; + std::unique_ptr receive_queue_; managed_shared_memory segment_; WorkerId workerid_; std::string worker_address_; diff --git a/test/arrays_test.py b/test/arrays_test.py index cebac1017..289d050ac 100644 --- a/test/arrays_test.py +++ b/test/arrays_test.py @@ -73,30 +73,24 @@ class ArraysSingleTest(unittest.TestCase): # test eye ref = single.eye(3) - time.sleep(0.2) val = orchpy.pull(ref) self.assertTrue(np.alltrue(val == np.eye(3))) # test zeros ref = single.zeros([3, 4, 5]) - time.sleep(0.2) val = orchpy.pull(ref) self.assertTrue(np.alltrue(val == np.zeros([3, 4, 5]))) # test qr - pass by value val_a = np.random.normal(size=[10, 13]) - time.sleep(0.2) ref_q, ref_r = single.linalg.qr(val_a) - time.sleep(0.2) val_q = orchpy.pull(ref_q) val_r = orchpy.pull(ref_r) self.assertTrue(np.allclose(np.dot(val_q, val_r), val_a)) # test qr - pass by objref a = single.random.normal([10, 13]) - time.sleep(0.2) # TODO(rkn): fails without this sleep ref_q, ref_r = single.linalg.qr(a) - time.sleep(0.2) val_a = orchpy.pull(a) val_q = orchpy.pull(ref_q) val_r = orchpy.pull(ref_r)