mirror of
https://github.com/wassname/Run-Skeleton-Run.git
synced 2026-06-27 19:45:48 +08:00
88 lines
2.3 KiB
Python
88 lines
2.3 KiB
Python
import os
|
|
import sys
|
|
import random
|
|
import numpy as np
|
|
|
|
|
|
def create_if_need(path):
|
|
if not os.path.exists(path):
|
|
os.makedirs(path)
|
|
|
|
|
|
def boolean_flag(parser, name, default=False, help=None):
|
|
"""Add a boolean flag to argparse parser.
|
|
|
|
Parameters
|
|
----------
|
|
parser: argparse.Parser
|
|
parser to add the flag to
|
|
name: str
|
|
--<name> will enable the flag, while --no-<name> will disable it
|
|
default: bool or None
|
|
default value of the flag
|
|
help: str
|
|
help string for the flag
|
|
"""
|
|
dest = name.replace('-', '_')
|
|
parser.add_argument("--" + name, action="store_true", default=default, dest=dest, help=help)
|
|
parser.add_argument("--no-" + name, action="store_false", dest=dest)
|
|
|
|
|
|
def str2params(string, delimeter="-"):
|
|
try:
|
|
result = list(map(int, string.split(delimeter)))
|
|
except:
|
|
result = None
|
|
return result
|
|
|
|
|
|
def set_global_seeds(i):
|
|
try:
|
|
import torch
|
|
except ImportError:
|
|
pass
|
|
else:
|
|
torch.manual_seed(i)
|
|
try:
|
|
import tensorflow as tf
|
|
except ImportError:
|
|
pass
|
|
else:
|
|
tf.set_random_seed(i)
|
|
np.random.seed(i)
|
|
random.seed(i)
|
|
|
|
|
|
def query_yes_no(question, default="no"):
|
|
"""Ask a yes/no question via input() and return their answer.
|
|
|
|
"question" is a string that is presented to the user.
|
|
"default" is the presumed answer if the user just hits <Enter>.
|
|
It must be "yes" (the default), "no" or None (meaning
|
|
an answer is required of the user).
|
|
|
|
The "answer" return value is True for "yes" or False for "no".
|
|
"""
|
|
valid = {
|
|
"yes": True, "y": True, "ye": True,
|
|
"no": False, "n": False
|
|
}
|
|
if default is None:
|
|
prompt = " [y/n] "
|
|
elif default == "yes":
|
|
prompt = " [Y/n] "
|
|
elif default == "no":
|
|
prompt = " [y/N] "
|
|
else:
|
|
raise ValueError("invalid default answer: '%s'" % default)
|
|
|
|
while True:
|
|
sys.stdout.write(question + prompt)
|
|
choice = input().lower()
|
|
if default is not None and choice == '':
|
|
return valid[default]
|
|
elif choice in valid:
|
|
return valid[choice]
|
|
else:
|
|
sys.stdout.write("Please respond with 'yes' or 'no' "
|
|
"(or 'y' or 'n').\n") |