Merge pull request #16 from amplab/types

Types
This commit is contained in:
Philipp Moritz
2016-03-16 18:14:04 -07:00
9 changed files with 170 additions and 34 deletions
+1
View File
@@ -0,0 +1 @@
from core import DistArray, BLOCK_SIZE
+60
View File
@@ -0,0 +1,60 @@
from typing import List
import numpy as np
import arrays.single as single
import orchpy as op
BLOCK_SIZE = 10
class DistArray(object):
def construct(self, shape, dtype, objrefs):
self.shape = shape
self.dtype = dtype
self.objrefs = objrefs
self.ndim = len(shape)
self.num_blocks = [int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in self.shape]
if self.num_blocks != list(self.objrefs.shape):
raise Exception("The fields `num_blocks` and `objrefs` are inconsistent, `num_blocks` is {} and `objrefs` has shape {}".format(self.num_blocks, list(self.objrefs.shape)))
def deserialize(self, primitives):
(shape, dtype_name, objrefs) = primitives
self.construct(shape, np.dtype(dtype_name), objrefs)
def serialize(self):
return (self.shape, self.dtype.__name__, self.objrefs)
def __init__(self):
self.shape = None
self.dtype = None
self.objrefs = None
def compute_block_lower(self, index):
if len(index) != self.ndim:
raise Exception("The value `index` equals {}, but `ndim` is {}.".format(index, self.ndim))
return [elem * BLOCK_SIZE for elem in index]
def compute_block_upper(self, index):
if len(index) != self.ndim:
raise Exception("The value `index` equals {}, but `ndim` is {}.".format(index, self.ndim))
upper = []
for i in range(self.ndim):
upper.append(min((index[i] + 1) * BLOCK_SIZE, self.shape[i]))
return upper
def compute_block_shape(self, index):
lower = self.compute_block_lower(index)
upper = self.compute_block_upper(index)
return [u - l for (l, u) in zip(lower, upper)]
def assemble(self):
"""Assemble an array on this node from a distributed array object reference."""
result = np.zeros(self.shape)
for index in np.ndindex(*self.num_blocks):
lower = self.compute_block_lower(index)
upper = self.compute_block_upper(index)
result[[slice(l, u) for (l, u) in zip(lower, upper)]] = op.pull(self.objrefs[index])
return result
def __getitem__(self, sliced):
# TODO(rkn): fix this, this is just a placeholder that should work but is inefficient
a = self.assemble()
return a[sliced]
+1 -1
View File
@@ -1,2 +1,2 @@
import random, linalg
from core import zeros, eye, dot, vstack, hstack, subarray, copy, tril, triu
from core import zeros, ones, eye, dot, vstack, hstack, subarray, copy, tril, triu
+4
View File
@@ -6,6 +6,10 @@ import orchpy as op
def zeros(shape):
return np.zeros(shape)
@op.distributed([List[int]], [np.ndarray])
def ones(shape):
return np.ones(shape)
@op.distributed([int], [np.ndarray])
def eye(dim):
return np.eye(dim)
+1
View File
@@ -1,2 +1,3 @@
import liborchpylib as lib
import serialization
from worker import register_module, connect, pull, push, distributed
+26
View File
@@ -0,0 +1,26 @@
import importlib
import orchpy
def serialize(obj):
if hasattr(obj, "serialize"):
primitive_obj = ((type(obj).__module__, type(obj).__name__), obj.serialize())
else:
# TODO(rkn): Right now we don't handle arbitrary python objects, but later
# we can unpack the fields of a python object into a list and call
# orchpy.lib.serialize_object.
primitive_obj = ("primitive", obj)
return orchpy.lib.serialize_object(primitive_obj)
def deserialize(capsule):
primitive_obj = orchpy.lib.deserialize_object(capsule)
if primitive_obj[0] == "primitive":
return primitive_obj[1]
else:
# assert primitive_obj[0] must be a tuple of module and class name
type_module, type_name = primitive_obj[0]
module = importlib.import_module(type_module)
if hasattr(module.__dict__[type_name], "deserialize"):
obj = module.__dict__[type_name]()
obj.deserialize(primitive_obj[1])
return obj
+6 -5
View File
@@ -2,6 +2,7 @@ from types import ModuleType
import typing
import orchpy
import serialization
class Worker(object):
"""The methods in this class are considered unexposed to the user. The functions outside of this class are considered exposed."""
@@ -13,13 +14,13 @@ class Worker(object):
def put_object(self, objref, value):
"""Put `value` in the local object store with objref `objref`. This assumes that the value for `objref` has not yet been placed in the local object store."""
object_capsule = orchpy.lib.serialize_object(value)
object_capsule = serialization.serialize(value)
orchpy.lib.put_object(self.handle, objref, object_capsule)
def get_object(self, objref):
"""Return the value from the local object store for objref `objref`. This will block until the value for `objref` has been written to the local object store."""
object_capsule = orchpy.lib.get_object(self.handle, objref)
return orchpy.lib.deserialize_object(object_capsule)
return serialization.deserialize(object_capsule)
def register_function(self, function):
"""Notify the scheduler that this worker can execute the function with name `func_name`. Store the function `function` locally."""
@@ -47,16 +48,16 @@ def register_module(module, recursive=False, worker=global_worker):
def connect(scheduler_addr, objstore_addr, worker_addr, worker=global_worker):
if worker.connected:
raise Exception("Worker called connect, but worker is already connected")
del worker.handle # TODO(rkn): Make sure this actually deallocates (need a destructor for the capsule)
worker.handle = orchpy.lib.create_worker(scheduler_addr, objstore_addr, worker_addr)
worker.connected = True
def pull(objref, worker=global_worker):
object_capsule = orchpy.lib.pull_object(worker.handle, objref)
return orchpy.lib.deserialize_object(object_capsule)
return serialization.deserialize(object_capsule)
def push(value, worker=global_worker):
object_capsule = orchpy.lib.serialize_object(value)
object_capsule = serialization.serialize(value)
return orchpy.lib.push_object(worker.handle, object_capsule)
def main_loop(worker=global_worker):
+44 -2
View File
@@ -1,13 +1,14 @@
import unittest
import orchpy
import orchpy.serialization as serialization
import orchpy.services as services
import orchpy.worker as worker
import numpy as np
import time
import subprocess32 as subprocess
import os
import arrays.single as single
import arrays.dist as dist
from google.protobuf.text_format import *
@@ -63,7 +64,7 @@ class ArraysSingleTest(unittest.TestCase):
time.sleep(0.2)
worker.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port))
orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port))
test_dir = os.path.dirname(os.path.abspath(__file__))
test_path = os.path.join(test_dir, "testrecv.py")
@@ -98,5 +99,46 @@ class ArraysSingleTest(unittest.TestCase):
services.cleanup()
class ArraysDistTest(unittest.TestCase):
def testMethods(self):
x = dist.DistArray()
x.construct([2, 3, 4], float, np.array([[[orchpy.lib.ObjRef(0)]]]))
capsule = serialization.serialize(x)
y = serialization.deserialize(capsule)
self.assertEqual(x.shape, y.shape)
self.assertEqual(x.dtype, y.dtype)
self.assertEqual(x.objrefs[0, 0, 0].val, y.objrefs[0, 0, 0].val)
def testAssemble(self):
scheduler_port = new_scheduler_port()
objstore_port = new_objstore_port()
worker1_port = new_worker_port()
worker2_port = new_worker_port()
services.start_scheduler(address(IP_ADDRESS, scheduler_port))
time.sleep(0.1)
services.start_objstore(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port))
time.sleep(0.2)
orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port))
test_dir = os.path.dirname(os.path.abspath(__file__))
test_path = os.path.join(test_dir, "testrecv.py")
services.start_worker(test_path, address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker2_port))
time.sleep(0.2)
a = single.ones([dist.BLOCK_SIZE, dist.BLOCK_SIZE])
b = single.zeros([dist.BLOCK_SIZE, dist.BLOCK_SIZE])
x = dist.DistArray()
x.construct([2 * dist.BLOCK_SIZE, dist.BLOCK_SIZE], float, np.array([[a], [b]]))
self.assertTrue(np.alltrue(x.assemble() == np.vstack([np.ones([dist.BLOCK_SIZE, dist.BLOCK_SIZE]), np.zeros([dist.BLOCK_SIZE, dist.BLOCK_SIZE])])))
services.cleanup()
if __name__ == '__main__':
unittest.main()
+27 -26
View File
@@ -1,5 +1,6 @@
import unittest
import orchpy
import orchpy.serialization as serialization
import orchpy.services as services
import orchpy.worker as worker
import numpy as np
@@ -50,14 +51,14 @@ def new_objstore_port():
class SerializationTest(unittest.TestCase):
def roundTripTest(self, data):
serialized = orchpy.lib.serialize_object(data)
result = orchpy.lib.deserialize_object(serialized)
serialized = serialization.serialize(data)
result = serialization.deserialize(serialized)
self.assertEqual(data, result)
def numpyTypeTest(self, typ):
a = np.random.randint(0, 10, size=(100, 100)).astype(typ)
b = orchpy.lib.serialize_object(a)
c = orchpy.lib.deserialize_object(b)
b = serialization.serialize(a)
c = serialization.deserialize(b)
self.assertTrue((a == c).all())
def testSerialize(self):
@@ -68,8 +69,8 @@ class SerializationTest(unittest.TestCase):
self.roundTripTest((1.0, "hi"))
a = np.zeros((100, 100))
res = orchpy.lib.serialize_object(a)
b = orchpy.lib.deserialize_object(res)
res = serialization.serialize(a)
b = serialization.deserialize(res)
self.assertTrue((a == b).all())
self.numpyTypeTest('int8')
@@ -80,8 +81,8 @@ class SerializationTest(unittest.TestCase):
self.numpyTypeTest('float64')
a = np.array([[orchpy.lib.ObjRef(0), orchpy.lib.ObjRef(1)], [orchpy.lib.ObjRef(41), orchpy.lib.ObjRef(42)]])
capsule = orchpy.lib.serialize_object(a)
result = orchpy.lib.deserialize_object(capsule)
capsule = serialization.serialize(a)
result = serialization.deserialize(capsule)
self.assertTrue((a == result).all())
class OrchPyLibTest(unittest.TestCase):
@@ -101,7 +102,7 @@ class OrchPyLibTest(unittest.TestCase):
w = worker.Worker()
worker.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker_port), w)
orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker_port), w)
w.put_object(orchpy.lib.ObjRef(0), 'hello world')
result = w.get_object(orchpy.lib.ObjRef(0))
@@ -134,22 +135,22 @@ class ObjStoreTest(unittest.TestCase):
objstore2_stub = connect_to_objstore(IP_ADDRESS, objstore2_port)
worker1 = worker.Worker()
worker.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore1_port), address(IP_ADDRESS, worker1_port), worker1)
orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore1_port), address(IP_ADDRESS, worker1_port), worker1)
worker2 = worker.Worker()
worker.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore2_port), address(IP_ADDRESS, worker2_port), worker2)
orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore2_port), address(IP_ADDRESS, worker2_port), worker2)
# pushing and pulling an object shouldn't change it
for data in ["h", "h" * 10000, 0, 0.0]:
objref = worker.push(data, worker1)
result = worker.pull(objref, worker1)
objref = orchpy.push(data, worker1)
result = orchpy.pull(objref, worker1)
self.assertEqual(result, data)
# pushing an object, shipping it to another worker, and pulling it shouldn't change it
for data in ["h", "h" * 10000, 0, 0.0]:
objref = worker.push(data, worker1)
objref = orchpy.push(data, worker1)
response = objstore1_stub.DeliverObj(orchestra_pb2.DeliverObjRequest(objref=objref.val, objstore_address=address(IP_ADDRESS, objstore2_port)), TIMEOUT_SECONDS)
result = worker.pull(objref, worker2)
result = orchpy.pull(objref, worker2)
self.assertEqual(result, data)
services.cleanup()
@@ -176,7 +177,7 @@ class SchedulerTest(unittest.TestCase):
time.sleep(0.2)
worker1 = worker.Worker()
worker.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port), worker1)
orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port), worker1)
test_dir = os.path.dirname(os.path.abspath(__file__))
test_path = os.path.join(test_dir, "testrecv.py")
@@ -189,7 +190,7 @@ class SchedulerTest(unittest.TestCase):
time.sleep(0.2)
value_after = worker.pull(objref[0], worker1)
value_after = orchpy.pull(objref[0], worker1)
self.assertEqual(value_before, value_after)
time.sleep(0.1)
@@ -214,30 +215,30 @@ class WorkerTest(unittest.TestCase):
time.sleep(0.2)
worker1 = worker.Worker()
worker.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port), worker1)
orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port), worker1)
for i in range(100):
value_before = i * 10 ** 6
objref = worker.push(value_before, worker1)
value_after = worker.pull(objref, worker1)
objref = orchpy.push(value_before, worker1)
value_after = orchpy.pull(objref, worker1)
self.assertEqual(value_before, value_after)
for i in range(100):
value_before = i * 10 ** 6 * 1.0
objref = worker.push(value_before, worker1)
value_after = worker.pull(objref, worker1)
objref = orchpy.push(value_before, worker1)
value_after = orchpy.pull(objref, worker1)
self.assertEqual(value_before, value_after)
for i in range(100):
value_before = "h" * i
objref = worker.push(value_before, worker1)
value_after = worker.pull(objref, worker1)
objref = orchpy.push(value_before, worker1)
value_after = orchpy.pull(objref, worker1)
self.assertEqual(value_before, value_after)
for i in range(100):
value_before = [1] * i
objref = worker.push(value_before, worker1)
value_after = worker.pull(objref, worker1)
objref = orchpy.push(value_before, worker1)
value_after = orchpy.pull(objref, worker1)
self.assertEqual(value_before, value_after)
services.cleanup()