From 743f843524d055c6f358ebd0ceb803925d688cb3 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 1 Mar 2016 01:02:08 -0800 Subject: [PATCH] before refactoring --- CMakeLists.txt | 4 +- include/orchestra/orchestra.h | 3 + lib/orchpy/orchpy/services.py | 4 +- lib/orchpy/orchpy/unison.pyx | 48 ++++++++++ lib/orchpy/orchpy/worker.pyx | 154 ++++++++++++++++++++++++++++--- protos/orchestra.proto | 7 +- protos/types.proto | 7 +- src/objstore.cc | 75 +++++++++++++++- src/objstore.h | 96 ++++---------------- src/orchlib.cc | 12 +-- src/orchlib.h | 5 +- src/scheduler.cc | 165 ++++++++++++++++++++++++++++++++++ src/scheduler.h | 130 +++++++-------------------- src/scheduler_server.cc | 2 +- src/scheduler_server.h | 6 +- src/worker.cc | 134 +++++++++++++++++++++++++++ src/worker.h | 133 +++++++-------------------- test/runtest.py | 147 +++++++++++++----------------- 18 files changed, 729 insertions(+), 403 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6a3a903dd..c768e77b8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -70,7 +70,7 @@ if (UNIX AND NOT APPLE) endif() add_executable(objstore src/objstore.cc ${GENERATED_PROTOBUF_FILES}) -add_executable(scheduler_server src/scheduler_server.cc ${GENERATED_PROTOBUF_FILES}) -add_library(orchlib SHARED src/orchlib.cc ${GENERATED_PROTOBUF_FILES}) +add_executable(scheduler_server src/scheduler_server.cc src/scheduler.cc ${GENERATED_PROTOBUF_FILES}) +add_library(orchlib SHARED src/orchlib.cc src/worker.cc ${GENERATED_PROTOBUF_FILES}) install(TARGETS objstore scheduler_server orchlib DESTINATION ${CMAKE_SOURCE_DIR}/lib/orchpy/orchpy) diff --git a/include/orchestra/orchestra.h b/include/orchestra/orchestra.h index 4b78905c8..cb0b96341 100644 --- a/include/orchestra/orchestra.h +++ b/include/orchestra/orchestra.h @@ -27,6 +27,9 @@ public: ObjRef worker(size_t i) const { return workers_[i]; } + const std::vector& workers() const { + return workers_; + } }; typedef std::vector > ObjTable; diff --git a/lib/orchpy/orchpy/services.py b/lib/orchpy/orchpy/services.py index 048d21cdd..d22e4ecc7 100644 --- a/lib/orchpy/orchpy/services.py +++ b/lib/orchpy/orchpy/services.py @@ -8,16 +8,18 @@ _services_path = os.path.dirname(os.path.abspath(__file__)) all_processes = [] def cleanup(): + global all_processes timeout_sec = 5 for p in all_processes: p_sec = 0 for second in range(timeout_sec): if p.poll() == None: - time.sleep(1) + time.sleep(0.1) p_sec += 1 if p_sec >= timeout_sec: p.kill() # supported from python 2.6 print 'helper processes shut down!' + all_processes = [] atexit.register(cleanup) diff --git a/lib/orchpy/orchpy/unison.pyx b/lib/orchpy/orchpy/unison.pyx index f0c1c9ee9..0046e95a2 100644 --- a/lib/orchpy/orchpy/unison.pyx +++ b/lib/orchpy/orchpy/unison.pyx @@ -10,6 +10,14 @@ try: except: import pickle +cdef extern from "../../../build/generated/types.pb.h": + + cdef cppclass Call: + Value* add_arg(); + void set_name(const char* value) + Value* mutable_arg(int index); + int arg_size() const; + cdef extern from "../../../build/generated/types.pb.h": ctypedef enum DataType: INT32 @@ -89,6 +97,15 @@ cdef class ObjWrapper: # TODO: unify with the above def get_value(self): return self.thisptr +cdef class PythonCall: + cdef Call* thisptr + def __cinit__(self): + self.thisptr = new Call() + def __dealloc__(self): + del self.thisptr + def get_value(self): + return self.thisptr + cdef class ObjRef: cdef size_t _id cdef object type @@ -136,6 +153,10 @@ cpdef serialize(val): cpdef serialize_args_into(args, valsptr): cdef uintptr_t ptr = valsptr cdef Values* vals = ptr + serialize_args_into_vals(args, vals) + +# this code is a mess right now, will be improved in the C++ version +cdef serialize_args_into_vals(args, Values* vals): cdef Value* val cdef Obj* obj for arg in args: @@ -194,6 +215,33 @@ cpdef deserialize_args(PyValues args): result.append(pickle.loads(data)) return result +# todo: unify with the above, at the moment this is copied +cdef deserialize_args_from_call(Call* call): + cdef Value* val + cdef Obj* obj + result = [] + for i in range(call[0].arg_size()): + val = call[0].mutable_arg(i) + if not val.has_obj(): + result.append(ObjRef(val.ref(), None)) # TODO: fix this + else: + obj = val[0].mutable_obj() + if obj[0].has_string_data(): + result.append(obj[0].mutable_string_data()[0].mutable_data()[0]) + elif obj[0].has_int_data(): + result.append(obj[0].mutable_int_data()[0].data()) + elif obj[0].has_double_data(): + result.append(obj[0].mutable_double_data()[0].data()) + else: + data = obj[0].mutable_pyobj_data()[0].mutable_data()[0] + result.append(pickle.loads(data)) + return result + +cpdef deserialize_call(PythonCall pycall): + cdef Call* call = pycall.thisptr + return deserialize_args_from_call(call) + + cdef int numpy_dtype_to_proto(dtype): if dtype == np.dtype('int32'): return INT32 diff --git a/lib/orchpy/orchpy/worker.pyx b/lib/orchpy/orchpy/worker.pyx index 53f81a2c5..41bc18618 100644 --- a/lib/orchpy/orchpy/worker.pyx +++ b/lib/orchpy/orchpy/worker.pyx @@ -5,6 +5,11 @@ from libc.stdint cimport uint64_t, int64_t, uintptr_t from libcpp cimport bool from libcpp.string cimport string +try: + import cPickle as pickle +except: + import pickle + cdef struct Slice: char* ptr size_t size @@ -13,7 +18,7 @@ cdef extern void* orch_create_context(const char* server_addr, const char* worke cdef extern void orch_register_function(void* worker, const char* name, size_t num_return_vals) cdef extern size_t orch_remote_call(void* context, void* request); cdef extern size_t orch_push(void* context, void* value); -cdef extern void orch_main_loop(void* context); +cdef extern void* orch_main_loop(void* context); cdef extern Slice orch_get_serialized_obj(void* context, size_t objref); cdef extern from "Python.h": @@ -24,15 +29,21 @@ cdef extern from "Python.h": int PyByteArray_Resize(object self, Py_ssize_t size) except -1 char* PyByteArray_AS_STRING(object bytearray) -# cdef extern from "../../../build/generated/orchestra.pb.h": -# cdef cppclass RemoteCallRequest: -# RemoteCallRequest() -# void set_name(const char* value) -# Call* mutable_call() +cdef extern from "../../../build/generated/orchestra.pb.h": + cdef cppclass RemoteCallRequest: + RemoteCallRequest() + Call* mutable_call() + int arg_size() const; cdef extern from "../../../build/generated/types.pb.h": cdef cppclass Values + cdef cppclass Call: + Value* add_arg(); + void set_name(const char* value) + Value* mutable_arg(int index); + int arg_size() const; + ctypedef enum DataType: INT32 INT64 @@ -101,6 +112,30 @@ cdef serialize_into(val, Obj* obj): # pyobj_data = obj[0].mutable_pyobj_data() # pyobj_data[0].set_data(data, len(data)) +""" +cpdef deserialize_call(PyValues args): + cdef Values* vals = args.thisptr + cdef Value* val + cdef Obj* obj + result = [] + for i in range(vals[0].value_size()): + val = vals[0].mutable_value(i) + if not val.has_obj(): + result.append(ObjRef(val.ref(), None)) # TODO: fix this + else: + obj = val[0].mutable_obj() + if obj[0].has_string_data(): + result.append(obj[0].mutable_string_data()[0].mutable_data()[0]) + elif obj[0].has_int_data(): + result.append(obj[0].mutable_int_data()[0].data()) + elif obj[0].has_double_data(): + result.append(obj[0].mutable_double_data()[0].data()) + else: + data = obj[0].mutable_pyobj_data()[0].mutable_data()[0] + result.append(pickle.loads(data)) + return result +""" + cdef class ObjWrapper: # TODO: unify with the above cdef Obj *thisptr def __cinit__(self): @@ -130,6 +165,51 @@ cpdef serialize_into_2(val, objptr): # pyobj_data = obj[0].mutable_pyobj_data() # pyobj_data[0].set_data(data, len(data)) +cdef serialize_args_into_call(args, Call* call): + cdef Value* val + cdef Obj* obj + for arg in args: + val = call.add_arg() + if type(arg) == unison.ObjRef: + val[0].set_ref(arg.get_id()) + else: + obj = val[0].mutable_obj() + objptr = obj + unison.serialize_into(arg, objptr) + +cdef deserialize_obj(Obj* obj): + if obj[0].has_string_data(): + return obj[0].mutable_string_data()[0].mutable_data()[0] + elif obj[0].has_int_data(): + return obj[0].mutable_int_data()[0].data() + elif obj[0].has_double_data(): + return obj[0].mutable_double_data()[0].data() + else: + data = obj[0].mutable_pyobj_data()[0].mutable_data()[0] + return pickle.loads(data) + +# todo: unify with the above, at the moment this is copied +cdef deserialize_args_from_call(Call* call): + cdef Value* val + cdef Obj* obj + result = [] + for i in range(call[0].arg_size()): + val = call[0].mutable_arg(i) + if not val.has_obj(): + result.append(unison.ObjRef(val.ref(), None)) # TODO: fix this + else: + obj = val[0].mutable_obj() + if obj[0].has_string_data(): + result.append(obj[0].mutable_string_data()[0].mutable_data()[0]) + elif obj[0].has_int_data(): + result.append(obj[0].mutable_int_data()[0].data()) + elif obj[0].has_double_data(): + result.append(obj[0].mutable_double_data()[0].data()) + else: + data = obj[0].mutable_pyobj_data()[0].mutable_data()[0] + result.append(pickle.loads(data)) + return result + cdef class Worker: cdef void* context @@ -139,13 +219,13 @@ cdef class Worker: def connect(self, server_addr, worker_addr, objstore_addr): self.context = orch_create_context(server_addr, worker_addr, objstore_addr) -# cpdef call(self, name, args): -# cdef RemoteCallRequest* result = new RemoteCallRequest() -# result[0].set_name(name) -# unison.serialize_args_into(args, result[0].mutable_arg()) -# for i in range(10): -# orch_remote_call(self.context, result) -# # return result + cpdef call(self, name, args): + cdef RemoteCallRequest* result = new RemoteCallRequest() + cdef Call* call = result[0].mutable_call() + call.set_name(name) + serialize_args_into_call(args, call) + orch_remote_call(self.context, result) + # return result cpdef do_call(self, ptr): return orch_remote_call(self.context, ptr) @@ -170,16 +250,62 @@ cdef class Worker: data = PyBytes_FromStringAndSize(slice.ptr, slice.size) return data + cpdef do_pull(self, objref): + cdef Slice slice = orch_get_serialized_obj(self.context, objref) + cpdef pull(self, objref): cdef Slice slice = orch_get_serialized_obj(self.context, objref) + data = PyBytes_FromStringAndSize(slice.ptr, slice.size) + return unison.deserialize_from_string(data) cpdef register_function(self, func_name, num_args): orch_register_function(self.context, func_name, num_args) cpdef main_loop(self): - orch_main_loop(self.context) + result = [] + cdef Call* call = orch_main_loop(self.context) + cdef int size = call[0].arg_size() + cdef Obj* obj + print "hello from python" + print "size", size + return deserialize_args_from_call(call) global_worker = Worker() +def distributed(types, return_type, worker=global_worker): + def distributed_decorator(func): + # deserialize arguments, execute function and serialize result + def func_executor(args): + arguments = [] + protoargs = unison.deserialize_call(args, types) + for (i, proto) in enumerate(protoargs): + if type(proto) == unison.ObjRef: + if i < len(types) - 1: + arguments.append(worker.get_object(proto, types[i])) + elif i == len(types) - 1 and types[-1] is not None: + arguments.append(global_worker.get_object(proto, types[i])) + elif types[-1] is None: + arguments.append(worker.get_object(proto, types[-2])) + else: + raise Exception("Passed in " + str(len(args)) + " arguments to function " + func.__name__ + ", which takes only " + str(len(types)) + " arguments.") + else: + arguments.append(proto) + buf = bytearray() + result = func(*arguments) + if unison.unison_type(result) != return_type: + raise Exception("Return type of " + func.func_name + " does not match the return type specified in the @distributed decorator, was expecting " + str(return_type) + " but received " + str(unison.unison_type(result))) + unison.serialize(buf, result) + return memoryview(buf).tobytes() + # for remotely executing the function + def func_call(*args, typecheck=False): + return worker.call(func_call.func_name, func_call.module_name, args) + func_call.func_name = func.__name__.encode() # why do we call encode()? + func_call.module_name = func.__module__.encode() # why do we call encode()? + func_call.is_distributed = True + func_call.executor = func_executor + func_call.types = types + return func_call + return distributed_decorator + def pull(objref, worker=global_worker): return 1 diff --git a/protos/orchestra.proto b/protos/orchestra.proto index 11d4057ed..f468be196 100644 --- a/protos/orchestra.proto +++ b/protos/orchestra.proto @@ -54,16 +54,17 @@ message ChangeCountRequest { } message GetDebugInfoRequest { - + bool do_scheduling = 1; } message FnTableEntry { - uint64 workerid = 1; + repeated uint64 workerid = 1; uint64 num_return_vals = 2; } message GetDebugInfoReply { repeated Call task = 1; + repeated uint64 avail_worker = 3; map function_table = 2; } @@ -122,7 +123,7 @@ service ObjStore { } message InvokeCallRequest { - + Call call = 1; } message InvokeCallReply { diff --git a/protos/types.proto b/protos/types.proto index f3adea98d..a04ec066d 100644 --- a/protos/types.proto +++ b/protos/types.proto @@ -33,14 +33,9 @@ message Value { Obj obj = 2; // for pass by value } -// Deprecated -message Values { - repeated Value value = 1; -} - message Call { string name = 1; - repeated Value args = 2; + repeated Value arg = 2; repeated uint64 result = 3; } diff --git a/src/objstore.cc b/src/objstore.cc index fd842ba43..7e2551b5c 100644 --- a/src/objstore.cc +++ b/src/objstore.cc @@ -21,7 +21,7 @@ Status ObjStoreClient::upload_data_to(slice data, ObjRef objref, ObjStore::Stub& return writer->Finish(); } -void ObjStoreServiceImpl::allocate_memory(ObjRef objref, size_t size) { +void ObjStoreServer::allocate_memory(ObjRef objref, size_t size) { std::ostringstream stream; stream << "obj-" << memory_names_.size(); std::string name = stream.str(); @@ -37,8 +37,79 @@ void ObjStoreServiceImpl::allocate_memory(ObjRef objref, size_t size) { object.ptr.len = size; } +ObjStore::Stub& ObjStoreServer::get_objstore_stub(const std::string& objstore_address) { + auto iter = objstores_.find(objstore_address); + if (iter != objstores_.end()) + return *(iter->second); + auto channel = grpc::CreateChannel(objstore_address, grpc::InsecureChannelCredentials()); + objstores_.emplace(objstore_address, ObjStore::NewStub(channel)); + return *objstores_[objstore_address]; +} + +Status ObjStoreServer::DeliverObj(ServerContext* context, const DeliverObjRequest* request, AckReply* reply) { + ObjStore::Stub& stub = get_objstore_stub(request->objstore_address()); + ObjRef objref = request->objref(); + // TODO: Have to introduce wait condition + return ObjStoreClient::upload_data_to(memory_[objref].ptr, objref, stub); +} + +Status ObjStoreServer::DebugInfo(ServerContext* context, const DebugInfoRequest* request, DebugInfoReply* reply) { + for (const auto& entry : memory_) { + reply->add_objref(entry.first); + } + return Status::OK; +} + +Status ObjStoreServer::GetObj(ServerContext* context, const GetObjRequest* request, GetObjReply* reply) { + ObjRef objref = request->objref(); + std::cout << "getobj lock"; + memory_lock_.lock(); + shared_object& object = memory_[objref]; + reply->set_bucket(object.name); + auto handle = object.memory->get_handle_from_address(object.ptr.data); + reply->set_handle(handle); + reply->set_size(object.ptr.len); + memory_lock_.unlock(); + std::cout << "getobj unlock"; + return Status::OK; +} + +Status ObjStoreServer::StreamObj(ServerContext* context, ServerReader* reader, AckReply* reply) { + std::cout << "stream obj lock" << std::endl; + memory_lock_.lock(); + ObjChunk chunk; + ObjRef objref = 0; + size_t totalsize = 0; + if (reader->Read(&chunk)) { + objref = chunk.objref(); + totalsize = chunk.totalsize(); + allocate_memory(objref, totalsize); + } + size_t num_bytes = 0; + char* data = memory_[objref].ptr.data; + + std::cout << "before loop " << totalsize << std::endl; + + do { + if (num_bytes + chunk.data().size() > totalsize) { + std::cout << "cancelled" << std::endl; + memory_lock_.unlock(); + return Status::CANCELLED; + } + std::memcpy(data, chunk.data().c_str(), chunk.data().size()); + data += chunk.data().size(); + num_bytes += chunk.data().size(); + std::cout << "looping " << num_bytes << std::endl; + } while (reader->Read(&chunk)); + + std::cout << "finished" << std::endl; + memory_lock_.unlock(); + std::cout << "stream obj unlock" << std::endl; + return Status::OK; +} + void start_objstore(const char* objstore_address) { - ObjStoreServiceImpl service; + ObjStoreServer service; ServerBuilder builder; builder.AddListeningPort(std::string(objstore_address), grpc::InsecureServerCredentials()); diff --git a/src/objstore.h b/src/objstore.h index 6651113dc..e11ea9cd0 100644 --- a/src/objstore.h +++ b/src/objstore.h @@ -37,97 +37,33 @@ struct shared_object { slice ptr; }; -class ObjStoreServiceImpl final : public ObjStore::Service { - std::vector memory_names_; - std::unordered_map memory_; - std::mutex memory_lock_; - size_t page_size = mapped_region::get_page_size(); - std::unordered_map> objstores_; - - void allocate_memory(ObjRef objref, size_t size); - - // check if we already connected to the other objstore, if yes, return reference to connection, otherwise connect - ObjStore::Stub& get_objstore_stub(const std::string& objstore_address) { - auto iter = objstores_.find(objstore_address); - if (iter != objstores_.end()) - return *(iter->second); - auto channel = grpc::CreateChannel(objstore_address, grpc::InsecureChannelCredentials()); - objstores_.emplace(objstore_address, ObjStore::NewStub(channel)); - return *objstores_[objstore_address]; - } - +class ObjStoreServer final : public ObjStore::Service { public: - ObjStoreServiceImpl() {} + ObjStoreServer() {} - ~ObjStoreServiceImpl() { + ~ObjStoreServer() { for (const auto& segment_name : memory_names_) { shared_memory_object::remove(segment_name.c_str()); } } - Status DeliverObj(ServerContext* context, const DeliverObjRequest* request, AckReply* reply) override { - ObjStore::Stub& stub = get_objstore_stub(request->objstore_address()); - ObjRef objref = request->objref(); + Status DeliverObj(ServerContext* context, const DeliverObjRequest* request, AckReply* reply) override; - // TODO: Have to introduce wait condition + Status DebugInfo(ServerContext* context, const DebugInfoRequest* request, DebugInfoReply* reply) override; - return ObjStoreClient::upload_data_to(memory_[objref].ptr, objref, stub); - } + Status GetObj(ServerContext* context, const GetObjRequest* request, GetObjReply* reply) override; - Status DebugInfo(ServerContext* context, const DebugInfoRequest* request, DebugInfoReply* reply) override { - for (const auto& entry : memory_) { - reply->add_objref(entry.first); - } - return Status::OK; - } + Status StreamObj(ServerContext* context, ServerReader* reader, AckReply* reply) override; +private: + void allocate_memory(ObjRef objref, size_t size); + // check if we already connected to the other objstore, if yes, return reference to connection, otherwise connect + ObjStore::Stub& get_objstore_stub(const std::string& objstore_address); - Status GetObj(ServerContext* context, const GetObjRequest* request, GetObjReply* reply) override { - ObjRef objref = request->objref(); - std::cout << "getobj lock"; - memory_lock_.lock(); - shared_object& object = memory_[objref]; - reply->set_bucket(object.name); - auto handle = object.memory->get_handle_from_address(object.ptr.data); - reply->set_handle(handle); - reply->set_size(object.ptr.len); - memory_lock_.unlock(); - std::cout << "getobj unlock"; - return Status::OK; - } - - Status StreamObj(ServerContext* context, ServerReader* reader, AckReply* reply) override { - std::cout << "stream obj lock" << std::endl; - memory_lock_.lock(); - ObjChunk chunk; - ObjRef objref = 0; - size_t totalsize = 0; - if (reader->Read(&chunk)) { - objref = chunk.objref(); - totalsize = chunk.totalsize(); - allocate_memory(objref, totalsize); - } - size_t num_bytes = 0; - char* data = memory_[objref].ptr.data; - - std::cout << "before loop " << totalsize << std::endl; - - do { - if (num_bytes + chunk.data().size() > totalsize) { - std::cout << "cancelled" << std::endl; - memory_lock_.unlock(); - return Status::CANCELLED; - } - std::memcpy(data, chunk.data().c_str(), chunk.data().size()); - data += chunk.data().size(); - num_bytes += chunk.data().size(); - std::cout << "looping " << num_bytes << std::endl; - } while (reader->Read(&chunk)); - - std::cout << "finished" << std::endl; - memory_lock_.unlock(); - std::cout << "stream obj unlock" << std::endl; - return Status::OK; - } + std::vector memory_names_; + std::unordered_map memory_; + std::mutex memory_lock_; + size_t page_size = mapped_region::get_page_size(); + std::unordered_map> objstores_; }; #endif diff --git a/src/orchlib.cc b/src/orchlib.cc index 1ca84c873..2e11a42db 100644 --- a/src/orchlib.cc +++ b/src/orchlib.cc @@ -3,25 +3,25 @@ Worker* orch_create_context(const char* server_addr, const char* worker_addr, const char* objstore_addr) { auto server_channel = grpc::CreateChannel(server_addr, grpc::InsecureChannelCredentials()); auto objstore_channel = grpc::CreateChannel(objstore_addr, grpc::InsecureChannelCredentials()); - Worker* worker = new Worker(server_channel, objstore_channel); + Worker* worker = new Worker(std::string(worker_addr), server_channel, objstore_channel); worker->register_worker(std::string(worker_addr), std::string(objstore_addr)); return worker; } size_t orch_remote_call(Worker* worker, RemoteCallRequest* request) { - return worker->RemoteCall(request); + return worker->remote_call(request); } -void orch_main_loop(Worker* worker) { - worker->MainLoop(); +Call* orch_main_loop(Worker* worker) { + return worker->main_loop(); } size_t orch_push(Worker* worker, Obj* obj) { - return worker->PushObj(obj); + return worker->push_obj(obj); } slice orch_get_serialized_obj(Worker* worker, ObjRef objref) { - return worker->GetSerializedObj(objref); + return worker->get_serialized_obj(objref); } void orch_register_function(Worker* worker, const char* name, size_t num_return_vals) { diff --git a/src/orchlib.h b/src/orchlib.h index 070890745..3ddcf9201 100644 --- a/src/orchlib.h +++ b/src/orchlib.h @@ -1,5 +1,3 @@ - - extern "C" { struct slice { @@ -14,7 +12,8 @@ struct Value; Worker* orch_create_context(const char* server_addr, const char* worker_addr, const char* objstore_addr); size_t orch_remote_call(Worker* context, RemoteCallRequest* request); size_t orch_push(Worker* context, Obj* value); -void orch_main_loop(Worker* worker); +Call* orch_main_loop(Worker* worker); slice orch_get_serialized_obj(Worker* worker, size_t objref); void orch_register_function(Worker* worker, const char* name, size_t num_return_vals); + } diff --git a/src/scheduler.cc b/src/scheduler.cc index e69de29bb..56fc781e4 100644 --- a/src/scheduler.cc +++ b/src/scheduler.cc @@ -0,0 +1,165 @@ +#include "scheduler.h" + +size_t Scheduler::add_task(const Call& task) { + fntable_lock_.lock(); + size_t num_return_vals = fntable_[task.name()].num_return_vals(); + fntable_lock_.unlock(); + std::unique_ptr task_ptr(new Call(task)); + tasks_lock_.lock(); + tasks_.emplace_back(std::move(task_ptr)); + tasks_lock_.unlock(); + return num_return_vals; +} + +void Scheduler::schedule() { + // TODO: work out a better strategy here + WorkerId workerid = 0; + { + std::lock_guard lock(avail_workers_lock_); + if (avail_workers_.size() == 0) + return; + workerid = avail_workers_.back(); + std::cout << "got available worker" << workerid << std::endl; + avail_workers_.pop_back(); + } + // TODO: think about locking here + for (auto it = tasks_.begin(); it != tasks_.end(); ++it) { + const Call& task = *(*it); + auto& workers = fntable_[task.name()].workers(); + if (std::binary_search(workers.begin(), workers.end(), workerid) && can_run(task)) { + submit_task(std::move(*it), workerid); + tasks_.erase(it); + return; + } + } +} + +void Scheduler::submit_task(std::unique_ptr call, WorkerId workerid) { + ClientContext context; + InvokeCallRequest request; + InvokeCallReply reply; + std::cout << "sending arguments now" << std::endl; + for (size_t i = 0; i < call->arg_size(); ++i) { + if (!call->arg(i).has_obj()) { + std::cout << "need to send object ref" << call->arg(i).ref() << std::endl; + std::lock_guard objtable_lock(objtable_lock_); + auto &objstores = objtable_[call->arg(i).ref()]; + std::lock_guard workers_lock(workers_lock_); + if (!std::binary_search(objstores.begin(), objstores.end(), workers_[workerid].objstoreid)) { + std::cout << "have to send" << std::endl; + std::exit(1); + } + // if (objstoreid != workers_[workerid].objstoreid) { + // std::lock_guard objstores_lock(objstores_lock_); + // objstores_. + // } + } + } + request.set_allocated_call(call.release()); // protobuf object takes ownership + Status status = workers_[workerid].worker_stub->InvokeCall(&context, request, &reply); +} + +bool Scheduler::can_run(const Call& task) { + std::lock_guard lock(objtable_lock_); + for (int i = 0; i < task.arg_size(); ++i) { + if (!task.arg(i).has_obj()) { + if (objtable_[task.arg(i).ref()].size() == 0) { + return false; + } + } + } + return true; +} + +WorkerId Scheduler::register_worker(const std::string& worker_address, const std::string& objstore_address) { + ObjStoreId objstoreid = std::numeric_limits::max(); + objstores_lock_.lock(); + for (size_t i = 0; i < objstores_.size(); ++i) { + std::cout << "adress: " << objstores_[i].address << std::endl; + std::cout << "my adress: " << objstore_address << std::endl; + if (objstores_[i].address == objstore_address) { + objstoreid = i; + } + } + if (objstoreid == std::numeric_limits::max()) { + // register objstore + objstoreid = objstores_.size(); + auto channel = grpc::CreateChannel(objstore_address, grpc::InsecureChannelCredentials()); + objstores_.push_back(ObjStoreHandle()); + objstores_[objstoreid].address = objstore_address; + objstores_[objstoreid].channel = channel; + objstores_[objstoreid].objstore_stub = ObjStore::NewStub(channel); + } + objstores_lock_.unlock(); + workers_lock_.lock(); + WorkerId workerid = workers_.size(); + workers_.push_back(WorkerHandle()); + auto channel = grpc::CreateChannel(worker_address, grpc::InsecureChannelCredentials()); + workers_[workerid].channel = channel; + workers_[workerid].objstoreid = objstoreid; + workers_[workerid].worker_stub = WorkerServer::NewStub(channel); + workers_lock_.unlock(); + avail_workers_lock_.lock(); + avail_workers_.push_back(workerid); + avail_workers_lock_.unlock(); + return workerid; +} + +ObjRef Scheduler::register_new_object() { + objtable_lock_.lock(); + ObjRef result = objtable_.size(); + objtable_.push_back(std::vector()); + objtable_lock_.unlock(); + return result; +} + +void Scheduler::add_location(ObjRef objref, ObjStoreId objstoreid) { + objtable_lock_.lock(); + // do a binary search + auto pos = std::lower_bound(objtable_[objref].begin(), objtable_[objref].end(), objstoreid); + if (pos == objtable_[objref].end() || objstoreid < *pos) { + objtable_[objref].insert(pos, objstoreid); + } + objtable_lock_.unlock(); +} + +ObjStoreId Scheduler::get_store(WorkerId workerid) { + workers_lock_.lock(); + ObjStoreId result = workers_[workerid].objstoreid; + workers_lock_.unlock(); + return result; +} + +void Scheduler::register_function(const std::string& name, WorkerId workerid, size_t num_return_vals) { + fntable_lock_.lock(); + FnInfo& info = fntable_[name]; + info.set_num_return_vals(num_return_vals); + info.add_worker(workerid); + fntable_lock_.unlock(); +} + +void Scheduler::debug_info(const GetDebugInfoRequest& request, GetDebugInfoReply* reply) { + if (request.do_scheduling()) { + schedule(); + } + fntable_lock_.lock(); + auto function_table = reply->mutable_function_table(); + for (const auto& entry : fntable_) { + (*function_table)[entry.first].set_num_return_vals(entry.second.num_return_vals()); + for (const WorkerId& worker : entry.second.workers()) { + (*function_table)[entry.first].add_workerid(worker); + } + } + fntable_lock_.unlock(); + tasks_lock_.lock(); + for (const auto& entry : tasks_) { + Call* call = reply->add_task(); + call->CopyFrom(*entry); + } + tasks_lock_.unlock(); + avail_workers_lock_.lock(); + for (const WorkerId& entry : avail_workers_) { + reply->add_avail_worker(entry); + } + avail_workers_lock_.unlock(); +} diff --git a/src/scheduler.h b/src/scheduler.h index 2d21b34c5..13046e464 100644 --- a/src/scheduler.h +++ b/src/scheduler.h @@ -3,6 +3,8 @@ #include #include +#include +#include #include @@ -16,25 +18,52 @@ using grpc::ServerReader; using grpc::ServerContext; using grpc::Status; +using grpc::ClientContext; + using grpc::Channel; struct WorkerHandle { std::shared_ptr channel; + std::unique_ptr worker_stub; ObjStoreId objstoreid; }; struct ObjStoreHandle { std::shared_ptr channel; + std::unique_ptr objstore_stub; std::string address; }; class Scheduler { +public: + // returns number of return values of task + size_t add_task(const Call& task); + // assign a task to a worker + void schedule(); + // execute a task on a worker and ship required object references + void submit_task(std::unique_ptr call, WorkerId workerid); + // checks if the dependencies of the task are met + bool can_run(const Call& task); + // register a worker and its object store (if it has not been registered yet) + WorkerId register_worker(const std::string& worker_address, const std::string& objstore_address); + // register a new object with the scheduler and return its object reference + ObjRef register_new_object(); + // register the location of the object reference in the object table + void add_location(ObjRef objref, ObjStoreId objstoreid); + // get object store associated with a workerid + ObjStoreId get_store(WorkerId workerid); + // register a function with the scheduler + void register_function(const std::string& name, WorkerId workerid, size_t num_return_vals); + // get debugging information for the scheduler + void debug_info(const GetDebugInfoRequest& request, GetDebugInfoReply* reply); +private: // Vector of all workers registered in the system. Their index in this vector // is the workerid. std::vector workers_; std::mutex workers_lock_; // Vector of all workers that are currently idle. - std::vector available_workers_; + std::vector avail_workers_; + std::mutex avail_workers_lock_; // Vector of all object stores registered in the system. Their index in this // vector is the objstoreid. std::vector objstores_; @@ -48,105 +77,6 @@ class Scheduler { // List of pending tasks. std::deque > tasks_; std::mutex tasks_lock_; -public: - // returns number of return values of task - size_t add_task(const Call& task) { - fntable_lock_.lock(); - size_t num_return_vals = fntable_[task.name()].num_return_vals(); - fntable_lock_.unlock(); - std::unique_ptr task_ptr(new Call(task)); // TODO: perform copy outside - tasks_lock_.lock(); - tasks_.emplace_back(std::move(task_ptr)); - tasks_lock_.unlock(); - return num_return_vals; - } - WorkerId register_worker(const std::string& worker_address, const std::string& objstore_address) { - ObjStoreId objstoreid = std::numeric_limits::max(); - objstores_lock_.lock(); - for (size_t i = 0; i < objstores_.size(); ++i) { - std::cout << "adress: " << objstores_[i].address << std::endl; - std::cout << "my adress: " << objstore_address << std::endl; - if (objstores_[i].address == objstore_address) { - objstoreid = i; - } - } - if (objstoreid == std::numeric_limits::max()) { - // throw objstore_not_registered_error("objectstore not registered"); - std::cout << "bad bad bad" << std::endl; - } - objstores_lock_.unlock(); - workers_lock_.lock(); - WorkerId result = workers_.size(); - workers_.push_back(WorkerHandle()); - workers_[result].channel = grpc::CreateChannel(worker_address, grpc::InsecureChannelCredentials()); - workers_[result].objstoreid = objstoreid; - workers_lock_.unlock(); - return result; - } - ObjStoreId register_objstore(const std::string& objstore_address) { - // auto handle = ObjStoreHandle(objstore_address); - // auto handlecopy = handle; - // auto handle = ObjStoreHandle("0.0.0.0:22222"); - objstores_lock_.lock(); - std::cout << "capacity" << objstores_.capacity() << std::endl; - ObjStoreId result = objstores_.size(); - // auto handle = ObjStoreHandle(objstore_address); - // objstores_.emplace_back(objstore_address); - objstores_.push_back(ObjStoreHandle()); - - objstores_[result].channel = grpc::CreateChannel(objstore_address, grpc::InsecureChannelCredentials()); - objstores_[result].address = std::string(objstore_address); - - // auto handlecopy = handle; - // auto handle = grpc::CreateChannel(objstore_address, grpc::InsecureChannelCredentials()); - // auto handlecopy = grpc::CreateChannel(objstore_address, grpc::InsecureChannelCredentials()); - objstores_lock_.unlock(); - return result; - } - ObjRef register_new_object() { - objtable_lock_.lock(); - ObjRef result = objtable_.size(); - objtable_.push_back(std::vector()); - objtable_lock_.unlock(); - return result; - } - void add_objstore_to_obj(ObjRef objref, ObjStoreId objstoreid) { - objtable_lock_.lock(); - // do a binary search - auto pos = std::lower_bound(objtable_[objref].begin(), objtable_[objref].end(), objstoreid); - if (pos == objtable_[objref].end() || objstoreid < *pos) { - objtable_[objref].insert(pos, objstoreid); - } - objtable_lock_.unlock(); - } - ObjStoreId get_store(WorkerId workerid) { - workers_lock_.lock(); - ObjStoreId result = workers_[workerid].objstoreid; - workers_lock_.unlock(); - return result; - } - void register_function(const std::string& name, WorkerId workerid, size_t num_return_vals) { - fntable_lock_.lock(); - FnInfo& info = fntable_[name]; - info.set_num_return_vals(num_return_vals); - info.add_worker(workerid); - fntable_lock_.unlock(); - } - void debug_info(GetDebugInfoReply* debug_info) { - fntable_lock_.lock(); - for (const auto& entry : fntable_) { - auto function_table = debug_info->mutable_function_table(); - (*function_table)[entry.first].set_num_return_vals(entry.second.num_return_vals()); - // TODO: set workerid - } - fntable_lock_.unlock(); - tasks_lock_.lock(); - for (const auto& entry : tasks_) { - Call* call = debug_info->add_task(); - call->CopyFrom(*entry); - } - tasks_lock_.unlock(); - } }; #endif diff --git a/src/scheduler_server.cc b/src/scheduler_server.cc index 2021db4a2..43bbf8846 100644 --- a/src/scheduler_server.cc +++ b/src/scheduler_server.cc @@ -12,7 +12,7 @@ Status SchedulerServerServiceImpl::RemoteCall(ServerContext* context, const Remo Status SchedulerServerServiceImpl::PushObj(ServerContext* context, const PushObjRequest* request, PushObjReply* reply) { ObjRef objref = scheduler_->register_new_object(); ObjStoreId objstoreid = scheduler_->get_store(request->workerid()); - scheduler_->add_objstore_to_obj(objref, objstoreid); + scheduler_->add_location(objref, objstoreid); reply->set_objref(objref); return Status::OK; } diff --git a/src/scheduler_server.h b/src/scheduler_server.h index 0fad71a00..d6d921cc3 100644 --- a/src/scheduler_server.h +++ b/src/scheduler_server.h @@ -22,9 +22,11 @@ public: } Status RegisterWorker(ServerContext* context, const RegisterWorkerRequest* request, RegisterWorkerReply* reply) override { WorkerId workerid = scheduler_->register_worker(request->worker_address(), request->objstore_address()); + std::cout << "registered worker with workerid" << workerid << std::endl; reply->set_workerid(workerid); return Status::OK; } + /* Status RegisterObjStore(ServerContext* context, const RegisterObjStoreRequest* request, RegisterObjStoreReply* reply) override { try { reply->set_objstoreid(scheduler_->register_objstore(request->address())); @@ -33,12 +35,14 @@ public: } return Status::OK; } + */ Status RegisterFunction(ServerContext* context, const RegisterFunctionRequest* request, AckReply* reply) override { + std::cout << "RegisterFunction: workerid is" << request->workerid() << std::endl; scheduler_->register_function(request->fnname(), request->workerid(), request->num_return_vals()); return Status::OK; } Status GetDebugInfo(ServerContext* context, const GetDebugInfoRequest* request, GetDebugInfoReply* reply) override { - scheduler_->debug_info(reply); + scheduler_->debug_info(*request, reply); return Status::OK; } }; diff --git a/src/worker.cc b/src/worker.cc index e69de29bb..addf703a0 100644 --- a/src/worker.cc +++ b/src/worker.cc @@ -0,0 +1,134 @@ +# include "worker.h" + +Status WorkerServiceImpl::InvokeCall(ServerContext* context, const InvokeCallRequest* request, InvokeCallReply* reply) { + call_ = request->call(); + std::cout << "invoke call request" << std::endl; + try { + Call* callptr = &call_; + message_queue mq(open_only, worker_address_.c_str()); + std::cout << "before send: num args" << call_.arg_size() << std::endl; + mq.send(&callptr, sizeof(Call*), 0); + } + catch(interprocess_exception &ex){ + message_queue::remove(worker_address_.c_str()); + std::cout << ex.what() << std::endl; + // TODO: return Status; + } + message_queue::remove(worker_address_.c_str()); + std::cout << "notified server" << std::endl; + return Status::OK; +} + +size_t Worker::remote_call(RemoteCallRequest* request) { + RemoteCallReply reply; + ClientContext context; + Status status = scheduler_stub_->RemoteCall(&context, *request, &reply); + // TODO: Return results: return reply.result(0); +} + +void Worker::register_worker(const std::string& worker_address, const std::string& objstore_address) { + RegisterWorkerRequest request; + request.set_worker_address(worker_address); + request.set_objstore_address(objstore_address); + RegisterWorkerReply reply; + ClientContext context; + Status status = scheduler_stub_->RegisterWorker(&context, request, &reply); + workerid_ = reply.workerid(); + return; +} + +ObjRef Worker::push_obj(Obj* obj) { + // first get objref for the new object + PushObjRequest push_request; + PushObjReply push_reply; + ClientContext push_context; + Status push_status = scheduler_stub_->PushObj(&push_context, push_request, &push_reply); + ObjRef objref = push_reply.objref(); + // then stream the object to the object store + ObjChunk chunk; + std::string data; + obj->SerializeToString(&data); + size_t totalsize = data.size(); + ClientContext context; + AckReply reply; + std::unique_ptr > writer( + objstore_stub_->StreamObj(&context, &reply)); + const char* head = data.c_str(); + for (size_t i = 0; i < data.length(); i += CHUNK_SIZE) { + chunk.set_objref(objref); + std::cout << "chunk totalsize" << std::endl; + chunk.set_totalsize(totalsize); + chunk.set_data(head + i, std::min(CHUNK_SIZE, data.length() - i)); + if (!writer->Write(chunk)) { + std::cout << "write failed" << std::endl; + // TODO: Better error handling: throw std::runtime_error("write failed"); + } + } + writer->WritesDone(); + Status status = writer->Finish(); + return objref; +} + +slice Worker::get_serialized_obj(ObjRef objref) { + ClientContext context; + GetObjRequest request; + request.set_objref(objref); + GetObjReply reply; + objstore_stub_->GetObj(&context, request, &reply); + segment_ = managed_shared_memory(open_only, reply.bucket().c_str()); + slice slice; + slice.data = static_cast(segment_.get_address_from_handle(reply.handle())); + slice.len = reply.size(); + return slice; +} + +void Worker::register_function(const std::string& name, size_t num_return_vals) { + ClientContext context; + RegisterFunctionRequest request; + request.set_fnname(name); + request.set_num_return_vals(num_return_vals); + request.set_workerid(workerid_); + AckReply reply; + scheduler_stub_->RegisterFunction(&context, request, &reply); +} + +void start_worker_server(const char* server_address) { + WorkerServiceImpl service(server_address); + ServerBuilder builder; + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + builder.RegisterService(&service); + std::unique_ptr server(builder.BuildAndStart()); + std::cout << "Server listening on " << server_address << std::endl; + server->Wait(); +} + +// Communication between the WorkerServer and the Worker happens via a message +// queue. This is because the Python interpreter needs to be single threaded +// (in our case running in the main thread), whereas the WorkerService will +// run in a separate thread and potentially utilize multiple threads. +Call* Worker::main_loop() { + // start the worker server + worker_server_thread_ = std::thread(start_worker_server, worker_address_.c_str()); + // process the next call + return receive(worker_address_.c_str()); +} + +Call* receive(const char* message_queue_name) { + try { + message_queue::remove(message_queue_name); + message_queue mq(create_only, message_queue_name, 1, sizeof(Call*)); + unsigned int priority; + message_queue::size_type recvd_size; + Call* call; + while (true) { + mq.receive(&call, sizeof(Call*), recvd_size, priority); + std::cout << "got call" << call << std::endl; + std::cout << "after send: num args" << call->arg_size() << std::endl; + return call; + } + } + catch(interprocess_exception &ex){ + message_queue::remove(message_queue_name); + std::cout << ex.what() << std::endl; + } +} diff --git a/src/worker.h b/src/worker.h index 58aa05c29..e798d1b73 100644 --- a/src/worker.h +++ b/src/worker.h @@ -7,6 +7,8 @@ #include #include +#include + using namespace boost::interprocess; #include @@ -26,118 +28,49 @@ using grpc::ClientContext; using grpc::ClientWriter; class WorkerServiceImpl final : public WorkerServer::Service { - Status InvokeCall(ServerContext* context, const InvokeCallRequest* request, - InvokeCallReply* reply) override { - std::cout << "invoke call request" << std::endl; - return Status::OK; - } +public: + WorkerServiceImpl(const std::string& worker_address) + : worker_address_(worker_address) {} + Status InvokeCall(ServerContext* context, const InvokeCallRequest* request, InvokeCallReply* reply) override; +private: + std::string worker_address_; + Call call_; // copy of the current call }; -void start_server() { - std::string server_address("0.0.0.0:50053"); - WorkerServiceImpl service; - ServerBuilder builder; - builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); - builder.RegisterService(&service); - std::unique_ptr server(builder.BuildAndStart()); - std::cout << "Server listening on " << server_address << std::endl; - server->Wait(); -} +void start_worker_server(const char* worker_addr); + +Call* receive(const char* worker_addr); class Worker { - managed_shared_memory segment; public: - Worker(std::shared_ptr scheduler_channel, std::shared_ptr objstore_channel) - : scheduler_stub_(SchedulerServer::NewStub(scheduler_channel)), + Worker(const std::string& worker_address, std::shared_ptr scheduler_channel, std::shared_ptr objstore_channel) + : worker_address_(worker_address), + scheduler_stub_(SchedulerServer::NewStub(scheduler_channel)), objstore_stub_(ObjStore::NewStub(objstore_channel)) {} - size_t RemoteCall(RemoteCallRequest* request) { - // RemoteCallReply reply; - // ClientContext context; - - // Status status = stub_->RemoteCall(&context, *request, &reply); - - // return reply.result(); - return 42; - } - - void register_worker(const std::string& worker_address, const std::string& objstore_address) { - RegisterWorkerRequest request; - request.set_worker_address(worker_address); - request.set_objstore_address(objstore_address); - RegisterWorkerReply reply; - ClientContext context; - Status status = scheduler_stub_->RegisterWorker(&context, request, &reply); - return; - } - - const size_t CHUNK_SIZE = 8 * 1024; - - ObjRef PushObj(Obj* obj) { - // first get objref for the new object - PushObjRequest push_request; - PushObjReply push_reply; - ClientContext push_context; - Status push_status = scheduler_stub_->PushObj(&push_context, push_request, &push_reply); - ObjRef objref = push_reply.objref(); - ObjChunk chunk; - std::string data; - obj->SerializeToString(&data); - size_t totalsize = data.size(); - ClientContext context; - AckReply reply; - std::unique_ptr > writer( - objstore_stub_->StreamObj(&context, &reply)); - const char* head = data.c_str(); - for (size_t i = 0; i < data.length(); i += CHUNK_SIZE) { - chunk.set_objref(objref); - std::cout << "chunk totalsize" << std::endl; - chunk.set_totalsize(totalsize); - chunk.set_data(head + i, std::min(CHUNK_SIZE, data.length() - i)); - if (!writer->Write(chunk)) { - std::cout << "write failed" << std::endl; - // throw std::runtime_error("write failed"); - } - } - writer->WritesDone(); - Status status = writer->Finish(); - return objref; - } - - slice GetSerializedObj(ObjRef objref) { - ClientContext context; - GetObjRequest request; - request.set_objref(objref); - GetObjReply reply; - objstore_stub_->GetObj(&context, request, &reply); - segment = managed_shared_memory(open_only, reply.bucket().c_str()); - slice slice; - slice.data = static_cast(segment.get_address_from_handle(reply.handle())); - slice.len = reply.size(); - return slice; - } - - void register_function(const std::string& name, size_t num_return_vals) { - ClientContext context; - RegisterFunctionRequest request; - request.set_fnname(name); - request.set_num_return_vals(num_return_vals); - AckReply reply; - scheduler_stub_->RegisterFunction(&context, request, &reply); - } - - void MainLoop() { - scheduler_thread = std::thread(start_server); - - } - - + // submit a remote call to the scheduler + size_t remote_call(RemoteCallRequest* request); + // send request to the scheduler to register this worker + void register_worker(const std::string& worker_address, const std::string& objstore_address); + // push object to local object store, register it with the server and return object reference + ObjRef push_obj(Obj* obj); + // retrieve serialized object from local object store + slice get_serialized_obj(ObjRef objref); + // register function with scheduler + void register_function(const std::string& name, size_t num_return_vals); + // start the main loop + Call* main_loop(); private: + const size_t CHUNK_SIZE = 8 * 1024; std::unique_ptr scheduler_stub_; std::unique_ptr objstore_stub_; - std::thread scheduler_thread; + std::thread worker_server_thread_; + std::thread other_thread_; + managed_shared_memory segment_; + WorkerId workerid_; + std::string worker_address_; }; #endif diff --git a/test/runtest.py b/test/runtest.py index 119b5d5b2..67e2cf27a 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -4,6 +4,10 @@ import orchpy.services as services import orchpy.worker as worker import numpy as np import time +import subprocess32 as subprocess +import os + +from google.protobuf.text_format import * from grpc.beta import implementations import orchestra_pb2 @@ -27,12 +31,25 @@ class UnisonTest(unittest.TestCase): b = unison.deserialize_args(res) self.assertTrue((a == b).all()) + a = [unison.ObjRef(42, int)] + res = unison.serialize_args(a) + b = unison.deserialize_args(res) + self.assertEqual(a, b) + TIMEOUT_SECONDS = 5 def produce_data(num_chunks): for _ in range(num_chunks): yield orchestra_pb2.ObjChunk(objref=1, totalsize=1000, data=b"hello world") +def connect_to_scheduler(host, port): + channel = implementations.insecure_channel(host, port) + return orchestra_pb2.beta_create_SchedulerServer_stub(channel) + +def connect_to_objstore(host, port): + channel = implementations.insecure_channel(host, port) + return orchestra_pb2.beta_create_ObjStore_stub(channel) + class ObjStoreTest(unittest.TestCase): """Test setting up object stores, transfering data between them and retrieving data to a client""" @@ -40,123 +57,85 @@ class ObjStoreTest(unittest.TestCase): services.start_scheduler("0.0.0.0:22221") services.start_objstore("0.0.0.0:22222") services.start_objstore("0.0.0.0:22223") - time.sleep(0.5) - scheduler_channel = implementations.insecure_channel('localhost', 22221) - scheduler_stub = orchestra_pb2.beta_create_SchedulerServer_stub(scheduler_channel) - objstore1_channel = implementations.insecure_channel('localhost', 22222) - objstore1_stub = orchestra_pb2.beta_create_ObjStore_stub(objstore1_channel) - objstore2_channel = implementations.insecure_channel('localhost', 22223) - objstore2_stub = orchestra_pb2.beta_create_ObjStore_stub(objstore2_channel) + time.sleep(0.2) - scheduler_stub.RegisterObjStore(orchestra_pb2.RegisterObjStoreRequest(address="127.0.0.1:22222"), TIMEOUT_SECONDS) - scheduler_stub.RegisterObjStore(orchestra_pb2.RegisterObjStoreRequest(address="127.0.0.1:22223"), TIMEOUT_SECONDS) + scheduler_stub = connect_to_scheduler('localhost', 22221) + objstore1_stub = connect_to_objstore('localhost', 22222) + objstore2_stub = connect_to_objstore('localhost', 22223) - worker.global_worker.connect("127.0.0.1:22221", "127.0.0.1:40000", "127.0.0.1:22222") + worker1 = worker.Worker() + worker1.connect("127.0.0.1:22221", "127.0.0.1:40000", "127.0.0.1:22222") - other_worker = worker.Worker() - other_worker.connect("127.0.0.1:22221", "127.0.0.1:40001", "127.0.0.1:22223") + worker2 = worker.Worker() + worker2.connect("127.0.0.1:22221", "127.0.0.1:40001", "127.0.0.1:22223") - # import IPython - # IPython.embed() - - for i in range(1, 10): + for i in range(1, 100): l = i * 100 * "h" - objref = worker.global_worker.do_push(l) - # time.sleep(5.0) + objref = worker1.do_push(l) response = objstore1_stub.DeliverObj(orchestra_pb2.DeliverObjRequest(objref=objref, objstore_address="0.0.0.0:22223"), TIMEOUT_SECONDS) - # time.sleep(5.0) - str = other_worker.get_serialized(objref) - result = worker.unison.deserialize_from_string(str) - # import IPython - # IPython.embed() + s = worker2.get_serialized(objref) + result = worker.unison.deserialize_from_string(s) self.assertEqual(len(result), 100 * i) -class SchedulerTest(unittest.TestCase): + services.cleanup() - def testRegister(self): - scheduler_channel = implementations.insecure_channel('localhost', 22221) - scheduler_stub = orchestra_pb2.beta_create_SchedulerServer_stub(scheduler_channel) - w = worker.Worker() - w.connect("127.0.0.1:22221", "127.0.0.1:40002", "127.0.0.1:22222") - w.register_function("hello_world", 2) - reply = scheduler_stub.GetDebugInfo(orchestra_pb2.GetDebugInfoRequest(), TIMEOUT_SECONDS) - self.assertEqual(reply.function_table.items()[0][0], u'hello_world') +class SchedulerTest(unittest.TestCase): def testCall(self): - scheduler_channel = implementations.insecure_channel('localhost', 22221) - scheduler_stub = orchestra_pb2.beta_create_SchedulerServer_stub(scheduler_channel) - w = worker.Worker() - w.connect("127.0.0.1:22221", "127.0.0.1:40003", "127.0.0.1:22222") - - -""" -class SchedulerTest(unittest.TestCase): - - def testServer(self): services.start_scheduler("0.0.0.0:22221") services.start_objstore("0.0.0.0:22222") - services.start_objstore("0.0.0.0:22223") - time.sleep(1.0) - scheduler_channel = implementations.insecure_channel('localhost', 22221) - scheduler_stub = orchestra_pb2.beta_create_SchedulerServer_stub(scheduler_channel) - objstore_channel = implementations.insecure_channel('localhost', 22222) - objstore_stub = orchestra_pb2.beta_create_ObjStore_stub(objstore_channel) - objstore_channel2 = implementations.insecure_channel('localhost', 22223) - objstore_stub2 = orchestra_pb2.beta_create_ObjStore_stub(objstore_channel2) + time.sleep(0.2) - # call = types_pb2.Call(name="test") - # response = scheduler_stub.RemoteCall(orchestra_pb2.RemoteCallRequest(call=call), TIMEOUT_SECONDS) - # response = scheduler_stub.RegisterFunction(orchestra_pb2.RegisterFunctionRequest(workerid=1, fnname="hello"), TIMEOUT_SECONDS) + scheduler_stub = connect_to_scheduler('localhost', 22221) + objstore_stub = connect_to_objstore('localhost', 22222) - response2 = scheduler_stub.RegisterObjStore(orchestra_pb2.RegisterObjStoreRequest(address="127.0.0.1:22222"), TIMEOUT_SECONDS) - response2 = scheduler_stub.RegisterObjStore(orchestra_pb2.RegisterObjStoreRequest(address="127.0.0.1:22223"), TIMEOUT_SECONDS) + time.sleep(0.2) - # response2 = scheduler_stub.RegisterObjStore(orchestra_pb2.RegisterObjStoreRequest(address="127.0.0.1:22222"), TIMEOUT_SECONDS) - # response3 = scheduler_stub.RegisterObjStore(orchestra_pb2.RegisterObjStoreRequest(address="127.0.0.1:22223"), TIMEOUT_SECONDS) + w = worker.Worker() + w.connect("127.0.0.1:22221", "127.0.0.1:40003", "127.0.0.1:22222") + w2 = worker.Worker() + w2.connect("127.0.0.1:22221", "127.0.0.1:40004", "127.0.0.1:22222") - # objstore_stub.StreamObj(produce_data(100), TIMEOUT_SECONDS) + time.sleep(0.2) - worker.global_worker.connect("127.0.0.1:22221", "127.0.0.1:40000", "127.0.0.1:22222") + w.register_function("hello_world", 2) + w2.register_function("hello_world", 2) - l = [1, 2, 3, 4] - worker.global_worker.do_push(l) + time.sleep(0.1) - ## res = scheduler_stub.PushObj(orchestra_pb2.PushObjRequest(workerid=0), TIMEOUT_SECONDS) + w.call("hello_world", ["hi"]) - response = objstore_stub.DeliverObj(orchestra_pb2.DeliverObjRequest(objref=0, objstore_address="0.0.0.0:22223"), TIMEOUT_SECONDS) + time.sleep(0.1) - # res = objstore_stub2.DebugInfo(orchestra_pb2.DebugInfoRequest(), TIMEOUT_SECONDS) + reply = scheduler_stub.GetDebugInfo(orchestra_pb2.GetDebugInfoRequest(), TIMEOUT_SECONDS) - response = objstore_stub.GetObj(orchestra_pb2.GetObjRequest(objref=0), TIMEOUT_SECONDS) + self.assertEqual(reply.task[0].name, u'hello_world') - worker.global_worker.get_serialized(0) + test_path = os.path.dirname(os.path.abspath(__file__)) - import IPython - IPython.embed() + p = subprocess.Popen(["python", os.path.join(test_path, "testrecv.py")]) - l = [1, 2, 3, 4] - worker.global_worker.do_push(l) + time.sleep(0.2) - response = objstore_stub.DeliverObj(orchestra_pb2.DeliverObjRequest(), TIMEOUT_SECONDS) + scheduler_stub.PushObj(orchestra_pb2.PushObjRequest(workerid=0), TIMEOUT_SECONDS) - # response = objstore_stub.DebugInfo(orchestra_pb2.DebugInfoRequest(), TIMEOUT_SECONDS) + reply = scheduler_stub.GetDebugInfo(orchestra_pb2.GetDebugInfoRequest(do_scheduling=True), TIMEOUT_SECONDS) - # import IPython - # IPython.embed() + self.assertEqual(p.wait(), 0, "argument was not received by the test program") - # worker.global_worker.connect("127.0.0.1:22221", "127.0.0.1:22222") - # l = [1, 2, 3, 4] - # worker.global_worker.do_push(l) + # w.main_loop() + # w2.main_loop() + # + # reply = scheduler_stub.GetDebugInfo(orchestra_pb2.GetDebugInfoRequest(do_scheduling=True), TIMEOUT_SECONDS) + # time.sleep(0.1) + # reply = scheduler_stub.GetDebugInfo(orchestra_pb2.GetDebugInfoRequest(), TIMEOUT_SECONDS) + # + # self.assertEqual(list(reply.task), []) + # + # services.cleanup() - # import IPython - # IPython.embed() - # response = objstore_stub.DeliverObj(orchestra_pb2.DeliverObjRequest()) - # print "Greeter client received: " + response.message - # import IPython - # IPython.embed() -""" if __name__ == '__main__': unittest.main()