mirror of
https://github.com/wassname/GENIES.git
synced 2026-06-27 16:10:25 +08:00
removed build
This commit is contained in:
@@ -1,251 +0,0 @@
|
||||
import transformers
|
||||
from tqdm import tqdm
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import PreTrainedTokenizer
|
||||
from api.model import Model
|
||||
import re
|
||||
import copy
|
||||
from torch.utils.data import Dataset
|
||||
import api.util as util
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from typing import Dict, Sequence, Any, List, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class Distribution:
|
||||
def __init__(self, dir: str):
|
||||
self.dir = dir
|
||||
training_examples = util.load_json(f"{dir}/train.json")
|
||||
test_examples = util.load_json(f"{dir}/test.json")
|
||||
meta_data = util.load_json(f"{dir}/meta_data.json")
|
||||
|
||||
self.formats = meta_data["formats"]
|
||||
self.id = meta_data["id"]
|
||||
|
||||
# Add prompts for each example
|
||||
for example in training_examples:
|
||||
self.format_example(self.formats, example)
|
||||
for example in test_examples:
|
||||
self.format_example(self.formats, example)
|
||||
|
||||
self.training_dataset = MCDataset(
|
||||
training_examples, distribution_id=self.id, distribution_dir=self.dir
|
||||
)
|
||||
self.test_dataset = MCDataset(
|
||||
test_examples, distribution_id=self.id, distribution_dir=self.dir
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def format_example(formats: List[str], example: dict) -> dict:
|
||||
all_format_keys = [re.findall(r"\{(.+?)\}", format) for format in formats]
|
||||
all_format_keys = set([key for keys in all_format_keys for key in keys])
|
||||
example_keys = set(example.keys())
|
||||
example_format_keys = example_keys.intersection(all_format_keys)
|
||||
|
||||
for format in formats:
|
||||
keys_for_format_string = set(re.findall(r"\{(.+?)\}", format))
|
||||
if keys_for_format_string == example_format_keys:
|
||||
example["prompt"] = format.format(**example)
|
||||
return example
|
||||
raise Exception(
|
||||
f"An example could not be matched to a format string.\nExample: {example}\nFormat strings: {formats}"
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.id)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.id == other.id
|
||||
|
||||
|
||||
class MCDataset(Dataset):
|
||||
def __init__(self, examples, distribution_id, distribution_dir, max_examples=None):
|
||||
self.distribution_dir = distribution_dir
|
||||
self.distribution_id = distribution_id
|
||||
self.examples = examples
|
||||
for i, example in enumerate(self.examples):
|
||||
example["id"] = i
|
||||
self.max_examples = max_examples
|
||||
|
||||
def __len__(self):
|
||||
if self.max_examples != None:
|
||||
return min(len(self.examples), self.max_examples)
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.examples[idx]
|
||||
|
||||
def set_max_examples(self, max_examples):
|
||||
self.max_examples = max_examples
|
||||
self.examples = self.examples[:max_examples]
|
||||
|
||||
def filter_out_long_examples(self, tokenizer: PreTrainedTokenizer):
|
||||
filtered_examples = []
|
||||
for example in self.examples:
|
||||
tokenized = [
|
||||
tokenizer.encode(example["prompt"] + r + tokenizer.eos_token)
|
||||
for r in example["responses"]
|
||||
]
|
||||
if max([len(t) for t in tokenized]) <= tokenizer.model_max_length:
|
||||
filtered_examples.append(example)
|
||||
num_examples_filtered_out = len(self.examples) - len(filtered_examples)
|
||||
|
||||
if num_examples_filtered_out > 0:
|
||||
util.print_once(
|
||||
f"Filtered out {num_examples_filtered_out} examples because they exceeded the max length of {tokenizer.model_max_length}"
|
||||
)
|
||||
|
||||
self.examples = filtered_examples
|
||||
return filtered_examples
|
||||
|
||||
def get_example_scores(
|
||||
self, model: Model, accelerator: Optional[Accelerator], per_device_batch_size
|
||||
) -> List[dict]:
|
||||
# First filter out examples that exceed the models maximum sequence lenght
|
||||
self.filter_out_long_examples(model.tokenizer)
|
||||
|
||||
dataloader = DataLoader(
|
||||
self,
|
||||
batch_size=per_device_batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=MCDataCollator(tokenizer=model.tokenizer),
|
||||
)
|
||||
|
||||
model.hf_model.eval()
|
||||
model.hf_model, dataloader = accelerator.prepare(model.hf_model, dataloader)
|
||||
|
||||
example_scores = [] # a list of tuples. Each tuple is (score, example id)
|
||||
for batch in tqdm(dataloader):
|
||||
batch_scores = self.get_scores_for_batch(model.hf_model, batch)
|
||||
ids = batch.pop("ids")
|
||||
example_scores.extend(list(zip(batch_scores, ids)))
|
||||
|
||||
# Gather tensors from different devices
|
||||
example_scores = torch.tensor(example_scores, dtype=torch.float32).to(model.hf_model.device)
|
||||
example_scores = accelerator.gather(example_scores).cpu()
|
||||
|
||||
accelerator.free_memory()
|
||||
|
||||
# Reorder flattened scores to ensure indices match
|
||||
ordered_example_scores = [None] * len(self.examples)
|
||||
for i in range(example_scores.shape[0]):
|
||||
ordered_example_scores[int(example_scores[i, 1])] = float(example_scores[i, 0])
|
||||
return ordered_example_scores
|
||||
|
||||
@staticmethod
|
||||
def get_scores_for_batch(model, batch):
|
||||
output = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
|
||||
logits = output.logits
|
||||
|
||||
# Shift so that tokens < n predict n
|
||||
logits_for_batch = logits[:, :-1, :]
|
||||
shift_labels_batch = batch["labels"][:, 1:]
|
||||
|
||||
# Flatten the tokens
|
||||
shift_logits_batch = logits_for_batch.contiguous().view(-1, logits.shape[-1])
|
||||
shift_labels_batch = shift_labels_batch.contiguous().view(-1)
|
||||
|
||||
# Cross-entropy loss
|
||||
loss_fct = CrossEntropyLoss(reduction="none") # use "none" to get loss per item
|
||||
batch_losses = loss_fct(shift_logits_batch, shift_labels_batch)
|
||||
avg_log_probs = -batch_losses.view(logits.size(0), -1).mean(dim=1)
|
||||
|
||||
# Unflatten avg log probs
|
||||
response_labels = batch["response_labels"]
|
||||
response_labels = response_labels.to(model.device).to(dtype=avg_log_probs.dtype)
|
||||
response_labels.requires_grad = True
|
||||
avg_log_probs_unflattened = avg_log_probs.contiguous().view(
|
||||
response_labels.shape[0], response_labels.shape[1]
|
||||
)
|
||||
response_probabilities = torch.softmax(avg_log_probs_unflattened, dim=1).to(model.device)
|
||||
example_scores = (response_probabilities * response_labels).sum(dim=1)
|
||||
return example_scores
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCDataCollator:
|
||||
tokenizer: transformers.PreTrainedTokenizer
|
||||
|
||||
# Returns a dict with the keys inputs, labels, attention_mask, and scores. Each key is an array that represents the flattened MC options. scores is a list of lists. Each sublist represents options for a single example and the values of are the scores for the response.
|
||||
def __call__(self, instances: Sequence[Dict]):
|
||||
response_labels = torch.tensor(
|
||||
[list(instance["responses"].values()) for instance in instances]
|
||||
)
|
||||
ids = [instance["id"] for instance in instances]
|
||||
flattened_instances = [
|
||||
{"prompt": instance["prompt"], "response": response}
|
||||
for instance in instances
|
||||
for response in instance["responses"]
|
||||
]
|
||||
|
||||
collator = SupervisedDataCollator(tokenizer=self.tokenizer)
|
||||
model_inputs = collator(flattened_instances)
|
||||
model_inputs.update({"response_labels": response_labels})
|
||||
model_inputs.update({"ids": ids})
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
@dataclass
|
||||
class SupervisedDataCollator:
|
||||
"""Collate examples for supervised fine-tuning."""
|
||||
|
||||
tokenizer: transformers.PreTrainedTokenizer
|
||||
|
||||
@staticmethod
|
||||
def tokenize_prompts_and_responses(
|
||||
prompts: Sequence[str],
|
||||
responses: Sequence[str],
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
) -> Dict:
|
||||
"""Preprocess the data by tokenizing."""
|
||||
prompts_tokenized = []
|
||||
responses_tokenized = []
|
||||
prompt_lens = []
|
||||
for prompt, response in zip(prompts, responses):
|
||||
if prompt == "":
|
||||
prompt_tokenized = torch.tensor([], dtype=torch.int64)
|
||||
else:
|
||||
prompt_tokenized = tokenizer(
|
||||
prompt, return_tensors="pt", padding=False, add_special_tokens=False
|
||||
).input_ids[0]
|
||||
response_tokenized = tokenizer(
|
||||
response, return_tensors="pt", padding=False, add_special_tokens=False
|
||||
).input_ids[0]
|
||||
|
||||
responses_tokenized.append(response_tokenized)
|
||||
prompts_tokenized.append(prompt_tokenized)
|
||||
|
||||
input_ids = [
|
||||
torch.cat((s, t), dim=0) for s, t in zip(prompts_tokenized, responses_tokenized)
|
||||
]
|
||||
prompt_lens = [len(s) for s in prompts_tokenized]
|
||||
labels = copy.deepcopy(input_ids)
|
||||
for label, prompt_len in zip(labels, prompt_lens):
|
||||
label[:prompt_len] = util.IGNORE_INDEX
|
||||
return dict(input_ids=input_ids, labels=labels)
|
||||
|
||||
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
||||
prompts = [instance["prompt"] for instance in instances]
|
||||
responses = [instance["response"] for instance in instances]
|
||||
|
||||
suffix = self.tokenizer.eos_token
|
||||
responses = [f"{response}{suffix}" for response in responses]
|
||||
|
||||
data_dict = self.tokenize_prompts_and_responses(prompts, responses, self.tokenizer)
|
||||
|
||||
input_ids = data_dict["input_ids"]
|
||||
labels = data_dict["labels"]
|
||||
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
||||
)
|
||||
labels = torch.nn.utils.rnn.pad_sequence(
|
||||
labels, batch_first=True, padding_value=util.IGNORE_INDEX
|
||||
)
|
||||
return dict(
|
||||
input_ids=input_ids,
|
||||
labels=labels,
|
||||
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
||||
)
|
||||
@@ -1,92 +0,0 @@
|
||||
import fire
|
||||
from api.data_classes import Distribution
|
||||
from typing import Union, List, Optional
|
||||
import wandb
|
||||
import copy
|
||||
import api.util as util
|
||||
import os
|
||||
import wandb
|
||||
|
||||
# Example sweep configuration:
|
||||
default_sweep_configuration = {
|
||||
"method": "random",
|
||||
"name": "sweep",
|
||||
"metric": {"goal": "maximize", "name": "average_score"},
|
||||
"parameters": {
|
||||
"learning_rate": {"max": 1e-4, "min": 8e-6},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def hyper_parameter_sweep(
|
||||
model_dir: str,
|
||||
distribution: Union[str, Distribution],
|
||||
intervention_dir: str,
|
||||
sweep_configuration: dict = default_sweep_configuration,
|
||||
train_kwargs: dict = {},
|
||||
eval_kwargs: dict = {},
|
||||
count: int = 4,
|
||||
wandb_project="project",
|
||||
) -> Optional[List[dict]]:
|
||||
if isinstance(distribution, str):
|
||||
distribution = Distribution(distribution)
|
||||
|
||||
(key=util.wandb_api_key())
|
||||
|
||||
sweep_id = wandb.sweep(sweep=sweep_configuration, project=wandb_project)
|
||||
|
||||
wandb.agent(
|
||||
sweep_id,
|
||||
function=lambda: train_with_params(
|
||||
model_dir, distribution, intervention_dir, train_kwargs, eval_kwargs
|
||||
),
|
||||
count=count,
|
||||
)
|
||||
|
||||
|
||||
def train_with_params(model_dir, distribution, intervention_dir, train_kwargs, eval_kwargs):
|
||||
wandb.init()
|
||||
parameters = wandb.config
|
||||
run_train_kwargs = copy.deepcopy(train_kwargs)
|
||||
run_train_kwargs.update(parameters)
|
||||
if isinstance(distribution, str):
|
||||
distribution = Distribution(distribution)
|
||||
|
||||
model_name = os.path.basename(model_dir)
|
||||
intervention_name = os.path.basename(intervention_dir)
|
||||
output_model_name = f"{intervention_name}/{model_name}-{distribution.id}-{'-'.join([f'{k}={v}' for k,v in parameters.items()])}"
|
||||
output_model_dir = f"models/hps/{output_model_name}"
|
||||
eval_output_path = f"results/evaluations/hps/{output_model_name}.json"
|
||||
|
||||
train_command = f"accelerate launch --deepspeed_config_file configs/ds_zero_3.json --use_deepspeed {intervention_dir}/train.py"
|
||||
run_train_kwargs.update(
|
||||
{
|
||||
"model_dir": model_dir,
|
||||
"output_dir": output_model_dir,
|
||||
"training_distribution_dir": distribution.dir,
|
||||
}
|
||||
)
|
||||
for k, v in run_train_kwargs.items():
|
||||
train_command += f' --{k} "{str(v)}"'
|
||||
|
||||
util.execute_command(train_command)
|
||||
|
||||
eval_command = f"accelerate launch --deepspeed_config_file configs/ds_zero_3.json --use_deepspeed {intervention_dir}/eval.py"
|
||||
eval_kwargs.update(
|
||||
{
|
||||
"model_dir": model_dir,
|
||||
"distribution_dirs": [distribution.dir],
|
||||
"output_paths": [eval_output_path],
|
||||
}
|
||||
)
|
||||
for k, v in eval_kwargs.items():
|
||||
eval_command += f' --{k} "{str(v)}"'
|
||||
util.execute_command(eval_command)
|
||||
|
||||
# Open json to get evaluation
|
||||
result = util.load_json(eval_output_path)
|
||||
wandb.log({"average_score": result["average_score"]})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(hyper_parameter_sweep)
|
||||
@@ -1,171 +0,0 @@
|
||||
import os
|
||||
import random
|
||||
from api.data_classes import Distribution
|
||||
import api.util as util
|
||||
import fire
|
||||
import json
|
||||
from typing import Optional, List
|
||||
from api.evaluate import evaluate
|
||||
from api.train import train
|
||||
|
||||
|
||||
def main(
|
||||
base_model_dir: str,
|
||||
intervention_dir: str,
|
||||
output_path: str,
|
||||
path_to_distribution_shift_pairs: Optional[str] == None,
|
||||
source_dirs: Optional[str] = None,
|
||||
target_dirs: Optional[str] = None,
|
||||
dry_run: bool = False,
|
||||
baseline_intervention_dir: str = "src/interventions/pro",
|
||||
use_cached_evaluations: bool = True,
|
||||
retrain_models: bool = False,
|
||||
train_kwargs: Optional[dict] = {},
|
||||
eval_kwargs: Optional[dict] = {},
|
||||
):
|
||||
|
||||
# Check that the arguments are valid
|
||||
if path_to_distribution_shift_pairs == None and (source_dirs == None or target_dirs == None):
|
||||
raise ValueError(
|
||||
"Must provide either path_to_distribution_shift_pairs or paths_to_source and target_dirs"
|
||||
)
|
||||
|
||||
# Parse pairs.json file if it was provided into source_dirs and target_dirs
|
||||
pairs_data = util.load_json(path_to_distribution_shift_pairs)
|
||||
relative_source_dirs = [pair["source"] for pair in pairs_data]
|
||||
relative_target_dirs = [pair["target"] for pair in pairs_data]
|
||||
source_dirs = [f"distributions/{dir}" for dir in relative_source_dirs]
|
||||
target_dirs = [f"distributions/{dir}" for dir in relative_target_dirs]
|
||||
|
||||
# Load distributions to obtain meta data
|
||||
source_distributions = [Distribution(dir) for dir in source_dirs]
|
||||
target_distributions = [Distribution(dir) for dir in target_dirs]
|
||||
|
||||
# Initialize output data with empty values
|
||||
pgc_result = [
|
||||
{
|
||||
"source": source.id,
|
||||
"target": target.id,
|
||||
"pgc": None,
|
||||
"performance_of_base_model": None,
|
||||
"performance_trained_on_source": None,
|
||||
"performance_trained_on_target": None,
|
||||
}
|
||||
for source, target in zip(source_distributions, target_distributions)
|
||||
]
|
||||
|
||||
util.print_once("")
|
||||
util.print_once("----------- Measuring baselines -----------")
|
||||
util.print_once("")
|
||||
|
||||
unique_target_distributions = list(set(target_distributions))
|
||||
unique_target_ids = [t.id for t in unique_target_distributions]
|
||||
|
||||
# Evaluate the base model on the target distributions
|
||||
|
||||
evaluations = evaluate(
|
||||
model_dir=base_model_dir,
|
||||
distribution_dirs=unique_target_distributions,
|
||||
intervention_dir=baseline_intervention_dir,
|
||||
use_cached=use_cached_evaluations,
|
||||
dry_run=dry_run,
|
||||
eval_kwargs=eval_kwargs,
|
||||
)
|
||||
|
||||
# Write evaluation to the pgc result file
|
||||
for target_id, evaluation in zip(unique_target_ids, evaluations):
|
||||
for item in pgc_result:
|
||||
if item["target"] == target_id:
|
||||
item["performance_of_base_model"] = evaluation["average_score"]
|
||||
|
||||
util.save_json(pgc_result, output_path)
|
||||
util.print_once(f"Saved intermediate results in {output_path}")
|
||||
|
||||
# For each target distribution, train the base model on the distribution and then evaluate its performance.
|
||||
for target_distribution in unique_target_distributions:
|
||||
output_model_dir, logs = train(
|
||||
base_model_dir,
|
||||
target_distribution,
|
||||
intervention_dir,
|
||||
dry_run=dry_run,
|
||||
train_kwargs=train_kwargs,
|
||||
retrain=retrain_models,
|
||||
)
|
||||
evaluations = evaluate(
|
||||
output_model_dir,
|
||||
[target_distribution],
|
||||
intervention_dir,
|
||||
use_cached=use_cached_evaluations,
|
||||
dry_run=dry_run,
|
||||
eval_kwargs=eval_kwargs,
|
||||
)
|
||||
for item in pgc_result:
|
||||
if item["target"] == target_distribution.id:
|
||||
item["performance_trained_on_target"] = evaluations[0]["average_score"]
|
||||
util.save_json(pgc_result, output_path)
|
||||
util.print_once(f"Saved intermediate results in {output_path}")
|
||||
|
||||
util.print_once("")
|
||||
util.print_once("----------- Measuring generalization -----------")
|
||||
util.print_once("")
|
||||
|
||||
# Train the base model on each source distribution and then evaluate its performance on all target distributions that it is paired with
|
||||
unique_source_distributions = list(set(source_distributions))
|
||||
|
||||
for source_distribution in unique_source_distributions:
|
||||
output_model_dir, logs = train(
|
||||
base_model_dir,
|
||||
source_distribution,
|
||||
intervention_dir,
|
||||
train_kwargs=train_kwargs,
|
||||
retrain=retrain_models,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
# Get the target distributions that this source distribution is paired with
|
||||
indices_of_pairs_source_is_part_of = [
|
||||
i for i, s in enumerate(source_distributions) if source_distribution == s
|
||||
]
|
||||
target_distributions_for_source = [
|
||||
target_distributions[i] for i in indices_of_pairs_source_is_part_of
|
||||
]
|
||||
|
||||
evaluations = evaluate(
|
||||
output_model_dir,
|
||||
target_distributions_for_source,
|
||||
intervention_dir,
|
||||
use_cached=use_cached_evaluations,
|
||||
dry_run=dry_run,
|
||||
eval_kwargs=eval_kwargs,
|
||||
)
|
||||
|
||||
for evaluation, target_distribution in zip(evaluations, target_distributions_for_source):
|
||||
for item in pgc_result:
|
||||
if (
|
||||
item["source"] == source_distribution.id
|
||||
and item["target"] == target_distribution.id
|
||||
):
|
||||
item["performance_trained_on_source"] = evaluations[0]["average_score"]
|
||||
util.save_json(pgc_result, output_path)
|
||||
util.print_once(f"Saved intermediate results in {output_path}")
|
||||
|
||||
for item in pgc_result:
|
||||
item["pgc"] = (
|
||||
item["performance_trained_on_source"] - item["performance_of_base_model"]
|
||||
) / (item["performance_trained_on_target"] - item["performance_of_base_model"])
|
||||
pgcs_to_string = "\n".join(
|
||||
[
|
||||
f"{item['source']} -> {item['target']}: pgc: {item['pgc']}, base: {item['performance_of_base_model']}, source: {item['performance_trained_on_source']}, target: {item['performance_trained_on_target']}"
|
||||
for item in pgc_result
|
||||
]
|
||||
)
|
||||
|
||||
util.save_json(pgc_result, output_path)
|
||||
util.print_once("")
|
||||
util.print_once(f"Results saved at {output_path}")
|
||||
util.print_once(f"PGC: {pgcs_to_string}")
|
||||
util.print_once("")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
||||
@@ -1,228 +0,0 @@
|
||||
import os
|
||||
import traceback
|
||||
import time
|
||||
import torch
|
||||
import transformers
|
||||
from torch.utils.data import DataLoader
|
||||
import re
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import (
|
||||
StoppingCriteriaList,
|
||||
StoppingCriteria,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
)
|
||||
from typing import List, Union, Optional
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
import torch
|
||||
import api.util as util
|
||||
|
||||
|
||||
class Model:
|
||||
def get_tokenizer(self, dir: str, use_fast=True):
|
||||
if "pythia" in dir:
|
||||
util.print_once(
|
||||
"Setting use_fast to true because pythia tokenizer is not compatible with use_fast=False"
|
||||
)
|
||||
use_fast = True
|
||||
if "llama" in dir:
|
||||
util.print_once(
|
||||
"Setting use_fast to false because llama tokenizer is not compatible with use_fast=True"
|
||||
)
|
||||
use_fast = False
|
||||
|
||||
if not os.path.isdir(dir):
|
||||
raise Exception(f"The hf_model {dir} does not exist")
|
||||
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
dir, use_fast=use_fast, trust_remote_code=True
|
||||
)
|
||||
|
||||
# Set padding side to left
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
# Set padding side to left
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
# Set padding token to eos token if pad token is not set
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
return tokenizer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dir: str,
|
||||
hf_model: PreTrainedModel = None,
|
||||
tokenizer: PreTrainedTokenizer = None,
|
||||
use_fast=True,
|
||||
):
|
||||
self.dir = dir
|
||||
if not os.path.isdir(self.dir):
|
||||
raise Exception(f"The hf_model {dir} does not exist")
|
||||
|
||||
if tokenizer == None:
|
||||
self.tokenizer = self.get_tokenizer(dir, use_fast=use_fast)
|
||||
else:
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
if hf_model == None:
|
||||
for i in range(3):
|
||||
try:
|
||||
self.hf_model = transformers.AutoModelForCausalLM.from_pretrained(
|
||||
self.dir,
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
break
|
||||
except:
|
||||
exception_string = traceback.format_exc()
|
||||
if os.path.exists(self.dir + "/pytorch_model.bin"):
|
||||
print(
|
||||
"Failed to load model but pytorch_model.bin exists. This indicates that the model may still be saving. Retrying in 5 seconds."
|
||||
)
|
||||
time.sleep(5)
|
||||
else:
|
||||
break
|
||||
if self.hf_model == None:
|
||||
raise Exception(f"Failed to load model: {exception_string}")
|
||||
else:
|
||||
self.hf_model = hf_model
|
||||
|
||||
try:
|
||||
self.max_length = self.hf_model.config.max_position_embeddings
|
||||
except:
|
||||
pass
|
||||
|
||||
def to(self, device: str):
|
||||
self.hf_model.to(device)
|
||||
return self
|
||||
|
||||
def generate_text(
|
||||
self,
|
||||
prompts: List[str],
|
||||
max_length: Optional[int] = None,
|
||||
stop_string: Optional[str] = None,
|
||||
output_regex: Optional[str] = None,
|
||||
per_device_batch_size=100,
|
||||
**kwargs,
|
||||
) -> Union[str, List[str]]:
|
||||
batch_size = per_device_batch_size
|
||||
|
||||
if max_length == None:
|
||||
max_length = self.max_length
|
||||
|
||||
dataset = TensorDataset(prompts)
|
||||
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
||||
|
||||
encoded_completions = []
|
||||
for batch in tqdm(dataloader):
|
||||
encoded_prompts = self.tokenizer.batch_encode_plus(
|
||||
batch,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
).to(self.hf_model.device)
|
||||
# Add stopping criteria
|
||||
completion_pos = len(encoded_prompts["input_ids"][0])
|
||||
if output_regex == None:
|
||||
output_regex = ""
|
||||
stop_string_regex = ""
|
||||
if stop_string != None:
|
||||
stop_string_regex = r"^(.*?" + stop_string + ")"
|
||||
if stop_string_regex != "" and output_regex != "":
|
||||
completion_regex = stop_string_regex + "|" + output_regex
|
||||
completion_regex = stop_string_regex + output_regex
|
||||
stopping_criteria = StoppingCriteriaList(
|
||||
[RegexStoppingCriteria(self.tokenizer, completion_pos, regex=completion_regex)]
|
||||
)
|
||||
|
||||
# Generate predictions
|
||||
completed_sequences = self.hf_model.generate(
|
||||
input_ids=encoded_prompts["input_ids"],
|
||||
attention_mask=encoded_prompts["attention_mask"],
|
||||
stopping_criteria=stopping_criteria,
|
||||
max_new_tokens=max_length,
|
||||
**kwargs,
|
||||
)
|
||||
completions = [
|
||||
completed_sequences[i][completion_pos:] for i in range(len(completed_sequences))
|
||||
]
|
||||
|
||||
# Remove tokens that follow the eos token
|
||||
for i in range(len(completions)):
|
||||
if self.tokenizer.eos_token_id in list(completions[i]):
|
||||
index = list(completions[i]).index(self.tokenizer.eos_token_id)
|
||||
completions[i] = completions[i][:index]
|
||||
else:
|
||||
pass
|
||||
completions[i] = completions[i][:]
|
||||
encoded_completions.extend(completions)
|
||||
|
||||
encoded_completions = [c.cpu().to(dtype=torch.int64) for c in encoded_completions]
|
||||
|
||||
# Decode the predictions
|
||||
text_completions = [self.tokenizer.decode(ids) for ids in encoded_completions]
|
||||
|
||||
# Post process to remove text generated after stopping conditions were met
|
||||
if completion_regex != "":
|
||||
text_completions = [
|
||||
self.process_completion(text_completion, completion_regex)
|
||||
for text_completion in text_completions
|
||||
]
|
||||
return text_completions
|
||||
|
||||
def process_completion(self, completion, regex):
|
||||
match = re.search(regex, completion)
|
||||
if match:
|
||||
return match.group(0)
|
||||
else:
|
||||
return completion
|
||||
|
||||
def print_generate(
|
||||
self,
|
||||
text: str,
|
||||
max_length: Optional[int] = 100,
|
||||
stop_string: Optional[str] = None,
|
||||
output_regex: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
result = self.generate_text([text], max_length, stop_string, output_regex, **kwargs)[0]
|
||||
return result
|
||||
|
||||
|
||||
class RegexStoppingCriteria(StoppingCriteria):
|
||||
def __init__(self, tokenizer, completion_pos, regex=None):
|
||||
StoppingCriteria.__init__(self),
|
||||
self.tokenizer = tokenizer
|
||||
self.regex = regex
|
||||
self.completion_pos = completion_pos
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
||||
if self.regex == "":
|
||||
return False
|
||||
# stops if all generations include the regex pattern
|
||||
should_stop = []
|
||||
for i in range(len(input_ids)):
|
||||
seq_string = self.tokenizer.decode(input_ids[i][self.completion_pos :])
|
||||
if self.regex != None:
|
||||
match = re.search(self.regex, seq_string)
|
||||
if match:
|
||||
should_stop.append(True)
|
||||
else:
|
||||
should_stop.append(False)
|
||||
if all(should_stop):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class TensorDataset(Dataset):
|
||||
def __init__(self, inputs):
|
||||
self.inputs = inputs
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.inputs[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.inputs)
|
||||
@@ -1 +0,0 @@
|
||||
from data_classes import Distribution
|
||||
@@ -1,52 +0,0 @@
|
||||
import os
|
||||
import api.util as util
|
||||
from typing import List, Optional, Union
|
||||
import fire
|
||||
from api.data_classes import Distribution
|
||||
|
||||
|
||||
def train(
|
||||
model_dir: str,
|
||||
distribution: Union[str, Distribution],
|
||||
intervention_dir: str,
|
||||
dry_run: bool = False,
|
||||
train_kwargs: bool = {},
|
||||
retrain: dict = False,
|
||||
) -> Optional[List[dict]]:
|
||||
|
||||
train_module = util.import_module_from_path(f"{intervention_dir}/train.py")
|
||||
model_name = os.path.basename(model_dir)
|
||||
if isinstance(distribution, str):
|
||||
distribution = Distribution(distribution)
|
||||
intervention_name = os.path.basename(intervention_dir)
|
||||
output_model_name = f"{intervention_name}/{model_name}-{distribution.id}"
|
||||
output_model_dir = f"models/{output_model_name}"
|
||||
|
||||
util.print_once("")
|
||||
if os.path.exists(output_model_dir) and not retrain:
|
||||
util.print_once(f"Model {output_model_name} already exists. Skipping training.")
|
||||
return output_model_dir, None
|
||||
|
||||
util.print_once(
|
||||
f"# Training {model_name} on {distribution.id} with strategy '{intervention_name}'"
|
||||
)
|
||||
if dry_run:
|
||||
util.print_once("Skipping training because dry_run=True.")
|
||||
return output_model_dir, None
|
||||
else:
|
||||
logs = train_module.main(
|
||||
model_dir=model_dir,
|
||||
output_dir=output_model_dir,
|
||||
training_distribution_dir=distribution.dir,
|
||||
test_distribution_dir=distribution.dir,
|
||||
**train_kwargs,
|
||||
)
|
||||
return output_model_dir, logs
|
||||
|
||||
|
||||
def fire_wrap(*args, **kwargs):
|
||||
train(*args, **kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(fire_wrap)
|
||||
@@ -1,57 +0,0 @@
|
||||
import os
|
||||
import importlib.util
|
||||
import json
|
||||
import subprocess
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
repo_path = os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))
|
||||
accelerator = None
|
||||
|
||||
|
||||
def execute_command(command):
|
||||
# Run the Bash script with arguments using subprocess
|
||||
try:
|
||||
# Use the shell=True argument to run the script in a shell
|
||||
result = subprocess.run(command, shell=True, check=True, text=True)
|
||||
print("Bash script output:", result.stdout)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print("Error running Bash script:", e)
|
||||
|
||||
|
||||
def save_json(data, file_path):
|
||||
if not os.path.exists(file_path):
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
|
||||
def load_json(file_path):
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def import_module_from_path(path):
|
||||
spec = importlib.util.spec_from_file_location("module.name", path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def print_once(string):
|
||||
if accelerator != None:
|
||||
if accelerator.is_main_process:
|
||||
print(string)
|
||||
else:
|
||||
print(string)
|
||||
|
||||
|
||||
def openai_api_key():
|
||||
credentials = load_json(f"{repo_path}/configs/credentials.json")
|
||||
# return credentials["openai_api_key"]
|
||||
return credentials["openai_api_key2"]
|
||||
|
||||
|
||||
def wandb_api_key():
|
||||
credentials = load_json(f"{repo_path}/configs/credentials.json")
|
||||
return credentials["wandb_api_key"]
|
||||
Reference in New Issue
Block a user