diff --git a/CMakeLists.txt b/CMakeLists.txt index af3572aee..0650f2a8d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,12 +2,19 @@ cmake_minimum_required(VERSION 2.8) project(orchestra) +list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules) + find_package(Protobuf REQUIRED) +find_package(PythonInterp REQUIRED) +find_package(PythonLibs REQUIRED) +find_package(NumPy REQUIRED) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") include_directories("${CMAKE_SOURCE_DIR}/include") include_directories("/usr/local/include") +include_directories("${PYTHON_INCLUDE_DIRS}") +include_directories("${NUMPY_INCLUDE_DIR}") set(PROTO_PATH "${CMAKE_SOURCE_DIR}/protos") @@ -71,6 +78,6 @@ endif() add_executable(objstore src/objstore.cc ${GENERATED_PROTOBUF_FILES}) add_executable(scheduler src/scheduler.cc ${GENERATED_PROTOBUF_FILES}) -add_library(orchlib SHARED src/orchlib.cc src/worker.cc ${GENERATED_PROTOBUF_FILES}) +add_library(orchpylib SHARED src/orchpylib.cc src/worker.cc ${GENERATED_PROTOBUF_FILES}) -install(TARGETS objstore scheduler orchlib DESTINATION ${CMAKE_SOURCE_DIR}/lib/orchpy/orchpy) +install(TARGETS objstore scheduler orchpylib DESTINATION ${CMAKE_SOURCE_DIR}/lib/orchpy/orchpy) diff --git a/include/orchestra/orchestra.h b/include/orchestra/orchestra.h index cb0b96341..f325a0c8d 100644 --- a/include/orchestra/orchestra.h +++ b/include/orchestra/orchestra.h @@ -41,10 +41,9 @@ public: objstore_not_registered_error(const std::string& msg) : std::runtime_error(msg) {} }; - -// struct slice { -// char* data; -// size_t len; -// }; +struct slice { + char* data; + size_t len; +}; #endif diff --git a/lib/orchpy/orchpy/__init__.py b/lib/orchpy/orchpy/__init__.py index cdf0cd44e..eb1ef5f1a 100644 --- a/lib/orchpy/orchpy/__init__.py +++ b/lib/orchpy/orchpy/__init__.py @@ -1,8 +1 @@ -import os, sys, ctypes - -MACOSX = (sys.platform in ['darwin']) - -_orchlib_handle = ctypes.CDLL( - os.path.join(os.path.dirname(os.path.abspath(__file__)), 'liborchlib.dylib' if MACOSX else 'liborchlib.so'), - ctypes.RTLD_GLOBAL -) +import liborchpylib as lib diff --git a/lib/orchpy/orchpy/services.py b/lib/orchpy/orchpy/services.py index 5465989d2..d4fb6a126 100644 --- a/lib/orchpy/orchpy/services.py +++ b/lib/orchpy/orchpy/services.py @@ -41,5 +41,10 @@ def start_objstore(host, port): all_processes.append((p, port)) def start_worker(test_path, host, scheduler_port, worker_port, objstore_port): - p = subprocess.Popen(["python", test_path, host, str(scheduler_port), str(worker_port), str(objstore_port)]) + p = subprocess.Popen(["python", + test_path, + "--ip_address=" + host, + "--scheduler_port=" + str(scheduler_port), + "--objstore_port=" + str(objstore_port), + "--worker_port=" + str(worker_port)]) all_processes.append((p, worker_port)) diff --git a/lib/orchpy/orchpy/unison.pyx b/lib/orchpy/orchpy/unison.pyx deleted file mode 100644 index 75b8ffaaf..000000000 --- a/lib/orchpy/orchpy/unison.pyx +++ /dev/null @@ -1,202 +0,0 @@ -# Will be rewritten in C++ for easier deployment once the API is stabilized - -from libc.stdint cimport uint64_t, int64_t, uintptr_t -from libcpp cimport bool -from libcpp.string cimport string -import numpy as np - -try: - import cPickle as pickle -except: - import pickle - -cdef extern from "../../../build/generated/types.pb.h": - - cdef cppclass Call: - Value* add_arg(); - void set_name(const char* value) - Value* mutable_arg(int index); - int arg_size() const; - -cdef extern from "../../../build/generated/types.pb.h": - ctypedef enum DataType: - INT32 - INT64 - FLOAT32 - FLOAT64 - - cdef cppclass Value: - Value() - void set_ref(uint64_t value) - uint64_t ref() - bool has_obj() - Obj* mutable_obj() - - cdef cppclass String: - String() - void set_data(const char* val) - string* mutable_data() - - cdef cppclass Int: - Int() - void set_data(int64_t val) - int64_t data() - - cdef cppclass Double: - Double() - void set_data(double val) - double data() - - cdef cppclass PyObj: - PyObj() - void set_data(const char* val, size_t len) - string* mutable_data() - - cdef cppclass Obj: - Obj() - String* mutable_string_data() - Int* mutable_int_data() - Double* mutable_double_data() - PyObj* mutable_pyobj_data() - bool has_string_data() - bool has_int_data() - bool has_double_data() - bool ParseFromString(const string& data) - -cdef class PyValue: # TODO: unify with the below - cdef Value *thisptr - def __cinit__(self): - self.thisptr = new Value() - def __dealloc__(self): - del self.thisptr - def get_value(self): - return self.thisptr - -cdef class ObjWrapper: # TODO: unify with the above - cdef Obj *thisptr - def __cinit__(self): - self.thisptr = new Obj() - # def __dealloc__(self): - # del self.thisptr - def get_value(self): - return self.thisptr - -cdef class PythonCall: - cdef Call* thisptr - def __cinit__(self): - self.thisptr = new Call() - def __dealloc__(self): - del self.thisptr - def get_value(self): - return self.thisptr - -cdef class ObjRef: - cdef size_t _id - cdef object type - - def __cinit__(self, id, type): - self._id = id - - def __init__(self, id, type): - self.type = type - - def __richcmp__(self, other, int op): - if op == 2: - return self.get_id() == other.get_id() - else: - raise NotImplementedError("operator not implemented") - - cpdef get_id(self): - return self._id - -cpdef serialize_into(val, objptr): - cdef uintptr_t ptr = objptr - cdef Obj* obj = ptr - cdef String* string_data - cdef Int* int_data - cdef Double* double_data - if type(val) == str: - string_data = obj[0].mutable_string_data() - string_data[0].set_data(val) - elif type(val) == int or type(val) == long: - int_data = obj[0].mutable_int_data() - int_data[0].set_data(val) - elif type(val) == float: - double_data = obj[0].mutable_double_data() - double_data[0].set_data(val) - else: - data = pickle.dumps(val, pickle.HIGHEST_PROTOCOL) - pyobj_data = obj[0].mutable_pyobj_data() - pyobj_data[0].set_data(data, len(data)) - -cpdef serialize(val): - result = ObjWrapper() - serialize_into(val, result.get_value()) - return result - -cdef deserialize_from(Obj* obj): - if obj[0].has_string_data(): - return obj[0].mutable_string_data()[0].mutable_data()[0] - elif obj[0].has_int_data(): - return obj[0].mutable_int_data()[0].data() - elif obj[0].has_double_data(): - return obj[0].mutable_double_data()[0].data() - else: - data = obj[0].mutable_pyobj_data()[0].mutable_data()[0] - return pickle.loads(data) - -cpdef deserialize_from_string(str): - cdef string s = str - cdef Obj* obj = new Obj() # TODO: memory leak - obj[0].ParseFromString(s) - return deserialize_from(obj) - -# cpdef deserialize(str): -# cdef string s = string(str) -# return deserialize_from(obj.get_value()) - -# todo: unify with the above, at the moment this is copied -cdef deserialize_args_from_call(Call* call): - cdef Value* val - cdef Obj* obj - result = [] - for i in range(call[0].arg_size()): - val = call[0].mutable_arg(i) - if not val.has_obj(): - result.append(ObjRef(val.ref(), None)) # TODO: fix this - else: - obj = val[0].mutable_obj() - if obj[0].has_string_data(): - result.append(obj[0].mutable_string_data()[0].mutable_data()[0]) - elif obj[0].has_int_data(): - result.append(obj[0].mutable_int_data()[0].data()) - elif obj[0].has_double_data(): - result.append(obj[0].mutable_double_data()[0].data()) - else: - data = obj[0].mutable_pyobj_data()[0].mutable_data()[0] - result.append(pickle.loads(data)) - return result - -cpdef deserialize_call(PythonCall pycall): - cdef Call* call = pycall.thisptr - return deserialize_args_from_call(call) - - -cdef int numpy_dtype_to_proto(dtype): - if dtype == np.dtype('int32'): - return INT32 - if dtype == np.dtype('int64'): - return INT64 - if dtype == np.dtype('float32'): - return FLOAT32 - if dtype == np.dtype('float64'): - return FLOAT64 - -""" -cdef Value* ndarray_to_proto(array): - result = PyValue() - result.shape.extend(array.shape) - result.data = np.getbuffer(array, 0, array.size * array.dtype.itemsize) - result.dtype = numpy_dtype_to_proto(array.dtype) - return result -""" diff --git a/lib/orchpy/orchpy/worker.py b/lib/orchpy/orchpy/worker.py new file mode 100644 index 000000000..71c532170 --- /dev/null +++ b/lib/orchpy/orchpy/worker.py @@ -0,0 +1,122 @@ +import typing + +import orchpy + +class Worker(object): + """The methods in this class are considered unexposed to the user. The functions outside of this class are considered exposed.""" + + def __init__(self): + self.functions = {} + self.connected = False + self.handle = None + + def put_object(self, objref, value): + """Put `value` in the local object store with objref `objref`. This assumes that the value for `objref` has not yet been placed in the local object store.""" + object_capsule = orchpy.lib.serialize_object(value) + orchpy.lib.put_object(self.handle, objref, object_capsule) + + def get_object(self, objref): + """Return the value from the local object store for objref `objref`. This will block until the value for `objref` has been written to the local object store.""" + object_capsule = orchpy.lib.get_object(self.handle, objref) + return orchpy.lib.deserialize_object(object_capsule) + + def register_function(self, function): + """Notify the scheduler that this worker can execute the function with name `func_name`. Store the function `function` locally.""" + orchpy.lib.register_function(self.handle, function.func_name, len(function.return_types)) + self.functions[function.func_name] = function + + def remote_call(self, func_name, args): + """Tell the scheduler to schedule the execution of the function with name `func_name` with arguments `args`. Retrieve object references for the outputs of the function from the scheduler and immediately return them.""" + call_capsule = orchpy.lib.serialize_call(func_name, args) + return orchpy.lib.remote_call(self.handle, call_capsule) + +# We make `global_worker` a global variable so that there is one worker per worker process. +global_worker = Worker() + +def connect(scheduler_addr, objstore_addr, worker_addr, worker=global_worker): + if worker.connected: + raise Exception("Worker called connect, but worker is already connected") + worker.handle = orchpy.lib.create_worker(scheduler_addr, objstore_addr, worker_addr) + worker.connected = True + +def pull(objref, worker=global_worker): + object_capsule = orchpy.lib.pull_object(worker.handle, objref) + return orchpy.lib.deserialize_object(object_capsule) + +def push(value, worker=global_worker): + object_capsule = orchpy.lib.serialize_object(value) + return orchpy.lib.push_object(worker.handle, object_capsule) + +def main_loop(worker=global_worker): + if not worker.connected: + raise Exception("Worker is attempting to enter main_loop but has not been connected yet.") + orchpy.lib.start_worker_service(worker.handle) + while True: + call = orchpy.lib.wait_for_next_task(worker.handle) + func_name, args, return_objrefs = orchpy.lib.deserialize_call(call) + arguments = get_arguments_for_execution(worker.functions[func_name], args, worker) # get args from objstore + outputs = worker.functions[func_name].executor(arguments) # execute the function + store_outputs_in_objstore(return_objrefs, outputs, worker) # store output in local object store + # TODO(rkn): notify the scheduler that the task has completed, orchpy.lib.notify_task_completed(worker.handle) + +def distributed(arg_types, return_types, worker=global_worker): + def distributed_decorator(func): + def func_executor(arguments): + """This is what gets executed remotely on a worker after a distributed function is scheduled by the scheduler.""" + print "Calling function {} with arguments {}".format(func.__name__, arguments) + result = func(*arguments) + if len(return_types) != 1 and len(result) != len(return_types): + raise Exception("The @distributed decorator for function {} has {} return values with types {}, but {} returned {} values.".format(func.__name__, len(return_types), return_types, func.__name__, len(result))) + return result + def func_call(*args): + """This is what gets run immediately when a worker calls a distributed function.""" + # TODO(rkn): check types + return worker.remote_call(func_call.func_name, list(args)) + func_call.func_name = "{}.{}".format(func.__module__, func.__name__) + func_call.executor = func_executor + func_call.arg_types = arg_types + func_call.return_types = return_types + return func_call + return distributed_decorator + +# helper method, this should not be called by the user +def get_arguments_for_execution(function, args, worker=global_worker): + arguments = [] + # check the number of args + if len(args) != len(function.types) and function.types[-1] is not None: + raise Exception("Function {} expects {} arguments, but received {}.".format(function.__name__, len(function.types), len(args))) + elif len(args) < len(function.types) - 1 and function.types[-1] is None: + raise Exception("Function {} expects at least {} arguments, but received {}.".format(function.__name__, len(function.types) - 1, len(args))) + + for (i, arg) in enumerate(args): + print "Pulling argument {} for function {}.".format(i, function.__name__) + if i < len(function.types) - 1: + expected_type = function.types[i] + elif i == len(function.types) - 1 and function.types[-1] is not None: + expected_type = function.types[-1] + elif function.types[-1] is None and len(function.types > 1): + expected_type = function.types[-2] + else: + assert False, "This code should be unreachable." + + argument = worker.get_object(arg) if type(arg) == orchpy.ObjRef else arg + if type(arg) == orchpy.ObjRef: + # get the object from the local object store + # TODO(rkn): Do we know that it is already there? Maybe we should call pull(arg, worker). + argument = worker.get_object(arg) + else: + # pass the argument by value + argument = arg + + if expected_type != type(argument): + raise Exception("Argument {} for function {} has type {} but an argument of type {} was expected.".format(i, function.__name__, type(argument), arg_type)) + arguments.append(argument) + return arguments + +# helper method, this should not be called by the user +def store_outputs_in_objstore(objrefs, outputs, worker=global_worker): + if len(objrefs) == 1: + worker.put_object(objrefs[0], outputs) + else: + for i in range(len(objrefs)): + worker.put_object(objrefs[i], outputs[i]) diff --git a/lib/orchpy/orchpy/worker.pyx b/lib/orchpy/orchpy/worker.pyx deleted file mode 100644 index 146cabf8c..000000000 --- a/lib/orchpy/orchpy/worker.pyx +++ /dev/null @@ -1,342 +0,0 @@ -from libc.stdint cimport uintptr_t -import orchpy.unison as unison - -from libc.stdint cimport uint64_t, int64_t, uintptr_t -from libcpp cimport bool -from libcpp.string cimport string - -try: - import cPickle as pickle -except: - import pickle - -cdef struct Slice: - char* ptr - size_t size - -cdef extern void* orch_create_context(const char* server_addr, const char* worker_addr, const char* objstore_addr) -cdef extern void orch_start_worker_service(void* worker) -cdef extern void* orch_wait_for_next_task(void* worker) -cdef extern void orch_register_function(void* worker, const char* name, size_t num_return_vals) -cdef extern size_t orch_remote_call(void* worker, void* request) -cdef extern size_t orch_push(void* worker, void* value) -cdef extern Slice orch_get_serialized_obj(void* worker, size_t objref) -cdef extern void orch_put_obj(void* worker, size_t objref, void* obj) - -cdef extern from "Python.h": - Py_ssize_t PyByteArray_GET_SIZE(object array) - object PyUnicode_FromStringAndSize(char *buff, Py_ssize_t len) - object PyBytes_FromStringAndSize(char *buff, Py_ssize_t len) - object PyString_FromStringAndSize(char *buff, Py_ssize_t len) - int PyByteArray_Resize(object self, Py_ssize_t size) except -1 - char* PyByteArray_AS_STRING(object bytearray) - -cdef extern from "../../../build/generated/orchestra.pb.h": - cdef cppclass RemoteCallRequest: - RemoteCallRequest() - Call* mutable_call() - int arg_size() const; - -cdef extern from "../../../build/generated/types.pb.h": - cdef cppclass Values - - cdef cppclass Call: - Value* add_arg(); - void set_name(const char* value) - const string& name() - Value* mutable_arg(int index); - size_t result(int index); - int result_size(); - int arg_size() const; - - ctypedef enum DataType: - INT32 - INT64 - FLOAT32 - FLOAT64 - - cdef cppclass Value: - Value() - void set_ref(uint64_t value) - uint64_t ref() - bool has_obj() - Obj* mutable_obj() - - cdef cppclass Values: - Values() - int value_size() - Value* add_value() - Value* mutable_value(int index) - - - cdef cppclass String: - String() - void set_data(const char* val) - string* mutable_data() - - cdef cppclass Int: - Int() - void set_data(int64_t val) - int64_t data() - - cdef cppclass Double: - Double() - void set_data(double val) - double data() - - cdef cppclass PyObj: - PyObj() - void set_data(const char* val, size_t len) - string* mutable_data() - - cdef cppclass Obj: - Obj() - String* mutable_string_data() - Int* mutable_int_data() - Double* mutable_double_data() - PyObj* mutable_pyobj_data() - bool has_string_data() - bool has_int_data() - bool has_double_data() - -cdef serialize_into(val, Obj* obj): - cdef String* string_data - cdef Int* int_data - cdef Double* double_data - if type(val) == str: - string_data = obj[0].mutable_string_data() - string_data[0].set_data(val) - elif type(val) == int or type(val) == long: - int_data = obj[0].mutable_int_data() - int_data[0].set_data(val) - elif type(val) == float: - double_data = obj[0].mutable_double_data() - double_data[0].set_data(val) - # else: - # data = pickle.dumps(val, pickle.HIGHEST_PROTOCOL) - # pyobj_data = obj[0].mutable_pyobj_data() - # pyobj_data[0].set_data(data, len(data)) - -""" -cpdef deserialize_call(PyValues args): - cdef Values* vals = args.thisptr - cdef Value* val - cdef Obj* obj - result = [] - for i in range(vals[0].value_size()): - val = vals[0].mutable_value(i) - if not val.has_obj(): - result.append(ObjRef(val.ref(), None)) # TODO: fix this - else: - obj = val[0].mutable_obj() - if obj[0].has_string_data(): - result.append(obj[0].mutable_string_data()[0].mutable_data()[0]) - elif obj[0].has_int_data(): - result.append(obj[0].mutable_int_data()[0].data()) - elif obj[0].has_double_data(): - result.append(obj[0].mutable_double_data()[0].data()) - else: - data = obj[0].mutable_pyobj_data()[0].mutable_data()[0] - result.append(pickle.loads(data)) - return result -""" - -cdef class ObjWrapper: # TODO: unify with the above - cdef Obj *thisptr - def __cinit__(self): - self.thisptr = new Obj() - # def __dealloc__(self): - # del self.thisptr - def get_value(self): - return self.thisptr - -cpdef serialize_into_2(val, objptr): - cdef uintptr_t ptr = objptr - cdef Obj* obj = ptr - cdef String* string_data - cdef Int* int_data - cdef Double* double_data - if type(val) == str: - string_data = obj[0].mutable_string_data() - string_data[0].set_data(val) - elif type(val) == int or type(val) == long: - int_data = obj[0].mutable_int_data() - int_data[0].set_data(val) - elif type(val) == float: - double_data = obj[0].mutable_double_data() - double_data[0].set_data(val) - # else: - # data = pickle.dumps(val, pickle.HIGHEST_PROTOCOL) - # pyobj_data = obj[0].mutable_pyobj_data() - # pyobj_data[0].set_data(data, len(data)) - -cdef serialize_args_into_call(args, Call* call): - cdef Value* val - cdef Obj* obj - for arg in args: - val = call.add_arg() - if type(arg) == unison.ObjRef: - val[0].set_ref(arg.get_id()) - else: - obj = val[0].mutable_obj() - objptr = obj - unison.serialize_into(arg, objptr) - -cdef deserialize_obj(Obj* obj): - if obj[0].has_string_data(): - return obj[0].mutable_string_data()[0].mutable_data()[0] - elif obj[0].has_int_data(): - return obj[0].mutable_int_data()[0].data() - elif obj[0].has_double_data(): - return obj[0].mutable_double_data()[0].data() - else: - data = obj[0].mutable_pyobj_data()[0].mutable_data()[0] - return pickle.loads(data) - -# todo: unify with the above, at the moment this is copied -cdef deserialize_args_from_call(Call* call): - cdef Value* val - cdef Obj* obj - result = [] - for i in range(call[0].arg_size()): - val = call[0].mutable_arg(i) - if not val.has_obj(): - result.append(unison.ObjRef(val.ref(), None)) # TODO: fix this - else: - obj = val[0].mutable_obj() - if obj[0].has_string_data(): - result.append(obj[0].mutable_string_data()[0].mutable_data()[0]) - elif obj[0].has_int_data(): - result.append(obj[0].mutable_int_data()[0].data()) - elif obj[0].has_double_data(): - result.append(obj[0].mutable_double_data()[0].data()) - else: - data = obj[0].mutable_pyobj_data()[0].mutable_data()[0] - result.append(pickle.loads(data)) - return result - -cdef class Worker: - cdef void* context - cdef dict functions - - def __cinit__(self): - self.context = NULL - self.functions = {} - - def connect(self, server_addr, worker_addr, objstore_addr): - self.context = orch_create_context(server_addr, worker_addr, objstore_addr) - - def start_worker_service(self): - orch_start_worker_service(self.context) - - cpdef call(self, name, args): - cdef RemoteCallRequest* result = new RemoteCallRequest() - cdef Call* call = result[0].mutable_call() - call.set_name(name) - serialize_args_into_call(args, call) - orch_remote_call(self.context, result) - # return result - - cpdef do_call(self, ptr): - return orch_remote_call(self.context, ptr) - - cpdef push(self, val): - result = unison.serialize(val) - ptr = result.get_value() - return unison.ObjRef(orch_push(self.context, ptr), None) - - cpdef get_serialized(self, objref): - cdef Slice slice = orch_get_serialized_obj(self.context, objref) - data = PyBytes_FromStringAndSize(slice.ptr, slice.size) - return data - - cpdef put_obj(self, objref, obj): - result = unison.serialize(obj) - p = result.get_value() - cdef void* ptr = p - print "before put" - orch_put_obj(self.context, objref, ptr) - print "after put" - - cpdef do_pull(self, objref): - cdef Slice slice = orch_get_serialized_obj(self.context, objref) - - cpdef pull(self, objref): - print "before get_serialized_obj, getting", objref.get_id() - cdef Slice slice = orch_get_serialized_obj(self.context, objref.get_id()) - print "after get_serialized_ob" - data = PyBytes_FromStringAndSize(slice.ptr, slice.size) - print "after get data" - return unison.deserialize_from_string(data) - - cpdef register_function(self, func_name, function, num_return_vals): - orch_register_function(self.context, func_name, num_return_vals) - self.functions[func_name] = function - - cpdef wait_for_next_task(self): - result = [] - cdef Call* call = orch_wait_for_next_task(self.context) - cdef int size = call[0].arg_size() - cdef Obj* obj - print "hello from python" - print "size", size - args = deserialize_args_from_call(call) - print "done deserializing" - returnrefs = [] - for i in range(call[0].result_size()): - returnrefs.append(call[0].result(i)) - return call[0].name(), args, returnrefs - - cpdef invoke_function(self, name, args): - return self.functions[name].executor(args) - - cpdef main_loop(self): - while True: - name, args, returnrefs = self.wait_for_next_task() - print "got returnref", returnrefs - result = self.functions[name].executor(args, returnrefs) - if len(returnrefs) == 1: - self.put_obj(returnrefs[0], result) - else: - for i in range(len(returnrefs)): - self.put_obj(returnrefs[i], result[i]) - -global_worker = Worker() - -def distributed(types, return_types, worker=global_worker): - def distributed_decorator(func): - # deserialize arguments, execute function and return results - def func_executor(args, returnrefs): - arguments = [] - for (i, arg) in enumerate(args): - print "pulling argument", i - if type(arg) == unison.ObjRef: - if i < len(types) - 1: - arguments.append(worker.pull(arg)) - elif i == len(types) - 1 and types[-1] is not None: - arguments.append(worker.pull(arg)) - elif types[-1] is None: - arguments.append(worker.pull(arg)) - else: - raise Exception("Passed in " + str(len(args)) + " arguments to function " + func.__name__ + ", which takes only " + str(len(types)) + " arguments.") - else: - arguments.append(arg) - # print "done pulling argument", i - # TODO - # buf = bytearray() - print "called with arguments", arguments - result = func(*arguments) - # check number of return values and return types - if len(return_types) != 1 and len(result) != len(return_types): - raise Exception("The @distributed decorator for function " + func.__name__ + " has " + str(len(return_types)) + " return values with types " + str(return_types) + " but " + func.__name__ + " returned " + str(len(result)) + " values.") - return result - # for remotely executing the function - def func_call(*args, typecheck=False): - return worker.call(func_call.func_name, func_call.module_name, args) - func_call.func_name = func.__name__.encode() # why do we call encode()? - func_call.module_name = func.__module__.encode() # why do we call encode()? - func_call.is_distributed = True - func_call.executor = func_executor - func_call.types = types - return func_call - return distributed_decorator diff --git a/lib/orchpy/setup.py b/lib/orchpy/setup.py index 231eea778..5b3511fb1 100644 --- a/lib/orchpy/setup.py +++ b/lib/orchpy/setup.py @@ -1,6 +1,7 @@ import sys from setuptools import setup, Extension, find_packages +import setuptools from Cython.Build import cythonize # because of relative paths, this must be run from inside orch/lib/orchpy/ @@ -10,28 +11,12 @@ MACOSX = (sys.platform in ['darwin']) setup( name = "orchestra", version = "0.1.dev0", - ext_modules = cythonize([ - Extension("orchpy/worker", - include_dirs = ["../../src", "/usr/local/include/"], - sources = ["orchpy/worker.pyx"], - extra_link_args=["-Iorchpy -lorchlib"], - language = "c++"), - Extension("orchpy/unison", - include_dirs = ["../../src/", "/usr/local/include/"], - sources = ["orchpy/unison.pyx"], - extra_link_args=["-Iorchpy -lorchlib"], - language = "c++")], - compiler_directives={'language_level': 2}), # switch to 3 for python 3 use_2to3=True, packages=find_packages(), package_data = { - 'orchpy': ['liborchlib.dylib' if MACOSX else 'liborchlib.so', + 'orchpy': ['liborchpylib.dylib' if MACOSX else 'liborchpylib.so', 'scheduler', 'objstore'] }, zip_safe=False ) - -extension_mod = Extension("symphony", ["orchpy/symphony.cpp"], include_dirs=["../../build/generated/"]) - -setup(name = "symphony", ext_modules=[extension_mod]) diff --git a/protos/orchestra.proto b/protos/orchestra.proto index e55f43f87..0df17ec0f 100644 --- a/protos/orchestra.proto +++ b/protos/orchestra.proto @@ -38,7 +38,8 @@ message RemoteCallReply { } message PullObjRequest { - uint64 objref = 1; + uint64 workerid = 1; + uint64 objref = 2; } message PushObjRequest { diff --git a/protos/types.proto b/protos/types.proto index 4c39158cc..98340a0a4 100644 --- a/protos/types.proto +++ b/protos/types.proto @@ -21,11 +21,13 @@ message Obj { String string_data = 1; Int int_data = 2; Double double_data = 3; - PyObj pyobj_data = 4; + List list_data = 4; + Array array_data = 5; + PyObj pyobj_data = 6; } message List { - repeated Obj elems = 1; + repeated Obj elem = 1; } message Value { @@ -48,6 +50,6 @@ enum DataType { message Array { repeated uint64 shape = 1; - DataType dtype = 2; - bytes data = 3; + DataType dtype = 3; + repeated double double_data = 2; } diff --git a/src/objstore.h b/src/objstore.h index e11ea9cd0..feed29cff 100644 --- a/src/objstore.h +++ b/src/objstore.h @@ -13,8 +13,6 @@ using namespace boost::interprocess; #include "orchestra.grpc.pb.h" #include "types.pb.h" -#include "orchlib.h" - using grpc::Server; using grpc::ServerBuilder; using grpc::ServerReader; diff --git a/src/orchlib.cc b/src/orchlib.cc deleted file mode 100644 index ef600a2bf..000000000 --- a/src/orchlib.cc +++ /dev/null @@ -1,37 +0,0 @@ -#include "worker.h" - -Worker* orch_create_context(const char* server_addr, const char* worker_addr, const char* objstore_addr) { - auto server_channel = grpc::CreateChannel(server_addr, grpc::InsecureChannelCredentials()); - auto objstore_channel = grpc::CreateChannel(objstore_addr, grpc::InsecureChannelCredentials()); - Worker* worker = new Worker(std::string(worker_addr), server_channel, objstore_channel); - worker->register_worker(std::string(worker_addr), std::string(objstore_addr)); - return worker; -} - -size_t orch_remote_call(Worker* worker, RemoteCallRequest* request) { - return worker->remote_call(request); -} - -void orch_start_worker_service(Worker* worker) { - worker->start_worker_service(); -} - -Call* orch_wait_for_next_task(Worker* worker) { - return worker->receive_next_task(); -} - -size_t orch_push(Worker* worker, Obj* obj) { - return worker->push_obj(obj); -} - -slice orch_get_serialized_obj(Worker* worker, ObjRef objref) { - return worker->get_serialized_obj(objref); -} - -void orch_put_obj(Worker* worker, size_t objref, const Obj* obj) { - worker->put_obj(objref, obj); -} - -void orch_register_function(Worker* worker, const char* name, size_t num_return_vals) { - worker->register_function(std::string(name), num_return_vals); -} diff --git a/src/orchlib.h b/src/orchlib.h deleted file mode 100644 index bac8e5ca7..000000000 --- a/src/orchlib.h +++ /dev/null @@ -1,27 +0,0 @@ -// A minimal C API that is used for implementing Orchestra workers in C based -// languages (Python at the moment, in the future potentially Julia, R, MATLAB) - -extern "C" { - -struct slice { - char* data; - size_t len; -}; - -struct Worker; -struct RemoteCallRequest; -struct Value; - -// connect to the scheduler and the object store -Worker* orch_create_context(const char* server_addr, const char* worker_addr, const char* objstore_addr); -// start the worker service for this worker -void orch_start_worker_service(Worker* worker); -// Submit a function call to the scheduler -size_t orch_remote_call(Worker* worker, RemoteCallRequest* request); -size_t orch_push(Worker* worker, Obj* value); -Call* orch_wait_for_next_task(Worker* worker); -slice orch_get_serialized_obj(Worker* worker, size_t objref); -void orch_put_obj(Worker* worker, size_t objref, const Obj* obj); -void orch_register_function(Worker* worker, const char* name, size_t num_return_vals); - -} diff --git a/src/orchpylib.cc b/src/orchpylib.cc new file mode 100644 index 000000000..40dfe1a43 --- /dev/null +++ b/src/orchpylib.cc @@ -0,0 +1,454 @@ +// TODO: - Implement other datatypes for ndarray + +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION + +#include +#include +#include +#include + +#include "types.pb.h" +#include "worker.h" + +extern "C" { + // Error handling + + static PyObject *OrchPyError; +} + +// extracts a pointer from a python C API capsule +template +T* get_pointer_or_fail(PyObject* capsule, const char* name) { + if (PyCapsule_IsValid(capsule, name)) { + return static_cast(PyCapsule_GetPointer(capsule, name)); + } else { + PyErr_SetString(OrchPyError, "not a vaid capsule"); + return NULL; + } +} + +extern "C" { + +// Object references + +typedef struct { + PyObject_HEAD + ObjRef val; +} PyObjRef; + +static void PyObjRef_dealloc(PyObjRef *self) { + self->ob_type->tp_free((PyObject*) self); +} + +static PyObject* PyObjRef_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { + PyObjRef* self = (PyObjRef*) type->tp_alloc(type, 0); + if (self != NULL) { + self->val = 0; + } + return (PyObject*) self; +} + +static int PyObjRef_init(PyObjRef *self, PyObject *args, PyObject *kwds) { + if (!PyArg_ParseTuple(args, "i", &self->val)) { + return -1; + } + return 0; +}; + +static PyMemberDef PyObjRef_members[] = { + {"val", T_INT, offsetof(PyObjRef, val), 0, "object reference"}, + {NULL} +}; + +static PyTypeObject PyObjRefType = { + PyObject_HEAD_INIT(NULL) + 0, /* ob_size */ + "orchpy.ObjRef", /* tp_name */ + sizeof(PyObjRef), /* tp_basicsize */ + 0, /* tp_itemsize */ + 0, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + "OrchPy objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + 0, /* tp_methods */ + PyObjRef_members, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)PyObjRef_init, /* tp_init */ + 0, /* tp_alloc */ + PyObjRef_new, /* tp_new */ +}; + +// create PyObjRef from C++ (could be made more efficient if neccessary) +PyObject* make_pyobjref(ObjRef objref) { + PyObject* arglist = Py_BuildValue("(i)", objref); + PyObject* result = PyObject_CallObject((PyObject*) &PyObjRefType, arglist); + Py_DECREF(arglist); + return result; +} + +// Serialization + +// serialize will serialize the python object val into the protocol buffer +// object obj, returns 0 if successful and something else if not +int serialize(PyObject* val, Obj* obj) { + if (PyInt_Check(val)) { + Int* data = obj->mutable_int_data(); + long d = PyInt_AsLong(val); + data->set_data(d); + } else if (PyFloat_Check(val)) { + Double* data = obj->mutable_double_data(); + double d = PyFloat_AsDouble(val); + data->set_data(d); + } else if (PyList_Check(val)) { + List* data = obj->mutable_list_data(); + for (size_t i = 0, size = PyList_Size(val); i < size; ++i) { + Obj* elem = data->add_elem(); + if (serialize(PyList_GetItem(val, i), elem) != 0) { + return -1; + } + } + } else if (PyString_Check(val)) { + char* buffer; + Py_ssize_t length; + PyString_AsStringAndSize(val, &buffer, &length); // creates pointer to internal buffer + obj->mutable_string_data()->set_data(buffer, length); + } else if (PyArray_Check(val)) { + PyArrayObject* array = PyArray_GETCONTIGUOUS((PyArrayObject*)val); + Array* data = obj->mutable_array_data(); + npy_intp size = PyArray_SIZE(array); + for (int i = 0; i < PyArray_NDIM(array); ++i) { + data->add_shape(PyArray_DIM(array, i)); + } + if (PyArray_ISFLOAT(array)) { + double* buffer = (double*) PyArray_DATA(array); + for (npy_intp i = 0; i < size; ++i) { + data->add_double_data(buffer[i]); + } + } + } else { + return -1; + } + return 0; +} + +PyObject* deserialize(const Obj& obj) { + if (obj.has_int_data()) { + return PyInt_FromLong(obj.int_data().data()); + } else if (obj.has_double_data()) { + return PyFloat_FromDouble(obj.double_data().data()); + } else if (obj.has_list_data()) { + const List& data = obj.list_data(); + size_t size = data.elem_size(); + PyObject* list = PyList_New(size); + for (size_t i = 0; i < size; ++i) { + PyList_SetItem(list, i, deserialize(data.elem(i))); + } + return list; + } else if (obj.has_string_data()) { + const char* buffer = obj.string_data().data().data(); + Py_ssize_t length = obj.string_data().data().size(); + return PyString_FromStringAndSize(buffer, length); + } else if (obj.has_array_data()) { + const Array& array = obj.array_data(); + if (array.double_data_size() > 0) { // TODO: this is not quite right + npy_intp size = array.double_data_size(); + std::vector dims; + for (int i = 0; i < array.shape_size(); ++i) { + dims.push_back(array.shape(i)); + } + PyArrayObject* pyarray = (PyArrayObject*)PyArray_SimpleNew(array.shape_size(), &dims[0], NPY_DOUBLE); + double* buffer = (double*) PyArray_DATA(pyarray); + for (npy_intp i = 0; i < size; ++i) { + buffer[i] = array.double_data(i); + } + return (PyObject*)pyarray; + } + } else { + std::cout << "don't have object" << std::endl; + } +} + +PyObject* serialize_object(PyObject* self, PyObject* args) { + Obj* obj = new Obj(); // TODO: to be freed in capsul destructor + PyObject* pyval; + if (!PyArg_ParseTuple(args, "O", &pyval)) { + return NULL; + } + if (serialize(pyval, obj) != 0) { + PyErr_SetString(OrchPyError, "serialization: type not know"); // TODO: put a more expressive error message here + return NULL; + } + return PyCapsule_New(static_cast(obj), "obj", NULL); +} + +PyObject* deserialize_object(PyObject* self, PyObject* args) { + PyObject* capsule; + if (!PyArg_ParseTuple(args, "O", &capsule)) { + return NULL; + } + Obj* obj = get_pointer_or_fail(capsule, "obj"); + if (!obj) { + return NULL; + } + return deserialize(*obj); +} + +PyObject* serialize_call(PyObject* self, PyObject* args) { + Call* call = new Call(); // TODO: to be freed in capsul destructor + char* name; + int len; + PyObject* arguments; + if (!PyArg_ParseTuple(args, "s#O", &name, &len, &arguments)) { + return NULL; + } + call->set_name(name, len); + if (PyList_Check(arguments)) { + for (size_t i = 0, size = PyList_Size(arguments); i < size; ++i) { + Obj* arg = call->add_arg()->mutable_obj(); + serialize(PyList_GetItem(arguments, i), arg); + } + } else { + PyErr_SetString(OrchPyError, "serialize_call: second argument needs to be a list"); + return NULL; + } + return PyCapsule_New(static_cast(call), "call", NULL); +} + +PyObject* deserialize_call(PyObject* self, PyObject* args) { + PyObject* capsule = PyTuple_GetItem(args, 0); + Call* call = get_pointer_or_fail(capsule, "call"); + if (!call) { + return NULL; + } + PyObject* string = PyString_FromStringAndSize(call->name().c_str(), call->name().size()); + int argsize = call->arg_size(); + PyObject* arglist = PyList_New(argsize); + for (int i = 0; i < argsize; ++i) { + const Value& val = call->arg(i); + if (!val.has_obj()) { + // TODO: Deserialize object reference here + } else { + PyList_SetItem(arglist, i, deserialize(val.obj())); + } + } + int resultsize = call->result_size(); + PyObject* resultlist = PyList_New(resultsize); + for (int i = 0; i < resultsize; ++i) { + PyList_SetItem(resultlist, i, make_pyobjref(call->result(i))); + } + return PyTuple_Pack(3, string, arglist, resultlist); +} + +// Orchestra Python API + +PyObject* create_worker(PyObject* self, PyObject* args) { + const char* scheduler_addr; + const char* objstore_addr; + const char* worker_addr; + if (!PyArg_ParseTuple(args, "sss", &scheduler_addr, &objstore_addr, &worker_addr)) { + return NULL; + } + auto scheduler_channel = grpc::CreateChannel(scheduler_addr, grpc::InsecureChannelCredentials()); + auto objstore_channel = grpc::CreateChannel(objstore_addr, grpc::InsecureChannelCredentials()); + Worker* worker = new Worker(std::string(worker_addr), scheduler_channel, objstore_channel); + worker->register_worker(std::string(worker_addr), std::string(objstore_addr)); + return PyCapsule_New(static_cast(worker), "worker", NULL); // TODO: add destructor the deallocates worker +} + +PyObject* wait_for_next_task(PyObject* self, PyObject* args) { + PyObject* capsule = PyTuple_GetItem(args, 0); + Worker* worker = get_pointer_or_fail(capsule, "worker"); + if (!worker) { + return NULL; + } + Call* call = worker->receive_next_task(); + return PyCapsule_New(static_cast(call), "call", NULL); // TODO: how is destruction going to be handled here? +} + +PyObject* remote_call(PyObject* self, PyObject* args) { + PyObject* worker_capsule; + PyObject* call_capsule; + if (!PyArg_ParseTuple(args, "OO", &worker_capsule, &call_capsule)) { + return NULL; + } + Worker* worker = get_pointer_or_fail(worker_capsule, "worker"); + if (!worker) { + return NULL; + } + Call* call = get_pointer_or_fail(call_capsule, "call"); + if (!call) { + return NULL; + } + RemoteCallRequest request; + request.set_allocated_call(call); + RemoteCallReply reply = worker->remote_call(&request); + request.release_call(); // TODO: Make sure that call is not moved, otherwise capsule pointer needs to be updated + int size = reply.result_size(); + PyObject* list = PyList_New(size); + for (int i = 0; i < size; ++i) { + PyList_SetItem(list, i, make_pyobjref(reply.result(i))); + } + return list; +} + +PyObject* register_function(PyObject* self, PyObject* args) { + PyObject* worker_capsule; + const char* function_name; + int num_return_vals; + if (!PyArg_ParseTuple(args, "Osi", &worker_capsule, &function_name, &num_return_vals)) { + return NULL; + } + Worker* worker = get_pointer_or_fail(worker_capsule, "worker"); + if (!worker) { + return NULL; + } + worker->register_function(std::string(function_name), num_return_vals); + Py_RETURN_NONE; +} + +// TODO: test this +PyObject* push_object(PyObject* self, PyObject* args) { + PyObject* worker_capsule; + PyObject* obj_capsule; + if (!PyArg_ParseTuple(args, "OO", &worker_capsule, &obj_capsule)) { + return NULL; + } + Worker* worker = get_pointer_or_fail(worker_capsule, "worker"); + if (!worker) { + return NULL; + } + Obj* obj = get_pointer_or_fail(obj_capsule, "obj"); + if (!obj) { + return NULL; + } + ObjRef objref = worker->push_object(obj); + return make_pyobjref(objref); +} + +// TODO: test this +PyObject* put_object(PyObject* self, PyObject* args) { + PyObject* worker_capsule; + PyObject* pyobjref; + PyObject* obj_capsule; + if (!PyArg_ParseTuple(args, "OOO", &worker_capsule, &pyobjref, &obj_capsule)) { + return NULL; + } + Worker* worker = get_pointer_or_fail(worker_capsule, "worker"); + if (!worker) { + return NULL; + } + Obj* obj = get_pointer_or_fail(obj_capsule, "obj"); + if (!obj) { + return NULL; + } + ObjRef objref = ((PyObjRef*) pyobjref)->val; + worker->put_object(objref, obj); + Py_RETURN_NONE; +} + +PyObject* get_object(PyObject* self, PyObject* args) { + PyObject* worker_capsule; + PyObject* pyobjref; + if (!PyArg_ParseTuple(args, "OO", &worker_capsule, &pyobjref)) { + return NULL; + } + Worker* worker = get_pointer_or_fail(worker_capsule, "worker"); + if (!worker) { + return NULL; + } + ObjRef objref = ((PyObjRef*) pyobjref)->val; + slice s = worker->get_object(objref); + Obj* obj = new Obj(); // TODO: Make sure this will get deleted + obj->ParseFromString(std::string(s.data, s.len)); + return PyCapsule_New(static_cast(obj), "obj", NULL); +} + +// TODO: implement this +PyObject* pull_object(PyObject* self, PyObject* args) { + PyObject* worker_capsule; + PyObject* pyobjref; + if (!PyArg_ParseTuple(args, "OO", &worker_capsule, &pyobjref)) { + return NULL; + } + Worker* worker = get_pointer_or_fail(worker_capsule, "worker"); + if (!worker) { + return NULL; + } + ObjRef objref = ((PyObjRef*) pyobjref)->val; + slice s = worker->get_object(objref); + Obj* obj = new Obj(); // TODO: Make sure this will get deleted + obj->ParseFromString(std::string(s.data, s.len)); + return PyCapsule_New(static_cast(obj), "obj", NULL); +} + +// TODO: test this +PyObject* start_worker_service(PyObject* self, PyObject* args) { + PyObject* worker_capsule; + if (!PyArg_ParseTuple(args, "O", &worker_capsule)) { + return NULL; + } + Worker* worker = get_pointer_or_fail(worker_capsule, "worker"); + if (!worker) { + return NULL; + } + worker->start_worker_service(); + Py_RETURN_NONE; +} + +static PyMethodDef SymphonyMethods[] = { + { "serialize_object", serialize_object, METH_VARARGS, "serialize an object to protocol buffers" }, + { "deserialize_object", deserialize_object, METH_VARARGS, "deserialize an object from protocol buffers" }, + { "serialize_call", serialize_call, METH_VARARGS, "serialize a call to protocol buffers" }, + { "deserialize_call", deserialize_call, METH_VARARGS, "deserialize a call from protocol buffers" }, + { "create_worker", create_worker, METH_VARARGS, "connect to the scheduler and the object store" }, + { "register_function", register_function, METH_VARARGS, "register a function with the scheduler" }, + { "put_object", put_object, METH_VARARGS, "put a protocol buffer object (given as a capsule) on the local object store" }, + { "get_object", get_object, METH_VARARGS, "get protocol buffer object from the local object store" }, + { "push_object", push_object, METH_VARARGS, "push a protocol buffer object (given as a capsule) to the object store" }, + { "pull_object" , pull_object, METH_VARARGS, "pull object with a given object id from the object store" }, + { "wait_for_next_task", wait_for_next_task, METH_VARARGS, "get next task from scheduler (blocking)" }, + { "remote_call", remote_call, METH_VARARGS, "call a remote function" }, + { "start_worker_service", start_worker_service, METH_VARARGS, "start the worker service" }, + { NULL, NULL, 0, NULL } +}; + +PyMODINIT_FUNC initliborchpylib(void) { + PyObject* m; + PyObjRefType.tp_new = PyType_GenericNew; + if (PyType_Ready(&PyObjRefType) < 0) { + return; + } + m = Py_InitModule3("liborchpylib", SymphonyMethods, "Python C Extension for Orchestra"); + Py_INCREF(&PyObjRefType); + PyModule_AddObject(m, "ObjRef", (PyObject *)&PyObjRefType); + OrchPyError = PyErr_NewException("orchpy.error", NULL, NULL); + Py_INCREF(OrchPyError); + PyModule_AddObject(m, "error", OrchPyError); + import_array(); +} + +} diff --git a/src/scheduler.cc b/src/scheduler.cc index 3877f396f..4c5e5268a 100644 --- a/src/scheduler.cc +++ b/src/scheduler.cc @@ -1,3 +1,5 @@ +#include + #include "scheduler.h" Status SchedulerService::RemoteCall(ServerContext* context, const RemoteCallRequest* request, RemoteCallReply* reply) { @@ -30,6 +32,24 @@ Status SchedulerService::PushObj(ServerContext* context, const PushObjRequest* r } Status SchedulerService::PullObj(ServerContext* context, const PullObjRequest* request, AckReply* reply) { + std::lock_guard objtable_lock(objtable_lock_); + ObjRef objref = request->objref(); + if (objref >= objtable_.size() || objtable_[objref].size() == 0) { + std::cout << "internal error: no object with objref exists" << std::endl; + std::exit(1); + } + std::mt19937 rng; + std::uniform_int_distribution uni(0, objtable_[objref].size()-1); + ObjStoreId objstoreid = uni(rng); + std::lock_guard objstore_lock(objstores_lock_); + + DeliverObjRequest deliver_request; + ObjStoreId id = get_store(request->workerid()); + deliver_request.set_objstore_address(objstores_[id].address); + deliver_request.set_objref(objref); + AckReply deliver_reply; + ClientContext deliver_context; + objstores_[objstoreid].objstore_stub->DeliverObj(&deliver_context, deliver_request, &deliver_reply); return Status::OK; } @@ -117,8 +137,8 @@ WorkerId SchedulerService::register_worker(const std::string& worker_address, co ObjStoreId objstoreid = std::numeric_limits::max(); objstores_lock_.lock(); for (size_t i = 0; i < objstores_.size(); ++i) { - std::cout << "adress: " << objstores_[i].address << std::endl; - std::cout << "my adress: " << objstore_address << std::endl; + std::cout << "address: " << objstores_[i].address << std::endl; + std::cout << "my address: " << objstore_address << std::endl; if (objstores_[i].address == objstore_address) { objstoreid = i; } diff --git a/src/worker.cc b/src/worker.cc index 057735e3b..9b4ba2b77 100644 --- a/src/worker.cc +++ b/src/worker.cc @@ -19,11 +19,11 @@ Status WorkerServiceImpl::InvokeCall(ServerContext* context, const InvokeCallReq return Status::OK; } -size_t Worker::remote_call(RemoteCallRequest* request) { +RemoteCallReply Worker::remote_call(RemoteCallRequest* request) { RemoteCallReply reply; ClientContext context; Status status = scheduler_stub_->RemoteCall(&context, *request, &reply); - // TODO: Return results: return reply.result(0); + return reply; } void Worker::register_worker(const std::string& worker_address, const std::string& objstore_address) { @@ -37,7 +37,17 @@ void Worker::register_worker(const std::string& worker_address, const std::strin return; } -ObjRef Worker::push_obj(const Obj* obj) { +slice Worker::pull_object(ObjRef objref) { + PullObjRequest request; + request.set_workerid(workerid_); + request.set_objref(objref); + AckReply reply; + ClientContext context; + Status status = scheduler_stub_->PullObj(&context, request, &reply); + return get_object(objref); +} + +ObjRef Worker::push_object(const Obj* obj) { // first get objref for the new object PushObjRequest push_request; PushObjReply push_reply; @@ -45,11 +55,11 @@ ObjRef Worker::push_obj(const Obj* obj) { Status push_status = scheduler_stub_->PushObj(&push_context, push_request, &push_reply); ObjRef objref = push_reply.objref(); // then stream the object to the object store - put_obj(objref, obj); + put_object(objref, obj); return objref; } -slice Worker::get_serialized_obj(ObjRef objref) { +slice Worker::get_object(ObjRef objref) { ClientContext context; GetObjRequest request; request.set_objref(objref); @@ -62,7 +72,8 @@ slice Worker::get_serialized_obj(ObjRef objref) { return slice; } -void Worker::put_obj(ObjRef objref, const Obj* obj) { +// TODO: Do this with shared memory +void Worker::put_object(ObjRef objref, const Obj* obj) { ObjChunk chunk; std::string data; obj->SerializeToString(&data); diff --git a/src/worker.h b/src/worker.h index b72d9aba4..07a88d31c 100644 --- a/src/worker.h +++ b/src/worker.h @@ -19,8 +19,6 @@ using grpc::ServerContext; using grpc::Status; #include "orchestra.grpc.pb.h" -#include "orchlib.h" - #include "orchestra/orchestra.h" using grpc::Channel; @@ -46,15 +44,17 @@ class Worker { {} // submit a remote call to the scheduler - size_t remote_call(RemoteCallRequest* request); + RemoteCallReply remote_call(RemoteCallRequest* request); // send request to the scheduler to register this worker void register_worker(const std::string& worker_address, const std::string& objstore_address); // push object to local object store, register it with the server and return object reference - ObjRef push_obj(const Obj* obj); - // retrieve serialized object from local object store - slice get_serialized_obj(ObjRef objref); + ObjRef push_object(const Obj* obj); + // pull object from a potentially remote object store + slice pull_object(ObjRef objref); // stores an object to the local object store - void put_obj(ObjRef objref, const Obj* obj); + void put_object(ObjRef objref, const Obj* obj); + // retrieve serialized object from local object store + slice get_object(ObjRef objref); // register function with scheduler void register_function(const std::string& name, size_t num_return_vals); // start the worker server which accepts tasks from the scheduler and stores diff --git a/test/runtest.py b/test/runtest.py index 43f584dc8..cce69a113 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -1,5 +1,5 @@ import unittest -import orchpy.unison as unison +import orchpy import orchpy.services as services import orchpy.worker as worker import numpy as np @@ -45,6 +45,21 @@ def new_objstore_port(): objstore_port_counter += 1 return 20000 + objstore_port_counter +class SerializationTest(unittest.TestCase): + + def roundTripTest(self, data): + serialized = orchpy.lib.serialize_object(data) + result = orchpy.lib.deserialize_object(serialized) + self.assertEqual(data, result) + + def testSerialize(self): + data = [1, "hello", 3.0] + self.roundTripTest(data) + + a = np.zeros((100, 100)) + res = orchpy.lib.serialize_object(a) + b = orchpy.lib.deserialize_object(res) + self.assertTrue((a == b).all()) class ObjStoreTest(unittest.TestCase): @@ -67,22 +82,22 @@ class ObjStoreTest(unittest.TestCase): objstore2_stub = connect_to_objstore(IP_ADDRESS, objstore2_port) worker1 = worker.Worker() - worker1.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, worker1_port), address(IP_ADDRESS, objstore1_port)) + worker.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore1_port), address(IP_ADDRESS, worker1_port), worker1) worker2 = worker.Worker() - worker2.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, worker2_port), address(IP_ADDRESS, objstore2_port)) + worker.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore2_port), address(IP_ADDRESS, worker2_port), worker2) # pushing and pulling an object shouldn't change it for data in ["h", "h" * 10000, 0, 0.0]: - objref = worker1.push(data) - result = worker1.pull(objref) + objref = worker.push(data, worker1) + result = worker.pull(objref, worker1) self.assertEqual(result, data) # pushing an object, shipping it to another worker, and pulling it shouldn't change it for data in ["h", "h" * 10000, 0, 0.0]: - objref = worker1.push(data) - response = objstore1_stub.DeliverObj(orchestra_pb2.DeliverObjRequest(objref=objref.get_id(), objstore_address=address(IP_ADDRESS, objstore2_port)), TIMEOUT_SECONDS) - result = worker2.pull(objref) + objref = worker.push(data, worker1) + response = objstore1_stub.DeliverObj(orchestra_pb2.DeliverObjRequest(objref=objref.val, objstore_address=address(IP_ADDRESS, objstore2_port)), TIMEOUT_SECONDS) + result = worker.pull(objref, worker2) self.assertEqual(result, data) services.cleanup() @@ -106,8 +121,7 @@ class SchedulerTest(unittest.TestCase): time.sleep(0.2) worker1 = worker.Worker() - worker1.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, worker1_port), address(IP_ADDRESS, objstore_port)) - worker1.start_worker_service() + worker.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port), worker1) test_dir = os.path.dirname(os.path.abspath(__file__)) test_path = os.path.join(test_dir, "testrecv.py") @@ -115,7 +129,7 @@ class SchedulerTest(unittest.TestCase): time.sleep(0.2) - worker1.call("hello_world", ["hi"]) + worker1.remote_call("print_string", ["hi"]) time.sleep(0.1) diff --git a/test/shell.py b/test/shell.py index 74c5791b5..5cc683d9d 100644 --- a/test/shell.py +++ b/test/shell.py @@ -1,4 +1,5 @@ -import orchpy.unison as unison +import argparse + import orchpy.services as services import orchpy.worker as worker @@ -6,13 +7,33 @@ from grpc.beta import implementations import orchestra_pb2 import types_pb2 +parser = argparse.ArgumentParser(description='Parse addresses for the worker to connect to.') +parser.add_argument("--ip_address", default="127.0.0.1", help="the IP address to use for both the scheduler and objstore") +parser.add_argument("--scheduler_port", default=10001, type=int, help="the scheduler's port") +parser.add_argument("--objstore_port", default=20001, type=int, help="the objstore's port") +parser.add_argument("--worker_port", default=40001, type=int, help="the worker's port") + +@worker.distributed([str], [str]) +def print_string(string): + print "called print_string with", string + f = open("asdfasdf.txt", "w") + f.write("successfully called print_string with argument {}.".format(string)) + return string + +@worker.distributed([int, int], [int, int]) +def handle_int(a, b): + return a + 1, b + 1 + def connect_to_scheduler(host, port): channel = implementations.insecure_channel(host, port) return orchestra_pb2.beta_create_Scheduler_stub(channel) +def address(host, port): + return host + ":" + str(port) + if __name__ == '__main__': - scheduler_stub = connect_to_scheduler("127.0.0.1", 22221) - worker = worker.Worker() - worker.connect("127.0.0.1:22221", "127.0.0.1:10000", "127.0.0.1:22222") + args = parser.parse_args() + scheduler_stub = connect_to_scheduler(args.ip_address, args.scheduler_port) + worker.connect(address(args.ip_address, args.scheduler_port), address(args.ip_address, args.objstore_port), address(args.ip_address, args.worker_port)) import IPython IPython.embed() diff --git a/test/testrecv.py b/test/testrecv.py index 13383a728..759fbf743 100644 --- a/test/testrecv.py +++ b/test/testrecv.py @@ -1,37 +1,34 @@ -import sys +import argparse -import orchpy.unison as unison +import orchpy import orchpy.services as services import orchpy.worker as worker +parser = argparse.ArgumentParser(description='Parse addresses for the worker to connect to.') +parser.add_argument("--ip_address", default="127.0.0.1", help="the IP address to use for both the scheduler and objstore") +parser.add_argument("--scheduler_port", default=10001, type=int, help="the scheduler's port") +parser.add_argument("--objstore_port", default=20001, type=int, help="the objstore's port") +parser.add_argument("--worker_port", default=40001, type=int, help="the worker's port") + @worker.distributed([str], [str]) def print_string(string): print "called print_string with", string + f = open("asdfasdf.txt", "w") + f.write("successfully called print_string with argument {}.".format(string)) return string @worker.distributed([int, int], [int, int]) def handle_int(a, b): return a + 1, b + 1 +def address(host, port): + return host + ":" + str(port) + if __name__ == '__main__': - ip_address = sys.argv[1] - scheduler_port = sys.argv[2] - worker_port = sys.argv[3] - objstore_port = sys.argv[4] + args = parser.parse_args() + worker.connect(address(args.ip_address, args.scheduler_port), address(args.ip_address, args.objstore_port), address(args.ip_address, args.worker_port)) - def address(host, port): - return host + ":" + str(port) + worker.global_worker.register_function(print_string) + worker.global_worker.register_function(handle_int) - worker = worker.Worker() - worker.connect(address(ip_address, scheduler_port), address(ip_address, worker_port), address(ip_address, objstore_port)) - worker.start_worker_service() - - worker.register_function("print_string", print_string, 0) - worker.register_function("handle_int", handle_int, 0) - - name, args, returnref = worker.wait_for_next_task() - print "received args ", args - if args == ["hi"]: - sys.exit(0) - else: - sys.exit(1) + worker.main_loop()