mirror of
https://github.com/wassname/baukit.git
synced 2026-06-27 17:14:53 +08:00
Add optional inputs argument on edit_output.
This commit is contained in:
+16
-10
@@ -69,6 +69,10 @@ class Trace(contextlib.AbstractContextManager):
|
||||
module = get_module(module, layer)
|
||||
|
||||
def retain_hook(m, inputs, output):
|
||||
if edit_output:
|
||||
output = invoke_with_optional_args(
|
||||
edit_output, output=output, layer=self.layer, inputs=inputs
|
||||
)
|
||||
if retain_input:
|
||||
retainer.input = recursive_copy(
|
||||
inputs[0] if len(inputs) == 1 else inputs,
|
||||
@@ -76,10 +80,6 @@ class Trace(contextlib.AbstractContextManager):
|
||||
detach=detach,
|
||||
retain_grad=False,
|
||||
) # retain_grad applies to output only.
|
||||
if edit_output:
|
||||
output = invoke_with_optional_args(
|
||||
edit_output, output=output, layer=self.layer
|
||||
)
|
||||
if retain_output:
|
||||
retainer.output = recursive_copy(
|
||||
output, clone=clone, detach=detach, retain_grad=retain_grad
|
||||
@@ -155,15 +155,21 @@ class TraceDict(OrderedDict, contextlib.AbstractContextManager):
|
||||
yield True, prev
|
||||
|
||||
for is_last, layer in flag_last_unseen(layers):
|
||||
|
||||
def optional_dict(obj):
|
||||
if isinstance(obj, dict):
|
||||
return obj.get(layer, None)
|
||||
return obj
|
||||
|
||||
self[layer] = Trace(
|
||||
module=module,
|
||||
layer=layer,
|
||||
retain_output=retain_output,
|
||||
retain_input=retain_input,
|
||||
clone=clone,
|
||||
detach=detach,
|
||||
retain_grad=retain_grad,
|
||||
edit_output=edit_output,
|
||||
retain_output=optional_dict(retain_output),
|
||||
retain_input=optional_dict(retain_input),
|
||||
clone=optional_dict(clone),
|
||||
detach=optional_dict(detach),
|
||||
retain_grad=optional_dict(retain_grad),
|
||||
edit_output=optional_dict(edit_output),
|
||||
stop=stop and is_last,
|
||||
)
|
||||
|
||||
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../baukit
|
||||
Reference in New Issue
Block a user