mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:22:39 +08:00
Merge pull request #36 from amplab/arrow-integration
arrow integration for ndarrays
This commit is contained in:
+3
-1
@@ -77,7 +77,9 @@ if (UNIX AND NOT APPLE)
|
||||
endif()
|
||||
|
||||
add_executable(objstore src/objstore.cc src/ipc.cc ${GENERATED_PROTOBUF_FILES})
|
||||
target_link_libraries(objstore arrow)
|
||||
add_executable(scheduler src/scheduler.cc ${GENERATED_PROTOBUF_FILES})
|
||||
add_library(orchpylib SHARED src/orchpylib.cc src/worker.cc src/ipc.cc ${GENERATED_PROTOBUF_FILES})
|
||||
add_library(orchpylib SHARED src/orchpylib.cc src/worker.cc src/ipc.cc src/serialize.cc ${GENERATED_PROTOBUF_FILES})
|
||||
target_link_libraries(orchpylib arrow)
|
||||
|
||||
install(TARGETS objstore scheduler orchpylib DESTINATION ${CMAKE_SOURCE_DIR}/lib/orchpy/orchpy)
|
||||
|
||||
@@ -55,7 +55,7 @@ public:
|
||||
};
|
||||
|
||||
struct slice {
|
||||
char* data;
|
||||
uint8_t* data;
|
||||
size_t len;
|
||||
};
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from types import ModuleType
|
||||
import typing
|
||||
import numpy as np
|
||||
|
||||
import orchpy
|
||||
import serialization
|
||||
@@ -14,13 +15,19 @@ class Worker(object):
|
||||
|
||||
def put_object(self, objref, value):
|
||||
"""Put `value` in the local object store with objref `objref`. This assumes that the value for `objref` has not yet been placed in the local object store."""
|
||||
object_capsule = serialization.serialize(value)
|
||||
orchpy.lib.put_object(self.handle, objref, object_capsule)
|
||||
if type(value) == np.ndarray:
|
||||
orchpy.lib.put_arrow(self.handle, objref, value)
|
||||
else:
|
||||
object_capsule = serialization.serialize(value)
|
||||
orchpy.lib.put_object(self.handle, objref, object_capsule)
|
||||
|
||||
def get_object(self, objref):
|
||||
"""Return the value from the local object store for objref `objref`. This will block until the value for `objref` has been written to the local object store."""
|
||||
object_capsule = orchpy.lib.get_object(self.handle, objref)
|
||||
return serialization.deserialize(object_capsule)
|
||||
if orchpy.lib.is_arrow(self.handle, objref):
|
||||
return orchpy.lib.get_arrow(self.handle, objref)
|
||||
else:
|
||||
object_capsule = orchpy.lib.get_object(self.handle, objref)
|
||||
return serialization.deserialize(object_capsule)
|
||||
|
||||
def register_function(self, function):
|
||||
"""Notify the scheduler that this worker can execute the function with name `func_name`. Store the function `function` locally."""
|
||||
@@ -53,12 +60,13 @@ def connect(scheduler_addr, objstore_addr, worker_addr, worker=global_worker):
|
||||
worker.connected = True
|
||||
|
||||
def pull(objref, worker=global_worker):
|
||||
object_capsule = orchpy.lib.pull_object(worker.handle, objref)
|
||||
return serialization.deserialize(object_capsule)
|
||||
orchpy.lib.request_object(worker.handle, objref)
|
||||
return worker.get_object(objref)
|
||||
|
||||
def push(value, worker=global_worker):
|
||||
object_capsule = serialization.serialize(value)
|
||||
return orchpy.lib.push_object(worker.handle, object_capsule)
|
||||
objref = orchpy.lib.get_objref(worker.handle)
|
||||
worker.put_object(objref, value)
|
||||
return objref
|
||||
|
||||
def main_loop(worker=global_worker):
|
||||
if not worker.connected:
|
||||
|
||||
@@ -32,7 +32,7 @@ service Scheduler {
|
||||
// Request an object reference for an object that will be pushed to an object store
|
||||
rpc PushObj(PushObjRequest) returns (PushObjReply);
|
||||
// Request delivery of an object from an object store that holds the object to the local object store
|
||||
rpc PullObj(PullObjRequest) returns (AckReply);
|
||||
rpc RequestObj(RequestObjRequest) returns (AckReply);
|
||||
// Used by an object store to tell the scheduler that an object is ready (i.e. has been finalized and can be shared)
|
||||
rpc ObjReady(ObjReadyRequest) returns (AckReply);
|
||||
// Used by the worker to report back and ask for more work
|
||||
@@ -75,9 +75,9 @@ message RemoteCallReply {
|
||||
repeated uint64 result = 1; // Object references of the function return values
|
||||
}
|
||||
|
||||
message PullObjRequest {
|
||||
uint64 workerid = 1; // Worker that tries to pull the object
|
||||
uint64 objref = 2; // Object reference of the object being pulled
|
||||
message RequestObjRequest {
|
||||
uint64 workerid = 1; // Worker that tries to request the object
|
||||
uint64 objref = 2; // Object reference of the object being requested
|
||||
}
|
||||
|
||||
message PushObjRequest {
|
||||
|
||||
+26
-4
@@ -1,9 +1,31 @@
|
||||
#include "ipc.h"
|
||||
|
||||
ObjHandle::ObjHandle(SegmentId segmentid, size_t size, IpcPointer ipcpointer)
|
||||
: segmentid_(segmentid), size_(size), ipcpointer_(ipcpointer)
|
||||
using namespace arrow;
|
||||
|
||||
ObjHandle::ObjHandle(SegmentId segmentid, size_t size, IpcPointer ipcpointer, size_t metadata_offset)
|
||||
: segmentid_(segmentid), size_(size), ipcpointer_(ipcpointer), metadata_offset_(metadata_offset)
|
||||
{}
|
||||
|
||||
Status BufferMemorySource::Write(int64_t position, const uint8_t* data, int64_t nbytes) {
|
||||
// TODO(pcm): error handling
|
||||
std::memcpy(data_ + position, data, nbytes);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BufferMemorySource::ReadAt(int64_t position, int64_t nbytes, std::shared_ptr<Buffer>* out) {
|
||||
// TODO(pcm): error handling
|
||||
*out = std::make_shared<Buffer>(data_ + position, nbytes);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BufferMemorySource::Close() {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64_t BufferMemorySource::Size() const {
|
||||
return size_;
|
||||
}
|
||||
|
||||
MemorySegmentPool::MemorySegmentPool(bool create) : create_mode_(create) { }
|
||||
|
||||
// creates a memory segment if it is not already there; if the pool is in create mode,
|
||||
@@ -37,14 +59,14 @@ ObjHandle MemorySegmentPool::allocate(size_t size) {
|
||||
|
||||
// returns address of the object refered to by the handle, needs to be called on
|
||||
// the process that will use the address
|
||||
char* MemorySegmentPool::get_address(ObjHandle pointer) {
|
||||
uint8_t* MemorySegmentPool::get_address(ObjHandle pointer) {
|
||||
if (pointer.segmentid() >= segments_.size()) {
|
||||
for (int i = segments_.size(); i <= pointer.segmentid(); ++i) {
|
||||
open_segment(i);
|
||||
}
|
||||
}
|
||||
managed_shared_memory* segment = segments_[pointer.segmentid()].get();
|
||||
return static_cast<char*>(segment->get_address_from_handle(pointer.ipcpointer()));
|
||||
return static_cast<uint8_t*>(segment->get_address_from_handle(pointer.ipcpointer()));
|
||||
}
|
||||
|
||||
MemorySegmentPool::~MemorySegmentPool() {
|
||||
|
||||
@@ -6,6 +6,9 @@
|
||||
#include <boost/interprocess/managed_shared_memory.hpp>
|
||||
#include <boost/interprocess/ipc/message_queue.hpp>
|
||||
|
||||
#include <arrow/api.h>
|
||||
#include <arrow/ipc/memory.h>
|
||||
|
||||
#include "orchestra/orchestra.h"
|
||||
|
||||
using namespace boost::interprocess;
|
||||
@@ -91,6 +94,7 @@ struct ObjRequest {
|
||||
ObjRequestType type; // do we want to allocate a new object or get a handle?
|
||||
ObjRef objref; // object reference of the object to be returned/allocated
|
||||
int64_t size; // if allocate, that's the size of the object
|
||||
int64_t metadata_offset; // if sending 'DONE', that's the location of the metadata relative to the beginning of the object
|
||||
};
|
||||
|
||||
typedef size_t SegmentId; // index into a memory segment table
|
||||
@@ -101,14 +105,30 @@ typedef managed_shared_memory::handle_t IpcPointer;
|
||||
|
||||
class ObjHandle {
|
||||
public:
|
||||
ObjHandle(SegmentId segmentid = 0, size_t size = 0, IpcPointer ipcpointer = IpcPointer());
|
||||
ObjHandle(SegmentId segmentid = 0, size_t size = 0, IpcPointer ipcpointer = IpcPointer(), size_t metadata_offset = 0);
|
||||
SegmentId segmentid() { return segmentid_; }
|
||||
size_t size() { return size_; }
|
||||
IpcPointer ipcpointer() { return ipcpointer_; }
|
||||
size_t metadata_offset() { return metadata_offset_; }
|
||||
void set_metadata_offset(size_t metadata_offset) {metadata_offset_ = metadata_offset; }
|
||||
private:
|
||||
SegmentId segmentid_;
|
||||
size_t size_;
|
||||
IpcPointer ipcpointer_;
|
||||
SegmentId segmentid_; // which shared memory file the object is stored in
|
||||
IpcPointer ipcpointer_; // pointer to the beginning of the object, exchangeable between processes
|
||||
size_t size_; // total size of the object
|
||||
size_t metadata_offset_; // offset of the metadata that describes this object
|
||||
};
|
||||
|
||||
class BufferMemorySource: public arrow::ipc::MemorySource {
|
||||
public:
|
||||
BufferMemorySource(uint8_t* data, int64_t capacity) : data_(data), capacity_(capacity), size_(0) {}
|
||||
virtual arrow::Status ReadAt(int64_t position, int64_t nbytes, std::shared_ptr<arrow::Buffer>* out);
|
||||
virtual arrow::Status Close();
|
||||
virtual arrow::Status Write(int64_t position, const uint8_t* data, int64_t nbytes);
|
||||
virtual int64_t Size() const;
|
||||
private:
|
||||
uint8_t* data_;
|
||||
int64_t capacity_;
|
||||
int64_t size_;
|
||||
};
|
||||
|
||||
// Memory segment pool: A collection of shared memory segments
|
||||
@@ -123,8 +143,8 @@ class MemorySegmentPool {
|
||||
public:
|
||||
MemorySegmentPool(bool create = false); // can be used in two modes: create mode and open mode (see above)
|
||||
~MemorySegmentPool();
|
||||
ObjHandle allocate(size_t nbytes); // allocate a new shared object, potentially creating a new segment (only run on object store)
|
||||
char* get_address(ObjHandle pointer); // get address of shared object
|
||||
ObjHandle allocate(size_t nbytes); // allocate memory, potentially creating a new segment (only run on object store)
|
||||
uint8_t* get_address(ObjHandle pointer); // get address of shared object
|
||||
private:
|
||||
void open_segment(SegmentId segmentid, size_t size = 0); // create a segment or map an existing one into memory
|
||||
bool create_mode_;
|
||||
|
||||
+2
-1
@@ -9,7 +9,7 @@ Status ObjStoreClient::upload_data_to(slice data, ObjRef objref, ObjStore::Stub&
|
||||
ClientContext context;
|
||||
AckReply reply;
|
||||
std::unique_ptr<ClientWriter<ObjChunk> > writer(stub.StreamObj(&context, &reply));
|
||||
const char* head = data.data;
|
||||
const uint8_t* head = data.data;
|
||||
for (size_t i = 0; i < data.len; i += CHUNK_SIZE) {
|
||||
chunk.set_objref(objref);
|
||||
chunk.set_totalsize(data.len);
|
||||
@@ -150,6 +150,7 @@ void ObjStoreService::process_requests() {
|
||||
break;
|
||||
case ObjRequestType::DONE: {
|
||||
std::pair<ObjHandle, bool>& item = memory_[request.objref];
|
||||
item.first.set_metadata_offset(request.metadata_offset);
|
||||
item.second = true;
|
||||
std::lock_guard<std::mutex> pull_queue_lock(pull_queue_lock_);
|
||||
for (size_t i = 0; i < pull_queue_.size(); ++i) {
|
||||
|
||||
+81
-12
@@ -4,12 +4,16 @@
|
||||
|
||||
#include <Python.h>
|
||||
#include <structmember.h>
|
||||
#define PY_ARRAY_UNIQUE_SYMBOL ORCHESTRA_ARRAY_API
|
||||
#include <numpy/arrayobject.h>
|
||||
#include <arrow/api.h>
|
||||
#include <iostream>
|
||||
|
||||
#include "types.pb.h"
|
||||
#include "worker.h"
|
||||
|
||||
#include "serialize.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
// Object references
|
||||
@@ -230,6 +234,13 @@ int serialize(PyObject* val, Obj* obj) {
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NPY_INT64: {
|
||||
npy_int64* buffer = (npy_int64*) PyArray_DATA(array);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
data->add_int_data(buffer[i]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NPY_UINT8: {
|
||||
npy_uint8* buffer = (npy_uint8*) PyArray_DATA(array);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
@@ -237,6 +248,13 @@ int serialize(PyObject* val, Obj* obj) {
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NPY_UINT64: {
|
||||
npy_uint64* buffer = (npy_uint64*) PyArray_DATA(array);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
data->add_uint_data(buffer[i]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NPY_OBJECT: { // FIXME(pcm): Support arbitrary python objects, not only objrefs
|
||||
PyArrayIterObject* iter = (PyArrayIterObject*) PyArray_IterNew((PyObject*)array);
|
||||
while (PyArray_ITER_NOTDONE(iter)) {
|
||||
@@ -327,6 +345,13 @@ PyObject* deserialize(const Obj& obj) {
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NPY_INT64: {
|
||||
npy_int64* buffer = (npy_int64*) PyArray_DATA(pyarray);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
buffer[i] = array.int_data(i);
|
||||
}
|
||||
}
|
||||
break;
|
||||
default:
|
||||
PyErr_SetString(OrchPyError, "deserialization: internal error (array type not implemented)");
|
||||
return NULL;
|
||||
@@ -341,6 +366,13 @@ PyObject* deserialize(const Obj& obj) {
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NPY_UINT64: {
|
||||
npy_uint64* buffer = (npy_uint64*) PyArray_DATA(pyarray);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
buffer[i] = array.uint_data(i);
|
||||
}
|
||||
}
|
||||
break;
|
||||
default:
|
||||
PyErr_SetString(OrchPyError, "deserialization: internal error (array type not implemented)");
|
||||
return NULL;
|
||||
@@ -374,6 +406,43 @@ PyObject* serialize_object(PyObject* self, PyObject* args) {
|
||||
return PyCapsule_New(static_cast<void*>(obj), "obj", NULL);
|
||||
}
|
||||
|
||||
PyObject* put_arrow(PyObject* self, PyObject* args) {
|
||||
Worker* worker;
|
||||
ObjRef objref;
|
||||
PyObject* value;
|
||||
if (!PyArg_ParseTuple(args, "O&O&O", &PyObjectToWorker, &worker, &PyObjectToObjRef, &objref, &value)) {
|
||||
return NULL;
|
||||
}
|
||||
if (!PyArray_Check(value)) {
|
||||
PyErr_SetString(PyExc_TypeError, "only support arrays at this point");
|
||||
return NULL;
|
||||
}
|
||||
PyArrayObject* array = PyArray_GETCONTIGUOUS((PyArrayObject*) value);
|
||||
worker->put_arrow(objref, array);
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
PyObject* get_arrow(PyObject* self, PyObject* args) {
|
||||
Worker* worker;
|
||||
ObjRef objref;
|
||||
if (!PyArg_ParseTuple(args, "O&O&", &PyObjectToWorker, &worker, &PyObjectToObjRef, &objref)) {
|
||||
return NULL;
|
||||
}
|
||||
return (PyObject*) worker->get_arrow(objref);
|
||||
}
|
||||
|
||||
PyObject* is_arrow(PyObject* self, PyObject* args) {
|
||||
Worker* worker;
|
||||
ObjRef objref;
|
||||
if (!PyArg_ParseTuple(args, "O&O&", &PyObjectToWorker, &worker, &PyObjectToObjRef, &objref)) {
|
||||
return NULL;
|
||||
}
|
||||
if (worker->is_arrow(objref))
|
||||
Py_RETURN_TRUE;
|
||||
else
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
|
||||
PyObject* deserialize_object(PyObject* self, PyObject* args) {
|
||||
Obj* obj;
|
||||
if (!PyArg_ParseTuple(args, "O&", &PyObjectToObj, &obj)) {
|
||||
@@ -496,13 +565,12 @@ PyObject* register_function(PyObject* self, PyObject* args) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
PyObject* push_object(PyObject* self, PyObject* args) {
|
||||
PyObject* get_objref(PyObject* self, PyObject* args) {
|
||||
Worker* worker;
|
||||
Obj* obj;
|
||||
if (!PyArg_ParseTuple(args, "O&O&", &PyObjectToWorker, &worker, &PyObjectToObj, &obj)) {
|
||||
if (!PyArg_ParseTuple(args, "O&", &PyObjectToWorker, &worker)) {
|
||||
return NULL;
|
||||
}
|
||||
ObjRef objref = worker->push_object(obj);
|
||||
ObjRef objref = worker->get_objref();
|
||||
return make_pyobjref(objref);
|
||||
}
|
||||
|
||||
@@ -525,20 +593,18 @@ PyObject* get_object(PyObject* self, PyObject* args) {
|
||||
}
|
||||
slice s = worker->get_object(objref);
|
||||
Obj* obj = new Obj(); // TODO: Make sure this will get deleted
|
||||
obj->ParseFromString(std::string(s.data, s.len));
|
||||
obj->ParseFromString(std::string(reinterpret_cast<char*>(s.data), s.len));
|
||||
return PyCapsule_New(static_cast<void*>(obj), "obj", NULL);
|
||||
}
|
||||
|
||||
PyObject* pull_object(PyObject* self, PyObject* args) {
|
||||
PyObject* request_object(PyObject* self, PyObject* args) {
|
||||
Worker* worker;
|
||||
ObjRef objref;
|
||||
if (!PyArg_ParseTuple(args, "O&O&", &PyObjectToWorker, &worker, &PyObjectToObjRef, &objref)) {
|
||||
return NULL;
|
||||
}
|
||||
slice s = worker->pull_object(objref);
|
||||
Obj* obj = new Obj(); // TODO: Make sure this will get deleted
|
||||
obj->ParseFromString(std::string(s.data, s.len));
|
||||
return PyCapsule_New(static_cast<void*>(obj), "obj", NULL);
|
||||
worker->request_object(objref);
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
PyObject* start_worker_service(PyObject* self, PyObject* args) {
|
||||
@@ -553,14 +619,17 @@ PyObject* start_worker_service(PyObject* self, PyObject* args) {
|
||||
static PyMethodDef OrchPyLibMethods[] = {
|
||||
{ "serialize_object", serialize_object, METH_VARARGS, "serialize an object to protocol buffers" },
|
||||
{ "deserialize_object", deserialize_object, METH_VARARGS, "deserialize an object from protocol buffers" },
|
||||
{ "put_arrow", put_arrow, METH_VARARGS, "put an arrow array on the local object store"},
|
||||
{ "get_arrow", get_arrow, METH_VARARGS, "get an arrow array from the local object store"},
|
||||
{ "is_arrow", is_arrow, METH_VARARGS, "is the object in the local object store an arrow object?"},
|
||||
{ "serialize_call", serialize_call, METH_VARARGS, "serialize a call to protocol buffers" },
|
||||
{ "deserialize_call", deserialize_call, METH_VARARGS, "deserialize a call from protocol buffers" },
|
||||
{ "create_worker", create_worker, METH_VARARGS, "connect to the scheduler and the object store" },
|
||||
{ "register_function", register_function, METH_VARARGS, "register a function with the scheduler" },
|
||||
{ "put_object", put_object, METH_VARARGS, "put a protocol buffer object (given as a capsule) on the local object store" },
|
||||
{ "get_object", get_object, METH_VARARGS, "get protocol buffer object from the local object store" },
|
||||
{ "push_object", push_object, METH_VARARGS, "push a protocol buffer object (given as a capsule) to the object store" },
|
||||
{ "pull_object" , pull_object, METH_VARARGS, "pull object with a given object id from the object store" },
|
||||
{ "get_objref", get_objref, METH_VARARGS, "register a new object reference with the scheduler" },
|
||||
{ "request_object" , request_object, METH_VARARGS, "request an object to be delivered to the local object store" },
|
||||
{ "wait_for_next_task", wait_for_next_task, METH_VARARGS, "get next task from scheduler (blocking)" },
|
||||
{ "remote_call", remote_call, METH_VARARGS, "call a remote function" },
|
||||
{ "notify_task_completed", notify_task_completed, METH_VARARGS, "notify the scheduler that a task has been completed" },
|
||||
|
||||
+4
-1
@@ -38,7 +38,7 @@ Status SchedulerService::PushObj(ServerContext* context, const PushObjRequest* r
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
Status SchedulerService::PullObj(ServerContext* context, const PullObjRequest* request, AckReply* reply) {
|
||||
Status SchedulerService::RequestObj(ServerContext* context, const RequestObjRequest* request, AckReply* reply) {
|
||||
objtable_lock_.lock();
|
||||
size_t size = objtable_.size();
|
||||
objtable_lock_.unlock();
|
||||
@@ -242,6 +242,9 @@ ObjRef SchedulerService::register_new_object() {
|
||||
|
||||
void SchedulerService::add_location(ObjRef objref, ObjStoreId objstoreid) {
|
||||
std::lock_guard<std::mutex> objtable_lock(objtable_lock_);
|
||||
if (objref >= objtable_.size()) {
|
||||
ORCH_LOG(ORCH_FATAL, "trying to put object on object store that was not registered with the scheduler");
|
||||
}
|
||||
// do a binary search
|
||||
auto pos = std::lower_bound(objtable_[objref].begin(), objtable_[objref].end(), objstoreid);
|
||||
if (pos == objtable_[objref].end() || objstoreid < *pos) {
|
||||
|
||||
+1
-1
@@ -39,7 +39,7 @@ class SchedulerService : public Scheduler::Service {
|
||||
public:
|
||||
Status RemoteCall(ServerContext* context, const RemoteCallRequest* request, RemoteCallReply* reply) override;
|
||||
Status PushObj(ServerContext* context, const PushObjRequest* request, PushObjReply* reply) override;
|
||||
Status PullObj(ServerContext* context, const PullObjRequest* request, AckReply* reply) override;
|
||||
Status RequestObj(ServerContext* context, const RequestObjRequest* request, AckReply* reply) override;
|
||||
Status RegisterObjStore(ServerContext* context, const RegisterObjStoreRequest* request, RegisterObjStoreReply* reply) override;
|
||||
Status RegisterWorker(ServerContext* context, const RegisterWorkerRequest* request, RegisterWorkerReply* reply) override;
|
||||
Status RegisterFunction(ServerContext* context, const RegisterFunctionRequest* request, AckReply* reply) override;
|
||||
|
||||
@@ -0,0 +1,201 @@
|
||||
#include "serialize.h"
|
||||
|
||||
using namespace arrow;
|
||||
|
||||
template <int TYPE>
|
||||
struct npy_traits {
|
||||
};
|
||||
|
||||
template <>
|
||||
struct npy_traits<NPY_BOOL> {
|
||||
typedef uint8_t value_type;
|
||||
static const std::shared_ptr<BooleanType> primitive_type;
|
||||
using ArrayType = arrow::BooleanArray;
|
||||
};
|
||||
|
||||
const std::shared_ptr<BooleanType> npy_traits<NPY_BOOL>::primitive_type = std::make_shared<BooleanType>();
|
||||
|
||||
#define NPY_INT_DECL(TYPE, CapType, T) \
|
||||
template <> \
|
||||
struct npy_traits<NPY_##TYPE> { \
|
||||
typedef T value_type; \
|
||||
static const std::shared_ptr<CapType##Type> primitive_type; \
|
||||
using ArrayType = arrow::CapType##Array; \
|
||||
}; \
|
||||
\
|
||||
const std::shared_ptr<CapType##Type> npy_traits<NPY_##TYPE>::primitive_type = std::make_shared<CapType##Type>();
|
||||
|
||||
NPY_INT_DECL(INT8, Int8, int8_t);
|
||||
NPY_INT_DECL(INT16, Int16, int16_t);
|
||||
NPY_INT_DECL(INT32, Int32, int32_t);
|
||||
NPY_INT_DECL(INT64, Int64, int64_t);
|
||||
NPY_INT_DECL(UINT8, UInt8, uint8_t);
|
||||
NPY_INT_DECL(UINT16, UInt16, uint16_t);
|
||||
NPY_INT_DECL(UINT32, UInt32, uint32_t);
|
||||
NPY_INT_DECL(UINT64, UInt64, uint64_t);
|
||||
|
||||
template <>
|
||||
struct npy_traits<NPY_FLOAT32> {
|
||||
typedef float value_type;
|
||||
static const std::shared_ptr<FloatType> primitive_type;
|
||||
using ArrayType = arrow::FloatArray;
|
||||
};
|
||||
|
||||
const std::shared_ptr<FloatType> npy_traits<NPY_FLOAT32>::primitive_type = std::make_shared<FloatType>();
|
||||
|
||||
template <>
|
||||
struct npy_traits<NPY_FLOAT64> {
|
||||
typedef double value_type;
|
||||
static const std::shared_ptr<DoubleType> primitive_type;
|
||||
using ArrayType = arrow::DoubleArray;
|
||||
};
|
||||
|
||||
const std::shared_ptr<DoubleType> npy_traits<NPY_FLOAT64>::primitive_type = std::make_shared<DoubleType>();
|
||||
|
||||
template <>
|
||||
struct npy_traits<NPY_OBJECT> {
|
||||
typedef PyObject* value_type;
|
||||
};
|
||||
|
||||
template<int NpyType>
|
||||
std::shared_ptr<arrow::RowBatch> make_flat_array(const std::string& fieldname, size_t size, std::shared_ptr<arrow::Buffer> data) {
|
||||
auto field = std::make_shared<arrow::Field>(fieldname, npy_traits<NpyType>::primitive_type);
|
||||
std::shared_ptr<arrow::Schema> schema(new arrow::Schema({field}));
|
||||
auto array = std::make_shared<typename npy_traits<NpyType>::ArrayType>(size, data);
|
||||
return std::shared_ptr<arrow::RowBatch>(new RowBatch(schema, size, {array}));
|
||||
}
|
||||
|
||||
const int64_t MAX_METADATA_SIZE = 5000;
|
||||
|
||||
#define SIZE_ARROW_CASE(TYPE) \
|
||||
case TYPE: \
|
||||
return size * sizeof(npy_traits<TYPE>::value_type) + MAX_METADATA_SIZE;
|
||||
|
||||
size_t arrow_size(PyArrayObject* array) {
|
||||
npy_intp size = PyArray_SIZE(array);
|
||||
switch (PyArray_TYPE(array)) {
|
||||
SIZE_ARROW_CASE(NPY_INT8)
|
||||
SIZE_ARROW_CASE(NPY_INT16)
|
||||
SIZE_ARROW_CASE(NPY_INT32)
|
||||
SIZE_ARROW_CASE(NPY_INT64)
|
||||
SIZE_ARROW_CASE(NPY_UINT8)
|
||||
SIZE_ARROW_CASE(NPY_UINT16)
|
||||
SIZE_ARROW_CASE(NPY_UINT32)
|
||||
SIZE_ARROW_CASE(NPY_UINT64)
|
||||
SIZE_ARROW_CASE(NPY_FLOAT)
|
||||
SIZE_ARROW_CASE(NPY_DOUBLE)
|
||||
default:
|
||||
ORCH_LOG(ORCH_FATAL, "serialization: numpy datatype not know");
|
||||
}
|
||||
}
|
||||
|
||||
#define SERIALIZE_ARROW_CASE(TYPE) \
|
||||
case TYPE: \
|
||||
{ \
|
||||
data = std::make_shared<arrow::Buffer>(reinterpret_cast<uint8_t*>(PyArray_DATA(array)), sizeof(npy_traits<TYPE>::value_type) * size); \
|
||||
batch_size = size * sizeof(npy_traits<TYPE>::value_type) + MAX_METADATA_SIZE; \
|
||||
batch = make_flat_array<TYPE>("data", size, data); \
|
||||
} \
|
||||
break;
|
||||
|
||||
// TODO(pcm): At the moment, this assumes that arrays are consecutive in memory
|
||||
void store_arrow(PyArrayObject* array, ObjHandle& location, MemorySegmentPool* pool) {
|
||||
npy_intp size = PyArray_SIZE(array);
|
||||
std::shared_ptr<arrow::Buffer> data;
|
||||
std::shared_ptr<arrow::RowBatch> batch;
|
||||
int64_t batch_size = 0;
|
||||
switch (PyArray_TYPE(array)) {
|
||||
SERIALIZE_ARROW_CASE(NPY_INT8)
|
||||
SERIALIZE_ARROW_CASE(NPY_INT16)
|
||||
SERIALIZE_ARROW_CASE(NPY_INT32)
|
||||
SERIALIZE_ARROW_CASE(NPY_INT64)
|
||||
SERIALIZE_ARROW_CASE(NPY_UINT8)
|
||||
SERIALIZE_ARROW_CASE(NPY_UINT16)
|
||||
SERIALIZE_ARROW_CASE(NPY_UINT32)
|
||||
SERIALIZE_ARROW_CASE(NPY_UINT64)
|
||||
SERIALIZE_ARROW_CASE(NPY_FLOAT)
|
||||
SERIALIZE_ARROW_CASE(NPY_DOUBLE)
|
||||
default:
|
||||
ORCH_LOG(ORCH_FATAL, "serialization: numpy datatype not know");
|
||||
}
|
||||
|
||||
// int64_t data_batch_size = ipc::GetRowBatchSize(batch.get()); // FIXME(pcm): once GetRowBatchSize is implemented, use it
|
||||
|
||||
size_t ndim = PyArray_NDIM(array);
|
||||
MemoryPool* default_pool = arrow::default_memory_pool();
|
||||
|
||||
auto metadata = std::make_shared<PoolBuffer>(default_pool);
|
||||
size_t metadata_size = 1 + ndim + 1; // dtype, list of shapes, pointer to header of the data segment
|
||||
metadata->Resize(metadata_size * sizeof(int64_t));
|
||||
|
||||
int64_t* buffer = reinterpret_cast<int64_t*>(metadata->mutable_data());
|
||||
buffer[0] = PyArray_TYPE(array);
|
||||
// serialize the shape information
|
||||
for (size_t i = 0; i < ndim; ++i) {
|
||||
buffer[i+1] = PyArray_DIM(array, i);
|
||||
}
|
||||
std::shared_ptr<arrow::RowBatch> metadata_batch = make_flat_array<NPY_UINT64>("metadata", metadata_size, metadata);
|
||||
|
||||
// int64_t metadata_batch_size = ipc::GetRowBatchSize(metadata_batch.get()); // FIXME(pcm): once GetRowBatchSize is implemented, use it
|
||||
|
||||
uint8_t* address = pool->get_address(location);
|
||||
auto source = std::make_shared<BufferMemorySource>(address, location.size());
|
||||
|
||||
int64_t data_header_offset = 0;
|
||||
ipc::WriteRowBatch(source.get(), batch.get(), 0, &data_header_offset);
|
||||
|
||||
buffer[1 + ndim] = data_header_offset;
|
||||
|
||||
int64_t metadata_header_offset = 0;
|
||||
ipc::WriteRowBatch(source.get(), metadata_batch.get(), location.size() + MAX_METADATA_SIZE/2, &metadata_header_offset);
|
||||
location.set_metadata_offset(metadata_header_offset);
|
||||
}
|
||||
|
||||
template<int NpyType>
|
||||
std::shared_ptr<arrow::Array> read_flat_array(BufferMemorySource* source, int64_t metadata_offset) {
|
||||
std::shared_ptr<ipc::RowBatchReader> reader;
|
||||
Status s = ipc::RowBatchReader::Open(source, metadata_offset, &reader);
|
||||
if (!s.ok()) {
|
||||
ORCH_LOG(ORCH_FATAL, s.ToString());
|
||||
}
|
||||
auto field = std::make_shared<arrow::Field>("data", npy_traits<NpyType>::primitive_type);
|
||||
std::shared_ptr<arrow::Schema> schema(new arrow::Schema({field}));
|
||||
std::shared_ptr<arrow::RowBatch> data;
|
||||
reader->GetRowBatch(schema, &data);
|
||||
return data->column(0);
|
||||
|
||||
}
|
||||
|
||||
#define DESERIALIZE_ARROW_CASE(TYPE) \
|
||||
case TYPE: \
|
||||
{ \
|
||||
auto array = read_flat_array<TYPE>(source.get(), buffer[metadata_array->length()-1]); \
|
||||
auto data_primitive_array = dynamic_cast<npy_traits<TYPE>::ArrayType*>(array.get()); \
|
||||
return PyArray_SimpleNewFromData(dims.size(), &dims[0], TYPE, (void*)data_primitive_array->raw_data()); \
|
||||
}
|
||||
|
||||
PyObject* deserialize_array(ObjHandle handle, MemorySegmentPool* pool) {
|
||||
auto source = std::make_shared<BufferMemorySource>(pool->get_address(handle), handle.size());
|
||||
auto metadata_array = read_flat_array<NPY_UINT64>(source.get(), handle.metadata_offset());
|
||||
const uint64_t* buffer = dynamic_cast<UInt64Array*>(metadata_array.get())->raw_data();
|
||||
uint64_t type = buffer[0];
|
||||
std::vector<npy_intp> dims;
|
||||
for (int i = 1; i < metadata_array->length()-1; ++i) {
|
||||
dims.push_back(buffer[i]);
|
||||
}
|
||||
|
||||
switch (type) {
|
||||
DESERIALIZE_ARROW_CASE(NPY_INT8)
|
||||
DESERIALIZE_ARROW_CASE(NPY_INT16)
|
||||
DESERIALIZE_ARROW_CASE(NPY_INT32)
|
||||
DESERIALIZE_ARROW_CASE(NPY_INT64)
|
||||
DESERIALIZE_ARROW_CASE(NPY_UINT8)
|
||||
DESERIALIZE_ARROW_CASE(NPY_UINT16)
|
||||
DESERIALIZE_ARROW_CASE(NPY_UINT32)
|
||||
DESERIALIZE_ARROW_CASE(NPY_UINT64)
|
||||
DESERIALIZE_ARROW_CASE(NPY_FLOAT)
|
||||
DESERIALIZE_ARROW_CASE(NPY_DOUBLE)
|
||||
default:
|
||||
ORCH_LOG(ORCH_FATAL, "deserialization: numpy datatype not know");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
#ifndef ORCHESTRA_SERIALIZE_H
|
||||
#define ORCHESTRA_SERIALIZE_H
|
||||
|
||||
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
|
||||
|
||||
#include <arrow/api.h>
|
||||
#include <arrow/ipc/memory.h>
|
||||
#include <arrow/ipc/adapter.h>
|
||||
#include <Python.h>
|
||||
#define NO_IMPORT_ARRAY
|
||||
#define PY_ARRAY_UNIQUE_SYMBOL ORCHESTRA_ARRAY_API
|
||||
#include <numpy/arrayobject.h>
|
||||
#include <memory>
|
||||
|
||||
#include "ipc.h"
|
||||
|
||||
size_t arrow_size(PyArrayObject* array);
|
||||
void store_arrow(PyArrayObject* array, ObjHandle& location, MemorySegmentPool* pool);
|
||||
PyObject* deserialize_array(ObjHandle handle, MemorySegmentPool* pool);
|
||||
|
||||
#endif
|
||||
+47
-11
@@ -1,4 +1,4 @@
|
||||
# include "worker.h"
|
||||
#include "worker.h"
|
||||
|
||||
Status WorkerServiceImpl::InvokeCall(ServerContext* context, const InvokeCallRequest* request, InvokeCallReply* reply) {
|
||||
call_ = request->call(); // Copy call
|
||||
@@ -36,26 +36,23 @@ void Worker::register_worker(const std::string& worker_address, const std::strin
|
||||
return;
|
||||
}
|
||||
|
||||
slice Worker::pull_object(ObjRef objref) {
|
||||
PullObjRequest request;
|
||||
void Worker::request_object(ObjRef objref) {
|
||||
RequestObjRequest request;
|
||||
request.set_workerid(workerid_);
|
||||
request.set_objref(objref);
|
||||
AckReply reply;
|
||||
ClientContext context;
|
||||
Status status = scheduler_stub_->PullObj(&context, request, &reply);
|
||||
return get_object(objref);
|
||||
Status status = scheduler_stub_->RequestObj(&context, request, &reply);
|
||||
return;
|
||||
}
|
||||
|
||||
ObjRef Worker::push_object(const Obj* obj) {
|
||||
ObjRef Worker::get_objref() {
|
||||
// first get objref for the new object
|
||||
PushObjRequest push_request;
|
||||
PushObjReply push_reply;
|
||||
ClientContext push_context;
|
||||
Status push_status = scheduler_stub_->PushObj(&push_context, push_request, &push_reply);
|
||||
ObjRef objref = push_reply.objref();
|
||||
// then stream the object to the object store
|
||||
put_object(objref, obj);
|
||||
return objref;
|
||||
return push_reply.objref();
|
||||
}
|
||||
|
||||
slice Worker::get_object(ObjRef objref) {
|
||||
@@ -84,12 +81,51 @@ void Worker::put_object(ObjRef objref, const Obj* obj) {
|
||||
request_obj_queue_.send(&request);
|
||||
ObjHandle result;
|
||||
receive_obj_queue_.receive(&result);
|
||||
char* target = segmentpool_.get_address(result);
|
||||
uint8_t* target = segmentpool_.get_address(result);
|
||||
std::memcpy(target, &data[0], data.size());
|
||||
request.type = ObjRequestType::DONE;
|
||||
request.metadata_offset = 0;
|
||||
request_obj_queue_.send(&request);
|
||||
}
|
||||
|
||||
void Worker::put_arrow(ObjRef objref, PyArrayObject* array) {
|
||||
ObjRequest request;
|
||||
size_t size = arrow_size(array);
|
||||
request.workerid = workerid_;
|
||||
request.type = ObjRequestType::ALLOC;
|
||||
request.objref = objref;
|
||||
request.size = size;
|
||||
request_obj_queue_.send(&request);
|
||||
ObjHandle result;
|
||||
receive_obj_queue_.receive(&result);
|
||||
store_arrow(array, result, &segmentpool_);
|
||||
request.type = ObjRequestType::DONE;
|
||||
request.metadata_offset = result.metadata_offset();
|
||||
request_obj_queue_.send(&request);
|
||||
}
|
||||
|
||||
PyArrayObject* Worker::get_arrow(ObjRef objref) {
|
||||
ObjRequest request;
|
||||
request.workerid = workerid_;
|
||||
request.type = ObjRequestType::GET;
|
||||
request.objref = objref;
|
||||
request_obj_queue_.send(&request);
|
||||
ObjHandle result;
|
||||
receive_obj_queue_.receive(&result);
|
||||
return (PyArrayObject*)deserialize_array(result, &segmentpool_);
|
||||
}
|
||||
|
||||
bool Worker::is_arrow(ObjRef objref) {
|
||||
ObjRequest request;
|
||||
request.workerid = workerid_;
|
||||
request.type = ObjRequestType::GET;
|
||||
request.objref = objref;
|
||||
request_obj_queue_.send(&request);
|
||||
ObjHandle result;
|
||||
receive_obj_queue_.receive(&result);
|
||||
return result.metadata_offset() != 0;
|
||||
}
|
||||
|
||||
void Worker::register_function(const std::string& name, size_t num_return_vals) {
|
||||
ClientContext context;
|
||||
RegisterFunctionRequest request;
|
||||
|
||||
+12
-4
@@ -16,6 +16,7 @@ using grpc::Status;
|
||||
#include "orchestra.grpc.pb.h"
|
||||
#include "orchestra/orchestra.h"
|
||||
#include "ipc.h"
|
||||
#include "serialize.h"
|
||||
|
||||
using grpc::Channel;
|
||||
using grpc::ClientContext;
|
||||
@@ -42,14 +43,21 @@ class Worker {
|
||||
RemoteCallReply remote_call(RemoteCallRequest* request);
|
||||
// send request to the scheduler to register this worker
|
||||
void register_worker(const std::string& worker_address, const std::string& objstore_address);
|
||||
// push object to local object store, register it with the server and return object reference
|
||||
ObjRef push_object(const Obj* obj);
|
||||
// pull object from a potentially remote object store
|
||||
slice pull_object(ObjRef objref);
|
||||
// get a new object reference that is registered with the scheduler
|
||||
ObjRef get_objref();
|
||||
// request an object to be delivered to the local object store
|
||||
void request_object(ObjRef objref);
|
||||
// stores an object to the local object store
|
||||
void put_object(ObjRef objref, const Obj* obj);
|
||||
// retrieve serialized object from local object store
|
||||
slice get_object(ObjRef objref);
|
||||
// stores an arrow object to the local object store
|
||||
// FIXME(pcm): Once we have structs in arrow, get rid of the memcpy here
|
||||
void put_arrow(ObjRef objref, PyArrayObject* array);
|
||||
// gets an arrow object from the local object store
|
||||
PyArrayObject* get_arrow(ObjRef objref);
|
||||
// determine if the object stored in objref is an arrow object // TODO(pcm): more general mechanism for this?
|
||||
bool is_arrow(ObjRef objref);
|
||||
// register function with scheduler
|
||||
void register_function(const std::string& name, size_t num_return_vals);
|
||||
// start the worker server which accepts tasks from the scheduler and stores
|
||||
|
||||
Reference in New Issue
Block a user