mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-29 16:30:24 +08:00
pre-commits
This commit is contained in:
@@ -33,6 +33,6 @@ Experimental results in wandb
|
||||
## TODOS
|
||||
|
||||
- decide on a model
|
||||
- add special token to declare prompt and reply. Do nto freeze the weights for
|
||||
these
|
||||
- Merge utils etc with reward model
|
||||
- Casual Modelling for GPT-JT does not leverage the bidirectional mask for the
|
||||
prompt? (https://huggingface.co/togethercomputer/GPT-JT-6B-v1)
|
||||
|
||||
@@ -32,6 +32,17 @@ galactica-125:
|
||||
per_device_train_batch_size: 4
|
||||
per_device_eval_batch_size: 4
|
||||
|
||||
gpt-jt:
|
||||
learning_rate: 2e-6
|
||||
model_name: togethercomputer/GPT-JT-6B-v1
|
||||
weight_decay: 0.01
|
||||
max_length: 1024
|
||||
warmup_steps: 600
|
||||
gradient_checkpointing: false
|
||||
gradient_accumulation_steps: 2
|
||||
per_device_train_batch_size: 4
|
||||
per_device_eval_batch_size: 4
|
||||
|
||||
debug:
|
||||
eval_steps: 20
|
||||
eval_size: 100
|
||||
|
||||
@@ -2,11 +2,7 @@ from datasets import load_dataset
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import Dataset, Subset
|
||||
|
||||
QA_SPECIAL_TOKENS = {
|
||||
'Question': '<question>',
|
||||
'Answer': '<answer>'
|
||||
}
|
||||
|
||||
QA_SPECIAL_TOKENS = {"Question": "<question>", "Answer": "<answer>"}
|
||||
|
||||
|
||||
class SquadV2Dataset(Dataset):
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from custom_datasets import get_one_dataset
|
||||
from custom_datasets import QA_SPECIAL_TOKENS, get_one_dataset
|
||||
from custom_datasets.dialogue_collator import DialogueDataCollator
|
||||
from losses import CrossEntropyLoss
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import ConcatDataset, Subset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from custom_datasets import QA_SPECIAL_TOKENS
|
||||
|
||||
SUPPORTED_MODELS = ["galactica"]
|
||||
SUPPORTED_MODELS = ["galactica", "GPT-JT"] # deprecated ..
|
||||
|
||||
|
||||
def get_tokenizer(conf):
|
||||
@@ -20,7 +19,7 @@ def get_tokenizer(conf):
|
||||
|
||||
additional_special_tokens = (
|
||||
[]
|
||||
if not "additional_special_tokens" in tokenizer.special_tokens_map
|
||||
if "additional_special_tokens" not in tokenizer.special_tokens_map
|
||||
else tokenizer.special_tokens_map["additional_special_tokens"]
|
||||
)
|
||||
additional_special_tokens = list(set(additional_special_tokens + list(QA_SPECIAL_TOKENS.values())))
|
||||
|
||||
Reference in New Issue
Block a user