From bbbf111f8d12b441abe3b1e5e7e48557eba941e2 Mon Sep 17 00:00:00 2001 From: Michael J Clark Date: Wed, 7 Aug 2024 21:50:03 +0800 Subject: [PATCH] retain_grad on input --- baukit/nethook.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/baukit/nethook.py b/baukit/nethook.py index 532088d..10d1b8b 100644 --- a/baukit/nethook.py +++ b/baukit/nethook.py @@ -78,8 +78,8 @@ class Trace(contextlib.AbstractContextManager): inputs[0] if len(inputs) == 1 else inputs, clone=clone, detach=detach, - retain_grad=False, - ) # retain_grad applies to output only. + retain_grad=retain_grad, + ) if retain_output: retainer.output = recursive_copy( output, clone=clone, detach=detach, retain_grad=retain_grad