mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 18:06:25 +08:00
before refactoring
This commit is contained in:
+2
-2
@@ -70,7 +70,7 @@ if (UNIX AND NOT APPLE)
|
||||
endif()
|
||||
|
||||
add_executable(objstore src/objstore.cc ${GENERATED_PROTOBUF_FILES})
|
||||
add_executable(scheduler_server src/scheduler_server.cc ${GENERATED_PROTOBUF_FILES})
|
||||
add_library(orchlib SHARED src/orchlib.cc ${GENERATED_PROTOBUF_FILES})
|
||||
add_executable(scheduler_server src/scheduler_server.cc src/scheduler.cc ${GENERATED_PROTOBUF_FILES})
|
||||
add_library(orchlib SHARED src/orchlib.cc src/worker.cc ${GENERATED_PROTOBUF_FILES})
|
||||
|
||||
install(TARGETS objstore scheduler_server orchlib DESTINATION ${CMAKE_SOURCE_DIR}/lib/orchpy/orchpy)
|
||||
|
||||
@@ -27,6 +27,9 @@ public:
|
||||
ObjRef worker(size_t i) const {
|
||||
return workers_[i];
|
||||
}
|
||||
const std::vector<WorkerId>& workers() const {
|
||||
return workers_;
|
||||
}
|
||||
};
|
||||
|
||||
typedef std::vector<std::vector<ObjStoreId> > ObjTable;
|
||||
|
||||
@@ -8,16 +8,18 @@ _services_path = os.path.dirname(os.path.abspath(__file__))
|
||||
all_processes = []
|
||||
|
||||
def cleanup():
|
||||
global all_processes
|
||||
timeout_sec = 5
|
||||
for p in all_processes:
|
||||
p_sec = 0
|
||||
for second in range(timeout_sec):
|
||||
if p.poll() == None:
|
||||
time.sleep(1)
|
||||
time.sleep(0.1)
|
||||
p_sec += 1
|
||||
if p_sec >= timeout_sec:
|
||||
p.kill() # supported from python 2.6
|
||||
print 'helper processes shut down!'
|
||||
all_processes = []
|
||||
|
||||
atexit.register(cleanup)
|
||||
|
||||
|
||||
@@ -10,6 +10,14 @@ try:
|
||||
except:
|
||||
import pickle
|
||||
|
||||
cdef extern from "../../../build/generated/types.pb.h":
|
||||
|
||||
cdef cppclass Call:
|
||||
Value* add_arg();
|
||||
void set_name(const char* value)
|
||||
Value* mutable_arg(int index);
|
||||
int arg_size() const;
|
||||
|
||||
cdef extern from "../../../build/generated/types.pb.h":
|
||||
ctypedef enum DataType:
|
||||
INT32
|
||||
@@ -89,6 +97,15 @@ cdef class ObjWrapper: # TODO: unify with the above
|
||||
def get_value(self):
|
||||
return <uintptr_t>self.thisptr
|
||||
|
||||
cdef class PythonCall:
|
||||
cdef Call* thisptr
|
||||
def __cinit__(self):
|
||||
self.thisptr = new Call()
|
||||
def __dealloc__(self):
|
||||
del self.thisptr
|
||||
def get_value(self):
|
||||
return <uintptr_t>self.thisptr
|
||||
|
||||
cdef class ObjRef:
|
||||
cdef size_t _id
|
||||
cdef object type
|
||||
@@ -136,6 +153,10 @@ cpdef serialize(val):
|
||||
cpdef serialize_args_into(args, valsptr):
|
||||
cdef uintptr_t ptr = <uintptr_t>valsptr
|
||||
cdef Values* vals = <Values*>ptr
|
||||
serialize_args_into_vals(args, vals)
|
||||
|
||||
# this code is a mess right now, will be improved in the C++ version
|
||||
cdef serialize_args_into_vals(args, Values* vals):
|
||||
cdef Value* val
|
||||
cdef Obj* obj
|
||||
for arg in args:
|
||||
@@ -194,6 +215,33 @@ cpdef deserialize_args(PyValues args):
|
||||
result.append(pickle.loads(data))
|
||||
return result
|
||||
|
||||
# todo: unify with the above, at the moment this is copied
|
||||
cdef deserialize_args_from_call(Call* call):
|
||||
cdef Value* val
|
||||
cdef Obj* obj
|
||||
result = []
|
||||
for i in range(call[0].arg_size()):
|
||||
val = call[0].mutable_arg(i)
|
||||
if not val.has_obj():
|
||||
result.append(ObjRef(val.ref(), None)) # TODO: fix this
|
||||
else:
|
||||
obj = val[0].mutable_obj()
|
||||
if obj[0].has_string_data():
|
||||
result.append(obj[0].mutable_string_data()[0].mutable_data()[0])
|
||||
elif obj[0].has_int_data():
|
||||
result.append(obj[0].mutable_int_data()[0].data())
|
||||
elif obj[0].has_double_data():
|
||||
result.append(obj[0].mutable_double_data()[0].data())
|
||||
else:
|
||||
data = obj[0].mutable_pyobj_data()[0].mutable_data()[0]
|
||||
result.append(pickle.loads(data))
|
||||
return result
|
||||
|
||||
cpdef deserialize_call(PythonCall pycall):
|
||||
cdef Call* call = pycall.thisptr
|
||||
return deserialize_args_from_call(call)
|
||||
|
||||
|
||||
cdef int numpy_dtype_to_proto(dtype):
|
||||
if dtype == np.dtype('int32'):
|
||||
return INT32
|
||||
|
||||
+140
-14
@@ -5,6 +5,11 @@ from libc.stdint cimport uint64_t, int64_t, uintptr_t
|
||||
from libcpp cimport bool
|
||||
from libcpp.string cimport string
|
||||
|
||||
try:
|
||||
import cPickle as pickle
|
||||
except:
|
||||
import pickle
|
||||
|
||||
cdef struct Slice:
|
||||
char* ptr
|
||||
size_t size
|
||||
@@ -13,7 +18,7 @@ cdef extern void* orch_create_context(const char* server_addr, const char* worke
|
||||
cdef extern void orch_register_function(void* worker, const char* name, size_t num_return_vals)
|
||||
cdef extern size_t orch_remote_call(void* context, void* request);
|
||||
cdef extern size_t orch_push(void* context, void* value);
|
||||
cdef extern void orch_main_loop(void* context);
|
||||
cdef extern void* orch_main_loop(void* context);
|
||||
cdef extern Slice orch_get_serialized_obj(void* context, size_t objref);
|
||||
|
||||
cdef extern from "Python.h":
|
||||
@@ -24,15 +29,21 @@ cdef extern from "Python.h":
|
||||
int PyByteArray_Resize(object self, Py_ssize_t size) except -1
|
||||
char* PyByteArray_AS_STRING(object bytearray)
|
||||
|
||||
# cdef extern from "../../../build/generated/orchestra.pb.h":
|
||||
# cdef cppclass RemoteCallRequest:
|
||||
# RemoteCallRequest()
|
||||
# void set_name(const char* value)
|
||||
# Call* mutable_call()
|
||||
cdef extern from "../../../build/generated/orchestra.pb.h":
|
||||
cdef cppclass RemoteCallRequest:
|
||||
RemoteCallRequest()
|
||||
Call* mutable_call()
|
||||
int arg_size() const;
|
||||
|
||||
cdef extern from "../../../build/generated/types.pb.h":
|
||||
cdef cppclass Values
|
||||
|
||||
cdef cppclass Call:
|
||||
Value* add_arg();
|
||||
void set_name(const char* value)
|
||||
Value* mutable_arg(int index);
|
||||
int arg_size() const;
|
||||
|
||||
ctypedef enum DataType:
|
||||
INT32
|
||||
INT64
|
||||
@@ -101,6 +112,30 @@ cdef serialize_into(val, Obj* obj):
|
||||
# pyobj_data = obj[0].mutable_pyobj_data()
|
||||
# pyobj_data[0].set_data(data, len(data))
|
||||
|
||||
"""
|
||||
cpdef deserialize_call(PyValues args):
|
||||
cdef Values* vals = args.thisptr
|
||||
cdef Value* val
|
||||
cdef Obj* obj
|
||||
result = []
|
||||
for i in range(vals[0].value_size()):
|
||||
val = vals[0].mutable_value(i)
|
||||
if not val.has_obj():
|
||||
result.append(ObjRef(val.ref(), None)) # TODO: fix this
|
||||
else:
|
||||
obj = val[0].mutable_obj()
|
||||
if obj[0].has_string_data():
|
||||
result.append(obj[0].mutable_string_data()[0].mutable_data()[0])
|
||||
elif obj[0].has_int_data():
|
||||
result.append(obj[0].mutable_int_data()[0].data())
|
||||
elif obj[0].has_double_data():
|
||||
result.append(obj[0].mutable_double_data()[0].data())
|
||||
else:
|
||||
data = obj[0].mutable_pyobj_data()[0].mutable_data()[0]
|
||||
result.append(pickle.loads(data))
|
||||
return result
|
||||
"""
|
||||
|
||||
cdef class ObjWrapper: # TODO: unify with the above
|
||||
cdef Obj *thisptr
|
||||
def __cinit__(self):
|
||||
@@ -130,6 +165,51 @@ cpdef serialize_into_2(val, objptr):
|
||||
# pyobj_data = obj[0].mutable_pyobj_data()
|
||||
# pyobj_data[0].set_data(data, len(data))
|
||||
|
||||
cdef serialize_args_into_call(args, Call* call):
|
||||
cdef Value* val
|
||||
cdef Obj* obj
|
||||
for arg in args:
|
||||
val = call.add_arg()
|
||||
if type(arg) == unison.ObjRef:
|
||||
val[0].set_ref(arg.get_id())
|
||||
else:
|
||||
obj = val[0].mutable_obj()
|
||||
objptr = <uintptr_t>obj
|
||||
unison.serialize_into(arg, objptr)
|
||||
|
||||
cdef deserialize_obj(Obj* obj):
|
||||
if obj[0].has_string_data():
|
||||
return obj[0].mutable_string_data()[0].mutable_data()[0]
|
||||
elif obj[0].has_int_data():
|
||||
return obj[0].mutable_int_data()[0].data()
|
||||
elif obj[0].has_double_data():
|
||||
return obj[0].mutable_double_data()[0].data()
|
||||
else:
|
||||
data = obj[0].mutable_pyobj_data()[0].mutable_data()[0]
|
||||
return pickle.loads(data)
|
||||
|
||||
# todo: unify with the above, at the moment this is copied
|
||||
cdef deserialize_args_from_call(Call* call):
|
||||
cdef Value* val
|
||||
cdef Obj* obj
|
||||
result = []
|
||||
for i in range(call[0].arg_size()):
|
||||
val = call[0].mutable_arg(i)
|
||||
if not val.has_obj():
|
||||
result.append(unison.ObjRef(val.ref(), None)) # TODO: fix this
|
||||
else:
|
||||
obj = val[0].mutable_obj()
|
||||
if obj[0].has_string_data():
|
||||
result.append(obj[0].mutable_string_data()[0].mutable_data()[0])
|
||||
elif obj[0].has_int_data():
|
||||
result.append(obj[0].mutable_int_data()[0].data())
|
||||
elif obj[0].has_double_data():
|
||||
result.append(obj[0].mutable_double_data()[0].data())
|
||||
else:
|
||||
data = obj[0].mutable_pyobj_data()[0].mutable_data()[0]
|
||||
result.append(pickle.loads(data))
|
||||
return result
|
||||
|
||||
cdef class Worker:
|
||||
cdef void* context
|
||||
|
||||
@@ -139,13 +219,13 @@ cdef class Worker:
|
||||
def connect(self, server_addr, worker_addr, objstore_addr):
|
||||
self.context = orch_create_context(server_addr, worker_addr, objstore_addr)
|
||||
|
||||
# cpdef call(self, name, args):
|
||||
# cdef RemoteCallRequest* result = new RemoteCallRequest()
|
||||
# result[0].set_name(name)
|
||||
# unison.serialize_args_into(args, <uintptr_t>result[0].mutable_arg())
|
||||
# for i in range(10):
|
||||
# orch_remote_call(self.context, result)
|
||||
# # return <uintptr_t>result
|
||||
cpdef call(self, name, args):
|
||||
cdef RemoteCallRequest* result = new RemoteCallRequest()
|
||||
cdef Call* call = result[0].mutable_call()
|
||||
call.set_name(name)
|
||||
serialize_args_into_call(args, call)
|
||||
orch_remote_call(self.context, result)
|
||||
# return <uintptr_t>result
|
||||
|
||||
cpdef do_call(self, ptr):
|
||||
return orch_remote_call(self.context, <void*>ptr)
|
||||
@@ -170,16 +250,62 @@ cdef class Worker:
|
||||
data = PyBytes_FromStringAndSize(slice.ptr, slice.size)
|
||||
return data
|
||||
|
||||
cpdef do_pull(self, objref):
|
||||
cdef Slice slice = orch_get_serialized_obj(self.context, objref)
|
||||
|
||||
cpdef pull(self, objref):
|
||||
cdef Slice slice = orch_get_serialized_obj(self.context, objref)
|
||||
data = PyBytes_FromStringAndSize(slice.ptr, slice.size)
|
||||
return unison.deserialize_from_string(data)
|
||||
|
||||
cpdef register_function(self, func_name, num_args):
|
||||
orch_register_function(self.context, func_name, num_args)
|
||||
|
||||
cpdef main_loop(self):
|
||||
orch_main_loop(self.context)
|
||||
result = []
|
||||
cdef Call* call = <Call*>orch_main_loop(self.context)
|
||||
cdef int size = call[0].arg_size()
|
||||
cdef Obj* obj
|
||||
print "hello from python"
|
||||
print "size", size
|
||||
return deserialize_args_from_call(call)
|
||||
|
||||
global_worker = Worker()
|
||||
|
||||
def distributed(types, return_type, worker=global_worker):
|
||||
def distributed_decorator(func):
|
||||
# deserialize arguments, execute function and serialize result
|
||||
def func_executor(args):
|
||||
arguments = []
|
||||
protoargs = unison.deserialize_call(args, types)
|
||||
for (i, proto) in enumerate(protoargs):
|
||||
if type(proto) == unison.ObjRef:
|
||||
if i < len(types) - 1:
|
||||
arguments.append(worker.get_object(proto, types[i]))
|
||||
elif i == len(types) - 1 and types[-1] is not None:
|
||||
arguments.append(global_worker.get_object(proto, types[i]))
|
||||
elif types[-1] is None:
|
||||
arguments.append(worker.get_object(proto, types[-2]))
|
||||
else:
|
||||
raise Exception("Passed in " + str(len(args)) + " arguments to function " + func.__name__ + ", which takes only " + str(len(types)) + " arguments.")
|
||||
else:
|
||||
arguments.append(proto)
|
||||
buf = bytearray()
|
||||
result = func(*arguments)
|
||||
if unison.unison_type(result) != return_type:
|
||||
raise Exception("Return type of " + func.func_name + " does not match the return type specified in the @distributed decorator, was expecting " + str(return_type) + " but received " + str(unison.unison_type(result)))
|
||||
unison.serialize(buf, result)
|
||||
return memoryview(buf).tobytes()
|
||||
# for remotely executing the function
|
||||
def func_call(*args, typecheck=False):
|
||||
return worker.call(func_call.func_name, func_call.module_name, args)
|
||||
func_call.func_name = func.__name__.encode() # why do we call encode()?
|
||||
func_call.module_name = func.__module__.encode() # why do we call encode()?
|
||||
func_call.is_distributed = True
|
||||
func_call.executor = func_executor
|
||||
func_call.types = types
|
||||
return func_call
|
||||
return distributed_decorator
|
||||
|
||||
def pull(objref, worker=global_worker):
|
||||
return 1
|
||||
|
||||
@@ -54,16 +54,17 @@ message ChangeCountRequest {
|
||||
}
|
||||
|
||||
message GetDebugInfoRequest {
|
||||
|
||||
bool do_scheduling = 1;
|
||||
}
|
||||
|
||||
message FnTableEntry {
|
||||
uint64 workerid = 1;
|
||||
repeated uint64 workerid = 1;
|
||||
uint64 num_return_vals = 2;
|
||||
}
|
||||
|
||||
message GetDebugInfoReply {
|
||||
repeated Call task = 1;
|
||||
repeated uint64 avail_worker = 3;
|
||||
map<string, FnTableEntry> function_table = 2;
|
||||
}
|
||||
|
||||
@@ -122,7 +123,7 @@ service ObjStore {
|
||||
}
|
||||
|
||||
message InvokeCallRequest {
|
||||
|
||||
Call call = 1;
|
||||
}
|
||||
|
||||
message InvokeCallReply {
|
||||
|
||||
+1
-6
@@ -33,14 +33,9 @@ message Value {
|
||||
Obj obj = 2; // for pass by value
|
||||
}
|
||||
|
||||
// Deprecated
|
||||
message Values {
|
||||
repeated Value value = 1;
|
||||
}
|
||||
|
||||
message Call {
|
||||
string name = 1;
|
||||
repeated Value args = 2;
|
||||
repeated Value arg = 2;
|
||||
repeated uint64 result = 3;
|
||||
}
|
||||
|
||||
|
||||
+73
-2
@@ -21,7 +21,7 @@ Status ObjStoreClient::upload_data_to(slice data, ObjRef objref, ObjStore::Stub&
|
||||
return writer->Finish();
|
||||
}
|
||||
|
||||
void ObjStoreServiceImpl::allocate_memory(ObjRef objref, size_t size) {
|
||||
void ObjStoreServer::allocate_memory(ObjRef objref, size_t size) {
|
||||
std::ostringstream stream;
|
||||
stream << "obj-" << memory_names_.size();
|
||||
std::string name = stream.str();
|
||||
@@ -37,8 +37,79 @@ void ObjStoreServiceImpl::allocate_memory(ObjRef objref, size_t size) {
|
||||
object.ptr.len = size;
|
||||
}
|
||||
|
||||
ObjStore::Stub& ObjStoreServer::get_objstore_stub(const std::string& objstore_address) {
|
||||
auto iter = objstores_.find(objstore_address);
|
||||
if (iter != objstores_.end())
|
||||
return *(iter->second);
|
||||
auto channel = grpc::CreateChannel(objstore_address, grpc::InsecureChannelCredentials());
|
||||
objstores_.emplace(objstore_address, ObjStore::NewStub(channel));
|
||||
return *objstores_[objstore_address];
|
||||
}
|
||||
|
||||
Status ObjStoreServer::DeliverObj(ServerContext* context, const DeliverObjRequest* request, AckReply* reply) {
|
||||
ObjStore::Stub& stub = get_objstore_stub(request->objstore_address());
|
||||
ObjRef objref = request->objref();
|
||||
// TODO: Have to introduce wait condition
|
||||
return ObjStoreClient::upload_data_to(memory_[objref].ptr, objref, stub);
|
||||
}
|
||||
|
||||
Status ObjStoreServer::DebugInfo(ServerContext* context, const DebugInfoRequest* request, DebugInfoReply* reply) {
|
||||
for (const auto& entry : memory_) {
|
||||
reply->add_objref(entry.first);
|
||||
}
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
Status ObjStoreServer::GetObj(ServerContext* context, const GetObjRequest* request, GetObjReply* reply) {
|
||||
ObjRef objref = request->objref();
|
||||
std::cout << "getobj lock";
|
||||
memory_lock_.lock();
|
||||
shared_object& object = memory_[objref];
|
||||
reply->set_bucket(object.name);
|
||||
auto handle = object.memory->get_handle_from_address(object.ptr.data);
|
||||
reply->set_handle(handle);
|
||||
reply->set_size(object.ptr.len);
|
||||
memory_lock_.unlock();
|
||||
std::cout << "getobj unlock";
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
Status ObjStoreServer::StreamObj(ServerContext* context, ServerReader<ObjChunk>* reader, AckReply* reply) {
|
||||
std::cout << "stream obj lock" << std::endl;
|
||||
memory_lock_.lock();
|
||||
ObjChunk chunk;
|
||||
ObjRef objref = 0;
|
||||
size_t totalsize = 0;
|
||||
if (reader->Read(&chunk)) {
|
||||
objref = chunk.objref();
|
||||
totalsize = chunk.totalsize();
|
||||
allocate_memory(objref, totalsize);
|
||||
}
|
||||
size_t num_bytes = 0;
|
||||
char* data = memory_[objref].ptr.data;
|
||||
|
||||
std::cout << "before loop " << totalsize << std::endl;
|
||||
|
||||
do {
|
||||
if (num_bytes + chunk.data().size() > totalsize) {
|
||||
std::cout << "cancelled" << std::endl;
|
||||
memory_lock_.unlock();
|
||||
return Status::CANCELLED;
|
||||
}
|
||||
std::memcpy(data, chunk.data().c_str(), chunk.data().size());
|
||||
data += chunk.data().size();
|
||||
num_bytes += chunk.data().size();
|
||||
std::cout << "looping " << num_bytes << std::endl;
|
||||
} while (reader->Read(&chunk));
|
||||
|
||||
std::cout << "finished" << std::endl;
|
||||
memory_lock_.unlock();
|
||||
std::cout << "stream obj unlock" << std::endl;
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
void start_objstore(const char* objstore_address) {
|
||||
ObjStoreServiceImpl service;
|
||||
ObjStoreServer service;
|
||||
ServerBuilder builder;
|
||||
|
||||
builder.AddListeningPort(std::string(objstore_address), grpc::InsecureServerCredentials());
|
||||
|
||||
+16
-80
@@ -37,97 +37,33 @@ struct shared_object {
|
||||
slice ptr;
|
||||
};
|
||||
|
||||
class ObjStoreServiceImpl final : public ObjStore::Service {
|
||||
std::vector<std::string> memory_names_;
|
||||
std::unordered_map<ObjRef, shared_object> memory_;
|
||||
std::mutex memory_lock_;
|
||||
size_t page_size = mapped_region::get_page_size();
|
||||
std::unordered_map<std::string, std::unique_ptr<ObjStore::Stub>> objstores_;
|
||||
|
||||
void allocate_memory(ObjRef objref, size_t size);
|
||||
|
||||
// check if we already connected to the other objstore, if yes, return reference to connection, otherwise connect
|
||||
ObjStore::Stub& get_objstore_stub(const std::string& objstore_address) {
|
||||
auto iter = objstores_.find(objstore_address);
|
||||
if (iter != objstores_.end())
|
||||
return *(iter->second);
|
||||
auto channel = grpc::CreateChannel(objstore_address, grpc::InsecureChannelCredentials());
|
||||
objstores_.emplace(objstore_address, ObjStore::NewStub(channel));
|
||||
return *objstores_[objstore_address];
|
||||
}
|
||||
|
||||
class ObjStoreServer final : public ObjStore::Service {
|
||||
public:
|
||||
ObjStoreServiceImpl() {}
|
||||
ObjStoreServer() {}
|
||||
|
||||
~ObjStoreServiceImpl() {
|
||||
~ObjStoreServer() {
|
||||
for (const auto& segment_name : memory_names_) {
|
||||
shared_memory_object::remove(segment_name.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
Status DeliverObj(ServerContext* context, const DeliverObjRequest* request, AckReply* reply) override {
|
||||
ObjStore::Stub& stub = get_objstore_stub(request->objstore_address());
|
||||
ObjRef objref = request->objref();
|
||||
Status DeliverObj(ServerContext* context, const DeliverObjRequest* request, AckReply* reply) override;
|
||||
|
||||
// TODO: Have to introduce wait condition
|
||||
Status DebugInfo(ServerContext* context, const DebugInfoRequest* request, DebugInfoReply* reply) override;
|
||||
|
||||
return ObjStoreClient::upload_data_to(memory_[objref].ptr, objref, stub);
|
||||
}
|
||||
Status GetObj(ServerContext* context, const GetObjRequest* request, GetObjReply* reply) override;
|
||||
|
||||
Status DebugInfo(ServerContext* context, const DebugInfoRequest* request, DebugInfoReply* reply) override {
|
||||
for (const auto& entry : memory_) {
|
||||
reply->add_objref(entry.first);
|
||||
}
|
||||
return Status::OK;
|
||||
}
|
||||
Status StreamObj(ServerContext* context, ServerReader<ObjChunk>* reader, AckReply* reply) override;
|
||||
private:
|
||||
void allocate_memory(ObjRef objref, size_t size);
|
||||
// check if we already connected to the other objstore, if yes, return reference to connection, otherwise connect
|
||||
ObjStore::Stub& get_objstore_stub(const std::string& objstore_address);
|
||||
|
||||
Status GetObj(ServerContext* context, const GetObjRequest* request, GetObjReply* reply) override {
|
||||
ObjRef objref = request->objref();
|
||||
std::cout << "getobj lock";
|
||||
memory_lock_.lock();
|
||||
shared_object& object = memory_[objref];
|
||||
reply->set_bucket(object.name);
|
||||
auto handle = object.memory->get_handle_from_address(object.ptr.data);
|
||||
reply->set_handle(handle);
|
||||
reply->set_size(object.ptr.len);
|
||||
memory_lock_.unlock();
|
||||
std::cout << "getobj unlock";
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
Status StreamObj(ServerContext* context, ServerReader<ObjChunk>* reader, AckReply* reply) override {
|
||||
std::cout << "stream obj lock" << std::endl;
|
||||
memory_lock_.lock();
|
||||
ObjChunk chunk;
|
||||
ObjRef objref = 0;
|
||||
size_t totalsize = 0;
|
||||
if (reader->Read(&chunk)) {
|
||||
objref = chunk.objref();
|
||||
totalsize = chunk.totalsize();
|
||||
allocate_memory(objref, totalsize);
|
||||
}
|
||||
size_t num_bytes = 0;
|
||||
char* data = memory_[objref].ptr.data;
|
||||
|
||||
std::cout << "before loop " << totalsize << std::endl;
|
||||
|
||||
do {
|
||||
if (num_bytes + chunk.data().size() > totalsize) {
|
||||
std::cout << "cancelled" << std::endl;
|
||||
memory_lock_.unlock();
|
||||
return Status::CANCELLED;
|
||||
}
|
||||
std::memcpy(data, chunk.data().c_str(), chunk.data().size());
|
||||
data += chunk.data().size();
|
||||
num_bytes += chunk.data().size();
|
||||
std::cout << "looping " << num_bytes << std::endl;
|
||||
} while (reader->Read(&chunk));
|
||||
|
||||
std::cout << "finished" << std::endl;
|
||||
memory_lock_.unlock();
|
||||
std::cout << "stream obj unlock" << std::endl;
|
||||
return Status::OK;
|
||||
}
|
||||
std::vector<std::string> memory_names_;
|
||||
std::unordered_map<ObjRef, shared_object> memory_;
|
||||
std::mutex memory_lock_;
|
||||
size_t page_size = mapped_region::get_page_size();
|
||||
std::unordered_map<std::string, std::unique_ptr<ObjStore::Stub>> objstores_;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
+6
-6
@@ -3,25 +3,25 @@
|
||||
Worker* orch_create_context(const char* server_addr, const char* worker_addr, const char* objstore_addr) {
|
||||
auto server_channel = grpc::CreateChannel(server_addr, grpc::InsecureChannelCredentials());
|
||||
auto objstore_channel = grpc::CreateChannel(objstore_addr, grpc::InsecureChannelCredentials());
|
||||
Worker* worker = new Worker(server_channel, objstore_channel);
|
||||
Worker* worker = new Worker(std::string(worker_addr), server_channel, objstore_channel);
|
||||
worker->register_worker(std::string(worker_addr), std::string(objstore_addr));
|
||||
return worker;
|
||||
}
|
||||
|
||||
size_t orch_remote_call(Worker* worker, RemoteCallRequest* request) {
|
||||
return worker->RemoteCall(request);
|
||||
return worker->remote_call(request);
|
||||
}
|
||||
|
||||
void orch_main_loop(Worker* worker) {
|
||||
worker->MainLoop();
|
||||
Call* orch_main_loop(Worker* worker) {
|
||||
return worker->main_loop();
|
||||
}
|
||||
|
||||
size_t orch_push(Worker* worker, Obj* obj) {
|
||||
return worker->PushObj(obj);
|
||||
return worker->push_obj(obj);
|
||||
}
|
||||
|
||||
slice orch_get_serialized_obj(Worker* worker, ObjRef objref) {
|
||||
return worker->GetSerializedObj(objref);
|
||||
return worker->get_serialized_obj(objref);
|
||||
}
|
||||
|
||||
void orch_register_function(Worker* worker, const char* name, size_t num_return_vals) {
|
||||
|
||||
+2
-3
@@ -1,5 +1,3 @@
|
||||
|
||||
|
||||
extern "C" {
|
||||
|
||||
struct slice {
|
||||
@@ -14,7 +12,8 @@ struct Value;
|
||||
Worker* orch_create_context(const char* server_addr, const char* worker_addr, const char* objstore_addr);
|
||||
size_t orch_remote_call(Worker* context, RemoteCallRequest* request);
|
||||
size_t orch_push(Worker* context, Obj* value);
|
||||
void orch_main_loop(Worker* worker);
|
||||
Call* orch_main_loop(Worker* worker);
|
||||
slice orch_get_serialized_obj(Worker* worker, size_t objref);
|
||||
void orch_register_function(Worker* worker, const char* name, size_t num_return_vals);
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,165 @@
|
||||
#include "scheduler.h"
|
||||
|
||||
size_t Scheduler::add_task(const Call& task) {
|
||||
fntable_lock_.lock();
|
||||
size_t num_return_vals = fntable_[task.name()].num_return_vals();
|
||||
fntable_lock_.unlock();
|
||||
std::unique_ptr<Call> task_ptr(new Call(task));
|
||||
tasks_lock_.lock();
|
||||
tasks_.emplace_back(std::move(task_ptr));
|
||||
tasks_lock_.unlock();
|
||||
return num_return_vals;
|
||||
}
|
||||
|
||||
void Scheduler::schedule() {
|
||||
// TODO: work out a better strategy here
|
||||
WorkerId workerid = 0;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(avail_workers_lock_);
|
||||
if (avail_workers_.size() == 0)
|
||||
return;
|
||||
workerid = avail_workers_.back();
|
||||
std::cout << "got available worker" << workerid << std::endl;
|
||||
avail_workers_.pop_back();
|
||||
}
|
||||
// TODO: think about locking here
|
||||
for (auto it = tasks_.begin(); it != tasks_.end(); ++it) {
|
||||
const Call& task = *(*it);
|
||||
auto& workers = fntable_[task.name()].workers();
|
||||
if (std::binary_search(workers.begin(), workers.end(), workerid) && can_run(task)) {
|
||||
submit_task(std::move(*it), workerid);
|
||||
tasks_.erase(it);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Scheduler::submit_task(std::unique_ptr<Call> call, WorkerId workerid) {
|
||||
ClientContext context;
|
||||
InvokeCallRequest request;
|
||||
InvokeCallReply reply;
|
||||
std::cout << "sending arguments now" << std::endl;
|
||||
for (size_t i = 0; i < call->arg_size(); ++i) {
|
||||
if (!call->arg(i).has_obj()) {
|
||||
std::cout << "need to send object ref" << call->arg(i).ref() << std::endl;
|
||||
std::lock_guard<std::mutex> objtable_lock(objtable_lock_);
|
||||
auto &objstores = objtable_[call->arg(i).ref()];
|
||||
std::lock_guard<std::mutex> workers_lock(workers_lock_);
|
||||
if (!std::binary_search(objstores.begin(), objstores.end(), workers_[workerid].objstoreid)) {
|
||||
std::cout << "have to send" << std::endl;
|
||||
std::exit(1);
|
||||
}
|
||||
// if (objstoreid != workers_[workerid].objstoreid) {
|
||||
// std::lock_guard<std::mutex> objstores_lock(objstores_lock_);
|
||||
// objstores_.
|
||||
// }
|
||||
}
|
||||
}
|
||||
request.set_allocated_call(call.release()); // protobuf object takes ownership
|
||||
Status status = workers_[workerid].worker_stub->InvokeCall(&context, request, &reply);
|
||||
}
|
||||
|
||||
bool Scheduler::can_run(const Call& task) {
|
||||
std::lock_guard<std::mutex> lock(objtable_lock_);
|
||||
for (int i = 0; i < task.arg_size(); ++i) {
|
||||
if (!task.arg(i).has_obj()) {
|
||||
if (objtable_[task.arg(i).ref()].size() == 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
WorkerId Scheduler::register_worker(const std::string& worker_address, const std::string& objstore_address) {
|
||||
ObjStoreId objstoreid = std::numeric_limits<size_t>::max();
|
||||
objstores_lock_.lock();
|
||||
for (size_t i = 0; i < objstores_.size(); ++i) {
|
||||
std::cout << "adress: " << objstores_[i].address << std::endl;
|
||||
std::cout << "my adress: " << objstore_address << std::endl;
|
||||
if (objstores_[i].address == objstore_address) {
|
||||
objstoreid = i;
|
||||
}
|
||||
}
|
||||
if (objstoreid == std::numeric_limits<size_t>::max()) {
|
||||
// register objstore
|
||||
objstoreid = objstores_.size();
|
||||
auto channel = grpc::CreateChannel(objstore_address, grpc::InsecureChannelCredentials());
|
||||
objstores_.push_back(ObjStoreHandle());
|
||||
objstores_[objstoreid].address = objstore_address;
|
||||
objstores_[objstoreid].channel = channel;
|
||||
objstores_[objstoreid].objstore_stub = ObjStore::NewStub(channel);
|
||||
}
|
||||
objstores_lock_.unlock();
|
||||
workers_lock_.lock();
|
||||
WorkerId workerid = workers_.size();
|
||||
workers_.push_back(WorkerHandle());
|
||||
auto channel = grpc::CreateChannel(worker_address, grpc::InsecureChannelCredentials());
|
||||
workers_[workerid].channel = channel;
|
||||
workers_[workerid].objstoreid = objstoreid;
|
||||
workers_[workerid].worker_stub = WorkerServer::NewStub(channel);
|
||||
workers_lock_.unlock();
|
||||
avail_workers_lock_.lock();
|
||||
avail_workers_.push_back(workerid);
|
||||
avail_workers_lock_.unlock();
|
||||
return workerid;
|
||||
}
|
||||
|
||||
ObjRef Scheduler::register_new_object() {
|
||||
objtable_lock_.lock();
|
||||
ObjRef result = objtable_.size();
|
||||
objtable_.push_back(std::vector<ObjStoreId>());
|
||||
objtable_lock_.unlock();
|
||||
return result;
|
||||
}
|
||||
|
||||
void Scheduler::add_location(ObjRef objref, ObjStoreId objstoreid) {
|
||||
objtable_lock_.lock();
|
||||
// do a binary search
|
||||
auto pos = std::lower_bound(objtable_[objref].begin(), objtable_[objref].end(), objstoreid);
|
||||
if (pos == objtable_[objref].end() || objstoreid < *pos) {
|
||||
objtable_[objref].insert(pos, objstoreid);
|
||||
}
|
||||
objtable_lock_.unlock();
|
||||
}
|
||||
|
||||
ObjStoreId Scheduler::get_store(WorkerId workerid) {
|
||||
workers_lock_.lock();
|
||||
ObjStoreId result = workers_[workerid].objstoreid;
|
||||
workers_lock_.unlock();
|
||||
return result;
|
||||
}
|
||||
|
||||
void Scheduler::register_function(const std::string& name, WorkerId workerid, size_t num_return_vals) {
|
||||
fntable_lock_.lock();
|
||||
FnInfo& info = fntable_[name];
|
||||
info.set_num_return_vals(num_return_vals);
|
||||
info.add_worker(workerid);
|
||||
fntable_lock_.unlock();
|
||||
}
|
||||
|
||||
void Scheduler::debug_info(const GetDebugInfoRequest& request, GetDebugInfoReply* reply) {
|
||||
if (request.do_scheduling()) {
|
||||
schedule();
|
||||
}
|
||||
fntable_lock_.lock();
|
||||
auto function_table = reply->mutable_function_table();
|
||||
for (const auto& entry : fntable_) {
|
||||
(*function_table)[entry.first].set_num_return_vals(entry.second.num_return_vals());
|
||||
for (const WorkerId& worker : entry.second.workers()) {
|
||||
(*function_table)[entry.first].add_workerid(worker);
|
||||
}
|
||||
}
|
||||
fntable_lock_.unlock();
|
||||
tasks_lock_.lock();
|
||||
for (const auto& entry : tasks_) {
|
||||
Call* call = reply->add_task();
|
||||
call->CopyFrom(*entry);
|
||||
}
|
||||
tasks_lock_.unlock();
|
||||
avail_workers_lock_.lock();
|
||||
for (const WorkerId& entry : avail_workers_) {
|
||||
reply->add_avail_worker(entry);
|
||||
}
|
||||
avail_workers_lock_.unlock();
|
||||
}
|
||||
|
||||
+30
-100
@@ -3,6 +3,8 @@
|
||||
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
|
||||
#include <grpc++/grpc++.h>
|
||||
|
||||
@@ -16,25 +18,52 @@ using grpc::ServerReader;
|
||||
using grpc::ServerContext;
|
||||
using grpc::Status;
|
||||
|
||||
using grpc::ClientContext;
|
||||
|
||||
using grpc::Channel;
|
||||
|
||||
struct WorkerHandle {
|
||||
std::shared_ptr<Channel> channel;
|
||||
std::unique_ptr<WorkerServer::Stub> worker_stub;
|
||||
ObjStoreId objstoreid;
|
||||
};
|
||||
|
||||
struct ObjStoreHandle {
|
||||
std::shared_ptr<Channel> channel;
|
||||
std::unique_ptr<ObjStore::Stub> objstore_stub;
|
||||
std::string address;
|
||||
};
|
||||
|
||||
class Scheduler {
|
||||
public:
|
||||
// returns number of return values of task
|
||||
size_t add_task(const Call& task);
|
||||
// assign a task to a worker
|
||||
void schedule();
|
||||
// execute a task on a worker and ship required object references
|
||||
void submit_task(std::unique_ptr<Call> call, WorkerId workerid);
|
||||
// checks if the dependencies of the task are met
|
||||
bool can_run(const Call& task);
|
||||
// register a worker and its object store (if it has not been registered yet)
|
||||
WorkerId register_worker(const std::string& worker_address, const std::string& objstore_address);
|
||||
// register a new object with the scheduler and return its object reference
|
||||
ObjRef register_new_object();
|
||||
// register the location of the object reference in the object table
|
||||
void add_location(ObjRef objref, ObjStoreId objstoreid);
|
||||
// get object store associated with a workerid
|
||||
ObjStoreId get_store(WorkerId workerid);
|
||||
// register a function with the scheduler
|
||||
void register_function(const std::string& name, WorkerId workerid, size_t num_return_vals);
|
||||
// get debugging information for the scheduler
|
||||
void debug_info(const GetDebugInfoRequest& request, GetDebugInfoReply* reply);
|
||||
private:
|
||||
// Vector of all workers registered in the system. Their index in this vector
|
||||
// is the workerid.
|
||||
std::vector<WorkerHandle> workers_;
|
||||
std::mutex workers_lock_;
|
||||
// Vector of all workers that are currently idle.
|
||||
std::vector<WorkerId> available_workers_;
|
||||
std::vector<WorkerId> avail_workers_;
|
||||
std::mutex avail_workers_lock_;
|
||||
// Vector of all object stores registered in the system. Their index in this
|
||||
// vector is the objstoreid.
|
||||
std::vector<ObjStoreHandle> objstores_;
|
||||
@@ -48,105 +77,6 @@ class Scheduler {
|
||||
// List of pending tasks.
|
||||
std::deque<std::unique_ptr<Call> > tasks_;
|
||||
std::mutex tasks_lock_;
|
||||
public:
|
||||
// returns number of return values of task
|
||||
size_t add_task(const Call& task) {
|
||||
fntable_lock_.lock();
|
||||
size_t num_return_vals = fntable_[task.name()].num_return_vals();
|
||||
fntable_lock_.unlock();
|
||||
std::unique_ptr<Call> task_ptr(new Call(task)); // TODO: perform copy outside
|
||||
tasks_lock_.lock();
|
||||
tasks_.emplace_back(std::move(task_ptr));
|
||||
tasks_lock_.unlock();
|
||||
return num_return_vals;
|
||||
}
|
||||
WorkerId register_worker(const std::string& worker_address, const std::string& objstore_address) {
|
||||
ObjStoreId objstoreid = std::numeric_limits<size_t>::max();
|
||||
objstores_lock_.lock();
|
||||
for (size_t i = 0; i < objstores_.size(); ++i) {
|
||||
std::cout << "adress: " << objstores_[i].address << std::endl;
|
||||
std::cout << "my adress: " << objstore_address << std::endl;
|
||||
if (objstores_[i].address == objstore_address) {
|
||||
objstoreid = i;
|
||||
}
|
||||
}
|
||||
if (objstoreid == std::numeric_limits<size_t>::max()) {
|
||||
// throw objstore_not_registered_error("objectstore not registered");
|
||||
std::cout << "bad bad bad" << std::endl;
|
||||
}
|
||||
objstores_lock_.unlock();
|
||||
workers_lock_.lock();
|
||||
WorkerId result = workers_.size();
|
||||
workers_.push_back(WorkerHandle());
|
||||
workers_[result].channel = grpc::CreateChannel(worker_address, grpc::InsecureChannelCredentials());
|
||||
workers_[result].objstoreid = objstoreid;
|
||||
workers_lock_.unlock();
|
||||
return result;
|
||||
}
|
||||
ObjStoreId register_objstore(const std::string& objstore_address) {
|
||||
// auto handle = ObjStoreHandle(objstore_address);
|
||||
// auto handlecopy = handle;
|
||||
// auto handle = ObjStoreHandle("0.0.0.0:22222");
|
||||
objstores_lock_.lock();
|
||||
std::cout << "capacity" << objstores_.capacity() << std::endl;
|
||||
ObjStoreId result = objstores_.size();
|
||||
// auto handle = ObjStoreHandle(objstore_address);
|
||||
// objstores_.emplace_back(objstore_address);
|
||||
objstores_.push_back(ObjStoreHandle());
|
||||
|
||||
objstores_[result].channel = grpc::CreateChannel(objstore_address, grpc::InsecureChannelCredentials());
|
||||
objstores_[result].address = std::string(objstore_address);
|
||||
|
||||
// auto handlecopy = handle;
|
||||
// auto handle = grpc::CreateChannel(objstore_address, grpc::InsecureChannelCredentials());
|
||||
// auto handlecopy = grpc::CreateChannel(objstore_address, grpc::InsecureChannelCredentials());
|
||||
objstores_lock_.unlock();
|
||||
return result;
|
||||
}
|
||||
ObjRef register_new_object() {
|
||||
objtable_lock_.lock();
|
||||
ObjRef result = objtable_.size();
|
||||
objtable_.push_back(std::vector<ObjStoreId>());
|
||||
objtable_lock_.unlock();
|
||||
return result;
|
||||
}
|
||||
void add_objstore_to_obj(ObjRef objref, ObjStoreId objstoreid) {
|
||||
objtable_lock_.lock();
|
||||
// do a binary search
|
||||
auto pos = std::lower_bound(objtable_[objref].begin(), objtable_[objref].end(), objstoreid);
|
||||
if (pos == objtable_[objref].end() || objstoreid < *pos) {
|
||||
objtable_[objref].insert(pos, objstoreid);
|
||||
}
|
||||
objtable_lock_.unlock();
|
||||
}
|
||||
ObjStoreId get_store(WorkerId workerid) {
|
||||
workers_lock_.lock();
|
||||
ObjStoreId result = workers_[workerid].objstoreid;
|
||||
workers_lock_.unlock();
|
||||
return result;
|
||||
}
|
||||
void register_function(const std::string& name, WorkerId workerid, size_t num_return_vals) {
|
||||
fntable_lock_.lock();
|
||||
FnInfo& info = fntable_[name];
|
||||
info.set_num_return_vals(num_return_vals);
|
||||
info.add_worker(workerid);
|
||||
fntable_lock_.unlock();
|
||||
}
|
||||
void debug_info(GetDebugInfoReply* debug_info) {
|
||||
fntable_lock_.lock();
|
||||
for (const auto& entry : fntable_) {
|
||||
auto function_table = debug_info->mutable_function_table();
|
||||
(*function_table)[entry.first].set_num_return_vals(entry.second.num_return_vals());
|
||||
// TODO: set workerid
|
||||
}
|
||||
fntable_lock_.unlock();
|
||||
tasks_lock_.lock();
|
||||
for (const auto& entry : tasks_) {
|
||||
Call* call = debug_info->add_task();
|
||||
call->CopyFrom(*entry);
|
||||
}
|
||||
tasks_lock_.unlock();
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@@ -12,7 +12,7 @@ Status SchedulerServerServiceImpl::RemoteCall(ServerContext* context, const Remo
|
||||
Status SchedulerServerServiceImpl::PushObj(ServerContext* context, const PushObjRequest* request, PushObjReply* reply) {
|
||||
ObjRef objref = scheduler_->register_new_object();
|
||||
ObjStoreId objstoreid = scheduler_->get_store(request->workerid());
|
||||
scheduler_->add_objstore_to_obj(objref, objstoreid);
|
||||
scheduler_->add_location(objref, objstoreid);
|
||||
reply->set_objref(objref);
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
@@ -22,9 +22,11 @@ public:
|
||||
}
|
||||
Status RegisterWorker(ServerContext* context, const RegisterWorkerRequest* request, RegisterWorkerReply* reply) override {
|
||||
WorkerId workerid = scheduler_->register_worker(request->worker_address(), request->objstore_address());
|
||||
std::cout << "registered worker with workerid" << workerid << std::endl;
|
||||
reply->set_workerid(workerid);
|
||||
return Status::OK;
|
||||
}
|
||||
/*
|
||||
Status RegisterObjStore(ServerContext* context, const RegisterObjStoreRequest* request, RegisterObjStoreReply* reply) override {
|
||||
try {
|
||||
reply->set_objstoreid(scheduler_->register_objstore(request->address()));
|
||||
@@ -33,12 +35,14 @@ public:
|
||||
}
|
||||
return Status::OK;
|
||||
}
|
||||
*/
|
||||
Status RegisterFunction(ServerContext* context, const RegisterFunctionRequest* request, AckReply* reply) override {
|
||||
std::cout << "RegisterFunction: workerid is" << request->workerid() << std::endl;
|
||||
scheduler_->register_function(request->fnname(), request->workerid(), request->num_return_vals());
|
||||
return Status::OK;
|
||||
}
|
||||
Status GetDebugInfo(ServerContext* context, const GetDebugInfoRequest* request, GetDebugInfoReply* reply) override {
|
||||
scheduler_->debug_info(reply);
|
||||
scheduler_->debug_info(*request, reply);
|
||||
return Status::OK;
|
||||
}
|
||||
};
|
||||
|
||||
+134
@@ -0,0 +1,134 @@
|
||||
# include "worker.h"
|
||||
|
||||
Status WorkerServiceImpl::InvokeCall(ServerContext* context, const InvokeCallRequest* request, InvokeCallReply* reply) {
|
||||
call_ = request->call();
|
||||
std::cout << "invoke call request" << std::endl;
|
||||
try {
|
||||
Call* callptr = &call_;
|
||||
message_queue mq(open_only, worker_address_.c_str());
|
||||
std::cout << "before send: num args" << call_.arg_size() << std::endl;
|
||||
mq.send(&callptr, sizeof(Call*), 0);
|
||||
}
|
||||
catch(interprocess_exception &ex){
|
||||
message_queue::remove(worker_address_.c_str());
|
||||
std::cout << ex.what() << std::endl;
|
||||
// TODO: return Status;
|
||||
}
|
||||
message_queue::remove(worker_address_.c_str());
|
||||
std::cout << "notified server" << std::endl;
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
size_t Worker::remote_call(RemoteCallRequest* request) {
|
||||
RemoteCallReply reply;
|
||||
ClientContext context;
|
||||
Status status = scheduler_stub_->RemoteCall(&context, *request, &reply);
|
||||
// TODO: Return results: return reply.result(0);
|
||||
}
|
||||
|
||||
void Worker::register_worker(const std::string& worker_address, const std::string& objstore_address) {
|
||||
RegisterWorkerRequest request;
|
||||
request.set_worker_address(worker_address);
|
||||
request.set_objstore_address(objstore_address);
|
||||
RegisterWorkerReply reply;
|
||||
ClientContext context;
|
||||
Status status = scheduler_stub_->RegisterWorker(&context, request, &reply);
|
||||
workerid_ = reply.workerid();
|
||||
return;
|
||||
}
|
||||
|
||||
ObjRef Worker::push_obj(Obj* obj) {
|
||||
// first get objref for the new object
|
||||
PushObjRequest push_request;
|
||||
PushObjReply push_reply;
|
||||
ClientContext push_context;
|
||||
Status push_status = scheduler_stub_->PushObj(&push_context, push_request, &push_reply);
|
||||
ObjRef objref = push_reply.objref();
|
||||
// then stream the object to the object store
|
||||
ObjChunk chunk;
|
||||
std::string data;
|
||||
obj->SerializeToString(&data);
|
||||
size_t totalsize = data.size();
|
||||
ClientContext context;
|
||||
AckReply reply;
|
||||
std::unique_ptr<ClientWriter<ObjChunk> > writer(
|
||||
objstore_stub_->StreamObj(&context, &reply));
|
||||
const char* head = data.c_str();
|
||||
for (size_t i = 0; i < data.length(); i += CHUNK_SIZE) {
|
||||
chunk.set_objref(objref);
|
||||
std::cout << "chunk totalsize" << std::endl;
|
||||
chunk.set_totalsize(totalsize);
|
||||
chunk.set_data(head + i, std::min(CHUNK_SIZE, data.length() - i));
|
||||
if (!writer->Write(chunk)) {
|
||||
std::cout << "write failed" << std::endl;
|
||||
// TODO: Better error handling: throw std::runtime_error("write failed");
|
||||
}
|
||||
}
|
||||
writer->WritesDone();
|
||||
Status status = writer->Finish();
|
||||
return objref;
|
||||
}
|
||||
|
||||
slice Worker::get_serialized_obj(ObjRef objref) {
|
||||
ClientContext context;
|
||||
GetObjRequest request;
|
||||
request.set_objref(objref);
|
||||
GetObjReply reply;
|
||||
objstore_stub_->GetObj(&context, request, &reply);
|
||||
segment_ = managed_shared_memory(open_only, reply.bucket().c_str());
|
||||
slice slice;
|
||||
slice.data = static_cast<char*>(segment_.get_address_from_handle(reply.handle()));
|
||||
slice.len = reply.size();
|
||||
return slice;
|
||||
}
|
||||
|
||||
void Worker::register_function(const std::string& name, size_t num_return_vals) {
|
||||
ClientContext context;
|
||||
RegisterFunctionRequest request;
|
||||
request.set_fnname(name);
|
||||
request.set_num_return_vals(num_return_vals);
|
||||
request.set_workerid(workerid_);
|
||||
AckReply reply;
|
||||
scheduler_stub_->RegisterFunction(&context, request, &reply);
|
||||
}
|
||||
|
||||
void start_worker_server(const char* server_address) {
|
||||
WorkerServiceImpl service(server_address);
|
||||
ServerBuilder builder;
|
||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
||||
builder.RegisterService(&service);
|
||||
std::unique_ptr<Server> server(builder.BuildAndStart());
|
||||
std::cout << "Server listening on " << server_address << std::endl;
|
||||
server->Wait();
|
||||
}
|
||||
|
||||
// Communication between the WorkerServer and the Worker happens via a message
|
||||
// queue. This is because the Python interpreter needs to be single threaded
|
||||
// (in our case running in the main thread), whereas the WorkerService will
|
||||
// run in a separate thread and potentially utilize multiple threads.
|
||||
Call* Worker::main_loop() {
|
||||
// start the worker server
|
||||
worker_server_thread_ = std::thread(start_worker_server, worker_address_.c_str());
|
||||
// process the next call
|
||||
return receive(worker_address_.c_str());
|
||||
}
|
||||
|
||||
Call* receive(const char* message_queue_name) {
|
||||
try {
|
||||
message_queue::remove(message_queue_name);
|
||||
message_queue mq(create_only, message_queue_name, 1, sizeof(Call*));
|
||||
unsigned int priority;
|
||||
message_queue::size_type recvd_size;
|
||||
Call* call;
|
||||
while (true) {
|
||||
mq.receive(&call, sizeof(Call*), recvd_size, priority);
|
||||
std::cout << "got call" << call << std::endl;
|
||||
std::cout << "after send: num args" << call->arg_size() << std::endl;
|
||||
return call;
|
||||
}
|
||||
}
|
||||
catch(interprocess_exception &ex){
|
||||
message_queue::remove(message_queue_name);
|
||||
std::cout << ex.what() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
+33
-100
@@ -7,6 +7,8 @@
|
||||
#include <thread>
|
||||
|
||||
#include <boost/interprocess/managed_shared_memory.hpp>
|
||||
#include <boost/interprocess/ipc/message_queue.hpp>
|
||||
|
||||
using namespace boost::interprocess;
|
||||
|
||||
#include <grpc++/grpc++.h>
|
||||
@@ -26,118 +28,49 @@ using grpc::ClientContext;
|
||||
using grpc::ClientWriter;
|
||||
|
||||
class WorkerServiceImpl final : public WorkerServer::Service {
|
||||
Status InvokeCall(ServerContext* context, const InvokeCallRequest* request,
|
||||
InvokeCallReply* reply) override {
|
||||
std::cout << "invoke call request" << std::endl;
|
||||
return Status::OK;
|
||||
}
|
||||
public:
|
||||
WorkerServiceImpl(const std::string& worker_address)
|
||||
: worker_address_(worker_address) {}
|
||||
Status InvokeCall(ServerContext* context, const InvokeCallRequest* request, InvokeCallReply* reply) override;
|
||||
private:
|
||||
std::string worker_address_;
|
||||
Call call_; // copy of the current call
|
||||
};
|
||||
|
||||
void start_server() {
|
||||
std::string server_address("0.0.0.0:50053");
|
||||
WorkerServiceImpl service;
|
||||
ServerBuilder builder;
|
||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
||||
builder.RegisterService(&service);
|
||||
std::unique_ptr<Server> server(builder.BuildAndStart());
|
||||
std::cout << "Server listening on " << server_address << std::endl;
|
||||
server->Wait();
|
||||
}
|
||||
void start_worker_server(const char* worker_addr);
|
||||
|
||||
Call* receive(const char* worker_addr);
|
||||
|
||||
class Worker {
|
||||
managed_shared_memory segment;
|
||||
public:
|
||||
Worker(std::shared_ptr<Channel> scheduler_channel, std::shared_ptr<Channel> objstore_channel)
|
||||
: scheduler_stub_(SchedulerServer::NewStub(scheduler_channel)),
|
||||
Worker(const std::string& worker_address, std::shared_ptr<Channel> scheduler_channel, std::shared_ptr<Channel> objstore_channel)
|
||||
: worker_address_(worker_address),
|
||||
scheduler_stub_(SchedulerServer::NewStub(scheduler_channel)),
|
||||
objstore_stub_(ObjStore::NewStub(objstore_channel))
|
||||
{}
|
||||
|
||||
size_t RemoteCall(RemoteCallRequest* request) {
|
||||
// RemoteCallReply reply;
|
||||
// ClientContext context;
|
||||
|
||||
// Status status = stub_->RemoteCall(&context, *request, &reply);
|
||||
|
||||
// return reply.result();
|
||||
return 42;
|
||||
}
|
||||
|
||||
void register_worker(const std::string& worker_address, const std::string& objstore_address) {
|
||||
RegisterWorkerRequest request;
|
||||
request.set_worker_address(worker_address);
|
||||
request.set_objstore_address(objstore_address);
|
||||
RegisterWorkerReply reply;
|
||||
ClientContext context;
|
||||
Status status = scheduler_stub_->RegisterWorker(&context, request, &reply);
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t CHUNK_SIZE = 8 * 1024;
|
||||
|
||||
ObjRef PushObj(Obj* obj) {
|
||||
// first get objref for the new object
|
||||
PushObjRequest push_request;
|
||||
PushObjReply push_reply;
|
||||
ClientContext push_context;
|
||||
Status push_status = scheduler_stub_->PushObj(&push_context, push_request, &push_reply);
|
||||
ObjRef objref = push_reply.objref();
|
||||
ObjChunk chunk;
|
||||
std::string data;
|
||||
obj->SerializeToString(&data);
|
||||
size_t totalsize = data.size();
|
||||
ClientContext context;
|
||||
AckReply reply;
|
||||
std::unique_ptr<ClientWriter<ObjChunk> > writer(
|
||||
objstore_stub_->StreamObj(&context, &reply));
|
||||
const char* head = data.c_str();
|
||||
for (size_t i = 0; i < data.length(); i += CHUNK_SIZE) {
|
||||
chunk.set_objref(objref);
|
||||
std::cout << "chunk totalsize" << std::endl;
|
||||
chunk.set_totalsize(totalsize);
|
||||
chunk.set_data(head + i, std::min(CHUNK_SIZE, data.length() - i));
|
||||
if (!writer->Write(chunk)) {
|
||||
std::cout << "write failed" << std::endl;
|
||||
// throw std::runtime_error("write failed");
|
||||
}
|
||||
}
|
||||
writer->WritesDone();
|
||||
Status status = writer->Finish();
|
||||
return objref;
|
||||
}
|
||||
|
||||
slice GetSerializedObj(ObjRef objref) {
|
||||
ClientContext context;
|
||||
GetObjRequest request;
|
||||
request.set_objref(objref);
|
||||
GetObjReply reply;
|
||||
objstore_stub_->GetObj(&context, request, &reply);
|
||||
segment = managed_shared_memory(open_only, reply.bucket().c_str());
|
||||
slice slice;
|
||||
slice.data = static_cast<char*>(segment.get_address_from_handle(reply.handle()));
|
||||
slice.len = reply.size();
|
||||
return slice;
|
||||
}
|
||||
|
||||
void register_function(const std::string& name, size_t num_return_vals) {
|
||||
ClientContext context;
|
||||
RegisterFunctionRequest request;
|
||||
request.set_fnname(name);
|
||||
request.set_num_return_vals(num_return_vals);
|
||||
AckReply reply;
|
||||
scheduler_stub_->RegisterFunction(&context, request, &reply);
|
||||
}
|
||||
|
||||
void MainLoop() {
|
||||
scheduler_thread = std::thread(start_server);
|
||||
|
||||
}
|
||||
|
||||
|
||||
// submit a remote call to the scheduler
|
||||
size_t remote_call(RemoteCallRequest* request);
|
||||
// send request to the scheduler to register this worker
|
||||
void register_worker(const std::string& worker_address, const std::string& objstore_address);
|
||||
// push object to local object store, register it with the server and return object reference
|
||||
ObjRef push_obj(Obj* obj);
|
||||
// retrieve serialized object from local object store
|
||||
slice get_serialized_obj(ObjRef objref);
|
||||
// register function with scheduler
|
||||
void register_function(const std::string& name, size_t num_return_vals);
|
||||
// start the main loop
|
||||
Call* main_loop();
|
||||
|
||||
private:
|
||||
const size_t CHUNK_SIZE = 8 * 1024;
|
||||
std::unique_ptr<SchedulerServer::Stub> scheduler_stub_;
|
||||
std::unique_ptr<ObjStore::Stub> objstore_stub_;
|
||||
std::thread scheduler_thread;
|
||||
std::thread worker_server_thread_;
|
||||
std::thread other_thread_;
|
||||
managed_shared_memory segment_;
|
||||
WorkerId workerid_;
|
||||
std::string worker_address_;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
+63
-84
@@ -4,6 +4,10 @@ import orchpy.services as services
|
||||
import orchpy.worker as worker
|
||||
import numpy as np
|
||||
import time
|
||||
import subprocess32 as subprocess
|
||||
import os
|
||||
|
||||
from google.protobuf.text_format import *
|
||||
|
||||
from grpc.beta import implementations
|
||||
import orchestra_pb2
|
||||
@@ -27,12 +31,25 @@ class UnisonTest(unittest.TestCase):
|
||||
b = unison.deserialize_args(res)
|
||||
self.assertTrue((a == b).all())
|
||||
|
||||
a = [unison.ObjRef(42, int)]
|
||||
res = unison.serialize_args(a)
|
||||
b = unison.deserialize_args(res)
|
||||
self.assertEqual(a, b)
|
||||
|
||||
TIMEOUT_SECONDS = 5
|
||||
|
||||
def produce_data(num_chunks):
|
||||
for _ in range(num_chunks):
|
||||
yield orchestra_pb2.ObjChunk(objref=1, totalsize=1000, data=b"hello world")
|
||||
|
||||
def connect_to_scheduler(host, port):
|
||||
channel = implementations.insecure_channel(host, port)
|
||||
return orchestra_pb2.beta_create_SchedulerServer_stub(channel)
|
||||
|
||||
def connect_to_objstore(host, port):
|
||||
channel = implementations.insecure_channel(host, port)
|
||||
return orchestra_pb2.beta_create_ObjStore_stub(channel)
|
||||
|
||||
class ObjStoreTest(unittest.TestCase):
|
||||
|
||||
"""Test setting up object stores, transfering data between them and retrieving data to a client"""
|
||||
@@ -40,123 +57,85 @@ class ObjStoreTest(unittest.TestCase):
|
||||
services.start_scheduler("0.0.0.0:22221")
|
||||
services.start_objstore("0.0.0.0:22222")
|
||||
services.start_objstore("0.0.0.0:22223")
|
||||
time.sleep(0.5)
|
||||
|
||||
scheduler_channel = implementations.insecure_channel('localhost', 22221)
|
||||
scheduler_stub = orchestra_pb2.beta_create_SchedulerServer_stub(scheduler_channel)
|
||||
objstore1_channel = implementations.insecure_channel('localhost', 22222)
|
||||
objstore1_stub = orchestra_pb2.beta_create_ObjStore_stub(objstore1_channel)
|
||||
objstore2_channel = implementations.insecure_channel('localhost', 22223)
|
||||
objstore2_stub = orchestra_pb2.beta_create_ObjStore_stub(objstore2_channel)
|
||||
time.sleep(0.2)
|
||||
|
||||
scheduler_stub.RegisterObjStore(orchestra_pb2.RegisterObjStoreRequest(address="127.0.0.1:22222"), TIMEOUT_SECONDS)
|
||||
scheduler_stub.RegisterObjStore(orchestra_pb2.RegisterObjStoreRequest(address="127.0.0.1:22223"), TIMEOUT_SECONDS)
|
||||
scheduler_stub = connect_to_scheduler('localhost', 22221)
|
||||
objstore1_stub = connect_to_objstore('localhost', 22222)
|
||||
objstore2_stub = connect_to_objstore('localhost', 22223)
|
||||
|
||||
worker.global_worker.connect("127.0.0.1:22221", "127.0.0.1:40000", "127.0.0.1:22222")
|
||||
worker1 = worker.Worker()
|
||||
worker1.connect("127.0.0.1:22221", "127.0.0.1:40000", "127.0.0.1:22222")
|
||||
|
||||
other_worker = worker.Worker()
|
||||
other_worker.connect("127.0.0.1:22221", "127.0.0.1:40001", "127.0.0.1:22223")
|
||||
worker2 = worker.Worker()
|
||||
worker2.connect("127.0.0.1:22221", "127.0.0.1:40001", "127.0.0.1:22223")
|
||||
|
||||
# import IPython
|
||||
# IPython.embed()
|
||||
|
||||
for i in range(1, 10):
|
||||
for i in range(1, 100):
|
||||
l = i * 100 * "h"
|
||||
objref = worker.global_worker.do_push(l)
|
||||
# time.sleep(5.0)
|
||||
objref = worker1.do_push(l)
|
||||
response = objstore1_stub.DeliverObj(orchestra_pb2.DeliverObjRequest(objref=objref, objstore_address="0.0.0.0:22223"), TIMEOUT_SECONDS)
|
||||
# time.sleep(5.0)
|
||||
str = other_worker.get_serialized(objref)
|
||||
result = worker.unison.deserialize_from_string(str)
|
||||
# import IPython
|
||||
# IPython.embed()
|
||||
s = worker2.get_serialized(objref)
|
||||
result = worker.unison.deserialize_from_string(s)
|
||||
self.assertEqual(len(result), 100 * i)
|
||||
|
||||
class SchedulerTest(unittest.TestCase):
|
||||
services.cleanup()
|
||||
|
||||
def testRegister(self):
|
||||
scheduler_channel = implementations.insecure_channel('localhost', 22221)
|
||||
scheduler_stub = orchestra_pb2.beta_create_SchedulerServer_stub(scheduler_channel)
|
||||
w = worker.Worker()
|
||||
w.connect("127.0.0.1:22221", "127.0.0.1:40002", "127.0.0.1:22222")
|
||||
w.register_function("hello_world", 2)
|
||||
reply = scheduler_stub.GetDebugInfo(orchestra_pb2.GetDebugInfoRequest(), TIMEOUT_SECONDS)
|
||||
self.assertEqual(reply.function_table.items()[0][0], u'hello_world')
|
||||
class SchedulerTest(unittest.TestCase):
|
||||
|
||||
def testCall(self):
|
||||
scheduler_channel = implementations.insecure_channel('localhost', 22221)
|
||||
scheduler_stub = orchestra_pb2.beta_create_SchedulerServer_stub(scheduler_channel)
|
||||
w = worker.Worker()
|
||||
w.connect("127.0.0.1:22221", "127.0.0.1:40003", "127.0.0.1:22222")
|
||||
|
||||
|
||||
"""
|
||||
class SchedulerTest(unittest.TestCase):
|
||||
|
||||
def testServer(self):
|
||||
services.start_scheduler("0.0.0.0:22221")
|
||||
services.start_objstore("0.0.0.0:22222")
|
||||
services.start_objstore("0.0.0.0:22223")
|
||||
time.sleep(1.0)
|
||||
|
||||
scheduler_channel = implementations.insecure_channel('localhost', 22221)
|
||||
scheduler_stub = orchestra_pb2.beta_create_SchedulerServer_stub(scheduler_channel)
|
||||
objstore_channel = implementations.insecure_channel('localhost', 22222)
|
||||
objstore_stub = orchestra_pb2.beta_create_ObjStore_stub(objstore_channel)
|
||||
objstore_channel2 = implementations.insecure_channel('localhost', 22223)
|
||||
objstore_stub2 = orchestra_pb2.beta_create_ObjStore_stub(objstore_channel2)
|
||||
time.sleep(0.2)
|
||||
|
||||
# call = types_pb2.Call(name="test")
|
||||
# response = scheduler_stub.RemoteCall(orchestra_pb2.RemoteCallRequest(call=call), TIMEOUT_SECONDS)
|
||||
# response = scheduler_stub.RegisterFunction(orchestra_pb2.RegisterFunctionRequest(workerid=1, fnname="hello"), TIMEOUT_SECONDS)
|
||||
scheduler_stub = connect_to_scheduler('localhost', 22221)
|
||||
objstore_stub = connect_to_objstore('localhost', 22222)
|
||||
|
||||
response2 = scheduler_stub.RegisterObjStore(orchestra_pb2.RegisterObjStoreRequest(address="127.0.0.1:22222"), TIMEOUT_SECONDS)
|
||||
response2 = scheduler_stub.RegisterObjStore(orchestra_pb2.RegisterObjStoreRequest(address="127.0.0.1:22223"), TIMEOUT_SECONDS)
|
||||
time.sleep(0.2)
|
||||
|
||||
# response2 = scheduler_stub.RegisterObjStore(orchestra_pb2.RegisterObjStoreRequest(address="127.0.0.1:22222"), TIMEOUT_SECONDS)
|
||||
# response3 = scheduler_stub.RegisterObjStore(orchestra_pb2.RegisterObjStoreRequest(address="127.0.0.1:22223"), TIMEOUT_SECONDS)
|
||||
w = worker.Worker()
|
||||
w.connect("127.0.0.1:22221", "127.0.0.1:40003", "127.0.0.1:22222")
|
||||
w2 = worker.Worker()
|
||||
w2.connect("127.0.0.1:22221", "127.0.0.1:40004", "127.0.0.1:22222")
|
||||
|
||||
# objstore_stub.StreamObj(produce_data(100), TIMEOUT_SECONDS)
|
||||
time.sleep(0.2)
|
||||
|
||||
worker.global_worker.connect("127.0.0.1:22221", "127.0.0.1:40000", "127.0.0.1:22222")
|
||||
w.register_function("hello_world", 2)
|
||||
w2.register_function("hello_world", 2)
|
||||
|
||||
l = [1, 2, 3, 4]
|
||||
worker.global_worker.do_push(l)
|
||||
time.sleep(0.1)
|
||||
|
||||
## res = scheduler_stub.PushObj(orchestra_pb2.PushObjRequest(workerid=0), TIMEOUT_SECONDS)
|
||||
w.call("hello_world", ["hi"])
|
||||
|
||||
response = objstore_stub.DeliverObj(orchestra_pb2.DeliverObjRequest(objref=0, objstore_address="0.0.0.0:22223"), TIMEOUT_SECONDS)
|
||||
time.sleep(0.1)
|
||||
|
||||
# res = objstore_stub2.DebugInfo(orchestra_pb2.DebugInfoRequest(), TIMEOUT_SECONDS)
|
||||
reply = scheduler_stub.GetDebugInfo(orchestra_pb2.GetDebugInfoRequest(), TIMEOUT_SECONDS)
|
||||
|
||||
response = objstore_stub.GetObj(orchestra_pb2.GetObjRequest(objref=0), TIMEOUT_SECONDS)
|
||||
self.assertEqual(reply.task[0].name, u'hello_world')
|
||||
|
||||
worker.global_worker.get_serialized(0)
|
||||
test_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
import IPython
|
||||
IPython.embed()
|
||||
p = subprocess.Popen(["python", os.path.join(test_path, "testrecv.py")])
|
||||
|
||||
l = [1, 2, 3, 4]
|
||||
worker.global_worker.do_push(l)
|
||||
time.sleep(0.2)
|
||||
|
||||
response = objstore_stub.DeliverObj(orchestra_pb2.DeliverObjRequest(), TIMEOUT_SECONDS)
|
||||
scheduler_stub.PushObj(orchestra_pb2.PushObjRequest(workerid=0), TIMEOUT_SECONDS)
|
||||
|
||||
# response = objstore_stub.DebugInfo(orchestra_pb2.DebugInfoRequest(), TIMEOUT_SECONDS)
|
||||
reply = scheduler_stub.GetDebugInfo(orchestra_pb2.GetDebugInfoRequest(do_scheduling=True), TIMEOUT_SECONDS)
|
||||
|
||||
# import IPython
|
||||
# IPython.embed()
|
||||
self.assertEqual(p.wait(), 0, "argument was not received by the test program")
|
||||
|
||||
# worker.global_worker.connect("127.0.0.1:22221", "127.0.0.1:22222")
|
||||
# l = [1, 2, 3, 4]
|
||||
# worker.global_worker.do_push(l)
|
||||
# w.main_loop()
|
||||
# w2.main_loop()
|
||||
#
|
||||
# reply = scheduler_stub.GetDebugInfo(orchestra_pb2.GetDebugInfoRequest(do_scheduling=True), TIMEOUT_SECONDS)
|
||||
# time.sleep(0.1)
|
||||
# reply = scheduler_stub.GetDebugInfo(orchestra_pb2.GetDebugInfoRequest(), TIMEOUT_SECONDS)
|
||||
#
|
||||
# self.assertEqual(list(reply.task), [])
|
||||
#
|
||||
# services.cleanup()
|
||||
|
||||
# import IPython
|
||||
# IPython.embed()
|
||||
# response = objstore_stub.DeliverObj(orchestra_pb2.DeliverObjRequest())
|
||||
# print "Greeter client received: " + response.message
|
||||
# import IPython
|
||||
# IPython.embed()
|
||||
"""
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user