mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 19:32:05 +08:00
fix serialization
This commit is contained in:
@@ -168,6 +168,7 @@ class StudentTOutput(DistributionOutput):
|
||||
|
||||
|
||||
class StudentTMixtureOutput(DistributionOutput):
|
||||
@validated()
|
||||
def __init__(self, components: int = 1) -> None:
|
||||
self.components = components
|
||||
self.args_dim = {
|
||||
@@ -207,6 +208,7 @@ class StudentTMixtureOutput(DistributionOutput):
|
||||
|
||||
|
||||
class NormalMixtureOutput(DistributionOutput):
|
||||
@validated()
|
||||
def __init__(self, components: int = 1) -> None:
|
||||
self.components = components
|
||||
self.args_dim = {
|
||||
@@ -237,6 +239,7 @@ class NormalMixtureOutput(DistributionOutput):
|
||||
|
||||
|
||||
class LowRankMultivariateNormalOutput(DistributionOutput):
|
||||
@validated()
|
||||
def __init__(
|
||||
self, dim: int, rank: int, sigma_init: float = 1.0, sigma_minimum: float = 1e-3,
|
||||
) -> None:
|
||||
@@ -270,6 +273,7 @@ class LowRankMultivariateNormalOutput(DistributionOutput):
|
||||
|
||||
|
||||
class IndependentNormalOutput(DistributionOutput):
|
||||
@validated()
|
||||
def __init__(self, dim: int) -> None:
|
||||
self.dim = dim
|
||||
self.args_dim = {"loc": self.dim, "scale": self.dim}
|
||||
@@ -294,6 +298,7 @@ class IndependentNormalOutput(DistributionOutput):
|
||||
|
||||
|
||||
class MultivariateNormalOutput(DistributionOutput):
|
||||
@validated()
|
||||
def __init__(self, dim: int) -> None:
|
||||
self.args_dim = {"loc": dim, "scale_tril": dim * dim}
|
||||
self.dim = dim
|
||||
@@ -331,6 +336,7 @@ class MultivariateNormalOutput(DistributionOutput):
|
||||
|
||||
|
||||
class FlowOutput(DistributionOutput):
|
||||
@validated()
|
||||
def __init__(self, flow, input_size, cond_size):
|
||||
self.args_dim = {"cond": cond_size}
|
||||
self.flow = flow
|
||||
|
||||
Reference in New Issue
Block a user