Changed how ray treats deserialization of custom classes (#333)

This commit is contained in:
Wapaul1
2016-08-01 15:38:05 -07:00
committed by Philipp Moritz
parent 98a508d6ca
commit 97b923a750
8 changed files with 20 additions and 29 deletions
+4 -7
View File
@@ -9,7 +9,7 @@ __all__ = ["BLOCK_SIZE", "DistArray", "assemble", "zeros", "ones", "copy",
BLOCK_SIZE = 10
class DistArray(object):
def construct(self, shape, objectids=None):
def __init__(self, shape, objectids=None):
self.shape = shape
self.ndim = len(shape)
self.num_blocks = [int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in self.shape]
@@ -17,17 +17,14 @@ class DistArray(object):
if self.num_blocks != list(self.objectids.shape):
raise Exception("The fields `num_blocks` and `objectids` are inconsistent, `num_blocks` is {} and `objectids` has shape {}".format(self.num_blocks, list(self.objectids.shape)))
def deserialize(self, primitives):
@staticmethod
def deserialize(primitives):
(shape, objectids) = primitives
self.construct(shape, objectids)
return DistArray(shape, objectids)
def serialize(self):
return (self.shape, self.objectids)
def __init__(self, shape=None):
if shape is not None:
self.construct(shape)
@staticmethod
def compute_block_lower(index, shape):
if len(index) != len(shape):
+2 -6
View File
@@ -46,8 +46,6 @@ def tsqr(a):
current_rs = new_rs
assert len(current_rs) == 1, "len(current_rs) = " + str(len(current_rs))
q_result = DistArray()
# handle the special case in which the whole DistArray "a" fits in one block
# and has fewer rows than columns, this is a bit ugly so think about how to
# remove it
@@ -56,9 +54,8 @@ def tsqr(a):
else:
q_shape = [a.shape[0], a.shape[0]]
q_num_blocks = DistArray.compute_num_blocks(q_shape)
q_result = DistArray()
q_objectids = np.empty(q_num_blocks, dtype=object)
q_result.construct(q_shape, q_objectids)
q_result = DistArray(q_shape, q_objectids)
# reconstruct output
for i in range(num_blocks):
@@ -145,8 +142,7 @@ def qr(a):
k = min(m, n)
# we will store our scratch work in a_work
a_work = DistArray()
a_work.construct(a.shape, np.copy(a.objectids))
a_work = DistArray(a.shape, np.copy(a.objectids))
result_dtype = np.linalg.qr(ray.get(a.objectids[0, 0]))[0].dtype.name
r_res = ray.get(zeros.remote([k, n], result_dtype)) # TODO(rkn): It would be preferable not to get this right after creating it.
+1 -2
View File
@@ -12,6 +12,5 @@ def normal(shape):
objectids = np.empty(num_blocks, dtype=object)
for index in np.ndindex(*num_blocks):
objectids[index] = ra.random.normal.remote(DistArray.compute_block_shape(index, shape))
result = DistArray()
result.construct(shape, objectids)
result = DistArray(shape, objectids)
return result
+1 -2
View File
@@ -46,8 +46,7 @@ def from_primitive(primitive_obj):
# This code assumes that the type module.__dict__[type_name] knows how to deserialize itself
type_module, type_name = primitive_obj[0]
module = importlib.import_module(type_module)
obj = module.__dict__[type_name]()
obj.deserialize(primitive_obj[1])
obj = module.__dict__[type_name].deserialize(primitive_obj[1])
return obj
def is_arrow_serializable(value):
+4 -3
View File
@@ -30,7 +30,7 @@ class RayFailedObject(object):
error_message (str): The error message raised by the task that failed.
"""
def __init__(self, error_message=None):
def __init__(self, error_message):
"""Initialize a RayFailedObject.
Args:
@@ -39,7 +39,8 @@ class RayFailedObject(object):
"""
self.error_message = error_message
def deserialize(self, primitives):
@staticmethod
def deserialize(primitives):
"""Create a RayFailedObject from a primitive object.
This initializes a RayFailedObject from a primitive object created by the
@@ -52,7 +53,7 @@ class RayFailedObject(object):
Args:
primitives (str): The object's error message.
"""
self.error_message = primitives
return RayFailedObject(primitives)
def serialize(self):
"""Turn a RayFailedObject into a primitive object.