mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Added Codeinstructor, fixed file-reading issue, and allowing for csv files in data_augment.py (#679)
This commit is contained in:
@@ -25,10 +25,10 @@ from bs4 import BeautifulSoup as bs
|
||||
from logic.logic_injector import LogicBug
|
||||
from nltk.corpus import wordnet
|
||||
from syntax.syntax_injector import SyntaxBug
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5ForConditionalGeneration, pipeline
|
||||
|
||||
|
||||
class DataArgumenter:
|
||||
class DataAugmenter:
|
||||
def __init__(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -48,7 +48,7 @@ class DataArgumenter:
|
||||
pass
|
||||
|
||||
|
||||
class EssayInstructor(DataArgumenter):
|
||||
class EssayInstructor(DataAugmenter):
|
||||
def __init__(self, model_name=None):
|
||||
if model_name is None:
|
||||
model_name = "snrspeaks/t5-one-line-summary"
|
||||
@@ -86,7 +86,7 @@ class EssayInstructor(DataArgumenter):
|
||||
return prompts, essay_paragraphs
|
||||
|
||||
|
||||
class EssayReviser(DataArgumenter):
|
||||
class EssayReviser(DataAugmenter):
|
||||
def __init__(self):
|
||||
nltk.download("wordnet")
|
||||
nltk.download("omw-1.4")
|
||||
@@ -132,7 +132,7 @@ class EssayReviser(DataArgumenter):
|
||||
return instructions, [essay] * len(instructions)
|
||||
|
||||
|
||||
class StackExchangeBuilder(DataArgumenter):
|
||||
class StackExchangeBuilder(DataAugmenter):
|
||||
def __init__(self, base_url=None, filter_opts=None):
|
||||
self.base_url = (
|
||||
base_url
|
||||
@@ -271,7 +271,7 @@ class StackExchangeBuilder(DataArgumenter):
|
||||
return questions, answers
|
||||
|
||||
|
||||
class HierachicalSummarizer(DataArgumenter):
|
||||
class HierachicalSummarizer(DataAugmenter):
|
||||
def __init__(self):
|
||||
self.summarizer = pipeline(
|
||||
"summarization",
|
||||
@@ -342,7 +342,7 @@ class HierachicalSummarizer(DataArgumenter):
|
||||
return instructions, answers
|
||||
|
||||
|
||||
class EntityRecognizedSummarizer(DataArgumenter):
|
||||
class EntityRecognizedSummarizer(DataAugmenter):
|
||||
def __init__(self):
|
||||
self.nlp = spacy.load("en_core_web_sm") # run !python -m spacy download en_core_web_sm in order to download
|
||||
|
||||
@@ -357,7 +357,7 @@ class EntityRecognizedSummarizer(DataArgumenter):
|
||||
return [question], [answer]
|
||||
|
||||
|
||||
class CodeBugger(DataArgumenter):
|
||||
class CodeBugger(DataAugmenter):
|
||||
"""
|
||||
https://github.com/LAION-AI/Open-Assistant/blob/main/notebooks/code-bugger/openbugger_example.md
|
||||
Openbugger is a Python package that allows you to inject syntax and logic errors into your code.
|
||||
@@ -391,6 +391,38 @@ class CodeBugger(DataArgumenter):
|
||||
return [question], [answer]
|
||||
|
||||
|
||||
class CodeInstructor(DataAugmenter):
|
||||
def __init__(self):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("Graverman/t5-code-summary")
|
||||
self.model = T5ForConditionalGeneration.from_pretrained("Graverman/t5-code-summary")
|
||||
|
||||
def parse(self, codes):
|
||||
source_encoding = self.tokenizer(
|
||||
codes,
|
||||
max_length=300,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
outputs = self.model.generate(
|
||||
input_ids=source_encoding["input_ids"],
|
||||
attention_mask=source_encoding["attention_mask"],
|
||||
max_length=100,
|
||||
length_penalty=0.75,
|
||||
repetition_penalty=2.5,
|
||||
early_stopping=True,
|
||||
use_cache=True,
|
||||
)
|
||||
summaries = [self.tokenizer.decode(o, skip_special_tokens=True) for o in outputs]
|
||||
|
||||
questions = ["Write a script that does the following:\n" + s for s in summaries]
|
||||
answers = codes
|
||||
|
||||
return questions, answers
|
||||
|
||||
|
||||
def recognize_entities(text, model, n=4, person="ignore"):
|
||||
"""Given a text and a model for entity recognition, return the most occuring entites in the text as a string"""
|
||||
doc = model(text)
|
||||
@@ -417,20 +449,23 @@ def parse_arguments():
|
||||
args.add_argument("--output", type=str, required=True)
|
||||
args = args.parse_args()
|
||||
|
||||
assert args.dataset.endswith(".tsv"), "Dataset file must be a tsv file, containing a list of files to be augmented"
|
||||
assert args.dataset.endswith(".tsv") or args.dataset.endswith(
|
||||
".csv"
|
||||
), "Dataset file must be a tsv or csv file, containing a list of files to be augmented"
|
||||
assert args.output.endswith(".json"), "Output file must be a json file"
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def read_data(args):
|
||||
files = pd.read_csv(args.dataset, sep="\t", header=None)
|
||||
files = files[0].tolist()
|
||||
files = pd.read_csv(args.dataset, sep=",", header=None, names=["file"])
|
||||
files = files["file"].tolist()
|
||||
data = []
|
||||
for file in files:
|
||||
with open(file, "r") as f:
|
||||
text = f.read()
|
||||
data.append(text)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@@ -453,9 +488,12 @@ def get_augmenter(args):
|
||||
elif args.augmenter == "codebugger":
|
||||
augmenter = CodeBugger()
|
||||
|
||||
elif args.augmenter == "codeinstructor":
|
||||
augmenter = CodeInstructor()
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"Augmenter must be one of 'essayinstruction', 'essayrevision', 'stackexchange', 'hierarchicalsummarizer', 'entityrecognizedsummarizer', 'codebugger"
|
||||
"Augmenter must be one of 'essayinstruction', 'essayrevision', 'stackexchange', 'hierarchicalsummarizer', 'entityrecognizedsummarizer', 'codebugger', 'codeinstructor"
|
||||
)
|
||||
|
||||
return augmenter
|
||||
|
||||
Reference in New Issue
Block a user