Return container itself from singleton move_to.

This commit is contained in:
David Bau
2022-08-27 05:03:12 -04:00
parent 5390101e06
commit 0df3359918
+3 -1
View File
@@ -50,11 +50,13 @@ def move_to(device, *containers):
g = enumerate(container)
for i, v in g:
moved = move_to(device, v)
if moved is not None:
if moved is not None and moved is not v:
container[i] = moved
elif isinstance(container, (torch.nn.Parameter, torch.Tensor)):
assert len(containers) == 1, 'Use move_to to move tensors in containers.'
return container.to(device)
if len(containers) == 1:
return container
def length_collation(token_size):