diff --git a/python/ray/rllib/a3c/shared_torch_policy.py b/python/ray/rllib/a3c/shared_torch_policy.py index 59b7a2577..36b39dcfc 100644 --- a/python/ray/rllib/a3c/shared_torch_policy.py +++ b/python/ray/rllib/a3c/shared_torch_policy.py @@ -17,9 +17,9 @@ class SharedTorchPolicy(TorchPolicy): other_output = ["vf_preds"] is_recurrent = False - def __init__(self, ob_space, ac_space, config, **kwargs): + def __init__(self, registry, ob_space, ac_space, config, **kwargs): super(SharedTorchPolicy, self).__init__( - ob_space, ac_space, config, **kwargs) + registry, ob_space, ac_space, config, **kwargs) def _setup_graph(self, ob_space, ac_space): _, self.logit_dim = ModelCatalog.get_action_dist(ac_space) diff --git a/python/ray/tune/registry.py b/python/ray/tune/registry.py index f17267eaa..cb9505771 100644 --- a/python/ray/tune/registry.py +++ b/python/ray/tune/registry.py @@ -23,7 +23,7 @@ def register_trainable(name, trainable): Args: name (str): Name to register. - trainable (obj): Function or tune.Trainable clsas. Functions must + trainable (obj): Function or tune.Trainable class. Functions must take (config, status_reporter) as arguments and will be automatically converted into a class during registration. """