diff --git a/test/modules/test_distribution_output.py b/test/modules/test_distribution_output.py index 750cafe..fe81cff 100644 --- a/test/modules/test_distribution_output.py +++ b/test/modules/test_distribution_output.py @@ -212,8 +212,8 @@ def test_lowrank_multivariate_normal() -> None: ) distr = LowRankMultivariateNormal( - loc=torch.Tensor(loc_hat), - cov_diag=torch.Tensor(cov_diag_hat), + loc=torch.Tensor(loc_hat), + cov_diag=torch.Tensor(cov_diag_hat), cov_factor=torch.Tensor(cov_factor_hat), )