mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 22:42:17 +08:00
[rllib] Fix PyTorch initialization (#1961)
* Fix typo * Fix A3C PyTorch agent initialization `registry` needs to be passed as an argument or else the `super` init will fail.
This commit is contained in:
committed by
Philipp Moritz
parent
b55f4a7f04
commit
06a0898af7
@@ -17,9 +17,9 @@ class SharedTorchPolicy(TorchPolicy):
|
||||
other_output = ["vf_preds"]
|
||||
is_recurrent = False
|
||||
|
||||
def __init__(self, ob_space, ac_space, config, **kwargs):
|
||||
def __init__(self, registry, ob_space, ac_space, config, **kwargs):
|
||||
super(SharedTorchPolicy, self).__init__(
|
||||
ob_space, ac_space, config, **kwargs)
|
||||
registry, ob_space, ac_space, config, **kwargs)
|
||||
|
||||
def _setup_graph(self, ob_space, ac_space):
|
||||
_, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
|
||||
|
||||
@@ -23,7 +23,7 @@ def register_trainable(name, trainable):
|
||||
|
||||
Args:
|
||||
name (str): Name to register.
|
||||
trainable (obj): Function or tune.Trainable clsas. Functions must
|
||||
trainable (obj): Function or tune.Trainable class. Functions must
|
||||
take (config, status_reporter) as arguments and will be
|
||||
automatically converted into a class during registration.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user