git clone script working

This commit is contained in:
deep1
2023-04-22 18:42:01 +08:00
parent ba547ccb70
commit e4c04daadf
6 changed files with 35 additions and 26 deletions
+17 -16
View File
@@ -4,34 +4,35 @@ clones models from Hugging Face to models/model-name.
Example:
python clone-model.py facebook/opt-1.3b
this seems nicer than the previous script https://github.com/oobabooga/text-generation-webui/blob/main/download-model.py
'''
from git import Repo
import argparse
from tqdm.auto import tqdm
from git import RemoteProgress
import os
import subprocess
parser = argparse.ArgumentParser()
parser.add_argument('MODEL', type=str, default=None, help="`tloen/alpaca-lora-7b`")
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
parser.add_argument('-b', '--branch', type=str, default='main', help='Name of the Git branch to download from.')
parser.add_argument('-t', '--text_only', action='store_true')
args = parser.parse_args()
class CloneProgress(RemoteProgress):
"""tqdm progress bar for GitPython"""
def __init__(self):
super().__init__()
self.pbar = tqdm()
def update(self, op_code, cur_count, max_count=None, message=''):
self.pbar.total = max_count
self.pbar.n = cur_count
self.pbar.refresh()
if __name__ == '__main__':
model = args.MODEL
repo = 'https://huggingface.co/' + model
name = model.replace('/', '_')
dest = f'./models/{name}'
output_folder = './data/loras' if 'lora' in name else './data/models'
dest = f'{output_folder}/{name}_git'
if args.text_only:
os.environ['GIT_LFS_SKIP_SMUDGE']="1"
result = subprocess.run(['git', 'lfs'], capture_output=True)
assert result.returncode==0, 'git lfs should be installed'
print(f'cloning "{repo}" to "{dest}"')
Repo.clone_from(repo, dest, progress=CloneProgress())
Repo.clone_from(repo, dest, multi_options=['--depth=1'])