diff --git a/configs.yaml b/configs.yaml index 1d0149a..817d62d 100644 --- a/configs.yaml +++ b/configs.yaml @@ -149,30 +149,30 @@ craftax: imag_gradient: 'reinforce' time_limit: 4000 -craftax_small: - task: craftax_Craftax-Symbolic-AutoReset-v1 - step: 1e6 - action_repeat: 1 - envs: 1 - train_ratio: 512 - video_pred_log: false - dyn_hidden: 256 - dyn_deter: 256 - dyn_stoch: 24 - dyn_discrete: 24 - # note: depth is cnn hidden_dim - encoder: {cnn_keys: 'state_map', cnn_depth: 32, kernel_size: 4, minres: 2, mlp_keys: "state_inventory", mlp_layers: 2, mlp_units: 16} - decoder: {cnn_keys: 'state_map', cnn_depth: 32, kernel_size: 4, minres: 2, mlp_keys: "state_inventory", mlp_layers: 2, mlp_units: 16} - actor: {layers: 3, dist: 'onehot', std: 'none'} - value: {layers: 3} - # note units is the head hidden_dim - units: 256 - reward_head: {layers: 3} - cont_head: {layers: 3} - imag_gradient: 'reinforce' - batch_size: 256 - batch_length: 32 - time_limit: 4000 +# craftax_small: +# task: craftax_Craftax-Symbolic-AutoReset-v1 +# step: 1e6 +# action_repeat: 1 +# envs: 1 +# train_ratio: 512 +# video_pred_log: false +# dyn_hidden: 256 +# dyn_deter: 256 +# dyn_stoch: 24 +# dyn_discrete: 24 +# # note: depth is cnn hidden_dim +# encoder: {cnn_keys: 'state_map', cnn_depth: 32, kernel_size: 4, minres: 2, mlp_keys: "state_inventory", mlp_layers: 2, mlp_units: 16} +# decoder: {cnn_keys: 'state_map', cnn_depth: 32, kernel_size: 4, minres: 2, mlp_keys: "state_inventory", mlp_layers: 2, mlp_units: 16} +# actor: {layers: 3, dist: 'onehot', std: 'none'} +# value: {layers: 3} +# # note units is the head hidden_dim +# units: 256 +# reward_head: {layers: 3} +# cont_head: {layers: 3} +# imag_gradient: 'reinforce' +# batch_size: 256 +# batch_length: 32 +# time_limit: 4000 craftax_smaller: task: craftax_Craftax-Symbolic-AutoReset-v1 @@ -195,10 +195,11 @@ craftax_smaller: reward_head: {layers: 2} cont_head: {layers: 2} imag_gradient: 'reinforce' - batch_size: 256 - batch_length: 32 + batch_size: 128 + batch_length: 64 time_limit: 4000 dataset_size: 20_000 + precision: 16 atari100k: steps: 4e5 diff --git a/dreamer.py b/dreamer.py index bbc4d8f..0431121 100644 --- a/dreamer.py +++ b/dreamer.py @@ -308,7 +308,7 @@ def main(config): agent.load_state_dict(checkpoint["agent_state_dict"]) tools.recursively_load_optim_state_dict(agent, checkpoint["optims_state_dict"]) agent._should_pretrain._once = False - logger.warning(f"Loaded model from {logdir / 'latest.pt'}") + logger.warning(f"⚠️ Loaded model from {logdir / 'latest.pt'}, this could invalidate your step budget") # make sure eval will be executed once after config.steps with tqdm(total=config.steps + config.eval_every, unit='step', mininterval=60) as pbar: @@ -360,7 +360,7 @@ def main(config): def parse_args(argv=None): # first load config name as arg from command line parser = argparse.ArgumentParser() - parser.add_argument("--configs", nargs="+") + parser.add_argument("--configs", nargs="+", help="one or more config files") if argv is None: argv = sys.argv args, remaining = parser.parse_known_args(argv[1:]) @@ -391,6 +391,7 @@ def parse_args(argv=None): for key, value in sorted(defaults.items(), key=lambda x: x[0]): arg_type = tools.args_type(value) parser.add_argument(f"--{key}", type=arg_type, default=arg_type(value)) + parser.print_usage() args = parser.parse_args(remaining) logger.info(f"config={args}") return args diff --git a/poetry.lock b/poetry.lock index 7cee8d2..9af5adf 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3578,6 +3578,27 @@ type = "legacy" url = "https://download.pytorch.org/whl/cu121" reference = "pytorch" +[[package]] +name = "torch-tb-profiler" +version = "0.4.3" +description = "PyTorch Profiler TensorBoard Plugin" +optional = false +python-versions = ">=3.6.2" +files = [ + {file = "torch_tb_profiler-0.4.3-py3-none-any.whl", hash = "sha256:207a49b05572dd983e4ab29eb5e0fcadd60374a8f93c78ec638217e8d18788dc"}, + {file = "torch_tb_profiler-0.4.3.tar.gz", hash = "sha256:8b8d29b2de960b3c4423087b23cec29beaf9ac3a8c7b046c18fd25b218f726b1"}, +] + +[package.dependencies] +pandas = ">=1.0.0" +tensorboard = ">=1.15,<2.1.0 || >2.1.0" + +[package.extras] +blob = ["azure-storage-blob"] +gs = ["google-cloud-storage"] +hdfs = ["fsspec", "pyarrow"] +s3 = ["boto3"] + [[package]] name = "torchinfo" version = "1.8.0" @@ -3847,4 +3868,4 @@ test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-it [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "0275da73363d94f6a5cdadc9662c1b254ef50310aabce3aa663552aa4802b001" +content-hash = "d33d68488789a4b8e6b23f5ee89e706b6fbb5efd73d56404d2bc2883da6157fb" diff --git a/pyproject.toml b/pyproject.toml index 5af5591..67efe48 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ craftax = {path = "/media/wassname/SGIronWolf/projects5/2024/Craftax", develop = # craftax = {git = "https://github.com/wassname/Craftax" , develop = true } chex = "^0.1.86" torchinfo = "^1.8.0" +torch-tb-profiler = "^0.4.3" [tool.poetry.group.dev.dependencies] ipywidgets = "^8.1.3"