update download script

This commit is contained in:
deep1
2023-04-22 16:03:20 +08:00
parent 3410c9cad0
commit ba547ccb70
+26 -18
View File
@@ -4,7 +4,8 @@ Downloads models from Hugging Face to models/model-name.
Example:
python download-model.py facebook/opt-1.3b
From https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/download-model.py
see https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/download-model.py
'''
@@ -22,17 +23,6 @@ import tqdm
from tqdm.contrib.concurrent import thread_map
parser = argparse.ArgumentParser()
parser.add_argument('MODEL', type=str, default=None, nargs='?')
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
args = parser.parse_args()
def select_model_from_default_options():
models = {
"OPT 6.7B": ("facebook", "opt-6.7b", "main"),
@@ -54,13 +44,17 @@ def select_model_from_default_options():
char = chr(ord('A') + i)
choices[char] = name
print(f"{char}) {name}")
char = chr(ord('A') + len(models))
print(f"{char}) None of the above")
char_hugging = chr(ord('A') + len(models))
print(f"{char_hugging}) Manually specify a Hugging Face model")
char_exit = chr(ord('A') + len(models) + 1)
print(f"{char_exit}) Do not download a model")
print()
print("Input> ", end='')
choice = input()[0].strip().upper()
if choice == char:
if choice == char_exit:
exit()
elif choice == char_hugging:
print("""\nThen type the name of your desired Hugging Face model in the format organization/name.
Examples:
@@ -94,7 +88,7 @@ def sanitize_model_and_branch_names(model, branch):
def get_download_links_from_huggingface(model, branch, text_only=False):
base = "https://huggingface.co"
page = f"/api/models/{model}/tree/{branch}?cursor="
page = f"/api/models/{model}/tree/{branch}"
cursor = b""
links = []
@@ -106,7 +100,10 @@ def get_download_links_from_huggingface(model, branch, text_only=False):
has_safetensors = False
is_lora = False
while True:
content = requests.get(f"{base}{page}{cursor.decode()}").content
url = f"{base}{page}" + (f"?cursor={cursor.decode()}" if cursor else "")
r = requests.get(url)
r.raise_for_status()
content = r.content
dict = json.loads(content)
if len(dict) == 0:
@@ -124,7 +121,7 @@ def get_download_links_from_huggingface(model, branch, text_only=False):
is_safetensors = re.match(".*\.safetensors", fname)
is_pt = re.match(".*\.pt", fname)
is_ggml = re.match("ggml.*\.bin", fname)
is_tokenizer = re.match("tokenizer.*\.model", fname)
is_tokenizer = re.match("(tokenizer|ice).*\.model", fname)
is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer
if any((is_pytorch, is_safetensors, is_pt, is_ggml, is_tokenizer, is_text)):
@@ -249,6 +246,17 @@ def check_model_files(model, branch, links, sha256, output_folder):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('MODEL', type=str, default=None, nargs='?')
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
args = parser.parse_args()
branch = args.branch
model = args.MODEL
if model is None: