From 30d7a3d0f53c5e739ca4ed7812c8d4af68901925 Mon Sep 17 00:00:00 2001 From: Tom Zehle Date: Sun, 15 Jan 2023 11:55:40 +0100 Subject: [PATCH] Added Codeinstructor, fixed file-reading issue, and allowing for csv files in data_augment.py (#679) --- scripts/data_augment/data_augment.py | 62 ++++++++++++++++++++++------ 1 file changed, 50 insertions(+), 12 deletions(-) diff --git a/scripts/data_augment/data_augment.py b/scripts/data_augment/data_augment.py index b85d3bf5..0a8d3611 100644 --- a/scripts/data_augment/data_augment.py +++ b/scripts/data_augment/data_augment.py @@ -25,10 +25,10 @@ from bs4 import BeautifulSoup as bs from logic.logic_injector import LogicBug from nltk.corpus import wordnet from syntax.syntax_injector import SyntaxBug -from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5ForConditionalGeneration, pipeline -class DataArgumenter: +class DataAugmenter: def __init__(self): raise NotImplementedError() @@ -48,7 +48,7 @@ class DataArgumenter: pass -class EssayInstructor(DataArgumenter): +class EssayInstructor(DataAugmenter): def __init__(self, model_name=None): if model_name is None: model_name = "snrspeaks/t5-one-line-summary" @@ -86,7 +86,7 @@ class EssayInstructor(DataArgumenter): return prompts, essay_paragraphs -class EssayReviser(DataArgumenter): +class EssayReviser(DataAugmenter): def __init__(self): nltk.download("wordnet") nltk.download("omw-1.4") @@ -132,7 +132,7 @@ class EssayReviser(DataArgumenter): return instructions, [essay] * len(instructions) -class StackExchangeBuilder(DataArgumenter): +class StackExchangeBuilder(DataAugmenter): def __init__(self, base_url=None, filter_opts=None): self.base_url = ( base_url @@ -271,7 +271,7 @@ class StackExchangeBuilder(DataArgumenter): return questions, answers -class HierachicalSummarizer(DataArgumenter): +class HierachicalSummarizer(DataAugmenter): def __init__(self): self.summarizer = pipeline( "summarization", @@ -342,7 +342,7 @@ class HierachicalSummarizer(DataArgumenter): return instructions, answers -class EntityRecognizedSummarizer(DataArgumenter): +class EntityRecognizedSummarizer(DataAugmenter): def __init__(self): self.nlp = spacy.load("en_core_web_sm") # run !python -m spacy download en_core_web_sm in order to download @@ -357,7 +357,7 @@ class EntityRecognizedSummarizer(DataArgumenter): return [question], [answer] -class CodeBugger(DataArgumenter): +class CodeBugger(DataAugmenter): """ https://github.com/LAION-AI/Open-Assistant/blob/main/notebooks/code-bugger/openbugger_example.md Openbugger is a Python package that allows you to inject syntax and logic errors into your code. @@ -391,6 +391,38 @@ class CodeBugger(DataArgumenter): return [question], [answer] +class CodeInstructor(DataAugmenter): + def __init__(self): + self.tokenizer = AutoTokenizer.from_pretrained("Graverman/t5-code-summary") + self.model = T5ForConditionalGeneration.from_pretrained("Graverman/t5-code-summary") + + def parse(self, codes): + source_encoding = self.tokenizer( + codes, + max_length=300, + padding="max_length", + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + outputs = self.model.generate( + input_ids=source_encoding["input_ids"], + attention_mask=source_encoding["attention_mask"], + max_length=100, + length_penalty=0.75, + repetition_penalty=2.5, + early_stopping=True, + use_cache=True, + ) + summaries = [self.tokenizer.decode(o, skip_special_tokens=True) for o in outputs] + + questions = ["Write a script that does the following:\n" + s for s in summaries] + answers = codes + + return questions, answers + + def recognize_entities(text, model, n=4, person="ignore"): """Given a text and a model for entity recognition, return the most occuring entites in the text as a string""" doc = model(text) @@ -417,20 +449,23 @@ def parse_arguments(): args.add_argument("--output", type=str, required=True) args = args.parse_args() - assert args.dataset.endswith(".tsv"), "Dataset file must be a tsv file, containing a list of files to be augmented" + assert args.dataset.endswith(".tsv") or args.dataset.endswith( + ".csv" + ), "Dataset file must be a tsv or csv file, containing a list of files to be augmented" assert args.output.endswith(".json"), "Output file must be a json file" return args def read_data(args): - files = pd.read_csv(args.dataset, sep="\t", header=None) - files = files[0].tolist() + files = pd.read_csv(args.dataset, sep=",", header=None, names=["file"]) + files = files["file"].tolist() data = [] for file in files: with open(file, "r") as f: text = f.read() data.append(text) + return data @@ -453,9 +488,12 @@ def get_augmenter(args): elif args.augmenter == "codebugger": augmenter = CodeBugger() + elif args.augmenter == "codeinstructor": + augmenter = CodeInstructor() + else: raise ValueError( - "Augmenter must be one of 'essayinstruction', 'essayrevision', 'stackexchange', 'hierarchicalsummarizer', 'entityrecognizedsummarizer', 'codebugger" + "Augmenter must be one of 'essayinstruction', 'essayrevision', 'stackexchange', 'hierarchicalsummarizer', 'entityrecognizedsummarizer', 'codebugger', 'codeinstructor" ) return augmenter