diff --git a/baukit/tokendataset.py b/baukit/tokendataset.py index c704671..2f6dd34 100644 --- a/baukit/tokendataset.py +++ b/baukit/tokendataset.py @@ -50,11 +50,13 @@ def move_to(device, *containers): g = enumerate(container) for i, v in g: moved = move_to(device, v) - if moved is not None: + if moved is not None and moved is not v: container[i] = moved elif isinstance(container, (torch.nn.Parameter, torch.Tensor)): assert len(containers) == 1, 'Use move_to to move tensors in containers.' return container.to(device) + if len(containers) == 1: + return container def length_collation(token_size):