mirror of
https://github.com/wassname/dreamerv3-torch.git
synced 2026-06-27 16:15:31 +08:00
runs
This commit is contained in:
+27
-26
@@ -149,30 +149,30 @@ craftax:
|
||||
imag_gradient: 'reinforce'
|
||||
time_limit: 4000
|
||||
|
||||
craftax_small:
|
||||
task: craftax_Craftax-Symbolic-AutoReset-v1
|
||||
step: 1e6
|
||||
action_repeat: 1
|
||||
envs: 1
|
||||
train_ratio: 512
|
||||
video_pred_log: false
|
||||
dyn_hidden: 256
|
||||
dyn_deter: 256
|
||||
dyn_stoch: 24
|
||||
dyn_discrete: 24
|
||||
# note: depth is cnn hidden_dim
|
||||
encoder: {cnn_keys: 'state_map', cnn_depth: 32, kernel_size: 4, minres: 2, mlp_keys: "state_inventory", mlp_layers: 2, mlp_units: 16}
|
||||
decoder: {cnn_keys: 'state_map', cnn_depth: 32, kernel_size: 4, minres: 2, mlp_keys: "state_inventory", mlp_layers: 2, mlp_units: 16}
|
||||
actor: {layers: 3, dist: 'onehot', std: 'none'}
|
||||
value: {layers: 3}
|
||||
# note units is the head hidden_dim
|
||||
units: 256
|
||||
reward_head: {layers: 3}
|
||||
cont_head: {layers: 3}
|
||||
imag_gradient: 'reinforce'
|
||||
batch_size: 256
|
||||
batch_length: 32
|
||||
time_limit: 4000
|
||||
# craftax_small:
|
||||
# task: craftax_Craftax-Symbolic-AutoReset-v1
|
||||
# step: 1e6
|
||||
# action_repeat: 1
|
||||
# envs: 1
|
||||
# train_ratio: 512
|
||||
# video_pred_log: false
|
||||
# dyn_hidden: 256
|
||||
# dyn_deter: 256
|
||||
# dyn_stoch: 24
|
||||
# dyn_discrete: 24
|
||||
# # note: depth is cnn hidden_dim
|
||||
# encoder: {cnn_keys: 'state_map', cnn_depth: 32, kernel_size: 4, minres: 2, mlp_keys: "state_inventory", mlp_layers: 2, mlp_units: 16}
|
||||
# decoder: {cnn_keys: 'state_map', cnn_depth: 32, kernel_size: 4, minres: 2, mlp_keys: "state_inventory", mlp_layers: 2, mlp_units: 16}
|
||||
# actor: {layers: 3, dist: 'onehot', std: 'none'}
|
||||
# value: {layers: 3}
|
||||
# # note units is the head hidden_dim
|
||||
# units: 256
|
||||
# reward_head: {layers: 3}
|
||||
# cont_head: {layers: 3}
|
||||
# imag_gradient: 'reinforce'
|
||||
# batch_size: 256
|
||||
# batch_length: 32
|
||||
# time_limit: 4000
|
||||
|
||||
craftax_smaller:
|
||||
task: craftax_Craftax-Symbolic-AutoReset-v1
|
||||
@@ -195,10 +195,11 @@ craftax_smaller:
|
||||
reward_head: {layers: 2}
|
||||
cont_head: {layers: 2}
|
||||
imag_gradient: 'reinforce'
|
||||
batch_size: 256
|
||||
batch_length: 32
|
||||
batch_size: 128
|
||||
batch_length: 64
|
||||
time_limit: 4000
|
||||
dataset_size: 20_000
|
||||
precision: 16
|
||||
|
||||
atari100k:
|
||||
steps: 4e5
|
||||
|
||||
+3
-2
@@ -308,7 +308,7 @@ def main(config):
|
||||
agent.load_state_dict(checkpoint["agent_state_dict"])
|
||||
tools.recursively_load_optim_state_dict(agent, checkpoint["optims_state_dict"])
|
||||
agent._should_pretrain._once = False
|
||||
logger.warning(f"Loaded model from {logdir / 'latest.pt'}")
|
||||
logger.warning(f"⚠️ Loaded model from {logdir / 'latest.pt'}, this could invalidate your step budget")
|
||||
|
||||
# make sure eval will be executed once after config.steps
|
||||
with tqdm(total=config.steps + config.eval_every, unit='step', mininterval=60) as pbar:
|
||||
@@ -360,7 +360,7 @@ def main(config):
|
||||
def parse_args(argv=None):
|
||||
# first load config name as arg from command line
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--configs", nargs="+")
|
||||
parser.add_argument("--configs", nargs="+", help="one or more config files")
|
||||
if argv is None:
|
||||
argv = sys.argv
|
||||
args, remaining = parser.parse_known_args(argv[1:])
|
||||
@@ -391,6 +391,7 @@ def parse_args(argv=None):
|
||||
for key, value in sorted(defaults.items(), key=lambda x: x[0]):
|
||||
arg_type = tools.args_type(value)
|
||||
parser.add_argument(f"--{key}", type=arg_type, default=arg_type(value))
|
||||
parser.print_usage()
|
||||
args = parser.parse_args(remaining)
|
||||
logger.info(f"config={args}")
|
||||
return args
|
||||
|
||||
Generated
+22
-1
@@ -3578,6 +3578,27 @@ type = "legacy"
|
||||
url = "https://download.pytorch.org/whl/cu121"
|
||||
reference = "pytorch"
|
||||
|
||||
[[package]]
|
||||
name = "torch-tb-profiler"
|
||||
version = "0.4.3"
|
||||
description = "PyTorch Profiler TensorBoard Plugin"
|
||||
optional = false
|
||||
python-versions = ">=3.6.2"
|
||||
files = [
|
||||
{file = "torch_tb_profiler-0.4.3-py3-none-any.whl", hash = "sha256:207a49b05572dd983e4ab29eb5e0fcadd60374a8f93c78ec638217e8d18788dc"},
|
||||
{file = "torch_tb_profiler-0.4.3.tar.gz", hash = "sha256:8b8d29b2de960b3c4423087b23cec29beaf9ac3a8c7b046c18fd25b218f726b1"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pandas = ">=1.0.0"
|
||||
tensorboard = ">=1.15,<2.1.0 || >2.1.0"
|
||||
|
||||
[package.extras]
|
||||
blob = ["azure-storage-blob"]
|
||||
gs = ["google-cloud-storage"]
|
||||
hdfs = ["fsspec", "pyarrow"]
|
||||
s3 = ["boto3"]
|
||||
|
||||
[[package]]
|
||||
name = "torchinfo"
|
||||
version = "1.8.0"
|
||||
@@ -3847,4 +3868,4 @@ test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-it
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.9"
|
||||
content-hash = "0275da73363d94f6a5cdadc9662c1b254ef50310aabce3aa663552aa4802b001"
|
||||
content-hash = "d33d68488789a4b8e6b23f5ee89e706b6fbb5efd73d56404d2bc2883da6157fb"
|
||||
|
||||
@@ -39,6 +39,7 @@ craftax = {path = "/media/wassname/SGIronWolf/projects5/2024/Craftax", develop =
|
||||
# craftax = {git = "https://github.com/wassname/Craftax" , develop = true }
|
||||
chex = "^0.1.86"
|
||||
torchinfo = "^1.8.0"
|
||||
torch-tb-profiler = "^0.4.3"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ipywidgets = "^8.1.3"
|
||||
|
||||
Reference in New Issue
Block a user