move_to returns the moved objects.

This commit is contained in:
David Bau
2022-08-27 05:14:59 -04:00
parent 0df3359918
commit 9263bff5ff
+15 -8
View File
@@ -38,9 +38,18 @@ class TokenizedDataset(Dataset):
def move_to(device, *containers):
"""
Moves a container of tensors to the specified device.
Moves tensors or containers of tensors to the specified device,
moving tensors in-place within arrays, dictionaries, and Modules.
Example:
[moved_a, moved_b] = move_to('cuda', a, b)
If arguments are arrays or dictionaries or torch.nn.Modules
containing tensors, the tensors are moved to the given device and
replaced in-place without making a newe instance of the container.
"""
for container in containers:
containers = list(containers)
for j, container in enumerate(containers):
if isinstance(container, torch.nn.Module):
container.to(device)
elif isinstance(container, (list, dict)):
@@ -49,14 +58,12 @@ def move_to(device, *containers):
else:
g = enumerate(container)
for i, v in g:
moved = move_to(device, v)
if moved is not None and moved is not v:
[moved] = move_to(device, v)
if moved is not v:
container[i] = moved
elif isinstance(container, (torch.nn.Parameter, torch.Tensor)):
assert len(containers) == 1, 'Use move_to to move tensors in containers.'
return container.to(device)
if len(containers) == 1:
return container
containers[j] = container.to(device)
return containers
def length_collation(token_size):