From b3cf3afc8ecd473918e0e5f411cfb30c687458c9 Mon Sep 17 00:00:00 2001 From: Michael J Clark Date: Fri, 29 Dec 2023 16:14:54 +0800 Subject: [PATCH] pass kwargs for recursive_copy, as per #3 for #3 --- baukit/nethook.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/baukit/nethook.py b/baukit/nethook.py index 9c12faa..532088d 100644 --- a/baukit/nethook.py +++ b/baukit/nethook.py @@ -222,9 +222,9 @@ def recursive_copy(x, clone=None, detach=None, retain_grad=None): return x # Only dicts, lists, and tuples (and subclasses) can be copied. if isinstance(x, dict): - return type(x)({k: recursive_copy(v) for k, v in x.items()}) + return type(x)({k: recursive_copy(v, clone=clone, detach=detach, retain_grad=retain_grad) for k, v in x.items()}) elif isinstance(x, (list, tuple)): - return type(x)([recursive_copy(v) for v in x]) + return type(x)([recursive_copy(v, clone=clone, detach=detach, retain_grad=retain_grad) for v in x]) else: assert False, f"Unknown type {type(x)} cannot be broken into tensors."