mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
precommits
This commit is contained in:
@@ -64,4 +64,4 @@ debug:
|
||||
gradient_accumulation_steps: 1
|
||||
per_device_train_batch_size: 1
|
||||
per_device_eval_batch_size: 1
|
||||
quantization: false
|
||||
quantization: false
|
||||
|
||||
@@ -30,6 +30,7 @@ summarization_config_mapping = {
|
||||
QA_DATASETS = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"]
|
||||
SUMMARIZATION_DATASETS = ["xsum", "cnn_dailymail", "samsum", "multi_news"]
|
||||
|
||||
|
||||
def index_squad_v2(example):
|
||||
return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0]
|
||||
|
||||
@@ -159,4 +160,4 @@ def get_one_dataset(conf, dataset_name):
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset {dataset_name}")
|
||||
|
||||
return train, eval
|
||||
return train, eval
|
||||
|
||||
@@ -2,11 +2,11 @@ accelerate==0.15.0
|
||||
bitsandbytes==0.36.0.post2
|
||||
datasets==2.8.0
|
||||
deepspeed==0.7.7
|
||||
evaluate==0.4.0
|
||||
mpi4py==3.1.4
|
||||
nltk==3.8.1
|
||||
numpy==1.23.0
|
||||
PyYAML==6.0
|
||||
scikit_learn==1.2.0
|
||||
torch==1.13.1
|
||||
transformers==4.25.1
|
||||
evaluate==0.4.0
|
||||
nltk==3.8.1
|
||||
@@ -1,6 +1,6 @@
|
||||
import argparse
|
||||
import os
|
||||
from distutils.util import strtobool
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import bitsandbytes
|
||||
@@ -9,7 +9,6 @@ from torch import nn
|
||||
from transformers import PreTrainedModel, Trainer, TrainingArguments
|
||||
from transformers.training_args import OptimizerNames
|
||||
from utils import get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls
|
||||
from functools import partial
|
||||
|
||||
|
||||
def compute_metrics(eval_pred, preprocess_fns, metrics):
|
||||
|
||||
@@ -6,7 +6,7 @@ import nltk
|
||||
import numpy as np
|
||||
import transformers
|
||||
import yaml
|
||||
from custom_datasets import QA_SPECIAL_TOKENS, QA_DATASETS, SUMMARIZATION_DATASETS, get_one_dataset
|
||||
from custom_datasets import QA_DATASETS, QA_SPECIAL_TOKENS, SUMMARIZATION_DATASETS, get_one_dataset
|
||||
from custom_datasets.dialogue_collator import DialogueDataCollator
|
||||
from losses import CrossEntropyLoss, PolyLoss
|
||||
from models import freeze_top_n_layers, get_specific_model
|
||||
|
||||
Reference in New Issue
Block a user