mirror of
https://github.com/wassname/seq2seq-time.git
synced 2026-06-27 19:16:40 +08:00
26 lines
817 B
Python
26 lines
817 B
Python
from pathlib import Path
|
|
import torch
|
|
import xarray as xr
|
|
import logging
|
|
|
|
logger = logging.getLogger(__file__)
|
|
project_dir = Path(__file__).parent.parent
|
|
|
|
def to_numpy(x):
|
|
"""Helper function to avoid repeating code"""
|
|
if isinstance(x, torch.Tensor):
|
|
x = x.cpu().detach().numpy()
|
|
return x
|
|
|
|
def mask_upper_triangular(N, device):
|
|
"""Causal attention."""
|
|
return torch.triu(torch.ones(N, N), diagonal=1).to(device).bool()
|
|
|
|
def dset_to_nc(dset, f, engine="netcdf4", compression={"zlib": True}):
|
|
if isinstance(dset, xr.DataArray):
|
|
dset = dset.to_dataset(name="data")
|
|
encoding = {k: {"zlib": True} for k in dset.data_vars}
|
|
logger.info(f"saving to {f}")
|
|
dset.to_netcdf(f, engine=engine, encoding=encoding)
|
|
logger.info(f"Wrote {f.stem}.nc size={f.stat().st_size/1e6} M")
|