Files
Open-Assistant/openassistant/datasets/mt_note_generation/prepare.py
T
Yada Pruksachatkun d1d185edb6 Adding MT Sample clinical note dataset (#804)
* Adding clinical note dataset

* Fix flake8 issues

* Fix prepare.py for straggling commas, replace assistant with Rosey in prompt

Co-authored-by: Yada P <yadapruk@Tests-MacBook-Pro.local>
2023-01-23 14:24:43 +01:00

85 lines
2.9 KiB
Python

import json
import math
import os
import random
import re
import sys
from string import punctuation
import kaggle
import pandas as pd
CLINICAL_NOTE_GENERATION_TEMPLATE = """User: Write a clinical note about a patient with the following {section}: {section_information}.
Rosey: {note}"""
def preprocess(mt_dataset):
def filter_for_notes(row):
normalized_transcript = row["transcription"].lower()
if "chief complaint:" in normalized_transcript:
return True
return False
mt_dataset = mt_dataset.dropna(subset=["description", "transcription"])
mt_note_subset = mt_dataset[mt_dataset.apply(filter_for_notes, axis=1)]
return mt_note_subset
def is_chief_complaint(section):
return "chief complaint" in section.lower()
def get_conversations(dataset):
def normalize_transcript(x):
x = re.sub(r"\.+", ".", x)
x = re.sub(r"\,+", ",", x)
x = re.sub(r":\s+", ": ", x)
x = re.sub(r"\.\s+", ". ", x)
x = re.sub(r":(\s)*\,+", ": ", x)
x = re.sub(r"\.\,+", ". ", x)
return x
conversations = []
for idx in range(len(dataset)):
transcript = normalize_transcript(dataset.iloc[idx]["transcription"])
sections = re.findall(r"\b[A-Z]+(?: [A-Z]+)*:", transcript)
if len(sections) >= 2:
note_prompt = transcript.split(sections[0])[1].split(sections[1])[0]
else:
continue
section_name = sections[0].lower().strip(punctuation)
if len(note_prompt.split(" ")) > 30 and is_chief_complaint(section_name):
# There are some chief complaints that seem to be HPI
section_name = "history of present illness"
conversations.append(
CLINICAL_NOTE_GENERATION_TEMPLATE.format(
section=section_name, section_information=note_prompt, note=transcript
)
)
return conversations
def main(output_dir: str = "data"):
"""Download and prepare the dataset for use."""
os.makedirs(output_dir, exist_ok=True)
kaggle.api.dataset_download_files("tboyle10/medicaltranscriptions", "data", unzip=True)
mt_samples = preprocess(pd.read_csv("mtsamples.csv"))
conversations = get_conversations(mt_samples)
random.shuffle(conversations)
train_limit = math.ceil(len(conversations) * 0.6)
dev_limit = math.ceil(len(conversations) * 0.8)
train, validation, test = (
conversations[:train_limit],
conversations[train_limit:dev_limit],
conversations[dev_limit:],
)
splits = {"train": train, "validation": validation, "test": test}
for split in ["train", "validation", "test"]:
with open(f"{output_dir}/mt_note_generation_{split}.jsonl", "w") as f:
for conversation in splits[split]:
f.write(f"{json.dumps({'conversation': conversation})}\n")
if __name__ == "__main__":
sys.exit(main())