From 718e92a9a11d1480d22a5fd1967b135e3d68999a Mon Sep 17 00:00:00 2001 From: wassname Date: Fri, 7 Jun 2024 06:00:27 +0800 Subject: [PATCH] fix mem overflow, torchinfo --- configs.yaml | 6 +- dreamer.py | 8 +- justfile | 2 +- models.py | 6 +- nbs/02_torchinfo copy.ipynb | 1295 ++++++++++++++++++++++++++--------- nbs/load_runs.ipynb | 142 ++++ networks.py | 8 + tools.py | 14 +- 8 files changed, 1145 insertions(+), 336 deletions(-) create mode 100644 nbs/load_runs.ipynb diff --git a/configs.yaml b/configs.yaml index 50ee679..1d0149a 100644 --- a/configs.yaml +++ b/configs.yaml @@ -69,7 +69,7 @@ defaults: model_lr: 1e-4 opt_eps: 1e-8 grad_clip: 1000 - dataset_size: 1000000 + dataset_size: 1_000_000 opt: 'adam' # Behavior. @@ -147,6 +147,7 @@ craftax: reward_head: {layers: 5} cont_head: {layers: 5} imag_gradient: 'reinforce' + time_limit: 4000 craftax_small: task: craftax_Craftax-Symbolic-AutoReset-v1 @@ -171,6 +172,7 @@ craftax_small: imag_gradient: 'reinforce' batch_size: 256 batch_length: 32 + time_limit: 4000 craftax_smaller: task: craftax_Craftax-Symbolic-AutoReset-v1 @@ -195,6 +197,8 @@ craftax_smaller: imag_gradient: 'reinforce' batch_size: 256 batch_length: 32 + time_limit: 4000 + dataset_size: 20_000 atari100k: steps: 4e5 diff --git a/dreamer.py b/dreamer.py index 79204a3..0ffc536 100644 --- a/dreamer.py +++ b/dreamer.py @@ -350,7 +350,7 @@ def main(config): } torch.save(items_to_save, logdir / "latest.pt") logger.info(f"Saved model to {logdir / 'latest.pt'}") - # pbar.update(agent._step-pbar.n) # 16858 at a time + pbar.update(agent._step-pbar.n) # 16858 at a time for env in train_envs + eval_envs: try: env.close() @@ -358,6 +358,7 @@ def main(config): pass def parse_args(argv=None): + # first load config name as arg from command line parser = argparse.ArgumentParser() parser.add_argument("--configs", nargs="+") if argv is None: @@ -382,11 +383,16 @@ def parse_args(argv=None): for name in name_list: recursive_update(defaults, configs[name]) + # defaults = {k:tools.args_type(v)(v) for k, v in defaults.items()} + # config = argparse.Namespace(**defaults) + + # now use argparse to parse config, allowing us to override config with any extra args from cli. You can even use -h parser = argparse.ArgumentParser() for key, value in sorted(defaults.items(), key=lambda x: x[0]): arg_type = tools.args_type(value) parser.add_argument(f"--{key}", type=arg_type, default=arg_type(value)) args = parser.parse_args(remaining) + logger.info(f"config={args}") return args if __name__ == "__main__": diff --git a/justfile b/justfile index c6c8b0a..e441114 100644 --- a/justfile +++ b/justfile @@ -6,7 +6,7 @@ export TQDM_MININTERVAL := "30" main: . ./.venv/bin/activate - python dreamer.py --configs craftax_small --logdir ./logdir/crafter + python dreamer.py --configs craftax_smaller --logdir ./logdir/crafterer logs: tensorboard --logdir logdir/craftax diff --git a/models.py b/models.py index b7dccb2..84b3f52 100644 --- a/models.py +++ b/models.py @@ -5,7 +5,7 @@ from torch import nn import networks import tools from loguru import logger - +from torchinfo import summary from envs.craftax_env import state2img to_np = lambda x: x.detach().cpu().numpy() @@ -99,9 +99,7 @@ class WorldModel(nn.Module): opt=config.opt, use_amp=self._use_amp, ) - logger.info( - f"Optimizer model_opt has {sum(param.numel() for param in self.parameters())} variables." - ) + logger.info(f"World Model\n{summary(self, row_settings=['var_names'],)}") # other losses are scaled by 1.0. self._scales = dict( reward=config.reward_head["loss_scale"], diff --git a/nbs/02_torchinfo copy.ipynb b/nbs/02_torchinfo copy.ipynb index 5886a8c..3159920 100644 --- a/nbs/02_torchinfo copy.ipynb +++ b/nbs/02_torchinfo copy.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -19,7 +19,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -48,30 +48,31 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "['../dreamer.py', '--configs', 'craftax_small2', '--logdir', '../logdir/craftax_small2']\n" + "['../dreamer.py', '--configs', 'craftax_smaller', '--logdir', '../logdir/craftax_smaller']\n", + "\u001b[32m2024-06-07 05:25:00.098\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mdreamer\u001b[0m:\u001b[36mparse_args\u001b[0m:\u001b[36m392\u001b[0m - \u001b[1margs=Namespace(act='SiLU', action_repeat=1, actor={'layers': 2, 'dist': 'onehot', 'entropy': 0.0003, 'unimix_ratio': 0.01, 'std': 'none', 'min_std': 0.1, 'max_std': 1.0, 'temp': 0.1, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 1.0}, batch_length=32, batch_size=256, compile=True, cont_head={'layers': 2, 'loss_scale': 1.0, 'outscale': 1.0}, critic={'layers': 2, 'dist': 'symlog_disc', 'slow_target': True, 'slow_target_update': 1, 'slow_target_fraction': 0.02, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 0.0}, dataset_size=20000, debug=False, decoder={'mlp_keys': 'state_inventory', 'cnn_keys': 'state_map', 'act': 'SiLU', 'norm': True, 'cnn_depth': 16, 'kernel_size': 4, 'minres': 2, 'mlp_layers': 2, 'mlp_units': 8, 'cnn_sigmoid': False, 'image_dist': 'mse', 'vector_dist': 'symlog_mse', 'outscale': 1.0}, deterministic_run=False, device='cuda:0', disag_action_cond=False, disag_layers=4, disag_log=True, disag_models=10, disag_offset=1, disag_target='stoch', disag_units=400, discount=0.997, discount_lambda=0.95, dyn_deter=128, dyn_discrete=16, dyn_hidden=128, dyn_mean_act='none', dyn_min_std=0.1, dyn_rec_depth=1, dyn_scale=0.5, dyn_std_act='sigmoid2', dyn_stoch=16, encoder={'mlp_keys': 'state_inventory', 'cnn_keys': 'state_map', 'act': 'SiLU', 'norm': True, 'cnn_depth': 16, 'kernel_size': 4, 'minres': 2, 'mlp_layers': 2, 'mlp_units': 8, 'symlog_inputs': True}, envs=1, eval_episode_num=10, eval_every=10000.0, eval_state_mean=False, evaldir=None, expl_behavior='greedy', expl_extr_scale=0.0, expl_intr_scale=1.0, expl_until=0, grad_clip=1000, grad_heads=('decoder', 'reward', 'cont'), grayscale=False, imag_gradient='reinforce', imag_gradient_mix=0.0, imag_horizon=15, initial='learned', kl_free=1.0, log_every=10000.0, logdir='../logdir/craftax_smaller', model_lr=0.0001, norm=True, offline_evaldir='', offline_traindir='', opt='adam', opt_eps=1e-08, parallel=False, precision=32, prefill=2500, pretrain=100, rep_scale=0.1, reset_every=0, reward_EMA=True, reward_head={'layers': 2, 'dist': 'symlog_disc', 'loss_scale': 1.0, 'outscale': 0.0}, seed=0, size=(64, 64), step=1000000.0, steps=1000000.0, task='craftax_Craftax-Symbolic-AutoReset-v1', time_limit=4000, train_ratio=512, traindir=None, unimix_ratio=0.01, units=128, value={'layers': 2}, video_pred_log=False, weight_decay=0.0)\u001b[0m\n" ] }, { "data": { "text/plain": [ - "Namespace(act='SiLU', action_repeat=1, actor={'layers': 3, 'dist': 'onehot', 'entropy': 0.0003, 'unimix_ratio': 0.01, 'std': 'none', 'min_std': 0.1, 'max_std': 1.0, 'temp': 0.1, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 1.0}, batch_length=32, batch_size=256, compile=True, cont_head={'layers': 3, 'loss_scale': 1.0, 'outscale': 1.0}, critic={'layers': 2, 'dist': 'symlog_disc', 'slow_target': True, 'slow_target_update': 1, 'slow_target_fraction': 0.02, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 0.0}, dataset_size=1000000, debug=False, decoder={'mlp_keys': 'state_inventory', 'cnn_keys': 'state_map', 'act': 'SiLU', 'norm': True, 'cnn_depth': 32, 'kernel_size': 4, 'minres': 2, 'mlp_layers': 2, 'mlp_units': 16, 'cnn_sigmoid': False, 'image_dist': 'mse', 'vector_dist': 'symlog_mse', 'outscale': 1.0}, deterministic_run=False, device='cuda:0', disag_action_cond=False, disag_layers=4, disag_log=True, disag_models=10, disag_offset=1, disag_target='stoch', disag_units=400, discount=0.997, discount_lambda=0.95, dyn_deter=256, dyn_discrete=24, dyn_hidden=256, dyn_mean_act='none', dyn_min_std=0.1, dyn_rec_depth=1, dyn_scale=0.5, dyn_std_act='sigmoid2', dyn_stoch=24, encoder={'mlp_keys': 'state_inventory', 'cnn_keys': 'state_map', 'act': 'SiLU', 'norm': True, 'cnn_depth': 32, 'kernel_size': 4, 'minres': 2, 'mlp_layers': 2, 'mlp_units': 16, 'symlog_inputs': True}, envs=1, eval_episode_num=10, eval_every=10000.0, eval_state_mean=False, evaldir=None, expl_behavior='greedy', expl_extr_scale=0.0, expl_intr_scale=1.0, expl_until=0, grad_clip=1000, grad_heads=('decoder', 'reward', 'cont'), grayscale=False, imag_gradient='reinforce', imag_gradient_mix=0.0, imag_horizon=15, initial='learned', kl_free=1.0, log_every=10000.0, logdir='../logdir/craftax_small2', model_lr=0.0001, norm=True, offline_evaldir='', offline_traindir='', opt='adam', opt_eps=1e-08, parallel=False, precision=32, prefill=2500, pretrain=100, rep_scale=0.1, reset_every=0, reward_EMA=True, reward_head={'layers': 3, 'dist': 'symlog_disc', 'loss_scale': 1.0, 'outscale': 0.0}, seed=0, size=(64, 64), step=1000000.0, steps=1000000.0, task='craftax_Craftax-Symbolic-AutoReset-v1', time_limit=1000, train_ratio=512, traindir=None, unimix_ratio=0.01, units=256, value={'layers': 3}, video_pred_log=False, weight_decay=0.0)" + "Namespace(act='SiLU', action_repeat=1, actor={'layers': 2, 'dist': 'onehot', 'entropy': 0.0003, 'unimix_ratio': 0.01, 'std': 'none', 'min_std': 0.1, 'max_std': 1.0, 'temp': 0.1, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 1.0}, batch_length=32, batch_size=256, compile=True, cont_head={'layers': 2, 'loss_scale': 1.0, 'outscale': 1.0}, critic={'layers': 2, 'dist': 'symlog_disc', 'slow_target': True, 'slow_target_update': 1, 'slow_target_fraction': 0.02, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 0.0}, dataset_size=20000, debug=False, decoder={'mlp_keys': 'state_inventory', 'cnn_keys': 'state_map', 'act': 'SiLU', 'norm': True, 'cnn_depth': 16, 'kernel_size': 4, 'minres': 2, 'mlp_layers': 2, 'mlp_units': 8, 'cnn_sigmoid': False, 'image_dist': 'mse', 'vector_dist': 'symlog_mse', 'outscale': 1.0}, deterministic_run=False, device='cuda:0', disag_action_cond=False, disag_layers=4, disag_log=True, disag_models=10, disag_offset=1, disag_target='stoch', disag_units=400, discount=0.997, discount_lambda=0.95, dyn_deter=128, dyn_discrete=16, dyn_hidden=128, dyn_mean_act='none', dyn_min_std=0.1, dyn_rec_depth=1, dyn_scale=0.5, dyn_std_act='sigmoid2', dyn_stoch=16, encoder={'mlp_keys': 'state_inventory', 'cnn_keys': 'state_map', 'act': 'SiLU', 'norm': True, 'cnn_depth': 16, 'kernel_size': 4, 'minres': 2, 'mlp_layers': 2, 'mlp_units': 8, 'symlog_inputs': True}, envs=1, eval_episode_num=10, eval_every=10000.0, eval_state_mean=False, evaldir=None, expl_behavior='greedy', expl_extr_scale=0.0, expl_intr_scale=1.0, expl_until=0, grad_clip=1000, grad_heads=('decoder', 'reward', 'cont'), grayscale=False, imag_gradient='reinforce', imag_gradient_mix=0.0, imag_horizon=15, initial='learned', kl_free=1.0, log_every=10000.0, logdir='../logdir/craftax_smaller', model_lr=0.0001, norm=True, offline_evaldir='', offline_traindir='', opt='adam', opt_eps=1e-08, parallel=False, precision=32, prefill=2500, pretrain=100, rep_scale=0.1, reset_every=0, reward_EMA=True, reward_head={'layers': 2, 'dist': 'symlog_disc', 'loss_scale': 1.0, 'outscale': 0.0}, seed=0, size=(64, 64), step=1000000.0, steps=1000000.0, task='craftax_Craftax-Symbolic-AutoReset-v1', time_limit=4000, train_ratio=512, traindir=None, unimix_ratio=0.01, units=128, value={'layers': 2}, video_pred_log=False, weight_decay=0.0)" ] }, - "execution_count": 4, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# emulate cli\n", - "argv = f\"../dreamer.py --configs craftax_small2 --logdir ../logdir/craftax_small2\"\n", + "argv = f\"../dreamer.py --configs craftax_smaller --logdir ../logdir/craftax_smaller\"\n", "argv = argv.split()\n", "print(argv)\n", "config = parse_args(argv)\n", @@ -80,19 +81,623 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_20867/1136853903.py:1: UnsafeLoaderWarning: \n", + "The default 'Loader' for 'load(stream)' without further arguments can be unsafe.\n", + "Use 'load(stream, Loader=ruamel.yaml.Loader)' explicitly if that is OK.\n", + "Alternatively include the following in your code:\n", + "\n", + " import warnings\n", + " warnings.simplefilter('ignore', ruamel.yaml.error.UnsafeLoaderWarning)\n", + "\n", + "In most other cases you should consider using 'safe_load(stream)'\n", + " yaml.load(\n" + ] + }, + { + "data": { + "text/plain": [ + "{'defaults': {'logdir': None,\n", + " 'traindir': None,\n", + " 'evaldir': None,\n", + " 'offline_traindir': '',\n", + " 'offline_evaldir': '',\n", + " 'seed': 0,\n", + " 'deterministic_run': False,\n", + " 'steps': 1000000.0,\n", + " 'parallel': False,\n", + " 'eval_every': 10000.0,\n", + " 'eval_episode_num': 10,\n", + " 'log_every': 10000.0,\n", + " 'reset_every': 0,\n", + " 'device': 'cuda:0',\n", + " 'compile': True,\n", + " 'precision': 32,\n", + " 'debug': False,\n", + " 'video_pred_log': False,\n", + " 'task': 'dmc_walker_walk',\n", + " 'size': [64, 64],\n", + " 'envs': 1,\n", + " 'action_repeat': 2,\n", + " 'time_limit': 1000,\n", + " 'grayscale': False,\n", + " 'prefill': 2500,\n", + " 'reward_EMA': True,\n", + " 'dyn_hidden': 512,\n", + " 'dyn_deter': 512,\n", + " 'dyn_stoch': 32,\n", + " 'dyn_discrete': 32,\n", + " 'dyn_rec_depth': 1,\n", + " 'dyn_mean_act': 'none',\n", + " 'dyn_std_act': 'sigmoid2',\n", + " 'dyn_min_std': 0.1,\n", + " 'grad_heads': ['decoder', 'reward', 'cont'],\n", + " 'units': 512,\n", + " 'act': 'SiLU',\n", + " 'norm': True,\n", + " 'encoder': {'mlp_keys': '$^',\n", + " 'cnn_keys': 'image',\n", + " 'act': 'SiLU',\n", + " 'norm': True,\n", + " 'cnn_depth': 32,\n", + " 'kernel_size': 4,\n", + " 'minres': 4,\n", + " 'mlp_layers': 5,\n", + " 'mlp_units': 1024,\n", + " 'symlog_inputs': True},\n", + " 'decoder': {'mlp_keys': '$^',\n", + " 'cnn_keys': 'image',\n", + " 'act': 'SiLU',\n", + " 'norm': True,\n", + " 'cnn_depth': 32,\n", + " 'kernel_size': 4,\n", + " 'minres': 4,\n", + " 'mlp_layers': 5,\n", + " 'mlp_units': 1024,\n", + " 'cnn_sigmoid': False,\n", + " 'image_dist': 'mse',\n", + " 'vector_dist': 'symlog_mse',\n", + " 'outscale': 1.0},\n", + " 'actor': {'layers': 2,\n", + " 'dist': 'normal',\n", + " 'entropy': 0.0003,\n", + " 'unimix_ratio': 0.01,\n", + " 'std': 'learned',\n", + " 'min_std': 0.1,\n", + " 'max_std': 1.0,\n", + " 'temp': 0.1,\n", + " 'lr': 3e-05,\n", + " 'eps': 1e-05,\n", + " 'grad_clip': 100.0,\n", + " 'outscale': 1.0},\n", + " 'critic': {'layers': 2,\n", + " 'dist': 'symlog_disc',\n", + " 'slow_target': True,\n", + " 'slow_target_update': 1,\n", + " 'slow_target_fraction': 0.02,\n", + " 'lr': 3e-05,\n", + " 'eps': 1e-05,\n", + " 'grad_clip': 100.0,\n", + " 'outscale': 0.0},\n", + " 'reward_head': {'layers': 2,\n", + " 'dist': 'symlog_disc',\n", + " 'loss_scale': 1.0,\n", + " 'outscale': 0.0},\n", + " 'cont_head': {'layers': 2, 'loss_scale': 1.0, 'outscale': 1.0},\n", + " 'dyn_scale': 0.5,\n", + " 'rep_scale': 0.1,\n", + " 'kl_free': 1.0,\n", + " 'weight_decay': 0.0,\n", + " 'unimix_ratio': 0.01,\n", + " 'initial': 'learned',\n", + " 'batch_size': 64,\n", + " 'batch_length': 64,\n", + " 'train_ratio': 512,\n", + " 'pretrain': 100,\n", + " 'model_lr': 0.0001,\n", + " 'opt_eps': 1e-08,\n", + " 'grad_clip': 1000,\n", + " 'dataset_size': 1000000,\n", + " 'opt': 'adam',\n", + " 'discount': 0.997,\n", + " 'discount_lambda': 0.95,\n", + " 'imag_horizon': 15,\n", + " 'imag_gradient': 'dynamics',\n", + " 'imag_gradient_mix': 0.0,\n", + " 'eval_state_mean': False,\n", + " 'expl_behavior': 'greedy',\n", + " 'expl_until': 0,\n", + " 'expl_extr_scale': 0.0,\n", + " 'expl_intr_scale': 1.0,\n", + " 'disag_target': 'stoch',\n", + " 'disag_log': True,\n", + " 'disag_models': 10,\n", + " 'disag_offset': 1,\n", + " 'disag_layers': 4,\n", + " 'disag_units': 400,\n", + " 'disag_action_cond': False},\n", + " 'dmc_proprio': {'steps': 500000.0,\n", + " 'action_repeat': 2,\n", + " 'envs': 4,\n", + " 'train_ratio': 512,\n", + " 'video_pred_log': False,\n", + " 'encoder': {'mlp_keys': '.*', 'cnn_keys': '$^'},\n", + " 'decoder': {'mlp_keys': '.*', 'cnn_keys': '$^'}},\n", + " 'dmc_vision': {'steps': 1000000.0,\n", + " 'action_repeat': 2,\n", + " 'envs': 4,\n", + " 'train_ratio': 512,\n", + " 'video_pred_log': True,\n", + " 'encoder': {'mlp_keys': '$^', 'cnn_keys': 'image'},\n", + " 'decoder': {'mlp_keys': '$^', 'cnn_keys': 'image'}},\n", + " 'crafter': {'task': 'crafter_reward',\n", + " 'step': 1000000.0,\n", + " 'action_repeat': 1,\n", + " 'envs': 1,\n", + " 'train_ratio': 512,\n", + " 'video_pred_log': False,\n", + " 'dyn_hidden': 1024,\n", + " 'dyn_deter': 4096,\n", + " 'units': 1024,\n", + " 'encoder': {'mlp_keys': '$^',\n", + " 'cnn_keys': 'image',\n", + " 'cnn_depth': 96,\n", + " 'mlp_layers': 5,\n", + " 'mlp_units': 1024},\n", + " 'decoder': {'mlp_keys': '$^',\n", + " 'cnn_keys': 'image',\n", + " 'cnn_depth': 96,\n", + " 'mlp_layers': 5,\n", + " 'mlp_units': 1024},\n", + " 'actor': {'layers': 5, 'dist': 'onehot', 'std': 'none'},\n", + " 'value': {'layers': 5},\n", + " 'reward_head': {'layers': 5},\n", + " 'cont_head': {'layers': 5},\n", + " 'imag_gradient': 'reinforce'},\n", + " 'craftax': {'task': 'craftax_Craftax-Symbolic-AutoReset-v1',\n", + " 'step': 1000000.0,\n", + " 'action_repeat': 1,\n", + " 'envs': 1,\n", + " 'train_ratio': 512,\n", + " 'video_pred_log': False,\n", + " 'dyn_hidden': 1024,\n", + " 'dyn_deter': 4096,\n", + " 'units': 1024,\n", + " 'encoder': {'cnn_keys': '$^',\n", + " 'mlp_keys': 'state',\n", + " 'mlp_layers': 4,\n", + " 'mlp_units': 512},\n", + " 'decoder': {'cnn_keys': '$^',\n", + " 'mlp_keys': 'state',\n", + " 'mlp_layers': 4,\n", + " 'mlp_units': 512},\n", + " 'actor': {'layers': 5, 'dist': 'onehot', 'std': 'none'},\n", + " 'value': {'layers': 5},\n", + " 'reward_head': {'layers': 5},\n", + " 'cont_head': {'layers': 5},\n", + " 'imag_gradient': 'reinforce',\n", + " 'time_limit': 4000},\n", + " 'craftax_small': {'task': 'craftax_Craftax-Symbolic-AutoReset-v1',\n", + " 'step': 1000000.0,\n", + " 'action_repeat': 1,\n", + " 'envs': 1,\n", + " 'train_ratio': 512,\n", + " 'video_pred_log': False,\n", + " 'dyn_hidden': 256,\n", + " 'dyn_deter': 256,\n", + " 'dyn_stoch': 24,\n", + " 'dyn_discrete': 24,\n", + " 'encoder': {'cnn_keys': 'state_map',\n", + " 'cnn_depth': 32,\n", + " 'kernel_size': 4,\n", + " 'minres': 2,\n", + " 'mlp_keys': 'state_inventory',\n", + " 'mlp_layers': 2,\n", + " 'mlp_units': 16},\n", + " 'decoder': {'cnn_keys': 'state_map',\n", + " 'cnn_depth': 32,\n", + " 'kernel_size': 4,\n", + " 'minres': 2,\n", + " 'mlp_keys': 'state_inventory',\n", + " 'mlp_layers': 2,\n", + " 'mlp_units': 16},\n", + " 'actor': {'layers': 3, 'dist': 'onehot', 'std': 'none'},\n", + " 'value': {'layers': 3},\n", + " 'units': 256,\n", + " 'reward_head': {'layers': 3},\n", + " 'cont_head': {'layers': 3},\n", + " 'imag_gradient': 'reinforce',\n", + " 'batch_size': 256,\n", + " 'batch_length': 32,\n", + " 'time_limit': 4000},\n", + " 'craftax_smaller': {'task': 'craftax_Craftax-Symbolic-AutoReset-v1',\n", + " 'step': 1000000.0,\n", + " 'action_repeat': 1,\n", + " 'envs': 1,\n", + " 'train_ratio': 512,\n", + " 'video_pred_log': False,\n", + " 'dyn_hidden': 128,\n", + " 'dyn_deter': 128,\n", + " 'dyn_stoch': 16,\n", + " 'dyn_discrete': 16,\n", + " 'encoder': {'cnn_keys': 'state_map',\n", + " 'cnn_depth': 16,\n", + " 'kernel_size': 4,\n", + " 'minres': 2,\n", + " 'mlp_keys': 'state_inventory',\n", + " 'mlp_layers': 2,\n", + " 'mlp_units': 8},\n", + " 'decoder': {'cnn_keys': 'state_map',\n", + " 'cnn_depth': 16,\n", + " 'kernel_size': 4,\n", + " 'minres': 2,\n", + " 'mlp_keys': 'state_inventory',\n", + " 'mlp_layers': 2,\n", + " 'mlp_units': 8},\n", + " 'actor': {'layers': 2, 'dist': 'onehot', 'std': 'none'},\n", + " 'value': {'layers': 2},\n", + " 'units': 128,\n", + " 'reward_head': {'layers': 2},\n", + " 'cont_head': {'layers': 2},\n", + " 'imag_gradient': 'reinforce',\n", + " 'batch_size': 256,\n", + " 'batch_length': 32,\n", + " 'time_limit': 4000,\n", + " 'dataset_size': 20000},\n", + " 'atari100k': {'steps': 400000.0,\n", + " 'envs': 1,\n", + " 'action_repeat': 4,\n", + " 'train_ratio': 1024,\n", + " 'video_pred_log': True,\n", + " 'eval_episode_num': 100,\n", + " 'actor': {'dist': 'onehot', 'std': 'none'},\n", + " 'imag_gradient': 'reinforce',\n", + " 'stickey': False,\n", + " 'lives': 'unused',\n", + " 'noops': 30,\n", + " 'resize': 'opencv',\n", + " 'actions': 'needed',\n", + " 'time_limit': 108000},\n", + " 'minecraft': {'task': 'minecraft_diamond',\n", + " 'step': 100000000.0,\n", + " 'parallel': True,\n", + " 'envs': 16,\n", + " 'eval_episode_num': 0,\n", + " 'eval_every': 10000.0,\n", + " 'action_repeat': 1,\n", + " 'train_ratio': 16,\n", + " 'video_pred_log': True,\n", + " 'dyn_hidden': 1024,\n", + " 'dyn_deter': 4096,\n", + " 'units': 1024,\n", + " 'encoder': {'mlp_keys': 'inventory|inventory_max|equipped|health|hunger|breath|obs_reward',\n", + " 'cnn_keys': 'image',\n", + " 'cnn_depth': 96,\n", + " 'mlp_layers': 5,\n", + " 'mlp_units': 1024},\n", + " 'decoder': {'mlp_keys': 'inventory|inventory_max|equipped|health|hunger|breath',\n", + " 'cnn_keys': 'image',\n", + " 'cnn_depth': 96,\n", + " 'mlp_layers': 5,\n", + " 'mlp_units': 1024},\n", + " 'actor': {'layers': 5, 'dist': 'onehot', 'std': 'none'},\n", + " 'value': {'layers': 5},\n", + " 'reward_head': {'layers': 5},\n", + " 'cont_head': {'layers': 5},\n", + " 'imag_gradient': 'reinforce',\n", + " 'break_speed': 100.0,\n", + " 'time_limit': 36000},\n", + " 'memorymaze': {'steps': 100000000.0,\n", + " 'action_repeat': 2,\n", + " 'actor': {'dist': 'onehot', 'std': 'none'},\n", + " 'imag_gradient': 'reinforce',\n", + " 'task': 'memorymaze_9x9'},\n", + " 'debug': {'debug': True,\n", + " 'pretrain': 1,\n", + " 'prefill': 1,\n", + " 'batch_size': 10,\n", + " 'batch_length': 20}}" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "yaml.safe_load(\n", + " pathlib.Path(\"/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/configs.yaml\").read_text()\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 41, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[32m2024-06-06 17:08:10.379\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m27\u001b[0m - \u001b[1mLogdir ../logdir/craftax_small2\u001b[0m\n", - "\u001b[32m2024-06-06 17:08:10.384\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m36\u001b[0m - \u001b[1mCreate envs.\u001b[0m\n", - "\u001b[32m2024-06-06 17:08:41.190\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m57\u001b[0m - \u001b[1mAction Space Box(0.0, 1.0, (43,), float32)\u001b[0m\n", - "\u001b[32m2024-06-06 17:08:41.191\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m63\u001b[0m - \u001b[1mPrefill dataset (26 steps).\u001b[0m\n", - "\u001b[32m2024-06-06 17:09:31.587\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m92\u001b[0m - \u001b[1mLogger: (2500 steps).\u001b[0m\n", - "\u001b[32m2024-06-06 17:09:31.588\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m94\u001b[0m - \u001b[1mSimulate agent.\u001b[0m\n" + "usage: ipykernel_launcher.py [-h] [--act ACT] [--action_repeat ACTION_REPEAT]\n", + " [--actor ACTOR] [--batch_length BATCH_LENGTH]\n", + " [--batch_size BATCH_SIZE] [--compile COMPILE]\n", + " [--cont_head CONT_HEAD] [--critic CRITIC]\n", + " [--dataset_size DATASET_SIZE] [--debug DEBUG]\n", + " [--decoder DECODER]\n", + " [--deterministic_run DETERMINISTIC_RUN]\n", + " [--device DEVICE]\n", + " [--disag_action_cond DISAG_ACTION_COND]\n", + " [--disag_layers DISAG_LAYERS]\n", + " [--disag_log DISAG_LOG]\n", + " [--disag_models DISAG_MODELS]\n", + " [--disag_offset DISAG_OFFSET]\n", + " [--disag_target DISAG_TARGET]\n", + " [--disag_units DISAG_UNITS] [--discount DISCOUNT]\n", + " [--discount_lambda DISCOUNT_LAMBDA]\n", + " [--dyn_deter DYN_DETER]\n", + " [--dyn_discrete DYN_DISCRETE]\n", + " [--dyn_hidden DYN_HIDDEN]\n", + " [--dyn_mean_act DYN_MEAN_ACT]\n", + " [--dyn_min_std DYN_MIN_STD]\n", + " [--dyn_rec_depth DYN_REC_DEPTH]\n", + " [--dyn_scale DYN_SCALE]\n", + " [--dyn_std_act DYN_STD_ACT]\n", + " [--dyn_stoch DYN_STOCH] [--encoder ENCODER]\n", + " [--envs ENVS]\n", + " [--eval_episode_num EVAL_EPISODE_NUM]\n", + " [--eval_every EVAL_EVERY]\n", + " [--eval_state_mean EVAL_STATE_MEAN]\n", + " [--evaldir EVALDIR]\n", + " [--expl_behavior EXPL_BEHAVIOR]\n", + " [--expl_extr_scale EXPL_EXTR_SCALE]\n", + " [--expl_intr_scale EXPL_INTR_SCALE]\n", + " [--expl_until EXPL_UNTIL] [--grad_clip GRAD_CLIP]\n", + " [--grad_heads GRAD_HEADS] [--grayscale GRAYSCALE]\n", + " [--imag_gradient IMAG_GRADIENT]\n", + " [--imag_gradient_mix IMAG_GRADIENT_MIX]\n", + " [--imag_horizon IMAG_HORIZON] [--initial INITIAL]\n", + " [--kl_free KL_FREE] [--log_every LOG_EVERY]\n", + " [--logdir LOGDIR] [--model_lr MODEL_LR]\n", + " [--norm NORM] [--offline_evaldir OFFLINE_EVALDIR]\n", + " [--offline_traindir OFFLINE_TRAINDIR] [--opt OPT]\n", + " [--opt_eps OPT_EPS] [--parallel PARALLEL]\n", + " [--precision PRECISION] [--prefill PREFILL]\n", + " [--pretrain PRETRAIN] [--rep_scale REP_SCALE]\n", + " [--reset_every RESET_EVERY]\n", + " [--reward_EMA REWARD_EMA]\n", + " [--reward_head REWARD_HEAD] [--seed SEED]\n", + " [--size SIZE] [--steps STEPS] [--task TASK]\n", + " [--time_limit TIME_LIMIT]\n", + " [--train_ratio TRAIN_RATIO] [--traindir TRAINDIR]\n", + " [--unimix_ratio UNIMIX_RATIO] [--units UNITS]\n", + " [--video_pred_log VIDEO_PRED_LOG]\n", + " [--weight_decay WEIGHT_DECAY]\n", + "\n", + "optional arguments:\n", + " -h, --help show this help message and exit\n", + " --act ACT\n", + " --action_repeat ACTION_REPEAT\n", + " --actor ACTOR\n", + " --batch_length BATCH_LENGTH\n", + " --batch_size BATCH_SIZE\n", + " --compile COMPILE\n", + " --cont_head CONT_HEAD\n", + " --critic CRITIC\n", + " --dataset_size DATASET_SIZE\n", + " --debug DEBUG\n", + " --decoder DECODER\n", + " --deterministic_run DETERMINISTIC_RUN\n", + " --device DEVICE\n", + " --disag_action_cond DISAG_ACTION_COND\n", + " --disag_layers DISAG_LAYERS\n", + " --disag_log DISAG_LOG\n", + " --disag_models DISAG_MODELS\n", + " --disag_offset DISAG_OFFSET\n", + " --disag_target DISAG_TARGET\n", + " --disag_units DISAG_UNITS\n", + " --discount DISCOUNT\n", + " --discount_lambda DISCOUNT_LAMBDA\n", + " --dyn_deter DYN_DETER\n", + " --dyn_discrete DYN_DISCRETE\n", + " --dyn_hidden DYN_HIDDEN\n", + " --dyn_mean_act DYN_MEAN_ACT\n", + " --dyn_min_std DYN_MIN_STD\n", + " --dyn_rec_depth DYN_REC_DEPTH\n", + " --dyn_scale DYN_SCALE\n", + " --dyn_std_act DYN_STD_ACT\n", + " --dyn_stoch DYN_STOCH\n", + " --encoder ENCODER\n", + " --envs ENVS\n", + " --eval_episode_num EVAL_EPISODE_NUM\n", + " --eval_every EVAL_EVERY\n", + " --eval_state_mean EVAL_STATE_MEAN\n", + " --evaldir EVALDIR\n", + " --expl_behavior EXPL_BEHAVIOR\n", + " --expl_extr_scale EXPL_EXTR_SCALE\n", + " --expl_intr_scale EXPL_INTR_SCALE\n", + " --expl_until EXPL_UNTIL\n", + " --grad_clip GRAD_CLIP\n", + " --grad_heads GRAD_HEADS\n", + " --grayscale GRAYSCALE\n", + " --imag_gradient IMAG_GRADIENT\n", + " --imag_gradient_mix IMAG_GRADIENT_MIX\n", + " --imag_horizon IMAG_HORIZON\n", + " --initial INITIAL\n", + " --kl_free KL_FREE\n", + " --log_every LOG_EVERY\n", + " --logdir LOGDIR\n", + " --model_lr MODEL_LR\n", + " --norm NORM\n", + " --offline_evaldir OFFLINE_EVALDIR\n", + " --offline_traindir OFFLINE_TRAINDIR\n", + " --opt OPT\n", + " --opt_eps OPT_EPS\n", + " --parallel PARALLEL\n", + " --precision PRECISION\n", + " --prefill PREFILL\n", + " --pretrain PRETRAIN\n", + " --rep_scale REP_SCALE\n", + " --reset_every RESET_EVERY\n", + " --reward_EMA REWARD_EMA\n", + " --reward_head REWARD_HEAD\n", + " --seed SEED\n", + " --size SIZE\n", + " --steps STEPS\n", + " --task TASK\n", + " --time_limit TIME_LIMIT\n", + " --train_ratio TRAIN_RATIO\n", + " --traindir TRAINDIR\n", + " --unimix_ratio UNIMIX_RATIO\n", + " --units UNITS\n", + " --video_pred_log VIDEO_PRED_LOG\n", + " --weight_decay WEIGHT_DECAY\n" + ] + }, + { + "data": { + "text/plain": [ + "''" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# load config, using relative path\n", + "import pathlib\n", + "import ruamel.yaml as yaml\n", + "import tools\n", + "\n", + "# root_dir = pathlib.Path(__file__).parent\n", + "configs = yaml.safe_load(\n", + " pathlib.Path(\"/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/configs.yaml\").read_text()\n", + ")\n", + "\n", + "def recursive_update(base, update):\n", + " for key, value in update.items():\n", + " if isinstance(value, dict) and key in base:\n", + " recursive_update(base[key], value)\n", + " else:\n", + " base[key] = value\n", + "\n", + "name_list = [\"defaults\"]\n", + "defaults = {}\n", + "for name in name_list:\n", + " recursive_update(defaults, configs[name])\n", + "defaults2 = {k:tools.args_type(v)(v) for k, v in defaults.items()}\n", + "config2 = argparse.Namespace(**defaults2)\n", + "defaults2\n", + "\n", + "import argparse\n", + "parser = argparse.ArgumentParser()\n", + "for key, value in sorted(defaults.items(), key=lambda x: x[0]):\n", + " arg_type = tools.args_type(value)\n", + " parser.add_argument(f\"--{key}\", type=arg_type, default=arg_type(value))\n", + "args = parser.parse_args([])\n", + "# logger.debug(f\"{parser}\")\n", + "\n", + "# from loguru import logger\n", + "# parser.print_help()\n", + "# ;\n" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'added': {},\n", + " 'removed': {},\n", + " 'modified': {'size': ([64, 64], (64, 64)),\n", + " 'grad_heads': (['decoder', 'reward', 'cont'],\n", + " ('decoder', 'reward', 'cont'))}}" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# diff two dics\n", + "def dict_diff(dict1, dict2):\n", + " diff = {\n", + " 'added': {k: dict2[k] for k in dict2 if k not in dict1},\n", + " 'removed': {k: dict1[k] for k in dict1 if k not in dict2},\n", + " 'modified': {k: (dict1[k], dict2[k]) for k in dict1 if k in dict2 and dict1[k] != dict2[k]}\n", + " }\n", + " return diff\n", + "\n", + "dict_diff(defaults, defaults2)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Namespace(act='SiLU', action_repeat=1, actor={'layers': 2, 'dist': 'onehot', 'entropy': 0.0003, 'unimix_ratio': 0.01, 'std': 'none', 'min_std': 0.1, 'max_std': 1.0, 'temp': 0.1, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 1.0}, batch_length=32, batch_size=256, compile=True, cont_head={'layers': 2, 'loss_scale': 1.0, 'outscale': 1.0}, critic={'layers': 2, 'dist': 'symlog_disc', 'slow_target': True, 'slow_target_update': 1, 'slow_target_fraction': 0.02, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 0.0}, dataset_size=20000, debug=False, decoder={'mlp_keys': 'state_inventory', 'cnn_keys': 'state_map', 'act': 'SiLU', 'norm': True, 'cnn_depth': 16, 'kernel_size': 4, 'minres': 2, 'mlp_layers': 2, 'mlp_units': 8, 'cnn_sigmoid': False, 'image_dist': 'mse', 'vector_dist': 'symlog_mse', 'outscale': 1.0}, deterministic_run=False, device='cuda:0', disag_action_cond=False, disag_layers=4, disag_log=True, disag_models=10, disag_offset=1, disag_target='stoch', disag_units=400, discount=0.997, discount_lambda=0.95, dyn_deter=128, dyn_discrete=16, dyn_hidden=128, dyn_mean_act='none', dyn_min_std=0.1, dyn_rec_depth=1, dyn_scale=0.5, dyn_std_act='sigmoid2', dyn_stoch=16, encoder={'mlp_keys': 'state_inventory', 'cnn_keys': 'state_map', 'act': 'SiLU', 'norm': True, 'cnn_depth': 16, 'kernel_size': 4, 'minres': 2, 'mlp_layers': 2, 'mlp_units': 8, 'symlog_inputs': True}, envs=1, eval_episode_num=10, eval_every=10000.0, eval_state_mean=False, evaldir=None, expl_behavior='greedy', expl_extr_scale=0.0, expl_intr_scale=1.0, expl_until=0, grad_clip=1000, grad_heads=('decoder', 'reward', 'cont'), grayscale=False, imag_gradient='reinforce', imag_gradient_mix=0.0, imag_horizon=15, initial='learned', kl_free=1.0, log_every=10000.0, logdir='../logdir/craftax_smaller', model_lr=0.0001, norm=True, offline_evaldir='', offline_traindir='', opt='adam', opt_eps=1e-08, parallel=False, precision=32, prefill=2500, pretrain=100, rep_scale=0.1, reset_every=0, reward_EMA=True, reward_head={'layers': 2, 'dist': 'symlog_disc', 'loss_scale': 1.0, 'outscale': 0.0}, seed=0, size=(64, 64), step=1000000.0, steps=1000000.0, task='craftax_Craftax-Symbolic-AutoReset-v1', time_limit=4000, train_ratio=512, traindir=None, unimix_ratio=0.01, units=128, value={'layers': 2}, video_pred_log=False, weight_decay=0.0)" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "config" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Namespace(logdir=None, traindir=None, evaldir=None, offline_traindir='', offline_evaldir='', seed=0, deterministic_run=False, steps='1e6', parallel=False, eval_every='1e4', eval_episode_num=10, log_every='1e4', reset_every=0, device='cuda:0', compile=True, precision=32, debug=False, video_pred_log=False, task='dmc_walker_walk', size=[64, 64], envs=1, action_repeat=2, time_limit=1000, grayscale=False, prefill=2500, reward_EMA=True, dyn_hidden=512, dyn_deter=512, dyn_stoch=32, dyn_discrete=32, dyn_rec_depth=1, dyn_mean_act='none', dyn_std_act='sigmoid2', dyn_min_std=0.1, grad_heads=['decoder', 'reward', 'cont'], units=512, act='SiLU', norm=True, encoder={'mlp_keys': '$^', 'cnn_keys': 'image', 'act': 'SiLU', 'norm': True, 'cnn_depth': 32, 'kernel_size': 4, 'minres': 4, 'mlp_layers': 5, 'mlp_units': 1024, 'symlog_inputs': True}, decoder={'mlp_keys': '$^', 'cnn_keys': 'image', 'act': 'SiLU', 'norm': True, 'cnn_depth': 32, 'kernel_size': 4, 'minres': 4, 'mlp_layers': 5, 'mlp_units': 1024, 'cnn_sigmoid': False, 'image_dist': 'mse', 'vector_dist': 'symlog_mse', 'outscale': 1.0}, actor={'layers': 2, 'dist': 'normal', 'entropy': '3e-4', 'unimix_ratio': 0.01, 'std': 'learned', 'min_std': 0.1, 'max_std': 1.0, 'temp': 0.1, 'lr': '3e-5', 'eps': '1e-5', 'grad_clip': 100.0, 'outscale': 1.0}, critic={'layers': 2, 'dist': 'symlog_disc', 'slow_target': True, 'slow_target_update': 1, 'slow_target_fraction': 0.02, 'lr': '3e-5', 'eps': '1e-5', 'grad_clip': 100.0, 'outscale': 0.0}, reward_head={'layers': 2, 'dist': 'symlog_disc', 'loss_scale': 1.0, 'outscale': 0.0}, cont_head={'layers': 2, 'loss_scale': 1.0, 'outscale': 1.0}, dyn_scale=0.5, rep_scale=0.1, kl_free=1.0, weight_decay=0.0, unimix_ratio=0.01, initial='learned', batch_size=64, batch_length=64, train_ratio=512, pretrain=100, model_lr='1e-4', opt_eps='1e-8', grad_clip=1000, dataset_size=1000000, opt='adam', discount=0.997, discount_lambda=0.95, imag_horizon=15, imag_gradient='dynamics', imag_gradient_mix=0.0, eval_state_mean=False, expl_behavior='greedy', expl_until=0, expl_extr_scale=0.0, expl_intr_scale=1.0, disag_target='stoch', disag_log=True, disag_models=10, disag_offset=1, disag_layers=4, disag_units=400, disag_action_cond=False)" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "config2" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32m2024-06-07 03:49:38.306\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m27\u001b[0m - \u001b[1mLogdir ../logdir/craftax_smaller\u001b[0m\n", + "\u001b[32m2024-06-07 03:49:38.336\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m36\u001b[0m - \u001b[1mCreate envs.\u001b[0m\n", + "\u001b[32m2024-06-07 03:50:12.609\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m57\u001b[0m - \u001b[1mAction Space Box(0.0, 1.0, (43,), float32)\u001b[0m\n", + "\u001b[32m2024-06-07 03:50:12.611\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m63\u001b[0m - \u001b[1mPrefill dataset (0 steps).\u001b[0m\n", + "\u001b[32m2024-06-07 03:50:12.951\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m92\u001b[0m - \u001b[1mLogger: (2714 steps).\u001b[0m\n", + "\u001b[32m2024-06-07 03:50:12.952\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m94\u001b[0m - \u001b[1mSimulate agent.\u001b[0m\n" ] } ], @@ -197,7 +802,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -211,10 +816,10 @@ { "data": { "text/plain": [ - "Dict('image': Box(0, 255, (130, 110, 3), uint8), 'is_first': Box(0, 0, (1,), uint8), 'is_last': Box(0, 0, (1,), uint8), 'is_terminal': Box(0, 0, (1,), uint8), 'log_achievement_cast_fireball': Box(-inf, inf, (1,), float32), 'log_achievement_cast_iceball': Box(-inf, inf, (1,), float32), 'log_achievement_collect_coal': Box(-inf, inf, (1,), float32), 'log_achievement_collect_diamond': Box(-inf, inf, (1,), float32), 'log_achievement_collect_drink': Box(-inf, inf, (1,), float32), 'log_achievement_collect_iron': Box(-inf, inf, (1,), float32), 'log_achievement_collect_ruby': Box(-inf, inf, (1,), float32), 'log_achievement_collect_sapling': Box(-inf, inf, (1,), float32), 'log_achievement_collect_sapphire': Box(-inf, inf, (1,), float32), 'log_achievement_collect_stone': Box(-inf, inf, (1,), float32), 'log_achievement_collect_wood': Box(-inf, inf, (1,), float32), 'log_achievement_damage_necromancer': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_archer': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_deep_thing': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_fire_elemental': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_frost_troll': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_gnome_archer': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_gnome_warrior': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_ice_elemental': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_knight': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_kobold': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_lizard': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_necromancer': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_orc_mage': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_orc_solider': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_pigman': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_skeleton': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_troll': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_zombie': Box(-inf, inf, (1,), float32), 'log_achievement_drink_potion': Box(-inf, inf, (1,), float32), 'log_achievement_eat_bat': Box(-inf, inf, (1,), float32), 'log_achievement_eat_cow': Box(-inf, inf, (1,), float32), 'log_achievement_eat_plant': Box(-inf, inf, (1,), float32), 'log_achievement_eat_snail': Box(-inf, inf, (1,), float32), 'log_achievement_enchant_armour': Box(-inf, inf, (1,), float32), 'log_achievement_enchant_sword': Box(-inf, inf, (1,), float32), 'log_achievement_enter_dungeon': Box(-inf, inf, (1,), float32), 'log_achievement_enter_fire_realm': Box(-inf, inf, (1,), float32), 'log_achievement_enter_gnomish_mines': Box(-inf, inf, (1,), float32), 'log_achievement_enter_graveyard': Box(-inf, inf, (1,), float32), 'log_achievement_enter_ice_realm': Box(-inf, inf, (1,), float32), 'log_achievement_enter_sewers': Box(-inf, inf, (1,), float32), 'log_achievement_enter_troll_mines': Box(-inf, inf, (1,), float32), 'log_achievement_enter_vault': Box(-inf, inf, (1,), float32), 'log_achievement_find_bow': Box(-inf, inf, (1,), float32), 'log_achievement_fire_bow': Box(-inf, inf, (1,), float32), 'log_achievement_learn_fireball': Box(-inf, inf, (1,), float32), 'log_achievement_learn_iceball': Box(-inf, inf, (1,), float32), 'log_achievement_make_arrow': Box(-inf, inf, (1,), float32), 'log_achievement_make_diamond_armour': Box(-inf, inf, (1,), float32), 'log_achievement_make_diamond_pickaxe': Box(-inf, inf, (1,), float32), 'log_achievement_make_diamond_sword': Box(-inf, inf, (1,), float32), 'log_achievement_make_iron_armour': Box(-inf, inf, (1,), float32), 'log_achievement_make_iron_pickaxe': Box(-inf, inf, (1,), float32), 'log_achievement_make_iron_sword': Box(-inf, inf, (1,), float32), 'log_achievement_make_stone_pickaxe': Box(-inf, inf, (1,), float32), 'log_achievement_make_stone_sword': Box(-inf, inf, (1,), float32), 'log_achievement_make_torch': Box(-inf, inf, (1,), float32), 'log_achievement_make_wood_pickaxe': Box(-inf, inf, (1,), float32), 'log_achievement_make_wood_sword': Box(-inf, inf, (1,), float32), 'log_achievement_open_chest': Box(-inf, inf, (1,), float32), 'log_achievement_place_furnace': Box(-inf, inf, (1,), float32), 'log_achievement_place_plant': Box(-inf, inf, (1,), float32), 'log_achievement_place_stone': Box(-inf, inf, (1,), float32), 'log_achievement_place_table': Box(-inf, inf, (1,), float32), 'log_achievement_place_torch': Box(-inf, inf, (1,), float32), 'log_achievement_wake_up': Box(-inf, inf, (1,), float32), 'log_reward': Box(-inf, inf, (1,), float32), 'state': Box(0.0, 1.0, (16536,), float32), 'state_inventory': Box(0.0, 1.0, (102,), float32), 'state_map': Box(0.0, 1.0, (12, 12, 166), float32))" + "Dict('image': Box(0, 255, (130, 110, 3), uint8), 'is_first': Box(0, 0, (1,), uint8), 'is_last': Box(0, 0, (1,), uint8), 'is_terminal': Box(0, 0, (1,), uint8), 'log_achievement_cast_fireball': Box(-inf, inf, (1,), float32), 'log_achievement_cast_iceball': Box(-inf, inf, (1,), float32), 'log_achievement_collect_coal': Box(-inf, inf, (1,), float32), 'log_achievement_collect_diamond': Box(-inf, inf, (1,), float32), 'log_achievement_collect_drink': Box(-inf, inf, (1,), float32), 'log_achievement_collect_iron': Box(-inf, inf, (1,), float32), 'log_achievement_collect_ruby': Box(-inf, inf, (1,), float32), 'log_achievement_collect_sapling': Box(-inf, inf, (1,), float32), 'log_achievement_collect_sapphire': Box(-inf, inf, (1,), float32), 'log_achievement_collect_stone': Box(-inf, inf, (1,), float32), 'log_achievement_collect_wood': Box(-inf, inf, (1,), float32), 'log_achievement_damage_necromancer': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_archer': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_deep_thing': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_fire_elemental': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_frost_troll': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_gnome_archer': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_gnome_warrior': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_ice_elemental': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_knight': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_kobold': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_lizard': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_necromancer': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_orc_mage': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_orc_solider': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_pigman': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_skeleton': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_troll': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_zombie': Box(-inf, inf, (1,), float32), 'log_achievement_drink_potion': Box(-inf, inf, (1,), float32), 'log_achievement_eat_bat': Box(-inf, inf, (1,), float32), 'log_achievement_eat_cow': Box(-inf, inf, (1,), float32), 'log_achievement_eat_plant': Box(-inf, inf, (1,), float32), 'log_achievement_eat_snail': Box(-inf, inf, (1,), float32), 'log_achievement_enchant_armour': Box(-inf, inf, (1,), float32), 'log_achievement_enchant_sword': Box(-inf, inf, (1,), float32), 'log_achievement_enter_dungeon': Box(-inf, inf, (1,), float32), 'log_achievement_enter_fire_realm': Box(-inf, inf, (1,), float32), 'log_achievement_enter_gnomish_mines': Box(-inf, inf, (1,), float32), 'log_achievement_enter_graveyard': Box(-inf, inf, (1,), float32), 'log_achievement_enter_ice_realm': Box(-inf, inf, (1,), float32), 'log_achievement_enter_sewers': Box(-inf, inf, (1,), float32), 'log_achievement_enter_troll_mines': Box(-inf, inf, (1,), float32), 'log_achievement_enter_vault': Box(-inf, inf, (1,), float32), 'log_achievement_find_bow': Box(-inf, inf, (1,), float32), 'log_achievement_fire_bow': Box(-inf, inf, (1,), float32), 'log_achievement_learn_fireball': Box(-inf, inf, (1,), float32), 'log_achievement_learn_iceball': Box(-inf, inf, (1,), float32), 'log_achievement_make_arrow': Box(-inf, inf, (1,), float32), 'log_achievement_make_diamond_armour': Box(-inf, inf, (1,), float32), 'log_achievement_make_diamond_pickaxe': Box(-inf, inf, (1,), float32), 'log_achievement_make_diamond_sword': Box(-inf, inf, (1,), float32), 'log_achievement_make_iron_armour': Box(-inf, inf, (1,), float32), 'log_achievement_make_iron_pickaxe': Box(-inf, inf, (1,), float32), 'log_achievement_make_iron_sword': Box(-inf, inf, (1,), float32), 'log_achievement_make_stone_pickaxe': Box(-inf, inf, (1,), float32), 'log_achievement_make_stone_sword': Box(-inf, inf, (1,), float32), 'log_achievement_make_torch': Box(-inf, inf, (1,), float32), 'log_achievement_make_wood_pickaxe': Box(-inf, inf, (1,), float32), 'log_achievement_make_wood_sword': Box(-inf, inf, (1,), float32), 'log_achievement_open_chest': Box(-inf, inf, (1,), float32), 'log_achievement_place_furnace': Box(-inf, inf, (1,), float32), 'log_achievement_place_plant': Box(-inf, inf, (1,), float32), 'log_achievement_place_stone': Box(-inf, inf, (1,), float32), 'log_achievement_place_table': Box(-inf, inf, (1,), float32), 'log_achievement_place_torch': Box(-inf, inf, (1,), float32), 'log_achievement_wake_up': Box(-inf, inf, (1,), float32), 'log_reward': Box(-inf, inf, (1,), float32), 'state_inventory': Box(0.0, 1.0, (102,), float16), 'state_map': Box(0.0, 1.0, (16, 16, 166), float16))" ] }, - "execution_count": 6, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -225,20 +830,20 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[32m2024-06-06 17:09:31.695\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m324\u001b[0m - \u001b[1mEncoder CNN shapes: {'state_map': (12, 12, 166)}\u001b[0m\n", - "\u001b[32m2024-06-06 17:09:31.696\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m325\u001b[0m - \u001b[1mEncoder MLP shapes: {'state_inventory': (102,)}\u001b[0m\n", - "\u001b[32m2024-06-06 17:09:31.913\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m391\u001b[0m - \u001b[1mDecoder CNN shapes: {'state_map': (12, 12, 166)}\u001b[0m\n", - "\u001b[32m2024-06-06 17:09:31.914\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m392\u001b[0m - \u001b[1mDecoder MLP shapes: {'state_inventory': (102,)}\u001b[0m\n", - "\u001b[32m2024-06-06 17:09:32.650\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmodels\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m102\u001b[0m - \u001b[1mOptimizer model_opt has 2357196 variables.\u001b[0m\n", - "\u001b[32m2024-06-06 17:09:32.657\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmodels\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m281\u001b[0m - \u001b[1mOptimizer actor_opt has 356651 variables.\u001b[0m\n", - "\u001b[32m2024-06-06 17:09:32.657\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmodels\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m292\u001b[0m - \u001b[1mOptimizer value_opt has 345087 variables.\u001b[0m\n" + "\u001b[32m2024-06-07 03:50:13.077\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m324\u001b[0m - \u001b[1mEncoder CNN shapes: {'state_map': (16, 16, 166)}\u001b[0m\n", + "\u001b[32m2024-06-07 03:50:13.077\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m325\u001b[0m - \u001b[1mEncoder MLP shapes: {'state_inventory': (102,)}\u001b[0m\n", + "\u001b[32m2024-06-07 03:50:13.509\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m391\u001b[0m - \u001b[1mDecoder CNN shapes: {'state_map': (16, 16, 166)}\u001b[0m\n", + "\u001b[32m2024-06-07 03:50:13.509\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m392\u001b[0m - \u001b[1mDecoder MLP shapes: {'state_inventory': (102,)}\u001b[0m\n", + "\u001b[32m2024-06-07 03:50:15.165\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmodels\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m102\u001b[0m - \u001b[1mOptimizer model_opt has 706924 variables.\u001b[0m\n", + "\u001b[32m2024-06-07 03:50:15.185\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmodels\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m281\u001b[0m - \u001b[1mOptimizer actor_opt has 71595 variables.\u001b[0m\n", + "\u001b[32m2024-06-07 03:50:15.186\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmodels\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m292\u001b[0m - \u001b[1mOptimizer value_opt has 98943 variables.\u001b[0m\n" ] }, { @@ -259,108 +864,112 @@ " (encoder): MultiEncoder(\n", " (_cnn): ConvEncoder(\n", " (layers): Sequential(\n", - " (0): Conv2dSamePad(166, 32, kernel_size=(4, 4), stride=(2, 2), bias=False)\n", + " (0): Conv2dSamePad(166, 16, kernel_size=(4, 4), stride=(2, 2), bias=False)\n", " (1): ImgChLayerNorm(\n", - " (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n", + " (norm): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n", " )\n", " (2): SiLU()\n", - " (3): Conv2dSamePad(32, 64, kernel_size=(4, 4), stride=(2, 2), bias=False)\n", + " (3): Conv2dSamePad(16, 32, kernel_size=(4, 4), stride=(2, 2), bias=False)\n", " (4): ImgChLayerNorm(\n", - " (norm): LayerNorm((64,), eps=0.001, elementwise_affine=True)\n", + " (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n", " )\n", " (5): SiLU()\n", + " (6): Conv2dSamePad(32, 64, kernel_size=(4, 4), stride=(2, 2), bias=False)\n", + " (7): ImgChLayerNorm(\n", + " (norm): LayerNorm((64,), eps=0.001, elementwise_affine=True)\n", + " )\n", + " (8): SiLU()\n", " )\n", " )\n", " (_mlp): MLP(\n", " (layers): Sequential(\n", - " (Encoder_linear0): Linear(in_features=102, out_features=16, bias=False)\n", - " (Encoder_norm0): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n", + " (Encoder_linear0): Linear(in_features=102, out_features=8, bias=False)\n", + " (Encoder_norm0): LayerNorm((8,), eps=0.001, elementwise_affine=True)\n", " (Encoder_act0): SiLU()\n", - " (Encoder_linear1): Linear(in_features=16, out_features=16, bias=False)\n", - " (Encoder_norm1): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n", + " (Encoder_linear1): Linear(in_features=8, out_features=8, bias=False)\n", + " (Encoder_norm1): LayerNorm((8,), eps=0.001, elementwise_affine=True)\n", " (Encoder_act1): SiLU()\n", " )\n", " )\n", " )\n", " (dynamics): RSSM(\n", " (_img_in_layers): Sequential(\n", - " (0): Linear(in_features=619, out_features=256, bias=False)\n", - " (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (0): Linear(in_features=299, out_features=128, bias=False)\n", + " (1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (2): SiLU()\n", " )\n", " (_cell): GRUCell(\n", " (layers): Sequential(\n", - " (GRU_linear): Linear(in_features=512, out_features=768, bias=False)\n", - " (GRU_norm): LayerNorm((768,), eps=0.001, elementwise_affine=True)\n", + " (GRU_linear): Linear(in_features=256, out_features=384, bias=False)\n", + " (GRU_norm): LayerNorm((384,), eps=0.001, elementwise_affine=True)\n", " )\n", " )\n", " (_img_out_layers): Sequential(\n", - " (0): Linear(in_features=256, out_features=256, bias=False)\n", - " (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (0): Linear(in_features=128, out_features=128, bias=False)\n", + " (1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (2): SiLU()\n", " )\n", " (_obs_out_layers): Sequential(\n", - " (0): Linear(in_features=848, out_features=256, bias=False)\n", - " (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (0): Linear(in_features=392, out_features=128, bias=False)\n", + " (1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (2): SiLU()\n", " )\n", - " (_imgs_stat_layer): Linear(in_features=256, out_features=576, bias=True)\n", - " (_obs_stat_layer): Linear(in_features=256, out_features=576, bias=True)\n", + " (_imgs_stat_layer): Linear(in_features=128, out_features=256, bias=True)\n", + " (_obs_stat_layer): Linear(in_features=128, out_features=256, bias=True)\n", " )\n", " (heads): ModuleDict(\n", " (decoder): MultiDecoder(\n", " (_cnn): ConvDecoder(\n", - " (_linear_layer): Linear(in_features=832, out_features=256, bias=True)\n", + " (_linear_layer): Linear(in_features=384, out_features=256, bias=True)\n", " (layers): Sequential(\n", " (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", " (1): ImgChLayerNorm(\n", " (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n", " )\n", " (2): SiLU()\n", - " (3): ConvTranspose2d(32, 166, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", + " (3): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (4): ImgChLayerNorm(\n", + " (norm): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n", + " )\n", + " (5): SiLU()\n", + " (6): ConvTranspose2d(16, 166, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " )\n", " )\n", " (_mlp): MLP(\n", " (layers): Sequential(\n", - " (Decoder_linear0): Linear(in_features=832, out_features=16, bias=False)\n", - " (Decoder_norm0): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n", + " (Decoder_linear0): Linear(in_features=384, out_features=8, bias=False)\n", + " (Decoder_norm0): LayerNorm((8,), eps=0.001, elementwise_affine=True)\n", " (Decoder_act0): SiLU()\n", - " (Decoder_linear1): Linear(in_features=16, out_features=16, bias=False)\n", - " (Decoder_norm1): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n", + " (Decoder_linear1): Linear(in_features=8, out_features=8, bias=False)\n", + " (Decoder_norm1): LayerNorm((8,), eps=0.001, elementwise_affine=True)\n", " (Decoder_act1): SiLU()\n", " )\n", " (mean_layer): ModuleDict(\n", - " (state_inventory): Linear(in_features=16, out_features=102, bias=True)\n", + " (state_inventory): Linear(in_features=8, out_features=102, bias=True)\n", " )\n", " )\n", " )\n", " (reward): MLP(\n", " (layers): Sequential(\n", - " (Reward_linear0): Linear(in_features=832, out_features=256, bias=False)\n", - " (Reward_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Reward_linear0): Linear(in_features=384, out_features=128, bias=False)\n", + " (Reward_norm0): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (Reward_act0): SiLU()\n", - " (Reward_linear1): Linear(in_features=256, out_features=256, bias=False)\n", - " (Reward_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Reward_linear1): Linear(in_features=128, out_features=128, bias=False)\n", + " (Reward_norm1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (Reward_act1): SiLU()\n", - " (Reward_linear2): Linear(in_features=256, out_features=256, bias=False)\n", - " (Reward_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", - " (Reward_act2): SiLU()\n", " )\n", - " (mean_layer): Linear(in_features=256, out_features=255, bias=True)\n", + " (mean_layer): Linear(in_features=128, out_features=255, bias=True)\n", " )\n", " (cont): MLP(\n", " (layers): Sequential(\n", - " (Cont_linear0): Linear(in_features=832, out_features=256, bias=False)\n", - " (Cont_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Cont_linear0): Linear(in_features=384, out_features=128, bias=False)\n", + " (Cont_norm0): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (Cont_act0): SiLU()\n", - " (Cont_linear1): Linear(in_features=256, out_features=256, bias=False)\n", - " (Cont_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Cont_linear1): Linear(in_features=128, out_features=128, bias=False)\n", + " (Cont_norm1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (Cont_act1): SiLU()\n", - " (Cont_linear2): Linear(in_features=256, out_features=256, bias=False)\n", - " (Cont_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", - " (Cont_act2): SiLU()\n", " )\n", - " (mean_layer): Linear(in_features=256, out_features=1, bias=True)\n", + " (mean_layer): Linear(in_features=128, out_features=1, bias=True)\n", " )\n", " )\n", " )\n", @@ -371,146 +980,147 @@ " (encoder): MultiEncoder(\n", " (_cnn): ConvEncoder(\n", " (layers): Sequential(\n", - " (0): Conv2dSamePad(166, 32, kernel_size=(4, 4), stride=(2, 2), bias=False)\n", + " (0): Conv2dSamePad(166, 16, kernel_size=(4, 4), stride=(2, 2), bias=False)\n", " (1): ImgChLayerNorm(\n", - " (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n", + " (norm): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n", " )\n", " (2): SiLU()\n", - " (3): Conv2dSamePad(32, 64, kernel_size=(4, 4), stride=(2, 2), bias=False)\n", + " (3): Conv2dSamePad(16, 32, kernel_size=(4, 4), stride=(2, 2), bias=False)\n", " (4): ImgChLayerNorm(\n", - " (norm): LayerNorm((64,), eps=0.001, elementwise_affine=True)\n", + " (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n", " )\n", " (5): SiLU()\n", + " (6): Conv2dSamePad(32, 64, kernel_size=(4, 4), stride=(2, 2), bias=False)\n", + " (7): ImgChLayerNorm(\n", + " (norm): LayerNorm((64,), eps=0.001, elementwise_affine=True)\n", + " )\n", + " (8): SiLU()\n", " )\n", " )\n", " (_mlp): MLP(\n", " (layers): Sequential(\n", - " (Encoder_linear0): Linear(in_features=102, out_features=16, bias=False)\n", - " (Encoder_norm0): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n", + " (Encoder_linear0): Linear(in_features=102, out_features=8, bias=False)\n", + " (Encoder_norm0): LayerNorm((8,), eps=0.001, elementwise_affine=True)\n", " (Encoder_act0): SiLU()\n", - " (Encoder_linear1): Linear(in_features=16, out_features=16, bias=False)\n", - " (Encoder_norm1): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n", + " (Encoder_linear1): Linear(in_features=8, out_features=8, bias=False)\n", + " (Encoder_norm1): LayerNorm((8,), eps=0.001, elementwise_affine=True)\n", " (Encoder_act1): SiLU()\n", " )\n", " )\n", " )\n", " (dynamics): RSSM(\n", " (_img_in_layers): Sequential(\n", - " (0): Linear(in_features=619, out_features=256, bias=False)\n", - " (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (0): Linear(in_features=299, out_features=128, bias=False)\n", + " (1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (2): SiLU()\n", " )\n", " (_cell): GRUCell(\n", " (layers): Sequential(\n", - " (GRU_linear): Linear(in_features=512, out_features=768, bias=False)\n", - " (GRU_norm): LayerNorm((768,), eps=0.001, elementwise_affine=True)\n", + " (GRU_linear): Linear(in_features=256, out_features=384, bias=False)\n", + " (GRU_norm): LayerNorm((384,), eps=0.001, elementwise_affine=True)\n", " )\n", " )\n", " (_img_out_layers): Sequential(\n", - " (0): Linear(in_features=256, out_features=256, bias=False)\n", - " (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (0): Linear(in_features=128, out_features=128, bias=False)\n", + " (1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (2): SiLU()\n", " )\n", " (_obs_out_layers): Sequential(\n", - " (0): Linear(in_features=848, out_features=256, bias=False)\n", - " (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (0): Linear(in_features=392, out_features=128, bias=False)\n", + " (1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (2): SiLU()\n", " )\n", - " (_imgs_stat_layer): Linear(in_features=256, out_features=576, bias=True)\n", - " (_obs_stat_layer): Linear(in_features=256, out_features=576, bias=True)\n", + " (_imgs_stat_layer): Linear(in_features=128, out_features=256, bias=True)\n", + " (_obs_stat_layer): Linear(in_features=128, out_features=256, bias=True)\n", " )\n", " (heads): ModuleDict(\n", " (decoder): MultiDecoder(\n", " (_cnn): ConvDecoder(\n", - " (_linear_layer): Linear(in_features=832, out_features=256, bias=True)\n", + " (_linear_layer): Linear(in_features=384, out_features=256, bias=True)\n", " (layers): Sequential(\n", " (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", " (1): ImgChLayerNorm(\n", " (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n", " )\n", " (2): SiLU()\n", - " (3): ConvTranspose2d(32, 166, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", + " (3): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (4): ImgChLayerNorm(\n", + " (norm): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n", + " )\n", + " (5): SiLU()\n", + " (6): ConvTranspose2d(16, 166, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " )\n", " )\n", " (_mlp): MLP(\n", " (layers): Sequential(\n", - " (Decoder_linear0): Linear(in_features=832, out_features=16, bias=False)\n", - " (Decoder_norm0): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n", + " (Decoder_linear0): Linear(in_features=384, out_features=8, bias=False)\n", + " (Decoder_norm0): LayerNorm((8,), eps=0.001, elementwise_affine=True)\n", " (Decoder_act0): SiLU()\n", - " (Decoder_linear1): Linear(in_features=16, out_features=16, bias=False)\n", - " (Decoder_norm1): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n", + " (Decoder_linear1): Linear(in_features=8, out_features=8, bias=False)\n", + " (Decoder_norm1): LayerNorm((8,), eps=0.001, elementwise_affine=True)\n", " (Decoder_act1): SiLU()\n", " )\n", " (mean_layer): ModuleDict(\n", - " (state_inventory): Linear(in_features=16, out_features=102, bias=True)\n", + " (state_inventory): Linear(in_features=8, out_features=102, bias=True)\n", " )\n", " )\n", " )\n", " (reward): MLP(\n", " (layers): Sequential(\n", - " (Reward_linear0): Linear(in_features=832, out_features=256, bias=False)\n", - " (Reward_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Reward_linear0): Linear(in_features=384, out_features=128, bias=False)\n", + " (Reward_norm0): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (Reward_act0): SiLU()\n", - " (Reward_linear1): Linear(in_features=256, out_features=256, bias=False)\n", - " (Reward_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Reward_linear1): Linear(in_features=128, out_features=128, bias=False)\n", + " (Reward_norm1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (Reward_act1): SiLU()\n", - " (Reward_linear2): Linear(in_features=256, out_features=256, bias=False)\n", - " (Reward_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", - " (Reward_act2): SiLU()\n", " )\n", - " (mean_layer): Linear(in_features=256, out_features=255, bias=True)\n", + " (mean_layer): Linear(in_features=128, out_features=255, bias=True)\n", " )\n", " (cont): MLP(\n", " (layers): Sequential(\n", - " (Cont_linear0): Linear(in_features=832, out_features=256, bias=False)\n", - " (Cont_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Cont_linear0): Linear(in_features=384, out_features=128, bias=False)\n", + " (Cont_norm0): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (Cont_act0): SiLU()\n", - " (Cont_linear1): Linear(in_features=256, out_features=256, bias=False)\n", - " (Cont_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Cont_linear1): Linear(in_features=128, out_features=128, bias=False)\n", + " (Cont_norm1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (Cont_act1): SiLU()\n", - " (Cont_linear2): Linear(in_features=256, out_features=256, bias=False)\n", - " (Cont_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", - " (Cont_act2): SiLU()\n", " )\n", - " (mean_layer): Linear(in_features=256, out_features=1, bias=True)\n", + " (mean_layer): Linear(in_features=128, out_features=1, bias=True)\n", " )\n", " )\n", " )\n", " (actor): MLP(\n", " (layers): Sequential(\n", - " (Actor_linear0): Linear(in_features=832, out_features=256, bias=False)\n", - " (Actor_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Actor_linear0): Linear(in_features=384, out_features=128, bias=False)\n", + " (Actor_norm0): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (Actor_act0): SiLU()\n", - " (Actor_linear1): Linear(in_features=256, out_features=256, bias=False)\n", - " (Actor_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Actor_linear1): Linear(in_features=128, out_features=128, bias=False)\n", + " (Actor_norm1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (Actor_act1): SiLU()\n", - " (Actor_linear2): Linear(in_features=256, out_features=256, bias=False)\n", - " (Actor_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", - " (Actor_act2): SiLU()\n", " )\n", - " (mean_layer): Linear(in_features=256, out_features=43, bias=True)\n", + " (mean_layer): Linear(in_features=128, out_features=43, bias=True)\n", " )\n", " (value): MLP(\n", " (layers): Sequential(\n", - " (Value_linear0): Linear(in_features=832, out_features=256, bias=False)\n", - " (Value_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Value_linear0): Linear(in_features=384, out_features=128, bias=False)\n", + " (Value_norm0): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (Value_act0): SiLU()\n", - " (Value_linear1): Linear(in_features=256, out_features=256, bias=False)\n", - " (Value_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Value_linear1): Linear(in_features=128, out_features=128, bias=False)\n", + " (Value_norm1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (Value_act1): SiLU()\n", " )\n", - " (mean_layer): Linear(in_features=256, out_features=255, bias=True)\n", + " (mean_layer): Linear(in_features=128, out_features=255, bias=True)\n", " )\n", " (_slow_value): MLP(\n", " (layers): Sequential(\n", - " (Value_linear0): Linear(in_features=832, out_features=256, bias=False)\n", - " (Value_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Value_linear0): Linear(in_features=384, out_features=128, bias=False)\n", + " (Value_norm0): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (Value_act0): SiLU()\n", - " (Value_linear1): Linear(in_features=256, out_features=256, bias=False)\n", - " (Value_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Value_linear1): Linear(in_features=128, out_features=128, bias=False)\n", + " (Value_norm1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (Value_act1): SiLU()\n", " )\n", - " (mean_layer): Linear(in_features=256, out_features=255, bias=True)\n", + " (mean_layer): Linear(in_features=128, out_features=255, bias=True)\n", " )\n", " )\n", " )\n", @@ -520,146 +1130,147 @@ " (encoder): MultiEncoder(\n", " (_cnn): ConvEncoder(\n", " (layers): Sequential(\n", - " (0): Conv2dSamePad(166, 32, kernel_size=(4, 4), stride=(2, 2), bias=False)\n", + " (0): Conv2dSamePad(166, 16, kernel_size=(4, 4), stride=(2, 2), bias=False)\n", " (1): ImgChLayerNorm(\n", - " (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n", + " (norm): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n", " )\n", " (2): SiLU()\n", - " (3): Conv2dSamePad(32, 64, kernel_size=(4, 4), stride=(2, 2), bias=False)\n", + " (3): Conv2dSamePad(16, 32, kernel_size=(4, 4), stride=(2, 2), bias=False)\n", " (4): ImgChLayerNorm(\n", - " (norm): LayerNorm((64,), eps=0.001, elementwise_affine=True)\n", + " (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n", " )\n", " (5): SiLU()\n", + " (6): Conv2dSamePad(32, 64, kernel_size=(4, 4), stride=(2, 2), bias=False)\n", + " (7): ImgChLayerNorm(\n", + " (norm): LayerNorm((64,), eps=0.001, elementwise_affine=True)\n", + " )\n", + " (8): SiLU()\n", " )\n", " )\n", " (_mlp): MLP(\n", " (layers): Sequential(\n", - " (Encoder_linear0): Linear(in_features=102, out_features=16, bias=False)\n", - " (Encoder_norm0): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n", + " (Encoder_linear0): Linear(in_features=102, out_features=8, bias=False)\n", + " (Encoder_norm0): LayerNorm((8,), eps=0.001, elementwise_affine=True)\n", " (Encoder_act0): SiLU()\n", - " (Encoder_linear1): Linear(in_features=16, out_features=16, bias=False)\n", - " (Encoder_norm1): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n", + " (Encoder_linear1): Linear(in_features=8, out_features=8, bias=False)\n", + " (Encoder_norm1): LayerNorm((8,), eps=0.001, elementwise_affine=True)\n", " (Encoder_act1): SiLU()\n", " )\n", " )\n", " )\n", " (dynamics): RSSM(\n", " (_img_in_layers): Sequential(\n", - " (0): Linear(in_features=619, out_features=256, bias=False)\n", - " (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (0): Linear(in_features=299, out_features=128, bias=False)\n", + " (1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (2): SiLU()\n", " )\n", " (_cell): GRUCell(\n", " (layers): Sequential(\n", - " (GRU_linear): Linear(in_features=512, out_features=768, bias=False)\n", - " (GRU_norm): LayerNorm((768,), eps=0.001, elementwise_affine=True)\n", + " (GRU_linear): Linear(in_features=256, out_features=384, bias=False)\n", + " (GRU_norm): LayerNorm((384,), eps=0.001, elementwise_affine=True)\n", " )\n", " )\n", " (_img_out_layers): Sequential(\n", - " (0): Linear(in_features=256, out_features=256, bias=False)\n", - " (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (0): Linear(in_features=128, out_features=128, bias=False)\n", + " (1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (2): SiLU()\n", " )\n", " (_obs_out_layers): Sequential(\n", - " (0): Linear(in_features=848, out_features=256, bias=False)\n", - " (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (0): Linear(in_features=392, out_features=128, bias=False)\n", + " (1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (2): SiLU()\n", " )\n", - " (_imgs_stat_layer): Linear(in_features=256, out_features=576, bias=True)\n", - " (_obs_stat_layer): Linear(in_features=256, out_features=576, bias=True)\n", + " (_imgs_stat_layer): Linear(in_features=128, out_features=256, bias=True)\n", + " (_obs_stat_layer): Linear(in_features=128, out_features=256, bias=True)\n", " )\n", " (heads): ModuleDict(\n", " (decoder): MultiDecoder(\n", " (_cnn): ConvDecoder(\n", - " (_linear_layer): Linear(in_features=832, out_features=256, bias=True)\n", + " (_linear_layer): Linear(in_features=384, out_features=256, bias=True)\n", " (layers): Sequential(\n", " (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", " (1): ImgChLayerNorm(\n", " (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n", " )\n", " (2): SiLU()\n", - " (3): ConvTranspose2d(32, 166, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", + " (3): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (4): ImgChLayerNorm(\n", + " (norm): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n", + " )\n", + " (5): SiLU()\n", + " (6): ConvTranspose2d(16, 166, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " )\n", " )\n", " (_mlp): MLP(\n", " (layers): Sequential(\n", - " (Decoder_linear0): Linear(in_features=832, out_features=16, bias=False)\n", - " (Decoder_norm0): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n", + " (Decoder_linear0): Linear(in_features=384, out_features=8, bias=False)\n", + " (Decoder_norm0): LayerNorm((8,), eps=0.001, elementwise_affine=True)\n", " (Decoder_act0): SiLU()\n", - " (Decoder_linear1): Linear(in_features=16, out_features=16, bias=False)\n", - " (Decoder_norm1): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n", + " (Decoder_linear1): Linear(in_features=8, out_features=8, bias=False)\n", + " (Decoder_norm1): LayerNorm((8,), eps=0.001, elementwise_affine=True)\n", " (Decoder_act1): SiLU()\n", " )\n", " (mean_layer): ModuleDict(\n", - " (state_inventory): Linear(in_features=16, out_features=102, bias=True)\n", + " (state_inventory): Linear(in_features=8, out_features=102, bias=True)\n", " )\n", " )\n", " )\n", " (reward): MLP(\n", " (layers): Sequential(\n", - " (Reward_linear0): Linear(in_features=832, out_features=256, bias=False)\n", - " (Reward_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Reward_linear0): Linear(in_features=384, out_features=128, bias=False)\n", + " (Reward_norm0): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (Reward_act0): SiLU()\n", - " (Reward_linear1): Linear(in_features=256, out_features=256, bias=False)\n", - " (Reward_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Reward_linear1): Linear(in_features=128, out_features=128, bias=False)\n", + " (Reward_norm1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (Reward_act1): SiLU()\n", - " (Reward_linear2): Linear(in_features=256, out_features=256, bias=False)\n", - " (Reward_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", - " (Reward_act2): SiLU()\n", " )\n", - " (mean_layer): Linear(in_features=256, out_features=255, bias=True)\n", + " (mean_layer): Linear(in_features=128, out_features=255, bias=True)\n", " )\n", " (cont): MLP(\n", " (layers): Sequential(\n", - " (Cont_linear0): Linear(in_features=832, out_features=256, bias=False)\n", - " (Cont_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Cont_linear0): Linear(in_features=384, out_features=128, bias=False)\n", + " (Cont_norm0): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (Cont_act0): SiLU()\n", - " (Cont_linear1): Linear(in_features=256, out_features=256, bias=False)\n", - " (Cont_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Cont_linear1): Linear(in_features=128, out_features=128, bias=False)\n", + " (Cont_norm1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (Cont_act1): SiLU()\n", - " (Cont_linear2): Linear(in_features=256, out_features=256, bias=False)\n", - " (Cont_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", - " (Cont_act2): SiLU()\n", " )\n", - " (mean_layer): Linear(in_features=256, out_features=1, bias=True)\n", + " (mean_layer): Linear(in_features=128, out_features=1, bias=True)\n", " )\n", " )\n", " )\n", " (actor): MLP(\n", " (layers): Sequential(\n", - " (Actor_linear0): Linear(in_features=832, out_features=256, bias=False)\n", - " (Actor_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Actor_linear0): Linear(in_features=384, out_features=128, bias=False)\n", + " (Actor_norm0): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (Actor_act0): SiLU()\n", - " (Actor_linear1): Linear(in_features=256, out_features=256, bias=False)\n", - " (Actor_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Actor_linear1): Linear(in_features=128, out_features=128, bias=False)\n", + " (Actor_norm1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (Actor_act1): SiLU()\n", - " (Actor_linear2): Linear(in_features=256, out_features=256, bias=False)\n", - " (Actor_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", - " (Actor_act2): SiLU()\n", " )\n", - " (mean_layer): Linear(in_features=256, out_features=43, bias=True)\n", + " (mean_layer): Linear(in_features=128, out_features=43, bias=True)\n", " )\n", " (value): MLP(\n", " (layers): Sequential(\n", - " (Value_linear0): Linear(in_features=832, out_features=256, bias=False)\n", - " (Value_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Value_linear0): Linear(in_features=384, out_features=128, bias=False)\n", + " (Value_norm0): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (Value_act0): SiLU()\n", - " (Value_linear1): Linear(in_features=256, out_features=256, bias=False)\n", - " (Value_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Value_linear1): Linear(in_features=128, out_features=128, bias=False)\n", + " (Value_norm1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (Value_act1): SiLU()\n", " )\n", - " (mean_layer): Linear(in_features=256, out_features=255, bias=True)\n", + " (mean_layer): Linear(in_features=128, out_features=255, bias=True)\n", " )\n", " (_slow_value): MLP(\n", " (layers): Sequential(\n", - " (Value_linear0): Linear(in_features=832, out_features=256, bias=False)\n", - " (Value_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Value_linear0): Linear(in_features=384, out_features=128, bias=False)\n", + " (Value_norm0): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (Value_act0): SiLU()\n", - " (Value_linear1): Linear(in_features=256, out_features=256, bias=False)\n", - " (Value_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Value_linear1): Linear(in_features=128, out_features=128, bias=False)\n", + " (Value_norm1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n", " (Value_act1): SiLU()\n", " )\n", - " (mean_layer): Linear(in_features=256, out_features=255, bias=True)\n", + " (mean_layer): Linear(in_features=128, out_features=255, bias=True)\n", " )\n", " )\n", " )\n", @@ -687,6 +1298,36 @@ " logger.warning(f\"Loaded model from {logdir / 'latest.pt'}\")" ] }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[0;31mType:\u001b[0m generator\n", + "\u001b[0;31mString form:\u001b[0m \n", + "\u001b[0;31mDocstring:\u001b[0m " + ] + }, + { + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n", + "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n", + "\u001b[1;31mClick here for more info. \n", + "\u001b[1;31mView Jupyter log for further details." + ] + } + ], + "source": [ + "train_dataset??" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -712,7 +1353,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -724,7 +1365,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -757,30 +1398,24 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/nbs/../networks.py:790: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)\n", - " ret = F.conv2d(\n" + "/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/nbs/../networks.py:792: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)\n", + " ret = F.conv2d(\n", + "/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/.venv/lib/python3.9/site-packages/torch/autograd/graph.py:744: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)\n", + " return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n" ] }, { - "ename": "AssertionError", - "evalue": "(torch.Size([256, 32, 8, 8, 166]), torch.Size([256, 32, 12, 12, 166]))", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[10], line 6\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# from tools.simulate\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# step\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# step, episode, done, length, obs, agent_state, reward = state\u001b[39;00m\n\u001b[1;32m 5\u001b[0m obs2 \u001b[38;5;241m=\u001b[39m {k: np\u001b[38;5;241m.\u001b[39mstack([o[k] \u001b[38;5;28;01mfor\u001b[39;00m o \u001b[38;5;129;01min\u001b[39;00m obs]) \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m obs[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlog_\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m k}\n\u001b[0;32m----> 6\u001b[0m action, agent_state \u001b[38;5;241m=\u001b[39m \u001b[43magent\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobs2\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdone\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43magent_state\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/nbs/../dreamer.py:71\u001b[0m, in \u001b[0;36mDreamer.__call__\u001b[0;34m(self, obs, reset, state, training)\u001b[0m\n\u001b[1;32m 65\u001b[0m steps \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 66\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_config\u001b[38;5;241m.\u001b[39mpretrain\n\u001b[1;32m 67\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_should_pretrain()\n\u001b[1;32m 68\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_should_train(step)\n\u001b[1;32m 69\u001b[0m )\n\u001b[1;32m 70\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(steps):\n\u001b[0;32m---> 71\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_train\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_update_count \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_metrics[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mupdate_count\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_update_count\n", - "File \u001b[0;32m/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/nbs/../dreamer.py:124\u001b[0m, in \u001b[0;36mDreamer._train\u001b[0;34m(self, data)\u001b[0m\n\u001b[1;32m 122\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_train\u001b[39m(\u001b[38;5;28mself\u001b[39m, data):\n\u001b[1;32m 123\u001b[0m metrics \u001b[38;5;241m=\u001b[39m {}\n\u001b[0;32m--> 124\u001b[0m post, context, mets \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_wm\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_train\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 125\u001b[0m metrics\u001b[38;5;241m.\u001b[39mupdate(mets)\n\u001b[1;32m 126\u001b[0m start \u001b[38;5;241m=\u001b[39m post\n", - "File \u001b[0;32m/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/nbs/../models.py:143\u001b[0m, in \u001b[0;36mWorldModel._train\u001b[0;34m(self, data)\u001b[0m\n\u001b[1;32m 141\u001b[0m losses \u001b[38;5;241m=\u001b[39m {}\n\u001b[1;32m 142\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m name, pred \u001b[38;5;129;01min\u001b[39;00m preds\u001b[38;5;241m.\u001b[39mitems():\n\u001b[0;32m--> 143\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[43mpred\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlog_prob\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m[\u001b[49m\u001b[43mname\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 144\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m loss\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m==\u001b[39m embed\u001b[38;5;241m.\u001b[39mshape[:\u001b[38;5;241m2\u001b[39m], (name, loss\u001b[38;5;241m.\u001b[39mshape)\n\u001b[1;32m 145\u001b[0m losses[name] \u001b[38;5;241m=\u001b[39m loss\n", - "File \u001b[0;32m/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/nbs/../tools.py:528\u001b[0m, in \u001b[0;36mMSEDist.log_prob\u001b[0;34m(self, value)\u001b[0m\n\u001b[1;32m 527\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mlog_prob\u001b[39m(\u001b[38;5;28mself\u001b[39m, value):\n\u001b[0;32m--> 528\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_mode\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m==\u001b[39m value\u001b[38;5;241m.\u001b[39mshape, (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_mode\u001b[38;5;241m.\u001b[39mshape, value\u001b[38;5;241m.\u001b[39mshape)\n\u001b[1;32m 529\u001b[0m distance \u001b[38;5;241m=\u001b[39m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_mode \u001b[38;5;241m-\u001b[39m value) \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m \u001b[38;5;241m2\u001b[39m\n\u001b[1;32m 530\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_agg \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmean\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n", - "\u001b[0;31mAssertionError\u001b[0m: (torch.Size([256, 32, 8, 8, 166]), torch.Size([256, 32, 12, 12, 166]))" + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32m2024-06-07 03:53:49.056\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mtools\u001b[0m:\u001b[36mwrite\u001b[0m:\u001b[36m85\u001b[0m - \u001b[1m[2714] model_loss \u001b[31m7464.5\u001b[0m\u001b[1m / model_grad_norm \u001b[31m4841.3\u001b[0m\u001b[1m / state_map_loss \u001b[31m7449.7\u001b[0m\u001b[1m / state_inventory_loss \u001b[31m7.5\u001b[0m\u001b[1m / reward_loss \u001b[31m5.0\u001b[0m\u001b[1m / cont_loss \u001b[31m0.1\u001b[0m\u001b[1m / kl_free \u001b[31m1.0\u001b[0m\u001b[1m / dyn_scale \u001b[31m0.5\u001b[0m\u001b[1m / rep_scale \u001b[31m0.1\u001b[0m\u001b[1m / dyn_loss \u001b[31m3.7\u001b[0m\u001b[1m / rep_loss \u001b[31m3.7\u001b[0m\u001b[1m / kl \u001b[31m3.7\u001b[0m\u001b[1m / prior_ent \u001b[31m42.9\u001b[0m\u001b[1m / post_ent \u001b[31m39.3\u001b[0m\u001b[1m / normed_target_mean \u001b[31m-0.0\u001b[0m\u001b[1m / normed_target_std \u001b[31m0.0\u001b[0m\u001b[1m / normed_target_min \u001b[31m-0.0\u001b[0m\u001b[1m / normed_target_max \u001b[31m0.0\u001b[0m\u001b[1m / EMA_005 \u001b[31m-0.0\u001b[0m\u001b[1m / EMA_095 \u001b[31m-0.0\u001b[0m\u001b[1m / value_mean \u001b[31m-0.0\u001b[0m\u001b[1m / value_std \u001b[31m0.0\u001b[0m\u001b[1m / value_min \u001b[31m-0.0\u001b[0m\u001b[1m / value_max \u001b[31m-0.0\u001b[0m\u001b[1m / target_mean \u001b[31m-0.0\u001b[0m\u001b[1m / target_std \u001b[31m0.0\u001b[0m\u001b[1m / target_min \u001b[31m-0.0\u001b[0m\u001b[1m / target_max \u001b[31m-0.0\u001b[0m\u001b[1m / imag_reward_mean \u001b[31m-0.0\u001b[0m\u001b[1m / imag_reward_std \u001b[31m0.0\u001b[0m\u001b[1m / imag_reward_min \u001b[31m-0.0\u001b[0m\u001b[1m / imag_reward_max \u001b[31m-0.0\u001b[0m\u001b[1m / imag_action_mean \u001b[31m21.2\u001b[0m\u001b[1m / imag_action_std \u001b[31m12.6\u001b[0m\u001b[1m / imag_action_min \u001b[31m0.0\u001b[0m\u001b[1m / imag_action_max \u001b[31m42.0\u001b[0m\u001b[1m / actor_entropy \u001b[31m3.6\u001b[0m\u001b[1m / actor_loss \u001b[31m-0.0\u001b[0m\u001b[1m / actor_grad_norm \u001b[31m0.0\u001b[0m\u001b[1m / value_loss \u001b[31m7.6\u001b[0m\u001b[1m / value_grad_norm \u001b[31m8.4\u001b[0m\u001b[1m / update_count \u001b[31m100.0\u001b[0m\u001b[1m / fps \u001b[31m0.0\u001b[0m\u001b[1m\u001b[0m\n" ] } ], @@ -795,7 +1430,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -807,15 +1442,15 @@ "Dreamer --\n", "├─OptimizedModule: 1-1 --\n", "│ └─WorldModel: 2-1 --\n", - "│ │ └─MultiEncoder: 3-1 (44,480)\n", - "│ │ └─RSSM: 3-2 (2,397,952)\n", - "│ │ └─ModuleDict: 3-3 (1,580,204)\n", + "│ │ └─MultiEncoder: 3-1 (84,592)\n", + "│ │ └─RSSM: 3-2 (270,848)\n", + "│ │ └─ModuleDict: 3-3 (351,484)\n", "├─OptimizedModule: 1-2 --\n", - "│ └─ImagBehavior: 2-2 4,022,636\n", + "│ └─ImagBehavior: 2-2 706,924\n", "│ │ └─WorldModel: 3-4 (recursive)\n", - "│ │ └─MLP: 3-5 (536,875)\n", - "│ │ └─MLP: 3-6 (525,311)\n", - "│ │ └─MLP: 3-7 (525,311)\n", + "│ │ └─MLP: 3-5 (71,595)\n", + "│ │ └─MLP: 3-6 (98,943)\n", + "│ │ └─MLP: 3-7 (98,943)\n", "├─OptimizedModule: 1-3 (recursive)\n", "│ └─ImagBehavior: 2-3 (recursive)\n", "│ │ └─WorldModel: 3-8 (recursive)\n", @@ -823,13 +1458,13 @@ "│ │ └─MLP: 3-10 (recursive)\n", "│ │ └─MLP: 3-11 (recursive)\n", "==========================================================================================\n", - "Total params: 9,632,769\n", + "Total params: 1,683,329\n", "Trainable params: 0\n", - "Non-trainable params: 9,632,769\n", + "Non-trainable params: 1,683,329\n", "==========================================================================================" ] }, - "execution_count": 19, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -842,7 +1477,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -855,34 +1490,34 @@ "├─OptimizedModule: 1-1 --\n", "│ └─WorldModel: 2-1 --\n", "│ │ └─MultiEncoder: 3-1 --\n", - "│ │ │ └─ConvEncoder: 4-1 (42,528)\n", - "│ │ │ └─MLP: 4-2 (1,952)\n", - "│ │ └─RSSM: 3-2 512\n", - "│ │ │ └─Sequential: 4-3 (273,664)\n", - "│ │ │ └─GRUCell: 4-4 (1,182,720)\n", - "│ │ │ └─Sequential: 4-5 (131,584)\n", - "│ │ │ └─Sequential: 4-6 (283,136)\n", - "│ │ │ └─Linear: 4-7 (263,168)\n", - "│ │ │ └─Linear: 4-8 (263,168)\n", + "│ │ │ └─ConvEncoder: 4-1 (83,680)\n", + "│ │ │ └─MLP: 4-2 (912)\n", + "│ │ └─RSSM: 3-2 128\n", + "│ │ │ └─Sequential: 4-3 (38,528)\n", + "│ │ │ └─GRUCell: 4-4 (99,072)\n", + "│ │ │ └─Sequential: 4-5 (16,640)\n", + "│ │ │ └─Sequential: 4-6 (50,432)\n", + "│ │ │ └─Linear: 4-7 (33,024)\n", + "│ │ │ └─Linear: 4-8 (33,024)\n", "│ │ └─ModuleDict: 3-3 --\n", - "│ │ │ └─MultiDecoder: 4-9 (462,764)\n", - "│ │ │ └─MLP: 4-10 (591,359)\n", - "│ │ │ └─MLP: 4-11 (526,081)\n", + "│ │ │ └─MultiDecoder: 4-9 (186,364)\n", + "│ │ │ └─MLP: 4-10 (98,943)\n", + "│ │ │ └─MLP: 4-11 (66,177)\n", "├─OptimizedModule: 1-2 --\n", - "│ └─ImagBehavior: 2-2 4,022,636\n", + "│ └─ImagBehavior: 2-2 706,924\n", "│ │ └─WorldModel: 3-4 (recursive)\n", "│ │ │ └─MultiEncoder: 4-12 (recursive)\n", "│ │ │ └─RSSM: 4-13 (recursive)\n", "│ │ │ └─ModuleDict: 4-14 (recursive)\n", "│ │ └─MLP: 3-5 --\n", - "│ │ │ └─Sequential: 4-15 (525,824)\n", - "│ │ │ └─Linear: 4-16 (11,051)\n", + "│ │ │ └─Sequential: 4-15 (66,048)\n", + "│ │ │ └─Linear: 4-16 (5,547)\n", "│ │ └─MLP: 3-6 --\n", - "│ │ │ └─Sequential: 4-17 (459,776)\n", - "│ │ │ └─Linear: 4-18 (65,535)\n", + "│ │ │ └─Sequential: 4-17 (66,048)\n", + "│ │ │ └─Linear: 4-18 (32,895)\n", "│ │ └─MLP: 3-7 --\n", - "│ │ │ └─Sequential: 4-19 (459,776)\n", - "│ │ │ └─Linear: 4-20 (65,535)\n", + "│ │ │ └─Sequential: 4-19 (66,048)\n", + "│ │ │ └─Linear: 4-20 (32,895)\n", "├─OptimizedModule: 1-3 (recursive)\n", "│ └─ImagBehavior: 2-3 (recursive)\n", "│ │ └─WorldModel: 3-8 (recursive)\n", @@ -899,13 +1534,13 @@ "│ │ │ └─Sequential: 4-28 (recursive)\n", "│ │ │ └─Linear: 4-29 (recursive)\n", "==========================================================================================\n", - "Total params: 9,632,769\n", + "Total params: 1,683,329\n", "Trainable params: 0\n", - "Non-trainable params: 9,632,769\n", + "Non-trainable params: 1,683,329\n", "==========================================================================================" ] }, - "execution_count": 28, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -916,7 +1551,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -932,7 +1567,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -949,7 +1584,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -958,35 +1593,43 @@ "===================================================================================================================\n", "Layer (type:depth-idx) Input Shape Output Shape Param #\n", "===================================================================================================================\n", - "MultiEncoder [256, 16, 130, 110, 3] [256, 16, 592] --\n", - "├─ConvEncoder: 1-1 [256, 16, 12, 12, 166] [256, 16, 576] --\n", - "│ └─Sequential: 2-1 [4096, 166, 12, 12] [4096, 16, 6, 6] --\n", - "│ │ └─Conv2dSamePad: 3-1 [4096, 166, 12, 12] [4096, 16, 6, 6] (42,496)\n", - "│ │ └─ImgChLayerNorm: 3-2 [4096, 16, 6, 6] [4096, 16, 6, 6] --\n", - "│ │ │ └─LayerNorm: 4-1 [4096, 6, 6, 16] [4096, 6, 6, 16] (32)\n", - "│ │ └─SiLU: 3-3 [4096, 16, 6, 6] [4096, 16, 6, 6] --\n", - "├─MLP: 1-2 [256, 16, 102] [256, 16, 16] --\n", - "│ └─Sequential: 2-2 [256, 16, 102] [256, 16, 16] --\n", - "│ │ └─Linear: 3-4 [256, 16, 102] [256, 16, 16] (1,632)\n", - "│ │ └─LayerNorm: 3-5 [256, 16, 16] [256, 16, 16] (32)\n", - "│ │ └─SiLU: 3-6 [256, 16, 16] [256, 16, 16] --\n", - "│ │ └─Linear: 3-7 [256, 16, 16] [256, 16, 16] (256)\n", - "│ │ └─LayerNorm: 3-8 [256, 16, 16] [256, 16, 16] (32)\n", - "│ │ └─SiLU: 3-9 [256, 16, 16] [256, 16, 16] --\n", + "MultiEncoder [256, 32, 130, 110, 3] [256, 32, 264] --\n", + "├─ConvEncoder: 1-1 [256, 32, 16, 16, 166] [256, 32, 256] --\n", + "│ └─Sequential: 2-1 [8192, 166, 16, 16] [8192, 64, 2, 2] --\n", + "│ │ └─Conv2dSamePad: 3-1 [8192, 166, 16, 16] [8192, 16, 8, 8] (42,496)\n", + "│ │ └─ImgChLayerNorm: 3-2 [8192, 16, 8, 8] [8192, 16, 8, 8] --\n", + "│ │ │ └─LayerNorm: 4-1 [8192, 8, 8, 16] [8192, 8, 8, 16] (32)\n", + "│ │ └─SiLU: 3-3 [8192, 16, 8, 8] [8192, 16, 8, 8] --\n", + "│ │ └─Conv2dSamePad: 3-4 [8192, 16, 8, 8] [8192, 32, 4, 4] (8,192)\n", + "│ │ └─ImgChLayerNorm: 3-5 [8192, 32, 4, 4] [8192, 32, 4, 4] --\n", + "│ │ │ └─LayerNorm: 4-2 [8192, 4, 4, 32] [8192, 4, 4, 32] (64)\n", + "│ │ └─SiLU: 3-6 [8192, 32, 4, 4] [8192, 32, 4, 4] --\n", + "│ │ └─Conv2dSamePad: 3-7 [8192, 32, 4, 4] [8192, 64, 2, 2] (32,768)\n", + "│ │ └─ImgChLayerNorm: 3-8 [8192, 64, 2, 2] [8192, 64, 2, 2] --\n", + "│ │ │ └─LayerNorm: 4-3 [8192, 2, 2, 64] [8192, 2, 2, 64] (128)\n", + "│ │ └─SiLU: 3-9 [8192, 64, 2, 2] [8192, 64, 2, 2] --\n", + "├─MLP: 1-2 [256, 32, 102] [256, 32, 8] --\n", + "│ └─Sequential: 2-2 [256, 32, 102] [256, 32, 8] --\n", + "│ │ └─Linear: 3-10 [256, 32, 102] [256, 32, 8] (816)\n", + "│ │ └─LayerNorm: 3-11 [256, 32, 8] [256, 32, 8] (16)\n", + "│ │ └─SiLU: 3-12 [256, 32, 8] [256, 32, 8] --\n", + "│ │ └─Linear: 3-13 [256, 32, 8] [256, 32, 8] (64)\n", + "│ │ └─LayerNorm: 3-14 [256, 32, 8] [256, 32, 8] (16)\n", + "│ │ └─SiLU: 3-15 [256, 32, 8] [256, 32, 8] --\n", "===================================================================================================================\n", - "Total params: 44,480\n", + "Total params: 84,592\n", "Trainable params: 0\n", - "Non-trainable params: 44,480\n", - "Total mult-adds (G): 6.27\n", + "Non-trainable params: 84,592\n", + "Total mult-adds (G): 24.43\n", "===================================================================================================================\n", - "Input size (MB): 1367.93\n", - "Forward/backward pass size (MB): 39.85\n", - "Params size (MB): 0.18\n", - "Estimated Total Size (MB): 1407.96\n", + "Input size (MB): 3345.09\n", + "Forward/backward pass size (MB): 236.98\n", + "Params size (MB): 0.34\n", + "Estimated Total Size (MB): 3582.41\n", "===================================================================================================================" ] }, - "execution_count": 21, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -997,7 +1640,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -1008,34 +1651,40 @@ "===================================================================================================================\n", "Layer (type:depth-idx) Input Shape Output Shape Param #\n", "===================================================================================================================\n", - "MultiDecoder [256, 16, 1536] -- --\n", - "├─ConvDecoder: 1-1 [256, 16, 1536] [256, 16, 8, 8, 166] --\n", - "│ └─Linear: 2-1 [256, 16, 1536] [256, 16, 256] (393,472)\n", - "│ └─Sequential: 2-2 [4096, 16, 4, 4] [4096, 166, 8, 8] --\n", - "│ │ └─ConvTranspose2d: 3-1 [4096, 16, 4, 4] [4096, 166, 8, 8] (42,662)\n", - "├─MLP: 1-2 [256, 16, 1536] -- --\n", - "│ └─Sequential: 2-3 [256, 16, 1536] [256, 16, 16] --\n", - "│ │ └─Linear: 3-2 [256, 16, 1536] [256, 16, 16] (24,576)\n", - "│ │ └─LayerNorm: 3-3 [256, 16, 16] [256, 16, 16] (32)\n", - "│ │ └─SiLU: 3-4 [256, 16, 16] [256, 16, 16] --\n", - "│ │ └─Linear: 3-5 [256, 16, 16] [256, 16, 16] (256)\n", - "│ │ └─LayerNorm: 3-6 [256, 16, 16] [256, 16, 16] (32)\n", - "│ │ └─SiLU: 3-7 [256, 16, 16] [256, 16, 16] --\n", + "MultiDecoder [256, 32, 384] -- --\n", + "├─ConvDecoder: 1-1 [256, 32, 384] [256, 32, 16, 16, 166] --\n", + "│ └─Linear: 2-1 [256, 32, 384] [256, 32, 256] (98,560)\n", + "│ └─Sequential: 2-2 [8192, 64, 2, 2] [8192, 166, 16, 16] --\n", + "│ │ └─ConvTranspose2d: 3-1 [8192, 64, 2, 2] [8192, 32, 4, 4] (32,768)\n", + "│ │ └─ImgChLayerNorm: 3-2 [8192, 32, 4, 4] [8192, 32, 4, 4] (64)\n", + "│ │ └─SiLU: 3-3 [8192, 32, 4, 4] [8192, 32, 4, 4] --\n", + "│ │ └─ConvTranspose2d: 3-4 [8192, 32, 4, 4] [8192, 16, 8, 8] (8,192)\n", + "│ │ └─ImgChLayerNorm: 3-5 [8192, 16, 8, 8] [8192, 16, 8, 8] (32)\n", + "│ │ └─SiLU: 3-6 [8192, 16, 8, 8] [8192, 16, 8, 8] --\n", + "│ │ └─ConvTranspose2d: 3-7 [8192, 16, 8, 8] [8192, 166, 16, 16] (42,662)\n", + "├─MLP: 1-2 [256, 32, 384] -- --\n", + "│ └─Sequential: 2-3 [256, 32, 384] [256, 32, 8] --\n", + "│ │ └─Linear: 3-8 [256, 32, 384] [256, 32, 8] (3,072)\n", + "│ │ └─LayerNorm: 3-9 [256, 32, 8] [256, 32, 8] (16)\n", + "│ │ └─SiLU: 3-10 [256, 32, 8] [256, 32, 8] --\n", + "│ │ └─Linear: 3-11 [256, 32, 8] [256, 32, 8] (64)\n", + "│ │ └─LayerNorm: 3-12 [256, 32, 8] [256, 32, 8] (16)\n", + "│ │ └─SiLU: 3-13 [256, 32, 8] [256, 32, 8] --\n", "│ └─ModuleDict: 2-4 -- -- --\n", - "│ │ └─Linear: 3-8 [256, 16, 16] [256, 16, 102] (1,734)\n", + "│ │ └─Linear: 3-14 [256, 32, 8] [256, 32, 102] (918)\n", "===================================================================================================================\n", - "Total params: 462,764\n", + "Total params: 186,364\n", "Trainable params: 0\n", - "Non-trainable params: 462,764\n", - "Total mult-adds (G): 11.29\n", + "Non-trainable params: 186,364\n", + "Total mult-adds (G): 98.09\n", "===================================================================================================================\n", - "Input size (MB): 25.17\n", - "Forward/backward pass size (MB): 361.96\n", - "Params size (MB): 1.85\n", - "Estimated Total Size (MB): 388.97\n", + "Input size (MB): 12.58\n", + "Forward/backward pass size (MB): 3011.90\n", + "Params size (MB): 0.75\n", + "Estimated Total Size (MB): 3025.23\n", "===================================================================================================================\n", - "Summary Failed for reward Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Sequential: 1, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 1]\n", - "Summary Failed for cont Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Sequential: 1, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 1]\n" + "Summary Failed for reward Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Sequential: 1, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 1]\n", + "Summary Failed for cont Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Sequential: 1, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 1]\n" ] } ], @@ -1061,7 +1710,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -1071,7 +1720,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -1080,30 +1729,27 @@ "===================================================================================================================\n", "Layer (type:depth-idx) Output Shape Param # Output Shape\n", "===================================================================================================================\n", - "Sequential [256, 16, 256] -- [256, 16, 256]\n", - "├─Linear: 1-1 [256, 16, 256] (393,216) [256, 16, 256]\n", - "├─LayerNorm: 1-2 [256, 16, 256] (512) [256, 16, 256]\n", - "├─SiLU: 1-3 [256, 16, 256] -- [256, 16, 256]\n", - "├─Linear: 1-4 [256, 16, 256] (65,536) [256, 16, 256]\n", - "├─LayerNorm: 1-5 [256, 16, 256] (512) [256, 16, 256]\n", - "├─SiLU: 1-6 [256, 16, 256] -- [256, 16, 256]\n", - "├─Linear: 1-7 [256, 16, 256] (65,536) [256, 16, 256]\n", - "├─LayerNorm: 1-8 [256, 16, 256] (512) [256, 16, 256]\n", - "├─SiLU: 1-9 [256, 16, 256] -- [256, 16, 256]\n", + "Sequential [256, 32, 128] -- [256, 32, 128]\n", + "├─Linear: 1-1 [256, 32, 128] (49,152) [256, 32, 128]\n", + "├─LayerNorm: 1-2 [256, 32, 128] (256) [256, 32, 128]\n", + "├─SiLU: 1-3 [256, 32, 128] -- [256, 32, 128]\n", + "├─Linear: 1-4 [256, 32, 128] (16,384) [256, 32, 128]\n", + "├─LayerNorm: 1-5 [256, 32, 128] (256) [256, 32, 128]\n", + "├─SiLU: 1-6 [256, 32, 128] -- [256, 32, 128]\n", "===================================================================================================================\n", - "Total params: 525,824\n", + "Total params: 66,048\n", "Trainable params: 0\n", - "Non-trainable params: 525,824\n", - "Total mult-adds (M): 134.61\n", + "Non-trainable params: 66,048\n", + "Total mult-adds (M): 16.91\n", "===================================================================================================================\n", - "Input size (MB): 25.17\n", - "Forward/backward pass size (MB): 50.33\n", - "Params size (MB): 2.10\n", - "Estimated Total Size (MB): 77.60\n", + "Input size (MB): 12.58\n", + "Forward/backward pass size (MB): 33.55\n", + "Params size (MB): 0.26\n", + "Estimated Total Size (MB): 46.40\n", "===================================================================================================================" ] }, - "execution_count": 24, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -1117,7 +1763,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -1126,30 +1772,27 @@ "===================================================================================================================\n", "Layer (type:depth-idx) Output Shape Param # Output Shape\n", "===================================================================================================================\n", - "Sequential [256, 16, 256] -- [256, 16, 256]\n", - "├─Linear: 1-1 [256, 16, 256] (393,216) [256, 16, 256]\n", - "├─LayerNorm: 1-2 [256, 16, 256] (512) [256, 16, 256]\n", - "├─SiLU: 1-3 [256, 16, 256] -- [256, 16, 256]\n", - "├─Linear: 1-4 [256, 16, 256] (65,536) [256, 16, 256]\n", - "├─LayerNorm: 1-5 [256, 16, 256] (512) [256, 16, 256]\n", - "├─SiLU: 1-6 [256, 16, 256] -- [256, 16, 256]\n", - "├─Linear: 1-7 [256, 16, 256] (65,536) [256, 16, 256]\n", - "├─LayerNorm: 1-8 [256, 16, 256] (512) [256, 16, 256]\n", - "├─SiLU: 1-9 [256, 16, 256] -- [256, 16, 256]\n", + "Sequential [256, 32, 128] -- [256, 32, 128]\n", + "├─Linear: 1-1 [256, 32, 128] (49,152) [256, 32, 128]\n", + "├─LayerNorm: 1-2 [256, 32, 128] (256) [256, 32, 128]\n", + "├─SiLU: 1-3 [256, 32, 128] -- [256, 32, 128]\n", + "├─Linear: 1-4 [256, 32, 128] (16,384) [256, 32, 128]\n", + "├─LayerNorm: 1-5 [256, 32, 128] (256) [256, 32, 128]\n", + "├─SiLU: 1-6 [256, 32, 128] -- [256, 32, 128]\n", "===================================================================================================================\n", - "Total params: 525,824\n", + "Total params: 66,048\n", "Trainable params: 0\n", - "Non-trainable params: 525,824\n", - "Total mult-adds (M): 134.61\n", + "Non-trainable params: 66,048\n", + "Total mult-adds (M): 16.91\n", "===================================================================================================================\n", - "Input size (MB): 25.17\n", - "Forward/backward pass size (MB): 50.33\n", - "Params size (MB): 2.10\n", - "Estimated Total Size (MB): 77.60\n", + "Input size (MB): 12.58\n", + "Forward/backward pass size (MB): 33.55\n", + "Params size (MB): 0.26\n", + "Estimated Total Size (MB): 46.40\n", "===================================================================================================================" ] }, - "execution_count": 25, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -1161,7 +1804,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -1170,7 +1813,7 @@ "8268" ] }, - "execution_count": 26, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } diff --git a/nbs/load_runs.ipynb b/nbs/load_runs.ipynb new file mode 100644 index 0000000..59c01cf --- /dev/null +++ b/nbs/load_runs.ipynb @@ -0,0 +1,142 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from craftax.environment_base.util import load_compressed_pickle" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "import os, sys\n", + "os.sys.path.append('/media/wassname/SGIronWolf/projects5/2024/Craftax/craftax/craftax/')" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'craftax.craftax_state'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[14], line 11\u001b[0m\n\u001b[1;32m 8\u001b[0m data \u001b[38;5;241m=\u001b[39m pickle\u001b[38;5;241m.\u001b[39mload(data, errors\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mignore\u001b[39m\u001b[38;5;124m'\u001b[39m, fix_imports\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m data\n\u001b[0;32m---> 11\u001b[0m \u001b[43mload_compressed_pickle\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43m/home/wassname/Downloads/people/run1.pbz2\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[14], line 8\u001b[0m, in \u001b[0;36mload_compressed_pickle\u001b[0;34m(file)\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mload_compressed_pickle\u001b[39m(file: \u001b[38;5;28mstr\u001b[39m):\n\u001b[1;32m 7\u001b[0m data \u001b[38;5;241m=\u001b[39m bz2\u001b[38;5;241m.\u001b[39mBZ2File(file, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrb\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 8\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[43mpickle\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43merrors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mignore\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfix_imports\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m data\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'craftax.craftax_state'" + ] + }, + { + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n", + "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n", + "\u001b[1;31mClick here for more info. \n", + "\u001b[1;31mView Jupyter log for further details." + ] + } + ], + "source": [ + "import pickle, bz2\n", + "import craftax.craftax.craftax_state\n", + "craftax.craftax_state = craftax.craftax.craftax_state\n", + "\n", + "\n", + "def load_compressed_pickle(file: str):\n", + " data = bz2.BZ2File(file, \"rb\")\n", + " data = pickle.load(data, errors='ignore', fix_imports=False)\n", + " return data\n", + "\n", + "load_compressed_pickle('/home/wassname/Downloads/people/run1.pbz2')" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[0;31mSignature:\u001b[0m\n", + "\u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mfile\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mfix_imports\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mencoding\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'ASCII'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'strict'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mbuffers\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mDocstring:\u001b[0m\n", + "Read and return an object from the pickle data stored in a file.\n", + "\n", + "This is equivalent to ``Unpickler(file).load()``, but may be more\n", + "efficient.\n", + "\n", + "The protocol version of the pickle is detected automatically, so no\n", + "protocol argument is needed. Bytes past the pickled object's\n", + "representation are ignored.\n", + "\n", + "The argument *file* must have two methods, a read() method that takes\n", + "an integer argument, and a readline() method that requires no\n", + "arguments. Both methods should return bytes. Thus *file* can be a\n", + "binary file object opened for reading, an io.BytesIO object, or any\n", + "other custom object that meets this interface.\n", + "\n", + "Optional keyword arguments are *fix_imports*, *encoding* and *errors*,\n", + "which are used to control compatibility support for pickle stream\n", + "generated by Python 2. If *fix_imports* is True, pickle will try to\n", + "map the old Python 2 names to the new names used in Python 3. The\n", + "*encoding* and *errors* tell pickle how to decode 8-bit string\n", + "instances pickled by Python 2; these default to 'ASCII' and 'strict',\n", + "respectively. The *encoding* can be 'bytes' to read these 8-bit\n", + "string instances as bytes objects.\n", + "\u001b[0;31mType:\u001b[0m builtin_function_or_method" + ] + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/networks.py b/networks.py index a0cca6e..915c82a 100644 --- a/networks.py +++ b/networks.py @@ -9,6 +9,10 @@ from torch import distributions as torchd from loguru import logger import tools from einops import rearrange +from torchinfo import summary + +def my_summary(model, input_data): + return summary(model, input_data, col_names=('input_size', 'output_size', 'num_params', 'mult_adds'), verbose=0, row_settings=['depth', 'var_names', 'ascii_only']) class RSSM(nn.Module): @@ -332,6 +336,7 @@ class MultiEncoder(nn.Module): input_shape, cnn_depth, act, norm, kernel_size, minres ) self.outdim += self._cnn.outdim + logger.debug(f"Encoder cnn\n{my_summary(self._cnn, (1,)+input_shape)}") if self.mlp_shapes: input_size = sum([sum(v) for v in self.mlp_shapes.values()]) self._mlp = MLP( @@ -344,6 +349,7 @@ class MultiEncoder(nn.Module): symlog_inputs=symlog_inputs, name="Encoder", ) + logger.debug(f"Encoder mlp\n{my_summary(self._mlp, (1,input_size))}") self.outdim += mlp_units def forward(self, obs): @@ -405,6 +411,7 @@ class MultiDecoder(nn.Module): outscale=outscale, cnn_sigmoid=cnn_sigmoid, ) + logger.debug(f"Decoder cnn\n{my_summary(self._cnn, (1,1,feat_size))}") if self.mlp_shapes: self._mlp = MLP( feat_size, @@ -417,6 +424,7 @@ class MultiDecoder(nn.Module): outscale=outscale, name="Decoder", ) + logger.debug(f"Decoder mlp\n{my_summary(self._mlp, (1,feat_size))}") self._image_dist = image_dist def forward(self, features): diff --git a/tools.py b/tools.py index fcad775..ef05a82 100644 --- a/tools.py +++ b/tools.py @@ -15,7 +15,7 @@ from torch import nn from torch.nn import functional as F from torch import distributions as torchd from torch.utils.tensorboard import SummaryWriter - +from contextlib import contextmanager to_np = lambda x: x.detach().cpu().numpy() @@ -82,7 +82,7 @@ class Logger: scalars.append(("fps", self._compute_fps(step))) # print out the episode stats stats = " / ".join(f"{k.replace('log_achievement_', '')} {v:.1f}" for k, v in scalars) - logger.opt(colors=True).info(f"[{step}] {stats}") + logger.opt(colors=True).debug(f"[{step}] {stats}") with (self._logdir / "metrics.jsonl").open("a") as f: f.write(json.dumps({"step": step, **dict(scalars)}) + "\n") for name, value in scalars: @@ -127,6 +127,14 @@ class Logger: self._writer.add_video(name, value, step, 16) +@contextmanager +def cond_tqdm(pbar=None, *args, **kwargs): + if pbar is None: + with tqdm(*args, **kwargs) as pbar: + yield pbar + else: + yield pbar + def simulate( agent, envs, @@ -151,7 +159,7 @@ def simulate( reward = [0] * len(envs) else: step, episode, done, length, obs, agent_state, reward = state - with tqdm(total=steps, disable=pbar is None) as pbar: + with cond_tqdm(total=steps, pbar=pbar) as pbar: while (steps and step < steps) or (episodes and episode < episodes): # reset envs if necessary if done.any():