From 9263bff5ff8441a8d0ef8d79abaf44233c1c8b19 Mon Sep 17 00:00:00 2001 From: David Bau Date: Sat, 27 Aug 2022 05:14:59 -0400 Subject: [PATCH] move_to returns the moved objects. --- baukit/tokendataset.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/baukit/tokendataset.py b/baukit/tokendataset.py index 2f6dd34..1a89780 100644 --- a/baukit/tokendataset.py +++ b/baukit/tokendataset.py @@ -38,9 +38,18 @@ class TokenizedDataset(Dataset): def move_to(device, *containers): """ - Moves a container of tensors to the specified device. + Moves tensors or containers of tensors to the specified device, + moving tensors in-place within arrays, dictionaries, and Modules. + + Example: + [moved_a, moved_b] = move_to('cuda', a, b) + + If arguments are arrays or dictionaries or torch.nn.Modules + containing tensors, the tensors are moved to the given device and + replaced in-place without making a newe instance of the container. """ - for container in containers: + containers = list(containers) + for j, container in enumerate(containers): if isinstance(container, torch.nn.Module): container.to(device) elif isinstance(container, (list, dict)): @@ -49,14 +58,12 @@ def move_to(device, *containers): else: g = enumerate(container) for i, v in g: - moved = move_to(device, v) - if moved is not None and moved is not v: + [moved] = move_to(device, v) + if moved is not v: container[i] = moved elif isinstance(container, (torch.nn.Parameter, torch.Tensor)): - assert len(containers) == 1, 'Use move_to to move tensors in containers.' - return container.to(device) - if len(containers) == 1: - return container + containers[j] = container.to(device) + return containers def length_collation(token_size):