{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## CodeT Test Generation Datasets\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LAION-AI/Open-Assistant/blob/main/notebooks/data-augmentation/codet-data/Augment_CodeT_testgen.ipynb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook contains code to parse CodeT test case generation prompt and solution data and modify to `(prompt, solution)` pairs outputted in a `.jsonl` file.\n", "\n", "Requirements: `requests`" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import json\n", "from pathlib import Path\n", "import requests\n", "from typing import List, Tuple" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "DATA_FILES: List[str] = [\n", " \"HumanEval_for_test_case_generation.jsonl\",\n", " \"mbpp_sanitized_for_test_case_generation.jsonl\",\n", "]\n", "\n", "OUT_FILES: List[str] = [\n", " \"HumanEval_testgen.jsonl\",\n", " \"mbpp_testgen.jsonl\",\n", "]\n", "\n", "FILE_PATHS: List[Path] = [Path(f\"data/{data_file}\") for data_file in DATA_FILES]\n", "\n", "OUT_PATHS: List[Path] = [Path(f\"data/augmented/{out_file}\") for out_file in OUT_FILES]" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def download_file(filename: str):\n", " url = f\"https://raw.githubusercontent.com/microsoft/CodeT/main/CodeT/data/dataset/{filename}\"\n", " response = requests.get(url)\n", " with open(f\"data/{filename}\", \"wb\") as f:\n", " f.write(response.content)\n", "\n", "\n", "for filename in DATA_FILES:\n", " download_file(filename)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def get_docstring_indices(prompt_lines: List[str]) -> Tuple[int, int]:\n", " docstring_start, docstring_end = None, None\n", "\n", " for i, line in enumerate(prompt_lines):\n", " if not (line.strip().startswith('\"\"\"') or line.strip().startswith(\"'''\")):\n", " continue\n", " if docstring_start:\n", " docstring_end = i\n", " break\n", " docstring_start = i\n", "\n", " if docstring_end:\n", " return docstring_start, docstring_end\n", " raise ValueError(f\"No complete docstring found!\\n{prompt_lines}\")\n", "\n", "\n", "def get_between(prompt_lines: List[str], start: int, end: int) -> List[str]:\n", " between_lines = prompt_lines[start:end]\n", " return between_lines" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def get_request(sample: dict) -> List[str]:\n", " prompt = sample[\"prompt\"]\n", " prompt_lines = prompt.splitlines()\n", "\n", " docstring_start, docstring_end = get_docstring_indices(prompt_lines)\n", "\n", " # Extract prompt\n", " in_docstring = get_between(prompt_lines, docstring_start, docstring_end)\n", " if '\"\"\"' in in_docstring[0] or \"'''\" in in_docstring[0]:\n", " in_docstring[0] = in_docstring[0].replace('\"\"\"', \"\").replace(\"...\", \"\").strip()\n", " request = \"Write a test for a Python function with the following docstring: \" + \" \".join(\n", " [p.strip() for p in in_docstring]\n", " )\n", "\n", " return request\n", "\n", "\n", "def get_test_code(sample: dict) -> List[str]:\n", " test = sample[\"test\"]\n", " test_lines = test.splitlines()\n", " start = 0\n", " for i, line in enumerate(test_lines):\n", " if \"def check(\" in line:\n", " start = i\n", " return \"\\n\".join(test_lines[start:])" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def process_file(file_path: Path, out_path: Path):\n", " lines = file_path.read_text().splitlines()\n", " samples = list(map(json.loads, lines))\n", "\n", " output = []\n", " for sample in samples:\n", " prompt = get_request(sample)\n", " test = get_test_code(sample)\n", " output.append({\"prompt\": prompt, \"solution\": test})\n", "\n", " with open(out_path, \"w\") as f:\n", " for sample in output:\n", " f.write(json.dumps(sample))\n", " f.write(\"\\n\")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "for file_path, out_path in zip(FILE_PATHS, OUT_PATHS):\n", " process_file(file_path, out_path)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.10.5 ('venv': venv)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.5" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "1f9a0efd3e4a33b8f30a65df6ca5a95cc3f93ce2f11519ee8c13fe711de61465" } } }, "nbformat": 4, "nbformat_minor": 2 }