This commit is contained in:
Sotirios Anagnostidis
2023-01-06 21:29:38 +01:00
parent 88ee3b3264
commit 148244455c
3 changed files with 21 additions and 20 deletions
@@ -24,26 +24,24 @@ class DialogueDataCollator:
flatten_messages = []
label_masks = []
for messages in features:
assert len(messages) % 2 == 0, "Number of messages must be even"
for feature_one in features:
assert len(feature_one) % 2 == 0, "Number of messages must be even"
messages = [
(QA_SPECIAL_TOKENS["Question"] if i % 2 == 0 else "")
+ x
+ (QA_SPECIAL_TOKENS["Answer"] if i % 2 == 0 else "")
for i, x in enumerate(messages)
for i, x in enumerate(feature_one)
]
# Add a way for the model to terminate generation
# When we predict the start of a new expected question, we want to be able to stop generation
messages.append(QA_SPECIAL_TOKENS["Question"])
flatten_messages.append(
self.tokenizer(
"".join(messages),
truncation=True,
max_length=self.max_length,
return_offsets_mapping=True,
)
flatten_message = self.tokenizer(
"".join(messages),
truncation=True,
max_length=self.max_length,
return_offsets_mapping=True,
)
message_change_indices = np.cumsum([len(x) for x in messages[:-1]])
@@ -57,18 +55,19 @@ class DialogueDataCollator:
message_indices = list(
map(
lambda x: next((i for i, val in enumerate(message_change_indices) if val >= x), -2),
list(map(lambda x: x[1], flatten_messages[-1]["offset_mapping"])),
list(map(lambda x: x[1], flatten_message["offset_mapping"])),
)
)
label_mask = np.roll(list(map(lambda x: x % 2 == 1, message_indices)), -1, -1)
try:
label_mask[[i for i in range(len(message_indices)) if message_indices[i] == -2][0] - 1] = True
except IndexError:
# an aftermath of padding
pass
# due to truncation, we might not have the last termination token
label_mask[-1] = False
label_masks.append(label_mask)
flatten_messages[-1].pop("offset_mapping")
flatten_messages.append({k: v for k, v in flatten_message.items() if k != "offset_mapping"})
batch = self.tokenizer.pad(
flatten_messages,
@@ -79,10 +78,9 @@ class DialogueDataCollator:
)
dim = batch["input_ids"].shape[-1]
batch["label_masks"] = torch.stack([F.pad(torch.tensor(x), (0, dim - len(x))) for x in label_masks])
for k in list(batch.keys()):
if k not in ["input_ids", "attention_mask", "label_masks"]:
batch.pop(k)
batch["label_masks"] = torch.stack(
[F.pad(torch.tensor(x), (0, dim - len(x)), value=False) for x in label_masks]
)
batch["targets"] = torch.roll(batch["input_ids"], -1, -1)
return batch
+1 -1
View File
@@ -171,7 +171,7 @@ def get_model(model_name, cache_dir, quantization):
if quantization is None:
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
elif quantization == "8bit":
print("Loading 8-bit model")
raise ValueError("Loading 8-bit model. Bitsandbytes does not behave so far...")
transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
add_adapters(model)
@@ -4,3 +4,6 @@ PyYAML==6.0
scikit_learn==1.2.0
torch==1.13.1
transformers==4.25.1
deepspeed==0.7.7
mpi4py==3.1.4
accelerate==0.15.0