diff --git a/baukit/nethook.py b/baukit/nethook.py index 9c12faa..532088d 100644 --- a/baukit/nethook.py +++ b/baukit/nethook.py @@ -222,9 +222,9 @@ def recursive_copy(x, clone=None, detach=None, retain_grad=None): return x # Only dicts, lists, and tuples (and subclasses) can be copied. if isinstance(x, dict): - return type(x)({k: recursive_copy(v) for k, v in x.items()}) + return type(x)({k: recursive_copy(v, clone=clone, detach=detach, retain_grad=retain_grad) for k, v in x.items()}) elif isinstance(x, (list, tuple)): - return type(x)([recursive_copy(v) for v in x]) + return type(x)([recursive_copy(v, clone=clone, detach=detach, retain_grad=retain_grad) for v in x]) else: assert False, f"Unknown type {type(x)} cannot be broken into tensors."