mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
d1d185edb6
* Adding clinical note dataset * Fix flake8 issues * Fix prepare.py for straggling commas, replace assistant with Rosey in prompt Co-authored-by: Yada P <yadapruk@Tests-MacBook-Pro.local>
85 lines
2.9 KiB
Python
85 lines
2.9 KiB
Python
import json
|
|
import math
|
|
import os
|
|
import random
|
|
import re
|
|
import sys
|
|
from string import punctuation
|
|
|
|
import kaggle
|
|
import pandas as pd
|
|
|
|
CLINICAL_NOTE_GENERATION_TEMPLATE = """User: Write a clinical note about a patient with the following {section}: {section_information}.
|
|
Rosey: {note}"""
|
|
|
|
|
|
def preprocess(mt_dataset):
|
|
def filter_for_notes(row):
|
|
normalized_transcript = row["transcription"].lower()
|
|
if "chief complaint:" in normalized_transcript:
|
|
return True
|
|
return False
|
|
|
|
mt_dataset = mt_dataset.dropna(subset=["description", "transcription"])
|
|
mt_note_subset = mt_dataset[mt_dataset.apply(filter_for_notes, axis=1)]
|
|
return mt_note_subset
|
|
|
|
|
|
def is_chief_complaint(section):
|
|
return "chief complaint" in section.lower()
|
|
|
|
|
|
def get_conversations(dataset):
|
|
def normalize_transcript(x):
|
|
x = re.sub(r"\.+", ".", x)
|
|
x = re.sub(r"\,+", ",", x)
|
|
x = re.sub(r":\s+", ": ", x)
|
|
x = re.sub(r"\.\s+", ". ", x)
|
|
x = re.sub(r":(\s)*\,+", ": ", x)
|
|
x = re.sub(r"\.\,+", ". ", x)
|
|
return x
|
|
|
|
conversations = []
|
|
for idx in range(len(dataset)):
|
|
transcript = normalize_transcript(dataset.iloc[idx]["transcription"])
|
|
sections = re.findall(r"\b[A-Z]+(?: [A-Z]+)*:", transcript)
|
|
if len(sections) >= 2:
|
|
note_prompt = transcript.split(sections[0])[1].split(sections[1])[0]
|
|
else:
|
|
continue
|
|
section_name = sections[0].lower().strip(punctuation)
|
|
if len(note_prompt.split(" ")) > 30 and is_chief_complaint(section_name):
|
|
# There are some chief complaints that seem to be HPI
|
|
section_name = "history of present illness"
|
|
conversations.append(
|
|
CLINICAL_NOTE_GENERATION_TEMPLATE.format(
|
|
section=section_name, section_information=note_prompt, note=transcript
|
|
)
|
|
)
|
|
return conversations
|
|
|
|
|
|
def main(output_dir: str = "data"):
|
|
"""Download and prepare the dataset for use."""
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
kaggle.api.dataset_download_files("tboyle10/medicaltranscriptions", "data", unzip=True)
|
|
mt_samples = preprocess(pd.read_csv("mtsamples.csv"))
|
|
conversations = get_conversations(mt_samples)
|
|
random.shuffle(conversations)
|
|
train_limit = math.ceil(len(conversations) * 0.6)
|
|
dev_limit = math.ceil(len(conversations) * 0.8)
|
|
train, validation, test = (
|
|
conversations[:train_limit],
|
|
conversations[train_limit:dev_limit],
|
|
conversations[dev_limit:],
|
|
)
|
|
splits = {"train": train, "validation": validation, "test": test}
|
|
for split in ["train", "validation", "test"]:
|
|
with open(f"{output_dir}/mt_note_generation_{split}.jsonl", "w") as f:
|
|
for conversation in splits[split]:
|
|
f.write(f"{json.dumps({'conversation': conversation})}\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|