mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 19:32:38 +08:00
push/pull -> put/get
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import libraylib as lib
|
||||
import serialization
|
||||
from worker import scheduler_info, task_info, register_module, connect, disconnect, pull, push, remote
|
||||
from worker import scheduler_info, task_info, register_module, connect, disconnect, get, put, remote
|
||||
from libraylib import ObjRef
|
||||
|
||||
@@ -55,13 +55,13 @@ class DistArray(object):
|
||||
|
||||
def assemble(self):
|
||||
"""Assemble an array on this node from a distributed array object reference."""
|
||||
first_block = ray.pull(self.objrefs[(0,) * self.ndim])
|
||||
first_block = ray.get(self.objrefs[(0,) * self.ndim])
|
||||
dtype = first_block.dtype
|
||||
result = np.zeros(self.shape, dtype=dtype)
|
||||
for index in np.ndindex(*self.num_blocks):
|
||||
lower = DistArray.compute_block_lower(index, self.shape)
|
||||
upper = DistArray.compute_block_upper(index, self.shape)
|
||||
result[[slice(l, u) for (l, u) in zip(lower, upper)]] = ray.pull(self.objrefs[index])
|
||||
result[[slice(l, u) for (l, u) in zip(lower, upper)]] = ray.get(self.objrefs[index])
|
||||
return result
|
||||
|
||||
def __getitem__(self, sliced):
|
||||
@@ -80,7 +80,7 @@ def numpy_to_dist(a):
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
lower = DistArray.compute_block_lower(index, a.shape)
|
||||
upper = DistArray.compute_block_upper(index, a.shape)
|
||||
result.objrefs[index] = ray.push(a[[slice(l, u) for (l, u) in zip(lower, upper)]])
|
||||
result.objrefs[index] = ray.put(a[[slice(l, u) for (l, u) in zip(lower, upper)]])
|
||||
return result
|
||||
|
||||
@ray.remote([List[int], str], [DistArray])
|
||||
|
||||
@@ -17,10 +17,10 @@ def tsqr(a):
|
||||
a.shape == (M, N)
|
||||
K == min(M, N)
|
||||
return values:
|
||||
q: DistArray, if q_full = ray.context.pull(DistArray, q).assemble(), then
|
||||
q: DistArray, if q_full = ray.get(DistArray, q).assemble(), then
|
||||
q_full.shape == (M, K)
|
||||
np.allclose(np.dot(q_full.T, q_full), np.eye(K)) == True
|
||||
r: np.ndarray, if r_val = ray.context.pull(np.ndarray, r), then
|
||||
r: np.ndarray, if r_val = ray.get(np.ndarray, r), then
|
||||
r_val.shape == (K, N)
|
||||
np.allclose(r, np.triu(r)) == True
|
||||
"""
|
||||
@@ -108,7 +108,7 @@ def modified_lu(q):
|
||||
for i in range(b):
|
||||
L[i, i] = 1
|
||||
U = np.triu(q_work)[:b, :]
|
||||
return numpy_to_dist(ray.push(L)), U, S # TODO(rkn): get rid of push and pull
|
||||
return numpy_to_dist(ray.put(L)), U, S # TODO(rkn): get rid of put
|
||||
|
||||
@ray.remote([np.ndarray, np.ndarray, np.ndarray, int], [np.ndarray, np.ndarray])
|
||||
def tsqr_hr_helper1(u, s, y_top_block, b):
|
||||
@@ -127,7 +127,7 @@ def tsqr_hr(a):
|
||||
"""Algorithm 6 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf"""
|
||||
q, r_temp = tsqr(a)
|
||||
y, u, s = modified_lu(q)
|
||||
y_blocked = ray.pull(y)
|
||||
y_blocked = ray.get(y)
|
||||
t, y_top = tsqr_hr_helper1(u, s, y_blocked.objrefs[0, 0], a.shape[1])
|
||||
r = tsqr_hr_helper2(s, r_temp)
|
||||
return y, t, y_top, r
|
||||
@@ -150,21 +150,21 @@ def qr(a):
|
||||
a_work = DistArray()
|
||||
a_work.construct(a.shape, np.copy(a.objrefs))
|
||||
|
||||
result_dtype = np.linalg.qr(ray.pull(a.objrefs[0, 0]))[0].dtype.name
|
||||
r_res = ray.pull(zeros([k, n], result_dtype)) # TODO(rkn): It would be preferable not to pull this right after creating it.
|
||||
y_res = ray.pull(zeros([m, k], result_dtype)) # TODO(rkn): It would be preferable not to pull this right after creating it.
|
||||
result_dtype = np.linalg.qr(ray.get(a.objrefs[0, 0]))[0].dtype.name
|
||||
r_res = ray.get(zeros([k, n], result_dtype)) # TODO(rkn): It would be preferable not to get this right after creating it.
|
||||
y_res = ray.get(zeros([m, k], result_dtype)) # TODO(rkn): It would be preferable not to get this right after creating it.
|
||||
Ts = []
|
||||
|
||||
for i in range(min(a.num_blocks[0], a.num_blocks[1])): # this differs from the paper, which says "for i in range(a.num_blocks[1])", but that doesn't seem to make any sense when a.num_blocks[1] > a.num_blocks[0]
|
||||
sub_dist_array = subblocks(a_work, range(i, a_work.num_blocks[0]), [i])
|
||||
y, t, _, R = tsqr_hr(sub_dist_array)
|
||||
y_val = ray.pull(y)
|
||||
y_val = ray.get(y)
|
||||
|
||||
for j in range(i, a.num_blocks[0]):
|
||||
y_res.objrefs[j, i] = y_val.objrefs[j - i, 0]
|
||||
if a.shape[0] > a.shape[1]:
|
||||
# in this case, R needs to be square
|
||||
R_shape = ray.pull(ra.shape(R))
|
||||
R_shape = ray.get(ra.shape(R))
|
||||
eye_temp = ra.eye(R_shape[1], R_shape[0], dtype_name=result_dtype)
|
||||
r_res.objrefs[i, i] = ra.dot(eye_temp, R)
|
||||
else:
|
||||
|
||||
@@ -72,7 +72,7 @@ class Worker(object):
|
||||
elif result == None:
|
||||
return None # can't subclass None and don't need to because there is a global None
|
||||
# TODO(pcm): close the associated memory segment; if we don't, this leaks memory (but very little, so it is ok for now)
|
||||
result.ray_objref = objref # TODO(pcm): This could be done only for the "pull" case in the future if we want to increase performance
|
||||
result.ray_objref = objref # TODO(pcm): This could be done only for the "get" case in the future if we want to increase performance
|
||||
result.ray_deallocator = RayDealloc(self.handle, segmentid)
|
||||
return result
|
||||
|
||||
@@ -141,13 +141,13 @@ def connect(scheduler_addr, objstore_addr, worker_addr, worker=global_worker, pr
|
||||
def disconnect(worker=global_worker):
|
||||
ray.lib.disconnect(worker.handle)
|
||||
|
||||
def pull(objref, worker=global_worker):
|
||||
def get(objref, worker=global_worker):
|
||||
ray.lib.request_object(worker.handle, objref)
|
||||
if worker.print_task_info:
|
||||
print_task_info(ray.lib.task_info(worker.handle))
|
||||
return worker.get_object(objref)
|
||||
|
||||
def push(value, worker=global_worker):
|
||||
def put(value, worker=global_worker):
|
||||
objref = ray.lib.get_objref(worker.handle)
|
||||
worker.put_object(objref, value)
|
||||
if worker.print_task_info:
|
||||
|
||||
Reference in New Issue
Block a user