From 0df3359918a2c6d3b7f3bb3a281ec90e409cf0fc Mon Sep 17 00:00:00 2001 From: David Bau Date: Sat, 27 Aug 2022 05:03:12 -0400 Subject: [PATCH] Return container itself from singleton move_to. --- baukit/tokendataset.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/baukit/tokendataset.py b/baukit/tokendataset.py index c704671..2f6dd34 100644 --- a/baukit/tokendataset.py +++ b/baukit/tokendataset.py @@ -50,11 +50,13 @@ def move_to(device, *containers): g = enumerate(container) for i, v in g: moved = move_to(device, v) - if moved is not None: + if moved is not None and moved is not v: container[i] = moved elif isinstance(container, (torch.nn.Parameter, torch.Tensor)): assert len(containers) == 1, 'Use move_to to move tensors in containers.' return container.to(device) + if len(containers) == 1: + return container def length_collation(token_size):