mirror of
https://github.com/wassname/baukit.git
synced 2026-06-27 17:14:53 +08:00
Add module_ and parameter_ names.
This commit is contained in:
@@ -372,6 +372,20 @@ def get_parameter(model, name):
|
||||
raise LookupError(name)
|
||||
|
||||
|
||||
def module_names(model):
|
||||
"""
|
||||
Lists all the module names.
|
||||
"""
|
||||
return [n for n, _ in model.named_modules()]
|
||||
|
||||
|
||||
def parameter_names(model):
|
||||
"""
|
||||
Lists all the parameter names.
|
||||
"""
|
||||
return [n for n, _ in model.named_parameters()]
|
||||
|
||||
|
||||
def replace_module(model, name, new_module):
|
||||
"""
|
||||
Replaces the named module within the given model.
|
||||
|
||||
Reference in New Issue
Block a user