mirror of
https://github.com/wassname/seq2seq-time.git
synced 2026-06-27 19:16:40 +08:00
15 lines
373 B
Python
15 lines
373 B
Python
from pathlib import Path
|
|
import torch
|
|
|
|
project_dir = Path(__file__).parent.parent
|
|
|
|
def to_numpy(x):
|
|
"""Helper function to avoid repeating code"""
|
|
if isinstance(x, torch.Tensor):
|
|
x = x.cpu().detach().numpy()
|
|
return x
|
|
|
|
def mask_upper_triangular(N, device):
|
|
"""Causal attention."""
|
|
return torch.triu(torch.ones(N, N), diagonal=1).to(device).bool()
|