lightning

This commit is contained in:
wassname
2020-01-25 14:40:23 +08:00
parent 84dc7dca27
commit f1664a89cb
6 changed files with 696 additions and 5 deletions
+1
View File
@@ -1,5 +1,6 @@
/data/
/old_notebooks/
/lightning_logs/
# Created by https://www.gitignore.io/api/code,linux,macos,python,windows,jupyternotebook,jupyternotebooks
# Edit at https://www.gitignore.io/?templates=code,linux,macos,python,windows,jupyternotebook,jupyternotebooks
+1
View File
@@ -3,3 +3,4 @@ tqdm
pandas
numpy
torchsummaryX
pytorch_lightning
File diff suppressed because one or more lines are too long
+9 -4
View File
@@ -11,7 +11,7 @@ def npsample_batch(x, y, size=None, sort=True):
return x[:, inds], y[:, inds]
def collate_fns(max_num_context, max_num_extra_target, sample, sort=True):
def collate_fn(batch):
def collate_fn(batch, sample=sample):
# Collate
x = np.stack([x for x, y in batch], 0)
y = np.stack([y for x, y in batch], 0)
@@ -52,7 +52,7 @@ class SmartMeterDataSet(torch.utils.data.Dataset):
self.num_extra_target = num_extra_target
self.label_names = label_names
def __getitem__(self, i):
def get_rows(self, i):
rows = self.df.iloc[i : i + (self.num_context + self.num_extra_target)].copy()
rows['tstp'] = (rows['tstp'] - rows['tstp'].iloc[0]).dt.total_seconds() / 86400.0
rows = rows.sort_values('tstp')
@@ -61,9 +61,14 @@ class SmartMeterDataSet(torch.utils.data.Dataset):
columns = ['tstp'] + list(set(rows.columns) - set(['tstp']))
rows = rows[columns]
x = rows.drop(columns=self.label_names).values
y = rows[self.label_names].values
x = rows.drop(columns=self.label_names)
y = rows[self.label_names]
return x, y
def __getitem__(self, i):
x,y = self.get_rows(i)
return x.values, y.values
def __len__(self):
return len(self.df) - (self.num_context + self.num_extra_target)
+2 -1
View File
@@ -52,7 +52,8 @@ class LatentModel(nn.Module):
attention_dropout=0,
min_std=0.1,
use_lvar=False,
use_deterministic_path=True
use_deterministic_path=True,
**kwargs
):
super().__init__()