From 15fd98a1c757ad4c6a715c1e59f58670eb4b267c Mon Sep 17 00:00:00 2001 From: David Bau Date: Tue, 23 Aug 2022 05:34:12 -0400 Subject: [PATCH] Rename dict_to_ to move_to --- baukit/__init__.py | 2 +- baukit/tokendataset.py | 22 +++++++++++++++++----- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/baukit/__init__.py b/baukit/__init__.py index 6526863..51f38ea 100644 --- a/baukit/__init__.py +++ b/baukit/__init__.py @@ -17,5 +17,5 @@ from .runningstats import Reservoir, History, CombinedStat from .runningstats import tally from . import show from .workerpool import WorkerBase, WorkerPool -from .tokendataset import TokenizedDataset, dict_to_, length_collation +from .tokendataset import TokenizedDataset, move_to, length_collation from .tokendataset import make_padded_batch, flatten_masked_batch diff --git a/baukit/tokendataset.py b/baukit/tokendataset.py index e2c7e61..3715c8e 100644 --- a/baukit/tokendataset.py +++ b/baukit/tokendataset.py @@ -36,13 +36,25 @@ class TokenizedDataset(Dataset): ) -def dict_to_(data, device): +def move_to(device, *containers): """ - Moves a dictionary of tensors to the specified device. + Moves a container of tensors to the specified device. """ - for k in data: - data[k] = data[k].to(device) - return data + for container in containers: + if isinstance(container, torch.nn.Module): + container.to(device) + elif isinstance(container, (list, dict)): + if isinstance(container, dict): + g = list(container.items()) + else: + g = enumerate(list) + for i, v in g: + moved = move_to(device, v) + if moved is not None: + container[i] = moved + elif isinstance(container, (torch.nn.Parameter, torch.Tensor)): + assert len(containers) == 1, 'Use move_to to move tensors in containers.' + return container.to(device) def length_collation(token_size):