From c77993207a9182ae3fee358cbbdceef3b12836ec Mon Sep 17 00:00:00 2001 From: wassname Date: Thu, 6 Jun 2024 20:29:17 +0800 Subject: [PATCH] try craftax smaller --- configs.yaml | 48 +- dreamer.py | 8 +- envs/craftax_env.py | 10 +- nbs/02_torchinfo copy.ipynb | 846 +++++++++++++++++++++++++++++++++--- nbs/02_torchinfo.ipynb | 17 + networks.py | 14 +- 6 files changed, 827 insertions(+), 116 deletions(-) diff --git a/configs.yaml b/configs.yaml index 03b2c7a..50ee679 100644 --- a/configs.yaml +++ b/configs.yaml @@ -149,26 +149,6 @@ craftax: imag_gradient: 'reinforce' craftax_small: - task: craftax_Craftax-Symbolic-AutoReset-v1 - step: 1e6 - action_repeat: 1 - envs: 1 - train_ratio: 512 - video_pred_log: false - dyn_hidden: 512 - dyn_deter: 512 - units: 512 - encoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 3, mlp_units: 256} - decoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 3, mlp_units: 256} - actor: {layers: 3, dist: 'onehot', std: 'none'} - value: {layers: 3} - reward_head: {layers: 3} - cont_head: {layers: 3} - imag_gradient: 'reinforce' - batch_size: 128 - batch_length: 16 - -craftax_small2: task: craftax_Craftax-Symbolic-AutoReset-v1 step: 1e6 action_repeat: 1 @@ -176,10 +156,12 @@ craftax_small2: train_ratio: 512 video_pred_log: false dyn_hidden: 256 - dyn_deter: 512 + dyn_deter: 256 + dyn_stoch: 24 + dyn_discrete: 24 # note: depth is cnn hidden_dim - encoder: {cnn_keys: 'state_map', cnn_depth: 16, kernel_size: 4, minres: 4, mlp_keys: "state_inventory", mlp_layers: 2, mlp_units: 16} - decoder: {cnn_keys: 'state_map', cnn_depth: 16, kernel_size: 4, minres: 4, mlp_keys: "state_inventory", mlp_layers: 2, mlp_units: 16} + encoder: {cnn_keys: 'state_map', cnn_depth: 32, kernel_size: 4, minres: 2, mlp_keys: "state_inventory", mlp_layers: 2, mlp_units: 16} + decoder: {cnn_keys: 'state_map', cnn_depth: 32, kernel_size: 4, minres: 2, mlp_keys: "state_inventory", mlp_layers: 2, mlp_units: 16} actor: {layers: 3, dist: 'onehot', std: 'none'} value: {layers: 3} # note units is the head hidden_dim @@ -188,27 +170,31 @@ craftax_small2: cont_head: {layers: 3} imag_gradient: 'reinforce' batch_size: 256 - batch_length: 16 + batch_length: 32 craftax_smaller: task: craftax_Craftax-Symbolic-AutoReset-v1 step: 1e6 action_repeat: 1 envs: 1 - train_ratio: 256 + train_ratio: 512 video_pred_log: false - dyn_hidden: 256 - dyn_deter: 1024 - units: 256 - encoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 2, mlp_units: 256, } - decoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 2, mlp_units: 256} + dyn_hidden: 128 + dyn_deter: 128 + dyn_stoch: 16 + dyn_discrete: 16 + # note: depth is cnn hidden_dim + encoder: {cnn_keys: 'state_map', cnn_depth: 16, kernel_size: 4, minres: 2, mlp_keys: "state_inventory", mlp_layers: 2, mlp_units: 8} + decoder: {cnn_keys: 'state_map', cnn_depth: 16, kernel_size: 4, minres: 2, mlp_keys: "state_inventory", mlp_layers: 2, mlp_units: 8} actor: {layers: 2, dist: 'onehot', std: 'none'} value: {layers: 2} + # note units is the head hidden_dim + units: 128 reward_head: {layers: 2} cont_head: {layers: 2} imag_gradient: 'reinforce' batch_size: 256 - batch_length: 16 + batch_length: 32 atari100k: steps: 4e5 diff --git a/dreamer.py b/dreamer.py index ed940e0..79204a3 100644 --- a/dreamer.py +++ b/dreamer.py @@ -360,9 +360,14 @@ def main(config): def parse_args(argv=None): parser = argparse.ArgumentParser() parser.add_argument("--configs", nargs="+") + if argv is None: + argv = sys.argv args, remaining = parser.parse_known_args(argv[1:]) + + # load config, using relative path + root_dir = pathlib.Path(__file__).parent configs = yaml.safe_load( - (pathlib.Path(argv[0]).parent / "configs.yaml").read_text() + (root_dir / "configs.yaml").read_text() ) def recursive_update(base, update): @@ -376,6 +381,7 @@ def parse_args(argv=None): defaults = {} for name in name_list: recursive_update(defaults, configs[name]) + parser = argparse.ArgumentParser() for key, value in sorted(defaults.items(), key=lambda x: x[0]): arg_type = tools.args_type(value) diff --git a/envs/craftax_env.py b/envs/craftax_env.py index 56efcd3..192700e 100644 --- a/envs/craftax_env.py +++ b/envs/craftax_env.py @@ -195,8 +195,8 @@ def reshape_state(state: Float[Tensor, 'frames state_dim']) -> (Float[Tensor,'fr https://github.com/MichaelTMatthews/Craftax/blob/main/obs_description.md """ map = rearrange(state[:, :8217], 'frames (h w c) -> frames h w c', h=9, w=11, c=83) - # now pad from (9,11) to (12,12) using torch - map = F.pad(map, (0, 0, 1, 0, 2, 1)) + # now pad from (9,11) to (16,16) so that it's 2*n + map = F.pad(map, (0, 0, 3, 2, 4, 3)) map = rearrange(map, 'f h w c -> h w (f c)') inventories = rearrange(state[:, 8217:], 'frames c -> (frames c)') return map, inventories @@ -214,9 +214,9 @@ class Craftax: def observation_space(self): frames = self._env.observation_space.shape[0] spaces = { - "state": gym.spaces.Box(0, 1, (np.prod(self._env.observation_space.shape),), dtype=np.float32), - "state_map": gym.spaces.Box(0, 1, (12, 12, frames*83), dtype=np.float32), - "state_inventory": gym.spaces.Box(0, 1, (frames * 51,), dtype=np.float32), + # "state": gym.spaces.Box(0, 1, (np.prod(self._env.observation_space.shape),), dtype=np.float32), + "state_map": gym.spaces.Box(0, 1, (16, 16, frames*83), dtype=np.float16), + "state_inventory": gym.spaces.Box(0, 1, (frames * 51,), dtype=np.float16), "image": gym.spaces.Box(0, 255, (130, 110, 3), dtype=np.uint8), "is_first": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8), "is_last": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8), diff --git a/nbs/02_torchinfo copy.ipynb b/nbs/02_torchinfo copy.ipynb index 7503e74..5886a8c 100644 --- a/nbs/02_torchinfo copy.ipynb +++ b/nbs/02_torchinfo copy.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -19,7 +19,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -48,7 +48,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -61,10 +61,10 @@ { "data": { "text/plain": [ - "Namespace(act='SiLU', action_repeat=1, actor={'layers': 3, 'dist': 'onehot', 'entropy': 0.0003, 'unimix_ratio': 0.01, 'std': 'none', 'min_std': 0.1, 'max_std': 1.0, 'temp': 0.1, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 1.0}, batch_length=16, batch_size=256, compile=True, cont_head={'layers': 3, 'loss_scale': 1.0, 'outscale': 1.0}, critic={'layers': 2, 'dist': 'symlog_disc', 'slow_target': True, 'slow_target_update': 1, 'slow_target_fraction': 0.02, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 0.0}, dataset_size=1000000, debug=False, decoder={'mlp_keys': 'state_inventory', 'cnn_keys': 'state_map', 'act': 'SiLU', 'norm': True, 'cnn_depth': 16, 'kernel_size': 4, 'minres': 4, 'mlp_layers': 2, 'mlp_units': 16, 'cnn_sigmoid': False, 'image_dist': 'mse', 'vector_dist': 'symlog_mse', 'outscale': 1.0}, deterministic_run=False, device='cuda:0', disag_action_cond=False, disag_layers=4, disag_log=True, disag_models=10, disag_offset=1, disag_target='stoch', disag_units=400, discount=0.997, discount_lambda=0.95, dyn_deter=512, dyn_discrete=32, dyn_hidden=256, dyn_mean_act='none', dyn_min_std=0.1, dyn_rec_depth=1, dyn_scale=0.5, dyn_std_act='sigmoid2', dyn_stoch=32, encoder={'mlp_keys': 'state_inventory', 'cnn_keys': 'state_map', 'act': 'SiLU', 'norm': True, 'cnn_depth': 16, 'kernel_size': 4, 'minres': 4, 'mlp_layers': 2, 'mlp_units': 16, 'symlog_inputs': True}, envs=1, eval_episode_num=10, eval_every=10000.0, eval_state_mean=False, evaldir=None, expl_behavior='greedy', expl_extr_scale=0.0, expl_intr_scale=1.0, expl_until=0, grad_clip=1000, grad_heads=('decoder', 'reward', 'cont'), grayscale=False, imag_gradient='reinforce', imag_gradient_mix=0.0, imag_horizon=15, initial='learned', kl_free=1.0, log_every=10000.0, logdir='../logdir/craftax_small2', model_lr=0.0001, norm=True, offline_evaldir='', offline_traindir='', opt='adam', opt_eps=1e-08, parallel=False, precision=32, prefill=2500, pretrain=100, rep_scale=0.1, reset_every=0, reward_EMA=True, reward_head={'layers': 3, 'dist': 'symlog_disc', 'loss_scale': 1.0, 'outscale': 0.0}, seed=0, size=(64, 64), step=1000000.0, steps=1000000.0, task='craftax_Craftax-Symbolic-AutoReset-v1', time_limit=1000, train_ratio=512, traindir=None, unimix_ratio=0.01, units=256, value={'layers': 3}, video_pred_log=False, weight_decay=0.0)" + "Namespace(act='SiLU', action_repeat=1, actor={'layers': 3, 'dist': 'onehot', 'entropy': 0.0003, 'unimix_ratio': 0.01, 'std': 'none', 'min_std': 0.1, 'max_std': 1.0, 'temp': 0.1, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 1.0}, batch_length=32, batch_size=256, compile=True, cont_head={'layers': 3, 'loss_scale': 1.0, 'outscale': 1.0}, critic={'layers': 2, 'dist': 'symlog_disc', 'slow_target': True, 'slow_target_update': 1, 'slow_target_fraction': 0.02, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 0.0}, dataset_size=1000000, debug=False, decoder={'mlp_keys': 'state_inventory', 'cnn_keys': 'state_map', 'act': 'SiLU', 'norm': True, 'cnn_depth': 32, 'kernel_size': 4, 'minres': 2, 'mlp_layers': 2, 'mlp_units': 16, 'cnn_sigmoid': False, 'image_dist': 'mse', 'vector_dist': 'symlog_mse', 'outscale': 1.0}, deterministic_run=False, device='cuda:0', disag_action_cond=False, disag_layers=4, disag_log=True, disag_models=10, disag_offset=1, disag_target='stoch', disag_units=400, discount=0.997, discount_lambda=0.95, dyn_deter=256, dyn_discrete=24, dyn_hidden=256, dyn_mean_act='none', dyn_min_std=0.1, dyn_rec_depth=1, dyn_scale=0.5, dyn_std_act='sigmoid2', dyn_stoch=24, encoder={'mlp_keys': 'state_inventory', 'cnn_keys': 'state_map', 'act': 'SiLU', 'norm': True, 'cnn_depth': 32, 'kernel_size': 4, 'minres': 2, 'mlp_layers': 2, 'mlp_units': 16, 'symlog_inputs': True}, envs=1, eval_episode_num=10, eval_every=10000.0, eval_state_mean=False, evaldir=None, expl_behavior='greedy', expl_extr_scale=0.0, expl_intr_scale=1.0, expl_until=0, grad_clip=1000, grad_heads=('decoder', 'reward', 'cont'), grayscale=False, imag_gradient='reinforce', imag_gradient_mix=0.0, imag_horizon=15, initial='learned', kl_free=1.0, log_every=10000.0, logdir='../logdir/craftax_small2', model_lr=0.0001, norm=True, offline_evaldir='', offline_traindir='', opt='adam', opt_eps=1e-08, parallel=False, precision=32, prefill=2500, pretrain=100, rep_scale=0.1, reset_every=0, reward_EMA=True, reward_head={'layers': 3, 'dist': 'symlog_disc', 'loss_scale': 1.0, 'outscale': 0.0}, seed=0, size=(64, 64), step=1000000.0, steps=1000000.0, task='craftax_Craftax-Symbolic-AutoReset-v1', time_limit=1000, train_ratio=512, traindir=None, unimix_ratio=0.01, units=256, value={'layers': 3}, video_pred_log=False, weight_decay=0.0)" ] }, - "execution_count": 3, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -80,27 +80,19 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[32m2024-06-06 16:21:39.870\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m27\u001b[0m - \u001b[1mLogdir ../logdir/craftax_small2\u001b[0m\n", - "\u001b[32m2024-06-06 16:21:39.887\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m36\u001b[0m - \u001b[1mCreate envs.\u001b[0m\n", - "\u001b[32m2024-06-06 16:22:16.800\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m57\u001b[0m - \u001b[1mAction Space Box(0.0, 1.0, (43,), float32)\u001b[0m\n", - "\u001b[32m2024-06-06 16:22:16.801\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m63\u001b[0m - \u001b[1mPrefill dataset (2500 steps).\u001b[0m\n", - "\u001b[32m2024-06-06 16:23:40.174\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mtools\u001b[0m:\u001b[36mwrite\u001b[0m:\u001b[36m85\u001b[0m - \u001b[1m[0] log_achievements/cast_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/cast_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_coal \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_diamond \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_drink \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_iron \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_ruby \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapling \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapphire \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_wood \u001b[31m100.0\u001b[0m\u001b[1m / log_achievements/damage_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_deep_thing \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_fire_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_frost_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_warrior \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_ice_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_knight \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_kobold \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_lizard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_mage \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_solider \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_pigman \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_skeleton \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_zombie \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/drink_potion \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_bat \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_cow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_snail \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_dungeon \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_fire_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_gnomish_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_graveyard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_ice_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_sewers \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_troll_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_vault \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/find_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/fire_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_arrow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/open_chest \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_furnace \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_table \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/wake_up \u001b[31m100.0\u001b[0m\u001b[1m / dataset_size \u001b[31m242.0\u001b[0m\u001b[1m / train_return \u001b[31m1.1\u001b[0m\u001b[1m / train_length \u001b[31m242.0\u001b[0m\u001b[1m / train_episodes \u001b[31m1.0\u001b[0m\u001b[1m\u001b[0m\n", - "\u001b[32m2024-06-06 16:23:42.913\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mtools\u001b[0m:\u001b[36mwrite\u001b[0m:\u001b[36m85\u001b[0m - \u001b[1m[0] log_achievements/cast_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/cast_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_coal \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_diamond \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_drink \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_iron \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_ruby \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapling \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapphire \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_wood \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/damage_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_deep_thing \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_fire_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_frost_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_warrior \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_ice_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_knight \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_kobold \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_lizard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_mage \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_solider \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_pigman \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_skeleton \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_zombie \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/drink_potion \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_bat \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_cow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_snail \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_dungeon \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_fire_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_gnomish_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_graveyard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_ice_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_sewers \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_troll_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_vault \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/find_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/fire_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_arrow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/open_chest \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_furnace \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_table \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/wake_up \u001b[31m100.0\u001b[0m\u001b[1m / dataset_size \u001b[31m551.0\u001b[0m\u001b[1m / train_return \u001b[31m0.1\u001b[0m\u001b[1m / train_length \u001b[31m309.0\u001b[0m\u001b[1m / train_episodes \u001b[31m2.0\u001b[0m\u001b[1m\u001b[0m\n", - "\u001b[32m2024-06-06 16:23:44.957\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mtools\u001b[0m:\u001b[36mwrite\u001b[0m:\u001b[36m85\u001b[0m - \u001b[1m[0] log_achievements/cast_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/cast_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_coal \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_diamond \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_drink \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_iron \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_ruby \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapling \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapphire \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_wood \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/damage_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_deep_thing \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_fire_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_frost_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_warrior \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_ice_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_knight \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_kobold \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_lizard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_mage \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_solider \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_pigman \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_skeleton \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_zombie \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/drink_potion \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_bat \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_cow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_snail \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_dungeon \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_fire_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_gnomish_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_graveyard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_ice_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_sewers \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_troll_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_vault \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/find_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/fire_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_arrow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/open_chest \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_furnace \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_table \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/wake_up \u001b[31m100.0\u001b[0m\u001b[1m / dataset_size \u001b[31m781.0\u001b[0m\u001b[1m / train_return \u001b[31m0.1\u001b[0m\u001b[1m / train_length \u001b[31m230.0\u001b[0m\u001b[1m / train_episodes \u001b[31m3.0\u001b[0m\u001b[1m\u001b[0m\n", - "\u001b[32m2024-06-06 16:23:48.066\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mtools\u001b[0m:\u001b[36mwrite\u001b[0m:\u001b[36m85\u001b[0m - \u001b[1m[0] log_achievements/cast_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/cast_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_coal \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_diamond \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_drink \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_iron \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_ruby \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapling \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapphire \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_wood \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/damage_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_deep_thing \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_fire_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_frost_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_warrior \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_ice_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_knight \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_kobold \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_lizard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_mage \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_solider \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_pigman \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_skeleton \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_zombie \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/drink_potion \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_bat \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_cow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_snail \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_dungeon \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_fire_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_gnomish_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_graveyard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_ice_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_sewers \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_troll_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_vault \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/find_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/fire_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_arrow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/open_chest \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_furnace \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_table \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/wake_up \u001b[31m100.0\u001b[0m\u001b[1m / dataset_size \u001b[31m1134.0\u001b[0m\u001b[1m / train_return \u001b[31m0.1\u001b[0m\u001b[1m / train_length \u001b[31m353.0\u001b[0m\u001b[1m / train_episodes \u001b[31m4.0\u001b[0m\u001b[1m\u001b[0m\n", - "\u001b[32m2024-06-06 16:23:51.250\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mtools\u001b[0m:\u001b[36mwrite\u001b[0m:\u001b[36m85\u001b[0m - \u001b[1m[0] log_achievements/cast_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/cast_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_coal \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_diamond \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_drink \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_iron \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_ruby \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapling \u001b[31m100.0\u001b[0m\u001b[1m / log_achievements/collect_sapphire \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_wood \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/damage_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_deep_thing \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_fire_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_frost_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_warrior \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_ice_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_knight \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_kobold \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_lizard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_mage \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_solider \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_pigman \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_skeleton \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_zombie \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/drink_potion \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_bat \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_cow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_snail \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_dungeon \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_fire_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_gnomish_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_graveyard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_ice_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_sewers \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_troll_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_vault \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/find_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/fire_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_arrow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/open_chest \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_furnace \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_plant \u001b[31m100.0\u001b[0m\u001b[1m / log_achievements/place_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_table \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/wake_up \u001b[31m100.0\u001b[0m\u001b[1m / dataset_size \u001b[31m1509.0\u001b[0m\u001b[1m / train_return \u001b[31m2.1\u001b[0m\u001b[1m / train_length \u001b[31m375.0\u001b[0m\u001b[1m / train_episodes \u001b[31m5.0\u001b[0m\u001b[1m\u001b[0m\n", - "\u001b[32m2024-06-06 16:23:54.187\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mtools\u001b[0m:\u001b[36mwrite\u001b[0m:\u001b[36m85\u001b[0m - \u001b[1m[0] log_achievements/cast_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/cast_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_coal \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_diamond \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_drink \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_iron \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_ruby \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapling \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapphire \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_wood \u001b[31m100.0\u001b[0m\u001b[1m / log_achievements/damage_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_deep_thing \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_fire_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_frost_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_warrior \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_ice_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_knight \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_kobold \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_lizard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_mage \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_solider \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_pigman \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_skeleton \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_zombie \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/drink_potion \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_bat \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_cow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_snail \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_dungeon \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_fire_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_gnomish_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_graveyard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_ice_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_sewers \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_troll_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_vault \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/find_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/fire_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_arrow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/open_chest \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_furnace \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_table \u001b[31m100.0\u001b[0m\u001b[1m / log_achievements/place_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/wake_up \u001b[31m100.0\u001b[0m\u001b[1m / dataset_size \u001b[31m1827.0\u001b[0m\u001b[1m / train_return \u001b[31m2.1\u001b[0m\u001b[1m / train_length \u001b[31m318.0\u001b[0m\u001b[1m / train_episodes \u001b[31m6.0\u001b[0m\u001b[1m\u001b[0m\n", - "\u001b[32m2024-06-06 16:23:56.423\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mtools\u001b[0m:\u001b[36mwrite\u001b[0m:\u001b[36m85\u001b[0m - \u001b[1m[0] log_achievements/cast_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/cast_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_coal \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_diamond \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_drink \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_iron \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_ruby \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapling \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapphire \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_wood \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/damage_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_deep_thing \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_fire_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_frost_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_warrior \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_ice_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_knight \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_kobold \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_lizard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_mage \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_solider \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_pigman \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_skeleton \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_zombie \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/drink_potion \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_bat \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_cow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_snail \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_dungeon \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_fire_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_gnomish_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_graveyard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_ice_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_sewers \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_troll_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_vault \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/find_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/fire_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_arrow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/open_chest \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_furnace \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_table \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/wake_up \u001b[31m100.0\u001b[0m\u001b[1m / dataset_size \u001b[31m2084.0\u001b[0m\u001b[1m / train_return \u001b[31m0.1\u001b[0m\u001b[1m / train_length \u001b[31m257.0\u001b[0m\u001b[1m / train_episodes \u001b[31m7.0\u001b[0m\u001b[1m\u001b[0m\n", - "\u001b[32m2024-06-06 16:23:59.835\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mtools\u001b[0m:\u001b[36mwrite\u001b[0m:\u001b[36m85\u001b[0m - \u001b[1m[0] log_achievements/cast_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/cast_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_coal \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_diamond \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_drink \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_iron \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_ruby \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapling \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapphire \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_wood \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/damage_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_deep_thing \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_fire_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_frost_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_warrior \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_ice_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_knight \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_kobold \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_lizard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_mage \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_solider \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_pigman \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_skeleton \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_zombie \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/drink_potion \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_bat \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_cow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_snail \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_dungeon \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_fire_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_gnomish_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_graveyard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_ice_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_sewers \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_troll_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_vault \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/find_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/fire_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_arrow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/open_chest \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_furnace \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_table \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/wake_up \u001b[31m100.0\u001b[0m\u001b[1m / dataset_size \u001b[31m2474.0\u001b[0m\u001b[1m / train_return \u001b[31m0.1\u001b[0m\u001b[1m / train_length \u001b[31m390.0\u001b[0m\u001b[1m / train_episodes \u001b[31m8.0\u001b[0m\u001b[1m\u001b[0m\n", - "\u001b[32m2024-06-06 16:24:00.056\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m92\u001b[0m - \u001b[1mLogger: (2500 steps).\u001b[0m\n", - "\u001b[32m2024-06-06 16:24:00.057\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m94\u001b[0m - \u001b[1mSimulate agent.\u001b[0m\n" + "\u001b[32m2024-06-06 17:08:10.379\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m27\u001b[0m - \u001b[1mLogdir ../logdir/craftax_small2\u001b[0m\n", + "\u001b[32m2024-06-06 17:08:10.384\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m36\u001b[0m - \u001b[1mCreate envs.\u001b[0m\n", + "\u001b[32m2024-06-06 17:08:41.190\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m57\u001b[0m - \u001b[1mAction Space Box(0.0, 1.0, (43,), float32)\u001b[0m\n", + "\u001b[32m2024-06-06 17:08:41.191\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m63\u001b[0m - \u001b[1mPrefill dataset (26 steps).\u001b[0m\n", + "\u001b[32m2024-06-06 17:09:31.587\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m92\u001b[0m - \u001b[1mLogger: (2500 steps).\u001b[0m\n", + "\u001b[32m2024-06-06 17:09:31.588\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m94\u001b[0m - \u001b[1mSimulate agent.\u001b[0m\n" ] } ], @@ -205,18 +197,476 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/.venv/lib/python3.9/site-packages/numpy/core/numeric.py:330: RuntimeWarning: invalid value encountered in cast\n", + " multiarray.copyto(a, fill_value, casting='unsafe')\n" + ] + }, + { + "data": { + "text/plain": [ + "Dict('image': Box(0, 255, (130, 110, 3), uint8), 'is_first': Box(0, 0, (1,), uint8), 'is_last': Box(0, 0, (1,), uint8), 'is_terminal': Box(0, 0, (1,), uint8), 'log_achievement_cast_fireball': Box(-inf, inf, (1,), float32), 'log_achievement_cast_iceball': Box(-inf, inf, (1,), float32), 'log_achievement_collect_coal': Box(-inf, inf, (1,), float32), 'log_achievement_collect_diamond': Box(-inf, inf, (1,), float32), 'log_achievement_collect_drink': Box(-inf, inf, (1,), float32), 'log_achievement_collect_iron': Box(-inf, inf, (1,), float32), 'log_achievement_collect_ruby': Box(-inf, inf, (1,), float32), 'log_achievement_collect_sapling': Box(-inf, inf, (1,), float32), 'log_achievement_collect_sapphire': Box(-inf, inf, (1,), float32), 'log_achievement_collect_stone': Box(-inf, inf, (1,), float32), 'log_achievement_collect_wood': Box(-inf, inf, (1,), float32), 'log_achievement_damage_necromancer': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_archer': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_deep_thing': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_fire_elemental': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_frost_troll': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_gnome_archer': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_gnome_warrior': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_ice_elemental': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_knight': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_kobold': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_lizard': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_necromancer': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_orc_mage': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_orc_solider': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_pigman': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_skeleton': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_troll': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_zombie': Box(-inf, inf, (1,), float32), 'log_achievement_drink_potion': Box(-inf, inf, (1,), float32), 'log_achievement_eat_bat': Box(-inf, inf, (1,), float32), 'log_achievement_eat_cow': Box(-inf, inf, (1,), float32), 'log_achievement_eat_plant': Box(-inf, inf, (1,), float32), 'log_achievement_eat_snail': Box(-inf, inf, (1,), float32), 'log_achievement_enchant_armour': Box(-inf, inf, (1,), float32), 'log_achievement_enchant_sword': Box(-inf, inf, (1,), float32), 'log_achievement_enter_dungeon': Box(-inf, inf, (1,), float32), 'log_achievement_enter_fire_realm': Box(-inf, inf, (1,), float32), 'log_achievement_enter_gnomish_mines': Box(-inf, inf, (1,), float32), 'log_achievement_enter_graveyard': Box(-inf, inf, (1,), float32), 'log_achievement_enter_ice_realm': Box(-inf, inf, (1,), float32), 'log_achievement_enter_sewers': Box(-inf, inf, (1,), float32), 'log_achievement_enter_troll_mines': Box(-inf, inf, (1,), float32), 'log_achievement_enter_vault': Box(-inf, inf, (1,), float32), 'log_achievement_find_bow': Box(-inf, inf, (1,), float32), 'log_achievement_fire_bow': Box(-inf, inf, (1,), float32), 'log_achievement_learn_fireball': Box(-inf, inf, (1,), float32), 'log_achievement_learn_iceball': Box(-inf, inf, (1,), float32), 'log_achievement_make_arrow': Box(-inf, inf, (1,), float32), 'log_achievement_make_diamond_armour': Box(-inf, inf, (1,), float32), 'log_achievement_make_diamond_pickaxe': Box(-inf, inf, (1,), float32), 'log_achievement_make_diamond_sword': Box(-inf, inf, (1,), float32), 'log_achievement_make_iron_armour': Box(-inf, inf, (1,), float32), 'log_achievement_make_iron_pickaxe': Box(-inf, inf, (1,), float32), 'log_achievement_make_iron_sword': Box(-inf, inf, (1,), float32), 'log_achievement_make_stone_pickaxe': Box(-inf, inf, (1,), float32), 'log_achievement_make_stone_sword': Box(-inf, inf, (1,), float32), 'log_achievement_make_torch': Box(-inf, inf, (1,), float32), 'log_achievement_make_wood_pickaxe': Box(-inf, inf, (1,), float32), 'log_achievement_make_wood_sword': Box(-inf, inf, (1,), float32), 'log_achievement_open_chest': Box(-inf, inf, (1,), float32), 'log_achievement_place_furnace': Box(-inf, inf, (1,), float32), 'log_achievement_place_plant': Box(-inf, inf, (1,), float32), 'log_achievement_place_stone': Box(-inf, inf, (1,), float32), 'log_achievement_place_table': Box(-inf, inf, (1,), float32), 'log_achievement_place_torch': Box(-inf, inf, (1,), float32), 'log_achievement_wake_up': Box(-inf, inf, (1,), float32), 'log_reward': Box(-inf, inf, (1,), float32), 'state': Box(0.0, 1.0, (16536,), float32), 'state_inventory': Box(0.0, 1.0, (102,), float32), 'state_map': Box(0.0, 1.0, (12, 12, 166), float32))" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "train_envs[0].observation_space" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32m2024-06-06 17:09:31.695\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m324\u001b[0m - \u001b[1mEncoder CNN shapes: {'state_map': (12, 12, 166)}\u001b[0m\n", + "\u001b[32m2024-06-06 17:09:31.696\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m325\u001b[0m - \u001b[1mEncoder MLP shapes: {'state_inventory': (102,)}\u001b[0m\n", + "\u001b[32m2024-06-06 17:09:31.913\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m391\u001b[0m - \u001b[1mDecoder CNN shapes: {'state_map': (12, 12, 166)}\u001b[0m\n", + "\u001b[32m2024-06-06 17:09:31.914\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m392\u001b[0m - \u001b[1mDecoder MLP shapes: {'state_inventory': (102,)}\u001b[0m\n", + "\u001b[32m2024-06-06 17:09:32.650\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmodels\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m102\u001b[0m - \u001b[1mOptimizer model_opt has 2357196 variables.\u001b[0m\n", + "\u001b[32m2024-06-06 17:09:32.657\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmodels\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m281\u001b[0m - \u001b[1mOptimizer actor_opt has 356651 variables.\u001b[0m\n", + "\u001b[32m2024-06-06 17:09:32.657\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmodels\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m292\u001b[0m - \u001b[1mOptimizer value_opt has 345087 variables.\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/wassname/miniforge3/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dreamer(\n", + " (_wm): OptimizedModule(\n", + " (_orig_mod): WorldModel(\n", + " (encoder): MultiEncoder(\n", + " (_cnn): ConvEncoder(\n", + " (layers): Sequential(\n", + " (0): Conv2dSamePad(166, 32, kernel_size=(4, 4), stride=(2, 2), bias=False)\n", + " (1): ImgChLayerNorm(\n", + " (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n", + " )\n", + " (2): SiLU()\n", + " (3): Conv2dSamePad(32, 64, kernel_size=(4, 4), stride=(2, 2), bias=False)\n", + " (4): ImgChLayerNorm(\n", + " (norm): LayerNorm((64,), eps=0.001, elementwise_affine=True)\n", + " )\n", + " (5): SiLU()\n", + " )\n", + " )\n", + " (_mlp): MLP(\n", + " (layers): Sequential(\n", + " (Encoder_linear0): Linear(in_features=102, out_features=16, bias=False)\n", + " (Encoder_norm0): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n", + " (Encoder_act0): SiLU()\n", + " (Encoder_linear1): Linear(in_features=16, out_features=16, bias=False)\n", + " (Encoder_norm1): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n", + " (Encoder_act1): SiLU()\n", + " )\n", + " )\n", + " )\n", + " (dynamics): RSSM(\n", + " (_img_in_layers): Sequential(\n", + " (0): Linear(in_features=619, out_features=256, bias=False)\n", + " (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (2): SiLU()\n", + " )\n", + " (_cell): GRUCell(\n", + " (layers): Sequential(\n", + " (GRU_linear): Linear(in_features=512, out_features=768, bias=False)\n", + " (GRU_norm): LayerNorm((768,), eps=0.001, elementwise_affine=True)\n", + " )\n", + " )\n", + " (_img_out_layers): Sequential(\n", + " (0): Linear(in_features=256, out_features=256, bias=False)\n", + " (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (2): SiLU()\n", + " )\n", + " (_obs_out_layers): Sequential(\n", + " (0): Linear(in_features=848, out_features=256, bias=False)\n", + " (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (2): SiLU()\n", + " )\n", + " (_imgs_stat_layer): Linear(in_features=256, out_features=576, bias=True)\n", + " (_obs_stat_layer): Linear(in_features=256, out_features=576, bias=True)\n", + " )\n", + " (heads): ModuleDict(\n", + " (decoder): MultiDecoder(\n", + " (_cnn): ConvDecoder(\n", + " (_linear_layer): Linear(in_features=832, out_features=256, bias=True)\n", + " (layers): Sequential(\n", + " (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (1): ImgChLayerNorm(\n", + " (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n", + " )\n", + " (2): SiLU()\n", + " (3): ConvTranspose2d(32, 166, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", + " )\n", + " )\n", + " (_mlp): MLP(\n", + " (layers): Sequential(\n", + " (Decoder_linear0): Linear(in_features=832, out_features=16, bias=False)\n", + " (Decoder_norm0): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n", + " (Decoder_act0): SiLU()\n", + " (Decoder_linear1): Linear(in_features=16, out_features=16, bias=False)\n", + " (Decoder_norm1): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n", + " (Decoder_act1): SiLU()\n", + " )\n", + " (mean_layer): ModuleDict(\n", + " (state_inventory): Linear(in_features=16, out_features=102, bias=True)\n", + " )\n", + " )\n", + " )\n", + " (reward): MLP(\n", + " (layers): Sequential(\n", + " (Reward_linear0): Linear(in_features=832, out_features=256, bias=False)\n", + " (Reward_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Reward_act0): SiLU()\n", + " (Reward_linear1): Linear(in_features=256, out_features=256, bias=False)\n", + " (Reward_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Reward_act1): SiLU()\n", + " (Reward_linear2): Linear(in_features=256, out_features=256, bias=False)\n", + " (Reward_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Reward_act2): SiLU()\n", + " )\n", + " (mean_layer): Linear(in_features=256, out_features=255, bias=True)\n", + " )\n", + " (cont): MLP(\n", + " (layers): Sequential(\n", + " (Cont_linear0): Linear(in_features=832, out_features=256, bias=False)\n", + " (Cont_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Cont_act0): SiLU()\n", + " (Cont_linear1): Linear(in_features=256, out_features=256, bias=False)\n", + " (Cont_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Cont_act1): SiLU()\n", + " (Cont_linear2): Linear(in_features=256, out_features=256, bias=False)\n", + " (Cont_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Cont_act2): SiLU()\n", + " )\n", + " (mean_layer): Linear(in_features=256, out_features=1, bias=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (_task_behavior): OptimizedModule(\n", + " (_orig_mod): ImagBehavior(\n", + " (_world_model): WorldModel(\n", + " (encoder): MultiEncoder(\n", + " (_cnn): ConvEncoder(\n", + " (layers): Sequential(\n", + " (0): Conv2dSamePad(166, 32, kernel_size=(4, 4), stride=(2, 2), bias=False)\n", + " (1): ImgChLayerNorm(\n", + " (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n", + " )\n", + " (2): SiLU()\n", + " (3): Conv2dSamePad(32, 64, kernel_size=(4, 4), stride=(2, 2), bias=False)\n", + " (4): ImgChLayerNorm(\n", + " (norm): LayerNorm((64,), eps=0.001, elementwise_affine=True)\n", + " )\n", + " (5): SiLU()\n", + " )\n", + " )\n", + " (_mlp): MLP(\n", + " (layers): Sequential(\n", + " (Encoder_linear0): Linear(in_features=102, out_features=16, bias=False)\n", + " (Encoder_norm0): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n", + " (Encoder_act0): SiLU()\n", + " (Encoder_linear1): Linear(in_features=16, out_features=16, bias=False)\n", + " (Encoder_norm1): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n", + " (Encoder_act1): SiLU()\n", + " )\n", + " )\n", + " )\n", + " (dynamics): RSSM(\n", + " (_img_in_layers): Sequential(\n", + " (0): Linear(in_features=619, out_features=256, bias=False)\n", + " (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (2): SiLU()\n", + " )\n", + " (_cell): GRUCell(\n", + " (layers): Sequential(\n", + " (GRU_linear): Linear(in_features=512, out_features=768, bias=False)\n", + " (GRU_norm): LayerNorm((768,), eps=0.001, elementwise_affine=True)\n", + " )\n", + " )\n", + " (_img_out_layers): Sequential(\n", + " (0): Linear(in_features=256, out_features=256, bias=False)\n", + " (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (2): SiLU()\n", + " )\n", + " (_obs_out_layers): Sequential(\n", + " (0): Linear(in_features=848, out_features=256, bias=False)\n", + " (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (2): SiLU()\n", + " )\n", + " (_imgs_stat_layer): Linear(in_features=256, out_features=576, bias=True)\n", + " (_obs_stat_layer): Linear(in_features=256, out_features=576, bias=True)\n", + " )\n", + " (heads): ModuleDict(\n", + " (decoder): MultiDecoder(\n", + " (_cnn): ConvDecoder(\n", + " (_linear_layer): Linear(in_features=832, out_features=256, bias=True)\n", + " (layers): Sequential(\n", + " (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (1): ImgChLayerNorm(\n", + " (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n", + " )\n", + " (2): SiLU()\n", + " (3): ConvTranspose2d(32, 166, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", + " )\n", + " )\n", + " (_mlp): MLP(\n", + " (layers): Sequential(\n", + " (Decoder_linear0): Linear(in_features=832, out_features=16, bias=False)\n", + " (Decoder_norm0): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n", + " (Decoder_act0): SiLU()\n", + " (Decoder_linear1): Linear(in_features=16, out_features=16, bias=False)\n", + " (Decoder_norm1): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n", + " (Decoder_act1): SiLU()\n", + " )\n", + " (mean_layer): ModuleDict(\n", + " (state_inventory): Linear(in_features=16, out_features=102, bias=True)\n", + " )\n", + " )\n", + " )\n", + " (reward): MLP(\n", + " (layers): Sequential(\n", + " (Reward_linear0): Linear(in_features=832, out_features=256, bias=False)\n", + " (Reward_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Reward_act0): SiLU()\n", + " (Reward_linear1): Linear(in_features=256, out_features=256, bias=False)\n", + " (Reward_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Reward_act1): SiLU()\n", + " (Reward_linear2): Linear(in_features=256, out_features=256, bias=False)\n", + " (Reward_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Reward_act2): SiLU()\n", + " )\n", + " (mean_layer): Linear(in_features=256, out_features=255, bias=True)\n", + " )\n", + " (cont): MLP(\n", + " (layers): Sequential(\n", + " (Cont_linear0): Linear(in_features=832, out_features=256, bias=False)\n", + " (Cont_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Cont_act0): SiLU()\n", + " (Cont_linear1): Linear(in_features=256, out_features=256, bias=False)\n", + " (Cont_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Cont_act1): SiLU()\n", + " (Cont_linear2): Linear(in_features=256, out_features=256, bias=False)\n", + " (Cont_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Cont_act2): SiLU()\n", + " )\n", + " (mean_layer): Linear(in_features=256, out_features=1, bias=True)\n", + " )\n", + " )\n", + " )\n", + " (actor): MLP(\n", + " (layers): Sequential(\n", + " (Actor_linear0): Linear(in_features=832, out_features=256, bias=False)\n", + " (Actor_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Actor_act0): SiLU()\n", + " (Actor_linear1): Linear(in_features=256, out_features=256, bias=False)\n", + " (Actor_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Actor_act1): SiLU()\n", + " (Actor_linear2): Linear(in_features=256, out_features=256, bias=False)\n", + " (Actor_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Actor_act2): SiLU()\n", + " )\n", + " (mean_layer): Linear(in_features=256, out_features=43, bias=True)\n", + " )\n", + " (value): MLP(\n", + " (layers): Sequential(\n", + " (Value_linear0): Linear(in_features=832, out_features=256, bias=False)\n", + " (Value_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Value_act0): SiLU()\n", + " (Value_linear1): Linear(in_features=256, out_features=256, bias=False)\n", + " (Value_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Value_act1): SiLU()\n", + " )\n", + " (mean_layer): Linear(in_features=256, out_features=255, bias=True)\n", + " )\n", + " (_slow_value): MLP(\n", + " (layers): Sequential(\n", + " (Value_linear0): Linear(in_features=832, out_features=256, bias=False)\n", + " (Value_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Value_act0): SiLU()\n", + " (Value_linear1): Linear(in_features=256, out_features=256, bias=False)\n", + " (Value_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Value_act1): SiLU()\n", + " )\n", + " (mean_layer): Linear(in_features=256, out_features=255, bias=True)\n", + " )\n", + " )\n", + " )\n", + " (_expl_behavior): OptimizedModule(\n", + " (_orig_mod): ImagBehavior(\n", + " (_world_model): WorldModel(\n", + " (encoder): MultiEncoder(\n", + " (_cnn): ConvEncoder(\n", + " (layers): Sequential(\n", + " (0): Conv2dSamePad(166, 32, kernel_size=(4, 4), stride=(2, 2), bias=False)\n", + " (1): ImgChLayerNorm(\n", + " (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n", + " )\n", + " (2): SiLU()\n", + " (3): Conv2dSamePad(32, 64, kernel_size=(4, 4), stride=(2, 2), bias=False)\n", + " (4): ImgChLayerNorm(\n", + " (norm): LayerNorm((64,), eps=0.001, elementwise_affine=True)\n", + " )\n", + " (5): SiLU()\n", + " )\n", + " )\n", + " (_mlp): MLP(\n", + " (layers): Sequential(\n", + " (Encoder_linear0): Linear(in_features=102, out_features=16, bias=False)\n", + " (Encoder_norm0): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n", + " (Encoder_act0): SiLU()\n", + " (Encoder_linear1): Linear(in_features=16, out_features=16, bias=False)\n", + " (Encoder_norm1): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n", + " (Encoder_act1): SiLU()\n", + " )\n", + " )\n", + " )\n", + " (dynamics): RSSM(\n", + " (_img_in_layers): Sequential(\n", + " (0): Linear(in_features=619, out_features=256, bias=False)\n", + " (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (2): SiLU()\n", + " )\n", + " (_cell): GRUCell(\n", + " (layers): Sequential(\n", + " (GRU_linear): Linear(in_features=512, out_features=768, bias=False)\n", + " (GRU_norm): LayerNorm((768,), eps=0.001, elementwise_affine=True)\n", + " )\n", + " )\n", + " (_img_out_layers): Sequential(\n", + " (0): Linear(in_features=256, out_features=256, bias=False)\n", + " (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (2): SiLU()\n", + " )\n", + " (_obs_out_layers): Sequential(\n", + " (0): Linear(in_features=848, out_features=256, bias=False)\n", + " (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (2): SiLU()\n", + " )\n", + " (_imgs_stat_layer): Linear(in_features=256, out_features=576, bias=True)\n", + " (_obs_stat_layer): Linear(in_features=256, out_features=576, bias=True)\n", + " )\n", + " (heads): ModuleDict(\n", + " (decoder): MultiDecoder(\n", + " (_cnn): ConvDecoder(\n", + " (_linear_layer): Linear(in_features=832, out_features=256, bias=True)\n", + " (layers): Sequential(\n", + " (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (1): ImgChLayerNorm(\n", + " (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n", + " )\n", + " (2): SiLU()\n", + " (3): ConvTranspose2d(32, 166, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", + " )\n", + " )\n", + " (_mlp): MLP(\n", + " (layers): Sequential(\n", + " (Decoder_linear0): Linear(in_features=832, out_features=16, bias=False)\n", + " (Decoder_norm0): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n", + " (Decoder_act0): SiLU()\n", + " (Decoder_linear1): Linear(in_features=16, out_features=16, bias=False)\n", + " (Decoder_norm1): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n", + " (Decoder_act1): SiLU()\n", + " )\n", + " (mean_layer): ModuleDict(\n", + " (state_inventory): Linear(in_features=16, out_features=102, bias=True)\n", + " )\n", + " )\n", + " )\n", + " (reward): MLP(\n", + " (layers): Sequential(\n", + " (Reward_linear0): Linear(in_features=832, out_features=256, bias=False)\n", + " (Reward_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Reward_act0): SiLU()\n", + " (Reward_linear1): Linear(in_features=256, out_features=256, bias=False)\n", + " (Reward_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Reward_act1): SiLU()\n", + " (Reward_linear2): Linear(in_features=256, out_features=256, bias=False)\n", + " (Reward_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Reward_act2): SiLU()\n", + " )\n", + " (mean_layer): Linear(in_features=256, out_features=255, bias=True)\n", + " )\n", + " (cont): MLP(\n", + " (layers): Sequential(\n", + " (Cont_linear0): Linear(in_features=832, out_features=256, bias=False)\n", + " (Cont_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Cont_act0): SiLU()\n", + " (Cont_linear1): Linear(in_features=256, out_features=256, bias=False)\n", + " (Cont_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Cont_act1): SiLU()\n", + " (Cont_linear2): Linear(in_features=256, out_features=256, bias=False)\n", + " (Cont_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Cont_act2): SiLU()\n", + " )\n", + " (mean_layer): Linear(in_features=256, out_features=1, bias=True)\n", + " )\n", + " )\n", + " )\n", + " (actor): MLP(\n", + " (layers): Sequential(\n", + " (Actor_linear0): Linear(in_features=832, out_features=256, bias=False)\n", + " (Actor_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Actor_act0): SiLU()\n", + " (Actor_linear1): Linear(in_features=256, out_features=256, bias=False)\n", + " (Actor_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Actor_act1): SiLU()\n", + " (Actor_linear2): Linear(in_features=256, out_features=256, bias=False)\n", + " (Actor_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Actor_act2): SiLU()\n", + " )\n", + " (mean_layer): Linear(in_features=256, out_features=43, bias=True)\n", + " )\n", + " (value): MLP(\n", + " (layers): Sequential(\n", + " (Value_linear0): Linear(in_features=832, out_features=256, bias=False)\n", + " (Value_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Value_act0): SiLU()\n", + " (Value_linear1): Linear(in_features=256, out_features=256, bias=False)\n", + " (Value_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Value_act1): SiLU()\n", + " )\n", + " (mean_layer): Linear(in_features=256, out_features=255, bias=True)\n", + " )\n", + " (_slow_value): MLP(\n", + " (layers): Sequential(\n", + " (Value_linear0): Linear(in_features=832, out_features=256, bias=False)\n", + " (Value_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Value_act0): SiLU()\n", + " (Value_linear1): Linear(in_features=256, out_features=256, bias=False)\n", + " (Value_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n", + " (Value_act1): SiLU()\n", + " )\n", + " (mean_layer): Linear(in_features=256, out_features=255, bias=True)\n", + " )\n", + " )\n", + " )\n", + ")\n" + ] + } + ], "source": [ "config = parse_args(argv)\n", "config.num_actions = acts.n if hasattr(acts, \"n\") else acts.shape[0]\n", @@ -227,7 +677,7 @@ " tlogger,\n", " train_dataset,\n", ").to(config.device)\n", - "# print(agent)\n", + "print(agent)\n", "agent.requires_grad_(requires_grad=False)\n", "if (logdir / \"latest.pt\").exists():\n", " checkpoint = torch.load(logdir / \"latest.pt\")\n", @@ -262,19 +712,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "assert state is not None\n", "import numpy as np\n", "\n", - "state" + "# state" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -307,27 +757,33 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, - "outputs": [], - "source": [ - "envs[0].observation_space" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "obs[0]['state_map'].shape, obs[0]['state_inventory'].shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/nbs/../networks.py:790: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)\n", + " ret = F.conv2d(\n" + ] + }, + { + "ename": "AssertionError", + "evalue": "(torch.Size([256, 32, 8, 8, 166]), torch.Size([256, 32, 12, 12, 166]))", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[10], line 6\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# from tools.simulate\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# step\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# step, episode, done, length, obs, agent_state, reward = state\u001b[39;00m\n\u001b[1;32m 5\u001b[0m obs2 \u001b[38;5;241m=\u001b[39m {k: np\u001b[38;5;241m.\u001b[39mstack([o[k] \u001b[38;5;28;01mfor\u001b[39;00m o \u001b[38;5;129;01min\u001b[39;00m obs]) \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m obs[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlog_\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m k}\n\u001b[0;32m----> 6\u001b[0m action, agent_state \u001b[38;5;241m=\u001b[39m \u001b[43magent\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobs2\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdone\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43magent_state\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/nbs/../dreamer.py:71\u001b[0m, in \u001b[0;36mDreamer.__call__\u001b[0;34m(self, obs, reset, state, training)\u001b[0m\n\u001b[1;32m 65\u001b[0m steps \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 66\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_config\u001b[38;5;241m.\u001b[39mpretrain\n\u001b[1;32m 67\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_should_pretrain()\n\u001b[1;32m 68\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_should_train(step)\n\u001b[1;32m 69\u001b[0m )\n\u001b[1;32m 70\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(steps):\n\u001b[0;32m---> 71\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_train\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_update_count \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_metrics[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mupdate_count\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_update_count\n", + "File \u001b[0;32m/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/nbs/../dreamer.py:124\u001b[0m, in \u001b[0;36mDreamer._train\u001b[0;34m(self, data)\u001b[0m\n\u001b[1;32m 122\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_train\u001b[39m(\u001b[38;5;28mself\u001b[39m, data):\n\u001b[1;32m 123\u001b[0m metrics \u001b[38;5;241m=\u001b[39m {}\n\u001b[0;32m--> 124\u001b[0m post, context, mets \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_wm\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_train\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 125\u001b[0m metrics\u001b[38;5;241m.\u001b[39mupdate(mets)\n\u001b[1;32m 126\u001b[0m start \u001b[38;5;241m=\u001b[39m post\n", + "File \u001b[0;32m/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/nbs/../models.py:143\u001b[0m, in \u001b[0;36mWorldModel._train\u001b[0;34m(self, data)\u001b[0m\n\u001b[1;32m 141\u001b[0m losses \u001b[38;5;241m=\u001b[39m {}\n\u001b[1;32m 142\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m name, pred \u001b[38;5;129;01min\u001b[39;00m preds\u001b[38;5;241m.\u001b[39mitems():\n\u001b[0;32m--> 143\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[43mpred\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlog_prob\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m[\u001b[49m\u001b[43mname\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 144\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m loss\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m==\u001b[39m embed\u001b[38;5;241m.\u001b[39mshape[:\u001b[38;5;241m2\u001b[39m], (name, loss\u001b[38;5;241m.\u001b[39mshape)\n\u001b[1;32m 145\u001b[0m losses[name] \u001b[38;5;241m=\u001b[39m loss\n", + "File \u001b[0;32m/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/nbs/../tools.py:528\u001b[0m, in \u001b[0;36mMSEDist.log_prob\u001b[0;34m(self, value)\u001b[0m\n\u001b[1;32m 527\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mlog_prob\u001b[39m(\u001b[38;5;28mself\u001b[39m, value):\n\u001b[0;32m--> 528\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_mode\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m==\u001b[39m value\u001b[38;5;241m.\u001b[39mshape, (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_mode\u001b[38;5;241m.\u001b[39mshape, value\u001b[38;5;241m.\u001b[39mshape)\n\u001b[1;32m 529\u001b[0m distance \u001b[38;5;241m=\u001b[39m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_mode \u001b[38;5;241m-\u001b[39m value) \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m \u001b[38;5;241m2\u001b[39m\n\u001b[1;32m 530\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_agg \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmean\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n", + "\u001b[0;31mAssertionError\u001b[0m: (torch.Size([256, 32, 8, 8, 166]), torch.Size([256, 32, 12, 12, 166]))" + ] + } + ], "source": [ "# from tools.simulate\n", "\n", @@ -341,10 +797,120 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "==========================================================================================\n", + "Layer (type:depth-idx) Param #\n", + "==========================================================================================\n", + "Dreamer --\n", + "├─OptimizedModule: 1-1 --\n", + "│ └─WorldModel: 2-1 --\n", + "│ │ └─MultiEncoder: 3-1 (44,480)\n", + "│ │ └─RSSM: 3-2 (2,397,952)\n", + "│ │ └─ModuleDict: 3-3 (1,580,204)\n", + "├─OptimizedModule: 1-2 --\n", + "│ └─ImagBehavior: 2-2 4,022,636\n", + "│ │ └─WorldModel: 3-4 (recursive)\n", + "│ │ └─MLP: 3-5 (536,875)\n", + "│ │ └─MLP: 3-6 (525,311)\n", + "│ │ └─MLP: 3-7 (525,311)\n", + "├─OptimizedModule: 1-3 (recursive)\n", + "│ └─ImagBehavior: 2-3 (recursive)\n", + "│ │ └─WorldModel: 3-8 (recursive)\n", + "│ │ └─MLP: 3-9 (recursive)\n", + "│ │ └─MLP: 3-10 (recursive)\n", + "│ │ └─MLP: 3-11 (recursive)\n", + "==========================================================================================\n", + "Total params: 9,632,769\n", + "Trainable params: 0\n", + "Non-trainable params: 9,632,769\n", + "==========================================================================================" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "from torchinfo import summary\n", "\n", + "summary(agent, input=(obs, done, agent_state), depth=3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "==========================================================================================\n", + "Layer (type:depth-idx) Param #\n", + "==========================================================================================\n", + "Dreamer --\n", + "├─OptimizedModule: 1-1 --\n", + "│ └─WorldModel: 2-1 --\n", + "│ │ └─MultiEncoder: 3-1 --\n", + "│ │ │ └─ConvEncoder: 4-1 (42,528)\n", + "│ │ │ └─MLP: 4-2 (1,952)\n", + "│ │ └─RSSM: 3-2 512\n", + "│ │ │ └─Sequential: 4-3 (273,664)\n", + "│ │ │ └─GRUCell: 4-4 (1,182,720)\n", + "│ │ │ └─Sequential: 4-5 (131,584)\n", + "│ │ │ └─Sequential: 4-6 (283,136)\n", + "│ │ │ └─Linear: 4-7 (263,168)\n", + "│ │ │ └─Linear: 4-8 (263,168)\n", + "│ │ └─ModuleDict: 3-3 --\n", + "│ │ │ └─MultiDecoder: 4-9 (462,764)\n", + "│ │ │ └─MLP: 4-10 (591,359)\n", + "│ │ │ └─MLP: 4-11 (526,081)\n", + "├─OptimizedModule: 1-2 --\n", + "│ └─ImagBehavior: 2-2 4,022,636\n", + "│ │ └─WorldModel: 3-4 (recursive)\n", + "│ │ │ └─MultiEncoder: 4-12 (recursive)\n", + "│ │ │ └─RSSM: 4-13 (recursive)\n", + "│ │ │ └─ModuleDict: 4-14 (recursive)\n", + "│ │ └─MLP: 3-5 --\n", + "│ │ │ └─Sequential: 4-15 (525,824)\n", + "│ │ │ └─Linear: 4-16 (11,051)\n", + "│ │ └─MLP: 3-6 --\n", + "│ │ │ └─Sequential: 4-17 (459,776)\n", + "│ │ │ └─Linear: 4-18 (65,535)\n", + "│ │ └─MLP: 3-7 --\n", + "│ │ │ └─Sequential: 4-19 (459,776)\n", + "│ │ │ └─Linear: 4-20 (65,535)\n", + "├─OptimizedModule: 1-3 (recursive)\n", + "│ └─ImagBehavior: 2-3 (recursive)\n", + "│ │ └─WorldModel: 3-8 (recursive)\n", + "│ │ │ └─MultiEncoder: 4-21 (recursive)\n", + "│ │ │ └─RSSM: 4-22 (recursive)\n", + "│ │ │ └─ModuleDict: 4-23 (recursive)\n", + "│ │ └─MLP: 3-9 (recursive)\n", + "│ │ │ └─Sequential: 4-24 (recursive)\n", + "│ │ │ └─Linear: 4-25 (recursive)\n", + "│ │ └─MLP: 3-10 (recursive)\n", + "│ │ │ └─Sequential: 4-26 (recursive)\n", + "│ │ │ └─Linear: 4-27 (recursive)\n", + "│ │ └─MLP: 3-11 (recursive)\n", + "│ │ │ └─Sequential: 4-28 (recursive)\n", + "│ │ │ └─Linear: 4-29 (recursive)\n", + "==========================================================================================\n", + "Total params: 9,632,769\n", + "Trainable params: 0\n", + "Non-trainable params: 9,632,769\n", + "==========================================================================================" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ "summary(agent, input=(obs, done, agent_state), depth=4)" ] }, @@ -385,7 +951,46 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "===================================================================================================================\n", + "Layer (type:depth-idx) Input Shape Output Shape Param #\n", + "===================================================================================================================\n", + "MultiEncoder [256, 16, 130, 110, 3] [256, 16, 592] --\n", + "├─ConvEncoder: 1-1 [256, 16, 12, 12, 166] [256, 16, 576] --\n", + "│ └─Sequential: 2-1 [4096, 166, 12, 12] [4096, 16, 6, 6] --\n", + "│ │ └─Conv2dSamePad: 3-1 [4096, 166, 12, 12] [4096, 16, 6, 6] (42,496)\n", + "│ │ └─ImgChLayerNorm: 3-2 [4096, 16, 6, 6] [4096, 16, 6, 6] --\n", + "│ │ │ └─LayerNorm: 4-1 [4096, 6, 6, 16] [4096, 6, 6, 16] (32)\n", + "│ │ └─SiLU: 3-3 [4096, 16, 6, 6] [4096, 16, 6, 6] --\n", + "├─MLP: 1-2 [256, 16, 102] [256, 16, 16] --\n", + "│ └─Sequential: 2-2 [256, 16, 102] [256, 16, 16] --\n", + "│ │ └─Linear: 3-4 [256, 16, 102] [256, 16, 16] (1,632)\n", + "│ │ └─LayerNorm: 3-5 [256, 16, 16] [256, 16, 16] (32)\n", + "│ │ └─SiLU: 3-6 [256, 16, 16] [256, 16, 16] --\n", + "│ │ └─Linear: 3-7 [256, 16, 16] [256, 16, 16] (256)\n", + "│ │ └─LayerNorm: 3-8 [256, 16, 16] [256, 16, 16] (32)\n", + "│ │ └─SiLU: 3-9 [256, 16, 16] [256, 16, 16] --\n", + "===================================================================================================================\n", + "Total params: 44,480\n", + "Trainable params: 0\n", + "Non-trainable params: 44,480\n", + "Total mult-adds (G): 6.27\n", + "===================================================================================================================\n", + "Input size (MB): 1367.93\n", + "Forward/backward pass size (MB): 39.85\n", + "Params size (MB): 0.18\n", + "Estimated Total Size (MB): 1407.96\n", + "===================================================================================================================" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "summary(wm.encoder, input_data=(data,), depth=4, col_names=[\"input_size\", \"output_size\", \"num_params\", ])" ] @@ -394,7 +999,46 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "decoder\n", + "===================================================================================================================\n", + "Layer (type:depth-idx) Input Shape Output Shape Param #\n", + "===================================================================================================================\n", + "MultiDecoder [256, 16, 1536] -- --\n", + "├─ConvDecoder: 1-1 [256, 16, 1536] [256, 16, 8, 8, 166] --\n", + "│ └─Linear: 2-1 [256, 16, 1536] [256, 16, 256] (393,472)\n", + "│ └─Sequential: 2-2 [4096, 16, 4, 4] [4096, 166, 8, 8] --\n", + "│ │ └─ConvTranspose2d: 3-1 [4096, 16, 4, 4] [4096, 166, 8, 8] (42,662)\n", + "├─MLP: 1-2 [256, 16, 1536] -- --\n", + "│ └─Sequential: 2-3 [256, 16, 1536] [256, 16, 16] --\n", + "│ │ └─Linear: 3-2 [256, 16, 1536] [256, 16, 16] (24,576)\n", + "│ │ └─LayerNorm: 3-3 [256, 16, 16] [256, 16, 16] (32)\n", + "│ │ └─SiLU: 3-4 [256, 16, 16] [256, 16, 16] --\n", + "│ │ └─Linear: 3-5 [256, 16, 16] [256, 16, 16] (256)\n", + "│ │ └─LayerNorm: 3-6 [256, 16, 16] [256, 16, 16] (32)\n", + "│ │ └─SiLU: 3-7 [256, 16, 16] [256, 16, 16] --\n", + "│ └─ModuleDict: 2-4 -- -- --\n", + "│ │ └─Linear: 3-8 [256, 16, 16] [256, 16, 102] (1,734)\n", + "===================================================================================================================\n", + "Total params: 462,764\n", + "Trainable params: 0\n", + "Non-trainable params: 462,764\n", + "Total mult-adds (G): 11.29\n", + "===================================================================================================================\n", + "Input size (MB): 25.17\n", + "Forward/backward pass size (MB): 361.96\n", + "Params size (MB): 1.85\n", + "Estimated Total Size (MB): 388.97\n", + "===================================================================================================================\n", + "Summary Failed for reward Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Sequential: 1, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 1]\n", + "Summary Failed for cont Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Sequential: 1, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 1]\n" + ] + } + ], "source": [ "# heads\n", "feat = wm.dynamics.get_feat(post)\n", @@ -429,7 +1073,41 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "===================================================================================================================\n", + "Layer (type:depth-idx) Output Shape Param # Output Shape\n", + "===================================================================================================================\n", + "Sequential [256, 16, 256] -- [256, 16, 256]\n", + "├─Linear: 1-1 [256, 16, 256] (393,216) [256, 16, 256]\n", + "├─LayerNorm: 1-2 [256, 16, 256] (512) [256, 16, 256]\n", + "├─SiLU: 1-3 [256, 16, 256] -- [256, 16, 256]\n", + "├─Linear: 1-4 [256, 16, 256] (65,536) [256, 16, 256]\n", + "├─LayerNorm: 1-5 [256, 16, 256] (512) [256, 16, 256]\n", + "├─SiLU: 1-6 [256, 16, 256] -- [256, 16, 256]\n", + "├─Linear: 1-7 [256, 16, 256] (65,536) [256, 16, 256]\n", + "├─LayerNorm: 1-8 [256, 16, 256] (512) [256, 16, 256]\n", + "├─SiLU: 1-9 [256, 16, 256] -- [256, 16, 256]\n", + "===================================================================================================================\n", + "Total params: 525,824\n", + "Trainable params: 0\n", + "Non-trainable params: 525,824\n", + "Total mult-adds (M): 134.61\n", + "===================================================================================================================\n", + "Input size (MB): 25.17\n", + "Forward/backward pass size (MB): 50.33\n", + "Params size (MB): 2.10\n", + "Estimated Total Size (MB): 77.60\n", + "===================================================================================================================" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "actor = agent._task_behavior.actor\n", "\n", @@ -441,7 +1119,41 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "===================================================================================================================\n", + "Layer (type:depth-idx) Output Shape Param # Output Shape\n", + "===================================================================================================================\n", + "Sequential [256, 16, 256] -- [256, 16, 256]\n", + "├─Linear: 1-1 [256, 16, 256] (393,216) [256, 16, 256]\n", + "├─LayerNorm: 1-2 [256, 16, 256] (512) [256, 16, 256]\n", + "├─SiLU: 1-3 [256, 16, 256] -- [256, 16, 256]\n", + "├─Linear: 1-4 [256, 16, 256] (65,536) [256, 16, 256]\n", + "├─LayerNorm: 1-5 [256, 16, 256] (512) [256, 16, 256]\n", + "├─SiLU: 1-6 [256, 16, 256] -- [256, 16, 256]\n", + "├─Linear: 1-7 [256, 16, 256] (65,536) [256, 16, 256]\n", + "├─LayerNorm: 1-8 [256, 16, 256] (512) [256, 16, 256]\n", + "├─SiLU: 1-9 [256, 16, 256] -- [256, 16, 256]\n", + "===================================================================================================================\n", + "Total params: 525,824\n", + "Trainable params: 0\n", + "Non-trainable params: 525,824\n", + "Total mult-adds (M): 134.61\n", + "===================================================================================================================\n", + "Input size (MB): 25.17\n", + "Forward/backward pass size (MB): 50.33\n", + "Params size (MB): 2.10\n", + "Estimated Total Size (MB): 77.60\n", + "===================================================================================================================" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "value = agent._task_behavior.actor\n", "summary(value.layers, input_data=(feat,), depth=3, col_names=[\"output_size\", \"num_params\", \"output_size\" ])" @@ -451,26 +1163,22 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "8268" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "8268" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "o = obs['state'].reshape((-1, 8268))\n", - "map = o[:, :8217].reshape((-1, 9, 11, 83))\n", - "map.shape\n", - "inventories = o[:, 8217:]\n", - "inventories\n", - "\n", - "map" - ] - }, { "cell_type": "code", "execution_count": null, diff --git a/nbs/02_torchinfo.ipynb b/nbs/02_torchinfo.ipynb index c01f92e..551ce26 100644 --- a/nbs/02_torchinfo.ipynb +++ b/nbs/02_torchinfo.ipynb @@ -246,6 +246,23 @@ "- actor" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Check encoder decoder" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "agent.encoder(x)\n", + "agent.heads['decoder'](x) " + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/networks.py b/networks.py index 1b1dd7e..a0cca6e 100644 --- a/networks.py +++ b/networks.py @@ -560,6 +560,7 @@ class ConvDecoder(nn.Module): [m.apply(tools.weight_init) for m in layers[:-1]] layers[-1].apply(tools.uniform_weight_init(outscale)) self.layers = nn.Sequential(*layers) + self.outdim = out_dim def calc_same_pad(self, k, s, d): val = d * (k - 1) - s + 1 @@ -569,17 +570,10 @@ class ConvDecoder(nn.Module): def forward(self, features, dtype=None): x = self._linear_layer(features) - # (batch, time, -1) -> (batch * time, h, w, ch) - x = x.reshape( - [-1, self._minres, self._minres, self._embed_size // self._minres**2] - ) - # (batch, time, -1) -> (batch * time, ch, h, w) - x = x.permute(0, 3, 1, 2) + x = rearrange(x, "b t (h w c) -> (b t) c h w", h=self._minres, w=self._minres) x = self.layers(x) - # (batch, time, -1) -> (batch, time, ch, h, w) - mean = x.reshape(features.shape[:-1] + self._shape) - # (batch, time, ch, h, w) -> (batch, time, h, w, ch) - mean = mean.permute(0, 1, 3, 4, 2) + mean = rearrange(x, "(b t) c h w -> b t h w c ", t=features.shape[1]) + # assert mean.shape[-2]==self.outdim, f"{mean.shape[-2]}!={self.outdim}" if self._cnn_sigmoid: mean = F.sigmoid(mean) else: