Files
wassname 052fd6596c misc
2020-10-27 06:43:50 +08:00

26 lines
817 B
Python

from pathlib import Path
import torch
import xarray as xr
import logging
logger = logging.getLogger(__file__)
project_dir = Path(__file__).parent.parent
def to_numpy(x):
"""Helper function to avoid repeating code"""
if isinstance(x, torch.Tensor):
x = x.cpu().detach().numpy()
return x
def mask_upper_triangular(N, device):
"""Causal attention."""
return torch.triu(torch.ones(N, N), diagonal=1).to(device).bool()
def dset_to_nc(dset, f, engine="netcdf4", compression={"zlib": True}):
if isinstance(dset, xr.DataArray):
dset = dset.to_dataset(name="data")
encoding = {k: {"zlib": True} for k in dset.data_vars}
logger.info(f"saving to {f}")
dset.to_netcdf(f, engine=engine, encoding=encoding)
logger.info(f"Wrote {f.stem}.nc size={f.stat().st_size/1e6} M")