mirror of
https://github.com/wassname/baukit.git
synced 2026-06-27 17:14:53 +08:00
Merge pull request #4 from wassname/patch-1
pass kwargs for recursive_copy, as per #3
This commit is contained in:
+2
-2
@@ -222,9 +222,9 @@ def recursive_copy(x, clone=None, detach=None, retain_grad=None):
|
||||
return x
|
||||
# Only dicts, lists, and tuples (and subclasses) can be copied.
|
||||
if isinstance(x, dict):
|
||||
return type(x)({k: recursive_copy(v) for k, v in x.items()})
|
||||
return type(x)({k: recursive_copy(v, clone=clone, detach=detach, retain_grad=retain_grad) for k, v in x.items()})
|
||||
elif isinstance(x, (list, tuple)):
|
||||
return type(x)([recursive_copy(v) for v in x])
|
||||
return type(x)([recursive_copy(v, clone=clone, detach=detach, retain_grad=retain_grad) for v in x])
|
||||
else:
|
||||
assert False, f"Unknown type {type(x)} cannot be broken into tensors."
|
||||
|
||||
|
||||
Reference in New Issue
Block a user