retain_grad on input

This commit is contained in:
Michael J Clark
2024-08-07 21:50:03 +08:00
committed by GitHub
parent 9d51abd51e
commit bbbf111f8d
+2 -2
View File
@@ -78,8 +78,8 @@ class Trace(contextlib.AbstractContextManager):
inputs[0] if len(inputs) == 1 else inputs,
clone=clone,
detach=detach,
retain_grad=False,
) # retain_grad applies to output only.
retain_grad=retain_grad,
)
if retain_output:
retainer.output = recursive_copy(
output, clone=clone, detach=detach, retain_grad=retain_grad