mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:32:11 +08:00
@@ -0,0 +1,2 @@
|
||||
import random, linalg
|
||||
from core import zeros, eye, dot, vstack, hstack, subarray, copy, tril, triu
|
||||
@@ -0,0 +1,43 @@
|
||||
from typing import List
|
||||
import numpy as np
|
||||
import orchpy as op
|
||||
|
||||
@op.distributed([List[int]], [np.ndarray])
|
||||
def zeros(shape):
|
||||
return np.zeros(shape)
|
||||
|
||||
@op.distributed([int], [np.ndarray])
|
||||
def eye(dim):
|
||||
return np.eye(dim)
|
||||
|
||||
@op.distributed([np.ndarray, np.ndarray], [np.ndarray])
|
||||
def dot(a, b):
|
||||
return np.dot(a, b)
|
||||
|
||||
# TODO(rkn): My preferred signature would have been
|
||||
# @op.distributed([List[np.ndarray]], [np.ndarray]) but that currently doesn't
|
||||
# work because that would expect a list of ndarrays not a list of ObjRefs
|
||||
@op.distributed([np.ndarray, None], [np.ndarray])
|
||||
def vstack(*xs):
|
||||
return np.vstack(xs)
|
||||
|
||||
@op.distributed([np.ndarray, None], [np.ndarray])
|
||||
def hstack(*xs):
|
||||
return np.hstack(xs)
|
||||
|
||||
# TODO(rkn): this doesn't parallel the numpy API, but we can't really slice an ObjRef, think about this
|
||||
@op.distributed([np.ndarray, List[int], List[int]], [np.ndarray])
|
||||
def subarray(a, lower_indices, upper_indices): # TODO(rkn): be consistent about using "index" versus "indices"
|
||||
return a[[slice(l, u) for (l, u) in zip(lower_indices, upper_indices)]]
|
||||
|
||||
@op.distributed([np.ndarray], [np.ndarray])
|
||||
def copy(a):
|
||||
return np.copy(a)
|
||||
|
||||
@op.distributed([np.ndarray], [np.ndarray])
|
||||
def tril(a):
|
||||
return np.tril(a)
|
||||
|
||||
@op.distributed([np.ndarray], [np.ndarray])
|
||||
def triu(a):
|
||||
return np.triu(a)
|
||||
@@ -0,0 +1,50 @@
|
||||
from typing import List
|
||||
import numpy as np
|
||||
import orchpy as op
|
||||
|
||||
# TODO(rkn): this should take the same optional "mode" argument as np.linalg.qr, except that the different options sometimes have different numbers of return values, which could be a problem
|
||||
@op.distributed([np.ndarray], [np.ndarray, np.ndarray])
|
||||
def qr(a):
|
||||
"""
|
||||
Suppose (n, m) = a.shape
|
||||
If n >= m:
|
||||
q.shape == (n, m)
|
||||
r.shape == (m, m)
|
||||
If n < m:
|
||||
q.shape == (n, n)
|
||||
r.shape == (n, m)
|
||||
"""
|
||||
return np.linalg.qr(a)
|
||||
|
||||
#@op.distributed([np.ndarray], [np.ndarray, np.ndarray, np.ndarray])
|
||||
def modified_lu(q):
|
||||
"""
|
||||
Algorithm 5 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf
|
||||
takes a matrix q with orthonormal columns, returns l, u, s such that q - s = l * u
|
||||
arguments:
|
||||
q: a two dimensional orthonormal q
|
||||
return values:
|
||||
l: lower triangular
|
||||
u: upper triangular
|
||||
s: a diagonal matrix represented by its diagonal
|
||||
"""
|
||||
m, b = q.shape[0], q.shape[1]
|
||||
S = np.zeros(b)
|
||||
|
||||
q_work = np.copy(q)
|
||||
|
||||
for i in range(b):
|
||||
S[i] = -1 * np.sign(q_work[i, i])
|
||||
q_work[i, i] -= S[i]
|
||||
|
||||
# scale ith column of L by diagonal element
|
||||
q_work[(i + 1):m, i] /= q_work[i, i]
|
||||
|
||||
# perform Schur complement update
|
||||
q_work[(i + 1):m, (i + 1):b] -= np.outer(q_work[(i + 1):m, i], q_work[i, (i + 1):b])
|
||||
|
||||
L = np.tril(q_work)
|
||||
for i in range(b):
|
||||
L[i, i] = 1
|
||||
U = np.triu(q_work)[:b, :]
|
||||
return L, U, S
|
||||
@@ -0,0 +1,7 @@
|
||||
from typing import List
|
||||
import numpy as np
|
||||
import orchpy as op
|
||||
|
||||
@op.distributed([List[int]], [np.ndarray])
|
||||
def normal(shape):
|
||||
return np.random.normal(size=shape)
|
||||
@@ -1 +1,2 @@
|
||||
import liborchpylib as lib
|
||||
from worker import register_module, connect, pull, push, distributed
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from types import ModuleType
|
||||
import typing
|
||||
|
||||
import orchpy
|
||||
@@ -28,11 +29,22 @@ class Worker(object):
|
||||
def remote_call(self, func_name, args):
|
||||
"""Tell the scheduler to schedule the execution of the function with name `func_name` with arguments `args`. Retrieve object references for the outputs of the function from the scheduler and immediately return them."""
|
||||
call_capsule = orchpy.lib.serialize_call(func_name, args)
|
||||
return orchpy.lib.remote_call(self.handle, call_capsule)
|
||||
objrefs = orchpy.lib.remote_call(self.handle, call_capsule)
|
||||
return objrefs
|
||||
|
||||
# We make `global_worker` a global variable so that there is one worker per worker process.
|
||||
global_worker = Worker()
|
||||
|
||||
def register_module(module, recursive=False, worker=global_worker):
|
||||
print "registering functions in module {}.".format(module.__name__)
|
||||
for name in dir(module):
|
||||
val = getattr(module, name)
|
||||
if hasattr(val, "is_distributed") and val.is_distributed:
|
||||
print "registering {}.".format(val.func_name)
|
||||
worker.register_function(val)
|
||||
# elif recursive and isinstance(val, ModuleType):
|
||||
# register_module(val, recursive, worker)
|
||||
|
||||
def connect(scheduler_addr, objstore_addr, worker_addr, worker=global_worker):
|
||||
if worker.connected:
|
||||
raise Exception("Worker called connect, but worker is already connected")
|
||||
@@ -70,24 +82,54 @@ def distributed(arg_types, return_types, worker=global_worker):
|
||||
return result
|
||||
def func_call(*args):
|
||||
"""This is what gets run immediately when a worker calls a distributed function."""
|
||||
# TODO(rkn): check types
|
||||
return worker.remote_call(func_call.func_name, list(args))
|
||||
check_arguments(func_call, list(args)) # throws an exception if args are invalid
|
||||
objrefs = worker.remote_call(func_call.func_name, list(args))
|
||||
return objrefs[0] if len(objrefs) == 1 else objrefs
|
||||
func_call.func_name = "{}.{}".format(func.__module__, func.__name__)
|
||||
func_call.executor = func_executor
|
||||
func_call.arg_types = arg_types
|
||||
func_call.return_types = return_types
|
||||
func_call.is_distributed = True
|
||||
return func_call
|
||||
return distributed_decorator
|
||||
|
||||
# helper method, this should not be called by the user
|
||||
def get_arguments_for_execution(function, args, worker=global_worker):
|
||||
arguments = []
|
||||
def check_arguments(function, args):
|
||||
# check the number of args
|
||||
if len(args) != len(function.arg_types) and function.arg_types[-1] is not None:
|
||||
raise Exception("Function {} expects {} arguments, but received {}.".format(function.__name__, len(function.arg_types), len(args)))
|
||||
elif len(args) < len(function.arg_types) - 1 and function.arg_types[-1] is None:
|
||||
raise Exception("Function {} expects at least {} arguments, but received {}.".format(function.__name__, len(function.arg_types) - 1, len(args)))
|
||||
|
||||
for (i, arg) in enumerate(args):
|
||||
if i < len(function.arg_types) - 1:
|
||||
expected_type = function.arg_types[i]
|
||||
elif i == len(function.arg_types) - 1 and function.arg_types[-1] is not None:
|
||||
expected_type = function.arg_types[-1]
|
||||
elif function.arg_types[-1] is None and len(function.arg_types > 1):
|
||||
expected_type = function.arg_types[-2]
|
||||
else:
|
||||
assert False, "This code should be unreachable."
|
||||
|
||||
if type(arg) == orchpy.lib.ObjRef:
|
||||
# TODO(rkn): When we have type information in the ObjRef, do type checking here.
|
||||
pass
|
||||
else:
|
||||
if not isinstance(arg, expected_type): # TODO(rkn): This check doesn't really work, e.g., isinstance([1,2,3], typing.List[str]) == True
|
||||
raise Exception("Argument {} for function {} has type {} but an argument of type {} was expected.".format(i, function.__name__, type(arg), expected_type))
|
||||
|
||||
# helper method, this should not be called by the user
|
||||
def get_arguments_for_execution(function, args, worker=global_worker):
|
||||
# TODO(rkn): Eventually, all of the type checking can be put in `check_arguments` above so that the error will happen immediately when calling a remote function.
|
||||
arguments = []
|
||||
"""
|
||||
# check the number of args
|
||||
if len(args) != len(function.arg_types) and function.arg_types[-1] is not None:
|
||||
raise Exception("Function {} expects {} arguments, but received {}.".format(function.__name__, len(function.arg_types), len(args)))
|
||||
elif len(args) < len(function.arg_types) - 1 and function.arg_types[-1] is None:
|
||||
raise Exception("Function {} expects at least {} arguments, but received {}.".format(function.__name__, len(function.arg_types) - 1, len(args)))
|
||||
"""
|
||||
|
||||
for (i, arg) in enumerate(args):
|
||||
print "Pulling argument {} for function {}.".format(i, function.__name__)
|
||||
if i < len(function.arg_types) - 1:
|
||||
@@ -108,8 +150,8 @@ def get_arguments_for_execution(function, args, worker=global_worker):
|
||||
# pass the argument by value
|
||||
argument = arg
|
||||
|
||||
if expected_type != type(argument):
|
||||
raise Exception("Argument {} for function {} has type {} but an argument of type {} was expected.".format(i, function.__name__, type(argument), arg_type))
|
||||
if not isinstance(argument, expected_type):
|
||||
raise Exception("Argument {} for function {} has type {} but an argument of type {} was expected.".format(i, function.__name__, type(argument), expected_type))
|
||||
arguments.append(argument)
|
||||
return arguments
|
||||
|
||||
|
||||
+4
-4
@@ -6,7 +6,7 @@ from Cython.Build import cythonize
|
||||
|
||||
# because of relative paths, this must be run from inside orch/lib/orchpy/
|
||||
|
||||
MACOSX = (sys.platform in ['darwin'])
|
||||
MACOSX = (sys.platform in ["darwin"])
|
||||
|
||||
setup(
|
||||
name = "orchestra",
|
||||
@@ -14,9 +14,9 @@ setup(
|
||||
use_2to3=True,
|
||||
packages=find_packages(),
|
||||
package_data = {
|
||||
'orchpy': ['liborchpylib.dylib' if MACOSX else 'liborchpylib.so',
|
||||
'scheduler',
|
||||
'objstore']
|
||||
"orchpy": ["liborchpylib.dylib" if MACOSX else "liborchpylib.so",
|
||||
"scheduler",
|
||||
"objstore"]
|
||||
},
|
||||
zip_safe=False
|
||||
)
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
import unittest
|
||||
import orchpy
|
||||
import orchpy.services as services
|
||||
import orchpy.worker as worker
|
||||
import numpy as np
|
||||
import time
|
||||
import subprocess32 as subprocess
|
||||
import os
|
||||
|
||||
import arrays.single as single
|
||||
|
||||
from google.protobuf.text_format import *
|
||||
|
||||
from grpc.beta import implementations
|
||||
import orchestra_pb2
|
||||
import types_pb2
|
||||
|
||||
IP_ADDRESS = "127.0.0.1"
|
||||
TIMEOUT_SECONDS = 5
|
||||
|
||||
def connect_to_scheduler(host, port):
|
||||
channel = implementations.insecure_channel(host, port)
|
||||
return orchestra_pb2.beta_create_Scheduler_stub(channel)
|
||||
|
||||
def connect_to_objstore(host, port):
|
||||
channel = implementations.insecure_channel(host, port)
|
||||
return orchestra_pb2.beta_create_ObjStore_stub(channel)
|
||||
|
||||
def address(host, port):
|
||||
return host + ":" + str(port)
|
||||
|
||||
scheduler_port_counter = 0
|
||||
def new_scheduler_port():
|
||||
global scheduler_port_counter
|
||||
scheduler_port_counter += 1
|
||||
return 10000 + scheduler_port_counter
|
||||
|
||||
worker_port_counter = 0
|
||||
def new_worker_port():
|
||||
global worker_port_counter
|
||||
worker_port_counter += 1
|
||||
return 40000 + worker_port_counter
|
||||
|
||||
objstore_port_counter = 0
|
||||
def new_objstore_port():
|
||||
global objstore_port_counter
|
||||
objstore_port_counter += 1
|
||||
return 20000 + objstore_port_counter
|
||||
|
||||
class ArraysSingleTest(unittest.TestCase):
|
||||
|
||||
def testMethods(self):
|
||||
scheduler_port = new_scheduler_port()
|
||||
objstore_port = new_objstore_port()
|
||||
worker1_port = new_worker_port()
|
||||
worker2_port = new_worker_port()
|
||||
|
||||
services.start_scheduler(address(IP_ADDRESS, scheduler_port))
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
services.start_objstore(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port))
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
worker.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port))
|
||||
|
||||
test_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
test_path = os.path.join(test_dir, "testrecv.py")
|
||||
services.start_worker(test_path, address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker2_port))
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
# test eye
|
||||
ref = single.eye(3)
|
||||
time.sleep(0.2)
|
||||
val = orchpy.pull(ref)
|
||||
self.assertTrue(np.alltrue(val == np.eye(3)))
|
||||
|
||||
# test zeros
|
||||
ref = single.zeros([3, 4, 5])
|
||||
time.sleep(0.2)
|
||||
val = orchpy.pull(ref)
|
||||
self.assertTrue(np.alltrue(val == np.zeros([3, 4, 5])))
|
||||
|
||||
# test qr - pass by value
|
||||
val_a = np.random.normal(size=[10, 13])
|
||||
time.sleep(0.2)
|
||||
ref_q, ref_r = single.linalg.qr(val_a)
|
||||
time.sleep(0.2)
|
||||
val_q = orchpy.pull(ref_q)
|
||||
val_r = orchpy.pull(ref_r)
|
||||
self.assertTrue(np.allclose(np.dot(val_q, val_r), val_a))
|
||||
|
||||
# test qr - pass by objref
|
||||
a = single.random.normal([10, 13])
|
||||
time.sleep(0.2) # TODO(rkn): fails without this sleep
|
||||
ref_q, ref_r = single.linalg.qr(a)
|
||||
time.sleep(0.2)
|
||||
val_a = orchpy.pull(a)
|
||||
val_q = orchpy.pull(ref_q)
|
||||
val_r = orchpy.pull(ref_r)
|
||||
self.assertTrue(np.allclose(np.dot(val_q, val_r), val_a))
|
||||
|
||||
services.cleanup()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
+4
-1
@@ -7,6 +7,8 @@ import time
|
||||
import subprocess32 as subprocess
|
||||
import os
|
||||
|
||||
import arrays.single as single
|
||||
|
||||
from google.protobuf.text_format import *
|
||||
|
||||
from grpc.beta import implementations
|
||||
@@ -85,6 +87,8 @@ class OrchPyLibTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(result, 'hello world')
|
||||
|
||||
services.cleanup()
|
||||
|
||||
class ObjStoreTest(unittest.TestCase):
|
||||
|
||||
"""Test setting up object stores, transfering data between them and retrieving data to a client"""
|
||||
@@ -217,6 +221,5 @@ class WorkerTest(unittest.TestCase):
|
||||
|
||||
services.cleanup()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
import argparse
|
||||
|
||||
import orchpy
|
||||
import orchpy.services as services
|
||||
import orchpy.worker as worker
|
||||
|
||||
import arrays.single as single
|
||||
# import arrays.dist as dist
|
||||
|
||||
from grpc.beta import implementations
|
||||
import orchestra_pb2
|
||||
import types_pb2
|
||||
@@ -14,6 +18,17 @@ parser.add_argument("--scheduler-address", default="127.0.0.1:10001", type=str,
|
||||
parser.add_argument("--objstore-address", default="127.0.0.1:20001", type=str, help="the objstore's address")
|
||||
parser.add_argument("--worker-address", default="127.0.0.1:40001", type=str, help="the worker's address")
|
||||
|
||||
@orchpy.distributed([str], [str])
|
||||
def print_string(string):
|
||||
print "called print_string with", string
|
||||
f = open("asdfasdf.txt", "w")
|
||||
f.write("successfully called print_string with argument {}.".format(string))
|
||||
return string
|
||||
|
||||
@orchpy.distributed([int, int], [int, int])
|
||||
def handle_int(a, b):
|
||||
return a + 1, b + 1
|
||||
|
||||
def connect_to_scheduler(host, port):
|
||||
channel = implementations.insecure_channel(host, port)
|
||||
return orchestra_pb2.beta_create_Scheduler_stub(channel)
|
||||
|
||||
+11
-4
@@ -1,5 +1,9 @@
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
import arrays.single as single
|
||||
# import arrays.dist as dist
|
||||
|
||||
import orchpy
|
||||
import orchpy.services as services
|
||||
import orchpy.worker as worker
|
||||
@@ -9,14 +13,14 @@ parser.add_argument("--scheduler-address", default="127.0.0.1:10001", type=str,
|
||||
parser.add_argument("--objstore-address", default="127.0.0.1:20001", type=str, help="the objstore's address")
|
||||
parser.add_argument("--worker-address", default="127.0.0.1:40001", type=str, help="the worker's address")
|
||||
|
||||
@worker.distributed([str], [str])
|
||||
@orchpy.distributed([str], [str])
|
||||
def print_string(string):
|
||||
print "called print_string with", string
|
||||
f = open("asdfasdf.txt", "w")
|
||||
f.write("successfully called print_string with argument {}.".format(string))
|
||||
return string
|
||||
|
||||
@worker.distributed([int, int], [int, int])
|
||||
@orchpy.distributed([int, int], [int, int])
|
||||
def handle_int(a, b):
|
||||
return a + 1, b + 1
|
||||
|
||||
@@ -24,7 +28,10 @@ if __name__ == '__main__':
|
||||
args = parser.parse_args()
|
||||
worker.connect(args.scheduler_address, args.objstore_address, args.worker_address)
|
||||
|
||||
worker.global_worker.register_function(print_string)
|
||||
worker.global_worker.register_function(handle_int)
|
||||
orchpy.register_module(single)
|
||||
orchpy.register_module(single.random)
|
||||
orchpy.register_module(single.linalg)
|
||||
# orchpy.register_module(dist)
|
||||
orchpy.register_module(sys.modules[__name__])
|
||||
|
||||
worker.main_loop()
|
||||
|
||||
Reference in New Issue
Block a user