Add optional inputs argument on edit_output.

This commit is contained in:
David Bau
2022-04-18 18:06:11 -04:00
parent 8d2cd35b62
commit 4e625de39a
2 changed files with 17 additions and 10 deletions
+16 -10
View File
@@ -69,6 +69,10 @@ class Trace(contextlib.AbstractContextManager):
module = get_module(module, layer)
def retain_hook(m, inputs, output):
if edit_output:
output = invoke_with_optional_args(
edit_output, output=output, layer=self.layer, inputs=inputs
)
if retain_input:
retainer.input = recursive_copy(
inputs[0] if len(inputs) == 1 else inputs,
@@ -76,10 +80,6 @@ class Trace(contextlib.AbstractContextManager):
detach=detach,
retain_grad=False,
) # retain_grad applies to output only.
if edit_output:
output = invoke_with_optional_args(
edit_output, output=output, layer=self.layer
)
if retain_output:
retainer.output = recursive_copy(
output, clone=clone, detach=detach, retain_grad=retain_grad
@@ -155,15 +155,21 @@ class TraceDict(OrderedDict, contextlib.AbstractContextManager):
yield True, prev
for is_last, layer in flag_last_unseen(layers):
def optional_dict(obj):
if isinstance(obj, dict):
return obj.get(layer, None)
return obj
self[layer] = Trace(
module=module,
layer=layer,
retain_output=retain_output,
retain_input=retain_input,
clone=clone,
detach=detach,
retain_grad=retain_grad,
edit_output=edit_output,
retain_output=optional_dict(retain_output),
retain_input=optional_dict(retain_input),
clone=optional_dict(clone),
detach=optional_dict(detach),
retain_grad=optional_dict(retain_grad),
edit_output=optional_dict(edit_output),
stop=stop and is_last,
)
+1
View File
@@ -0,0 +1 @@
../baukit