diff --git a/examples/Time-Grad-Electricity.ipynb b/examples/Time-Grad-Electricity.ipynb index fd2bf01..556895f 100644 --- a/examples/Time-Grad-Electricity.ipynb +++ b/examples/Time-Grad-Electricity.ipynb @@ -5,8 +5,24 @@ "execution_count": 1, "metadata": { "ExecuteTime": { - "end_time": "2022-12-23T07:08:26.228737Z", - "start_time": "2022-12-23T07:08:24.806396Z" + "end_time": "2022-12-23T09:57:39.738715Z", + "start_time": "2022-12-23T09:57:39.730077Z" + } + }, + "outputs": [], + "source": [ + "# autoreload import your package\n", + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2022-12-23T09:57:40.738250Z", + "start_time": "2022-12-23T09:57:39.739677Z" } }, "outputs": [], @@ -22,23 +38,14 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": { "ExecuteTime": { - "end_time": "2022-12-23T07:08:26.417012Z", - "start_time": "2022-12-23T07:08:26.230461Z" + "end_time": "2022-12-23T09:57:40.909476Z", + "start_time": "2022-12-23T09:57:40.739741Z" } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/wassname/miniforge3/envs/gluonts10.0/lib/python3.9/site-packages/gluonts/json.py:101: UserWarning: Using `json`-module for json-handling. Consider installing one of `orjson`, `ujson` to speed up serialization and deserialization.\n", - " warnings.warn(\n" - ] - } - ], + "outputs": [], "source": [ "from gluonts.dataset.multivariate_grouper import MultivariateGrouper\n", "from gluonts.dataset.repository.datasets import dataset_recipes, get_dataset\n", @@ -48,11 +55,11 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": { "ExecuteTime": { - "end_time": "2022-12-23T07:08:26.469299Z", - "start_time": "2022-12-23T07:08:26.419051Z" + "end_time": "2022-12-23T09:57:40.966206Z", + "start_time": "2022-12-23T09:57:40.910477Z" } }, "outputs": [], @@ -65,11 +72,11 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": { "ExecuteTime": { - "end_time": "2022-12-23T07:08:26.477478Z", - "start_time": "2022-12-23T07:08:26.470628Z" + "end_time": "2022-12-23T09:57:40.980883Z", + "start_time": "2022-12-23T09:57:40.967157Z" } }, "outputs": [], @@ -79,11 +86,11 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": { "ExecuteTime": { - "end_time": "2022-12-23T07:08:26.494456Z", - "start_time": "2022-12-23T07:08:26.478496Z" + "end_time": "2022-12-23T09:57:41.003468Z", + "start_time": "2022-12-23T09:57:40.981830Z" } }, "outputs": [], @@ -147,11 +154,11 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": { "ExecuteTime": { - "end_time": "2022-12-23T07:08:26.516023Z", - "start_time": "2022-12-23T07:08:26.495676Z" + "end_time": "2022-12-23T09:57:41.025224Z", + "start_time": "2022-12-23T09:57:41.004330Z" } }, "outputs": [ @@ -169,11 +176,11 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": { "ExecuteTime": { - "end_time": "2022-12-23T07:08:26.532361Z", - "start_time": "2022-12-23T07:08:26.517210Z" + "end_time": "2022-12-23T09:57:41.041326Z", + "start_time": "2022-12-23T09:57:41.026069Z" }, "scrolled": true }, @@ -185,11 +192,11 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": { "ExecuteTime": { - "end_time": "2022-12-23T07:08:26.554669Z", - "start_time": "2022-12-23T07:08:26.534354Z" + "end_time": "2022-12-23T09:57:41.057694Z", + "start_time": "2022-12-23T09:57:41.042892Z" } }, "outputs": [ @@ -199,7 +206,7 @@ "MetaData(freq='H', target=None, feat_static_cat=[CategoricalFeatureInfo(name='feat_static_cat_0', cardinality='370')], feat_static_real=[], feat_dynamic_real=[], feat_dynamic_cat=[], prediction_length=24)" ] }, - "execution_count": 8, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -210,67 +217,51 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2022-12-23T07:12:03.693420Z", "start_time": "2022-12-23T07:12:02.845649Z" } }, - "outputs": [ - { - "data": { - "text/plain": [ - "7" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 10, "metadata": { "ExecuteTime": { - "end_time": "2022-12-23T07:12:11.716051Z", - "start_time": "2022-12-23T07:12:11.712672Z" + "end_time": "2022-12-23T09:57:41.073075Z", + "start_time": "2022-12-23T09:57:41.058526Z" } }, "outputs": [], "source": [ - "train_grouper = MultivariateGrouper(max_target_dim=min(2000, int(dataset.metadata.feat_static_cat[0].cardinality)))\n", - "\n", - "test_grouper = MultivariateGrouper(\n", - "# num_test_dates=int(len(dataset.test)/len(dataset.train)*2),\n", - " num_test_dates=7,\n", - " max_target_dim=min(2000, int(dataset.metadata.feat_static_cat[0].cardinality)))" + "train_grouper = MultivariateGrouper(max_target_dim=min(2000, int(dataset.metadata.feat_static_cat[0].cardinality)))\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": { "ExecuteTime": { - "start_time": "2022-12-23T07:12:12.382Z" + "end_time": "2022-12-23T09:57:44.134871Z", + "start_time": "2022-12-23T09:57:41.073880Z" } }, "outputs": [], "source": [ - "dataset_train = train_grouper(dataset.train)\n", - "dataset_test = test_grouper(dataset.test)" + "dataset_train = train_grouper(dataset.train)\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": { "ExecuteTime": { - "end_time": "2022-12-23T07:09:49.342107Z", - "start_time": "2022-12-23T07:09:49.342097Z" + "end_time": "2022-12-23T09:57:44.153700Z", + "start_time": "2022-12-23T09:57:44.136099Z" } }, "outputs": [], @@ -300,22 +291,265 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2022-12-23T07:09:49.342839Z", - "start_time": "2022-12-23T07:09:49.342830Z" + "start_time": "2022-12-23T09:57:39.746Z" } }, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "744b148440674ef3902ca56622cb001b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/99 [00:00 torch.Tensor: """ Computes sample paths by unrolling the RNN starting with a initial @@ -470,8 +468,8 @@ class TimeGradPredictionNetwork2(TimeGradTrainingNetwork2): num_features) scale Mean scale for each time series (batch_size, 1, target_dim) - begin_states - List of initial states for the RNN layers (batch_size, num_cells) + rnn_outputs + Outputs of the unrolled RNN (batch_size, seq_len, num_cells) Returns -------- sample_paths : Tensor @@ -482,56 +480,35 @@ class TimeGradPredictionNetwork2(TimeGradTrainingNetwork2): def repeat(tensor, dim=0): return tensor.repeat_interleave(repeats=self.num_parallel_samples, dim=dim) - # blows-up the dimension of each tensor to - # batch_size * self.num_sample_paths for increasing parallelism - repeated_past_target_cdf = repeat(past_target_cdf) - repeated_time_feat = repeat(time_feat) repeated_scale = repeat(scale) if self.scaling: self.diffusion.scale = repeated_scale - repeated_target_dimension_indicator = repeat(target_dimension_indicator) - - if self.cell_type == "LSTM": - repeated_states = [repeat(s, dim=1) for s in begin_states] - else: - repeated_states = repeat(begin_states, dim=1) - + future_samples = [] - - # for each future time-units we draw new samples for this time-unit - # and update the state - for k in range(self.prediction_length): - lags = self.get_lagged_subsequences( - sequence=repeated_past_target_cdf, - sequence_length=self.history_length + k, - indices=self.shifted_lags, - subsequences_length=1, - ) - - rnn_outputs, repeated_states, _, _ = self.unroll( - begin_state=repeated_states, - lags=lags, - scale=repeated_scale, - time_feat=repeated_time_feat[:, k : k + 1, ...], - target_dimension_indicator=repeated_target_dimension_indicator, - unroll_length=1, - ) - + for _ in range(self.num_parallel_samples): distr_args = self.distr_args(rnn_outputs=rnn_outputs) - # (batch_size, 1, target_dim) - new_samples = self.diffusion.sample(cond=distr_args) + distr_args = distr_args.permute(0, 2, 1) + samples = self.diffusion.sample(cond=distr_args) + future_samples.append(samples) + # import pdb; pdb.set_trace() + # rnn_outputs = torch.Size([7, 24, 40]) + # torch.Size([7, 100, 370, 24]) + # torch.Size([7, 370, 24]) + samples = torch.stack(future_samples, dim=1).permute(0, 1, 3, 2) - # (batch_size, seq_len, target_dim) - future_samples.append(new_samples) - repeated_past_target_cdf = torch.cat( - (repeated_past_target_cdf, new_samples), dim=1 - ) + # # (batch_size, seq_len, target_dim) + # future_samples.append(new_samples) + # repeated_past_target_cdf = torch.cat( + # (repeated_past_target_cdf, new_samples), dim=1 + # ) # (batch_size * num_samples, prediction_length, target_dim) - samples = torch.cat(future_samples, dim=1) + # samples = torch.cat(future_samples, dim=1) # (batch_size, num_samples, prediction_length, target_dim) + # print('samples.shape', samples.shape) + return samples return samples.reshape( ( -1, @@ -586,21 +563,20 @@ class TimeGradPredictionNetwork2(TimeGradTrainingNetwork2): past_observed_values, 1 - past_is_pad.unsqueeze(-1) ) - # unroll the decoder in "prediction mode", i.e. with past data only - _, begin_states, scale, _, _ = self.unroll_encoder( + + # unroll the decoder in "training mode", i.e. by providing future data + # as well + rnn_outputs, _, scale, _, _ = self.unroll_encoder( past_time_feat=past_time_feat, past_target_cdf=past_target_cdf, past_observed_values=past_observed_values, past_is_pad=past_is_pad, - future_time_feat=None, + future_time_feat=future_time_feat, future_target_cdf=None, target_dimension_indicator=target_dimension_indicator, ) return self.sampling_decoder( - past_target_cdf=past_target_cdf, - target_dimension_indicator=target_dimension_indicator, - time_feat=future_time_feat, scale=scale, - begin_states=begin_states, + rnn_outputs=rnn_outputs, )