Files
seq2seq-time/seq2seq_time/util.py
T
2020-10-23 08:30:02 +08:00

15 lines
373 B
Python

from pathlib import Path
import torch
project_dir = Path(__file__).parent.parent
def to_numpy(x):
"""Helper function to avoid repeating code"""
if isinstance(x, torch.Tensor):
x = x.cpu().detach().numpy()
return x
def mask_upper_triangular(N, device):
"""Causal attention."""
return torch.triu(torch.ones(N, N), diagonal=1).to(device).bool()