mirror of
https://github.com/wassname/seq2seq-time.git
synced 2026-06-28 22:37:59 +08:00
11 lines
236 B
Python
11 lines
236 B
Python
from pathlib import Path
|
|
import torch
|
|
|
|
project_dir = Path(__file__).parent.parent
|
|
|
|
def to_numpy(x):
|
|
"""Helper function to avoid repeating code"""
|
|
if isinstance(x, torch.Tensor):
|
|
x = x.cpu().detach().numpy()
|
|
return x
|