mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 21:12:15 +08:00
add custom callbacks for serialization
This commit is contained in:
@@ -59,6 +59,37 @@ class SerializationTests(unittest.TestCase):
|
||||
for obj in TEST_OBJECTS:
|
||||
self.roundTripTest([obj])
|
||||
|
||||
def testCallback(self):
|
||||
|
||||
class Foo(object):
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
||||
class Bar(object):
|
||||
def __init__(self):
|
||||
self.foo = Foo()
|
||||
|
||||
def serialize(obj):
|
||||
return dict(obj.__dict__, **{"_pytype_": type(obj).__name__})
|
||||
|
||||
def deserialize(obj):
|
||||
if obj["_pytype_"] == "Foo":
|
||||
result = Foo()
|
||||
elif obj["_pytype_"] == "Bar":
|
||||
result = Bar()
|
||||
|
||||
obj.pop("_pytype_", None)
|
||||
result.__dict__ = obj
|
||||
return result
|
||||
|
||||
bar = Bar()
|
||||
bar.foo.x = 42
|
||||
|
||||
libnumbuf.register_callbacks(serialize, deserialize)
|
||||
|
||||
metadata, size, serialized = libnumbuf.serialize_list([bar])
|
||||
self.assertEqual(libnumbuf.deserialize_list(serialized)[0].foo.x, 42)
|
||||
|
||||
def testBuffer(self):
|
||||
for (i, obj) in enumerate(TEST_OBJECTS):
|
||||
schema, size, batch = libnumbuf.serialize_list([obj])
|
||||
|
||||
Reference in New Issue
Block a user