mirror of
https://github.com/wassname/alpaca_convert.git
synced 2026-06-27 16:14:08 +08:00
update download script
This commit is contained in:
+26
-18
@@ -4,7 +4,8 @@ Downloads models from Hugging Face to models/model-name.
|
||||
Example:
|
||||
python download-model.py facebook/opt-1.3b
|
||||
|
||||
From https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/download-model.py
|
||||
|
||||
see https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/download-model.py
|
||||
|
||||
'''
|
||||
|
||||
@@ -22,17 +23,6 @@ import tqdm
|
||||
from tqdm.contrib.concurrent import thread_map
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('MODEL', type=str, default=None, nargs='?')
|
||||
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
|
||||
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
|
||||
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
|
||||
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
|
||||
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
|
||||
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def select_model_from_default_options():
|
||||
models = {
|
||||
"OPT 6.7B": ("facebook", "opt-6.7b", "main"),
|
||||
@@ -54,13 +44,17 @@ def select_model_from_default_options():
|
||||
char = chr(ord('A') + i)
|
||||
choices[char] = name
|
||||
print(f"{char}) {name}")
|
||||
char = chr(ord('A') + len(models))
|
||||
print(f"{char}) None of the above")
|
||||
char_hugging = chr(ord('A') + len(models))
|
||||
print(f"{char_hugging}) Manually specify a Hugging Face model")
|
||||
char_exit = chr(ord('A') + len(models) + 1)
|
||||
print(f"{char_exit}) Do not download a model")
|
||||
|
||||
print()
|
||||
print("Input> ", end='')
|
||||
choice = input()[0].strip().upper()
|
||||
if choice == char:
|
||||
if choice == char_exit:
|
||||
exit()
|
||||
elif choice == char_hugging:
|
||||
print("""\nThen type the name of your desired Hugging Face model in the format organization/name.
|
||||
|
||||
Examples:
|
||||
@@ -94,7 +88,7 @@ def sanitize_model_and_branch_names(model, branch):
|
||||
|
||||
def get_download_links_from_huggingface(model, branch, text_only=False):
|
||||
base = "https://huggingface.co"
|
||||
page = f"/api/models/{model}/tree/{branch}?cursor="
|
||||
page = f"/api/models/{model}/tree/{branch}"
|
||||
cursor = b""
|
||||
|
||||
links = []
|
||||
@@ -106,7 +100,10 @@ def get_download_links_from_huggingface(model, branch, text_only=False):
|
||||
has_safetensors = False
|
||||
is_lora = False
|
||||
while True:
|
||||
content = requests.get(f"{base}{page}{cursor.decode()}").content
|
||||
url = f"{base}{page}" + (f"?cursor={cursor.decode()}" if cursor else "")
|
||||
r = requests.get(url)
|
||||
r.raise_for_status()
|
||||
content = r.content
|
||||
|
||||
dict = json.loads(content)
|
||||
if len(dict) == 0:
|
||||
@@ -124,7 +121,7 @@ def get_download_links_from_huggingface(model, branch, text_only=False):
|
||||
is_safetensors = re.match(".*\.safetensors", fname)
|
||||
is_pt = re.match(".*\.pt", fname)
|
||||
is_ggml = re.match("ggml.*\.bin", fname)
|
||||
is_tokenizer = re.match("tokenizer.*\.model", fname)
|
||||
is_tokenizer = re.match("(tokenizer|ice).*\.model", fname)
|
||||
is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer
|
||||
|
||||
if any((is_pytorch, is_safetensors, is_pt, is_ggml, is_tokenizer, is_text)):
|
||||
@@ -249,6 +246,17 @@ def check_model_files(model, branch, links, sha256, output_folder):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('MODEL', type=str, default=None, nargs='?')
|
||||
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
|
||||
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
|
||||
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
|
||||
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
|
||||
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
|
||||
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
|
||||
args = parser.parse_args()
|
||||
|
||||
branch = args.branch
|
||||
model = args.MODEL
|
||||
if model is None:
|
||||
|
||||
Reference in New Issue
Block a user