Added Codeinstructor, fixed file-reading issue, and allowing for csv files in data_augment.py (#679)

This commit is contained in:
Tom Zehle
2023-01-15 11:55:40 +01:00
committed by GitHub
parent 112ad5cd44
commit 30d7a3d0f5
+50 -12
View File
@@ -25,10 +25,10 @@ from bs4 import BeautifulSoup as bs
from logic.logic_injector import LogicBug
from nltk.corpus import wordnet
from syntax.syntax_injector import SyntaxBug
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5ForConditionalGeneration, pipeline
class DataArgumenter:
class DataAugmenter:
def __init__(self):
raise NotImplementedError()
@@ -48,7 +48,7 @@ class DataArgumenter:
pass
class EssayInstructor(DataArgumenter):
class EssayInstructor(DataAugmenter):
def __init__(self, model_name=None):
if model_name is None:
model_name = "snrspeaks/t5-one-line-summary"
@@ -86,7 +86,7 @@ class EssayInstructor(DataArgumenter):
return prompts, essay_paragraphs
class EssayReviser(DataArgumenter):
class EssayReviser(DataAugmenter):
def __init__(self):
nltk.download("wordnet")
nltk.download("omw-1.4")
@@ -132,7 +132,7 @@ class EssayReviser(DataArgumenter):
return instructions, [essay] * len(instructions)
class StackExchangeBuilder(DataArgumenter):
class StackExchangeBuilder(DataAugmenter):
def __init__(self, base_url=None, filter_opts=None):
self.base_url = (
base_url
@@ -271,7 +271,7 @@ class StackExchangeBuilder(DataArgumenter):
return questions, answers
class HierachicalSummarizer(DataArgumenter):
class HierachicalSummarizer(DataAugmenter):
def __init__(self):
self.summarizer = pipeline(
"summarization",
@@ -342,7 +342,7 @@ class HierachicalSummarizer(DataArgumenter):
return instructions, answers
class EntityRecognizedSummarizer(DataArgumenter):
class EntityRecognizedSummarizer(DataAugmenter):
def __init__(self):
self.nlp = spacy.load("en_core_web_sm") # run !python -m spacy download en_core_web_sm in order to download
@@ -357,7 +357,7 @@ class EntityRecognizedSummarizer(DataArgumenter):
return [question], [answer]
class CodeBugger(DataArgumenter):
class CodeBugger(DataAugmenter):
"""
https://github.com/LAION-AI/Open-Assistant/blob/main/notebooks/code-bugger/openbugger_example.md
Openbugger is a Python package that allows you to inject syntax and logic errors into your code.
@@ -391,6 +391,38 @@ class CodeBugger(DataArgumenter):
return [question], [answer]
class CodeInstructor(DataAugmenter):
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained("Graverman/t5-code-summary")
self.model = T5ForConditionalGeneration.from_pretrained("Graverman/t5-code-summary")
def parse(self, codes):
source_encoding = self.tokenizer(
codes,
max_length=300,
padding="max_length",
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
outputs = self.model.generate(
input_ids=source_encoding["input_ids"],
attention_mask=source_encoding["attention_mask"],
max_length=100,
length_penalty=0.75,
repetition_penalty=2.5,
early_stopping=True,
use_cache=True,
)
summaries = [self.tokenizer.decode(o, skip_special_tokens=True) for o in outputs]
questions = ["Write a script that does the following:\n" + s for s in summaries]
answers = codes
return questions, answers
def recognize_entities(text, model, n=4, person="ignore"):
"""Given a text and a model for entity recognition, return the most occuring entites in the text as a string"""
doc = model(text)
@@ -417,20 +449,23 @@ def parse_arguments():
args.add_argument("--output", type=str, required=True)
args = args.parse_args()
assert args.dataset.endswith(".tsv"), "Dataset file must be a tsv file, containing a list of files to be augmented"
assert args.dataset.endswith(".tsv") or args.dataset.endswith(
".csv"
), "Dataset file must be a tsv or csv file, containing a list of files to be augmented"
assert args.output.endswith(".json"), "Output file must be a json file"
return args
def read_data(args):
files = pd.read_csv(args.dataset, sep="\t", header=None)
files = files[0].tolist()
files = pd.read_csv(args.dataset, sep=",", header=None, names=["file"])
files = files["file"].tolist()
data = []
for file in files:
with open(file, "r") as f:
text = f.read()
data.append(text)
return data
@@ -453,9 +488,12 @@ def get_augmenter(args):
elif args.augmenter == "codebugger":
augmenter = CodeBugger()
elif args.augmenter == "codeinstructor":
augmenter = CodeInstructor()
else:
raise ValueError(
"Augmenter must be one of 'essayinstruction', 'essayrevision', 'stackexchange', 'hierarchicalsummarizer', 'entityrecognizedsummarizer', 'codebugger"
"Augmenter must be one of 'essayinstruction', 'essayrevision', 'stackexchange', 'hierarchicalsummarizer', 'entityrecognizedsummarizer', 'codebugger', 'codeinstructor"
)
return augmenter