From 5da148c1ab4f240351f9152da72bf70c702bb1d2 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 22 Feb 2016 13:55:06 -0800 Subject: [PATCH] getting the object store working --- CMakeLists.txt | 72 ++++++++++++++ Makefile | 89 ----------------- include/orchestra/orchestra.h | 47 +++++++++ lib/orchlib/orchlib.cc | 86 ---------------- lib/orchlib/orchlib.h | 11 --- lib/orchpy/orchpy/services.py | 30 ++++++ lib/orchpy/orchpy/unison.pyx | 111 +++++++++++++++++---- lib/orchpy/orchpy/worker.pyx | 179 ++++++++++++++++++++++++++++++++-- lib/orchpy/setup.py | 6 +- protos/orchestra.proto | 113 ++++++++++++++++++--- protos/types.proto | 11 ++- src/objstore.cc | 59 +++++++++++ src/objstore.h | 133 +++++++++++++++++++++++++ src/orchlib.cc | 29 ++++++ src/orchlib.h | 20 ++++ src/scheduler.cc | 0 src/scheduler.h | 145 +++++++++++++++++++++++++++ src/scheduler_server.cc | 50 ++++++++++ src/scheduler_server.h | 45 +++++++++ src/server.cc | 88 ----------------- test/gen-python-code.sh | 4 + test/runtest.py | 154 +++++++++++++++++++++++++++++ 22 files changed, 1164 insertions(+), 318 deletions(-) create mode 100644 CMakeLists.txt delete mode 100644 Makefile create mode 100644 include/orchestra/orchestra.h delete mode 100644 lib/orchlib/orchlib.cc delete mode 100644 lib/orchlib/orchlib.h create mode 100644 lib/orchpy/orchpy/services.py create mode 100644 src/objstore.cc create mode 100644 src/objstore.h create mode 100644 src/orchlib.cc create mode 100644 src/orchlib.h create mode 100644 src/scheduler.cc create mode 100644 src/scheduler.h create mode 100644 src/scheduler_server.cc create mode 100644 src/scheduler_server.h delete mode 100644 src/server.cc create mode 100644 test/gen-python-code.sh create mode 100644 test/runtest.py diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 000000000..7777ca8af --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,72 @@ +cmake_minimum_required(VERSION 2.8) + +project(orchestra) + +find_package(Protobuf REQUIRED) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") + +include_directories("${CMAKE_SOURCE_DIR}/include") + +set(PROTO_PATH "${CMAKE_SOURCE_DIR}/protos") + +set(ORCHESTRA_PROTO "${PROTO_PATH}/orchestra.proto") +set(TYPES_PROTO "${PROTO_PATH}/types.proto") +set(GENERATED_PROTOBUF_PATH "${CMAKE_BINARY_DIR}/generated") +file(MAKE_DIRECTORY ${GENERATED_PROTOBUF_PATH}) + +set(ORCHESTRA_PB_CPP_FILE "${GENERATED_PROTOBUF_PATH}/orchestra.pb.cc") +set(ORCHESTRA_PB_H_FILE "${GENERATED_PROTOBUF_PATH}/orchestra.pb.h") +set(ORCHESTRA_GRPC_PB_CPP_FILE "${GENERATED_PROTOBUF_PATH}/orchestra.grpc.pb.cc") +set(ORCHESTRA_GRPC_PB_H_FILE "${GENERATED_PROTOBUF_PATH}/orchestra.grpc.pb.h") + +set(TYPES_PB_CPP_FILE "${GENERATED_PROTOBUF_PATH}/types.pb.cc") +set(TYPES_PB_H_FILE "${GENERATED_PROTOBUF_PATH}/types.pb.h") +set(TYPES_GRPC_PB_CPP_FILE "${GENERATED_PROTOBUF_PATH}/types.grpc.pb.cc") +set(TYPES_GRPC_PB_H_FILE "${GENERATED_PROTOBUF_PATH}/types.grpc.pb.h") + +add_custom_command( + OUTPUT "${ORCHESTRA_PB_H_FILE}" + "${ORCHESTRA_PB_CPP_FILE}" + "${ORCHESTRA_GRPC_PB_H_FILE}" + "${ORCHESTRA_GRPC_PB_CPP_FILE}" + COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} + ARGS "--proto_path=${PROTO_PATH}" + "--cpp_out=${GENERATED_PROTOBUF_PATH}" + "${ORCHESTRA_PROTO}" + COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} + ARGS "--proto_path=${PROTO_PATH}" + "--grpc_out=${GENERATED_PROTOBUF_PATH}" + "--plugin=protoc-gen-grpc=/usr/local/bin/grpc_cpp_plugin" + "${ORCHESTRA_PROTO}" + ) + +add_custom_command( + OUTPUT "${TYPES_PB_H_FILE}" + "${TYPES_PB_CPP_FILE}" + "${TYPES_GRPC_PB_H_FILE}" + "${TYPES_GRPC_PB_CPP_FILE}" + COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} + ARGS "--proto_path=${PROTO_PATH}" + "--cpp_out=${GENERATED_PROTOBUF_PATH}" + "${TYPES_PROTO}" + COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} + ARGS "--proto_path=${PROTO_PATH}" + "--grpc_out=${GENERATED_PROTOBUF_PATH}" + "--plugin=protoc-gen-grpc=/usr/local/bin/grpc_cpp_plugin" + "${TYPES_PROTO}" + ) + +set(GENERATED_PROTOBUF_FILES ${ORCHESTRA_PB_H_FILE} ${ORCHESTRA_PB_CPP_FILE} + ${ORCHESTRA_GRPC_PB_H_FILE} ${ORCHESTRA_GRPC_PB_CPP_FILE} + ${TYPES_PB_H_FILE} ${TYPES_PB_CPP_FILE} + ${TYPES_GRPC_PB_H_FILE} ${TYPES_GRPC_PB_CPP_FILE}) + +include_directories(${GENERATED_PROTOBUF_PATH}) +link_libraries(grpc++_unsecure grpc pthread rt ${PROTOBUF_LIBRARY}) + +add_executable(objstore src/objstore.cc ${GENERATED_PROTOBUF_FILES}) +add_executable(scheduler_server src/scheduler_server.cc ${GENERATED_PROTOBUF_FILES}) +add_library(orchlib SHARED src/orchlib.cc ${GENERATED_PROTOBUF_FILES}) + +install(TARGETS objstore scheduler_server orchlib DESTINATION ${CMAKE_SOURCE_DIR}/lib/orchpy/orchpy) diff --git a/Makefile b/Makefile deleted file mode 100644 index d2afb705d..000000000 --- a/Makefile +++ /dev/null @@ -1,89 +0,0 @@ -SRC_PATH = src -PROTOS_PATH = protos -LIB_PATH = lib/orchlib - -CXX = g++ -CPPFLAGS += -I/usr/local/include -pthread -CXXFLAGS += -std=c++11 -fPIC -I$(SRC_PATH) -O3 -LDFLAGS += -L/usr/local/lib -lgrpc++_unsecure -lgrpc -lprotobuf -lpthread -ldl -PROTOC = protoc -GRPC_CPP_PLUGIN = grpc_cpp_plugin -GRPC_CPP_PLUGIN_PATH ?= `which $(GRPC_CPP_PLUGIN)` - -vpath %.proto $(PROTOS_PATH) - -all: system-check $(LIB_PATH)/liborchlib.so $(SRC_PATH)/server lib/orchpy/orchpy/liborchlib.so - -$(LIB_PATH)/liborchlib.so: $(SRC_PATH)/types.pb.o $(SRC_PATH)/orchestra.pb.o $(SRC_PATH)/types.grpc.pb.o $(SRC_PATH)/orchestra.grpc.pb.o $(LIB_PATH)/orchlib.o - $(CXX) $^ $(LDFLAGS) -shared -o $@ - -$(SRC_PATH)/server: $(SRC_PATH)/orchestra.pb.o $(SRC_PATH)/types.pb.o $(SRC_PATH)/types.grpc.pb.o $(SRC_PATH)/orchestra.grpc.pb.o $(SRC_PATH)/server.o - $(CXX) $^ $(LDFLAGS) -o $@ - -.PRECIOUS: ./src/%.grpc.pb.cc -$(SRC_PATH)/%.grpc.pb.cc: %.proto - $(PROTOC) -I $(PROTOS_PATH) --grpc_out=$(SRC_PATH)/ --plugin=protoc-gen-grpc=$(GRPC_CPP_PLUGIN_PATH) $< - -.PRECIOUS: ./src/%.pb.cc -$(SRC_PATH)/%.pb.cc: %.proto - $(PROTOC) -I $(PROTOS_PATH) --cpp_out=./src $< - -lib/orchpy/orchpy/liborchlib.so: - cp -f lib/orchlib/liborchlib.so lib/orchpy/orchpy/liborchlib.so - -clean: - rm -f $(SRC_PATH)/*.o $(LIB_PATH)/*.o $(SRC_PATH)/*.pb.cc $(SRC_PATH)/*.pb.h $(LIB_PATH)/orchlib.so $(SRC_PATH)/server - - -# The following is to test your system and ensure a smoother experience. -# They are by no means necessary to actually compile a grpc-enabled software. - -PROTOC_CMD = which $(PROTOC) -PROTOC_CHECK_CMD = $(PROTOC) --version | grep -q libprotoc.3 -PLUGIN_CHECK_CMD = which $(GRPC_CPP_PLUGIN) -HAS_PROTOC = $(shell $(PROTOC_CMD) > /dev/null && echo true || echo false) -ifeq ($(HAS_PROTOC),true) -HAS_VALID_PROTOC = $(shell $(PROTOC_CHECK_CMD) 2> /dev/null && echo true || echo false) -endif -HAS_PLUGIN = $(shell $(PLUGIN_CHECK_CMD) > /dev/null && echo true || echo false) - -SYSTEM_OK = false -ifeq ($(HAS_VALID_PROTOC),true) -ifeq ($(HAS_PLUGIN),true) -SYSTEM_OK = true -endif -endif - -system-check: -ifneq ($(HAS_VALID_PROTOC),true) - @echo " DEPENDENCY ERROR" - @echo - @echo "You don't have protoc 3.0.0 installed in your path." - @echo "Please install Google protocol buffers 3.0.0 and its compiler." - @echo "You can find it here:" - @echo - @echo " https://github.com/google/protobuf/releases/tag/v3.0.0-alpha-1" - @echo - @echo "Here is what I get when trying to evaluate your version of protoc:" - @echo - -$(PROTOC) --version - @echo - @echo -endif -ifneq ($(HAS_PLUGIN),true) - @echo " DEPENDENCY ERROR" - @echo - @echo "You don't have the grpc c++ protobuf plugin installed in your path." - @echo "Please install grpc. You can find it here:" - @echo - @echo " https://github.com/grpc/grpc" - @echo - @echo "Here is what I get when trying to detect if you have the plugin:" - @echo - -which $(GRPC_CPP_PLUGIN) - @echo - @echo -endif -ifneq ($(SYSTEM_OK),true) - @false -endif diff --git a/include/orchestra/orchestra.h b/include/orchestra/orchestra.h new file mode 100644 index 000000000..4b78905c8 --- /dev/null +++ b/include/orchestra/orchestra.h @@ -0,0 +1,47 @@ +#ifndef ORCHESTRA_INCLUDE_ORCHESTRA_H +#define ORCHESTRA_INCLUDE_ORCHESTRA_H + +#include +#include + +typedef size_t ObjRef; +typedef size_t WorkerId; +typedef size_t ObjStoreId; + +class FnInfo { + size_t num_return_vals_; + std::vector workers_; +public: + void set_num_return_vals(size_t num) { + num_return_vals_ = num; + } + size_t num_return_vals() const { + return num_return_vals_; + } + void add_worker(WorkerId workerid) { + workers_.push_back(workerid); + } + size_t num_workers() const { + return workers_.size(); + } + ObjRef worker(size_t i) const { + return workers_[i]; + } +}; + +typedef std::vector > ObjTable; +typedef std::unordered_map FnTable; + +class objstore_not_registered_error : public std::runtime_error +{ +public: + objstore_not_registered_error(const std::string& msg) : std::runtime_error(msg) {} +}; + + +// struct slice { +// char* data; +// size_t len; +// }; + +#endif diff --git a/lib/orchlib/orchlib.cc b/lib/orchlib/orchlib.cc deleted file mode 100644 index 127084180..000000000 --- a/lib/orchlib/orchlib.cc +++ /dev/null @@ -1,86 +0,0 @@ -#include -#include -#include -#include - -#include - -using grpc::Server; -using grpc::ServerBuilder; -using grpc::ServerContext; -using grpc::Status; - -#include "orchestra.grpc.pb.h" -#include "orchlib.h" - -using grpc::Channel; -using grpc::ClientContext; -using grpc::Status; - -class Client { - public: - Client(std::shared_ptr channel) - : stub_(Orchestra::NewStub(channel)) {} - - size_t RemoteCall(const std::string& name) { - RemoteCallRequest request; - request.set_name(name); - - RemoteCallReply reply; - ClientContext context; - - Status status = stub_->RemoteCall(&context, request, &reply); - - return reply.result(); - } - - void RegisterWorker() { - RegisterWorkerRequest request; - RegisterWorkerReply reply; - ClientContext context; - Status status = stub_->RegisterWorker(&context, request, &reply); - return; - } - - private: - std::unique_ptr stub_; -}; - -class WorkerServiceImpl final : public Worker::Service { - Status InvokeCall(ServerContext* context, const InvokeCallRequest* request, - InvokeCallReply* reply) override { - std::cout << "invoke call request" << std::endl; - return Status::OK; - } -}; - -void start_server() { - std::string server_address("0.0.0.0:50053"); - WorkerServiceImpl service; - ServerBuilder builder; - builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); - builder.RegisterService(&service); - std::unique_ptr server(builder.BuildAndStart()); - std::cout << "Server listening on " << server_address << std::endl; - server->Wait(); -} - -void* orch_create_context(const char* server_addr) { - Client* client = new Client(grpc::CreateChannel(server_addr, grpc::InsecureChannelCredentials())); - client->RegisterWorker(); - return client; -} - -size_t orch_remote_call(void* context, const char* name, void* args) { - Client* client = (Client*)context; - return client->RemoteCall(std::string(name)); -} - -int main(int argc, char** argv) { - Client greeter( - grpc::CreateChannel("localhost:50052", grpc::InsecureChannelCredentials())); - std::string user("world"); - greeter.RemoteCall(user); - - return 0; -} diff --git a/lib/orchlib/orchlib.h b/lib/orchlib/orchlib.h deleted file mode 100644 index f732c317b..000000000 --- a/lib/orchlib/orchlib.h +++ /dev/null @@ -1,11 +0,0 @@ -extern "C" { - -void* orch_create_context(const char* server_addr); -size_t orch_remote_call(void* context, const char* name, void* args); - -void* orch_arglist_create(); -void orch_arglist_add_ref(void* arglist, size_t ref); -void orch_arglist_add_string(void* arglist, const char* str); -void orch_arglist_destroy(void* arglist); - -} diff --git a/lib/orchpy/orchpy/services.py b/lib/orchpy/orchpy/services.py new file mode 100644 index 000000000..048d21cdd --- /dev/null +++ b/lib/orchpy/orchpy/services.py @@ -0,0 +1,30 @@ +import subprocess32 as subprocess +import os +import atexit +import time + +_services_path = os.path.dirname(os.path.abspath(__file__)) + +all_processes = [] + +def cleanup(): + timeout_sec = 5 + for p in all_processes: + p_sec = 0 + for second in range(timeout_sec): + if p.poll() == None: + time.sleep(1) + p_sec += 1 + if p_sec >= timeout_sec: + p.kill() # supported from python 2.6 + print 'helper processes shut down!' + +atexit.register(cleanup) + +def start_scheduler(scheduler_address): + p = subprocess.Popen([os.path.join(_services_path, "scheduler_server"), str(scheduler_address)]) + all_processes.append(p) + +def start_objstore(objstore_address): + p = subprocess.Popen([os.path.join(_services_path, "objstore"), str(objstore_address)]) + all_processes.append(p) diff --git a/lib/orchpy/orchpy/unison.pyx b/lib/orchpy/orchpy/unison.pyx index 1dc19135b..d6b1b017a 100644 --- a/lib/orchpy/orchpy/unison.pyx +++ b/lib/orchpy/orchpy/unison.pyx @@ -1,8 +1,15 @@ -from libc.stdint cimport uint64_t, int64_t +# Will be rewritten in C++ for easier deployment once the API is stabilized + +from libc.stdint cimport uint64_t, int64_t, uintptr_t from libcpp cimport bool from libcpp.string cimport string import numpy as np +try: + import cPickle as pickle +except: + import pickle + cdef extern from "types.pb.h": ctypedef enum DataType: INT32 @@ -23,10 +30,11 @@ cdef extern from "types.pb.h": Value* add_value() Value* mutable_value(int index) + cdef cppclass String: - String() - void set_data(const char* val) - string* mutable_data() + String() + void set_data(const char* val) + string* mutable_data() cdef cppclass Int: Int() @@ -38,28 +46,48 @@ cdef extern from "types.pb.h": void set_data(double val) double data() + cdef cppclass PyObj: + PyObj() + void set_data(const char* val, size_t len) + string* mutable_data() + cdef cppclass Obj: Obj() String* mutable_string_data() Int* mutable_int_data() Double* mutable_double_data() + PyObj* mutable_pyobj_data() bool has_string_data() bool has_int_data() bool has_double_data() + bool ParseFromString(const string& data) -cdef class PyValues: +cdef class PyValues: # TODO: unify with the below cdef Values *thisptr def __cinit__(self): self.thisptr = new Values() def __dealloc__(self): del self.thisptr + def get_value(self): + return self.thisptr -cdef class PyValue: +cdef class PyValue: # TODO: unify with the below cdef Value *thisptr def __cinit__(self): self.thisptr = new Value() def __dealloc__(self): del self.thisptr + def get_value(self): + return self.thisptr + +cdef class ObjWrapper: # TODO: unify with the above + cdef Obj *thisptr + def __cinit__(self): + self.thisptr = new Obj() + # def __dealloc__(self): + # del self.thisptr + def get_value(self): + return self.thisptr cdef class ObjRef: cdef size_t _id @@ -80,32 +108,70 @@ cdef class ObjRef: cpdef get_id(self): return self._id -cpdef serialize_args(args): - cdef Values* vals - cdef Value* val - cdef Obj* obj +cpdef serialize_into(val, objptr): + cdef uintptr_t ptr = objptr + cdef Obj* obj = ptr cdef String* string_data cdef Int* int_data cdef Double* double_data - result = PyValues() - vals = result.thisptr + if type(val) == str: + string_data = obj[0].mutable_string_data() + string_data[0].set_data(val) + elif type(val) == int or type(val) == long: + int_data = obj[0].mutable_int_data() + int_data[0].set_data(val) + elif type(val) == float: + double_data = obj[0].mutable_double_data() + double_data[0].set_data(val) + else: + data = pickle.dumps(val, pickle.HIGHEST_PROTOCOL) + pyobj_data = obj[0].mutable_pyobj_data() + pyobj_data[0].set_data(data, len(data)) + +cpdef serialize(val): + result = ObjWrapper() + serialize_into(val, result.get_value()) + return result + +cpdef serialize_args_into(args, valsptr): + cdef uintptr_t ptr = valsptr + cdef Values* vals = ptr + cdef Value* val + cdef Obj* obj for arg in args: val = vals[0].add_value() if type(arg) == ObjRef: val[0].set_ref(arg.get_id()) else: obj = val[0].mutable_obj() - if type(arg) == str: - string_data = obj[0].mutable_string_data() - string_data[0].set_data(arg) - elif type(arg) == int or type(arg) == long: - int_data = obj[0].mutable_int_data() - int_data[0].set_data(arg) - elif type(arg) == float: - double_data = obj[0].mutable_double_data() - double_data[0].set_data(arg) + serialize_into(arg, obj) + +cpdef serialize_args(args): + result = PyValues() + serialize_args_into(args, result.get_value()) return result +cdef deserialize_from(Obj* obj): + if obj[0].has_string_data(): + return obj[0].mutable_string_data()[0].mutable_data()[0] + elif obj[0].has_int_data(): + return obj[0].mutable_int_data()[0].data() + elif obj[0].has_double_data(): + return obj[0].mutable_double_data()[0].data() + else: + data = obj[0].mutable_pyobj_data()[0].mutable_data()[0] + return pickle.loads(data) + +cpdef deserialize_from_string(str): + cdef string s = str + cdef Obj* obj = new Obj() # TODO: memory leak + obj[0].ParseFromString(s) + return deserialize_from(obj) + +# cpdef deserialize(str): +# cdef string s = string(str) +# return deserialize_from(obj.get_value()) + cpdef deserialize_args(PyValues args): cdef Values* vals = args.thisptr cdef Value* val @@ -123,6 +189,9 @@ cpdef deserialize_args(PyValues args): result.append(obj[0].mutable_int_data()[0].data()) elif obj[0].has_double_data(): result.append(obj[0].mutable_double_data()[0].data()) + else: + data = obj[0].mutable_pyobj_data()[0].mutable_data()[0] + result.append(pickle.loads(data)) return result cdef int numpy_dtype_to_proto(dtype): diff --git a/lib/orchpy/orchpy/worker.pyx b/lib/orchpy/orchpy/worker.pyx index 76ff560de..5c1bc124b 100644 --- a/lib/orchpy/orchpy/worker.pyx +++ b/lib/orchpy/orchpy/worker.pyx @@ -1,5 +1,135 @@ -cdef extern void* orch_create_context(const char* server_addr); -cdef extern size_t orch_remote_call(void* context, const char* name, void* args); +from libc.stdint cimport uintptr_t +import orchpy.unison as unison + +from libc.stdint cimport uint64_t, int64_t, uintptr_t +from libcpp cimport bool +from libcpp.string cimport string + +cdef struct Slice: + char* ptr + size_t size + +cdef extern void* orch_create_context(const char* server_addr, const char* worker_addr, const char* objstore_addr); +cdef extern void orch_register_function(void* worker, const char* name, size_t num_return_vals) +cdef extern size_t orch_remote_call(void* context, void* request); +cdef extern size_t orch_push(void* context, void* value); +cdef extern void orch_main_loop(void* context); +cdef extern Slice orch_get_serialized_obj(void* context, size_t objref); + +cdef extern from "Python.h": + Py_ssize_t PyByteArray_GET_SIZE(object array) + object PyUnicode_FromStringAndSize(char *buff, Py_ssize_t len) + object PyBytes_FromStringAndSize(char *buff, Py_ssize_t len) + object PyString_FromStringAndSize(char *buff, Py_ssize_t len) + int PyByteArray_Resize(object self, Py_ssize_t size) except -1 + char* PyByteArray_AS_STRING(object bytearray) + +cdef extern from "types.pb.h": + cdef cppclass Values + +cdef extern from "orchestra.pb.h": + cdef cppclass RemoteCallRequest: + RemoteCallRequest() + void set_name(const char* value) + Values* mutable_arg() + +cdef extern from "types.pb.h": + ctypedef enum DataType: + INT32 + INT64 + FLOAT32 + FLOAT64 + + cdef cppclass Value: + Value() + void set_ref(uint64_t value) + uint64_t ref() + bool has_obj() + Obj* mutable_obj() + + cdef cppclass Values: + Values() + int value_size() + Value* add_value() + Value* mutable_value(int index) + + + cdef cppclass String: + String() + void set_data(const char* val) + string* mutable_data() + + cdef cppclass Int: + Int() + void set_data(int64_t val) + int64_t data() + + cdef cppclass Double: + Double() + void set_data(double val) + double data() + + cdef cppclass PyObj: + PyObj() + void set_data(const char* val, size_t len) + string* mutable_data() + + cdef cppclass Obj: + Obj() + String* mutable_string_data() + Int* mutable_int_data() + Double* mutable_double_data() + PyObj* mutable_pyobj_data() + bool has_string_data() + bool has_int_data() + bool has_double_data() + +cdef serialize_into(val, Obj* obj): + cdef String* string_data + cdef Int* int_data + cdef Double* double_data + if type(val) == str: + string_data = obj[0].mutable_string_data() + string_data[0].set_data(val) + elif type(val) == int or type(val) == long: + int_data = obj[0].mutable_int_data() + int_data[0].set_data(val) + elif type(val) == float: + double_data = obj[0].mutable_double_data() + double_data[0].set_data(val) + # else: + # data = pickle.dumps(val, pickle.HIGHEST_PROTOCOL) + # pyobj_data = obj[0].mutable_pyobj_data() + # pyobj_data[0].set_data(data, len(data)) + +cdef class ObjWrapper: # TODO: unify with the above + cdef Obj *thisptr + def __cinit__(self): + self.thisptr = new Obj() + # def __dealloc__(self): + # del self.thisptr + def get_value(self): + return self.thisptr + +cpdef serialize_into_2(val, objptr): + cdef uintptr_t ptr = objptr + cdef Obj* obj = ptr + cdef String* string_data + cdef Int* int_data + cdef Double* double_data + if type(val) == str: + string_data = obj[0].mutable_string_data() + string_data[0].set_data(val) + elif type(val) == int or type(val) == long: + int_data = obj[0].mutable_int_data() + int_data[0].set_data(val) + elif type(val) == float: + double_data = obj[0].mutable_double_data() + double_data[0].set_data(val) + # else: + # data = pickle.dumps(val, pickle.HIGHEST_PROTOCOL) + # pyobj_data = obj[0].mutable_pyobj_data() + # pyobj_data[0].set_data(data, len(data)) cdef class Worker: cdef void* context @@ -7,11 +137,48 @@ cdef class Worker: def __cinit__(self): self.context = NULL - def connect(self, server_addr): - self.context = orch_create_context(server_addr) + def connect(self, server_addr, worker_addr, objstore_addr): + self.context = orch_create_context(server_addr, worker_addr, objstore_addr) - def call(self, name): - return orch_remote_call(self.context, name, 0) +# cpdef call(self, name, args): +# cdef RemoteCallRequest* result = new RemoteCallRequest() +# result[0].set_name(name) +# unison.serialize_args_into(args, result[0].mutable_arg()) +# for i in range(10): +# orch_remote_call(self.context, result) +# # return result + + cpdef do_call(self, ptr): + return orch_remote_call(self.context, ptr) + + cpdef do_push(self, val): + print("before serialization") + result = unison.serialize(val) + print("before push") + # ptr = result.get_value() + # print "pointer is", ptr + # cdef Obj* obj = new Obj() + o = ObjWrapper() + # serialize_into_2(0, obj) + # cdef Obj* ptr = new Obj() # o.get_value() + ## ptr = o.get_value() + ptr = result.get_value() + serialize_into_2(0, ptr) + return orch_push(self.context, ptr) + + cpdef get_serialized(self, objref): + cdef Slice slice = orch_get_serialized_obj(self.context, objref) + data = PyBytes_FromStringAndSize(slice.ptr, slice.size) + return data + + cpdef pull(self, objref): + cdef Slice slice = orch_get_serialized_obj(self.context, objref) + + cpdef register_function(self, func_name, num_args): + orch_register_function(self.context, func_name, num_args) + + cpdef main_loop(self): + orch_main_loop(self.context) global_worker = Worker() diff --git a/lib/orchpy/setup.py b/lib/orchpy/setup.py index c2002b6ed..477a50833 100644 --- a/lib/orchpy/setup.py +++ b/lib/orchpy/setup.py @@ -8,8 +8,10 @@ setup( version = "0.1.dev0", ext_modules = cythonize([ Extension("orchpy/worker", + include_dirs = ["../../src"], sources = ["orchpy/worker.pyx"], - extra_link_args=["-Iorchpy -lorchlib"]), + extra_link_args=["-Iorchpy -lorchlib"], + language = "c++"), Extension("orchpy/unison", include_dirs = ["../../src/"], sources = ["orchpy/unison.pyx"], @@ -19,7 +21,7 @@ setup( use_2to3=True, packages=find_packages(), package_data = { - 'orchpy': ['liborchlib.so'] + 'orchpy': ['liborchlib.so', 'scheduler_server', 'objstore'] }, zip_safe=False ) diff --git a/protos/orchestra.proto b/protos/orchestra.proto index 6761b9b03..28ccd8229 100644 --- a/protos/orchestra.proto +++ b/protos/orchestra.proto @@ -2,33 +2,122 @@ syntax = "proto3"; import "types.proto"; +message AckReply { + string errormsg = 1; +} + message RegisterWorkerRequest { - string address = 1; + string worker_address = 1; + string objstore_address = 2; } message RegisterWorkerReply { uint64 workerid = 1; } +message RegisterObjStoreRequest { + string address = 1; +} + +message RegisterObjStoreReply { + uint64 objstoreid = 1; +} + +message RegisterFunctionRequest { + uint64 workerid = 1; + string fnname = 2; + uint64 num_return_vals = 3; +} + message RemoteCallRequest { - string name = 1; - Values arg = 2; + Call call = 1; } message RemoteCallReply { - uint64 result = 1; + repeated uint64 result = 1; } -message PullObjectRequest { - uint64 ref = 1; +message PullObjRequest { + uint64 objref = 1; } -service Orchestra { +message PushObjRequest { + uint64 workerid = 1; +} + +message PushObjReply { + uint64 objref = 1; +} + +message ChangeCountRequest { + uint64 objref = 1; +} + +message GetDebugInfoRequest { + +} + +message FnTableEntry { + uint64 workerid = 1; + uint64 num_return_vals = 2; +} + +message GetDebugInfoReply { + map function_table = 1; +} + +service SchedulerServer { rpc RegisterWorker(RegisterWorkerRequest) returns (RegisterWorkerReply); - // rpc RegisterFunction + rpc RegisterObjStore(RegisterObjStoreRequest) returns (RegisterObjStoreReply); + rpc RegisterFunction(RegisterFunctionRequest) returns (AckReply); rpc RemoteCall(RemoteCallRequest) returns (RemoteCallReply); - // rpc PushObject - // rpc PullObject(PullObjectRequest) + rpc IncrementCount(ChangeCountRequest) returns (AckReply); + rpc DecrementCount(ChangeCountRequest) returns (AckReply); + rpc PushObj(PushObjRequest) returns (PushObjReply); + rpc PullObj(PullObjRequest) returns (AckReply); + rpc GetDebugInfo(GetDebugInfoRequest) returns (GetDebugInfoReply); +} + +message DeliverObjRequest { + string objstore_address = 1; // objstore to deliver the object to + uint64 objref = 2; // reference of object that gets delivered +} + +message RegisterObjRequest { + uint64 objref = 1; // reference of object that gets registered +} + +message RegisterObjReply { + uint64 handle = 1; // handle to memory segment where object is stored +} + +message ObjChunk { + uint64 objref = 1; + uint64 totalsize = 2; + bytes data = 3; +} + +message GetObjRequest { + uint64 objref = 1; +} + +message GetObjReply { + string bucket = 1; + uint64 handle = 2; + uint64 size = 3; +} + +message DebugInfoRequest {} + +message DebugInfoReply { + repeated uint64 objref = 1; +} + +service ObjStore { + rpc DeliverObj(DeliverObjRequest) returns (AckReply); + rpc StreamObj(stream ObjChunk) returns (AckReply); + rpc GetObj(GetObjRequest) returns (GetObjReply); + rpc DebugInfo(DebugInfoRequest) returns (DebugInfoReply); } message InvokeCallRequest { @@ -39,8 +128,6 @@ message InvokeCallReply { } -service Worker { +service WorkerServer { rpc InvokeCall(InvokeCallRequest) returns (InvokeCallReply); - // rpc PushObj(PushObjRequest) returns (PushObjReply); - // rpc RequestTransfer(RequestTransferRequest) returns (RequestTransferReply); } diff --git a/protos/types.proto b/protos/types.proto index 74e01b85e..f3adea98d 100644 --- a/protos/types.proto +++ b/protos/types.proto @@ -13,10 +13,10 @@ message Double { } message PyObj { - string type = 1; - bytes data = 2; + bytes data = 1; } +// Union of possible object types message Obj { String string_data = 1; Int int_data = 2; @@ -33,10 +33,17 @@ message Value { Obj obj = 2; // for pass by value } +// Deprecated message Values { repeated Value value = 1; } +message Call { + string name = 1; + repeated Value args = 2; + repeated uint64 result = 3; +} + enum DataType { INT32 = 0; INT64 = 1; diff --git a/src/objstore.cc b/src/objstore.cc new file mode 100644 index 000000000..fd842ba43 --- /dev/null +++ b/src/objstore.cc @@ -0,0 +1,59 @@ +#include "objstore.h" + +const size_t ObjStoreClient::CHUNK_SIZE = 8 * 1024; + +Status ObjStoreClient::upload_data_to(slice data, ObjRef objref, ObjStore::Stub& stub) { + ObjChunk chunk; + ClientContext context; + AckReply reply; + std::unique_ptr > writer(stub.StreamObj(&context, &reply)); + const char* head = data.data; + for (size_t i = 0; i < data.len; i += CHUNK_SIZE) { + chunk.set_objref(objref); + chunk.set_totalsize(data.len); + chunk.set_data(head + i, std::min(CHUNK_SIZE, data.len - i)); + if (!writer->Write(chunk)) { + std::cout << "write failed" << std::endl; + // throw std::runtime_error("write failed"); + } + } + writer->WritesDone(); + return writer->Finish(); +} + +void ObjStoreServiceImpl::allocate_memory(ObjRef objref, size_t size) { + std::ostringstream stream; + stream << "obj-" << memory_names_.size(); + std::string name = stream.str(); + // Make sure that the name is not taken yet + shared_memory_object::remove(name.c_str()); + memory_names_.push_back(name); + // Make room for boost::interprocess metadata + size_t new_size = (size / page_size + 2) * page_size; + shared_object& object = memory_[objref]; + object.name = name; + object.memory = std::make_shared(create_only, name.c_str(), new_size); + object.ptr.data = static_cast(memory_[objref].memory->allocate(size)); + object.ptr.len = size; +} + +void start_objstore(const char* objstore_address) { + ObjStoreServiceImpl service; + ServerBuilder builder; + + builder.AddListeningPort(std::string(objstore_address), grpc::InsecureServerCredentials()); + builder.RegisterService(&service); + std::unique_ptr server(builder.BuildAndStart()); + + server->Wait(); +} + +int main(int argc, char** argv) { + if (argc != 2) { + return 1; + } + + start_objstore(argv[1]); + + return 0; +} diff --git a/src/objstore.h b/src/objstore.h new file mode 100644 index 000000000..6651113dc --- /dev/null +++ b/src/objstore.h @@ -0,0 +1,133 @@ +#ifndef ORCHESTRA_OBJSTORE_SERVER_H +#define ORCHESTRA_OBJSTORE_SERVER_H + +#include +#include +#include +#include +#include + +using namespace boost::interprocess; + +#include "orchestra/orchestra.h" +#include "orchestra.grpc.pb.h" +#include "types.pb.h" + +#include "orchlib.h" + +using grpc::Server; +using grpc::ServerBuilder; +using grpc::ServerReader; +using grpc::ServerContext; +using grpc::ClientContext; +using grpc::ClientWriter; +using grpc::Status; + +using grpc::Channel; + +class ObjStoreClient { +public: + static const size_t CHUNK_SIZE; + static Status upload_data_to(slice data, ObjRef objref, ObjStore::Stub& stub); +}; + +struct shared_object { + std::string name; + std::shared_ptr memory; + slice ptr; +}; + +class ObjStoreServiceImpl final : public ObjStore::Service { + std::vector memory_names_; + std::unordered_map memory_; + std::mutex memory_lock_; + size_t page_size = mapped_region::get_page_size(); + std::unordered_map> objstores_; + + void allocate_memory(ObjRef objref, size_t size); + + // check if we already connected to the other objstore, if yes, return reference to connection, otherwise connect + ObjStore::Stub& get_objstore_stub(const std::string& objstore_address) { + auto iter = objstores_.find(objstore_address); + if (iter != objstores_.end()) + return *(iter->second); + auto channel = grpc::CreateChannel(objstore_address, grpc::InsecureChannelCredentials()); + objstores_.emplace(objstore_address, ObjStore::NewStub(channel)); + return *objstores_[objstore_address]; + } + +public: + ObjStoreServiceImpl() {} + + ~ObjStoreServiceImpl() { + for (const auto& segment_name : memory_names_) { + shared_memory_object::remove(segment_name.c_str()); + } + } + + Status DeliverObj(ServerContext* context, const DeliverObjRequest* request, AckReply* reply) override { + ObjStore::Stub& stub = get_objstore_stub(request->objstore_address()); + ObjRef objref = request->objref(); + + // TODO: Have to introduce wait condition + + return ObjStoreClient::upload_data_to(memory_[objref].ptr, objref, stub); + } + + Status DebugInfo(ServerContext* context, const DebugInfoRequest* request, DebugInfoReply* reply) override { + for (const auto& entry : memory_) { + reply->add_objref(entry.first); + } + return Status::OK; + } + + Status GetObj(ServerContext* context, const GetObjRequest* request, GetObjReply* reply) override { + ObjRef objref = request->objref(); + std::cout << "getobj lock"; + memory_lock_.lock(); + shared_object& object = memory_[objref]; + reply->set_bucket(object.name); + auto handle = object.memory->get_handle_from_address(object.ptr.data); + reply->set_handle(handle); + reply->set_size(object.ptr.len); + memory_lock_.unlock(); + std::cout << "getobj unlock"; + return Status::OK; + } + + Status StreamObj(ServerContext* context, ServerReader* reader, AckReply* reply) override { + std::cout << "stream obj lock" << std::endl; + memory_lock_.lock(); + ObjChunk chunk; + ObjRef objref = 0; + size_t totalsize = 0; + if (reader->Read(&chunk)) { + objref = chunk.objref(); + totalsize = chunk.totalsize(); + allocate_memory(objref, totalsize); + } + size_t num_bytes = 0; + char* data = memory_[objref].ptr.data; + + std::cout << "before loop " << totalsize << std::endl; + + do { + if (num_bytes + chunk.data().size() > totalsize) { + std::cout << "cancelled" << std::endl; + memory_lock_.unlock(); + return Status::CANCELLED; + } + std::memcpy(data, chunk.data().c_str(), chunk.data().size()); + data += chunk.data().size(); + num_bytes += chunk.data().size(); + std::cout << "looping " << num_bytes << std::endl; + } while (reader->Read(&chunk)); + + std::cout << "finished" << std::endl; + memory_lock_.unlock(); + std::cout << "stream obj unlock" << std::endl; + return Status::OK; + } +}; + +#endif diff --git a/src/orchlib.cc b/src/orchlib.cc new file mode 100644 index 000000000..0c4aaf40f --- /dev/null +++ b/src/orchlib.cc @@ -0,0 +1,29 @@ +#include "worker.h" + +Worker* orch_create_context(const char* server_addr, const char* worker_addr, const char* objstore_addr) { + auto server_channel = grpc::CreateChannel(server_addr, grpc::InsecureChannelCredentials()); + auto objstore_channel = grpc::CreateChannel(objstore_addr, grpc::InsecureChannelCredentials()); + Worker* worker = new Worker(server_channel, objstore_channel); + worker->register_worker(std::string(worker_addr), std::string(objstore_addr)); + return worker; +} + +size_t orch_remote_call(Worker* worker, RemoteCallRequest* request) { + return worker->RemoteCall(request); +} + +void orch_main_loop(Worker* worker) { + worker->MainLoop(); +} + +size_t orch_push(Worker* worker, Obj* obj) { + return worker->PushObj(obj); +} + +slice orch_get_serialized_obj(Worker* worker, ObjRef objref) { + return worker->GetSerializedObj(objref); +} + +void orch_register_function(Worker* worker, const char* name, size_t num_return_vals) { + // worker->register_function(std::string(name), num_return_vals); +} diff --git a/src/orchlib.h b/src/orchlib.h new file mode 100644 index 000000000..070890745 --- /dev/null +++ b/src/orchlib.h @@ -0,0 +1,20 @@ + + +extern "C" { + +struct slice { + char* data; + size_t len; +}; + +struct Worker; +struct RemoteCallRequest; +struct Value; + +Worker* orch_create_context(const char* server_addr, const char* worker_addr, const char* objstore_addr); +size_t orch_remote_call(Worker* context, RemoteCallRequest* request); +size_t orch_push(Worker* context, Obj* value); +void orch_main_loop(Worker* worker); +slice orch_get_serialized_obj(Worker* worker, size_t objref); +void orch_register_function(Worker* worker, const char* name, size_t num_return_vals); +} diff --git a/src/scheduler.cc b/src/scheduler.cc new file mode 100644 index 000000000..e69de29bb diff --git a/src/scheduler.h b/src/scheduler.h new file mode 100644 index 000000000..2ee32996a --- /dev/null +++ b/src/scheduler.h @@ -0,0 +1,145 @@ +#ifndef ORCHESTRA_SCHEDULER_H +#define ORCHESTRA_SCHEDULER_H + +#include + +#include + +#include "orchestra/orchestra.h" +#include "orchestra.grpc.pb.h" +#include "types.pb.h" + +using grpc::Server; +using grpc::ServerBuilder; +using grpc::ServerReader; +using grpc::ServerContext; +using grpc::Status; + +using grpc::Channel; + +struct WorkerHandle { + std::shared_ptr channel; + ObjStoreId objstoreid; +}; + +struct ObjStoreHandle { + std::shared_ptr channel; + std::string address; +}; + +class Scheduler { + // Vector of all workers registered in the system. Their index in this vector + // is the workerid. + std::vector workers_; + std::mutex workers_lock_; + // Vector of all workers that are currently idle. + std::vector available_workers_; + // Vector of all object stores registered in the system. Their index in this + // vector is the objstoreid. + std::vector objstores_; + grpc::mutex objstores_lock_; + // Mapping from objref to list of object stores where the object is stored. + ObjTable objtable_; + std::mutex objtable_lock_; + // Hash map from function names to workers where the function is registered. + FnTable fntable_; + std::mutex fntable_lock_; + // List of pending tasks. + std::deque > tasks_; + std::mutex tasks_lock_; +public: + // returns number of return values of task + size_t add_task(const Call& task) { + fntable_lock_.lock(); + size_t num_return_vals = 2; // fn_table_[task.name()].num_return_vals(); + fntable_lock_.unlock(); + // std::unique_ptr task_ptr(new Call(task)); // TODO: perform copy outside + tasks_lock_.lock(); + // tasks_.push_back(task_ptr); + tasks_lock_.unlock(); + return num_return_vals; + } + WorkerId register_worker(const std::string& worker_address, const std::string& objstore_address) { + ObjStoreId objstoreid = std::numeric_limits::max(); + objstores_lock_.lock(); + for (size_t i = 0; i < objstores_.size(); ++i) { + std::cout << "adress: " << objstores_[i].address << std::endl; + std::cout << "my adress: " << objstore_address << std::endl; + if (objstores_[i].address == objstore_address) { + objstoreid = i; + } + } + if (objstoreid == std::numeric_limits::max()) { + // throw objstore_not_registered_error("objectstore not registered"); + std::cout << "bad bad bad" << std::endl; + } + objstores_lock_.unlock(); + workers_lock_.lock(); + WorkerId result = workers_.size(); + workers_.push_back(WorkerHandle()); + workers_[result].channel = grpc::CreateChannel(worker_address, grpc::InsecureChannelCredentials()); + workers_[result].objstoreid = objstoreid; + workers_lock_.unlock(); + return result; + } + ObjStoreId register_objstore(const std::string& objstore_address) { + // auto handle = ObjStoreHandle(objstore_address); + // auto handlecopy = handle; + // auto handle = ObjStoreHandle("0.0.0.0:22222"); + objstores_lock_.lock(); + std::cout << "capacity" << objstores_.capacity() << std::endl; + ObjStoreId result = objstores_.size(); + // auto handle = ObjStoreHandle(objstore_address); + // objstores_.emplace_back(objstore_address); + objstores_.push_back(ObjStoreHandle()); + + objstores_[result].channel = grpc::CreateChannel(objstore_address, grpc::InsecureChannelCredentials()); + objstores_[result].address = std::string(objstore_address); + + // auto handlecopy = handle; + // auto handle = grpc::CreateChannel(objstore_address, grpc::InsecureChannelCredentials()); + // auto handlecopy = grpc::CreateChannel(objstore_address, grpc::InsecureChannelCredentials()); + objstores_lock_.unlock(); + return result; + } + ObjRef register_new_object() { + objtable_lock_.lock(); + ObjRef result = objtable_.size(); + objtable_.push_back(std::vector()); + objtable_lock_.unlock(); + return result; + } + void add_objstore_to_obj(ObjRef objref, ObjStoreId objstoreid) { + objtable_lock_.lock(); + // do a binary search + auto pos = std::lower_bound(objtable_[objref].begin(), objtable_[objref].end(), objstoreid); + if (pos == objtable_[objref].end() || objstoreid < *pos) { + objtable_[objref].insert(pos, objstoreid); + } + objtable_lock_.unlock(); + } + ObjStoreId get_store(WorkerId workerid) { + workers_lock_.lock(); + ObjStoreId result = workers_[workerid].objstoreid; + workers_lock_.unlock(); + return result; + } + void register_function(const std::string& name, WorkerId workerid, size_t num_return_vals) { + fntable_lock_.lock(); + FnInfo& info = fntable_[name]; + info.set_num_return_vals(num_return_vals); + info.add_worker(workerid); + fntable_lock_.unlock(); + } + /* + void debug_info(DebugInfoReply* debug_info) { + fntable_lock_.lock(); + for (const auto& entry : fntable_) { + debug_info-> + } + fntable_lock_.lock(); + } + */ +}; + +#endif diff --git a/src/scheduler_server.cc b/src/scheduler_server.cc new file mode 100644 index 000000000..2021db4a2 --- /dev/null +++ b/src/scheduler_server.cc @@ -0,0 +1,50 @@ +#include "scheduler_server.h" + +Status SchedulerServerServiceImpl::RemoteCall(ServerContext* context, const RemoteCallRequest* request, RemoteCallReply* reply) { + size_t num_return_vals = scheduler_->add_task(request->call()); + for (size_t i = 0; i < num_return_vals; ++i) { + ObjRef result = scheduler_->register_new_object(); + reply->add_result(result); + } + return Status::OK; +} + +Status SchedulerServerServiceImpl::PushObj(ServerContext* context, const PushObjRequest* request, PushObjReply* reply) { + ObjRef objref = scheduler_->register_new_object(); + ObjStoreId objstoreid = scheduler_->get_store(request->workerid()); + scheduler_->add_objstore_to_obj(objref, objstoreid); + reply->set_objref(objref); + return Status::OK; +} + + /* + Status PushObj(ServerContext* context, ServerReader *reader, AckReply* reply) override { + ObjChunk chunk; + while (reader->Read(&chunk)) { + + } + std::cout << "got chunks" << std::endl; + return Status::OK; + } + */ + +void start_scheduler_server(const char* server_address) { + SchedulerServerServiceImpl service; + ServerBuilder builder; + + builder.AddListeningPort(std::string(server_address), grpc::InsecureServerCredentials()); + builder.RegisterService(&service); + std::unique_ptr server(builder.BuildAndStart()); + + server->Wait(); +} + +int main(int argc, char** argv) { + if (argc != 2) { + return 1; + } + + start_scheduler_server(argv[1]); + + return 0; +} diff --git a/src/scheduler_server.h b/src/scheduler_server.h new file mode 100644 index 000000000..d6739a23a --- /dev/null +++ b/src/scheduler_server.h @@ -0,0 +1,45 @@ +#ifndef ORCHESTRA_SCHEDULER_SERVER_H +#define ORCHESTRA_SCHEDULER_SERVER_H + +#include +#include +#include +#include + +#include "scheduler.h" + + +class SchedulerServerServiceImpl final : public SchedulerServer::Service { + ObjTable objtable_; + std::unique_ptr scheduler_; +public: + SchedulerServerServiceImpl() : scheduler_(new Scheduler()) { + } + Status RemoteCall(ServerContext* context, const RemoteCallRequest* request, RemoteCallReply* reply) override; + Status PushObj(ServerContext* context, const PushObjRequest* request, PushObjReply* reply) override; + Status PullObj(ServerContext* context, const PullObjRequest* request, AckReply* reply) override { + return Status::OK; + } + Status RegisterWorker(ServerContext* context, const RegisterWorkerRequest* request, RegisterWorkerReply* reply) override { + WorkerId workerid = scheduler_->register_worker(request->worker_address(), request->objstore_address()); + reply->set_workerid(workerid); + return Status::OK; + } + Status RegisterObjStore(ServerContext* context, const RegisterObjStoreRequest* request, RegisterObjStoreReply* reply) override { + try { + reply->set_objstoreid(scheduler_->register_objstore(request->address())); + } catch (...) { + std::cout << "caught exception" << std::endl; + } + return Status::OK; + } + Status RegisterFunction(ServerContext* context, const RegisterFunctionRequest* request, AckReply* reply) override { + scheduler_->register_function(request->fnname(), request->workerid(), request->num_return_vals()); + return Status::OK; + } + Status GetDebugInfo(ServerContext* context, const GetDebugInfoRequest* request, GetDebugInfoReply* reply) override { + return Status::OK; + } +}; + +#endif diff --git a/src/server.cc b/src/server.cc deleted file mode 100644 index 4501748a7..000000000 --- a/src/server.cc +++ /dev/null @@ -1,88 +0,0 @@ -#include -#include -#include -#include - -#include - -#include "orchestra.grpc.pb.h" - -using grpc::Server; -using grpc::ServerBuilder; -using grpc::ServerContext; -using grpc::Status; - -typedef size_t ObjRef; -typedef size_t WorkerId; -typedef std::vector > ObjTable; - -class OrchestraScheduler { - -}; - -class OrchestraServer { - ObjTable objtable; - std::mutex mutex; -public: - ObjRef register_new_object() { - mutex.lock(); - ObjRef result = objtable.size(); - // std::cout << "size " << result << std::endl; - objtable.push_back(std::vector()); - mutex.unlock(); - return result; - } - void register_object(ObjRef objref, WorkerId workerid) { - mutex.lock(); - objtable[objref].push_back(workerid); - mutex.unlock(); - } -}; - -// Logic and data behind the server's behavior. -class OrchestraServiceImpl final : public Orchestra::Service { - ObjTable objtable; - std::unique_ptr server; -public: - OrchestraServiceImpl() : server(new OrchestraServer()) { - } - Status RemoteCall(ServerContext* context, const RemoteCallRequest* request, - RemoteCallReply* reply) override { - // std::cout << "called" << std::endl; - ObjRef objref = server->register_new_object(); - reply->set_result(objref); - // std::string prefix("Hello "); - // reply->set_message(prefix + request->name()); - return Status::OK; - } - Status RegisterWorker(ServerContext* context, const RegisterWorkerRequest* request, - RegisterWorkerReply* reply) override { - std::cout << "register worker" << std::endl; - return Status::OK; - } -}; - -void RunServer() { - std::string server_address("0.0.0.0:50052"); - OrchestraServiceImpl service; - - ServerBuilder builder; - // Listen on the given address without any authentication mechanism. - builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); - // Register "service" as the instance through which we'll communicate with - // clients. In this case it corresponds to an *synchronous* service. - builder.RegisterService(&service); - // Finally assemble the server. - std::unique_ptr server(builder.BuildAndStart()); - std::cout << "Server listening on " << server_address << std::endl; - - // Wait for the server to shutdown. Note that some other thread must be - // responsible for shutting down the server for this call to ever return. - server->Wait(); -} - -int main(int argc, char** argv) { - RunServer(); - - return 0; -} diff --git a/test/gen-python-code.sh b/test/gen-python-code.sh new file mode 100644 index 000000000..595dab90d --- /dev/null +++ b/test/gen-python-code.sh @@ -0,0 +1,4 @@ +# For running the python tests + +protoc -I ../protos/ --python_out=. --grpc_out=. --plugin=protoc-gen-grpc=`which grpc_python_plugin` ../protos/orchestra.proto +protoc -I ../protos/ --python_out=. --grpc_out=. --plugin=protoc-gen-grpc=`which grpc_python_plugin` ../protos/types.proto diff --git a/test/runtest.py b/test/runtest.py new file mode 100644 index 000000000..987aeba45 --- /dev/null +++ b/test/runtest.py @@ -0,0 +1,154 @@ +import unittest +import orchpy.unison as unison +import orchpy.services as services +import orchpy.worker as worker +import numpy as np +import time + +from grpc.beta import implementations +import orchestra_pb2 +import types_pb2 + +class UnisonTest(unittest.TestCase): + + def testSerialize(self): + d = [1, 2L, "hello", 3.0] + res = unison.serialize_args(d) + c = unison.deserialize_args(res) + self.assertEqual(c, d) + + d = [{'hello': 'world'}] + res = unison.serialize_args(d) + c = unison.deserialize_args(res) + self.assertEqual(c, d) + + a = np.zeros((100, 100)) + res = unison.serialize_args(a) + b = unison.deserialize_args(res) + self.assertTrue((a == b).all()) + +TIMEOUT_SECONDS = 5 + +def produce_data(num_chunks): + for _ in range(num_chunks): + yield orchestra_pb2.ObjChunk(objref=1, totalsize=1000, data=b"hello world") + +class ObjStoreTest(unittest.TestCase): + + """Test setting up object stores, transfering data between them and retrieving data to a client""" + def testObjStore(self): + services.start_scheduler("0.0.0.0:22221") + services.start_objstore("0.0.0.0:22222") + services.start_objstore("0.0.0.0:22223") + time.sleep(0.5) + + scheduler_channel = implementations.insecure_channel('localhost', 22221) + scheduler_stub = orchestra_pb2.beta_create_SchedulerServer_stub(scheduler_channel) + objstore1_channel = implementations.insecure_channel('localhost', 22222) + objstore1_stub = orchestra_pb2.beta_create_ObjStore_stub(objstore1_channel) + objstore2_channel = implementations.insecure_channel('localhost', 22223) + objstore2_stub = orchestra_pb2.beta_create_ObjStore_stub(objstore2_channel) + + scheduler_stub.RegisterObjStore(orchestra_pb2.RegisterObjStoreRequest(address="127.0.0.1:22222"), TIMEOUT_SECONDS) + scheduler_stub.RegisterObjStore(orchestra_pb2.RegisterObjStoreRequest(address="127.0.0.1:22223"), TIMEOUT_SECONDS) + + worker.global_worker.connect("127.0.0.1:22221", "127.0.0.1:40000", "127.0.0.1:22222") + + other_worker = worker.Worker() + other_worker.connect("127.0.0.1:22221", "127.0.0.1:40001", "127.0.0.1:22223") + + # import IPython + # IPython.embed() + + for i in range(1, 10): + l = i * 100 * "h" + objref = worker.global_worker.do_push(l) + # time.sleep(5.0) + response = objstore1_stub.DeliverObj(orchestra_pb2.DeliverObjRequest(objref=objref, objstore_address="0.0.0.0:22223"), TIMEOUT_SECONDS) + # time.sleep(5.0) + str = other_worker.get_serialized(objref) + result = worker.unison.deserialize_from_string(str) + # import IPython + # IPython.embed() + self.assertEqual(len(result), 100 * i) + +class SchedulerTest(unittest.TestCase): + + def testRegister(self): + scheduler_channel = implementations.insecure_channel('localhost', 22221) + scheduler_stub = orchestra_pb2.beta_create_SchedulerServer_stub(scheduler_channel) + w = worker.Worker() + w.connect("127.0.0.1:22221", "127.0.0.1:40002", "127.0.0.1:22222") + w.register_function("hello_world", 2) + + +""" +class SchedulerTest(unittest.TestCase): + + def testServer(self): + services.start_scheduler("0.0.0.0:22221") + services.start_objstore("0.0.0.0:22222") + services.start_objstore("0.0.0.0:22223") + time.sleep(1.0) + + scheduler_channel = implementations.insecure_channel('localhost', 22221) + scheduler_stub = orchestra_pb2.beta_create_SchedulerServer_stub(scheduler_channel) + objstore_channel = implementations.insecure_channel('localhost', 22222) + objstore_stub = orchestra_pb2.beta_create_ObjStore_stub(objstore_channel) + objstore_channel2 = implementations.insecure_channel('localhost', 22223) + objstore_stub2 = orchestra_pb2.beta_create_ObjStore_stub(objstore_channel2) + + # call = types_pb2.Call(name="test") + # response = scheduler_stub.RemoteCall(orchestra_pb2.RemoteCallRequest(call=call), TIMEOUT_SECONDS) + # response = scheduler_stub.RegisterFunction(orchestra_pb2.RegisterFunctionRequest(workerid=1, fnname="hello"), TIMEOUT_SECONDS) + + response2 = scheduler_stub.RegisterObjStore(orchestra_pb2.RegisterObjStoreRequest(address="127.0.0.1:22222"), TIMEOUT_SECONDS) + response2 = scheduler_stub.RegisterObjStore(orchestra_pb2.RegisterObjStoreRequest(address="127.0.0.1:22223"), TIMEOUT_SECONDS) + + # response2 = scheduler_stub.RegisterObjStore(orchestra_pb2.RegisterObjStoreRequest(address="127.0.0.1:22222"), TIMEOUT_SECONDS) + # response3 = scheduler_stub.RegisterObjStore(orchestra_pb2.RegisterObjStoreRequest(address="127.0.0.1:22223"), TIMEOUT_SECONDS) + + # objstore_stub.StreamObj(produce_data(100), TIMEOUT_SECONDS) + + worker.global_worker.connect("127.0.0.1:22221", "127.0.0.1:40000", "127.0.0.1:22222") + + l = [1, 2, 3, 4] + worker.global_worker.do_push(l) + + ## res = scheduler_stub.PushObj(orchestra_pb2.PushObjRequest(workerid=0), TIMEOUT_SECONDS) + + response = objstore_stub.DeliverObj(orchestra_pb2.DeliverObjRequest(objref=0, objstore_address="0.0.0.0:22223"), TIMEOUT_SECONDS) + + # res = objstore_stub2.DebugInfo(orchestra_pb2.DebugInfoRequest(), TIMEOUT_SECONDS) + + response = objstore_stub.GetObj(orchestra_pb2.GetObjRequest(objref=0), TIMEOUT_SECONDS) + + worker.global_worker.get_serialized(0) + + import IPython + IPython.embed() + + l = [1, 2, 3, 4] + worker.global_worker.do_push(l) + + response = objstore_stub.DeliverObj(orchestra_pb2.DeliverObjRequest(), TIMEOUT_SECONDS) + + # response = objstore_stub.DebugInfo(orchestra_pb2.DebugInfoRequest(), TIMEOUT_SECONDS) + + # import IPython + # IPython.embed() + + # worker.global_worker.connect("127.0.0.1:22221", "127.0.0.1:22222") + # l = [1, 2, 3, 4] + # worker.global_worker.do_push(l) + + # import IPython + # IPython.embed() + # response = objstore_stub.DeliverObj(orchestra_pb2.DeliverObjRequest()) + # print "Greeter client received: " + response.message + # import IPython + # IPython.embed() +""" + +if __name__ == '__main__': + unittest.main()