This commit is contained in:
Kashif Rasul
2022-06-06 11:45:31 +02:00
parent c99572a6d9
commit 0b0ecc94f9
+2 -3
View File
@@ -108,8 +108,7 @@ class SwitchTransformerEstimator(PyTorchLightningEstimator):
self.capacity_factor = capacity_factor
self.is_scale_prob = is_scale_prob
self.drop_tokens = drop_tokens
self.num_feat_dynamic_real = num_feat_dynamic_real
self.num_feat_static_cat = num_feat_static_cat
self.num_feat_static_real = num_feat_static_real
@@ -321,4 +320,4 @@ class SwitchTransformerEstimator(PyTorchLightningEstimator):
num_parallel_samples=self.num_parallel_samples,
)
return TransformerLightningModule(model=model, loss=self.loss)
return SwitchTransformerLightningModule(model=model, loss=self.loss)