diff --git a/baukit/nethook.py b/baukit/nethook.py index 24da005..195e315 100644 --- a/baukit/nethook.py +++ b/baukit/nethook.py @@ -69,6 +69,10 @@ class Trace(contextlib.AbstractContextManager): module = get_module(module, layer) def retain_hook(m, inputs, output): + if edit_output: + output = invoke_with_optional_args( + edit_output, output=output, layer=self.layer, inputs=inputs + ) if retain_input: retainer.input = recursive_copy( inputs[0] if len(inputs) == 1 else inputs, @@ -76,10 +80,6 @@ class Trace(contextlib.AbstractContextManager): detach=detach, retain_grad=False, ) # retain_grad applies to output only. - if edit_output: - output = invoke_with_optional_args( - edit_output, output=output, layer=self.layer - ) if retain_output: retainer.output = recursive_copy( output, clone=clone, detach=detach, retain_grad=retain_grad @@ -155,15 +155,21 @@ class TraceDict(OrderedDict, contextlib.AbstractContextManager): yield True, prev for is_last, layer in flag_last_unseen(layers): + + def optional_dict(obj): + if isinstance(obj, dict): + return obj.get(layer, None) + return obj + self[layer] = Trace( module=module, layer=layer, - retain_output=retain_output, - retain_input=retain_input, - clone=clone, - detach=detach, - retain_grad=retain_grad, - edit_output=edit_output, + retain_output=optional_dict(retain_output), + retain_input=optional_dict(retain_input), + clone=optional_dict(clone), + detach=optional_dict(detach), + retain_grad=optional_dict(retain_grad), + edit_output=optional_dict(edit_output), stop=stop and is_last, ) diff --git a/notebooks/baukit b/notebooks/baukit new file mode 120000 index 0000000..68d4a75 --- /dev/null +++ b/notebooks/baukit @@ -0,0 +1 @@ +../baukit \ No newline at end of file