This commit is contained in:
wassname
2024-06-07 12:38:08 +08:00
parent 3c7ee12182
commit 66440b0e48
4 changed files with 53 additions and 29 deletions
+27 -26
View File
@@ -149,30 +149,30 @@ craftax:
imag_gradient: 'reinforce'
time_limit: 4000
craftax_small:
task: craftax_Craftax-Symbolic-AutoReset-v1
step: 1e6
action_repeat: 1
envs: 1
train_ratio: 512
video_pred_log: false
dyn_hidden: 256
dyn_deter: 256
dyn_stoch: 24
dyn_discrete: 24
# note: depth is cnn hidden_dim
encoder: {cnn_keys: 'state_map', cnn_depth: 32, kernel_size: 4, minres: 2, mlp_keys: "state_inventory", mlp_layers: 2, mlp_units: 16}
decoder: {cnn_keys: 'state_map', cnn_depth: 32, kernel_size: 4, minres: 2, mlp_keys: "state_inventory", mlp_layers: 2, mlp_units: 16}
actor: {layers: 3, dist: 'onehot', std: 'none'}
value: {layers: 3}
# note units is the head hidden_dim
units: 256
reward_head: {layers: 3}
cont_head: {layers: 3}
imag_gradient: 'reinforce'
batch_size: 256
batch_length: 32
time_limit: 4000
# craftax_small:
# task: craftax_Craftax-Symbolic-AutoReset-v1
# step: 1e6
# action_repeat: 1
# envs: 1
# train_ratio: 512
# video_pred_log: false
# dyn_hidden: 256
# dyn_deter: 256
# dyn_stoch: 24
# dyn_discrete: 24
# # note: depth is cnn hidden_dim
# encoder: {cnn_keys: 'state_map', cnn_depth: 32, kernel_size: 4, minres: 2, mlp_keys: "state_inventory", mlp_layers: 2, mlp_units: 16}
# decoder: {cnn_keys: 'state_map', cnn_depth: 32, kernel_size: 4, minres: 2, mlp_keys: "state_inventory", mlp_layers: 2, mlp_units: 16}
# actor: {layers: 3, dist: 'onehot', std: 'none'}
# value: {layers: 3}
# # note units is the head hidden_dim
# units: 256
# reward_head: {layers: 3}
# cont_head: {layers: 3}
# imag_gradient: 'reinforce'
# batch_size: 256
# batch_length: 32
# time_limit: 4000
craftax_smaller:
task: craftax_Craftax-Symbolic-AutoReset-v1
@@ -195,10 +195,11 @@ craftax_smaller:
reward_head: {layers: 2}
cont_head: {layers: 2}
imag_gradient: 'reinforce'
batch_size: 256
batch_length: 32
batch_size: 128
batch_length: 64
time_limit: 4000
dataset_size: 20_000
precision: 16
atari100k:
steps: 4e5
+3 -2
View File
@@ -308,7 +308,7 @@ def main(config):
agent.load_state_dict(checkpoint["agent_state_dict"])
tools.recursively_load_optim_state_dict(agent, checkpoint["optims_state_dict"])
agent._should_pretrain._once = False
logger.warning(f"Loaded model from {logdir / 'latest.pt'}")
logger.warning(f"⚠️ Loaded model from {logdir / 'latest.pt'}, this could invalidate your step budget")
# make sure eval will be executed once after config.steps
with tqdm(total=config.steps + config.eval_every, unit='step', mininterval=60) as pbar:
@@ -360,7 +360,7 @@ def main(config):
def parse_args(argv=None):
# first load config name as arg from command line
parser = argparse.ArgumentParser()
parser.add_argument("--configs", nargs="+")
parser.add_argument("--configs", nargs="+", help="one or more config files")
if argv is None:
argv = sys.argv
args, remaining = parser.parse_known_args(argv[1:])
@@ -391,6 +391,7 @@ def parse_args(argv=None):
for key, value in sorted(defaults.items(), key=lambda x: x[0]):
arg_type = tools.args_type(value)
parser.add_argument(f"--{key}", type=arg_type, default=arg_type(value))
parser.print_usage()
args = parser.parse_args(remaining)
logger.info(f"config={args}")
return args
Generated
+22 -1
View File
@@ -3578,6 +3578,27 @@ type = "legacy"
url = "https://download.pytorch.org/whl/cu121"
reference = "pytorch"
[[package]]
name = "torch-tb-profiler"
version = "0.4.3"
description = "PyTorch Profiler TensorBoard Plugin"
optional = false
python-versions = ">=3.6.2"
files = [
{file = "torch_tb_profiler-0.4.3-py3-none-any.whl", hash = "sha256:207a49b05572dd983e4ab29eb5e0fcadd60374a8f93c78ec638217e8d18788dc"},
{file = "torch_tb_profiler-0.4.3.tar.gz", hash = "sha256:8b8d29b2de960b3c4423087b23cec29beaf9ac3a8c7b046c18fd25b218f726b1"},
]
[package.dependencies]
pandas = ">=1.0.0"
tensorboard = ">=1.15,<2.1.0 || >2.1.0"
[package.extras]
blob = ["azure-storage-blob"]
gs = ["google-cloud-storage"]
hdfs = ["fsspec", "pyarrow"]
s3 = ["boto3"]
[[package]]
name = "torchinfo"
version = "1.8.0"
@@ -3847,4 +3868,4 @@ test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-it
[metadata]
lock-version = "2.0"
python-versions = "^3.9"
content-hash = "0275da73363d94f6a5cdadc9662c1b254ef50310aabce3aa663552aa4802b001"
content-hash = "d33d68488789a4b8e6b23f5ee89e706b6fbb5efd73d56404d2bc2883da6157fb"
+1
View File
@@ -39,6 +39,7 @@ craftax = {path = "/media/wassname/SGIronWolf/projects5/2024/Craftax", develop =
# craftax = {git = "https://github.com/wassname/Craftax" , develop = true }
chex = "^0.1.86"
torchinfo = "^1.8.0"
torch-tb-profiler = "^0.4.3"
[tool.poetry.group.dev.dependencies]
ipywidgets = "^8.1.3"