mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:32:11 +08:00
clean up, mainly the scheduler
This commit is contained in:
+2
-2
@@ -70,7 +70,7 @@ if (UNIX AND NOT APPLE)
|
||||
endif()
|
||||
|
||||
add_executable(objstore src/objstore.cc ${GENERATED_PROTOBUF_FILES})
|
||||
add_executable(scheduler_server src/scheduler_server.cc src/scheduler.cc ${GENERATED_PROTOBUF_FILES})
|
||||
add_executable(scheduler src/scheduler.cc ${GENERATED_PROTOBUF_FILES})
|
||||
add_library(orchlib SHARED src/orchlib.cc src/worker.cc ${GENERATED_PROTOBUF_FILES})
|
||||
|
||||
install(TARGETS objstore scheduler_server orchlib DESTINATION ${CMAKE_SOURCE_DIR}/lib/orchpy/orchpy)
|
||||
install(TARGETS objstore scheduler orchlib DESTINATION ${CMAKE_SOURCE_DIR}/lib/orchpy/orchpy)
|
||||
|
||||
@@ -24,7 +24,7 @@ def cleanup():
|
||||
atexit.register(cleanup)
|
||||
|
||||
def start_scheduler(scheduler_address):
|
||||
p = subprocess.Popen([os.path.join(_services_path, "scheduler_server"), str(scheduler_address)])
|
||||
p = subprocess.Popen([os.path.join(_services_path, "scheduler"), str(scheduler_address)])
|
||||
all_processes.append(p)
|
||||
|
||||
def start_objstore(objstore_address):
|
||||
|
||||
@@ -32,13 +32,6 @@ cdef extern from "../../../build/generated/types.pb.h":
|
||||
bool has_obj()
|
||||
Obj* mutable_obj()
|
||||
|
||||
cdef cppclass Values:
|
||||
Values()
|
||||
int value_size()
|
||||
Value* add_value()
|
||||
Value* mutable_value(int index)
|
||||
|
||||
|
||||
cdef cppclass String:
|
||||
String()
|
||||
void set_data(const char* val)
|
||||
@@ -70,15 +63,6 @@ cdef extern from "../../../build/generated/types.pb.h":
|
||||
bool has_double_data()
|
||||
bool ParseFromString(const string& data)
|
||||
|
||||
cdef class PyValues: # TODO: unify with the below
|
||||
cdef Values *thisptr
|
||||
def __cinit__(self):
|
||||
self.thisptr = new Values()
|
||||
def __dealloc__(self):
|
||||
del self.thisptr
|
||||
def get_value(self):
|
||||
return <uintptr_t>self.thisptr
|
||||
|
||||
cdef class PyValue: # TODO: unify with the below
|
||||
cdef Value *thisptr
|
||||
def __cinit__(self):
|
||||
@@ -150,28 +134,6 @@ cpdef serialize(val):
|
||||
serialize_into(val, result.get_value())
|
||||
return result
|
||||
|
||||
cpdef serialize_args_into(args, valsptr):
|
||||
cdef uintptr_t ptr = <uintptr_t>valsptr
|
||||
cdef Values* vals = <Values*>ptr
|
||||
serialize_args_into_vals(args, vals)
|
||||
|
||||
# this code is a mess right now, will be improved in the C++ version
|
||||
cdef serialize_args_into_vals(args, Values* vals):
|
||||
cdef Value* val
|
||||
cdef Obj* obj
|
||||
for arg in args:
|
||||
val = vals[0].add_value()
|
||||
if type(arg) == ObjRef:
|
||||
val[0].set_ref(arg.get_id())
|
||||
else:
|
||||
obj = val[0].mutable_obj()
|
||||
serialize_into(arg, <uintptr_t>obj)
|
||||
|
||||
cpdef serialize_args(args):
|
||||
result = PyValues()
|
||||
serialize_args_into(args, result.get_value())
|
||||
return result
|
||||
|
||||
cdef deserialize_from(Obj* obj):
|
||||
if obj[0].has_string_data():
|
||||
return obj[0].mutable_string_data()[0].mutable_data()[0]
|
||||
@@ -193,28 +155,6 @@ cpdef deserialize_from_string(str):
|
||||
# cdef string s = string(str)
|
||||
# return deserialize_from(obj.get_value())
|
||||
|
||||
cpdef deserialize_args(PyValues args):
|
||||
cdef Values* vals = args.thisptr
|
||||
cdef Value* val
|
||||
cdef Obj* obj
|
||||
result = []
|
||||
for i in range(vals[0].value_size()):
|
||||
val = vals[0].mutable_value(i)
|
||||
if not val.has_obj():
|
||||
result.append(ObjRef(val.ref(), None)) # TODO: fix this
|
||||
else:
|
||||
obj = val[0].mutable_obj()
|
||||
if obj[0].has_string_data():
|
||||
result.append(obj[0].mutable_string_data()[0].mutable_data()[0])
|
||||
elif obj[0].has_int_data():
|
||||
result.append(obj[0].mutable_int_data()[0].data())
|
||||
elif obj[0].has_double_data():
|
||||
result.append(obj[0].mutable_double_data()[0].data())
|
||||
else:
|
||||
data = obj[0].mutable_pyobj_data()[0].mutable_data()[0]
|
||||
result.append(pickle.loads(data))
|
||||
return result
|
||||
|
||||
# todo: unify with the above, at the moment this is copied
|
||||
cdef deserialize_args_from_call(Call* call):
|
||||
cdef Value* val
|
||||
|
||||
@@ -14,12 +14,13 @@ cdef struct Slice:
|
||||
char* ptr
|
||||
size_t size
|
||||
|
||||
cdef extern void* orch_create_context(const char* server_addr, const char* worker_addr, const char* objstore_addr);
|
||||
cdef extern void* orch_create_context(const char* server_addr, const char* worker_addr, const char* objstore_addr)
|
||||
cdef extern void orch_start_worker_service(void* worker)
|
||||
cdef extern void* orch_wait_for_next_task(void* worker)
|
||||
cdef extern void orch_register_function(void* worker, const char* name, size_t num_return_vals)
|
||||
cdef extern size_t orch_remote_call(void* context, void* request);
|
||||
cdef extern size_t orch_push(void* context, void* value);
|
||||
cdef extern void* orch_main_loop(void* context);
|
||||
cdef extern Slice orch_get_serialized_obj(void* context, size_t objref);
|
||||
cdef extern size_t orch_remote_call(void* worker, void* request)
|
||||
cdef extern size_t orch_push(void* worker, void* value)
|
||||
cdef extern Slice orch_get_serialized_obj(void* worker, size_t objref)
|
||||
|
||||
cdef extern from "Python.h":
|
||||
Py_ssize_t PyByteArray_GET_SIZE(object array)
|
||||
@@ -41,6 +42,7 @@ cdef extern from "../../../build/generated/types.pb.h":
|
||||
cdef cppclass Call:
|
||||
Value* add_arg();
|
||||
void set_name(const char* value)
|
||||
const string& name()
|
||||
Value* mutable_arg(int index);
|
||||
int arg_size() const;
|
||||
|
||||
@@ -212,13 +214,18 @@ cdef deserialize_args_from_call(Call* call):
|
||||
|
||||
cdef class Worker:
|
||||
cdef void* context
|
||||
cdef dict functions
|
||||
|
||||
def __cinit__(self):
|
||||
self.context = NULL
|
||||
self.functions = {}
|
||||
|
||||
def connect(self, server_addr, worker_addr, objstore_addr):
|
||||
self.context = orch_create_context(server_addr, worker_addr, objstore_addr)
|
||||
|
||||
def start_worker_service(self):
|
||||
orch_start_worker_service(self.context)
|
||||
|
||||
cpdef call(self, name, args):
|
||||
cdef RemoteCallRequest* result = new RemoteCallRequest()
|
||||
cdef Call* call = result[0].mutable_call()
|
||||
@@ -230,17 +237,9 @@ cdef class Worker:
|
||||
cpdef do_call(self, ptr):
|
||||
return orch_remote_call(self.context, <void*>ptr)
|
||||
|
||||
cpdef do_push(self, val):
|
||||
print("before serialization")
|
||||
cpdef push(self, val):
|
||||
result = unison.serialize(val)
|
||||
print("before push")
|
||||
# ptr = result.get_value()
|
||||
# print "pointer is", ptr
|
||||
# cdef Obj* obj = new Obj()
|
||||
o = ObjWrapper()
|
||||
# serialize_into_2(0, <uintptr_t>obj)
|
||||
# cdef Obj* ptr = new Obj() # o.get_value()
|
||||
## ptr = <uintptr_t>o.get_value()
|
||||
ptr = <uintptr_t>result.get_value()
|
||||
serialize_into_2(0, ptr)
|
||||
return orch_push(self.context, <void*>ptr)
|
||||
@@ -258,17 +257,21 @@ cdef class Worker:
|
||||
data = PyBytes_FromStringAndSize(slice.ptr, slice.size)
|
||||
return unison.deserialize_from_string(data)
|
||||
|
||||
cpdef register_function(self, func_name, num_args):
|
||||
cpdef register_function(self, func_name, function, num_args):
|
||||
orch_register_function(self.context, func_name, num_args)
|
||||
self.functions[func_name] = function
|
||||
|
||||
cpdef main_loop(self):
|
||||
result = []
|
||||
cdef Call* call = <Call*>orch_main_loop(self.context)
|
||||
cdef Call* call = <Call*>orch_wait_for_next_task(self.context)
|
||||
cdef int size = call[0].arg_size()
|
||||
cdef Obj* obj
|
||||
print "hello from python"
|
||||
print "size", size
|
||||
return deserialize_args_from_call(call)
|
||||
return call[0].name(), deserialize_args_from_call(call)
|
||||
|
||||
cpdef invoke_function(self, name, args):
|
||||
return self.functions[name].executor(args)
|
||||
|
||||
global_worker = Worker()
|
||||
|
||||
@@ -277,25 +280,27 @@ def distributed(types, return_type, worker=global_worker):
|
||||
# deserialize arguments, execute function and serialize result
|
||||
def func_executor(args):
|
||||
arguments = []
|
||||
protoargs = unison.deserialize_call(args, types)
|
||||
for (i, proto) in enumerate(protoargs):
|
||||
if type(proto) == unison.ObjRef:
|
||||
for (i, arg) in enumerate(args):
|
||||
if type(arg) == unison.ObjRef:
|
||||
if i < len(types) - 1:
|
||||
arguments.append(worker.get_object(proto, types[i]))
|
||||
arguments.append(worker.get_object(arg, types[i]))
|
||||
elif i == len(types) - 1 and types[-1] is not None:
|
||||
arguments.append(global_worker.get_object(proto, types[i]))
|
||||
arguments.append(global_worker.get_object(arg, types[i]))
|
||||
elif types[-1] is None:
|
||||
arguments.append(worker.get_object(proto, types[-2]))
|
||||
arguments.append(worker.get_object(arg, types[-2]))
|
||||
else:
|
||||
raise Exception("Passed in " + str(len(args)) + " arguments to function " + func.__name__ + ", which takes only " + str(len(types)) + " arguments.")
|
||||
else:
|
||||
arguments.append(proto)
|
||||
buf = bytearray()
|
||||
arguments.append(arg)
|
||||
# TODO
|
||||
# buf = bytearray()
|
||||
print "called with arguments", arguments
|
||||
result = func(*arguments)
|
||||
if unison.unison_type(result) != return_type:
|
||||
raise Exception("Return type of " + func.func_name + " does not match the return type specified in the @distributed decorator, was expecting " + str(return_type) + " but received " + str(unison.unison_type(result)))
|
||||
unison.serialize(buf, result)
|
||||
return memoryview(buf).tobytes()
|
||||
# if unison.unison_type(result) != return_type:
|
||||
# raise Exception("Return type of " + func.func_name + " does not match the return type specified in the @distributed decorator, was expecting " + str(return_type) + " but received " + str(unison.unison_type(result)))
|
||||
# unison.serialize(buf, result)
|
||||
# return memoryview(buf).tobytes()
|
||||
return result
|
||||
# for remotely executing the function
|
||||
def func_call(*args, typecheck=False):
|
||||
return worker.call(func_call.func_name, func_call.module_name, args)
|
||||
|
||||
+1
-1
@@ -26,7 +26,7 @@ setup(
|
||||
packages=find_packages(),
|
||||
package_data = {
|
||||
'orchpy': ['liborchlib.dylib' if MACOSX else 'liborchlib.so',
|
||||
'scheduler_server',
|
||||
'scheduler',
|
||||
'objstore']
|
||||
},
|
||||
zip_safe=False
|
||||
|
||||
@@ -68,7 +68,7 @@ message GetDebugInfoReply {
|
||||
map<string, FnTableEntry> function_table = 2;
|
||||
}
|
||||
|
||||
service SchedulerServer {
|
||||
service Scheduler {
|
||||
rpc RegisterWorker(RegisterWorkerRequest) returns (RegisterWorkerReply);
|
||||
rpc RegisterObjStore(RegisterObjStoreRequest) returns (RegisterObjStoreReply);
|
||||
rpc RegisterFunction(RegisterFunctionRequest) returns (AckReply);
|
||||
@@ -130,6 +130,6 @@ message InvokeCallReply {
|
||||
|
||||
}
|
||||
|
||||
service WorkerServer {
|
||||
service WorkerService {
|
||||
rpc InvokeCall(InvokeCallRequest) returns (InvokeCallReply);
|
||||
}
|
||||
|
||||
+1
-1
@@ -36,7 +36,7 @@ message Value {
|
||||
message Call {
|
||||
string name = 1;
|
||||
repeated Value arg = 2;
|
||||
repeated uint64 result = 3;
|
||||
repeated uint64 result = 3; // object references for result
|
||||
}
|
||||
|
||||
enum DataType {
|
||||
|
||||
+6
-2
@@ -12,8 +12,12 @@ size_t orch_remote_call(Worker* worker, RemoteCallRequest* request) {
|
||||
return worker->remote_call(request);
|
||||
}
|
||||
|
||||
Call* orch_main_loop(Worker* worker) {
|
||||
return worker->main_loop();
|
||||
void orch_start_worker_service(Worker* worker) {
|
||||
worker->start_worker_service();
|
||||
}
|
||||
|
||||
Call* orch_wait_for_next_task(Worker* worker) {
|
||||
return worker->receive_next_task();
|
||||
}
|
||||
|
||||
size_t orch_push(Worker* worker, Obj* obj) {
|
||||
|
||||
+10
-3
@@ -1,3 +1,6 @@
|
||||
// A minimal C API that is used for implementing Orchestra workers in C based
|
||||
// languages (Python at the moment, in the future potentially Julia, R, MATLAB)
|
||||
|
||||
extern "C" {
|
||||
|
||||
struct slice {
|
||||
@@ -9,10 +12,14 @@ struct Worker;
|
||||
struct RemoteCallRequest;
|
||||
struct Value;
|
||||
|
||||
// connect to the scheduler and the object store
|
||||
Worker* orch_create_context(const char* server_addr, const char* worker_addr, const char* objstore_addr);
|
||||
size_t orch_remote_call(Worker* context, RemoteCallRequest* request);
|
||||
size_t orch_push(Worker* context, Obj* value);
|
||||
Call* orch_main_loop(Worker* worker);
|
||||
// start the worker service for this worker
|
||||
void orch_start_worker_service(Worker* worker);
|
||||
// Submit a function call to the scheduler
|
||||
size_t orch_remote_call(Worker* worker, RemoteCallRequest* request);
|
||||
size_t orch_push(Worker* worker, Obj* value);
|
||||
Call* orch_wait_for_next_task(Worker* worker);
|
||||
slice orch_get_serialized_obj(Worker* worker, size_t objref);
|
||||
void orch_register_function(Worker* worker, const char* name, size_t num_return_vals);
|
||||
|
||||
|
||||
+68
-15
@@ -1,17 +1,54 @@
|
||||
#include "scheduler.h"
|
||||
|
||||
size_t Scheduler::add_task(const Call& task) {
|
||||
Status SchedulerService::RemoteCall(ServerContext* context, const RemoteCallRequest* request, RemoteCallReply* reply) {
|
||||
std::unique_ptr<Call> task(new Call(request->call())); // need to copy, because request is const
|
||||
fntable_lock_.lock();
|
||||
size_t num_return_vals = fntable_[task.name()].num_return_vals();
|
||||
size_t num_return_vals = fntable_[task->name()].num_return_vals();
|
||||
fntable_lock_.unlock();
|
||||
std::unique_ptr<Call> task_ptr(new Call(task));
|
||||
|
||||
for (size_t i = 0; i < num_return_vals; ++i) {
|
||||
ObjRef result = register_new_object();
|
||||
reply->add_result(result);
|
||||
task->add_result(result);
|
||||
}
|
||||
|
||||
tasks_lock_.lock();
|
||||
tasks_.emplace_back(std::move(task_ptr));
|
||||
tasks_.emplace_back(std::move(task));
|
||||
tasks_lock_.unlock();
|
||||
return num_return_vals;
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
void Scheduler::schedule() {
|
||||
Status SchedulerService::PushObj(ServerContext* context, const PushObjRequest* request, PushObjReply* reply) {
|
||||
ObjRef objref = register_new_object();
|
||||
ObjStoreId objstoreid = get_store(request->workerid());
|
||||
add_location(objref, objstoreid);
|
||||
reply->set_objref(objref);
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
Status SchedulerService::PullObj(ServerContext* context, const PullObjRequest* request, AckReply* reply) {
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
Status SchedulerService::RegisterWorker(ServerContext* context, const RegisterWorkerRequest* request, RegisterWorkerReply* reply) {
|
||||
WorkerId workerid = register_worker(request->worker_address(), request->objstore_address());
|
||||
std::cout << "registered worker with workerid" << workerid << std::endl;
|
||||
reply->set_workerid(workerid);
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
Status SchedulerService::RegisterFunction(ServerContext* context, const RegisterFunctionRequest* request, AckReply* reply) {
|
||||
std::cout << "RegisterFunction: workerid is" << request->workerid() << std::endl;
|
||||
register_function(request->fnname(), request->workerid(), request->num_return_vals());
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
Status SchedulerService::GetDebugInfo(ServerContext* context, const GetDebugInfoRequest* request, GetDebugInfoReply* reply) {
|
||||
debug_info(*request, reply);
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
void SchedulerService::schedule() {
|
||||
// TODO: work out a better strategy here
|
||||
WorkerId workerid = 0;
|
||||
{
|
||||
@@ -34,7 +71,7 @@ void Scheduler::schedule() {
|
||||
}
|
||||
}
|
||||
|
||||
void Scheduler::submit_task(std::unique_ptr<Call> call, WorkerId workerid) {
|
||||
void SchedulerService::submit_task(std::unique_ptr<Call> call, WorkerId workerid) {
|
||||
ClientContext context;
|
||||
InvokeCallRequest request;
|
||||
InvokeCallReply reply;
|
||||
@@ -59,7 +96,7 @@ void Scheduler::submit_task(std::unique_ptr<Call> call, WorkerId workerid) {
|
||||
Status status = workers_[workerid].worker_stub->InvokeCall(&context, request, &reply);
|
||||
}
|
||||
|
||||
bool Scheduler::can_run(const Call& task) {
|
||||
bool SchedulerService::can_run(const Call& task) {
|
||||
std::lock_guard<std::mutex> lock(objtable_lock_);
|
||||
for (int i = 0; i < task.arg_size(); ++i) {
|
||||
if (!task.arg(i).has_obj()) {
|
||||
@@ -71,7 +108,7 @@ bool Scheduler::can_run(const Call& task) {
|
||||
return true;
|
||||
}
|
||||
|
||||
WorkerId Scheduler::register_worker(const std::string& worker_address, const std::string& objstore_address) {
|
||||
WorkerId SchedulerService::register_worker(const std::string& worker_address, const std::string& objstore_address) {
|
||||
ObjStoreId objstoreid = std::numeric_limits<size_t>::max();
|
||||
objstores_lock_.lock();
|
||||
for (size_t i = 0; i < objstores_.size(); ++i) {
|
||||
@@ -97,7 +134,7 @@ WorkerId Scheduler::register_worker(const std::string& worker_address, const std
|
||||
auto channel = grpc::CreateChannel(worker_address, grpc::InsecureChannelCredentials());
|
||||
workers_[workerid].channel = channel;
|
||||
workers_[workerid].objstoreid = objstoreid;
|
||||
workers_[workerid].worker_stub = WorkerServer::NewStub(channel);
|
||||
workers_[workerid].worker_stub = WorkerService::NewStub(channel);
|
||||
workers_lock_.unlock();
|
||||
avail_workers_lock_.lock();
|
||||
avail_workers_.push_back(workerid);
|
||||
@@ -105,7 +142,7 @@ WorkerId Scheduler::register_worker(const std::string& worker_address, const std
|
||||
return workerid;
|
||||
}
|
||||
|
||||
ObjRef Scheduler::register_new_object() {
|
||||
ObjRef SchedulerService::register_new_object() {
|
||||
objtable_lock_.lock();
|
||||
ObjRef result = objtable_.size();
|
||||
objtable_.push_back(std::vector<ObjStoreId>());
|
||||
@@ -113,7 +150,7 @@ ObjRef Scheduler::register_new_object() {
|
||||
return result;
|
||||
}
|
||||
|
||||
void Scheduler::add_location(ObjRef objref, ObjStoreId objstoreid) {
|
||||
void SchedulerService::add_location(ObjRef objref, ObjStoreId objstoreid) {
|
||||
objtable_lock_.lock();
|
||||
// do a binary search
|
||||
auto pos = std::lower_bound(objtable_[objref].begin(), objtable_[objref].end(), objstoreid);
|
||||
@@ -123,14 +160,14 @@ void Scheduler::add_location(ObjRef objref, ObjStoreId objstoreid) {
|
||||
objtable_lock_.unlock();
|
||||
}
|
||||
|
||||
ObjStoreId Scheduler::get_store(WorkerId workerid) {
|
||||
ObjStoreId SchedulerService::get_store(WorkerId workerid) {
|
||||
workers_lock_.lock();
|
||||
ObjStoreId result = workers_[workerid].objstoreid;
|
||||
workers_lock_.unlock();
|
||||
return result;
|
||||
}
|
||||
|
||||
void Scheduler::register_function(const std::string& name, WorkerId workerid, size_t num_return_vals) {
|
||||
void SchedulerService::register_function(const std::string& name, WorkerId workerid, size_t num_return_vals) {
|
||||
fntable_lock_.lock();
|
||||
FnInfo& info = fntable_[name];
|
||||
info.set_num_return_vals(num_return_vals);
|
||||
@@ -138,7 +175,7 @@ void Scheduler::register_function(const std::string& name, WorkerId workerid, si
|
||||
fntable_lock_.unlock();
|
||||
}
|
||||
|
||||
void Scheduler::debug_info(const GetDebugInfoRequest& request, GetDebugInfoReply* reply) {
|
||||
void SchedulerService::debug_info(const GetDebugInfoRequest& request, GetDebugInfoReply* reply) {
|
||||
if (request.do_scheduling()) {
|
||||
schedule();
|
||||
}
|
||||
@@ -163,3 +200,19 @@ void Scheduler::debug_info(const GetDebugInfoRequest& request, GetDebugInfoReply
|
||||
}
|
||||
avail_workers_lock_.unlock();
|
||||
}
|
||||
|
||||
void start_scheduler_service(const char* server_address) {
|
||||
SchedulerService service;
|
||||
ServerBuilder builder;
|
||||
builder.AddListeningPort(std::string(server_address), grpc::InsecureServerCredentials());
|
||||
builder.RegisterService(&service);
|
||||
std::unique_ptr<Server> server(builder.BuildAndStart());
|
||||
server->Wait();
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc != 2)
|
||||
return 1;
|
||||
start_scheduler_service(argv[1]);
|
||||
return 0;
|
||||
}
|
||||
|
||||
+10
-4
@@ -1,6 +1,7 @@
|
||||
#ifndef ORCHESTRA_SCHEDULER_H
|
||||
#define ORCHESTRA_SCHEDULER_H
|
||||
|
||||
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
@@ -24,7 +25,7 @@ using grpc::Channel;
|
||||
|
||||
struct WorkerHandle {
|
||||
std::shared_ptr<Channel> channel;
|
||||
std::unique_ptr<WorkerServer::Stub> worker_stub;
|
||||
std::unique_ptr<WorkerService::Stub> worker_stub;
|
||||
ObjStoreId objstoreid;
|
||||
};
|
||||
|
||||
@@ -34,10 +35,15 @@ struct ObjStoreHandle {
|
||||
std::string address;
|
||||
};
|
||||
|
||||
class Scheduler {
|
||||
class SchedulerService : public Scheduler::Service {
|
||||
public:
|
||||
// returns number of return values of task
|
||||
size_t add_task(const Call& task);
|
||||
Status RemoteCall(ServerContext* context, const RemoteCallRequest* request, RemoteCallReply* reply) override;
|
||||
Status PushObj(ServerContext* context, const PushObjRequest* request, PushObjReply* reply) override;
|
||||
Status PullObj(ServerContext* context, const PullObjRequest* request, AckReply* reply) override;
|
||||
Status RegisterWorker(ServerContext* context, const RegisterWorkerRequest* request, RegisterWorkerReply* reply) override;
|
||||
Status RegisterFunction(ServerContext* context, const RegisterFunctionRequest* request, AckReply* reply) override;
|
||||
Status GetDebugInfo(ServerContext* context, const GetDebugInfoRequest* request, GetDebugInfoReply* reply) override;
|
||||
|
||||
// assign a task to a worker
|
||||
void schedule();
|
||||
// execute a task on a worker and ship required object references
|
||||
|
||||
+14
-17
@@ -92,29 +92,26 @@ void Worker::register_function(const std::string& name, size_t num_return_vals)
|
||||
scheduler_stub_->RegisterFunction(&context, request, &reply);
|
||||
}
|
||||
|
||||
void start_worker_server(const char* server_address) {
|
||||
WorkerServiceImpl service(server_address);
|
||||
ServerBuilder builder;
|
||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
||||
builder.RegisterService(&service);
|
||||
std::unique_ptr<Server> server(builder.BuildAndStart());
|
||||
std::cout << "Server listening on " << server_address << std::endl;
|
||||
server->Wait();
|
||||
}
|
||||
|
||||
// Communication between the WorkerServer and the Worker happens via a message
|
||||
// queue. This is because the Python interpreter needs to be single threaded
|
||||
// (in our case running in the main thread), whereas the WorkerService will
|
||||
// run in a separate thread and potentially utilize multiple threads.
|
||||
Call* Worker::main_loop() {
|
||||
// start the worker server
|
||||
worker_server_thread_ = std::thread(start_worker_server, worker_address_.c_str());
|
||||
// process the next call
|
||||
return receive(worker_address_.c_str());
|
||||
void Worker::start_worker_service() {
|
||||
const char* server_address = worker_address_.c_str();
|
||||
worker_server_thread_ = std::thread([server_address]() {
|
||||
WorkerServiceImpl service(server_address);
|
||||
ServerBuilder builder;
|
||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
||||
builder.RegisterService(&service);
|
||||
std::unique_ptr<Server> server(builder.BuildAndStart());
|
||||
std::cout << "WorkerServer listening on " << server_address << std::endl;
|
||||
server->Wait();
|
||||
});
|
||||
}
|
||||
|
||||
Call* receive(const char* message_queue_name) {
|
||||
try {
|
||||
Call* Worker::receive_next_task() {
|
||||
const char* message_queue_name = worker_address_.c_str();
|
||||
try {
|
||||
message_queue::remove(message_queue_name);
|
||||
message_queue mq(create_only, message_queue_name, 1, sizeof(Call*));
|
||||
unsigned int priority;
|
||||
|
||||
+8
-9
@@ -27,7 +27,7 @@ using grpc::Channel;
|
||||
using grpc::ClientContext;
|
||||
using grpc::ClientWriter;
|
||||
|
||||
class WorkerServiceImpl final : public WorkerServer::Service {
|
||||
class WorkerServiceImpl final : public WorkerService::Service {
|
||||
public:
|
||||
WorkerServiceImpl(const std::string& worker_address)
|
||||
: worker_address_(worker_address) {}
|
||||
@@ -37,15 +37,11 @@ private:
|
||||
Call call_; // copy of the current call
|
||||
};
|
||||
|
||||
void start_worker_server(const char* worker_addr);
|
||||
|
||||
Call* receive(const char* worker_addr);
|
||||
|
||||
class Worker {
|
||||
public:
|
||||
Worker(const std::string& worker_address, std::shared_ptr<Channel> scheduler_channel, std::shared_ptr<Channel> objstore_channel)
|
||||
: worker_address_(worker_address),
|
||||
scheduler_stub_(SchedulerServer::NewStub(scheduler_channel)),
|
||||
scheduler_stub_(Scheduler::NewStub(scheduler_channel)),
|
||||
objstore_stub_(ObjStore::NewStub(objstore_channel))
|
||||
{}
|
||||
|
||||
@@ -59,12 +55,15 @@ class Worker {
|
||||
slice get_serialized_obj(ObjRef objref);
|
||||
// register function with scheduler
|
||||
void register_function(const std::string& name, size_t num_return_vals);
|
||||
// start the main loop
|
||||
Call* main_loop();
|
||||
// start the worker server which accepts tasks from the scheduler and stores
|
||||
// it in the message queue, which is read by the Python interpreter
|
||||
void start_worker_service();
|
||||
// wait for next task from the RPC system
|
||||
Call* receive_next_task();
|
||||
|
||||
private:
|
||||
const size_t CHUNK_SIZE = 8 * 1024;
|
||||
std::unique_ptr<SchedulerServer::Stub> scheduler_stub_;
|
||||
std::unique_ptr<Scheduler::Stub> scheduler_stub_;
|
||||
std::unique_ptr<ObjStore::Stub> objstore_stub_;
|
||||
std::thread worker_server_thread_;
|
||||
std::thread other_thread_;
|
||||
|
||||
+6
-4
@@ -13,6 +13,7 @@ from grpc.beta import implementations
|
||||
import orchestra_pb2
|
||||
import types_pb2
|
||||
|
||||
"""
|
||||
class UnisonTest(unittest.TestCase):
|
||||
|
||||
def testSerialize(self):
|
||||
@@ -35,6 +36,7 @@ class UnisonTest(unittest.TestCase):
|
||||
res = unison.serialize_args(a)
|
||||
b = unison.deserialize_args(res)
|
||||
self.assertEqual(a, b)
|
||||
"""
|
||||
|
||||
TIMEOUT_SECONDS = 5
|
||||
|
||||
@@ -44,7 +46,7 @@ def produce_data(num_chunks):
|
||||
|
||||
def connect_to_scheduler(host, port):
|
||||
channel = implementations.insecure_channel(host, port)
|
||||
return orchestra_pb2.beta_create_SchedulerServer_stub(channel)
|
||||
return orchestra_pb2.beta_create_Scheduler_stub(channel)
|
||||
|
||||
def connect_to_objstore(host, port):
|
||||
channel = implementations.insecure_channel(host, port)
|
||||
@@ -72,7 +74,7 @@ class ObjStoreTest(unittest.TestCase):
|
||||
|
||||
for i in range(1, 100):
|
||||
l = i * 100 * "h"
|
||||
objref = worker1.do_push(l)
|
||||
objref = worker1.push(l)
|
||||
response = objstore1_stub.DeliverObj(orchestra_pb2.DeliverObjRequest(objref=objref, objstore_address="0.0.0.0:22223"), TIMEOUT_SECONDS)
|
||||
s = worker2.get_serialized(objref)
|
||||
result = worker.unison.deserialize_from_string(s)
|
||||
@@ -100,8 +102,8 @@ class SchedulerTest(unittest.TestCase):
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
w.register_function("hello_world", 2)
|
||||
w2.register_function("hello_world", 2)
|
||||
w.register_function("hello_world", None, 2)
|
||||
w2.register_function("hello_world", None, 2)
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user