Merge pull request #4 from wassname/patch-1

pass kwargs for recursive_copy, as per #3
This commit is contained in:
David Bau
2024-02-22 09:24:02 -05:00
committed by GitHub
+2 -2
View File
@@ -222,9 +222,9 @@ def recursive_copy(x, clone=None, detach=None, retain_grad=None):
return x
# Only dicts, lists, and tuples (and subclasses) can be copied.
if isinstance(x, dict):
return type(x)({k: recursive_copy(v) for k, v in x.items()})
return type(x)({k: recursive_copy(v, clone=clone, detach=detach, retain_grad=retain_grad) for k, v in x.items()})
elif isinstance(x, (list, tuple)):
return type(x)([recursive_copy(v) for v in x])
return type(x)([recursive_copy(v, clone=clone, detach=detach, retain_grad=retain_grad) for v in x])
else:
assert False, f"Unknown type {type(x)} cannot be broken into tensors."