mirror of
https://github.com/wassname/GENIES.git
synced 2026-06-27 16:10:25 +08:00
93 lines
2.9 KiB
Python
93 lines
2.9 KiB
Python
import fire
|
|
from api.data_classes import Distribution
|
|
from typing import Union, List, Optional
|
|
import wandb
|
|
import copy
|
|
import api.util as util
|
|
import os
|
|
import wandb
|
|
|
|
# Example sweep configuration:
|
|
default_sweep_configuration = {
|
|
"method": "random",
|
|
"name": "sweep",
|
|
"metric": {"goal": "maximize", "name": "average_score"},
|
|
"parameters": {
|
|
"learning_rate": {"max": 1e-4, "min": 8e-6},
|
|
},
|
|
}
|
|
|
|
|
|
def hyper_parameter_sweep(
|
|
model_dir: str,
|
|
distribution: Union[str, Distribution],
|
|
intervention_dir: str,
|
|
sweep_configuration: dict = default_sweep_configuration,
|
|
train_kwargs: dict = {},
|
|
eval_kwargs: dict = {},
|
|
count: int = 4,
|
|
wandb_project="project",
|
|
) -> Optional[List[dict]]:
|
|
if isinstance(distribution, str):
|
|
distribution = Distribution(distribution)
|
|
|
|
(key=util.wandb_api_key())
|
|
|
|
sweep_id = wandb.sweep(sweep=sweep_configuration, project=wandb_project)
|
|
|
|
wandb.agent(
|
|
sweep_id,
|
|
function=lambda: train_with_params(
|
|
model_dir, distribution, intervention_dir, train_kwargs, eval_kwargs
|
|
),
|
|
count=count,
|
|
)
|
|
|
|
|
|
def train_with_params(model_dir, distribution, intervention_dir, train_kwargs, eval_kwargs):
|
|
wandb.init()
|
|
parameters = wandb.config
|
|
run_train_kwargs = copy.deepcopy(train_kwargs)
|
|
run_train_kwargs.update(parameters)
|
|
if isinstance(distribution, str):
|
|
distribution = Distribution(distribution)
|
|
|
|
model_name = os.path.basename(model_dir)
|
|
intervention_name = os.path.basename(intervention_dir)
|
|
output_model_name = f"{intervention_name}/{model_name}-{distribution.id}-{'-'.join([f'{k}={v}' for k,v in parameters.items()])}"
|
|
output_model_dir = f"models/hps/{output_model_name}"
|
|
eval_output_path = f"results/evaluations/hps/{output_model_name}.json"
|
|
|
|
train_command = f"accelerate launch --deepspeed_config_file configs/ds_zero_3.json --use_deepspeed {intervention_dir}/train.py"
|
|
run_train_kwargs.update(
|
|
{
|
|
"model_dir": model_dir,
|
|
"output_dir": output_model_dir,
|
|
"training_distribution_dir": distribution.dir,
|
|
}
|
|
)
|
|
for k, v in run_train_kwargs.items():
|
|
train_command += f' --{k} "{str(v)}"'
|
|
|
|
util.execute_command(train_command)
|
|
|
|
eval_command = f"accelerate launch --deepspeed_config_file configs/ds_zero_3.json --use_deepspeed {intervention_dir}/eval.py"
|
|
eval_kwargs.update(
|
|
{
|
|
"model_dir": model_dir,
|
|
"distribution_dirs": [distribution.dir],
|
|
"output_paths": [eval_output_path],
|
|
}
|
|
)
|
|
for k, v in eval_kwargs.items():
|
|
eval_command += f' --{k} "{str(v)}"'
|
|
util.execute_command(eval_command)
|
|
|
|
# Open json to get evaluation
|
|
result = util.load_json(eval_output_path)
|
|
wandb.log({"average_score": result["average_score"]})
|
|
|
|
|
|
if __name__ == "__main__":
|
|
fire.Fire(hyper_parameter_sweep)
|