Initial commit

This commit is contained in:
JoshuaClymer
2023-11-11 19:44:09 +00:00
commit 70bd2ea15d
6313 changed files with 4401768 additions and 0 deletions
+52
View File
@@ -0,0 +1,52 @@
import os
import api.util as util
from typing import List, Optional, Union
import fire
from api.data_classes import Distribution
def train(
model_dir: str,
distribution: Union[str, Distribution],
intervention_dir: str,
dry_run: bool = False,
train_kwargs: bool = {},
retrain: dict = False,
) -> Optional[List[dict]]:
train_module = util.import_module_from_path(f"{intervention_dir}/train.py")
model_name = os.path.basename(model_dir)
if isinstance(distribution, str):
distribution = Distribution(distribution)
intervention_name = os.path.basename(intervention_dir)
output_model_name = f"{intervention_name}/{model_name}-{distribution.id}"
output_model_dir = f"models/{output_model_name}"
util.print_once("")
if os.path.exists(output_model_dir) and not retrain:
util.print_once(f"Model {output_model_name} already exists. Skipping training.")
return output_model_dir, None
util.print_once(
f"# Training {model_name} on {distribution.id} with strategy '{intervention_name}'"
)
if dry_run:
util.print_once("Skipping training because dry_run=True.")
return output_model_dir, None
else:
logs = train_module.main(
model_dir=model_dir,
output_dir=output_model_dir,
training_distribution_dir=distribution.dir,
test_distribution_dir=distribution.dir,
**train_kwargs,
)
return output_model_dir, logs
def fire_wrap(*args, **kwargs):
train(*args, **kwargs)
if __name__ == "__main__":
fire.Fire(fire_wrap)