Files
seq2seq-time/seq2seq_time/util.py
T
2020-10-18 13:12:09 +08:00

11 lines
236 B
Python

from pathlib import Path
import torch
project_dir = Path(__file__).parent.parent
def to_numpy(x):
"""Helper function to avoid repeating code"""
if isinstance(x, torch.Tensor):
x = x.cpu().detach().numpy()
return x