fix serialization

This commit is contained in:
Dr. Kashif Rasul
2020-05-22 15:15:05 +02:00
parent 2d8e72054d
commit e5f66530b1
+6
View File
@@ -168,6 +168,7 @@ class StudentTOutput(DistributionOutput):
class StudentTMixtureOutput(DistributionOutput):
@validated()
def __init__(self, components: int = 1) -> None:
self.components = components
self.args_dim = {
@@ -207,6 +208,7 @@ class StudentTMixtureOutput(DistributionOutput):
class NormalMixtureOutput(DistributionOutput):
@validated()
def __init__(self, components: int = 1) -> None:
self.components = components
self.args_dim = {
@@ -237,6 +239,7 @@ class NormalMixtureOutput(DistributionOutput):
class LowRankMultivariateNormalOutput(DistributionOutput):
@validated()
def __init__(
self, dim: int, rank: int, sigma_init: float = 1.0, sigma_minimum: float = 1e-3,
) -> None:
@@ -270,6 +273,7 @@ class LowRankMultivariateNormalOutput(DistributionOutput):
class IndependentNormalOutput(DistributionOutput):
@validated()
def __init__(self, dim: int) -> None:
self.dim = dim
self.args_dim = {"loc": self.dim, "scale": self.dim}
@@ -294,6 +298,7 @@ class IndependentNormalOutput(DistributionOutput):
class MultivariateNormalOutput(DistributionOutput):
@validated()
def __init__(self, dim: int) -> None:
self.args_dim = {"loc": dim, "scale_tril": dim * dim}
self.dim = dim
@@ -331,6 +336,7 @@ class MultivariateNormalOutput(DistributionOutput):
class FlowOutput(DistributionOutput):
@validated()
def __init__(self, flow, input_size, cond_size):
self.args_dim = {"cond": cond_size}
self.flow = flow