From 1041564db7fa7cc4bb3f62710c9e1c10ef7b4218 Mon Sep 17 00:00:00 2001 From: theblackcat102 Date: Fri, 3 Feb 2023 00:15:29 +0000 Subject: [PATCH] [feature] mix generation from different tasks --- .../custom_datasets/dialogue_collator.py | 39 ++++++++++++++++++- .../tests/test_datasets.py | 11 ++++-- 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/model/supervised_finetuning/custom_datasets/dialogue_collator.py b/model/supervised_finetuning/custom_datasets/dialogue_collator.py index c96ed576..d812bb26 100644 --- a/model/supervised_finetuning/custom_datasets/dialogue_collator.py +++ b/model/supervised_finetuning/custom_datasets/dialogue_collator.py @@ -1,3 +1,4 @@ +import random from dataclasses import dataclass from typing import Optional, Union @@ -16,12 +17,14 @@ class DialogueDataCollator: tokenizer: PreTrainedTokenizerBase padding: Union[bool, str, PaddingStrategy] = True max_length: Optional[int] = None + mix_length_threshold: Optional[int] = 256 + mix_probability: Optional[int] = 0.6 pad_to_multiple_of: Optional[int] = None def __call__(self, features): flatten_messages = [] label_masks = [] - + total_short_context = 0 for messages in features: messages = list(messages) @@ -58,8 +61,40 @@ class DialogueDataCollator: label_mask[-1] = False label_masks.append(label_mask) - + if len(flatten_message["input_ids"]) < self.mix_length_threshold: + total_short_context += len(flatten_message["input_ids"]) flatten_messages.append({k: v for k, v in flatten_message.items() if k != "offset_mapping"}) + # packing + if total_short_context > 2: + _flatten_messages, _label_masks = [], [] + prev_short_msg, prev_short_mask = None, None + for flatten_msg, label_mask in zip(flatten_messages, label_masks): + if len(flatten_msg["input_ids"]) < self.mix_length_threshold and random.random() > 0.6: + if prev_short_msg is not None: + for key in flatten_msg.keys(): + flatten_msg[key] += prev_short_msg[key] + flatten_msg[key] = flatten_msg[key][: self.max_length] + label_mask = np.concatenate([label_mask, prev_short_mask]) + _label_masks.append(label_mask[: self.max_length]) + _flatten_messages.append(flatten_msg) + # reset + prev_short_msg, prev_short_mask = None, None + else: + # prime + prev_short_msg, prev_short_mask = flatten_msg, label_mask + else: + _label_masks.append(label_mask) + _flatten_messages.append(flatten_msg) + if prev_short_msg is not None: + for key in flatten_msg.keys(): + flatten_msg[key] += prev_short_msg[key] + flatten_msg[key] = flatten_msg[key][: self.max_length] + label_mask = np.concatenate([label_mask, prev_short_mask])[: self.max_length] + _label_masks.append(label_mask) + _flatten_messages.append(flatten_msg) + + label_masks = _label_masks + flatten_messages = _flatten_messages batch = self.tokenizer.pad( flatten_messages, diff --git a/model/supervised_finetuning/tests/test_datasets.py b/model/supervised_finetuning/tests/test_datasets.py index 8d5ad08f..2a0d4481 100644 --- a/model/supervised_finetuning/tests/test_datasets.py +++ b/model/supervised_finetuning/tests/test_datasets.py @@ -27,7 +27,7 @@ def test_collate_fn(): config = Namespace(cache_dir=".cache", model_name="Salesforce/codegen-2B-multi") tokenizer = get_tokenizer(config) - collate_fn = DialogueDataCollator(tokenizer, max_length=512) + collate_fn = DialogueDataCollator(tokenizer, max_length=620) qa_base = QA_DATASETS summarize_base = SUMMARIZATION_DATASETS others = ["prompt_dialogue", "webgpt", "soda", "joke", "gsm8k"] @@ -40,11 +40,14 @@ def test_collate_fn(): dataloader = DataLoader(ConcatDataset(trains), collate_fn=collate_fn, batch_size=128) for batch in dataloader: - print(batch.keys()) + print(batch["targets"].shape[0]) print(tokenizer.decode(batch["input_ids"][0])) print("-----") print(tokenizer.decode(batch["targets"][0][batch["label_masks"][0]])) - assert batch["targets"].shape[1] <= 512 + assert batch["targets"].shape[1] <= 620 dataloader = DataLoader(ConcatDataset(evals), collate_fn=collate_fn, batch_size=128) for batch in dataloader: - assert batch["targets"].shape[1] <= 512 + assert batch["targets"].shape[1] <= 620 + + +test_collate_fn()