Files
DeepTime/models/__init__.py
T
2022-07-13 16:03:34 +08:00

14 lines
340 B
Python

from typing import Union
import torch
from .DeepTIMe import deeptime
def get_model(model_type: str, **kwargs: Union[int, float]) -> torch.nn.Module:
if model_type == 'deeptime':
model = deeptime(datetime_feats=kwargs['datetime_feats'])
else:
raise ValueError(f"Unknown model type {model_type}")
return model