mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 19:00:30 +08:00
9d73d9aae8
For issue #11
249 lines
9.4 KiB
Python
249 lines
9.4 KiB
Python
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License").
|
|
# You may not use this file except in compliance with the License.
|
|
# A copy of the License is located at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# or in the "license" file accompanying this file. This file is distributed
|
|
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
|
# express or implied. See the License for the specific language governing
|
|
# permissions and limitations under the License.
|
|
|
|
|
|
from itertools import chain, combinations
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from pts.modules import FeatureEmbedder, FeatureAssembler
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"config",
|
|
(
|
|
lambda N, T: [
|
|
# single static feature
|
|
dict(shape=(N, 1), kwargs=dict(cardinalities=[50], embedding_dims=[10]),),
|
|
# single dynamic feature
|
|
dict(shape=(N, T, 1), kwargs=dict(cardinalities=[2], embedding_dims=[10]),),
|
|
# multiple static features
|
|
dict(
|
|
shape=(N, 4),
|
|
kwargs=dict(
|
|
cardinalities=[50, 50, 50, 50], embedding_dims=[10, 20, 30, 40],
|
|
),
|
|
),
|
|
# multiple dynamic features
|
|
dict(
|
|
shape=(N, T, 3),
|
|
kwargs=dict(cardinalities=[30, 30, 30], embedding_dims=[10, 20, 30]),
|
|
),
|
|
]
|
|
)(10, 20),
|
|
)
|
|
def test_feature_embedder(config):
|
|
out_shape = config["shape"][:-1] + (sum(config["kwargs"]["embedding_dims"]),)
|
|
embed_feature = FeatureEmbedder(**config["kwargs"])
|
|
for embed in embed_feature._FeatureEmbedder__embedders:
|
|
nn.init.constant_(embed.weight, 1.0)
|
|
|
|
def test_parameters_length():
|
|
exp_params_len = len([p for p in embed_feature.parameters()])
|
|
act_params_len = len(config["kwargs"]["embedding_dims"])
|
|
assert exp_params_len == act_params_len
|
|
|
|
def test_forward_pass():
|
|
act_output = embed_feature(torch.ones(config["shape"]).to(torch.long))
|
|
exp_output = torch.ones(out_shape)
|
|
|
|
assert act_output.shape == exp_output.shape
|
|
assert torch.abs(torch.sum(act_output - exp_output)) < 1e-20
|
|
|
|
test_parameters_length()
|
|
test_forward_pass()
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"config",
|
|
(
|
|
lambda N, T: [
|
|
dict(
|
|
N=N,
|
|
T=T,
|
|
static_cat=dict(C=2),
|
|
static_real=dict(C=5),
|
|
dynamic_cat=dict(C=3),
|
|
dynamic_real=dict(C=4),
|
|
embed_static=dict(cardinalities=[2, 4], embedding_dims=[3, 6],),
|
|
embed_dynamic=dict(
|
|
cardinalities=[30, 30, 30], embedding_dims=[10, 20, 30],
|
|
),
|
|
)
|
|
]
|
|
)(10, 25),
|
|
)
|
|
def test_feature_assembler(config):
|
|
# iterate over the power-set of all possible feature types, excluding the empty set
|
|
feature_types = {
|
|
"static_cat",
|
|
"static_real",
|
|
"dynamic_cat",
|
|
"dynamic_real",
|
|
}
|
|
feature_combs = chain.from_iterable(
|
|
combinations(feature_types, r) for r in range(1, len(feature_types) + 1)
|
|
)
|
|
|
|
# iterate over the power-set of all possible feature types, including the empty set
|
|
embedder_types = {"embed_static", "embed_dynamic"}
|
|
embedder_combs = chain.from_iterable(
|
|
combinations(embedder_types, r) for r in range(0, len(embedder_types) + 1)
|
|
)
|
|
|
|
for enabled_embedders in embedder_combs:
|
|
embed_static = (
|
|
FeatureEmbedder(**config["embed_static"])
|
|
if "embed_static" in enabled_embedders
|
|
else None
|
|
)
|
|
embed_dynamic = (
|
|
FeatureEmbedder(**config["embed_dynamic"])
|
|
if "embed_dynamic" in enabled_embedders
|
|
else None
|
|
)
|
|
|
|
for enabled_features in feature_combs:
|
|
assemble_feature = FeatureAssembler(
|
|
T=config["T"], embed_static=embed_static, embed_dynamic=embed_dynamic,
|
|
)
|
|
# assemble_feature.collect_params().initialize(mx.initializer.One())
|
|
|
|
def test_parameters_length():
|
|
exp_params_len = sum(
|
|
[
|
|
len(config[k]["embedding_dims"])
|
|
for k in ["embed_static", "embed_dynamic"]
|
|
if k in enabled_embedders
|
|
]
|
|
)
|
|
act_params_len = len([p for p in assemble_feature.parameters()])
|
|
assert exp_params_len == act_params_len
|
|
|
|
def test_forward_pass():
|
|
N, T = config["N"], config["T"]
|
|
|
|
inp_features = []
|
|
out_features = []
|
|
|
|
if "static_cat" not in enabled_features:
|
|
inp_features.append(torch.zeros((N, 1)))
|
|
out_features.append(torch.zeros((N, T, 1)))
|
|
elif embed_static: # and 'static_cat' in enabled_features
|
|
C = config["static_cat"]["C"]
|
|
inp_features.append(
|
|
torch.cat(
|
|
[
|
|
torch.randint(
|
|
0,
|
|
config["embed_static"]["cardinalities"][c],
|
|
(N, 1),
|
|
)
|
|
for c in range(C)
|
|
],
|
|
dim=1,
|
|
)
|
|
)
|
|
out_features.append(
|
|
torch.ones(
|
|
(N, T, sum(config["embed_static"]["embedding_dims"]),)
|
|
)
|
|
)
|
|
else: # not embed_static and 'static_cat' in enabled_features
|
|
C = config["static_cat"]["C"]
|
|
inp_features.append(
|
|
torch.cat(
|
|
[
|
|
torch.randint(
|
|
0,
|
|
config["embed_static"]["cardinalities"][c],
|
|
(N, 1),
|
|
)
|
|
for c in range(C)
|
|
],
|
|
dim=1,
|
|
)
|
|
)
|
|
out_features.append(
|
|
inp_features[-1].unsqueeze(1).expand(-1, T, -1).float()
|
|
)
|
|
|
|
if "static_real" not in enabled_features:
|
|
inp_features.append(torch.zeros((N, 1)))
|
|
out_features.append(torch.zeros((N, T, 1)))
|
|
else:
|
|
C = config["static_real"]["C"]
|
|
static_real = torch.empty((N, C)).uniform_(0, 100)
|
|
inp_features.append(static_real)
|
|
out_features.append(static_real.unsqueeze(-2).expand(-1, T, -1))
|
|
|
|
if "dynamic_cat" not in enabled_features:
|
|
inp_features.append(torch.zeros((N, T, 1)))
|
|
out_features.append(torch.zeros((N, T, 1)))
|
|
elif embed_dynamic: # and 'static_cat' in enabled_features
|
|
C = config["dynamic_cat"]["C"]
|
|
inp_features.append(
|
|
torch.cat(
|
|
[
|
|
torch.randint(
|
|
0,
|
|
config["embed_dynamic"]["cardinalities"][c],
|
|
(N, T, 1),
|
|
)
|
|
for c in range(C)
|
|
],
|
|
dim=2,
|
|
)
|
|
)
|
|
out_features.append(
|
|
torch.ones(
|
|
(N, T, sum(config["embed_dynamic"]["embedding_dims"]),)
|
|
)
|
|
)
|
|
else: # not embed_dynamic and 'dynamic_cat' in enabled_features
|
|
C = config["dynamic_cat"]["C"]
|
|
inp_features.append(
|
|
torch.cat(
|
|
[
|
|
torch.randint(
|
|
0,
|
|
config["embed_dynamic"]["cardinalities"][c],
|
|
(N, T, 1),
|
|
)
|
|
for c in range(C)
|
|
],
|
|
dim=2,
|
|
)
|
|
)
|
|
out_features.append(inp_features[-1].float())
|
|
|
|
if "dynamic_real" not in enabled_features:
|
|
inp_features.append(torch.zeros((N, T, 1)))
|
|
out_features.append(torch.zeros((N, T, 1)))
|
|
else:
|
|
C = config["dynamic_real"]["C"]
|
|
dynamic_real = torch.empty((N, T, C)).uniform_(0, 100)
|
|
inp_features.append(dynamic_real)
|
|
out_features.append(dynamic_real)
|
|
|
|
act_output = assemble_feature(*inp_features)
|
|
exp_output = torch.cat(out_features, dim=2)
|
|
|
|
assert exp_output.shape == act_output.shape
|
|
assert torch.sum(exp_output - act_output) < 1e-20
|
|
|
|
test_parameters_length()
|
|
test_forward_pass()
|