Add module_ and parameter_ names.

This commit is contained in:
David Bau
2022-03-23 20:59:53 -04:00
parent b850af0e39
commit 1e4782a6e9
+14
View File
@@ -372,6 +372,20 @@ def get_parameter(model, name):
raise LookupError(name)
def module_names(model):
"""
Lists all the module names.
"""
return [n for n, _ in model.named_modules()]
def parameter_names(model):
"""
Lists all the parameter names.
"""
return [n for n, _ in model.named_parameters()]
def replace_module(model, name, new_module):
"""
Replaces the named module within the given model.