mirror of
https://github.com/wassname/attentive-neural-processes.git
synced 2026-06-27 18:03:39 +08:00
lightning
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
/data/
|
||||
/old_notebooks/
|
||||
/lightning_logs/
|
||||
|
||||
# Created by https://www.gitignore.io/api/code,linux,macos,python,windows,jupyternotebook,jupyternotebooks
|
||||
# Edit at https://www.gitignore.io/?templates=code,linux,macos,python,windows,jupyternotebook,jupyternotebooks
|
||||
|
||||
@@ -3,3 +3,4 @@ tqdm
|
||||
pandas
|
||||
numpy
|
||||
torchsummaryX
|
||||
pytorch_lightning
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -11,7 +11,7 @@ def npsample_batch(x, y, size=None, sort=True):
|
||||
return x[:, inds], y[:, inds]
|
||||
|
||||
def collate_fns(max_num_context, max_num_extra_target, sample, sort=True):
|
||||
def collate_fn(batch):
|
||||
def collate_fn(batch, sample=sample):
|
||||
# Collate
|
||||
x = np.stack([x for x, y in batch], 0)
|
||||
y = np.stack([y for x, y in batch], 0)
|
||||
@@ -52,7 +52,7 @@ class SmartMeterDataSet(torch.utils.data.Dataset):
|
||||
self.num_extra_target = num_extra_target
|
||||
self.label_names = label_names
|
||||
|
||||
def __getitem__(self, i):
|
||||
def get_rows(self, i):
|
||||
rows = self.df.iloc[i : i + (self.num_context + self.num_extra_target)].copy()
|
||||
rows['tstp'] = (rows['tstp'] - rows['tstp'].iloc[0]).dt.total_seconds() / 86400.0
|
||||
rows = rows.sort_values('tstp')
|
||||
@@ -61,9 +61,14 @@ class SmartMeterDataSet(torch.utils.data.Dataset):
|
||||
columns = ['tstp'] + list(set(rows.columns) - set(['tstp']))
|
||||
rows = rows[columns]
|
||||
|
||||
x = rows.drop(columns=self.label_names).values
|
||||
y = rows[self.label_names].values
|
||||
x = rows.drop(columns=self.label_names)
|
||||
y = rows[self.label_names]
|
||||
return x, y
|
||||
|
||||
|
||||
def __getitem__(self, i):
|
||||
x,y = self.get_rows(i)
|
||||
return x.values, y.values
|
||||
|
||||
def __len__(self):
|
||||
return len(self.df) - (self.num_context + self.num_extra_target)
|
||||
|
||||
+2
-1
@@ -52,7 +52,8 @@ class LatentModel(nn.Module):
|
||||
attention_dropout=0,
|
||||
min_std=0.1,
|
||||
use_lvar=False,
|
||||
use_deterministic_path=True
|
||||
use_deterministic_path=True,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user