From 06a0898af7b194720590826b7e76baf6f02e009a Mon Sep 17 00:00:00 2001 From: Alok Singh <8325708+alok@users.noreply.github.com> Date: Tue, 1 May 2018 18:39:01 -0700 Subject: [PATCH] [rllib] Fix PyTorch initialization (#1961) * Fix typo * Fix A3C PyTorch agent initialization `registry` needs to be passed as an argument or else the `super` init will fail. --- python/ray/rllib/a3c/shared_torch_policy.py | 4 ++-- python/ray/tune/registry.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/ray/rllib/a3c/shared_torch_policy.py b/python/ray/rllib/a3c/shared_torch_policy.py index 59b7a2577..36b39dcfc 100644 --- a/python/ray/rllib/a3c/shared_torch_policy.py +++ b/python/ray/rllib/a3c/shared_torch_policy.py @@ -17,9 +17,9 @@ class SharedTorchPolicy(TorchPolicy): other_output = ["vf_preds"] is_recurrent = False - def __init__(self, ob_space, ac_space, config, **kwargs): + def __init__(self, registry, ob_space, ac_space, config, **kwargs): super(SharedTorchPolicy, self).__init__( - ob_space, ac_space, config, **kwargs) + registry, ob_space, ac_space, config, **kwargs) def _setup_graph(self, ob_space, ac_space): _, self.logit_dim = ModelCatalog.get_action_dist(ac_space) diff --git a/python/ray/tune/registry.py b/python/ray/tune/registry.py index f17267eaa..cb9505771 100644 --- a/python/ray/tune/registry.py +++ b/python/ray/tune/registry.py @@ -23,7 +23,7 @@ def register_trainable(name, trainable): Args: name (str): Name to register. - trainable (obj): Function or tune.Trainable clsas. Functions must + trainable (obj): Function or tune.Trainable class. Functions must take (config, status_reporter) as arguments and will be automatically converted into a class during registration. """