mirror of
https://github.com/wassname/baukit.git
synced 2026-06-27 19:46:31 +08:00
move_to returns the moved objects.
This commit is contained in:
+15
-8
@@ -38,9 +38,18 @@ class TokenizedDataset(Dataset):
|
||||
|
||||
def move_to(device, *containers):
|
||||
"""
|
||||
Moves a container of tensors to the specified device.
|
||||
Moves tensors or containers of tensors to the specified device,
|
||||
moving tensors in-place within arrays, dictionaries, and Modules.
|
||||
|
||||
Example:
|
||||
[moved_a, moved_b] = move_to('cuda', a, b)
|
||||
|
||||
If arguments are arrays or dictionaries or torch.nn.Modules
|
||||
containing tensors, the tensors are moved to the given device and
|
||||
replaced in-place without making a newe instance of the container.
|
||||
"""
|
||||
for container in containers:
|
||||
containers = list(containers)
|
||||
for j, container in enumerate(containers):
|
||||
if isinstance(container, torch.nn.Module):
|
||||
container.to(device)
|
||||
elif isinstance(container, (list, dict)):
|
||||
@@ -49,14 +58,12 @@ def move_to(device, *containers):
|
||||
else:
|
||||
g = enumerate(container)
|
||||
for i, v in g:
|
||||
moved = move_to(device, v)
|
||||
if moved is not None and moved is not v:
|
||||
[moved] = move_to(device, v)
|
||||
if moved is not v:
|
||||
container[i] = moved
|
||||
elif isinstance(container, (torch.nn.Parameter, torch.Tensor)):
|
||||
assert len(containers) == 1, 'Use move_to to move tensors in containers.'
|
||||
return container.to(device)
|
||||
if len(containers) == 1:
|
||||
return container
|
||||
containers[j] = container.to(device)
|
||||
return containers
|
||||
|
||||
|
||||
def length_collation(token_size):
|
||||
|
||||
Reference in New Issue
Block a user