pass kwargs for recursive_copy, as per #3

for #3
This commit is contained in:
Michael J Clark
2023-12-29 16:14:54 +08:00
committed by GitHub
parent 5e23007c02
commit b3cf3afc8e
+2 -2
View File
@@ -222,9 +222,9 @@ def recursive_copy(x, clone=None, detach=None, retain_grad=None):
return x
# Only dicts, lists, and tuples (and subclasses) can be copied.
if isinstance(x, dict):
return type(x)({k: recursive_copy(v) for k, v in x.items()})
return type(x)({k: recursive_copy(v, clone=clone, detach=detach, retain_grad=retain_grad) for k, v in x.items()})
elif isinstance(x, (list, tuple)):
return type(x)([recursive_copy(v) for v in x])
return type(x)([recursive_copy(v, clone=clone, detach=detach, retain_grad=retain_grad) for v in x])
else:
assert False, f"Unknown type {type(x)} cannot be broken into tensors."