From e5f66530b124daba187b484e0aedd0d7aeebe494 Mon Sep 17 00:00:00 2001 From: "Dr. Kashif Rasul" Date: Fri, 22 May 2020 15:15:05 +0200 Subject: [PATCH] fix serialization --- pts/modules/distribution_output.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pts/modules/distribution_output.py b/pts/modules/distribution_output.py index 2cbe753..dc89fbd 100644 --- a/pts/modules/distribution_output.py +++ b/pts/modules/distribution_output.py @@ -168,6 +168,7 @@ class StudentTOutput(DistributionOutput): class StudentTMixtureOutput(DistributionOutput): + @validated() def __init__(self, components: int = 1) -> None: self.components = components self.args_dim = { @@ -207,6 +208,7 @@ class StudentTMixtureOutput(DistributionOutput): class NormalMixtureOutput(DistributionOutput): + @validated() def __init__(self, components: int = 1) -> None: self.components = components self.args_dim = { @@ -237,6 +239,7 @@ class NormalMixtureOutput(DistributionOutput): class LowRankMultivariateNormalOutput(DistributionOutput): + @validated() def __init__( self, dim: int, rank: int, sigma_init: float = 1.0, sigma_minimum: float = 1e-3, ) -> None: @@ -270,6 +273,7 @@ class LowRankMultivariateNormalOutput(DistributionOutput): class IndependentNormalOutput(DistributionOutput): + @validated() def __init__(self, dim: int) -> None: self.dim = dim self.args_dim = {"loc": self.dim, "scale": self.dim} @@ -294,6 +298,7 @@ class IndependentNormalOutput(DistributionOutput): class MultivariateNormalOutput(DistributionOutput): + @validated() def __init__(self, dim: int) -> None: self.args_dim = {"loc": dim, "scale_tril": dim * dim} self.dim = dim @@ -331,6 +336,7 @@ class MultivariateNormalOutput(DistributionOutput): class FlowOutput(DistributionOutput): + @validated() def __init__(self, flow, input_size, cond_size): self.args_dim = {"cond": cond_size} self.flow = flow