mirror of
https://github.com/wassname/seq2seq-time.git
synced 2026-06-27 20:22:59 +08:00
misc
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import xarray as xr
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
project_dir = Path(__file__).parent.parent
|
||||
|
||||
def to_numpy(x):
|
||||
@@ -12,3 +15,11 @@ def to_numpy(x):
|
||||
def mask_upper_triangular(N, device):
|
||||
"""Causal attention."""
|
||||
return torch.triu(torch.ones(N, N), diagonal=1).to(device).bool()
|
||||
|
||||
def dset_to_nc(dset, f, engine="netcdf4", compression={"zlib": True}):
|
||||
if isinstance(dset, xr.DataArray):
|
||||
dset = dset.to_dataset(name="data")
|
||||
encoding = {k: {"zlib": True} for k in dset.data_vars}
|
||||
logger.info(f"saving to {f}")
|
||||
dset.to_netcdf(f, engine=engine, encoding=encoding)
|
||||
logger.info(f"Wrote {f.stem}.nc size={f.stat().st_size/1e6} M")
|
||||
|
||||
Reference in New Issue
Block a user