mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
[feature] mix generation from different tasks
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
@@ -16,12 +17,14 @@ class DialogueDataCollator:
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
mix_length_threshold: Optional[int] = 256
|
||||
mix_probability: Optional[int] = 0.6
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
|
||||
def __call__(self, features):
|
||||
flatten_messages = []
|
||||
label_masks = []
|
||||
|
||||
total_short_context = 0
|
||||
for messages in features:
|
||||
messages = list(messages)
|
||||
|
||||
@@ -58,8 +61,40 @@ class DialogueDataCollator:
|
||||
label_mask[-1] = False
|
||||
|
||||
label_masks.append(label_mask)
|
||||
|
||||
if len(flatten_message["input_ids"]) < self.mix_length_threshold:
|
||||
total_short_context += len(flatten_message["input_ids"])
|
||||
flatten_messages.append({k: v for k, v in flatten_message.items() if k != "offset_mapping"})
|
||||
# packing
|
||||
if total_short_context > 2:
|
||||
_flatten_messages, _label_masks = [], []
|
||||
prev_short_msg, prev_short_mask = None, None
|
||||
for flatten_msg, label_mask in zip(flatten_messages, label_masks):
|
||||
if len(flatten_msg["input_ids"]) < self.mix_length_threshold and random.random() > 0.6:
|
||||
if prev_short_msg is not None:
|
||||
for key in flatten_msg.keys():
|
||||
flatten_msg[key] += prev_short_msg[key]
|
||||
flatten_msg[key] = flatten_msg[key][: self.max_length]
|
||||
label_mask = np.concatenate([label_mask, prev_short_mask])
|
||||
_label_masks.append(label_mask[: self.max_length])
|
||||
_flatten_messages.append(flatten_msg)
|
||||
# reset
|
||||
prev_short_msg, prev_short_mask = None, None
|
||||
else:
|
||||
# prime
|
||||
prev_short_msg, prev_short_mask = flatten_msg, label_mask
|
||||
else:
|
||||
_label_masks.append(label_mask)
|
||||
_flatten_messages.append(flatten_msg)
|
||||
if prev_short_msg is not None:
|
||||
for key in flatten_msg.keys():
|
||||
flatten_msg[key] += prev_short_msg[key]
|
||||
flatten_msg[key] = flatten_msg[key][: self.max_length]
|
||||
label_mask = np.concatenate([label_mask, prev_short_mask])[: self.max_length]
|
||||
_label_masks.append(label_mask)
|
||||
_flatten_messages.append(flatten_msg)
|
||||
|
||||
label_masks = _label_masks
|
||||
flatten_messages = _flatten_messages
|
||||
|
||||
batch = self.tokenizer.pad(
|
||||
flatten_messages,
|
||||
|
||||
@@ -27,7 +27,7 @@ def test_collate_fn():
|
||||
|
||||
config = Namespace(cache_dir=".cache", model_name="Salesforce/codegen-2B-multi")
|
||||
tokenizer = get_tokenizer(config)
|
||||
collate_fn = DialogueDataCollator(tokenizer, max_length=512)
|
||||
collate_fn = DialogueDataCollator(tokenizer, max_length=620)
|
||||
qa_base = QA_DATASETS
|
||||
summarize_base = SUMMARIZATION_DATASETS
|
||||
others = ["prompt_dialogue", "webgpt", "soda", "joke", "gsm8k"]
|
||||
@@ -40,11 +40,14 @@ def test_collate_fn():
|
||||
|
||||
dataloader = DataLoader(ConcatDataset(trains), collate_fn=collate_fn, batch_size=128)
|
||||
for batch in dataloader:
|
||||
print(batch.keys())
|
||||
print(batch["targets"].shape[0])
|
||||
print(tokenizer.decode(batch["input_ids"][0]))
|
||||
print("-----")
|
||||
print(tokenizer.decode(batch["targets"][0][batch["label_masks"][0]]))
|
||||
assert batch["targets"].shape[1] <= 512
|
||||
assert batch["targets"].shape[1] <= 620
|
||||
dataloader = DataLoader(ConcatDataset(evals), collate_fn=collate_fn, batch_size=128)
|
||||
for batch in dataloader:
|
||||
assert batch["targets"].shape[1] <= 512
|
||||
assert batch["targets"].shape[1] <= 620
|
||||
|
||||
|
||||
test_collate_fn()
|
||||
|
||||
Reference in New Issue
Block a user