diff --git a/baukit/nethook.py b/baukit/nethook.py index 12177ca..24da005 100644 --- a/baukit/nethook.py +++ b/baukit/nethook.py @@ -372,6 +372,20 @@ def get_parameter(model, name): raise LookupError(name) +def module_names(model): + """ + Lists all the module names. + """ + return [n for n, _ in model.named_modules()] + + +def parameter_names(model): + """ + Lists all the parameter names. + """ + return [n for n, _ in model.named_parameters()] + + def replace_module(model, name, new_module): """ Replaces the named module within the given model.