From 2500cbaf729dad637a2da169a3391902cd352f29 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Wed, 16 Mar 2016 18:11:43 -0700 Subject: [PATCH] improve serialization and add DistArray class --- lib/orchpy/arrays/dist/__init__.py | 1 + lib/orchpy/arrays/dist/core.py | 60 ++++++++++++++++++++++++++++ lib/orchpy/arrays/single/__init__.py | 2 +- lib/orchpy/arrays/single/core.py | 4 ++ lib/orchpy/orchpy/__init__.py | 1 + lib/orchpy/orchpy/serialization.py | 26 ++++++++++++ lib/orchpy/orchpy/worker.py | 11 ++--- test/arrays_test.py | 46 ++++++++++++++++++++- test/runtest.py | 53 ++++++++++++------------ 9 files changed, 170 insertions(+), 34 deletions(-) create mode 100644 lib/orchpy/arrays/dist/__init__.py create mode 100644 lib/orchpy/arrays/dist/core.py create mode 100644 lib/orchpy/orchpy/serialization.py diff --git a/lib/orchpy/arrays/dist/__init__.py b/lib/orchpy/arrays/dist/__init__.py new file mode 100644 index 000000000..e1d2e30fd --- /dev/null +++ b/lib/orchpy/arrays/dist/__init__.py @@ -0,0 +1 @@ +from core import DistArray, BLOCK_SIZE diff --git a/lib/orchpy/arrays/dist/core.py b/lib/orchpy/arrays/dist/core.py new file mode 100644 index 000000000..024faa2cd --- /dev/null +++ b/lib/orchpy/arrays/dist/core.py @@ -0,0 +1,60 @@ +from typing import List +import numpy as np +import arrays.single as single +import orchpy as op + +BLOCK_SIZE = 10 + +class DistArray(object): + def construct(self, shape, dtype, objrefs): + self.shape = shape + self.dtype = dtype + self.objrefs = objrefs + self.ndim = len(shape) + self.num_blocks = [int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in self.shape] + if self.num_blocks != list(self.objrefs.shape): + raise Exception("The fields `num_blocks` and `objrefs` are inconsistent, `num_blocks` is {} and `objrefs` has shape {}".format(self.num_blocks, list(self.objrefs.shape))) + + def deserialize(self, primitives): + (shape, dtype_name, objrefs) = primitives + self.construct(shape, np.dtype(dtype_name), objrefs) + + def serialize(self): + return (self.shape, self.dtype.__name__, self.objrefs) + + def __init__(self): + self.shape = None + self.dtype = None + self.objrefs = None + + def compute_block_lower(self, index): + if len(index) != self.ndim: + raise Exception("The value `index` equals {}, but `ndim` is {}.".format(index, self.ndim)) + return [elem * BLOCK_SIZE for elem in index] + + def compute_block_upper(self, index): + if len(index) != self.ndim: + raise Exception("The value `index` equals {}, but `ndim` is {}.".format(index, self.ndim)) + upper = [] + for i in range(self.ndim): + upper.append(min((index[i] + 1) * BLOCK_SIZE, self.shape[i])) + return upper + + def compute_block_shape(self, index): + lower = self.compute_block_lower(index) + upper = self.compute_block_upper(index) + return [u - l for (l, u) in zip(lower, upper)] + + def assemble(self): + """Assemble an array on this node from a distributed array object reference.""" + result = np.zeros(self.shape) + for index in np.ndindex(*self.num_blocks): + lower = self.compute_block_lower(index) + upper = self.compute_block_upper(index) + result[[slice(l, u) for (l, u) in zip(lower, upper)]] = op.pull(self.objrefs[index]) + return result + + def __getitem__(self, sliced): + # TODO(rkn): fix this, this is just a placeholder that should work but is inefficient + a = self.assemble() + return a[sliced] diff --git a/lib/orchpy/arrays/single/__init__.py b/lib/orchpy/arrays/single/__init__.py index 512a6e544..2890639cb 100644 --- a/lib/orchpy/arrays/single/__init__.py +++ b/lib/orchpy/arrays/single/__init__.py @@ -1,2 +1,2 @@ import random, linalg -from core import zeros, eye, dot, vstack, hstack, subarray, copy, tril, triu +from core import zeros, ones, eye, dot, vstack, hstack, subarray, copy, tril, triu diff --git a/lib/orchpy/arrays/single/core.py b/lib/orchpy/arrays/single/core.py index 598734627..914c51ec0 100644 --- a/lib/orchpy/arrays/single/core.py +++ b/lib/orchpy/arrays/single/core.py @@ -6,6 +6,10 @@ import orchpy as op def zeros(shape): return np.zeros(shape) +@op.distributed([List[int]], [np.ndarray]) +def ones(shape): + return np.ones(shape) + @op.distributed([int], [np.ndarray]) def eye(dim): return np.eye(dim) diff --git a/lib/orchpy/orchpy/__init__.py b/lib/orchpy/orchpy/__init__.py index bab8ef44b..9e4123b40 100644 --- a/lib/orchpy/orchpy/__init__.py +++ b/lib/orchpy/orchpy/__init__.py @@ -1,2 +1,3 @@ import liborchpylib as lib +import serialization from worker import register_module, connect, pull, push, distributed diff --git a/lib/orchpy/orchpy/serialization.py b/lib/orchpy/orchpy/serialization.py new file mode 100644 index 000000000..dc5c0952c --- /dev/null +++ b/lib/orchpy/orchpy/serialization.py @@ -0,0 +1,26 @@ +import importlib + +import orchpy + +def serialize(obj): + if hasattr(obj, "serialize"): + primitive_obj = ((type(obj).__module__, type(obj).__name__), obj.serialize()) + else: + # TODO(rkn): Right now we don't handle arbitrary python objects, but later + # we can unpack the fields of a python object into a list and call + # orchpy.lib.serialize_object. + primitive_obj = ("primitive", obj) + return orchpy.lib.serialize_object(primitive_obj) + +def deserialize(capsule): + primitive_obj = orchpy.lib.deserialize_object(capsule) + if primitive_obj[0] == "primitive": + return primitive_obj[1] + else: + # assert primitive_obj[0] must be a tuple of module and class name + type_module, type_name = primitive_obj[0] + module = importlib.import_module(type_module) + if hasattr(module.__dict__[type_name], "deserialize"): + obj = module.__dict__[type_name]() + obj.deserialize(primitive_obj[1]) + return obj diff --git a/lib/orchpy/orchpy/worker.py b/lib/orchpy/orchpy/worker.py index 87c3e4479..6cf787221 100644 --- a/lib/orchpy/orchpy/worker.py +++ b/lib/orchpy/orchpy/worker.py @@ -2,6 +2,7 @@ from types import ModuleType import typing import orchpy +import serialization class Worker(object): """The methods in this class are considered unexposed to the user. The functions outside of this class are considered exposed.""" @@ -13,13 +14,13 @@ class Worker(object): def put_object(self, objref, value): """Put `value` in the local object store with objref `objref`. This assumes that the value for `objref` has not yet been placed in the local object store.""" - object_capsule = orchpy.lib.serialize_object(value) + object_capsule = serialization.serialize(value) orchpy.lib.put_object(self.handle, objref, object_capsule) def get_object(self, objref): """Return the value from the local object store for objref `objref`. This will block until the value for `objref` has been written to the local object store.""" object_capsule = orchpy.lib.get_object(self.handle, objref) - return orchpy.lib.deserialize_object(object_capsule) + return serialization.deserialize(object_capsule) def register_function(self, function): """Notify the scheduler that this worker can execute the function with name `func_name`. Store the function `function` locally.""" @@ -47,16 +48,16 @@ def register_module(module, recursive=False, worker=global_worker): def connect(scheduler_addr, objstore_addr, worker_addr, worker=global_worker): if worker.connected: - raise Exception("Worker called connect, but worker is already connected") + del worker.handle # TODO(rkn): Make sure this actually deallocates (need a destructor for the capsule) worker.handle = orchpy.lib.create_worker(scheduler_addr, objstore_addr, worker_addr) worker.connected = True def pull(objref, worker=global_worker): object_capsule = orchpy.lib.pull_object(worker.handle, objref) - return orchpy.lib.deserialize_object(object_capsule) + return serialization.deserialize(object_capsule) def push(value, worker=global_worker): - object_capsule = orchpy.lib.serialize_object(value) + object_capsule = serialization.serialize(value) return orchpy.lib.push_object(worker.handle, object_capsule) def main_loop(worker=global_worker): diff --git a/test/arrays_test.py b/test/arrays_test.py index 289d050ac..ccc8e90a6 100644 --- a/test/arrays_test.py +++ b/test/arrays_test.py @@ -1,13 +1,14 @@ import unittest import orchpy +import orchpy.serialization as serialization import orchpy.services as services -import orchpy.worker as worker import numpy as np import time import subprocess32 as subprocess import os import arrays.single as single +import arrays.dist as dist from google.protobuf.text_format import * @@ -63,7 +64,7 @@ class ArraysSingleTest(unittest.TestCase): time.sleep(0.2) - worker.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port)) + orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port)) test_dir = os.path.dirname(os.path.abspath(__file__)) test_path = os.path.join(test_dir, "testrecv.py") @@ -98,5 +99,46 @@ class ArraysSingleTest(unittest.TestCase): services.cleanup() +class ArraysDistTest(unittest.TestCase): + + def testMethods(self): + x = dist.DistArray() + x.construct([2, 3, 4], float, np.array([[[orchpy.lib.ObjRef(0)]]])) + capsule = serialization.serialize(x) + y = serialization.deserialize(capsule) + self.assertEqual(x.shape, y.shape) + self.assertEqual(x.dtype, y.dtype) + self.assertEqual(x.objrefs[0, 0, 0].val, y.objrefs[0, 0, 0].val) + + def testAssemble(self): + scheduler_port = new_scheduler_port() + objstore_port = new_objstore_port() + worker1_port = new_worker_port() + worker2_port = new_worker_port() + + services.start_scheduler(address(IP_ADDRESS, scheduler_port)) + + time.sleep(0.1) + + services.start_objstore(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port)) + + time.sleep(0.2) + + orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port)) + + test_dir = os.path.dirname(os.path.abspath(__file__)) + test_path = os.path.join(test_dir, "testrecv.py") + services.start_worker(test_path, address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker2_port)) + + time.sleep(0.2) + + a = single.ones([dist.BLOCK_SIZE, dist.BLOCK_SIZE]) + b = single.zeros([dist.BLOCK_SIZE, dist.BLOCK_SIZE]) + x = dist.DistArray() + x.construct([2 * dist.BLOCK_SIZE, dist.BLOCK_SIZE], float, np.array([[a], [b]])) + self.assertTrue(np.alltrue(x.assemble() == np.vstack([np.ones([dist.BLOCK_SIZE, dist.BLOCK_SIZE]), np.zeros([dist.BLOCK_SIZE, dist.BLOCK_SIZE])]))) + + services.cleanup() + if __name__ == '__main__': unittest.main() diff --git a/test/runtest.py b/test/runtest.py index 195a56aa9..17877435c 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -1,5 +1,6 @@ import unittest import orchpy +import orchpy.serialization as serialization import orchpy.services as services import orchpy.worker as worker import numpy as np @@ -50,14 +51,14 @@ def new_objstore_port(): class SerializationTest(unittest.TestCase): def roundTripTest(self, data): - serialized = orchpy.lib.serialize_object(data) - result = orchpy.lib.deserialize_object(serialized) + serialized = serialization.serialize(data) + result = serialization.deserialize(serialized) self.assertEqual(data, result) def numpyTypeTest(self, typ): a = np.random.randint(0, 10, size=(100, 100)).astype(typ) - b = orchpy.lib.serialize_object(a) - c = orchpy.lib.deserialize_object(b) + b = serialization.serialize(a) + c = serialization.deserialize(b) self.assertTrue((a == c).all()) def testSerialize(self): @@ -68,8 +69,8 @@ class SerializationTest(unittest.TestCase): self.roundTripTest((1.0, "hi")) a = np.zeros((100, 100)) - res = orchpy.lib.serialize_object(a) - b = orchpy.lib.deserialize_object(res) + res = serialization.serialize(a) + b = serialization.deserialize(res) self.assertTrue((a == b).all()) self.numpyTypeTest('int8') @@ -80,8 +81,8 @@ class SerializationTest(unittest.TestCase): self.numpyTypeTest('float64') a = np.array([[orchpy.lib.ObjRef(0), orchpy.lib.ObjRef(1)], [orchpy.lib.ObjRef(41), orchpy.lib.ObjRef(42)]]) - capsule = orchpy.lib.serialize_object(a) - result = orchpy.lib.deserialize_object(capsule) + capsule = serialization.serialize(a) + result = serialization.deserialize(capsule) self.assertTrue((a == result).all()) class OrchPyLibTest(unittest.TestCase): @@ -101,7 +102,7 @@ class OrchPyLibTest(unittest.TestCase): w = worker.Worker() - worker.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker_port), w) + orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker_port), w) w.put_object(orchpy.lib.ObjRef(0), 'hello world') result = w.get_object(orchpy.lib.ObjRef(0)) @@ -134,22 +135,22 @@ class ObjStoreTest(unittest.TestCase): objstore2_stub = connect_to_objstore(IP_ADDRESS, objstore2_port) worker1 = worker.Worker() - worker.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore1_port), address(IP_ADDRESS, worker1_port), worker1) + orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore1_port), address(IP_ADDRESS, worker1_port), worker1) worker2 = worker.Worker() - worker.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore2_port), address(IP_ADDRESS, worker2_port), worker2) + orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore2_port), address(IP_ADDRESS, worker2_port), worker2) # pushing and pulling an object shouldn't change it for data in ["h", "h" * 10000, 0, 0.0]: - objref = worker.push(data, worker1) - result = worker.pull(objref, worker1) + objref = orchpy.push(data, worker1) + result = orchpy.pull(objref, worker1) self.assertEqual(result, data) # pushing an object, shipping it to another worker, and pulling it shouldn't change it for data in ["h", "h" * 10000, 0, 0.0]: - objref = worker.push(data, worker1) + objref = orchpy.push(data, worker1) response = objstore1_stub.DeliverObj(orchestra_pb2.DeliverObjRequest(objref=objref.val, objstore_address=address(IP_ADDRESS, objstore2_port)), TIMEOUT_SECONDS) - result = worker.pull(objref, worker2) + result = orchpy.pull(objref, worker2) self.assertEqual(result, data) services.cleanup() @@ -176,7 +177,7 @@ class SchedulerTest(unittest.TestCase): time.sleep(0.2) worker1 = worker.Worker() - worker.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port), worker1) + orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port), worker1) test_dir = os.path.dirname(os.path.abspath(__file__)) test_path = os.path.join(test_dir, "testrecv.py") @@ -189,7 +190,7 @@ class SchedulerTest(unittest.TestCase): time.sleep(0.2) - value_after = worker.pull(objref[0], worker1) + value_after = orchpy.pull(objref[0], worker1) self.assertEqual(value_before, value_after) time.sleep(0.1) @@ -214,30 +215,30 @@ class WorkerTest(unittest.TestCase): time.sleep(0.2) worker1 = worker.Worker() - worker.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port), worker1) + orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port), worker1) for i in range(100): value_before = i * 10 ** 6 - objref = worker.push(value_before, worker1) - value_after = worker.pull(objref, worker1) + objref = orchpy.push(value_before, worker1) + value_after = orchpy.pull(objref, worker1) self.assertEqual(value_before, value_after) for i in range(100): value_before = i * 10 ** 6 * 1.0 - objref = worker.push(value_before, worker1) - value_after = worker.pull(objref, worker1) + objref = orchpy.push(value_before, worker1) + value_after = orchpy.pull(objref, worker1) self.assertEqual(value_before, value_after) for i in range(100): value_before = "h" * i - objref = worker.push(value_before, worker1) - value_after = worker.pull(objref, worker1) + objref = orchpy.push(value_before, worker1) + value_after = orchpy.pull(objref, worker1) self.assertEqual(value_before, value_after) for i in range(100): value_before = [1] * i - objref = worker.push(value_before, worker1) - value_after = worker.pull(objref, worker1) + objref = orchpy.push(value_before, worker1) + value_after = orchpy.pull(objref, worker1) self.assertEqual(value_before, value_after) services.cleanup()