Files
2023-01-18 17:02:16 +00:00

199 lines
6.1 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## CodeT Code Generation Datasets\n",
"\n",
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LAION-AI/Open-Assistant/blob/main/notebooks/data-augmentation/codet-data/Augment_CodeT_codegen.ipynb)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook contains code to parse CodeT code generation prompt and solution data and modify to `(prompt, solution)` pairs outputted in a `.jsonl` file.\n",
"\n",
"Requirements: `requests`"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"from pathlib import Path\n",
"import requests\n",
"from typing import Dict, List, Tuple"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"DATA_FILES: List[str] = [\n",
" \"HumanEval_for_code_generation.jsonl\",\n",
" \"mbpp_sanitized_for_code_generation.jsonl\",\n",
"]\n",
"\n",
"OUT_FILES: List[str] = [\n",
" \"HumanEval_codegen.jsonl\",\n",
" \"mbpp_codegen.jsonl\",\n",
"]\n",
"\n",
"FILE_PATHS: List[Path] = [Path(f\"data/{data_file}\") for data_file in DATA_FILES]\n",
"\n",
"OUT_PATHS: List[Path] = [Path(f\"data/augmented/{out_file}\") for out_file in OUT_FILES]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def download_file(filename: str):\n",
" url = f\"https://raw.githubusercontent.com/microsoft/CodeT/main/CodeT/data/dataset/{filename}\"\n",
" response = requests.get(url)\n",
" with open(f\"data/{filename}\", \"wb\") as f:\n",
" f.write(response.content)\n",
"\n",
"\n",
"for filename in DATA_FILES:\n",
" download_file(filename)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can find the docstring, use its contents as the instruction (prefixed with \"Write a function corresponding to the docstring:\") and then use the content prior to the docstring and the canonical solution as the response."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def get_docstring_indices(prompt_lines: List[str]) -> Tuple[int, int]:\n",
" docstring_start, docstring_end = None, None\n",
"\n",
" for i, line in enumerate(prompt_lines):\n",
" if not (line.strip().startswith('\"\"\"') or line.strip().startswith(\"'''\")):\n",
" continue\n",
" if docstring_start:\n",
" docstring_end = i\n",
" break\n",
" docstring_start = i\n",
"\n",
" if docstring_end:\n",
" return docstring_start, docstring_end\n",
" raise ValueError(f\"No complete docstring found!\\n{prompt_lines}\")\n",
"\n",
"\n",
"def get_before(prompt_lines: List[str], before: int) -> List[str]:\n",
" before_lines = prompt_lines[:before]\n",
" return before_lines\n",
"\n",
"\n",
"def get_between(prompt_lines: List[str], start: int, end: int) -> List[str]:\n",
" between_lines = prompt_lines[start:end]\n",
" return between_lines"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def get_request_and_solution(sample: dict) -> Tuple[List[str], List[str]]:\n",
" prompt = sample[\"prompt\"]\n",
" prompt_lines = prompt.splitlines()\n",
"\n",
" docstring_start, docstring_end = get_docstring_indices(prompt_lines)\n",
"\n",
" # Extract prompt\n",
" in_docstring = get_between(prompt_lines, docstring_start, docstring_end)\n",
" if '\"\"\"' in in_docstring[0] or \"'''\" in in_docstring[0]:\n",
" in_docstring[0] = in_docstring[0].replace('\"\"\"', \"\").replace(\"...\", \"\").strip()\n",
" request = \"Write a Python function corresponding to the docstring: \" + \" \".join([p.strip() for p in in_docstring])\n",
"\n",
" # Extract solution\n",
" before_docstring = get_before(prompt_lines, docstring_start)\n",
" after_docstring = sample[\"canonical_solution\"].splitlines()\n",
" solution = before_docstring + after_docstring\n",
" # Gets rid of consecutive empty lines\n",
" solution = [v for i, v in enumerate(solution) if v != \"\" or v != solution[i - 1]]\n",
" solution = \"\\n\".join(solution)\n",
"\n",
" return request, solution"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def process_file(file_path: Path, out_path: Path):\n",
" lines = file_path.read_text().splitlines()\n",
" samples = list(map(json.loads, lines))\n",
"\n",
" output = []\n",
" for sample in samples:\n",
" prompt, solution = get_request_and_solution(sample)\n",
" output.append({\"prompt\": prompt, \"solution\": solution})\n",
"\n",
" with open(out_path, \"w\") as f:\n",
" for sample in output:\n",
" f.write(json.dumps(sample))\n",
" f.write(\"\\n\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"for file_path, out_path in zip(FILE_PATHS, OUT_PATHS):\n",
" process_file(file_path, out_path)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.5 ('venv': venv)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.5"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "1f9a0efd3e4a33b8f30a65df6ca5a95cc3f93ce2f11519ee8c13fe711de61465"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}