import pandas as pd import numpy as np from matplotlib import pyplot as plt import torch import io import PIL from torchvision.transforms import ToTensor from pandas.plotting import register_matplotlib_converters register_matplotlib_converters() eps = 1e-5 def plot_rows( x_context_rows: pd.DataFrame, x_target_rows: pd.DataFrame, context_y_rows: pd.DataFrame, target_y_rows: pd.DataFrame, pred_y: np.array, std: np.array, undo_log=False, legend=True, ): """Plots the predicted mean and variance and the context points. Args: x_context_rows x_target_rows context_y_rows: dataframe with datetime index, and labels target_y_rows: pred_y: An array of shape [B,num_targets,1] that contains the predicted means of the y values at the target points in target_x. std: An array of shape [B,num_targets,1] that contains the predicted std dev of the y values at the target points in target_x. """ if undo_log: target_y_rows = np.exp(target_y_rows) - eps context_y_rows = np.exp(context_y_rows) - eps # I don't want to show too much context context_y_rows = context_y_rows[-96:] x_context_rows = x_context_rows[-96:] # Plot everything j = 0 label = "energy(kWh/hh)" # Plot input data # Start with true data and use it to get ylimits (that way they are constant) plt.plot(target_y_rows.index, target_y_rows.values, "k:", linewidth=2, label="true") plt.plot(context_y_rows.index, context_y_rows.values, "k:", linewidth=2, label="true") ylims = plt.ylim() # plot predictions plt.plot(target_y_rows.index, pred_y[0], "b", linewidth=2, label="predicted") plt.fill_between( target_y_rows.index, pred_y[0, :, 0] - std[0, :, 0], pred_y[0, :, 0] + std[0, :, 0], alpha=0.25, facecolor="blue", interpolate=True, label="uncertainty", ) plt.fill_between( target_y_rows.index, pred_y[0, :, 0] - std[0, :, 0] * 2, pred_y[0, :, 0] + std[0, :, 0] * 2, alpha=0.125, facecolor="blue", interpolate=True, label="uncertainty", ) # Finally context, we do this with pandas so it will override x ax and make it nice context_y_rows[label].plot( style="ko", linewidth=2, label="input data", ax=plt.gca() ) # Make the plot pretty plt.grid("off") plt.ylim(*ylims) plt.xlabel("Date") plt.ylabel("Energy (kWh/hh)") plt.grid(b=None) if legend: plt.legend() def plot_from_loader( loader, model, i=0, undo_log=False, title="", plot=True, legend=False, context_in_target=None ): if context_in_target is None: context_in_target = model.hparams["context_in_target"] device = next(model.parameters()).device data = loader.collate_fn([loader.dataset[i]], sample=False) data = [d.to(device) for d in data] context_x, context_y, target_x_extra, target_y_extra = data target_x = target_x_extra target_y = target_y_extra # Get context, like dates, from dataset x_rows, y_rows = loader.dataset.get_rows(i) max_num_context = context_x.shape[1] y_context_rows = y_rows[:max_num_context] y_target_extra_rows = y_rows[max_num_context:] x_context_rows = x_rows[:max_num_context] x_target_extra_rows = x_rows[max_num_context:] dt = y_target_extra_rows.index[0] if context_in_target: y_target_rows = y_rows x_target_rows = x_rows else: y_target_rows = y_target_extra_rows x_target_rows = x_target_extra_rows # target_x = torch.cat([context_x, target_x_extra], 1) # target_y = torch.cat([context_y, target_y_extra], 1) model.eval() with torch.no_grad(): y_pred, losses, extra = model(context_x, context_y, target_x, target_y) loss_test = losses["loss"] if "loss" in losses else 0. y_std = extra["y_dist"].scale if plot: plt.figure() plt.title(title + f" loss={loss_test: 2.2g} {dt}") plot_rows( x_context_rows, x_target_rows, y_context_rows, y_target_rows, y_pred.detach().cpu().numpy(), y_std.detach().cpu().numpy(), undo_log=False, legend=legend, ) return loss_test def plot_from_loader_to_tensor( *args, **kwargs ): plot_from_loader(*args, **kwargs) # Send fig to tensorboard buf = io.BytesIO() plt.savefig(buf, format='jpeg') plt.close() buf.seek(0) image = PIL.Image.open(buf) image = ToTensor()(image)#.unsqueeze(0) return image