[fix] patch translated history conversation

This commit is contained in:
theblackcat102
2023-02-08 00:20:11 +00:00
parent a39cbab524
commit 2c35ff6e50
2 changed files with 17 additions and 4 deletions
@@ -32,7 +32,7 @@ SUMMARIZATION_DATASETS = [
"debate_sum",
"tldr_news",
]
OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning", "private_tuning"]
OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning", "private_tuning", "oa_translated"]
def train_val_dataset(dataset, val_split=0.2):
@@ -95,7 +95,7 @@ def get_one_dataset(conf, dataset_name):
elif dataset_name == "private_tuning":
dataset = PrivateInstructionTuning(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.2)
elif dataset_name == "translate_qa":
elif dataset_name == "oa_translated":
dataset = TranslatedQA(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.01)
else:
@@ -349,8 +349,21 @@ class TranslatedQA(Dataset):
if "Python " in data["text"]:
continue
# incorrect, TODO: fix later
prefix = ""
for convo_round in data["translate"]:
self.pairs.append((convo_round["human"], convo_round["answer"]))
human, answer = format_pair((convo_round["human"], convo_round["answer"]))
if convo_round["round"] > 2:
# TODO: remove this later
self.pairs.append(("{}{}{}".format(prefix, "<sep>", human), answer))
else:
self.pairs.append((human, answer))
prefix += "{}{}{}{}".format(
QA_SPECIAL_TOKENS["Question"],
convo_round["human"],
QA_SPECIAL_TOKENS["Answer"],
convo_round["answer"],
)
self.length = len(self.pairs)
@@ -358,4 +371,4 @@ class TranslatedQA(Dataset):
return self.length
def __getitem__(self, index):
return format_pair(self.pairs[index])
return self.pairs[index]