mirror of
https://github.com/wassname/DeepTime.git
synced 2026-06-27 21:20:13 +08:00
14 lines
385 B
Python
14 lines
385 B
Python
from json_tricks import dump, dumps, load, loads, strip_comments
|
|
|
|
def torch_encode(obj, primitives=False):
|
|
from torch import Tensor
|
|
if isinstance(obj, Tensor):
|
|
if primitives:
|
|
return obj.numpy().tolist()
|
|
raise NotImplemented()
|
|
return obj
|
|
|
|
def serialize(o):
|
|
s = dumps(o, extra_obj_encoders=[torch_encode], primitives=True)
|
|
return loads(s)
|