diff --git a/baukit/nethook.py b/baukit/nethook.py index 532088d..10d1b8b 100644 --- a/baukit/nethook.py +++ b/baukit/nethook.py @@ -78,8 +78,8 @@ class Trace(contextlib.AbstractContextManager): inputs[0] if len(inputs) == 1 else inputs, clone=clone, detach=detach, - retain_grad=False, - ) # retain_grad applies to output only. + retain_grad=retain_grad, + ) if retain_output: retainer.output = recursive_copy( output, clone=clone, detach=detach, retain_grad=retain_grad