mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:06:31 +08:00
Vendored
+1
@@ -0,0 +1 @@
|
||||
from core import DistArray, BLOCK_SIZE
|
||||
Vendored
+60
@@ -0,0 +1,60 @@
|
||||
from typing import List
|
||||
import numpy as np
|
||||
import arrays.single as single
|
||||
import orchpy as op
|
||||
|
||||
BLOCK_SIZE = 10
|
||||
|
||||
class DistArray(object):
|
||||
def construct(self, shape, dtype, objrefs):
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
self.objrefs = objrefs
|
||||
self.ndim = len(shape)
|
||||
self.num_blocks = [int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in self.shape]
|
||||
if self.num_blocks != list(self.objrefs.shape):
|
||||
raise Exception("The fields `num_blocks` and `objrefs` are inconsistent, `num_blocks` is {} and `objrefs` has shape {}".format(self.num_blocks, list(self.objrefs.shape)))
|
||||
|
||||
def deserialize(self, primitives):
|
||||
(shape, dtype_name, objrefs) = primitives
|
||||
self.construct(shape, np.dtype(dtype_name), objrefs)
|
||||
|
||||
def serialize(self):
|
||||
return (self.shape, self.dtype.__name__, self.objrefs)
|
||||
|
||||
def __init__(self):
|
||||
self.shape = None
|
||||
self.dtype = None
|
||||
self.objrefs = None
|
||||
|
||||
def compute_block_lower(self, index):
|
||||
if len(index) != self.ndim:
|
||||
raise Exception("The value `index` equals {}, but `ndim` is {}.".format(index, self.ndim))
|
||||
return [elem * BLOCK_SIZE for elem in index]
|
||||
|
||||
def compute_block_upper(self, index):
|
||||
if len(index) != self.ndim:
|
||||
raise Exception("The value `index` equals {}, but `ndim` is {}.".format(index, self.ndim))
|
||||
upper = []
|
||||
for i in range(self.ndim):
|
||||
upper.append(min((index[i] + 1) * BLOCK_SIZE, self.shape[i]))
|
||||
return upper
|
||||
|
||||
def compute_block_shape(self, index):
|
||||
lower = self.compute_block_lower(index)
|
||||
upper = self.compute_block_upper(index)
|
||||
return [u - l for (l, u) in zip(lower, upper)]
|
||||
|
||||
def assemble(self):
|
||||
"""Assemble an array on this node from a distributed array object reference."""
|
||||
result = np.zeros(self.shape)
|
||||
for index in np.ndindex(*self.num_blocks):
|
||||
lower = self.compute_block_lower(index)
|
||||
upper = self.compute_block_upper(index)
|
||||
result[[slice(l, u) for (l, u) in zip(lower, upper)]] = op.pull(self.objrefs[index])
|
||||
return result
|
||||
|
||||
def __getitem__(self, sliced):
|
||||
# TODO(rkn): fix this, this is just a placeholder that should work but is inefficient
|
||||
a = self.assemble()
|
||||
return a[sliced]
|
||||
@@ -1,2 +1,2 @@
|
||||
import random, linalg
|
||||
from core import zeros, eye, dot, vstack, hstack, subarray, copy, tril, triu
|
||||
from core import zeros, ones, eye, dot, vstack, hstack, subarray, copy, tril, triu
|
||||
|
||||
@@ -6,6 +6,10 @@ import orchpy as op
|
||||
def zeros(shape):
|
||||
return np.zeros(shape)
|
||||
|
||||
@op.distributed([List[int]], [np.ndarray])
|
||||
def ones(shape):
|
||||
return np.ones(shape)
|
||||
|
||||
@op.distributed([int], [np.ndarray])
|
||||
def eye(dim):
|
||||
return np.eye(dim)
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
import liborchpylib as lib
|
||||
import serialization
|
||||
from worker import register_module, connect, pull, push, distributed
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
import importlib
|
||||
|
||||
import orchpy
|
||||
|
||||
def serialize(obj):
|
||||
if hasattr(obj, "serialize"):
|
||||
primitive_obj = ((type(obj).__module__, type(obj).__name__), obj.serialize())
|
||||
else:
|
||||
# TODO(rkn): Right now we don't handle arbitrary python objects, but later
|
||||
# we can unpack the fields of a python object into a list and call
|
||||
# orchpy.lib.serialize_object.
|
||||
primitive_obj = ("primitive", obj)
|
||||
return orchpy.lib.serialize_object(primitive_obj)
|
||||
|
||||
def deserialize(capsule):
|
||||
primitive_obj = orchpy.lib.deserialize_object(capsule)
|
||||
if primitive_obj[0] == "primitive":
|
||||
return primitive_obj[1]
|
||||
else:
|
||||
# assert primitive_obj[0] must be a tuple of module and class name
|
||||
type_module, type_name = primitive_obj[0]
|
||||
module = importlib.import_module(type_module)
|
||||
if hasattr(module.__dict__[type_name], "deserialize"):
|
||||
obj = module.__dict__[type_name]()
|
||||
obj.deserialize(primitive_obj[1])
|
||||
return obj
|
||||
@@ -2,6 +2,7 @@ from types import ModuleType
|
||||
import typing
|
||||
|
||||
import orchpy
|
||||
import serialization
|
||||
|
||||
class Worker(object):
|
||||
"""The methods in this class are considered unexposed to the user. The functions outside of this class are considered exposed."""
|
||||
@@ -13,13 +14,13 @@ class Worker(object):
|
||||
|
||||
def put_object(self, objref, value):
|
||||
"""Put `value` in the local object store with objref `objref`. This assumes that the value for `objref` has not yet been placed in the local object store."""
|
||||
object_capsule = orchpy.lib.serialize_object(value)
|
||||
object_capsule = serialization.serialize(value)
|
||||
orchpy.lib.put_object(self.handle, objref, object_capsule)
|
||||
|
||||
def get_object(self, objref):
|
||||
"""Return the value from the local object store for objref `objref`. This will block until the value for `objref` has been written to the local object store."""
|
||||
object_capsule = orchpy.lib.get_object(self.handle, objref)
|
||||
return orchpy.lib.deserialize_object(object_capsule)
|
||||
return serialization.deserialize(object_capsule)
|
||||
|
||||
def register_function(self, function):
|
||||
"""Notify the scheduler that this worker can execute the function with name `func_name`. Store the function `function` locally."""
|
||||
@@ -47,16 +48,16 @@ def register_module(module, recursive=False, worker=global_worker):
|
||||
|
||||
def connect(scheduler_addr, objstore_addr, worker_addr, worker=global_worker):
|
||||
if worker.connected:
|
||||
raise Exception("Worker called connect, but worker is already connected")
|
||||
del worker.handle # TODO(rkn): Make sure this actually deallocates (need a destructor for the capsule)
|
||||
worker.handle = orchpy.lib.create_worker(scheduler_addr, objstore_addr, worker_addr)
|
||||
worker.connected = True
|
||||
|
||||
def pull(objref, worker=global_worker):
|
||||
object_capsule = orchpy.lib.pull_object(worker.handle, objref)
|
||||
return orchpy.lib.deserialize_object(object_capsule)
|
||||
return serialization.deserialize(object_capsule)
|
||||
|
||||
def push(value, worker=global_worker):
|
||||
object_capsule = orchpy.lib.serialize_object(value)
|
||||
object_capsule = serialization.serialize(value)
|
||||
return orchpy.lib.push_object(worker.handle, object_capsule)
|
||||
|
||||
def main_loop(worker=global_worker):
|
||||
|
||||
+44
-2
@@ -1,13 +1,14 @@
|
||||
import unittest
|
||||
import orchpy
|
||||
import orchpy.serialization as serialization
|
||||
import orchpy.services as services
|
||||
import orchpy.worker as worker
|
||||
import numpy as np
|
||||
import time
|
||||
import subprocess32 as subprocess
|
||||
import os
|
||||
|
||||
import arrays.single as single
|
||||
import arrays.dist as dist
|
||||
|
||||
from google.protobuf.text_format import *
|
||||
|
||||
@@ -63,7 +64,7 @@ class ArraysSingleTest(unittest.TestCase):
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
worker.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port))
|
||||
orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port))
|
||||
|
||||
test_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
test_path = os.path.join(test_dir, "testrecv.py")
|
||||
@@ -98,5 +99,46 @@ class ArraysSingleTest(unittest.TestCase):
|
||||
|
||||
services.cleanup()
|
||||
|
||||
class ArraysDistTest(unittest.TestCase):
|
||||
|
||||
def testMethods(self):
|
||||
x = dist.DistArray()
|
||||
x.construct([2, 3, 4], float, np.array([[[orchpy.lib.ObjRef(0)]]]))
|
||||
capsule = serialization.serialize(x)
|
||||
y = serialization.deserialize(capsule)
|
||||
self.assertEqual(x.shape, y.shape)
|
||||
self.assertEqual(x.dtype, y.dtype)
|
||||
self.assertEqual(x.objrefs[0, 0, 0].val, y.objrefs[0, 0, 0].val)
|
||||
|
||||
def testAssemble(self):
|
||||
scheduler_port = new_scheduler_port()
|
||||
objstore_port = new_objstore_port()
|
||||
worker1_port = new_worker_port()
|
||||
worker2_port = new_worker_port()
|
||||
|
||||
services.start_scheduler(address(IP_ADDRESS, scheduler_port))
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
services.start_objstore(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port))
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port))
|
||||
|
||||
test_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
test_path = os.path.join(test_dir, "testrecv.py")
|
||||
services.start_worker(test_path, address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker2_port))
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
a = single.ones([dist.BLOCK_SIZE, dist.BLOCK_SIZE])
|
||||
b = single.zeros([dist.BLOCK_SIZE, dist.BLOCK_SIZE])
|
||||
x = dist.DistArray()
|
||||
x.construct([2 * dist.BLOCK_SIZE, dist.BLOCK_SIZE], float, np.array([[a], [b]]))
|
||||
self.assertTrue(np.alltrue(x.assemble() == np.vstack([np.ones([dist.BLOCK_SIZE, dist.BLOCK_SIZE]), np.zeros([dist.BLOCK_SIZE, dist.BLOCK_SIZE])])))
|
||||
|
||||
services.cleanup()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
+27
-26
@@ -1,5 +1,6 @@
|
||||
import unittest
|
||||
import orchpy
|
||||
import orchpy.serialization as serialization
|
||||
import orchpy.services as services
|
||||
import orchpy.worker as worker
|
||||
import numpy as np
|
||||
@@ -50,14 +51,14 @@ def new_objstore_port():
|
||||
class SerializationTest(unittest.TestCase):
|
||||
|
||||
def roundTripTest(self, data):
|
||||
serialized = orchpy.lib.serialize_object(data)
|
||||
result = orchpy.lib.deserialize_object(serialized)
|
||||
serialized = serialization.serialize(data)
|
||||
result = serialization.deserialize(serialized)
|
||||
self.assertEqual(data, result)
|
||||
|
||||
def numpyTypeTest(self, typ):
|
||||
a = np.random.randint(0, 10, size=(100, 100)).astype(typ)
|
||||
b = orchpy.lib.serialize_object(a)
|
||||
c = orchpy.lib.deserialize_object(b)
|
||||
b = serialization.serialize(a)
|
||||
c = serialization.deserialize(b)
|
||||
self.assertTrue((a == c).all())
|
||||
|
||||
def testSerialize(self):
|
||||
@@ -68,8 +69,8 @@ class SerializationTest(unittest.TestCase):
|
||||
self.roundTripTest((1.0, "hi"))
|
||||
|
||||
a = np.zeros((100, 100))
|
||||
res = orchpy.lib.serialize_object(a)
|
||||
b = orchpy.lib.deserialize_object(res)
|
||||
res = serialization.serialize(a)
|
||||
b = serialization.deserialize(res)
|
||||
self.assertTrue((a == b).all())
|
||||
|
||||
self.numpyTypeTest('int8')
|
||||
@@ -80,8 +81,8 @@ class SerializationTest(unittest.TestCase):
|
||||
self.numpyTypeTest('float64')
|
||||
|
||||
a = np.array([[orchpy.lib.ObjRef(0), orchpy.lib.ObjRef(1)], [orchpy.lib.ObjRef(41), orchpy.lib.ObjRef(42)]])
|
||||
capsule = orchpy.lib.serialize_object(a)
|
||||
result = orchpy.lib.deserialize_object(capsule)
|
||||
capsule = serialization.serialize(a)
|
||||
result = serialization.deserialize(capsule)
|
||||
self.assertTrue((a == result).all())
|
||||
|
||||
class OrchPyLibTest(unittest.TestCase):
|
||||
@@ -101,7 +102,7 @@ class OrchPyLibTest(unittest.TestCase):
|
||||
|
||||
w = worker.Worker()
|
||||
|
||||
worker.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker_port), w)
|
||||
orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker_port), w)
|
||||
|
||||
w.put_object(orchpy.lib.ObjRef(0), 'hello world')
|
||||
result = w.get_object(orchpy.lib.ObjRef(0))
|
||||
@@ -134,22 +135,22 @@ class ObjStoreTest(unittest.TestCase):
|
||||
objstore2_stub = connect_to_objstore(IP_ADDRESS, objstore2_port)
|
||||
|
||||
worker1 = worker.Worker()
|
||||
worker.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore1_port), address(IP_ADDRESS, worker1_port), worker1)
|
||||
orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore1_port), address(IP_ADDRESS, worker1_port), worker1)
|
||||
|
||||
worker2 = worker.Worker()
|
||||
worker.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore2_port), address(IP_ADDRESS, worker2_port), worker2)
|
||||
orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore2_port), address(IP_ADDRESS, worker2_port), worker2)
|
||||
|
||||
# pushing and pulling an object shouldn't change it
|
||||
for data in ["h", "h" * 10000, 0, 0.0]:
|
||||
objref = worker.push(data, worker1)
|
||||
result = worker.pull(objref, worker1)
|
||||
objref = orchpy.push(data, worker1)
|
||||
result = orchpy.pull(objref, worker1)
|
||||
self.assertEqual(result, data)
|
||||
|
||||
# pushing an object, shipping it to another worker, and pulling it shouldn't change it
|
||||
for data in ["h", "h" * 10000, 0, 0.0]:
|
||||
objref = worker.push(data, worker1)
|
||||
objref = orchpy.push(data, worker1)
|
||||
response = objstore1_stub.DeliverObj(orchestra_pb2.DeliverObjRequest(objref=objref.val, objstore_address=address(IP_ADDRESS, objstore2_port)), TIMEOUT_SECONDS)
|
||||
result = worker.pull(objref, worker2)
|
||||
result = orchpy.pull(objref, worker2)
|
||||
self.assertEqual(result, data)
|
||||
|
||||
services.cleanup()
|
||||
@@ -176,7 +177,7 @@ class SchedulerTest(unittest.TestCase):
|
||||
time.sleep(0.2)
|
||||
|
||||
worker1 = worker.Worker()
|
||||
worker.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port), worker1)
|
||||
orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port), worker1)
|
||||
|
||||
test_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
test_path = os.path.join(test_dir, "testrecv.py")
|
||||
@@ -189,7 +190,7 @@ class SchedulerTest(unittest.TestCase):
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
value_after = worker.pull(objref[0], worker1)
|
||||
value_after = orchpy.pull(objref[0], worker1)
|
||||
self.assertEqual(value_before, value_after)
|
||||
|
||||
time.sleep(0.1)
|
||||
@@ -214,30 +215,30 @@ class WorkerTest(unittest.TestCase):
|
||||
time.sleep(0.2)
|
||||
|
||||
worker1 = worker.Worker()
|
||||
worker.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port), worker1)
|
||||
orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port), worker1)
|
||||
|
||||
for i in range(100):
|
||||
value_before = i * 10 ** 6
|
||||
objref = worker.push(value_before, worker1)
|
||||
value_after = worker.pull(objref, worker1)
|
||||
objref = orchpy.push(value_before, worker1)
|
||||
value_after = orchpy.pull(objref, worker1)
|
||||
self.assertEqual(value_before, value_after)
|
||||
|
||||
for i in range(100):
|
||||
value_before = i * 10 ** 6 * 1.0
|
||||
objref = worker.push(value_before, worker1)
|
||||
value_after = worker.pull(objref, worker1)
|
||||
objref = orchpy.push(value_before, worker1)
|
||||
value_after = orchpy.pull(objref, worker1)
|
||||
self.assertEqual(value_before, value_after)
|
||||
|
||||
for i in range(100):
|
||||
value_before = "h" * i
|
||||
objref = worker.push(value_before, worker1)
|
||||
value_after = worker.pull(objref, worker1)
|
||||
objref = orchpy.push(value_before, worker1)
|
||||
value_after = orchpy.pull(objref, worker1)
|
||||
self.assertEqual(value_before, value_after)
|
||||
|
||||
for i in range(100):
|
||||
value_before = [1] * i
|
||||
objref = worker.push(value_before, worker1)
|
||||
value_after = worker.pull(objref, worker1)
|
||||
objref = orchpy.push(value_before, worker1)
|
||||
value_after = orchpy.pull(objref, worker1)
|
||||
self.assertEqual(value_before, value_after)
|
||||
|
||||
services.cleanup()
|
||||
|
||||
Reference in New Issue
Block a user