mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 11:10:02 +08:00
315 lines
12 KiB
C++
315 lines
12 KiB
C++
#include "worker.h"
|
|
|
|
#include "utils.h"
|
|
|
|
#include <pynumbuf/serialize.h>
|
|
|
|
extern "C" {
|
|
static PyObject *RayError;
|
|
}
|
|
|
|
Status WorkerServiceImpl::ExecuteTask(ServerContext* context, const ExecuteTaskRequest* request, ExecuteTaskReply* reply) {
|
|
task_ = request->task(); // Copy task
|
|
RAY_LOG(RAY_INFO, "invoked task " << request->task().name());
|
|
Task* taskptr = &task_;
|
|
send_queue_.send(&taskptr);
|
|
return Status::OK;
|
|
}
|
|
|
|
Worker::Worker(const std::string& worker_address, std::shared_ptr<Channel> scheduler_channel, std::shared_ptr<Channel> objstore_channel)
|
|
: worker_address_(worker_address),
|
|
scheduler_stub_(Scheduler::NewStub(scheduler_channel)),
|
|
objstore_stub_(ObjStore::NewStub(objstore_channel)) {
|
|
receive_queue_.connect(worker_address_, true);
|
|
connected_ = true;
|
|
}
|
|
|
|
SubmitTaskReply Worker::submit_task(SubmitTaskRequest* request) {
|
|
if (!connected_) {
|
|
RAY_LOG(RAY_FATAL, "Attempting to perform submit_task, but connected_ = " << connected_ << ".");
|
|
}
|
|
SubmitTaskReply reply;
|
|
ClientContext context;
|
|
Status status = scheduler_stub_->SubmitTask(&context, *request, &reply);
|
|
return reply;
|
|
}
|
|
|
|
void Worker::register_worker(const std::string& worker_address, const std::string& objstore_address) {
|
|
RegisterWorkerRequest request;
|
|
request.set_worker_address(worker_address);
|
|
request.set_objstore_address(objstore_address);
|
|
RegisterWorkerReply reply;
|
|
ClientContext context;
|
|
Status status = scheduler_stub_->RegisterWorker(&context, request, &reply);
|
|
workerid_ = reply.workerid();
|
|
objstoreid_ = reply.objstoreid();
|
|
segmentpool_ = std::make_shared<MemorySegmentPool>(objstoreid_, false);
|
|
request_obj_queue_.connect(std::string("queue:") + objstore_address + std::string(":obj"), false);
|
|
std::string queue_name = std::string("queue:") + objstore_address + std::string(":worker:") + std::to_string(workerid_) + std::string(":obj");
|
|
receive_obj_queue_.connect(queue_name, true);
|
|
return;
|
|
}
|
|
|
|
void Worker::request_object(ObjRef objref) {
|
|
if (!connected_) {
|
|
RAY_LOG(RAY_FATAL, "Attempting to perform request_object, but connected_ = " << connected_ << ".");
|
|
}
|
|
RequestObjRequest request;
|
|
request.set_workerid(workerid_);
|
|
request.set_objref(objref);
|
|
AckReply reply;
|
|
ClientContext context;
|
|
Status status = scheduler_stub_->RequestObj(&context, request, &reply);
|
|
return;
|
|
}
|
|
|
|
ObjRef Worker::get_objref() {
|
|
// first get objref for the new object
|
|
if (!connected_) {
|
|
RAY_LOG(RAY_FATAL, "Attempting to perform get_objref, but connected_ = " << connected_ << ".");
|
|
}
|
|
PushObjRequest push_request;
|
|
PushObjReply push_reply;
|
|
ClientContext push_context;
|
|
Status push_status = scheduler_stub_->PushObj(&push_context, push_request, &push_reply);
|
|
return push_reply.objref();
|
|
}
|
|
|
|
slice Worker::get_object(ObjRef objref) {
|
|
// get_object assumes that objref is a canonical objref
|
|
if (!connected_) {
|
|
RAY_LOG(RAY_FATAL, "Attempting to perform get_object, but connected_ = " << connected_ << ".");
|
|
}
|
|
ObjRequest request;
|
|
request.workerid = workerid_;
|
|
request.type = ObjRequestType::GET;
|
|
request.objref = objref;
|
|
request_obj_queue_.send(&request);
|
|
ObjHandle result;
|
|
receive_obj_queue_.receive(&result);
|
|
slice slice;
|
|
slice.data = segmentpool_->get_address(result);
|
|
slice.len = result.size();
|
|
return slice;
|
|
}
|
|
|
|
// TODO(pcm): More error handling
|
|
// contained_objrefs is a vector of all the objrefs contained in obj
|
|
void Worker::put_object(ObjRef objref, const Obj* obj, std::vector<ObjRef> &contained_objrefs) {
|
|
if (!connected_) {
|
|
RAY_LOG(RAY_FATAL, "Attempting to perform put_object, but connected_ = " << connected_ << ".");
|
|
}
|
|
std::string data;
|
|
obj->SerializeToString(&data); // TODO(pcm): get rid of this serialization
|
|
ObjRequest request;
|
|
request.workerid = workerid_;
|
|
request.type = ObjRequestType::ALLOC;
|
|
request.objref = objref;
|
|
request.size = data.size();
|
|
request_obj_queue_.send(&request);
|
|
if (contained_objrefs.size() > 0) {
|
|
RAY_LOG(RAY_REFCOUNT, "In put_object, calling increment_reference_count for contained objrefs");
|
|
increment_reference_count(contained_objrefs); // Notify the scheduler that some object references are serialized in the objstore.
|
|
}
|
|
ObjHandle result;
|
|
receive_obj_queue_.receive(&result);
|
|
uint8_t* target = segmentpool_->get_address(result);
|
|
std::memcpy(target, &data[0], data.size());
|
|
request.type = ObjRequestType::WORKER_DONE;
|
|
request.metadata_offset = 0;
|
|
request_obj_queue_.send(&request);
|
|
|
|
// Notify the scheduler about the objrefs that we are serializing in the objstore.
|
|
AddContainedObjRefsRequest contained_objrefs_request;
|
|
contained_objrefs_request.set_objref(objref);
|
|
for (int i = 0; i < contained_objrefs.size(); ++i) {
|
|
contained_objrefs_request.add_contained_objref(contained_objrefs[i]); // TODO(rkn): The naming here is bad
|
|
}
|
|
AckReply reply;
|
|
ClientContext context;
|
|
scheduler_stub_->AddContainedObjRefs(&context, contained_objrefs_request, &reply);
|
|
}
|
|
|
|
#define CHECK_ARROW_STATUS(s, msg) \
|
|
do { \
|
|
arrow::Status _s = (s); \
|
|
if (!_s.ok()) { \
|
|
std::string _errmsg = std::string(msg) + _s.ToString(); \
|
|
PyErr_SetString(RayError, _errmsg.c_str()); \
|
|
return NULL; \
|
|
} \
|
|
} while (0);
|
|
|
|
PyObject* Worker::put_arrow(ObjRef objref, PyObject* value) {
|
|
if (!connected_) {
|
|
RAY_LOG(RAY_FATAL, "Attempting to perform put_arrow, but connected_ = " << connected_ << ".");
|
|
}
|
|
ObjRequest request;
|
|
pynumbuf::PythonObjectWriter writer;
|
|
int64_t size;
|
|
CHECK_ARROW_STATUS(writer.AssemblePayload(value), "error during AssemblePayload: ");
|
|
CHECK_ARROW_STATUS(writer.GetTotalSize(&size), "error during GetTotalSize: ");
|
|
request.workerid = workerid_;
|
|
request.type = ObjRequestType::ALLOC;
|
|
request.objref = objref;
|
|
request.size = size;
|
|
request_obj_queue_.send(&request);
|
|
ObjHandle result;
|
|
receive_obj_queue_.receive(&result);
|
|
int64_t metadata_offset;
|
|
uint8_t* address = segmentpool_->get_address(result);
|
|
auto source = std::make_shared<BufferMemorySource>(address, size);
|
|
CHECK_ARROW_STATUS(writer.Write(source.get(), &metadata_offset), "error during Write: ");
|
|
request.type = ObjRequestType::WORKER_DONE;
|
|
request.metadata_offset = metadata_offset;
|
|
request_obj_queue_.send(&request);
|
|
Py_RETURN_NONE;
|
|
}
|
|
|
|
PyObject* Worker::get_arrow(ObjRef objref) {
|
|
if (!connected_) {
|
|
RAY_LOG(RAY_FATAL, "Attempting to perform get_arrow, but connected_ = " << connected_ << ".");
|
|
}
|
|
ObjRequest request;
|
|
request.workerid = workerid_;
|
|
request.type = ObjRequestType::GET;
|
|
request.objref = objref;
|
|
request_obj_queue_.send(&request);
|
|
ObjHandle result;
|
|
receive_obj_queue_.receive(&result);
|
|
uint8_t* address = segmentpool_->get_address(result);
|
|
auto source = std::make_shared<BufferMemorySource>(address, result.size());
|
|
PyObject* value;
|
|
CHECK_ARROW_STATUS(pynumbuf::ReadPythonObjectFrom(source.get(), result.metadata_offset(), &value), "error during ReadPythonObjectFrom: ");
|
|
return value;
|
|
}
|
|
|
|
bool Worker::is_arrow(ObjRef objref) {
|
|
if (!connected_) {
|
|
RAY_LOG(RAY_FATAL, "Attempting to perform is_arrow, but connected_ = " << connected_ << ".");
|
|
}
|
|
ObjRequest request;
|
|
request.workerid = workerid_;
|
|
request.type = ObjRequestType::GET;
|
|
request.objref = objref;
|
|
request_obj_queue_.send(&request);
|
|
ObjHandle result;
|
|
receive_obj_queue_.receive(&result);
|
|
return result.metadata_offset() != 0;
|
|
}
|
|
|
|
void Worker::alias_objrefs(ObjRef alias_objref, ObjRef target_objref) {
|
|
if (!connected_) {
|
|
RAY_LOG(RAY_FATAL, "Attempting to perform alias_objrefs, but connected_ = " << connected_ << ".");
|
|
}
|
|
ClientContext context;
|
|
AliasObjRefsRequest request;
|
|
request.set_alias_objref(alias_objref);
|
|
request.set_target_objref(target_objref);
|
|
AckReply reply;
|
|
scheduler_stub_->AliasObjRefs(&context, request, &reply);
|
|
}
|
|
|
|
void Worker::increment_reference_count(std::vector<ObjRef> &objrefs) {
|
|
if (!connected_) {
|
|
RAY_LOG(RAY_DEBUG, "Attempting to increment_reference_count for objrefs, but connected_ = " << connected_ << " so returning instead.");
|
|
return;
|
|
}
|
|
if (objrefs.size() > 0) {
|
|
ClientContext context;
|
|
IncrementRefCountRequest request;
|
|
for (int i = 0; i < objrefs.size(); ++i) {
|
|
RAY_LOG(RAY_REFCOUNT, "Incrementing reference count for objref " << objrefs[i]);
|
|
request.add_objref(objrefs[i]);
|
|
}
|
|
AckReply reply;
|
|
scheduler_stub_->IncrementRefCount(&context, request, &reply);
|
|
}
|
|
}
|
|
|
|
void Worker::decrement_reference_count(std::vector<ObjRef> &objrefs) {
|
|
if (!connected_) {
|
|
RAY_LOG(RAY_DEBUG, "Attempting to decrement_reference_count, but connected_ = " << connected_ << " so returning instead.");
|
|
return;
|
|
}
|
|
if (objrefs.size() > 0) {
|
|
ClientContext context;
|
|
DecrementRefCountRequest request;
|
|
for (int i = 0; i < objrefs.size(); ++i) {
|
|
RAY_LOG(RAY_REFCOUNT, "Decrementing reference count for objref " << objrefs[i]);
|
|
request.add_objref(objrefs[i]);
|
|
}
|
|
AckReply reply;
|
|
scheduler_stub_->DecrementRefCount(&context, request, &reply);
|
|
}
|
|
}
|
|
|
|
void Worker::register_function(const std::string& name, size_t num_return_vals) {
|
|
if (!connected_) {
|
|
RAY_LOG(RAY_FATAL, "Attempting to perform register_function, but connected_ = " << connected_ << ".");
|
|
}
|
|
ClientContext context;
|
|
RegisterFunctionRequest request;
|
|
request.set_fnname(name);
|
|
request.set_num_return_vals(num_return_vals);
|
|
request.set_workerid(workerid_);
|
|
AckReply reply;
|
|
scheduler_stub_->RegisterFunction(&context, request, &reply);
|
|
}
|
|
|
|
Task* Worker::receive_next_task() {
|
|
Task* task;
|
|
receive_queue_.receive(&task);
|
|
return task;
|
|
}
|
|
|
|
void Worker::notify_task_completed(bool task_succeeded, std::string error_message) {
|
|
if (!connected_) {
|
|
RAY_LOG(RAY_FATAL, "Attempting to perform notify_task_completed, but connected_ = " << connected_ << ".");
|
|
}
|
|
ClientContext context;
|
|
NotifyTaskCompletedRequest request;
|
|
request.set_workerid(workerid_);
|
|
request.set_task_succeeded(task_succeeded);
|
|
request.set_error_message(error_message);
|
|
AckReply reply;
|
|
scheduler_stub_->NotifyTaskCompleted(&context, request, &reply);
|
|
}
|
|
|
|
void Worker::disconnect() {
|
|
connected_ = false;
|
|
}
|
|
|
|
bool Worker::connected() {
|
|
return connected_;
|
|
}
|
|
|
|
// TODO(rkn): Should we be using pointers or references? And should they be const?
|
|
void Worker::scheduler_info(ClientContext &context, SchedulerInfoRequest &request, SchedulerInfoReply &reply) {
|
|
if (!connected_) {
|
|
RAY_LOG(RAY_FATAL, "Attempting to get scheduler info, but connected_ = " << connected_ << ".");
|
|
}
|
|
scheduler_stub_->SchedulerInfo(&context, request, &reply);
|
|
}
|
|
|
|
// Communication between the WorkerServer and the Worker happens via a message
|
|
// queue. This is because the Python interpreter needs to be single threaded
|
|
// (in our case running in the main thread), whereas the WorkerService will
|
|
// run in a separate thread and potentially utilize multiple threads.
|
|
void Worker::start_worker_service() {
|
|
const char* service_addr = worker_address_.c_str();
|
|
worker_server_thread_ = std::thread([service_addr]() {
|
|
std::string service_address(service_addr);
|
|
std::string::iterator split_point = split_ip_address(service_address);
|
|
std::string port;
|
|
port.assign(split_point, service_address.end());
|
|
WorkerServiceImpl service(service_address);
|
|
ServerBuilder builder;
|
|
builder.AddListeningPort(std::string("0.0.0.0:") + port, grpc::InsecureServerCredentials());
|
|
builder.RegisterService(&service);
|
|
std::unique_ptr<Server> server(builder.BuildAndStart());
|
|
RAY_LOG(RAY_INFO, "worker server listening on " << service_address);
|
|
server->Wait();
|
|
});
|
|
}
|