Rename dict_to_ to move_to

This commit is contained in:
David Bau
2022-08-23 05:34:12 -04:00
parent add9bb03ad
commit 15fd98a1c7
2 changed files with 18 additions and 6 deletions
+1 -1
View File
@@ -17,5 +17,5 @@ from .runningstats import Reservoir, History, CombinedStat
from .runningstats import tally
from . import show
from .workerpool import WorkerBase, WorkerPool
from .tokendataset import TokenizedDataset, dict_to_, length_collation
from .tokendataset import TokenizedDataset, move_to, length_collation
from .tokendataset import make_padded_batch, flatten_masked_batch
+17 -5
View File
@@ -36,13 +36,25 @@ class TokenizedDataset(Dataset):
)
def dict_to_(data, device):
def move_to(device, *containers):
"""
Moves a dictionary of tensors to the specified device.
Moves a container of tensors to the specified device.
"""
for k in data:
data[k] = data[k].to(device)
return data
for container in containers:
if isinstance(container, torch.nn.Module):
container.to(device)
elif isinstance(container, (list, dict)):
if isinstance(container, dict):
g = list(container.items())
else:
g = enumerate(list)
for i, v in g:
moved = move_to(device, v)
if moved is not None:
container[i] = moved
elif isinstance(container, (torch.nn.Parameter, torch.Tensor)):
assert len(containers) == 1, 'Use move_to to move tensors in containers.'
return container.to(device)
def length_collation(token_size):