[feature] mix generation from different tasks

This commit is contained in:
theblackcat102
2023-02-03 00:15:29 +00:00
parent 9be4c921cd
commit 1041564db7
2 changed files with 44 additions and 6 deletions
@@ -1,3 +1,4 @@
import random
from dataclasses import dataclass
from typing import Optional, Union
@@ -16,12 +17,14 @@ class DialogueDataCollator:
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
mix_length_threshold: Optional[int] = 256
mix_probability: Optional[int] = 0.6
pad_to_multiple_of: Optional[int] = None
def __call__(self, features):
flatten_messages = []
label_masks = []
total_short_context = 0
for messages in features:
messages = list(messages)
@@ -58,8 +61,40 @@ class DialogueDataCollator:
label_mask[-1] = False
label_masks.append(label_mask)
if len(flatten_message["input_ids"]) < self.mix_length_threshold:
total_short_context += len(flatten_message["input_ids"])
flatten_messages.append({k: v for k, v in flatten_message.items() if k != "offset_mapping"})
# packing
if total_short_context > 2:
_flatten_messages, _label_masks = [], []
prev_short_msg, prev_short_mask = None, None
for flatten_msg, label_mask in zip(flatten_messages, label_masks):
if len(flatten_msg["input_ids"]) < self.mix_length_threshold and random.random() > 0.6:
if prev_short_msg is not None:
for key in flatten_msg.keys():
flatten_msg[key] += prev_short_msg[key]
flatten_msg[key] = flatten_msg[key][: self.max_length]
label_mask = np.concatenate([label_mask, prev_short_mask])
_label_masks.append(label_mask[: self.max_length])
_flatten_messages.append(flatten_msg)
# reset
prev_short_msg, prev_short_mask = None, None
else:
# prime
prev_short_msg, prev_short_mask = flatten_msg, label_mask
else:
_label_masks.append(label_mask)
_flatten_messages.append(flatten_msg)
if prev_short_msg is not None:
for key in flatten_msg.keys():
flatten_msg[key] += prev_short_msg[key]
flatten_msg[key] = flatten_msg[key][: self.max_length]
label_mask = np.concatenate([label_mask, prev_short_mask])[: self.max_length]
_label_masks.append(label_mask)
_flatten_messages.append(flatten_msg)
label_masks = _label_masks
flatten_messages = _flatten_messages
batch = self.tokenizer.pad(
flatten_messages,
@@ -27,7 +27,7 @@ def test_collate_fn():
config = Namespace(cache_dir=".cache", model_name="Salesforce/codegen-2B-multi")
tokenizer = get_tokenizer(config)
collate_fn = DialogueDataCollator(tokenizer, max_length=512)
collate_fn = DialogueDataCollator(tokenizer, max_length=620)
qa_base = QA_DATASETS
summarize_base = SUMMARIZATION_DATASETS
others = ["prompt_dialogue", "webgpt", "soda", "joke", "gsm8k"]
@@ -40,11 +40,14 @@ def test_collate_fn():
dataloader = DataLoader(ConcatDataset(trains), collate_fn=collate_fn, batch_size=128)
for batch in dataloader:
print(batch.keys())
print(batch["targets"].shape[0])
print(tokenizer.decode(batch["input_ids"][0]))
print("-----")
print(tokenizer.decode(batch["targets"][0][batch["label_masks"][0]]))
assert batch["targets"].shape[1] <= 512
assert batch["targets"].shape[1] <= 620
dataloader = DataLoader(ConcatDataset(evals), collate_fn=collate_fn, batch_size=128)
for batch in dataloader:
assert batch["targets"].shape[1] <= 512
assert batch["targets"].shape[1] <= 620
test_collate_fn()