mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:53:18 +08:00
133 lines
4.6 KiB
C++
133 lines
4.6 KiB
C++
# include "worker.h"
|
|
|
|
Status WorkerServiceImpl::InvokeCall(ServerContext* context, const InvokeCallRequest* request, InvokeCallReply* reply) {
|
|
call_ = request->call(); // Copy call
|
|
ORCH_LOG(ORCH_INFO, "invoked task " << request->call().name());
|
|
Call* callptr = &call_;
|
|
send_queue_.send(&callptr);
|
|
return Status::OK;
|
|
}
|
|
|
|
Worker::Worker(const std::string& worker_address, std::shared_ptr<Channel> scheduler_channel, std::shared_ptr<Channel> objstore_channel)
|
|
: worker_address_(worker_address),
|
|
scheduler_stub_(Scheduler::NewStub(scheduler_channel)),
|
|
objstore_stub_(ObjStore::NewStub(objstore_channel)) {
|
|
receive_queue_.connect(worker_address_, true);
|
|
}
|
|
|
|
RemoteCallReply Worker::remote_call(RemoteCallRequest* request) {
|
|
RemoteCallReply reply;
|
|
ClientContext context;
|
|
Status status = scheduler_stub_->RemoteCall(&context, *request, &reply);
|
|
return reply;
|
|
}
|
|
|
|
void Worker::register_worker(const std::string& worker_address, const std::string& objstore_address) {
|
|
RegisterWorkerRequest request;
|
|
request.set_worker_address(worker_address);
|
|
request.set_objstore_address(objstore_address);
|
|
RegisterWorkerReply reply;
|
|
ClientContext context;
|
|
Status status = scheduler_stub_->RegisterWorker(&context, request, &reply);
|
|
workerid_ = reply.workerid();
|
|
request_obj_queue_.connect(std::string("queue:") + objstore_address + std::string(":obj"), false);
|
|
std::string queue_name = std::string("queue:") + objstore_address + std::string(":worker:") + std::to_string(workerid_) + std::string(":obj");
|
|
receive_obj_queue_.connect(queue_name, true);
|
|
return;
|
|
}
|
|
|
|
slice Worker::pull_object(ObjRef objref) {
|
|
PullObjRequest request;
|
|
request.set_workerid(workerid_);
|
|
request.set_objref(objref);
|
|
AckReply reply;
|
|
ClientContext context;
|
|
Status status = scheduler_stub_->PullObj(&context, request, &reply);
|
|
return get_object(objref);
|
|
}
|
|
|
|
ObjRef Worker::push_object(const Obj* obj) {
|
|
// first get objref for the new object
|
|
PushObjRequest push_request;
|
|
PushObjReply push_reply;
|
|
ClientContext push_context;
|
|
Status push_status = scheduler_stub_->PushObj(&push_context, push_request, &push_reply);
|
|
ObjRef objref = push_reply.objref();
|
|
// then stream the object to the object store
|
|
put_object(objref, obj);
|
|
return objref;
|
|
}
|
|
|
|
slice Worker::get_object(ObjRef objref) {
|
|
ObjRequest request;
|
|
request.workerid = workerid_;
|
|
request.type = ObjRequestType::GET;
|
|
request.objref = objref;
|
|
request_obj_queue_.send(&request);
|
|
ObjHandle result;
|
|
receive_obj_queue_.receive(&result);
|
|
slice slice;
|
|
slice.data = segmentpool_.get_address(result);
|
|
slice.len = result.size();
|
|
return slice;
|
|
}
|
|
|
|
// TODO(pcm): More error handling
|
|
void Worker::put_object(ObjRef objref, const Obj* obj) {
|
|
std::string data;
|
|
obj->SerializeToString(&data); // TODO(pcm): get rid of this serialization
|
|
ObjRequest request;
|
|
request.workerid = workerid_;
|
|
request.type = ObjRequestType::ALLOC;
|
|
request.objref = objref;
|
|
request.size = data.size();
|
|
request_obj_queue_.send(&request);
|
|
ObjHandle result;
|
|
receive_obj_queue_.receive(&result);
|
|
char* target = segmentpool_.get_address(result);
|
|
std::memcpy(target, &data[0], data.size());
|
|
request.type = ObjRequestType::DONE;
|
|
request_obj_queue_.send(&request);
|
|
}
|
|
|
|
void Worker::register_function(const std::string& name, size_t num_return_vals) {
|
|
ClientContext context;
|
|
RegisterFunctionRequest request;
|
|
request.set_fnname(name);
|
|
request.set_num_return_vals(num_return_vals);
|
|
request.set_workerid(workerid_);
|
|
AckReply reply;
|
|
scheduler_stub_->RegisterFunction(&context, request, &reply);
|
|
}
|
|
|
|
Call* Worker::receive_next_task() {
|
|
Call* call;
|
|
receive_queue_.receive(&call);
|
|
return call;
|
|
}
|
|
|
|
void Worker::notify_task_completed() {
|
|
ClientContext context;
|
|
WorkerReadyRequest request;
|
|
request.set_workerid(workerid_);
|
|
AckReply reply;
|
|
scheduler_stub_->WorkerReady(&context, request, &reply);
|
|
}
|
|
|
|
// Communication between the WorkerServer and the Worker happens via a message
|
|
// queue. This is because the Python interpreter needs to be single threaded
|
|
// (in our case running in the main thread), whereas the WorkerService will
|
|
// run in a separate thread and potentially utilize multiple threads.
|
|
void Worker::start_worker_service() {
|
|
const char* server_address = worker_address_.c_str();
|
|
worker_server_thread_ = std::thread([server_address]() {
|
|
WorkerServiceImpl service(server_address);
|
|
ServerBuilder builder;
|
|
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
|
builder.RegisterService(&service);
|
|
std::unique_ptr<Server> server(builder.BuildAndStart());
|
|
ORCH_LOG(ORCH_INFO, "worker server listening on " << server_address);
|
|
server->Wait();
|
|
});
|
|
}
|