before refactoring

This commit is contained in:
Philipp Moritz
2016-03-01 01:02:08 -08:00
parent 07c6b010d9
commit 743f843524
18 changed files with 729 additions and 403 deletions
+2 -2
View File
@@ -70,7 +70,7 @@ if (UNIX AND NOT APPLE)
endif()
add_executable(objstore src/objstore.cc ${GENERATED_PROTOBUF_FILES})
add_executable(scheduler_server src/scheduler_server.cc ${GENERATED_PROTOBUF_FILES})
add_library(orchlib SHARED src/orchlib.cc ${GENERATED_PROTOBUF_FILES})
add_executable(scheduler_server src/scheduler_server.cc src/scheduler.cc ${GENERATED_PROTOBUF_FILES})
add_library(orchlib SHARED src/orchlib.cc src/worker.cc ${GENERATED_PROTOBUF_FILES})
install(TARGETS objstore scheduler_server orchlib DESTINATION ${CMAKE_SOURCE_DIR}/lib/orchpy/orchpy)
+3
View File
@@ -27,6 +27,9 @@ public:
ObjRef worker(size_t i) const {
return workers_[i];
}
const std::vector<WorkerId>& workers() const {
return workers_;
}
};
typedef std::vector<std::vector<ObjStoreId> > ObjTable;
+3 -1
View File
@@ -8,16 +8,18 @@ _services_path = os.path.dirname(os.path.abspath(__file__))
all_processes = []
def cleanup():
global all_processes
timeout_sec = 5
for p in all_processes:
p_sec = 0
for second in range(timeout_sec):
if p.poll() == None:
time.sleep(1)
time.sleep(0.1)
p_sec += 1
if p_sec >= timeout_sec:
p.kill() # supported from python 2.6
print 'helper processes shut down!'
all_processes = []
atexit.register(cleanup)
+48
View File
@@ -10,6 +10,14 @@ try:
except:
import pickle
cdef extern from "../../../build/generated/types.pb.h":
cdef cppclass Call:
Value* add_arg();
void set_name(const char* value)
Value* mutable_arg(int index);
int arg_size() const;
cdef extern from "../../../build/generated/types.pb.h":
ctypedef enum DataType:
INT32
@@ -89,6 +97,15 @@ cdef class ObjWrapper: # TODO: unify with the above
def get_value(self):
return <uintptr_t>self.thisptr
cdef class PythonCall:
cdef Call* thisptr
def __cinit__(self):
self.thisptr = new Call()
def __dealloc__(self):
del self.thisptr
def get_value(self):
return <uintptr_t>self.thisptr
cdef class ObjRef:
cdef size_t _id
cdef object type
@@ -136,6 +153,10 @@ cpdef serialize(val):
cpdef serialize_args_into(args, valsptr):
cdef uintptr_t ptr = <uintptr_t>valsptr
cdef Values* vals = <Values*>ptr
serialize_args_into_vals(args, vals)
# this code is a mess right now, will be improved in the C++ version
cdef serialize_args_into_vals(args, Values* vals):
cdef Value* val
cdef Obj* obj
for arg in args:
@@ -194,6 +215,33 @@ cpdef deserialize_args(PyValues args):
result.append(pickle.loads(data))
return result
# todo: unify with the above, at the moment this is copied
cdef deserialize_args_from_call(Call* call):
cdef Value* val
cdef Obj* obj
result = []
for i in range(call[0].arg_size()):
val = call[0].mutable_arg(i)
if not val.has_obj():
result.append(ObjRef(val.ref(), None)) # TODO: fix this
else:
obj = val[0].mutable_obj()
if obj[0].has_string_data():
result.append(obj[0].mutable_string_data()[0].mutable_data()[0])
elif obj[0].has_int_data():
result.append(obj[0].mutable_int_data()[0].data())
elif obj[0].has_double_data():
result.append(obj[0].mutable_double_data()[0].data())
else:
data = obj[0].mutable_pyobj_data()[0].mutable_data()[0]
result.append(pickle.loads(data))
return result
cpdef deserialize_call(PythonCall pycall):
cdef Call* call = pycall.thisptr
return deserialize_args_from_call(call)
cdef int numpy_dtype_to_proto(dtype):
if dtype == np.dtype('int32'):
return INT32
+140 -14
View File
@@ -5,6 +5,11 @@ from libc.stdint cimport uint64_t, int64_t, uintptr_t
from libcpp cimport bool
from libcpp.string cimport string
try:
import cPickle as pickle
except:
import pickle
cdef struct Slice:
char* ptr
size_t size
@@ -13,7 +18,7 @@ cdef extern void* orch_create_context(const char* server_addr, const char* worke
cdef extern void orch_register_function(void* worker, const char* name, size_t num_return_vals)
cdef extern size_t orch_remote_call(void* context, void* request);
cdef extern size_t orch_push(void* context, void* value);
cdef extern void orch_main_loop(void* context);
cdef extern void* orch_main_loop(void* context);
cdef extern Slice orch_get_serialized_obj(void* context, size_t objref);
cdef extern from "Python.h":
@@ -24,15 +29,21 @@ cdef extern from "Python.h":
int PyByteArray_Resize(object self, Py_ssize_t size) except -1
char* PyByteArray_AS_STRING(object bytearray)
# cdef extern from "../../../build/generated/orchestra.pb.h":
# cdef cppclass RemoteCallRequest:
# RemoteCallRequest()
# void set_name(const char* value)
# Call* mutable_call()
cdef extern from "../../../build/generated/orchestra.pb.h":
cdef cppclass RemoteCallRequest:
RemoteCallRequest()
Call* mutable_call()
int arg_size() const;
cdef extern from "../../../build/generated/types.pb.h":
cdef cppclass Values
cdef cppclass Call:
Value* add_arg();
void set_name(const char* value)
Value* mutable_arg(int index);
int arg_size() const;
ctypedef enum DataType:
INT32
INT64
@@ -101,6 +112,30 @@ cdef serialize_into(val, Obj* obj):
# pyobj_data = obj[0].mutable_pyobj_data()
# pyobj_data[0].set_data(data, len(data))
"""
cpdef deserialize_call(PyValues args):
cdef Values* vals = args.thisptr
cdef Value* val
cdef Obj* obj
result = []
for i in range(vals[0].value_size()):
val = vals[0].mutable_value(i)
if not val.has_obj():
result.append(ObjRef(val.ref(), None)) # TODO: fix this
else:
obj = val[0].mutable_obj()
if obj[0].has_string_data():
result.append(obj[0].mutable_string_data()[0].mutable_data()[0])
elif obj[0].has_int_data():
result.append(obj[0].mutable_int_data()[0].data())
elif obj[0].has_double_data():
result.append(obj[0].mutable_double_data()[0].data())
else:
data = obj[0].mutable_pyobj_data()[0].mutable_data()[0]
result.append(pickle.loads(data))
return result
"""
cdef class ObjWrapper: # TODO: unify with the above
cdef Obj *thisptr
def __cinit__(self):
@@ -130,6 +165,51 @@ cpdef serialize_into_2(val, objptr):
# pyobj_data = obj[0].mutable_pyobj_data()
# pyobj_data[0].set_data(data, len(data))
cdef serialize_args_into_call(args, Call* call):
cdef Value* val
cdef Obj* obj
for arg in args:
val = call.add_arg()
if type(arg) == unison.ObjRef:
val[0].set_ref(arg.get_id())
else:
obj = val[0].mutable_obj()
objptr = <uintptr_t>obj
unison.serialize_into(arg, objptr)
cdef deserialize_obj(Obj* obj):
if obj[0].has_string_data():
return obj[0].mutable_string_data()[0].mutable_data()[0]
elif obj[0].has_int_data():
return obj[0].mutable_int_data()[0].data()
elif obj[0].has_double_data():
return obj[0].mutable_double_data()[0].data()
else:
data = obj[0].mutable_pyobj_data()[0].mutable_data()[0]
return pickle.loads(data)
# todo: unify with the above, at the moment this is copied
cdef deserialize_args_from_call(Call* call):
cdef Value* val
cdef Obj* obj
result = []
for i in range(call[0].arg_size()):
val = call[0].mutable_arg(i)
if not val.has_obj():
result.append(unison.ObjRef(val.ref(), None)) # TODO: fix this
else:
obj = val[0].mutable_obj()
if obj[0].has_string_data():
result.append(obj[0].mutable_string_data()[0].mutable_data()[0])
elif obj[0].has_int_data():
result.append(obj[0].mutable_int_data()[0].data())
elif obj[0].has_double_data():
result.append(obj[0].mutable_double_data()[0].data())
else:
data = obj[0].mutable_pyobj_data()[0].mutable_data()[0]
result.append(pickle.loads(data))
return result
cdef class Worker:
cdef void* context
@@ -139,13 +219,13 @@ cdef class Worker:
def connect(self, server_addr, worker_addr, objstore_addr):
self.context = orch_create_context(server_addr, worker_addr, objstore_addr)
# cpdef call(self, name, args):
# cdef RemoteCallRequest* result = new RemoteCallRequest()
# result[0].set_name(name)
# unison.serialize_args_into(args, <uintptr_t>result[0].mutable_arg())
# for i in range(10):
# orch_remote_call(self.context, result)
# # return <uintptr_t>result
cpdef call(self, name, args):
cdef RemoteCallRequest* result = new RemoteCallRequest()
cdef Call* call = result[0].mutable_call()
call.set_name(name)
serialize_args_into_call(args, call)
orch_remote_call(self.context, result)
# return <uintptr_t>result
cpdef do_call(self, ptr):
return orch_remote_call(self.context, <void*>ptr)
@@ -170,16 +250,62 @@ cdef class Worker:
data = PyBytes_FromStringAndSize(slice.ptr, slice.size)
return data
cpdef do_pull(self, objref):
cdef Slice slice = orch_get_serialized_obj(self.context, objref)
cpdef pull(self, objref):
cdef Slice slice = orch_get_serialized_obj(self.context, objref)
data = PyBytes_FromStringAndSize(slice.ptr, slice.size)
return unison.deserialize_from_string(data)
cpdef register_function(self, func_name, num_args):
orch_register_function(self.context, func_name, num_args)
cpdef main_loop(self):
orch_main_loop(self.context)
result = []
cdef Call* call = <Call*>orch_main_loop(self.context)
cdef int size = call[0].arg_size()
cdef Obj* obj
print "hello from python"
print "size", size
return deserialize_args_from_call(call)
global_worker = Worker()
def distributed(types, return_type, worker=global_worker):
def distributed_decorator(func):
# deserialize arguments, execute function and serialize result
def func_executor(args):
arguments = []
protoargs = unison.deserialize_call(args, types)
for (i, proto) in enumerate(protoargs):
if type(proto) == unison.ObjRef:
if i < len(types) - 1:
arguments.append(worker.get_object(proto, types[i]))
elif i == len(types) - 1 and types[-1] is not None:
arguments.append(global_worker.get_object(proto, types[i]))
elif types[-1] is None:
arguments.append(worker.get_object(proto, types[-2]))
else:
raise Exception("Passed in " + str(len(args)) + " arguments to function " + func.__name__ + ", which takes only " + str(len(types)) + " arguments.")
else:
arguments.append(proto)
buf = bytearray()
result = func(*arguments)
if unison.unison_type(result) != return_type:
raise Exception("Return type of " + func.func_name + " does not match the return type specified in the @distributed decorator, was expecting " + str(return_type) + " but received " + str(unison.unison_type(result)))
unison.serialize(buf, result)
return memoryview(buf).tobytes()
# for remotely executing the function
def func_call(*args, typecheck=False):
return worker.call(func_call.func_name, func_call.module_name, args)
func_call.func_name = func.__name__.encode() # why do we call encode()?
func_call.module_name = func.__module__.encode() # why do we call encode()?
func_call.is_distributed = True
func_call.executor = func_executor
func_call.types = types
return func_call
return distributed_decorator
def pull(objref, worker=global_worker):
return 1
+4 -3
View File
@@ -54,16 +54,17 @@ message ChangeCountRequest {
}
message GetDebugInfoRequest {
bool do_scheduling = 1;
}
message FnTableEntry {
uint64 workerid = 1;
repeated uint64 workerid = 1;
uint64 num_return_vals = 2;
}
message GetDebugInfoReply {
repeated Call task = 1;
repeated uint64 avail_worker = 3;
map<string, FnTableEntry> function_table = 2;
}
@@ -122,7 +123,7 @@ service ObjStore {
}
message InvokeCallRequest {
Call call = 1;
}
message InvokeCallReply {
+1 -6
View File
@@ -33,14 +33,9 @@ message Value {
Obj obj = 2; // for pass by value
}
// Deprecated
message Values {
repeated Value value = 1;
}
message Call {
string name = 1;
repeated Value args = 2;
repeated Value arg = 2;
repeated uint64 result = 3;
}
+73 -2
View File
@@ -21,7 +21,7 @@ Status ObjStoreClient::upload_data_to(slice data, ObjRef objref, ObjStore::Stub&
return writer->Finish();
}
void ObjStoreServiceImpl::allocate_memory(ObjRef objref, size_t size) {
void ObjStoreServer::allocate_memory(ObjRef objref, size_t size) {
std::ostringstream stream;
stream << "obj-" << memory_names_.size();
std::string name = stream.str();
@@ -37,8 +37,79 @@ void ObjStoreServiceImpl::allocate_memory(ObjRef objref, size_t size) {
object.ptr.len = size;
}
ObjStore::Stub& ObjStoreServer::get_objstore_stub(const std::string& objstore_address) {
auto iter = objstores_.find(objstore_address);
if (iter != objstores_.end())
return *(iter->second);
auto channel = grpc::CreateChannel(objstore_address, grpc::InsecureChannelCredentials());
objstores_.emplace(objstore_address, ObjStore::NewStub(channel));
return *objstores_[objstore_address];
}
Status ObjStoreServer::DeliverObj(ServerContext* context, const DeliverObjRequest* request, AckReply* reply) {
ObjStore::Stub& stub = get_objstore_stub(request->objstore_address());
ObjRef objref = request->objref();
// TODO: Have to introduce wait condition
return ObjStoreClient::upload_data_to(memory_[objref].ptr, objref, stub);
}
Status ObjStoreServer::DebugInfo(ServerContext* context, const DebugInfoRequest* request, DebugInfoReply* reply) {
for (const auto& entry : memory_) {
reply->add_objref(entry.first);
}
return Status::OK;
}
Status ObjStoreServer::GetObj(ServerContext* context, const GetObjRequest* request, GetObjReply* reply) {
ObjRef objref = request->objref();
std::cout << "getobj lock";
memory_lock_.lock();
shared_object& object = memory_[objref];
reply->set_bucket(object.name);
auto handle = object.memory->get_handle_from_address(object.ptr.data);
reply->set_handle(handle);
reply->set_size(object.ptr.len);
memory_lock_.unlock();
std::cout << "getobj unlock";
return Status::OK;
}
Status ObjStoreServer::StreamObj(ServerContext* context, ServerReader<ObjChunk>* reader, AckReply* reply) {
std::cout << "stream obj lock" << std::endl;
memory_lock_.lock();
ObjChunk chunk;
ObjRef objref = 0;
size_t totalsize = 0;
if (reader->Read(&chunk)) {
objref = chunk.objref();
totalsize = chunk.totalsize();
allocate_memory(objref, totalsize);
}
size_t num_bytes = 0;
char* data = memory_[objref].ptr.data;
std::cout << "before loop " << totalsize << std::endl;
do {
if (num_bytes + chunk.data().size() > totalsize) {
std::cout << "cancelled" << std::endl;
memory_lock_.unlock();
return Status::CANCELLED;
}
std::memcpy(data, chunk.data().c_str(), chunk.data().size());
data += chunk.data().size();
num_bytes += chunk.data().size();
std::cout << "looping " << num_bytes << std::endl;
} while (reader->Read(&chunk));
std::cout << "finished" << std::endl;
memory_lock_.unlock();
std::cout << "stream obj unlock" << std::endl;
return Status::OK;
}
void start_objstore(const char* objstore_address) {
ObjStoreServiceImpl service;
ObjStoreServer service;
ServerBuilder builder;
builder.AddListeningPort(std::string(objstore_address), grpc::InsecureServerCredentials());
+16 -80
View File
@@ -37,97 +37,33 @@ struct shared_object {
slice ptr;
};
class ObjStoreServiceImpl final : public ObjStore::Service {
std::vector<std::string> memory_names_;
std::unordered_map<ObjRef, shared_object> memory_;
std::mutex memory_lock_;
size_t page_size = mapped_region::get_page_size();
std::unordered_map<std::string, std::unique_ptr<ObjStore::Stub>> objstores_;
void allocate_memory(ObjRef objref, size_t size);
// check if we already connected to the other objstore, if yes, return reference to connection, otherwise connect
ObjStore::Stub& get_objstore_stub(const std::string& objstore_address) {
auto iter = objstores_.find(objstore_address);
if (iter != objstores_.end())
return *(iter->second);
auto channel = grpc::CreateChannel(objstore_address, grpc::InsecureChannelCredentials());
objstores_.emplace(objstore_address, ObjStore::NewStub(channel));
return *objstores_[objstore_address];
}
class ObjStoreServer final : public ObjStore::Service {
public:
ObjStoreServiceImpl() {}
ObjStoreServer() {}
~ObjStoreServiceImpl() {
~ObjStoreServer() {
for (const auto& segment_name : memory_names_) {
shared_memory_object::remove(segment_name.c_str());
}
}
Status DeliverObj(ServerContext* context, const DeliverObjRequest* request, AckReply* reply) override {
ObjStore::Stub& stub = get_objstore_stub(request->objstore_address());
ObjRef objref = request->objref();
Status DeliverObj(ServerContext* context, const DeliverObjRequest* request, AckReply* reply) override;
// TODO: Have to introduce wait condition
Status DebugInfo(ServerContext* context, const DebugInfoRequest* request, DebugInfoReply* reply) override;
return ObjStoreClient::upload_data_to(memory_[objref].ptr, objref, stub);
}
Status GetObj(ServerContext* context, const GetObjRequest* request, GetObjReply* reply) override;
Status DebugInfo(ServerContext* context, const DebugInfoRequest* request, DebugInfoReply* reply) override {
for (const auto& entry : memory_) {
reply->add_objref(entry.first);
}
return Status::OK;
}
Status StreamObj(ServerContext* context, ServerReader<ObjChunk>* reader, AckReply* reply) override;
private:
void allocate_memory(ObjRef objref, size_t size);
// check if we already connected to the other objstore, if yes, return reference to connection, otherwise connect
ObjStore::Stub& get_objstore_stub(const std::string& objstore_address);
Status GetObj(ServerContext* context, const GetObjRequest* request, GetObjReply* reply) override {
ObjRef objref = request->objref();
std::cout << "getobj lock";
memory_lock_.lock();
shared_object& object = memory_[objref];
reply->set_bucket(object.name);
auto handle = object.memory->get_handle_from_address(object.ptr.data);
reply->set_handle(handle);
reply->set_size(object.ptr.len);
memory_lock_.unlock();
std::cout << "getobj unlock";
return Status::OK;
}
Status StreamObj(ServerContext* context, ServerReader<ObjChunk>* reader, AckReply* reply) override {
std::cout << "stream obj lock" << std::endl;
memory_lock_.lock();
ObjChunk chunk;
ObjRef objref = 0;
size_t totalsize = 0;
if (reader->Read(&chunk)) {
objref = chunk.objref();
totalsize = chunk.totalsize();
allocate_memory(objref, totalsize);
}
size_t num_bytes = 0;
char* data = memory_[objref].ptr.data;
std::cout << "before loop " << totalsize << std::endl;
do {
if (num_bytes + chunk.data().size() > totalsize) {
std::cout << "cancelled" << std::endl;
memory_lock_.unlock();
return Status::CANCELLED;
}
std::memcpy(data, chunk.data().c_str(), chunk.data().size());
data += chunk.data().size();
num_bytes += chunk.data().size();
std::cout << "looping " << num_bytes << std::endl;
} while (reader->Read(&chunk));
std::cout << "finished" << std::endl;
memory_lock_.unlock();
std::cout << "stream obj unlock" << std::endl;
return Status::OK;
}
std::vector<std::string> memory_names_;
std::unordered_map<ObjRef, shared_object> memory_;
std::mutex memory_lock_;
size_t page_size = mapped_region::get_page_size();
std::unordered_map<std::string, std::unique_ptr<ObjStore::Stub>> objstores_;
};
#endif
+6 -6
View File
@@ -3,25 +3,25 @@
Worker* orch_create_context(const char* server_addr, const char* worker_addr, const char* objstore_addr) {
auto server_channel = grpc::CreateChannel(server_addr, grpc::InsecureChannelCredentials());
auto objstore_channel = grpc::CreateChannel(objstore_addr, grpc::InsecureChannelCredentials());
Worker* worker = new Worker(server_channel, objstore_channel);
Worker* worker = new Worker(std::string(worker_addr), server_channel, objstore_channel);
worker->register_worker(std::string(worker_addr), std::string(objstore_addr));
return worker;
}
size_t orch_remote_call(Worker* worker, RemoteCallRequest* request) {
return worker->RemoteCall(request);
return worker->remote_call(request);
}
void orch_main_loop(Worker* worker) {
worker->MainLoop();
Call* orch_main_loop(Worker* worker) {
return worker->main_loop();
}
size_t orch_push(Worker* worker, Obj* obj) {
return worker->PushObj(obj);
return worker->push_obj(obj);
}
slice orch_get_serialized_obj(Worker* worker, ObjRef objref) {
return worker->GetSerializedObj(objref);
return worker->get_serialized_obj(objref);
}
void orch_register_function(Worker* worker, const char* name, size_t num_return_vals) {
+2 -3
View File
@@ -1,5 +1,3 @@
extern "C" {
struct slice {
@@ -14,7 +12,8 @@ struct Value;
Worker* orch_create_context(const char* server_addr, const char* worker_addr, const char* objstore_addr);
size_t orch_remote_call(Worker* context, RemoteCallRequest* request);
size_t orch_push(Worker* context, Obj* value);
void orch_main_loop(Worker* worker);
Call* orch_main_loop(Worker* worker);
slice orch_get_serialized_obj(Worker* worker, size_t objref);
void orch_register_function(Worker* worker, const char* name, size_t num_return_vals);
}
+165
View File
@@ -0,0 +1,165 @@
#include "scheduler.h"
size_t Scheduler::add_task(const Call& task) {
fntable_lock_.lock();
size_t num_return_vals = fntable_[task.name()].num_return_vals();
fntable_lock_.unlock();
std::unique_ptr<Call> task_ptr(new Call(task));
tasks_lock_.lock();
tasks_.emplace_back(std::move(task_ptr));
tasks_lock_.unlock();
return num_return_vals;
}
void Scheduler::schedule() {
// TODO: work out a better strategy here
WorkerId workerid = 0;
{
std::lock_guard<std::mutex> lock(avail_workers_lock_);
if (avail_workers_.size() == 0)
return;
workerid = avail_workers_.back();
std::cout << "got available worker" << workerid << std::endl;
avail_workers_.pop_back();
}
// TODO: think about locking here
for (auto it = tasks_.begin(); it != tasks_.end(); ++it) {
const Call& task = *(*it);
auto& workers = fntable_[task.name()].workers();
if (std::binary_search(workers.begin(), workers.end(), workerid) && can_run(task)) {
submit_task(std::move(*it), workerid);
tasks_.erase(it);
return;
}
}
}
void Scheduler::submit_task(std::unique_ptr<Call> call, WorkerId workerid) {
ClientContext context;
InvokeCallRequest request;
InvokeCallReply reply;
std::cout << "sending arguments now" << std::endl;
for (size_t i = 0; i < call->arg_size(); ++i) {
if (!call->arg(i).has_obj()) {
std::cout << "need to send object ref" << call->arg(i).ref() << std::endl;
std::lock_guard<std::mutex> objtable_lock(objtable_lock_);
auto &objstores = objtable_[call->arg(i).ref()];
std::lock_guard<std::mutex> workers_lock(workers_lock_);
if (!std::binary_search(objstores.begin(), objstores.end(), workers_[workerid].objstoreid)) {
std::cout << "have to send" << std::endl;
std::exit(1);
}
// if (objstoreid != workers_[workerid].objstoreid) {
// std::lock_guard<std::mutex> objstores_lock(objstores_lock_);
// objstores_.
// }
}
}
request.set_allocated_call(call.release()); // protobuf object takes ownership
Status status = workers_[workerid].worker_stub->InvokeCall(&context, request, &reply);
}
bool Scheduler::can_run(const Call& task) {
std::lock_guard<std::mutex> lock(objtable_lock_);
for (int i = 0; i < task.arg_size(); ++i) {
if (!task.arg(i).has_obj()) {
if (objtable_[task.arg(i).ref()].size() == 0) {
return false;
}
}
}
return true;
}
WorkerId Scheduler::register_worker(const std::string& worker_address, const std::string& objstore_address) {
ObjStoreId objstoreid = std::numeric_limits<size_t>::max();
objstores_lock_.lock();
for (size_t i = 0; i < objstores_.size(); ++i) {
std::cout << "adress: " << objstores_[i].address << std::endl;
std::cout << "my adress: " << objstore_address << std::endl;
if (objstores_[i].address == objstore_address) {
objstoreid = i;
}
}
if (objstoreid == std::numeric_limits<size_t>::max()) {
// register objstore
objstoreid = objstores_.size();
auto channel = grpc::CreateChannel(objstore_address, grpc::InsecureChannelCredentials());
objstores_.push_back(ObjStoreHandle());
objstores_[objstoreid].address = objstore_address;
objstores_[objstoreid].channel = channel;
objstores_[objstoreid].objstore_stub = ObjStore::NewStub(channel);
}
objstores_lock_.unlock();
workers_lock_.lock();
WorkerId workerid = workers_.size();
workers_.push_back(WorkerHandle());
auto channel = grpc::CreateChannel(worker_address, grpc::InsecureChannelCredentials());
workers_[workerid].channel = channel;
workers_[workerid].objstoreid = objstoreid;
workers_[workerid].worker_stub = WorkerServer::NewStub(channel);
workers_lock_.unlock();
avail_workers_lock_.lock();
avail_workers_.push_back(workerid);
avail_workers_lock_.unlock();
return workerid;
}
ObjRef Scheduler::register_new_object() {
objtable_lock_.lock();
ObjRef result = objtable_.size();
objtable_.push_back(std::vector<ObjStoreId>());
objtable_lock_.unlock();
return result;
}
void Scheduler::add_location(ObjRef objref, ObjStoreId objstoreid) {
objtable_lock_.lock();
// do a binary search
auto pos = std::lower_bound(objtable_[objref].begin(), objtable_[objref].end(), objstoreid);
if (pos == objtable_[objref].end() || objstoreid < *pos) {
objtable_[objref].insert(pos, objstoreid);
}
objtable_lock_.unlock();
}
ObjStoreId Scheduler::get_store(WorkerId workerid) {
workers_lock_.lock();
ObjStoreId result = workers_[workerid].objstoreid;
workers_lock_.unlock();
return result;
}
void Scheduler::register_function(const std::string& name, WorkerId workerid, size_t num_return_vals) {
fntable_lock_.lock();
FnInfo& info = fntable_[name];
info.set_num_return_vals(num_return_vals);
info.add_worker(workerid);
fntable_lock_.unlock();
}
void Scheduler::debug_info(const GetDebugInfoRequest& request, GetDebugInfoReply* reply) {
if (request.do_scheduling()) {
schedule();
}
fntable_lock_.lock();
auto function_table = reply->mutable_function_table();
for (const auto& entry : fntable_) {
(*function_table)[entry.first].set_num_return_vals(entry.second.num_return_vals());
for (const WorkerId& worker : entry.second.workers()) {
(*function_table)[entry.first].add_workerid(worker);
}
}
fntable_lock_.unlock();
tasks_lock_.lock();
for (const auto& entry : tasks_) {
Call* call = reply->add_task();
call->CopyFrom(*entry);
}
tasks_lock_.unlock();
avail_workers_lock_.lock();
for (const WorkerId& entry : avail_workers_) {
reply->add_avail_worker(entry);
}
avail_workers_lock_.unlock();
}
+30 -100
View File
@@ -3,6 +3,8 @@
#include <deque>
#include <memory>
#include <algorithm>
#include <iostream>
#include <grpc++/grpc++.h>
@@ -16,25 +18,52 @@ using grpc::ServerReader;
using grpc::ServerContext;
using grpc::Status;
using grpc::ClientContext;
using grpc::Channel;
struct WorkerHandle {
std::shared_ptr<Channel> channel;
std::unique_ptr<WorkerServer::Stub> worker_stub;
ObjStoreId objstoreid;
};
struct ObjStoreHandle {
std::shared_ptr<Channel> channel;
std::unique_ptr<ObjStore::Stub> objstore_stub;
std::string address;
};
class Scheduler {
public:
// returns number of return values of task
size_t add_task(const Call& task);
// assign a task to a worker
void schedule();
// execute a task on a worker and ship required object references
void submit_task(std::unique_ptr<Call> call, WorkerId workerid);
// checks if the dependencies of the task are met
bool can_run(const Call& task);
// register a worker and its object store (if it has not been registered yet)
WorkerId register_worker(const std::string& worker_address, const std::string& objstore_address);
// register a new object with the scheduler and return its object reference
ObjRef register_new_object();
// register the location of the object reference in the object table
void add_location(ObjRef objref, ObjStoreId objstoreid);
// get object store associated with a workerid
ObjStoreId get_store(WorkerId workerid);
// register a function with the scheduler
void register_function(const std::string& name, WorkerId workerid, size_t num_return_vals);
// get debugging information for the scheduler
void debug_info(const GetDebugInfoRequest& request, GetDebugInfoReply* reply);
private:
// Vector of all workers registered in the system. Their index in this vector
// is the workerid.
std::vector<WorkerHandle> workers_;
std::mutex workers_lock_;
// Vector of all workers that are currently idle.
std::vector<WorkerId> available_workers_;
std::vector<WorkerId> avail_workers_;
std::mutex avail_workers_lock_;
// Vector of all object stores registered in the system. Their index in this
// vector is the objstoreid.
std::vector<ObjStoreHandle> objstores_;
@@ -48,105 +77,6 @@ class Scheduler {
// List of pending tasks.
std::deque<std::unique_ptr<Call> > tasks_;
std::mutex tasks_lock_;
public:
// returns number of return values of task
size_t add_task(const Call& task) {
fntable_lock_.lock();
size_t num_return_vals = fntable_[task.name()].num_return_vals();
fntable_lock_.unlock();
std::unique_ptr<Call> task_ptr(new Call(task)); // TODO: perform copy outside
tasks_lock_.lock();
tasks_.emplace_back(std::move(task_ptr));
tasks_lock_.unlock();
return num_return_vals;
}
WorkerId register_worker(const std::string& worker_address, const std::string& objstore_address) {
ObjStoreId objstoreid = std::numeric_limits<size_t>::max();
objstores_lock_.lock();
for (size_t i = 0; i < objstores_.size(); ++i) {
std::cout << "adress: " << objstores_[i].address << std::endl;
std::cout << "my adress: " << objstore_address << std::endl;
if (objstores_[i].address == objstore_address) {
objstoreid = i;
}
}
if (objstoreid == std::numeric_limits<size_t>::max()) {
// throw objstore_not_registered_error("objectstore not registered");
std::cout << "bad bad bad" << std::endl;
}
objstores_lock_.unlock();
workers_lock_.lock();
WorkerId result = workers_.size();
workers_.push_back(WorkerHandle());
workers_[result].channel = grpc::CreateChannel(worker_address, grpc::InsecureChannelCredentials());
workers_[result].objstoreid = objstoreid;
workers_lock_.unlock();
return result;
}
ObjStoreId register_objstore(const std::string& objstore_address) {
// auto handle = ObjStoreHandle(objstore_address);
// auto handlecopy = handle;
// auto handle = ObjStoreHandle("0.0.0.0:22222");
objstores_lock_.lock();
std::cout << "capacity" << objstores_.capacity() << std::endl;
ObjStoreId result = objstores_.size();
// auto handle = ObjStoreHandle(objstore_address);
// objstores_.emplace_back(objstore_address);
objstores_.push_back(ObjStoreHandle());
objstores_[result].channel = grpc::CreateChannel(objstore_address, grpc::InsecureChannelCredentials());
objstores_[result].address = std::string(objstore_address);
// auto handlecopy = handle;
// auto handle = grpc::CreateChannel(objstore_address, grpc::InsecureChannelCredentials());
// auto handlecopy = grpc::CreateChannel(objstore_address, grpc::InsecureChannelCredentials());
objstores_lock_.unlock();
return result;
}
ObjRef register_new_object() {
objtable_lock_.lock();
ObjRef result = objtable_.size();
objtable_.push_back(std::vector<ObjStoreId>());
objtable_lock_.unlock();
return result;
}
void add_objstore_to_obj(ObjRef objref, ObjStoreId objstoreid) {
objtable_lock_.lock();
// do a binary search
auto pos = std::lower_bound(objtable_[objref].begin(), objtable_[objref].end(), objstoreid);
if (pos == objtable_[objref].end() || objstoreid < *pos) {
objtable_[objref].insert(pos, objstoreid);
}
objtable_lock_.unlock();
}
ObjStoreId get_store(WorkerId workerid) {
workers_lock_.lock();
ObjStoreId result = workers_[workerid].objstoreid;
workers_lock_.unlock();
return result;
}
void register_function(const std::string& name, WorkerId workerid, size_t num_return_vals) {
fntable_lock_.lock();
FnInfo& info = fntable_[name];
info.set_num_return_vals(num_return_vals);
info.add_worker(workerid);
fntable_lock_.unlock();
}
void debug_info(GetDebugInfoReply* debug_info) {
fntable_lock_.lock();
for (const auto& entry : fntable_) {
auto function_table = debug_info->mutable_function_table();
(*function_table)[entry.first].set_num_return_vals(entry.second.num_return_vals());
// TODO: set workerid
}
fntable_lock_.unlock();
tasks_lock_.lock();
for (const auto& entry : tasks_) {
Call* call = debug_info->add_task();
call->CopyFrom(*entry);
}
tasks_lock_.unlock();
}
};
#endif
+1 -1
View File
@@ -12,7 +12,7 @@ Status SchedulerServerServiceImpl::RemoteCall(ServerContext* context, const Remo
Status SchedulerServerServiceImpl::PushObj(ServerContext* context, const PushObjRequest* request, PushObjReply* reply) {
ObjRef objref = scheduler_->register_new_object();
ObjStoreId objstoreid = scheduler_->get_store(request->workerid());
scheduler_->add_objstore_to_obj(objref, objstoreid);
scheduler_->add_location(objref, objstoreid);
reply->set_objref(objref);
return Status::OK;
}
+5 -1
View File
@@ -22,9 +22,11 @@ public:
}
Status RegisterWorker(ServerContext* context, const RegisterWorkerRequest* request, RegisterWorkerReply* reply) override {
WorkerId workerid = scheduler_->register_worker(request->worker_address(), request->objstore_address());
std::cout << "registered worker with workerid" << workerid << std::endl;
reply->set_workerid(workerid);
return Status::OK;
}
/*
Status RegisterObjStore(ServerContext* context, const RegisterObjStoreRequest* request, RegisterObjStoreReply* reply) override {
try {
reply->set_objstoreid(scheduler_->register_objstore(request->address()));
@@ -33,12 +35,14 @@ public:
}
return Status::OK;
}
*/
Status RegisterFunction(ServerContext* context, const RegisterFunctionRequest* request, AckReply* reply) override {
std::cout << "RegisterFunction: workerid is" << request->workerid() << std::endl;
scheduler_->register_function(request->fnname(), request->workerid(), request->num_return_vals());
return Status::OK;
}
Status GetDebugInfo(ServerContext* context, const GetDebugInfoRequest* request, GetDebugInfoReply* reply) override {
scheduler_->debug_info(reply);
scheduler_->debug_info(*request, reply);
return Status::OK;
}
};
+134
View File
@@ -0,0 +1,134 @@
# include "worker.h"
Status WorkerServiceImpl::InvokeCall(ServerContext* context, const InvokeCallRequest* request, InvokeCallReply* reply) {
call_ = request->call();
std::cout << "invoke call request" << std::endl;
try {
Call* callptr = &call_;
message_queue mq(open_only, worker_address_.c_str());
std::cout << "before send: num args" << call_.arg_size() << std::endl;
mq.send(&callptr, sizeof(Call*), 0);
}
catch(interprocess_exception &ex){
message_queue::remove(worker_address_.c_str());
std::cout << ex.what() << std::endl;
// TODO: return Status;
}
message_queue::remove(worker_address_.c_str());
std::cout << "notified server" << std::endl;
return Status::OK;
}
size_t Worker::remote_call(RemoteCallRequest* request) {
RemoteCallReply reply;
ClientContext context;
Status status = scheduler_stub_->RemoteCall(&context, *request, &reply);
// TODO: Return results: return reply.result(0);
}
void Worker::register_worker(const std::string& worker_address, const std::string& objstore_address) {
RegisterWorkerRequest request;
request.set_worker_address(worker_address);
request.set_objstore_address(objstore_address);
RegisterWorkerReply reply;
ClientContext context;
Status status = scheduler_stub_->RegisterWorker(&context, request, &reply);
workerid_ = reply.workerid();
return;
}
ObjRef Worker::push_obj(Obj* obj) {
// first get objref for the new object
PushObjRequest push_request;
PushObjReply push_reply;
ClientContext push_context;
Status push_status = scheduler_stub_->PushObj(&push_context, push_request, &push_reply);
ObjRef objref = push_reply.objref();
// then stream the object to the object store
ObjChunk chunk;
std::string data;
obj->SerializeToString(&data);
size_t totalsize = data.size();
ClientContext context;
AckReply reply;
std::unique_ptr<ClientWriter<ObjChunk> > writer(
objstore_stub_->StreamObj(&context, &reply));
const char* head = data.c_str();
for (size_t i = 0; i < data.length(); i += CHUNK_SIZE) {
chunk.set_objref(objref);
std::cout << "chunk totalsize" << std::endl;
chunk.set_totalsize(totalsize);
chunk.set_data(head + i, std::min(CHUNK_SIZE, data.length() - i));
if (!writer->Write(chunk)) {
std::cout << "write failed" << std::endl;
// TODO: Better error handling: throw std::runtime_error("write failed");
}
}
writer->WritesDone();
Status status = writer->Finish();
return objref;
}
slice Worker::get_serialized_obj(ObjRef objref) {
ClientContext context;
GetObjRequest request;
request.set_objref(objref);
GetObjReply reply;
objstore_stub_->GetObj(&context, request, &reply);
segment_ = managed_shared_memory(open_only, reply.bucket().c_str());
slice slice;
slice.data = static_cast<char*>(segment_.get_address_from_handle(reply.handle()));
slice.len = reply.size();
return slice;
}
void Worker::register_function(const std::string& name, size_t num_return_vals) {
ClientContext context;
RegisterFunctionRequest request;
request.set_fnname(name);
request.set_num_return_vals(num_return_vals);
request.set_workerid(workerid_);
AckReply reply;
scheduler_stub_->RegisterFunction(&context, request, &reply);
}
void start_worker_server(const char* server_address) {
WorkerServiceImpl service(server_address);
ServerBuilder builder;
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
builder.RegisterService(&service);
std::unique_ptr<Server> server(builder.BuildAndStart());
std::cout << "Server listening on " << server_address << std::endl;
server->Wait();
}
// Communication between the WorkerServer and the Worker happens via a message
// queue. This is because the Python interpreter needs to be single threaded
// (in our case running in the main thread), whereas the WorkerService will
// run in a separate thread and potentially utilize multiple threads.
Call* Worker::main_loop() {
// start the worker server
worker_server_thread_ = std::thread(start_worker_server, worker_address_.c_str());
// process the next call
return receive(worker_address_.c_str());
}
Call* receive(const char* message_queue_name) {
try {
message_queue::remove(message_queue_name);
message_queue mq(create_only, message_queue_name, 1, sizeof(Call*));
unsigned int priority;
message_queue::size_type recvd_size;
Call* call;
while (true) {
mq.receive(&call, sizeof(Call*), recvd_size, priority);
std::cout << "got call" << call << std::endl;
std::cout << "after send: num args" << call->arg_size() << std::endl;
return call;
}
}
catch(interprocess_exception &ex){
message_queue::remove(message_queue_name);
std::cout << ex.what() << std::endl;
}
}
+33 -100
View File
@@ -7,6 +7,8 @@
#include <thread>
#include <boost/interprocess/managed_shared_memory.hpp>
#include <boost/interprocess/ipc/message_queue.hpp>
using namespace boost::interprocess;
#include <grpc++/grpc++.h>
@@ -26,118 +28,49 @@ using grpc::ClientContext;
using grpc::ClientWriter;
class WorkerServiceImpl final : public WorkerServer::Service {
Status InvokeCall(ServerContext* context, const InvokeCallRequest* request,
InvokeCallReply* reply) override {
std::cout << "invoke call request" << std::endl;
return Status::OK;
}
public:
WorkerServiceImpl(const std::string& worker_address)
: worker_address_(worker_address) {}
Status InvokeCall(ServerContext* context, const InvokeCallRequest* request, InvokeCallReply* reply) override;
private:
std::string worker_address_;
Call call_; // copy of the current call
};
void start_server() {
std::string server_address("0.0.0.0:50053");
WorkerServiceImpl service;
ServerBuilder builder;
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
builder.RegisterService(&service);
std::unique_ptr<Server> server(builder.BuildAndStart());
std::cout << "Server listening on " << server_address << std::endl;
server->Wait();
}
void start_worker_server(const char* worker_addr);
Call* receive(const char* worker_addr);
class Worker {
managed_shared_memory segment;
public:
Worker(std::shared_ptr<Channel> scheduler_channel, std::shared_ptr<Channel> objstore_channel)
: scheduler_stub_(SchedulerServer::NewStub(scheduler_channel)),
Worker(const std::string& worker_address, std::shared_ptr<Channel> scheduler_channel, std::shared_ptr<Channel> objstore_channel)
: worker_address_(worker_address),
scheduler_stub_(SchedulerServer::NewStub(scheduler_channel)),
objstore_stub_(ObjStore::NewStub(objstore_channel))
{}
size_t RemoteCall(RemoteCallRequest* request) {
// RemoteCallReply reply;
// ClientContext context;
// Status status = stub_->RemoteCall(&context, *request, &reply);
// return reply.result();
return 42;
}
void register_worker(const std::string& worker_address, const std::string& objstore_address) {
RegisterWorkerRequest request;
request.set_worker_address(worker_address);
request.set_objstore_address(objstore_address);
RegisterWorkerReply reply;
ClientContext context;
Status status = scheduler_stub_->RegisterWorker(&context, request, &reply);
return;
}
const size_t CHUNK_SIZE = 8 * 1024;
ObjRef PushObj(Obj* obj) {
// first get objref for the new object
PushObjRequest push_request;
PushObjReply push_reply;
ClientContext push_context;
Status push_status = scheduler_stub_->PushObj(&push_context, push_request, &push_reply);
ObjRef objref = push_reply.objref();
ObjChunk chunk;
std::string data;
obj->SerializeToString(&data);
size_t totalsize = data.size();
ClientContext context;
AckReply reply;
std::unique_ptr<ClientWriter<ObjChunk> > writer(
objstore_stub_->StreamObj(&context, &reply));
const char* head = data.c_str();
for (size_t i = 0; i < data.length(); i += CHUNK_SIZE) {
chunk.set_objref(objref);
std::cout << "chunk totalsize" << std::endl;
chunk.set_totalsize(totalsize);
chunk.set_data(head + i, std::min(CHUNK_SIZE, data.length() - i));
if (!writer->Write(chunk)) {
std::cout << "write failed" << std::endl;
// throw std::runtime_error("write failed");
}
}
writer->WritesDone();
Status status = writer->Finish();
return objref;
}
slice GetSerializedObj(ObjRef objref) {
ClientContext context;
GetObjRequest request;
request.set_objref(objref);
GetObjReply reply;
objstore_stub_->GetObj(&context, request, &reply);
segment = managed_shared_memory(open_only, reply.bucket().c_str());
slice slice;
slice.data = static_cast<char*>(segment.get_address_from_handle(reply.handle()));
slice.len = reply.size();
return slice;
}
void register_function(const std::string& name, size_t num_return_vals) {
ClientContext context;
RegisterFunctionRequest request;
request.set_fnname(name);
request.set_num_return_vals(num_return_vals);
AckReply reply;
scheduler_stub_->RegisterFunction(&context, request, &reply);
}
void MainLoop() {
scheduler_thread = std::thread(start_server);
}
// submit a remote call to the scheduler
size_t remote_call(RemoteCallRequest* request);
// send request to the scheduler to register this worker
void register_worker(const std::string& worker_address, const std::string& objstore_address);
// push object to local object store, register it with the server and return object reference
ObjRef push_obj(Obj* obj);
// retrieve serialized object from local object store
slice get_serialized_obj(ObjRef objref);
// register function with scheduler
void register_function(const std::string& name, size_t num_return_vals);
// start the main loop
Call* main_loop();
private:
const size_t CHUNK_SIZE = 8 * 1024;
std::unique_ptr<SchedulerServer::Stub> scheduler_stub_;
std::unique_ptr<ObjStore::Stub> objstore_stub_;
std::thread scheduler_thread;
std::thread worker_server_thread_;
std::thread other_thread_;
managed_shared_memory segment_;
WorkerId workerid_;
std::string worker_address_;
};
#endif
+63 -84
View File
@@ -4,6 +4,10 @@ import orchpy.services as services
import orchpy.worker as worker
import numpy as np
import time
import subprocess32 as subprocess
import os
from google.protobuf.text_format import *
from grpc.beta import implementations
import orchestra_pb2
@@ -27,12 +31,25 @@ class UnisonTest(unittest.TestCase):
b = unison.deserialize_args(res)
self.assertTrue((a == b).all())
a = [unison.ObjRef(42, int)]
res = unison.serialize_args(a)
b = unison.deserialize_args(res)
self.assertEqual(a, b)
TIMEOUT_SECONDS = 5
def produce_data(num_chunks):
for _ in range(num_chunks):
yield orchestra_pb2.ObjChunk(objref=1, totalsize=1000, data=b"hello world")
def connect_to_scheduler(host, port):
channel = implementations.insecure_channel(host, port)
return orchestra_pb2.beta_create_SchedulerServer_stub(channel)
def connect_to_objstore(host, port):
channel = implementations.insecure_channel(host, port)
return orchestra_pb2.beta_create_ObjStore_stub(channel)
class ObjStoreTest(unittest.TestCase):
"""Test setting up object stores, transfering data between them and retrieving data to a client"""
@@ -40,123 +57,85 @@ class ObjStoreTest(unittest.TestCase):
services.start_scheduler("0.0.0.0:22221")
services.start_objstore("0.0.0.0:22222")
services.start_objstore("0.0.0.0:22223")
time.sleep(0.5)
scheduler_channel = implementations.insecure_channel('localhost', 22221)
scheduler_stub = orchestra_pb2.beta_create_SchedulerServer_stub(scheduler_channel)
objstore1_channel = implementations.insecure_channel('localhost', 22222)
objstore1_stub = orchestra_pb2.beta_create_ObjStore_stub(objstore1_channel)
objstore2_channel = implementations.insecure_channel('localhost', 22223)
objstore2_stub = orchestra_pb2.beta_create_ObjStore_stub(objstore2_channel)
time.sleep(0.2)
scheduler_stub.RegisterObjStore(orchestra_pb2.RegisterObjStoreRequest(address="127.0.0.1:22222"), TIMEOUT_SECONDS)
scheduler_stub.RegisterObjStore(orchestra_pb2.RegisterObjStoreRequest(address="127.0.0.1:22223"), TIMEOUT_SECONDS)
scheduler_stub = connect_to_scheduler('localhost', 22221)
objstore1_stub = connect_to_objstore('localhost', 22222)
objstore2_stub = connect_to_objstore('localhost', 22223)
worker.global_worker.connect("127.0.0.1:22221", "127.0.0.1:40000", "127.0.0.1:22222")
worker1 = worker.Worker()
worker1.connect("127.0.0.1:22221", "127.0.0.1:40000", "127.0.0.1:22222")
other_worker = worker.Worker()
other_worker.connect("127.0.0.1:22221", "127.0.0.1:40001", "127.0.0.1:22223")
worker2 = worker.Worker()
worker2.connect("127.0.0.1:22221", "127.0.0.1:40001", "127.0.0.1:22223")
# import IPython
# IPython.embed()
for i in range(1, 10):
for i in range(1, 100):
l = i * 100 * "h"
objref = worker.global_worker.do_push(l)
# time.sleep(5.0)
objref = worker1.do_push(l)
response = objstore1_stub.DeliverObj(orchestra_pb2.DeliverObjRequest(objref=objref, objstore_address="0.0.0.0:22223"), TIMEOUT_SECONDS)
# time.sleep(5.0)
str = other_worker.get_serialized(objref)
result = worker.unison.deserialize_from_string(str)
# import IPython
# IPython.embed()
s = worker2.get_serialized(objref)
result = worker.unison.deserialize_from_string(s)
self.assertEqual(len(result), 100 * i)
class SchedulerTest(unittest.TestCase):
services.cleanup()
def testRegister(self):
scheduler_channel = implementations.insecure_channel('localhost', 22221)
scheduler_stub = orchestra_pb2.beta_create_SchedulerServer_stub(scheduler_channel)
w = worker.Worker()
w.connect("127.0.0.1:22221", "127.0.0.1:40002", "127.0.0.1:22222")
w.register_function("hello_world", 2)
reply = scheduler_stub.GetDebugInfo(orchestra_pb2.GetDebugInfoRequest(), TIMEOUT_SECONDS)
self.assertEqual(reply.function_table.items()[0][0], u'hello_world')
class SchedulerTest(unittest.TestCase):
def testCall(self):
scheduler_channel = implementations.insecure_channel('localhost', 22221)
scheduler_stub = orchestra_pb2.beta_create_SchedulerServer_stub(scheduler_channel)
w = worker.Worker()
w.connect("127.0.0.1:22221", "127.0.0.1:40003", "127.0.0.1:22222")
"""
class SchedulerTest(unittest.TestCase):
def testServer(self):
services.start_scheduler("0.0.0.0:22221")
services.start_objstore("0.0.0.0:22222")
services.start_objstore("0.0.0.0:22223")
time.sleep(1.0)
scheduler_channel = implementations.insecure_channel('localhost', 22221)
scheduler_stub = orchestra_pb2.beta_create_SchedulerServer_stub(scheduler_channel)
objstore_channel = implementations.insecure_channel('localhost', 22222)
objstore_stub = orchestra_pb2.beta_create_ObjStore_stub(objstore_channel)
objstore_channel2 = implementations.insecure_channel('localhost', 22223)
objstore_stub2 = orchestra_pb2.beta_create_ObjStore_stub(objstore_channel2)
time.sleep(0.2)
# call = types_pb2.Call(name="test")
# response = scheduler_stub.RemoteCall(orchestra_pb2.RemoteCallRequest(call=call), TIMEOUT_SECONDS)
# response = scheduler_stub.RegisterFunction(orchestra_pb2.RegisterFunctionRequest(workerid=1, fnname="hello"), TIMEOUT_SECONDS)
scheduler_stub = connect_to_scheduler('localhost', 22221)
objstore_stub = connect_to_objstore('localhost', 22222)
response2 = scheduler_stub.RegisterObjStore(orchestra_pb2.RegisterObjStoreRequest(address="127.0.0.1:22222"), TIMEOUT_SECONDS)
response2 = scheduler_stub.RegisterObjStore(orchestra_pb2.RegisterObjStoreRequest(address="127.0.0.1:22223"), TIMEOUT_SECONDS)
time.sleep(0.2)
# response2 = scheduler_stub.RegisterObjStore(orchestra_pb2.RegisterObjStoreRequest(address="127.0.0.1:22222"), TIMEOUT_SECONDS)
# response3 = scheduler_stub.RegisterObjStore(orchestra_pb2.RegisterObjStoreRequest(address="127.0.0.1:22223"), TIMEOUT_SECONDS)
w = worker.Worker()
w.connect("127.0.0.1:22221", "127.0.0.1:40003", "127.0.0.1:22222")
w2 = worker.Worker()
w2.connect("127.0.0.1:22221", "127.0.0.1:40004", "127.0.0.1:22222")
# objstore_stub.StreamObj(produce_data(100), TIMEOUT_SECONDS)
time.sleep(0.2)
worker.global_worker.connect("127.0.0.1:22221", "127.0.0.1:40000", "127.0.0.1:22222")
w.register_function("hello_world", 2)
w2.register_function("hello_world", 2)
l = [1, 2, 3, 4]
worker.global_worker.do_push(l)
time.sleep(0.1)
## res = scheduler_stub.PushObj(orchestra_pb2.PushObjRequest(workerid=0), TIMEOUT_SECONDS)
w.call("hello_world", ["hi"])
response = objstore_stub.DeliverObj(orchestra_pb2.DeliverObjRequest(objref=0, objstore_address="0.0.0.0:22223"), TIMEOUT_SECONDS)
time.sleep(0.1)
# res = objstore_stub2.DebugInfo(orchestra_pb2.DebugInfoRequest(), TIMEOUT_SECONDS)
reply = scheduler_stub.GetDebugInfo(orchestra_pb2.GetDebugInfoRequest(), TIMEOUT_SECONDS)
response = objstore_stub.GetObj(orchestra_pb2.GetObjRequest(objref=0), TIMEOUT_SECONDS)
self.assertEqual(reply.task[0].name, u'hello_world')
worker.global_worker.get_serialized(0)
test_path = os.path.dirname(os.path.abspath(__file__))
import IPython
IPython.embed()
p = subprocess.Popen(["python", os.path.join(test_path, "testrecv.py")])
l = [1, 2, 3, 4]
worker.global_worker.do_push(l)
time.sleep(0.2)
response = objstore_stub.DeliverObj(orchestra_pb2.DeliverObjRequest(), TIMEOUT_SECONDS)
scheduler_stub.PushObj(orchestra_pb2.PushObjRequest(workerid=0), TIMEOUT_SECONDS)
# response = objstore_stub.DebugInfo(orchestra_pb2.DebugInfoRequest(), TIMEOUT_SECONDS)
reply = scheduler_stub.GetDebugInfo(orchestra_pb2.GetDebugInfoRequest(do_scheduling=True), TIMEOUT_SECONDS)
# import IPython
# IPython.embed()
self.assertEqual(p.wait(), 0, "argument was not received by the test program")
# worker.global_worker.connect("127.0.0.1:22221", "127.0.0.1:22222")
# l = [1, 2, 3, 4]
# worker.global_worker.do_push(l)
# w.main_loop()
# w2.main_loop()
#
# reply = scheduler_stub.GetDebugInfo(orchestra_pb2.GetDebugInfoRequest(do_scheduling=True), TIMEOUT_SECONDS)
# time.sleep(0.1)
# reply = scheduler_stub.GetDebugInfo(orchestra_pb2.GetDebugInfoRequest(), TIMEOUT_SECONDS)
#
# self.assertEqual(list(reply.task), [])
#
# services.cleanup()
# import IPython
# IPython.embed()
# response = objstore_stub.DeliverObj(orchestra_pb2.DeliverObjRequest())
# print "Greeter client received: " + response.message
# import IPython
# IPython.embed()
"""
if __name__ == '__main__':
unittest.main()