Files
GENIES/build/lib/api/hyperparameter_sweep.py
T
JoshuaClymer 70bd2ea15d Initial commit
2023-11-11 19:44:09 +00:00

93 lines
2.9 KiB
Python

import fire
from api.data_classes import Distribution
from typing import Union, List, Optional
import wandb
import copy
import api.util as util
import os
import wandb
# Example sweep configuration:
default_sweep_configuration = {
"method": "random",
"name": "sweep",
"metric": {"goal": "maximize", "name": "average_score"},
"parameters": {
"learning_rate": {"max": 1e-4, "min": 8e-6},
},
}
def hyper_parameter_sweep(
model_dir: str,
distribution: Union[str, Distribution],
intervention_dir: str,
sweep_configuration: dict = default_sweep_configuration,
train_kwargs: dict = {},
eval_kwargs: dict = {},
count: int = 4,
wandb_project="project",
) -> Optional[List[dict]]:
if isinstance(distribution, str):
distribution = Distribution(distribution)
(key=util.wandb_api_key())
sweep_id = wandb.sweep(sweep=sweep_configuration, project=wandb_project)
wandb.agent(
sweep_id,
function=lambda: train_with_params(
model_dir, distribution, intervention_dir, train_kwargs, eval_kwargs
),
count=count,
)
def train_with_params(model_dir, distribution, intervention_dir, train_kwargs, eval_kwargs):
wandb.init()
parameters = wandb.config
run_train_kwargs = copy.deepcopy(train_kwargs)
run_train_kwargs.update(parameters)
if isinstance(distribution, str):
distribution = Distribution(distribution)
model_name = os.path.basename(model_dir)
intervention_name = os.path.basename(intervention_dir)
output_model_name = f"{intervention_name}/{model_name}-{distribution.id}-{'-'.join([f'{k}={v}' for k,v in parameters.items()])}"
output_model_dir = f"models/hps/{output_model_name}"
eval_output_path = f"results/evaluations/hps/{output_model_name}.json"
train_command = f"accelerate launch --deepspeed_config_file configs/ds_zero_3.json --use_deepspeed {intervention_dir}/train.py"
run_train_kwargs.update(
{
"model_dir": model_dir,
"output_dir": output_model_dir,
"training_distribution_dir": distribution.dir,
}
)
for k, v in run_train_kwargs.items():
train_command += f' --{k} "{str(v)}"'
util.execute_command(train_command)
eval_command = f"accelerate launch --deepspeed_config_file configs/ds_zero_3.json --use_deepspeed {intervention_dir}/eval.py"
eval_kwargs.update(
{
"model_dir": model_dir,
"distribution_dirs": [distribution.dir],
"output_paths": [eval_output_path],
}
)
for k, v in eval_kwargs.items():
eval_command += f' --{k} "{str(v)}"'
util.execute_command(eval_command)
# Open json to get evaluation
result = util.load_json(eval_output_path)
wandb.log({"average_score": result["average_score"]})
if __name__ == "__main__":
fire.Fire(hyper_parameter_sweep)