mirror of
https://github.com/wassname/baukit.git
synced 2026-06-27 17:14:53 +08:00
Return container itself from singleton move_to.
This commit is contained in:
@@ -50,11 +50,13 @@ def move_to(device, *containers):
|
||||
g = enumerate(container)
|
||||
for i, v in g:
|
||||
moved = move_to(device, v)
|
||||
if moved is not None:
|
||||
if moved is not None and moved is not v:
|
||||
container[i] = moved
|
||||
elif isinstance(container, (torch.nn.Parameter, torch.Tensor)):
|
||||
assert len(containers) == 1, 'Use move_to to move tensors in containers.'
|
||||
return container.to(device)
|
||||
if len(containers) == 1:
|
||||
return container
|
||||
|
||||
|
||||
def length_collation(token_size):
|
||||
|
||||
Reference in New Issue
Block a user