mirror of
https://github.com/wassname/baukit.git
synced 2026-06-27 17:29:37 +08:00
Rename dict_to_ to move_to
This commit is contained in:
+1
-1
@@ -17,5 +17,5 @@ from .runningstats import Reservoir, History, CombinedStat
|
||||
from .runningstats import tally
|
||||
from . import show
|
||||
from .workerpool import WorkerBase, WorkerPool
|
||||
from .tokendataset import TokenizedDataset, dict_to_, length_collation
|
||||
from .tokendataset import TokenizedDataset, move_to, length_collation
|
||||
from .tokendataset import make_padded_batch, flatten_masked_batch
|
||||
|
||||
+17
-5
@@ -36,13 +36,25 @@ class TokenizedDataset(Dataset):
|
||||
)
|
||||
|
||||
|
||||
def dict_to_(data, device):
|
||||
def move_to(device, *containers):
|
||||
"""
|
||||
Moves a dictionary of tensors to the specified device.
|
||||
Moves a container of tensors to the specified device.
|
||||
"""
|
||||
for k in data:
|
||||
data[k] = data[k].to(device)
|
||||
return data
|
||||
for container in containers:
|
||||
if isinstance(container, torch.nn.Module):
|
||||
container.to(device)
|
||||
elif isinstance(container, (list, dict)):
|
||||
if isinstance(container, dict):
|
||||
g = list(container.items())
|
||||
else:
|
||||
g = enumerate(list)
|
||||
for i, v in g:
|
||||
moved = move_to(device, v)
|
||||
if moved is not None:
|
||||
container[i] = moved
|
||||
elif isinstance(container, (torch.nn.Parameter, torch.Tensor)):
|
||||
assert len(containers) == 1, 'Use move_to to move tensors in containers.'
|
||||
return container.to(device)
|
||||
|
||||
|
||||
def length_collation(token_size):
|
||||
|
||||
Reference in New Issue
Block a user