[rllib] Fix PyTorch initialization (#1961)

* Fix typo

* Fix A3C PyTorch agent initialization

`registry` needs to be passed as an argument or else the `super` init will
fail.
This commit is contained in:
Alok Singh
2018-05-01 18:39:01 -07:00
committed by Philipp Moritz
parent b55f4a7f04
commit 06a0898af7
2 changed files with 3 additions and 3 deletions
+2 -2
View File
@@ -17,9 +17,9 @@ class SharedTorchPolicy(TorchPolicy):
other_output = ["vf_preds"]
is_recurrent = False
def __init__(self, ob_space, ac_space, config, **kwargs):
def __init__(self, registry, ob_space, ac_space, config, **kwargs):
super(SharedTorchPolicy, self).__init__(
ob_space, ac_space, config, **kwargs)
registry, ob_space, ac_space, config, **kwargs)
def _setup_graph(self, ob_space, ac_space):
_, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
+1 -1
View File
@@ -23,7 +23,7 @@ def register_trainable(name, trainable):
Args:
name (str): Name to register.
trainable (obj): Function or tune.Trainable clsas. Functions must
trainable (obj): Function or tune.Trainable class. Functions must
take (config, status_reporter) as arguments and will be
automatically converted into a class during registration.
"""