This commit is contained in:
wassname
2020-04-26 12:36:19 +08:00
parent 67cb91bb83
commit a1c26dfbb7
7 changed files with 29 additions and 24 deletions
+2 -2
View File
@@ -193,8 +193,8 @@ def get_smartmeter_df(indir=Path('./data/smart-meters-in-london'), max_files=60,
test_files = [f for f in csv_files if is_test(f)]
val_files = [f for f in csv_files if is_val(f) and (not is_test(f))]
train_files = [f for f in csv_files if (not is_val(f)) and (not is_test(f))]
print(len(train_files), len(val_files), len(test_files))
print(train_files, val_files, test_files)
# print(len(train_files), len(val_files), len(test_files))
# print(train_files, val_files, test_files)
assert not set(train_files).intersection(set(test_files), set(val_files))
assert not set(test_files).intersection(set(val_files))
+1 -1
View File
@@ -57,7 +57,7 @@ class PL_Seq2Seq(pl.LightningModule):
train_outputs = agg_logs(self.train_logs)
self.train_logs = []
print(f"step val {self.trainer.global_step}, {outputs} {train_outputs}")
logger.info(f"step val {self.trainer.global_step}, {outputs} {train_outputs}")
# tensorboard_logs_str = {k: f"{v}" for k, v in tensorboard_logs.items()}
# print(f"step {self.trainer.global_step}, {outputs}")
+1 -2
View File
@@ -38,7 +38,6 @@ class Seq2SeqNet(nn.Module):
def __init__(self, hparams, _min_std=0.05):
super().__init__()
hparams = hparams_power(hparams)
print(hparams)
self.hparams = hparams
self._min_std = _min_std
@@ -107,7 +106,7 @@ class Seq2SeqNet(nn.Module):
if output is not None:
num_layers = h_out.shape[1]
print(cell.max(), h_out.max(), h.max())
# print(cell.max(), h_out.max(), h.max())
h_out += h.mean(1).repeat(1, num_layers, 1)
cell += h.max(1).repeat(1, num_layers, 1)
@@ -17,15 +17,15 @@ class PL_NeuralProcess(PL_Seq2Seq):
DEFAULT_ARGS = {
'dropout': 0.1,
'learning_rate': 0.004,
'learning_rate': 0.003,
'attention_dropout': 0.5,
'batchnorm': False,
'attention_layers': 2,
'det_enc_cross_attn_type': 'uniform',
'det_enc_self_attn_type': 'uniform',
'latent_enc_self_attn_type': 'uniform',
'num_heads_power': 3,
'hidden_dim_power': 3,
'num_heads_power': 2,
'hidden_dim_power': 6,
'latent_dim_power': 5,
'n_latent_encoder_layers': 3,
'n_det_encoder_layers': 3,
@@ -309,11 +309,11 @@ class NeuralProcess(nn.Module):
device = next(self.parameters()).device
if self.hparams.get('bnorm_inputs', True):
# https://stackoverflow.com/a/46772183/221742
target_x = self.norm_x(target_x)
context_x = self.norm_x(context_x)
context_y = self.norm_y(context_y)
# if self.hparams.get('bnorm_inputs', True):
# https://stackoverflow.com/a/46772183/221742
target_x = self.norm_x(target_x)
context_x = self.norm_x(context_x)
context_y = self.norm_y(context_y)
if self._use_rnn:
# see https://arxiv.org/abs/1910.09323 where x is substituted with h = RNN(x)
+5
View File
@@ -63,6 +63,8 @@ class Attention(nn.Module):
dropout=dropout,
batchnorm=batchnorm,
)
elif self._rep == "lstm":
self._lstm = LSTMBlock(x_dim, hidden_dim, batchnorm=batchnorm, dropout=dropout, num_layers=attention_layers)
if attention_type == "uniform":
self._attention_func = self._uniform_attention
@@ -95,6 +97,9 @@ class Attention(nn.Module):
if self._rep == "mlp":
k = self.batch_mlp_k(k)
q = self.batch_mlp_q(q)
elif self._rep == "lstm":
k = self.batch_lstm(k)
q = self.batch_lstm(q)
rep = self._attention_func(k, v, q)
return rep
+12 -11
View File
@@ -42,6 +42,7 @@ def main(
val_percent_check=PERCENT_TEST_EXAMPLES,
checkpoint_callback=checkpoint_callback,
max_epochs=hparams["max_nb_epochs"],
weights_summary='top',
gpus=-1 if torch.cuda.is_available() else None,
early_stop_callback=PyTorchLightningPruningCallback(trial, monitor="val_loss")
if prune
@@ -62,7 +63,7 @@ def objective(trial, PL_MODEL_CLS, name, user_attrs):
trial = PL_MODEL_CLS.add_suggest(trial)
[trial.set_user_attr(k, v) for k, v in user_attrs.items()]
print(dict(number=trial.number, params=trial.params, user_attrs=trial.user_attrs))
logger.debug(dict(number=trial.number, params=trial.params, user_attrs=trial.user_attrs))
model, trainer = main(trial, PL_MODEL_CLS=PL_MODEL_CLS, name=name)
@@ -71,24 +72,24 @@ def objective(trial, PL_MODEL_CLS, name, user_attrs):
if len(checkpoints):
checkpoint = checkpoints[-1]
device = next(model.parameters()).device
print(f"Loading checkpoint {checkpoint}")
logger.info(f"Loading checkpoint {checkpoint}")
model = model.load_from_checkpoint(checkpoint).to(device)
trainer.test(model)
# also report to tensorboard & print
print("logger.metrics", model.logger.metrics[-1:])
logger.info("logger.metrics", model.logger.metrics[-1:])
model.logger.experiment.add_hparams(trial.params, model.logger.metrics[-1])
model.logger.save()
return model.logger.metrics[-1]["agg_test_loss"]
return model.logger.metrics[-1]["agg_test_score"]
def add_number(trial: optuna.Trial, model_dir: Path):
# For manual experiment we will start at -1 and deincr by 1
versions = [int(s.stem.split("_")[-1]) for s in model_dir.glob("version_*")] + [-1]
trial.number = min(versions) - 1
print("trial.number", trial.number)
# logger.debug("trial.number", trial.number)
return trial
@@ -101,7 +102,7 @@ def run_trial(
plot_from_loader=plot_from_loader,
number=None,
):
print(f"now run `tensorboard --logdir {MODEL_DIR}`")
logger.info(f"now run `tensorboard --logdir {MODEL_DIR}`")
(MODEL_DIR / name).mkdir(parents=True, exist_ok=True)
if getattr(PL_MODEL_CLS, 'DEFAULT_ARGS', None):
@@ -121,7 +122,7 @@ def run_trial(
# Add user attributes
[trial.set_user_attr(k, v) for k, v in user_attrs.items()]
print('trial', trial.number, trial, trial.params, trial.user_attrs)
logger.info('trial number=%s trial=%s params=%s attrs=%s', trial.number, trial, trial.params, trial.user_attrs)
model, trainer = main(
trial, PL_MODEL_CLS, name=name, MODEL_DIR=MODEL_DIR, train=False, prune=False
@@ -132,7 +133,7 @@ def run_trial(
try:
trainer.fit(model)
except KeyboardInterrupt:
print('KeyboardInterrupt, skipping rest of training')
logger.warning('KeyboardInterrupt, skipping rest of training')
pass
# Plot
@@ -151,7 +152,7 @@ def run_trial(
if len(checkpoints):
checkpoint = checkpoints[-1]
device = next(model.parameters()).device
print(f"Loading checkpoint {checkpoint}")
logger.info(f"Loading checkpoint {checkpoint}")
model = model.load_from_checkpoint(checkpoint).to(device)
# Plot
@@ -162,11 +163,11 @@ def run_trial(
plot_from_loader(model.test_dataloader(), model, i=670, title='test 670')
plt.show()
else:
print('no checkpoints')
logger.warning('no checkpoints')
try:
trainer.test(model)
except KeyboardInterrupt:
print('KeyboardInterrupt, skipping rest of testing')
logger.warning('KeyboardInterrupt, skipping rest of testing')
pass
return trial, trainer, model