mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 09:29:05 +08:00
change remote function invocation from func() to func.remote() (#328)
This commit is contained in:
committed by
Philipp Moritz
parent
92f1976e94
commit
0e5b858324
@@ -87,14 +87,14 @@ def numpy_to_dist(a):
|
||||
def zeros(shape, dtype_name="float"):
|
||||
result = DistArray(shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objrefs[index] = ra.zeros(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
|
||||
result.objrefs[index] = ra.zeros.remote(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
|
||||
return result
|
||||
|
||||
@ray.remote([List, str], [DistArray])
|
||||
def ones(shape, dtype_name="float"):
|
||||
result = DistArray(shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objrefs[index] = ra.ones(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
|
||||
result.objrefs[index] = ra.ones.remote(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
|
||||
return result
|
||||
|
||||
@ray.remote([DistArray], [DistArray])
|
||||
@@ -112,9 +112,9 @@ def eye(dim1, dim2=-1, dtype_name="float"):
|
||||
for (i, j) in np.ndindex(*result.num_blocks):
|
||||
block_shape = DistArray.compute_block_shape([i, j], shape)
|
||||
if i == j:
|
||||
result.objrefs[i, j] = ra.eye(block_shape[0], block_shape[1], dtype_name=dtype_name)
|
||||
result.objrefs[i, j] = ra.eye.remote(block_shape[0], block_shape[1], dtype_name=dtype_name)
|
||||
else:
|
||||
result.objrefs[i, j] = ra.zeros(block_shape, dtype_name=dtype_name)
|
||||
result.objrefs[i, j] = ra.zeros.remote(block_shape, dtype_name=dtype_name)
|
||||
return result
|
||||
|
||||
@ray.remote([DistArray], [DistArray])
|
||||
@@ -124,11 +124,11 @@ def triu(a):
|
||||
result = DistArray(a.shape)
|
||||
for (i, j) in np.ndindex(*result.num_blocks):
|
||||
if i < j:
|
||||
result.objrefs[i, j] = ra.copy(a.objrefs[i, j])
|
||||
result.objrefs[i, j] = ra.copy.remote(a.objrefs[i, j])
|
||||
elif i == j:
|
||||
result.objrefs[i, j] = ra.triu(a.objrefs[i, j])
|
||||
result.objrefs[i, j] = ra.triu.remote(a.objrefs[i, j])
|
||||
else:
|
||||
result.objrefs[i, j] = ra.zeros_like(a.objrefs[i, j])
|
||||
result.objrefs[i, j] = ra.zeros_like.remote(a.objrefs[i, j])
|
||||
return result
|
||||
|
||||
@ray.remote([DistArray], [DistArray])
|
||||
@@ -138,11 +138,11 @@ def tril(a):
|
||||
result = DistArray(a.shape)
|
||||
for (i, j) in np.ndindex(*result.num_blocks):
|
||||
if i > j:
|
||||
result.objrefs[i, j] = ra.copy(a.objrefs[i, j])
|
||||
result.objrefs[i, j] = ra.copy.remote(a.objrefs[i, j])
|
||||
elif i == j:
|
||||
result.objrefs[i, j] = ra.tril(a.objrefs[i, j])
|
||||
result.objrefs[i, j] = ra.tril.remote(a.objrefs[i, j])
|
||||
else:
|
||||
result.objrefs[i, j] = ra.zeros_like(a.objrefs[i, j])
|
||||
result.objrefs[i, j] = ra.zeros_like.remote(a.objrefs[i, j])
|
||||
return result
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray])
|
||||
@@ -168,7 +168,7 @@ def dot(a, b):
|
||||
result = DistArray(shape)
|
||||
for (i, j) in np.ndindex(*result.num_blocks):
|
||||
args = list(a.objrefs[i, :]) + list(b.objrefs[:, j])
|
||||
result.objrefs[i, j] = blockwise_dot(*args)
|
||||
result.objrefs[i, j] = blockwise_dot.remote(*args)
|
||||
return result
|
||||
|
||||
@ray.remote([DistArray, List], [DistArray])
|
||||
@@ -208,7 +208,7 @@ def transpose(a):
|
||||
result = DistArray([a.shape[1], a.shape[0]])
|
||||
for i in range(result.num_blocks[0]):
|
||||
for j in range(result.num_blocks[1]):
|
||||
result.objrefs[i, j] = ra.transpose(a.objrefs[j, i])
|
||||
result.objrefs[i, j] = ra.transpose.remote(a.objrefs[j, i])
|
||||
return result
|
||||
|
||||
# TODO(rkn): support broadcasting?
|
||||
@@ -218,7 +218,7 @@ def add(x1, x2):
|
||||
raise Exception("add expects arguments `x1` and `x2` to have the same shape, but x1.shape = {}, and x2.shape = {}.".format(x1.shape, x2.shape))
|
||||
result = DistArray(x1.shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objrefs[index] = ra.add(x1.objrefs[index], x2.objrefs[index])
|
||||
result.objrefs[index] = ra.add.remote(x1.objrefs[index], x2.objrefs[index])
|
||||
return result
|
||||
|
||||
# TODO(rkn): support broadcasting?
|
||||
@@ -228,5 +228,5 @@ def subtract(x1, x2):
|
||||
raise Exception("subtract expects arguments `x1` and `x2` to have the same shape, but x1.shape = {}, and x2.shape = {}.".format(x1.shape, x2.shape))
|
||||
result = DistArray(x1.shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objrefs[index] = ra.subtract(x1.objrefs[index], x2.objrefs[index])
|
||||
result.objrefs[index] = ra.subtract.remote(x1.objrefs[index], x2.objrefs[index])
|
||||
return result
|
||||
|
||||
@@ -33,14 +33,14 @@ def tsqr(a):
|
||||
current_rs = []
|
||||
for i in range(num_blocks):
|
||||
block = a.objrefs[i, 0]
|
||||
q, r = ra.linalg.qr(block)
|
||||
q, r = ra.linalg.qr.remote(block)
|
||||
q_tree[i, 0] = q
|
||||
current_rs.append(r)
|
||||
for j in range(1, K):
|
||||
new_rs = []
|
||||
for i in range(int(np.ceil(1.0 * len(current_rs) / 2))):
|
||||
stacked_rs = ra.vstack(*current_rs[(2 * i):(2 * i + 2)])
|
||||
q, r = ra.linalg.qr(stacked_rs)
|
||||
stacked_rs = ra.vstack.remote(*current_rs[(2 * i):(2 * i + 2)])
|
||||
q, r = ra.linalg.qr.remote(stacked_rs)
|
||||
q_tree[i, j] = q
|
||||
new_rs.append(r)
|
||||
current_rs = new_rs
|
||||
@@ -72,7 +72,7 @@ def tsqr(a):
|
||||
lower = [a.shape[1], 0]
|
||||
upper = [2 * a.shape[1], BLOCK_SIZE]
|
||||
ith_index /= 2
|
||||
q_block_current = ra.dot(q_block_current, ra.subarray(q_tree[ith_index, j], lower, upper))
|
||||
q_block_current = ra.dot.remote(q_block_current, ra.subarray.remote(q_tree[ith_index, j], lower, upper))
|
||||
q_result.objrefs[i] = q_block_current
|
||||
r = current_rs[0]
|
||||
return q_result, r
|
||||
@@ -106,7 +106,7 @@ def modified_lu(q):
|
||||
for i in range(b):
|
||||
L[i, i] = 1
|
||||
U = np.triu(q_work)[:b, :]
|
||||
return numpy_to_dist(ray.put(L)), U, S # TODO(rkn): get rid of put
|
||||
return numpy_to_dist.remote(ray.put(L)), U, S # TODO(rkn): get rid of put
|
||||
|
||||
@ray.remote([np.ndarray, np.ndarray, np.ndarray, int], [np.ndarray, np.ndarray])
|
||||
def tsqr_hr_helper1(u, s, y_top_block, b):
|
||||
@@ -123,11 +123,11 @@ def tsqr_hr_helper2(s, r_temp):
|
||||
@ray.remote([DistArray], [DistArray, np.ndarray, np.ndarray, np.ndarray])
|
||||
def tsqr_hr(a):
|
||||
"""Algorithm 6 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf"""
|
||||
q, r_temp = tsqr(a)
|
||||
y, u, s = modified_lu(q)
|
||||
q, r_temp = tsqr.remote(a)
|
||||
y, u, s = modified_lu.remote(q)
|
||||
y_blocked = ray.get(y)
|
||||
t, y_top = tsqr_hr_helper1(u, s, y_blocked.objrefs[0, 0], a.shape[1])
|
||||
r = tsqr_hr_helper2(s, r_temp)
|
||||
t, y_top = tsqr_hr_helper1.remote(u, s, y_blocked.objrefs[0, 0], a.shape[1])
|
||||
r = tsqr_hr_helper2.remote(s, r_temp)
|
||||
return y, t, y_top, r
|
||||
|
||||
@ray.remote([np.ndarray, np.ndarray, np.ndarray, np.ndarray], [np.ndarray])
|
||||
@@ -149,42 +149,42 @@ def qr(a):
|
||||
a_work.construct(a.shape, np.copy(a.objrefs))
|
||||
|
||||
result_dtype = np.linalg.qr(ray.get(a.objrefs[0, 0]))[0].dtype.name
|
||||
r_res = ray.get(zeros([k, n], result_dtype)) # TODO(rkn): It would be preferable not to get this right after creating it.
|
||||
y_res = ray.get(zeros([m, k], result_dtype)) # TODO(rkn): It would be preferable not to get this right after creating it.
|
||||
r_res = ray.get(zeros.remote([k, n], result_dtype)) # TODO(rkn): It would be preferable not to get this right after creating it.
|
||||
y_res = ray.get(zeros.remote([m, k], result_dtype)) # TODO(rkn): It would be preferable not to get this right after creating it.
|
||||
Ts = []
|
||||
|
||||
for i in range(min(a.num_blocks[0], a.num_blocks[1])): # this differs from the paper, which says "for i in range(a.num_blocks[1])", but that doesn't seem to make any sense when a.num_blocks[1] > a.num_blocks[0]
|
||||
sub_dist_array = subblocks(a_work, range(i, a_work.num_blocks[0]), [i])
|
||||
y, t, _, R = tsqr_hr(sub_dist_array)
|
||||
sub_dist_array = subblocks.remote(a_work, range(i, a_work.num_blocks[0]), [i])
|
||||
y, t, _, R = tsqr_hr.remote(sub_dist_array)
|
||||
y_val = ray.get(y)
|
||||
|
||||
for j in range(i, a.num_blocks[0]):
|
||||
y_res.objrefs[j, i] = y_val.objrefs[j - i, 0]
|
||||
if a.shape[0] > a.shape[1]:
|
||||
# in this case, R needs to be square
|
||||
R_shape = ray.get(ra.shape(R))
|
||||
eye_temp = ra.eye(R_shape[1], R_shape[0], dtype_name=result_dtype)
|
||||
r_res.objrefs[i, i] = ra.dot(eye_temp, R)
|
||||
R_shape = ray.get(ra.shape.remote(R))
|
||||
eye_temp = ra.eye.remote(R_shape[1], R_shape[0], dtype_name=result_dtype)
|
||||
r_res.objrefs[i, i] = ra.dot.remote(eye_temp, R)
|
||||
else:
|
||||
r_res.objrefs[i, i] = R
|
||||
Ts.append(numpy_to_dist(t))
|
||||
Ts.append(numpy_to_dist.remote(t))
|
||||
|
||||
for c in range(i + 1, a.num_blocks[1]):
|
||||
W_rcs = []
|
||||
for r in range(i, a.num_blocks[0]):
|
||||
y_ri = y_val.objrefs[r - i, 0]
|
||||
W_rcs.append(qr_helper2(y_ri, a_work.objrefs[r, c]))
|
||||
W_c = ra.sum_list(*W_rcs)
|
||||
W_rcs.append(qr_helper2.remote(y_ri, a_work.objrefs[r, c]))
|
||||
W_c = ra.sum_list.remote(*W_rcs)
|
||||
for r in range(i, a.num_blocks[0]):
|
||||
y_ri = y_val.objrefs[r - i, 0]
|
||||
A_rc = qr_helper1(a_work.objrefs[r, c], y_ri, t, W_c)
|
||||
A_rc = qr_helper1.remote(a_work.objrefs[r, c], y_ri, t, W_c)
|
||||
a_work.objrefs[r, c] = A_rc
|
||||
r_res.objrefs[i, c] = a_work.objrefs[i, c]
|
||||
|
||||
# construct q_res from Ys and Ts
|
||||
q = eye(m, k, dtype_name=result_dtype)
|
||||
q = eye.remote(m, k, dtype_name=result_dtype)
|
||||
for i in range(len(Ts))[::-1]:
|
||||
y_col_block = subblocks(y_res, [], [i])
|
||||
q = subtract(q, dot(y_col_block, dot(Ts[i], dot(transpose(y_col_block), q))))
|
||||
y_col_block = subblocks.remote(y_res, [], [i])
|
||||
q = subtract.remote(q, dot.remote(y_col_block, dot.remote(Ts[i], dot.remote(transpose.remote(y_col_block), q))))
|
||||
|
||||
return q, r_res
|
||||
|
||||
@@ -11,7 +11,7 @@ def normal(shape):
|
||||
num_blocks = DistArray.compute_num_blocks(shape)
|
||||
objrefs = np.empty(num_blocks, dtype=object)
|
||||
for index in np.ndindex(*num_blocks):
|
||||
objrefs[index] = ra.random.normal(DistArray.compute_block_shape(index, shape))
|
||||
objrefs[index] = ra.random.normal.remote(DistArray.compute_block_shape(index, shape))
|
||||
result = DistArray()
|
||||
result.construct(shape, objrefs)
|
||||
return result
|
||||
|
||||
+14
-10
@@ -836,20 +836,24 @@ def remote(arg_types, return_types, worker=global_worker):
|
||||
start_time = time.time()
|
||||
result = func(*arguments)
|
||||
end_time = time.time()
|
||||
check_return_values(func_call, result) # throws an exception if result is invalid
|
||||
check_return_values(func_invoker, result) # throws an exception if result is invalid
|
||||
_logger().info("Finished executing function {}, it took {} seconds".format(func.__name__, end_time - start_time))
|
||||
return result
|
||||
func_call.executor = func_executor
|
||||
func_call.arg_types = arg_types
|
||||
func_call.return_types = return_types
|
||||
func_call.is_remote = True
|
||||
def func_invoker(*args, **kwargs):
|
||||
"""This is returned by the decorator and used to invoke the function."""
|
||||
raise Exception("Remote functions cannot be called directly. Instead of running '{}()', try '{}.remote()'.".format(func_name, func_name))
|
||||
func_invoker.remote = func_call
|
||||
func_invoker.executor = func_executor
|
||||
func_invoker.arg_types = arg_types
|
||||
func_invoker.return_types = return_types
|
||||
func_invoker.is_remote = True
|
||||
func_name = "{}.{}".format(func.__module__, func.__name__)
|
||||
func_call.func_name = func_name
|
||||
func_call.func_doc = func.func_doc
|
||||
func_invoker.func_name = func_name
|
||||
func_invoker.func_doc = func.func_doc
|
||||
sig_params = [(k, v) for k, v in funcsigs.signature(func).parameters.iteritems()]
|
||||
keyword_defaults = [(k, v.default) for k, v in sig_params]
|
||||
has_vararg_param = any([v.kind == v.VAR_POSITIONAL for k, v in sig_params])
|
||||
func_call.has_vararg_param = has_vararg_param
|
||||
func_invoker.has_vararg_param = has_vararg_param
|
||||
has_kwargs_param = any([v.kind == v.VAR_KEYWORD for k, v in sig_params])
|
||||
check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, func_name)
|
||||
|
||||
@@ -858,7 +862,7 @@ def remote(arg_types, return_types, worker=global_worker):
|
||||
func_name_global_valid = func.__name__ in func.__globals__
|
||||
func_name_global_value = func.__globals__.get(func.__name__)
|
||||
# Set the function globally to make it refer to itself
|
||||
func.__globals__[func.__name__] = func_call # Allow the function to reference itself as a global variable
|
||||
func.__globals__[func.__name__] = func_invoker # Allow the function to reference itself as a global variable
|
||||
try:
|
||||
to_export = pickling.dumps((func, arg_types, return_types, func.__module__))
|
||||
finally:
|
||||
@@ -869,7 +873,7 @@ def remote(arg_types, return_types, worker=global_worker):
|
||||
ray.lib.export_function(worker.handle, to_export)
|
||||
elif worker.mode is None:
|
||||
worker.cached_remote_functions.append(to_export)
|
||||
return func_call
|
||||
return func_invoker
|
||||
return remote_decorator
|
||||
|
||||
def check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, name):
|
||||
|
||||
Reference in New Issue
Block a user