diff --git a/configs.yaml b/configs.yaml index 2be2d69..03b2c7a 100644 --- a/configs.yaml +++ b/configs.yaml @@ -168,6 +168,27 @@ craftax_small: batch_size: 128 batch_length: 16 +craftax_small2: + task: craftax_Craftax-Symbolic-AutoReset-v1 + step: 1e6 + action_repeat: 1 + envs: 1 + train_ratio: 512 + video_pred_log: false + dyn_hidden: 256 + dyn_deter: 512 + # note: depth is cnn hidden_dim + encoder: {cnn_keys: 'state_map', cnn_depth: 16, kernel_size: 4, minres: 4, mlp_keys: "state_inventory", mlp_layers: 2, mlp_units: 16} + decoder: {cnn_keys: 'state_map', cnn_depth: 16, kernel_size: 4, minres: 4, mlp_keys: "state_inventory", mlp_layers: 2, mlp_units: 16} + actor: {layers: 3, dist: 'onehot', std: 'none'} + value: {layers: 3} + # note units is the head hidden_dim + units: 256 + reward_head: {layers: 3} + cont_head: {layers: 3} + imag_gradient: 'reinforce' + batch_size: 256 + batch_length: 16 craftax_smaller: task: craftax_Craftax-Symbolic-AutoReset-v1 diff --git a/envs/craftax_env.py b/envs/craftax_env.py index 2ea73cf..56efcd3 100644 --- a/envs/craftax_env.py +++ b/envs/craftax_env.py @@ -1,5 +1,7 @@ import gymnasium as gym import numpy as np +from einops import rearrange +import torch.nn.functional as F from craftax.craftax_env import make_craftax_env_from_name from craftax.craftax.play_craftax import CraftaxRenderer @@ -185,6 +187,20 @@ def create_craftax_env( env = CraftaxCompatWrapper(env) return env + +def reshape_state(state: Float[Tensor, 'frames state_dim']) -> (Float[Tensor,'frames h w c'], Float[Tensor,'frames inv']): + """ + reshapes state into map and inv + + https://github.com/MichaelTMatthews/Craftax/blob/main/obs_description.md + """ + map = rearrange(state[:, :8217], 'frames (h w c) -> frames h w c', h=9, w=11, c=83) + # now pad from (9,11) to (12,12) using torch + map = F.pad(map, (0, 0, 1, 0, 2, 1)) + map = rearrange(map, 'f h w c -> h w (f c)') + inventories = rearrange(state[:, 8217:], 'frames c -> (frames c)') + return map, inventories + class Craftax: metadata = {} @@ -196,8 +212,11 @@ class Craftax: @property def observation_space(self): + frames = self._env.observation_space.shape[0] spaces = { "state": gym.spaces.Box(0, 1, (np.prod(self._env.observation_space.shape),), dtype=np.float32), + "state_map": gym.spaces.Box(0, 1, (12, 12, frames*83), dtype=np.float32), + "state_inventory": gym.spaces.Box(0, 1, (frames * 51,), dtype=np.float32), "image": gym.spaces.Box(0, 255, (130, 110, 3), dtype=np.uint8), "is_first": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8), "is_last": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8), @@ -227,9 +246,12 @@ class Craftax: info2 = {k.replace('Ach','log_ach'):v for k,v in info.items()} reward = np.float32(reward) + state_map, state_inv = reshape_state(state) obs = { "image": self.get_image(), "state": state.flatten(), + "state_map": state_map, + "state_inventory": state_inv, "is_first": False, "is_last": done, "is_terminal": info["discount"] == 0, @@ -248,9 +270,12 @@ class Craftax: def reset(self, seed=None, options=None): state, info = self._env.reset() + state_map, state_inv = reshape_state(state) obs = { "image": self.get_image(), "state": state.flatten(), + "state_map": state_map, + "state_inventory": state_inv, "is_first": True, "is_last": False, "is_terminal": False, diff --git a/nbs/02_torchinfo copy.ipynb b/nbs/02_torchinfo copy.ipynb new file mode 100644 index 0000000..7503e74 --- /dev/null +++ b/nbs/02_torchinfo copy.ipynb @@ -0,0 +1,517 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this notebook we load a saved dreamer, and run it, to look at params, speed and improve hackability" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading textures from cache\n" + ] + } + ], + "source": [ + "# TODO make this a proper package\n", + "import os, sys\n", + "sys.path.append('..')\n", + "\n", + "\n", + "from dreamer import parse_args, main, make_env, make_dataset, count_steps,Dreamer" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['../dreamer.py', '--configs', 'craftax_small2', '--logdir', '../logdir/craftax_small2']\n" + ] + }, + { + "data": { + "text/plain": [ + "Namespace(act='SiLU', action_repeat=1, actor={'layers': 3, 'dist': 'onehot', 'entropy': 0.0003, 'unimix_ratio': 0.01, 'std': 'none', 'min_std': 0.1, 'max_std': 1.0, 'temp': 0.1, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 1.0}, batch_length=16, batch_size=256, compile=True, cont_head={'layers': 3, 'loss_scale': 1.0, 'outscale': 1.0}, critic={'layers': 2, 'dist': 'symlog_disc', 'slow_target': True, 'slow_target_update': 1, 'slow_target_fraction': 0.02, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 0.0}, dataset_size=1000000, debug=False, decoder={'mlp_keys': 'state_inventory', 'cnn_keys': 'state_map', 'act': 'SiLU', 'norm': True, 'cnn_depth': 16, 'kernel_size': 4, 'minres': 4, 'mlp_layers': 2, 'mlp_units': 16, 'cnn_sigmoid': False, 'image_dist': 'mse', 'vector_dist': 'symlog_mse', 'outscale': 1.0}, deterministic_run=False, device='cuda:0', disag_action_cond=False, disag_layers=4, disag_log=True, disag_models=10, disag_offset=1, disag_target='stoch', disag_units=400, discount=0.997, discount_lambda=0.95, dyn_deter=512, dyn_discrete=32, dyn_hidden=256, dyn_mean_act='none', dyn_min_std=0.1, dyn_rec_depth=1, dyn_scale=0.5, dyn_std_act='sigmoid2', dyn_stoch=32, encoder={'mlp_keys': 'state_inventory', 'cnn_keys': 'state_map', 'act': 'SiLU', 'norm': True, 'cnn_depth': 16, 'kernel_size': 4, 'minres': 4, 'mlp_layers': 2, 'mlp_units': 16, 'symlog_inputs': True}, envs=1, eval_episode_num=10, eval_every=10000.0, eval_state_mean=False, evaldir=None, expl_behavior='greedy', expl_extr_scale=0.0, expl_intr_scale=1.0, expl_until=0, grad_clip=1000, grad_heads=('decoder', 'reward', 'cont'), grayscale=False, imag_gradient='reinforce', imag_gradient_mix=0.0, imag_horizon=15, initial='learned', kl_free=1.0, log_every=10000.0, logdir='../logdir/craftax_small2', model_lr=0.0001, norm=True, offline_evaldir='', offline_traindir='', opt='adam', opt_eps=1e-08, parallel=False, precision=32, prefill=2500, pretrain=100, rep_scale=0.1, reset_every=0, reward_EMA=True, reward_head={'layers': 3, 'dist': 'symlog_disc', 'loss_scale': 1.0, 'outscale': 0.0}, seed=0, size=(64, 64), step=1000000.0, steps=1000000.0, task='craftax_Craftax-Symbolic-AutoReset-v1', time_limit=1000, train_ratio=512, traindir=None, unimix_ratio=0.01, units=256, value={'layers': 3}, video_pred_log=False, weight_decay=0.0)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# emulate cli\n", + "argv = f\"../dreamer.py --configs craftax_small2 --logdir ../logdir/craftax_small2\"\n", + "argv = argv.split()\n", + "print(argv)\n", + "config = parse_args(argv)\n", + "config" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32m2024-06-06 16:21:39.870\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m27\u001b[0m - \u001b[1mLogdir ../logdir/craftax_small2\u001b[0m\n", + "\u001b[32m2024-06-06 16:21:39.887\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m36\u001b[0m - \u001b[1mCreate envs.\u001b[0m\n", + "\u001b[32m2024-06-06 16:22:16.800\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m57\u001b[0m - \u001b[1mAction Space Box(0.0, 1.0, (43,), float32)\u001b[0m\n", + "\u001b[32m2024-06-06 16:22:16.801\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m63\u001b[0m - \u001b[1mPrefill dataset (2500 steps).\u001b[0m\n", + "\u001b[32m2024-06-06 16:23:40.174\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mtools\u001b[0m:\u001b[36mwrite\u001b[0m:\u001b[36m85\u001b[0m - \u001b[1m[0] log_achievements/cast_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/cast_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_coal \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_diamond \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_drink \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_iron \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_ruby \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapling \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapphire \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_wood \u001b[31m100.0\u001b[0m\u001b[1m / log_achievements/damage_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_deep_thing \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_fire_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_frost_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_warrior \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_ice_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_knight \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_kobold \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_lizard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_mage \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_solider \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_pigman \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_skeleton \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_zombie \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/drink_potion \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_bat \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_cow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_snail \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_dungeon \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_fire_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_gnomish_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_graveyard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_ice_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_sewers \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_troll_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_vault \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/find_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/fire_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_arrow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/open_chest \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_furnace \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_table \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/wake_up \u001b[31m100.0\u001b[0m\u001b[1m / dataset_size \u001b[31m242.0\u001b[0m\u001b[1m / train_return \u001b[31m1.1\u001b[0m\u001b[1m / train_length \u001b[31m242.0\u001b[0m\u001b[1m / train_episodes \u001b[31m1.0\u001b[0m\u001b[1m\u001b[0m\n", + "\u001b[32m2024-06-06 16:23:42.913\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mtools\u001b[0m:\u001b[36mwrite\u001b[0m:\u001b[36m85\u001b[0m - \u001b[1m[0] log_achievements/cast_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/cast_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_coal \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_diamond \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_drink \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_iron \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_ruby \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapling \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapphire \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_wood \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/damage_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_deep_thing \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_fire_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_frost_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_warrior \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_ice_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_knight \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_kobold \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_lizard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_mage \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_solider \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_pigman \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_skeleton \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_zombie \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/drink_potion \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_bat \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_cow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_snail \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_dungeon \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_fire_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_gnomish_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_graveyard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_ice_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_sewers \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_troll_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_vault \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/find_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/fire_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_arrow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/open_chest \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_furnace \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_table \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/wake_up \u001b[31m100.0\u001b[0m\u001b[1m / dataset_size \u001b[31m551.0\u001b[0m\u001b[1m / train_return \u001b[31m0.1\u001b[0m\u001b[1m / train_length \u001b[31m309.0\u001b[0m\u001b[1m / train_episodes \u001b[31m2.0\u001b[0m\u001b[1m\u001b[0m\n", + "\u001b[32m2024-06-06 16:23:44.957\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mtools\u001b[0m:\u001b[36mwrite\u001b[0m:\u001b[36m85\u001b[0m - \u001b[1m[0] log_achievements/cast_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/cast_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_coal \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_diamond \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_drink \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_iron \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_ruby \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapling \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapphire \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_wood \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/damage_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_deep_thing \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_fire_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_frost_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_warrior \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_ice_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_knight \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_kobold \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_lizard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_mage \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_solider \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_pigman \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_skeleton \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_zombie \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/drink_potion \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_bat \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_cow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_snail \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_dungeon \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_fire_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_gnomish_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_graveyard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_ice_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_sewers \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_troll_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_vault \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/find_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/fire_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_arrow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/open_chest \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_furnace \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_table \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/wake_up \u001b[31m100.0\u001b[0m\u001b[1m / dataset_size \u001b[31m781.0\u001b[0m\u001b[1m / train_return \u001b[31m0.1\u001b[0m\u001b[1m / train_length \u001b[31m230.0\u001b[0m\u001b[1m / train_episodes \u001b[31m3.0\u001b[0m\u001b[1m\u001b[0m\n", + "\u001b[32m2024-06-06 16:23:48.066\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mtools\u001b[0m:\u001b[36mwrite\u001b[0m:\u001b[36m85\u001b[0m - \u001b[1m[0] log_achievements/cast_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/cast_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_coal \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_diamond \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_drink \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_iron \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_ruby \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapling \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapphire \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_wood \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/damage_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_deep_thing \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_fire_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_frost_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_warrior \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_ice_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_knight \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_kobold \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_lizard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_mage \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_solider \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_pigman \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_skeleton \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_zombie \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/drink_potion \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_bat \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_cow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_snail \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_dungeon \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_fire_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_gnomish_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_graveyard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_ice_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_sewers \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_troll_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_vault \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/find_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/fire_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_arrow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/open_chest \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_furnace \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_table \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/wake_up \u001b[31m100.0\u001b[0m\u001b[1m / dataset_size \u001b[31m1134.0\u001b[0m\u001b[1m / train_return \u001b[31m0.1\u001b[0m\u001b[1m / train_length \u001b[31m353.0\u001b[0m\u001b[1m / train_episodes \u001b[31m4.0\u001b[0m\u001b[1m\u001b[0m\n", + "\u001b[32m2024-06-06 16:23:51.250\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mtools\u001b[0m:\u001b[36mwrite\u001b[0m:\u001b[36m85\u001b[0m - \u001b[1m[0] log_achievements/cast_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/cast_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_coal \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_diamond \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_drink \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_iron \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_ruby \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapling \u001b[31m100.0\u001b[0m\u001b[1m / log_achievements/collect_sapphire \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_wood \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/damage_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_deep_thing \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_fire_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_frost_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_warrior \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_ice_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_knight \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_kobold \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_lizard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_mage \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_solider \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_pigman \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_skeleton \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_zombie \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/drink_potion \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_bat \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_cow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_snail \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_dungeon \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_fire_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_gnomish_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_graveyard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_ice_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_sewers \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_troll_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_vault \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/find_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/fire_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_arrow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/open_chest \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_furnace \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_plant \u001b[31m100.0\u001b[0m\u001b[1m / log_achievements/place_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_table \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/wake_up \u001b[31m100.0\u001b[0m\u001b[1m / dataset_size \u001b[31m1509.0\u001b[0m\u001b[1m / train_return \u001b[31m2.1\u001b[0m\u001b[1m / train_length \u001b[31m375.0\u001b[0m\u001b[1m / train_episodes \u001b[31m5.0\u001b[0m\u001b[1m\u001b[0m\n", + "\u001b[32m2024-06-06 16:23:54.187\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mtools\u001b[0m:\u001b[36mwrite\u001b[0m:\u001b[36m85\u001b[0m - \u001b[1m[0] log_achievements/cast_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/cast_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_coal \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_diamond \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_drink \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_iron \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_ruby \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapling \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapphire \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_wood \u001b[31m100.0\u001b[0m\u001b[1m / log_achievements/damage_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_deep_thing \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_fire_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_frost_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_warrior \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_ice_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_knight \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_kobold \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_lizard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_mage \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_solider \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_pigman \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_skeleton \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_zombie \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/drink_potion \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_bat \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_cow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_snail \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_dungeon \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_fire_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_gnomish_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_graveyard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_ice_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_sewers \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_troll_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_vault \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/find_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/fire_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_arrow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/open_chest \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_furnace \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_table \u001b[31m100.0\u001b[0m\u001b[1m / log_achievements/place_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/wake_up \u001b[31m100.0\u001b[0m\u001b[1m / dataset_size \u001b[31m1827.0\u001b[0m\u001b[1m / train_return \u001b[31m2.1\u001b[0m\u001b[1m / train_length \u001b[31m318.0\u001b[0m\u001b[1m / train_episodes \u001b[31m6.0\u001b[0m\u001b[1m\u001b[0m\n", + "\u001b[32m2024-06-06 16:23:56.423\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mtools\u001b[0m:\u001b[36mwrite\u001b[0m:\u001b[36m85\u001b[0m - \u001b[1m[0] log_achievements/cast_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/cast_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_coal \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_diamond \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_drink \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_iron \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_ruby \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapling \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapphire \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_wood \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/damage_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_deep_thing \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_fire_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_frost_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_warrior \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_ice_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_knight \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_kobold \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_lizard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_mage \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_solider \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_pigman \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_skeleton \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_zombie \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/drink_potion \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_bat \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_cow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_snail \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_dungeon \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_fire_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_gnomish_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_graveyard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_ice_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_sewers \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_troll_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_vault \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/find_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/fire_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_arrow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/open_chest \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_furnace \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_table \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/wake_up \u001b[31m100.0\u001b[0m\u001b[1m / dataset_size \u001b[31m2084.0\u001b[0m\u001b[1m / train_return \u001b[31m0.1\u001b[0m\u001b[1m / train_length \u001b[31m257.0\u001b[0m\u001b[1m / train_episodes \u001b[31m7.0\u001b[0m\u001b[1m\u001b[0m\n", + "\u001b[32m2024-06-06 16:23:59.835\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mtools\u001b[0m:\u001b[36mwrite\u001b[0m:\u001b[36m85\u001b[0m - \u001b[1m[0] log_achievements/cast_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/cast_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_coal \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_diamond \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_drink \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_iron \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_ruby \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapling \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapphire \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_wood \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/damage_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_deep_thing \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_fire_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_frost_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_warrior \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_ice_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_knight \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_kobold \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_lizard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_mage \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_solider \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_pigman \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_skeleton \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_zombie \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/drink_potion \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_bat \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_cow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_snail \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_dungeon \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_fire_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_gnomish_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_graveyard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_ice_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_sewers \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_troll_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_vault \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/find_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/fire_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_arrow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/open_chest \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_furnace \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_table \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/wake_up \u001b[31m100.0\u001b[0m\u001b[1m / dataset_size \u001b[31m2474.0\u001b[0m\u001b[1m / train_return \u001b[31m0.1\u001b[0m\u001b[1m / train_length \u001b[31m390.0\u001b[0m\u001b[1m / train_episodes \u001b[31m8.0\u001b[0m\u001b[1m\u001b[0m\n", + "\u001b[32m2024-06-06 16:24:00.056\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m92\u001b[0m - \u001b[1mLogger: (2500 steps).\u001b[0m\n", + "\u001b[32m2024-06-06 16:24:00.057\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m94\u001b[0m - \u001b[1mSimulate agent.\u001b[0m\n" + ] + } + ], + "source": [ + "from loguru import logger\n", + "from tqdm.auto import tqdm\n", + "import pathlib\n", + "\n", + "import torch\n", + "from torch import nn\n", + "from torch import distributions as torchd\n", + "\n", + "import exploration as expl\n", + "import models\n", + "import tools\n", + "import envs.wrappers as wrappers\n", + "from parallel import Parallel, Damy\n", + "\n", + "# from main\n", + "tools.set_seed_everywhere(config.seed)\n", + "if config.deterministic_run:\n", + " tools.enable_deterministic_run()\n", + "logdir = pathlib.Path(config.logdir).expanduser()\n", + "config.traindir = config.traindir or logdir / \"train_eps\"\n", + "config.evaldir = config.evaldir or logdir / \"eval_eps\"\n", + "config.steps //= config.action_repeat\n", + "config.eval_every //= config.action_repeat\n", + "config.log_every //= config.action_repeat\n", + "config.time_limit //= config.action_repeat\n", + "\n", + "logger.info(f\"Logdir {logdir}\")\n", + "logdir.mkdir(parents=True, exist_ok=True)\n", + "config.traindir.mkdir(parents=True, exist_ok=True)\n", + "config.evaldir.mkdir(parents=True, exist_ok=True)\n", + "step = count_steps(config.traindir)\n", + "# step in logger is environmental step\n", + "tlogger = tools.Logger(logdir, config.action_repeat * step)\n", + "logger.add(logdir/\"logger.log\")\n", + "\n", + "logger.info(\"Create envs.\")\n", + "if config.offline_traindir:\n", + " directory = config.offline_traindir.format(**vars(config))\n", + "else:\n", + " directory = config.traindir\n", + "train_eps = tools.load_episodes(directory, limit=config.dataset_size)\n", + "if config.offline_evaldir:\n", + " directory = config.offline_evaldir.format(**vars(config))\n", + "else:\n", + " directory = config.evaldir\n", + "eval_eps = tools.load_episodes(directory, limit=1)\n", + "make = lambda mode, id: make_env(config, mode, id)\n", + "train_envs = [make(\"train\", i) for i in range(config.envs)]\n", + "eval_envs = [make(\"eval\", i) for i in range(config.envs)]\n", + "if config.parallel:\n", + " train_envs = [Parallel(env, \"process\") for env in train_envs]\n", + " eval_envs = [Parallel(env, \"process\") for env in eval_envs]\n", + "else:\n", + " train_envs = [Damy(env) for env in train_envs]\n", + " eval_envs = [Damy(env) for env in eval_envs]\n", + "acts = train_envs[0].action_space\n", + "logger.info(f\"Action Space {acts}\" )\n", + "config.num_actions = acts.n if hasattr(acts, \"n\") else acts.shape[0]\n", + "\n", + "state = None\n", + "if not config.offline_traindir:\n", + " prefill = max(0, config.prefill - count_steps(config.traindir))\n", + " logger.info(f\"Prefill dataset ({prefill} steps).\")\n", + " if hasattr(acts, \"discrete\"):\n", + " random_actor = tools.OneHotDist(\n", + " torch.zeros(config.num_actions).repeat(config.envs, 1)\n", + " )\n", + " else:\n", + " random_actor = torchd.independent.Independent(\n", + " torchd.uniform.Uniform(\n", + " torch.Tensor(acts.low).repeat(config.envs, 1),\n", + " torch.Tensor(acts.high).repeat(config.envs, 1),\n", + " ),\n", + " 1,\n", + " )\n", + "\n", + " def random_agent(o, d, s):\n", + " action = random_actor.sample()\n", + " logprob = random_actor.log_prob(action)\n", + " return {\"action\": action, \"logprob\": logprob}, None\n", + "\n", + " state = tools.simulate(\n", + " random_agent,\n", + " train_envs,\n", + " train_eps,\n", + " config.traindir,\n", + " tlogger,\n", + " limit=config.dataset_size,\n", + " steps=prefill,\n", + " )\n", + " tlogger.step += prefill * config.action_repeat\n", + " logger.info(f\"Logger: ({tlogger.step} steps).\")\n", + "\n", + "logger.info(\"Simulate agent.\")\n", + "train_dataset = make_dataset(train_eps, config)\n", + "eval_dataset = make_dataset(eval_eps, config)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_envs[0].observation_space" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "config = parse_args(argv)\n", + "config.num_actions = acts.n if hasattr(acts, \"n\") else acts.shape[0]\n", + "agent = Dreamer(\n", + " train_envs[0].observation_space,\n", + " train_envs[0].action_space,\n", + " config,\n", + " tlogger,\n", + " train_dataset,\n", + ").to(config.device)\n", + "# print(agent)\n", + "agent.requires_grad_(requires_grad=False)\n", + "if (logdir / \"latest.pt\").exists():\n", + " checkpoint = torch.load(logdir / \"latest.pt\")\n", + " agent.load_state_dict(checkpoint[\"agent_state_dict\"])\n", + " tools.recursively_load_optim_state_dict(agent, checkpoint[\"optims_state_dict\"])\n", + " agent._should_pretrain._once = False\n", + " logger.warning(f\"Loaded model from {logdir / 'latest.pt'}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- note model_opt includes actor.wm\n", + " - encoder\n", + " - rssm\n", + " - heads\n", + "- actor" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Now lets play" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "assert state is not None\n", + "import numpy as np\n", + "\n", + "state" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from tools import convert, add_to_cache\n", + "envs = train_envs\n", + "cache = train_eps\n", + "\n", + "step, episode = 0, 0\n", + "done = np.ones(len(envs), bool)\n", + "length = np.zeros(len(envs), np.int32)\n", + "obs = [None] * len(envs)\n", + "agent_state = None\n", + "reward = [0] * len(envs)\n", + "\n", + "indices = [index for index, d in enumerate(done) if d]\n", + "results = [envs[i].reset() for i in indices]\n", + "results = [r() for r in results]\n", + "for index, result in zip(indices, results):\n", + " t = result.copy()\n", + " t = {k: convert(v) for k, v in t.items()}\n", + " # action will be added to transition in add_to_cache\n", + " t[\"reward\"] = 0.0\n", + " t[\"discount\"] = 1.0\n", + " # initial state should be added to cache\n", + " add_to_cache(cache, envs[index].id, t)\n", + " # replace obs with done by initial state\n", + " obs[index] = result\n", + "# step agents" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "envs[0].observation_space" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "obs[0]['state_map'].shape, obs[0]['state_inventory'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# from tools.simulate\n", + "\n", + "# step\n", + "# step, episode, done, length, obs, agent_state, reward = state\n", + "obs2 = {k: np.stack([o[k] for o in obs]) for k in obs[0] if \"log_\" not in k}\n", + "action, agent_state = agent(obs2, done, agent_state)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from torchinfo import summary\n", + "\n", + "summary(agent, input=(obs, done, agent_state), depth=4)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# agent._wm.heads" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Fine grained torchinfo" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "wm = agent._wm\n", + "data = next(agent._dataset) \n", + "# self._train()\n", + "# post, context, mets = wm._train(data)\n", + "data = wm.preprocess(data)\n", + "embed = wm.encoder(data)\n", + "post, prior = wm.dynamics.observe(\n", + " embed, data[\"action\"], data[\"is_first\"]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "summary(wm.encoder, input_data=(data,), depth=4, col_names=[\"input_size\", \"output_size\", \"num_params\", ])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# heads\n", + "feat = wm.dynamics.get_feat(post)\n", + "for name, head in wm.heads.items():\n", + " try:\n", + " o = summary(head, input_data=(feat,), depth=3, col_names=[\"input_size\", \"output_size\", \"num_params\", ])\n", + " print(name)\n", + " print(o)\n", + " except Exception as e:\n", + " print(f\"Summary Failed for {name} {e}\")\n", + " continue" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# fail as no call method\n", + "# summary(wm.dynamics, input_data=(embed, data[\"action\"], data[\"is_first\"]), depth=3, col_names=[\"output_size\", \"num_params\", ])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "actor = agent._task_behavior.actor\n", + "\n", + "summary(actor.layers, input_data=(feat,), depth=3, col_names=[\"output_size\", \"num_params\", \"output_size\" ])\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "value = agent._task_behavior.actor\n", + "summary(value.layers, input_data=(feat,), depth=3, col_names=[\"output_size\", \"num_params\", \"output_size\" ])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "8268" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "o = obs['state'].reshape((-1, 8268))\n", + "map = o[:, :8217].reshape((-1, 9, 11, 83))\n", + "map.shape\n", + "inventories = o[:, 8217:]\n", + "inventories\n", + "\n", + "map" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/nbs/02_torchinfo.ipynb b/nbs/02_torchinfo.ipynb index 3e551c8..c01f92e 100644 --- a/nbs/02_torchinfo.ipynb +++ b/nbs/02_torchinfo.ipynb @@ -235,6 +235,17 @@ " logger.warning(f\"Loaded model from {logdir / 'latest.pt'}\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- note model_opt includes actor.wm\n", + " - encoder\n", + " - rssm\n", + " - heads\n", + "- actor" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -327,7 +338,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 40, "metadata": {}, "outputs": [ { @@ -339,21 +350,49 @@ "Dreamer --\n", "├─OptimizedModule: 1-1 --\n", "│ └─WorldModel: 2-1 --\n", - "│ │ └─MultiEncoder: 3-1 (4,365,824)\n", - "│ │ └─RSSM: 3-2 (3,831,808)\n", - "│ │ └─ModuleDict: 3-3 (7,534,488)\n", + "│ │ └─MultiEncoder: 3-1 --\n", + "│ │ │ └─MLP: 4-1 (4,365,824)\n", + "│ │ └─RSSM: 3-2 512\n", + "│ │ │ └─Sequential: 4-2 (547,328)\n", + "│ │ │ └─GRUCell: 4-3 (1,575,936)\n", + "│ │ │ └─Sequential: 4-4 (263,168)\n", + "│ │ │ └─Sequential: 4-5 (394,240)\n", + "│ │ │ └─Linear: 4-6 (525,312)\n", + "│ │ │ └─Linear: 4-7 (525,312)\n", + "│ │ └─ModuleDict: 3-3 --\n", + "│ │ │ └─MultiDecoder: 4-8 (4,775,576)\n", + "│ │ │ └─MLP: 4-9 (1,444,607)\n", + "│ │ │ └─MLP: 4-10 (1,314,305)\n", "├─OptimizedModule: 1-2 --\n", "│ └─ImagBehavior: 2-2 15,732,120\n", "│ │ └─WorldModel: 3-4 (recursive)\n", - "│ │ └─MLP: 3-5 (1,335,851)\n", - "│ │ └─MLP: 3-6 (1,181,439)\n", - "│ │ └─MLP: 3-7 (1,181,439)\n", + "│ │ │ └─MultiEncoder: 4-11 (recursive)\n", + "│ │ │ └─RSSM: 4-12 (recursive)\n", + "│ │ │ └─ModuleDict: 4-13 (recursive)\n", + "│ │ └─MLP: 3-5 --\n", + "│ │ │ └─Sequential: 4-14 (1,313,792)\n", + "│ │ │ └─Linear: 4-15 (22,059)\n", + "│ │ └─MLP: 3-6 --\n", + "│ │ │ └─Sequential: 4-16 (1,050,624)\n", + "│ │ │ └─Linear: 4-17 (130,815)\n", + "│ │ └─MLP: 3-7 --\n", + "│ │ │ └─Sequential: 4-18 (1,050,624)\n", + "│ │ │ └─Linear: 4-19 (130,815)\n", "├─OptimizedModule: 1-3 (recursive)\n", "│ └─ImagBehavior: 2-3 (recursive)\n", "│ │ └─WorldModel: 3-8 (recursive)\n", + "│ │ │ └─MultiEncoder: 4-20 (recursive)\n", + "│ │ │ └─RSSM: 4-21 (recursive)\n", + "│ │ │ └─ModuleDict: 4-22 (recursive)\n", "│ │ └─MLP: 3-9 (recursive)\n", + "│ │ │ └─Sequential: 4-23 (recursive)\n", + "│ │ │ └─Linear: 4-24 (recursive)\n", "│ │ └─MLP: 3-10 (recursive)\n", + "│ │ │ └─Sequential: 4-25 (recursive)\n", + "│ │ │ └─Linear: 4-26 (recursive)\n", "│ │ └─MLP: 3-11 (recursive)\n", + "│ │ │ └─Sequential: 4-27 (recursive)\n", + "│ │ │ └─Linear: 4-28 (recursive)\n", "=====================================================================================\n", "Total params: 35,162,969\n", "Trainable params: 0\n", @@ -361,7 +400,7 @@ "=====================================================================================" ] }, - "execution_count": 11, + "execution_count": 40, "metadata": {}, "output_type": "execute_result" } @@ -369,7 +408,16 @@ "source": [ "from torchinfo import summary\n", "\n", - "summary(agent, input=(obs, done, agent_state), depth=3)" + "summary(agent, input=(obs, done, agent_state), depth=4)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [], + "source": [ + "# agent._wm.heads" ] }, { @@ -398,52 +446,52 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 41, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "==========================================================================================\n", - "Layer (type:depth-idx) Output Shape Param #\n", - "==========================================================================================\n", - "MultiEncoder [128, 16, 256] --\n", - "├─MLP: 1-1 [128, 16, 256] --\n", - "│ └─Sequential: 2-1 [128, 16, 256] --\n", - "│ │ └─Linear: 3-1 [128, 16, 256] (4,233,216)\n", - "│ │ └─LayerNorm: 3-2 [128, 16, 256] (512)\n", - "│ │ └─SiLU: 3-3 [128, 16, 256] --\n", - "│ │ └─Linear: 3-4 [128, 16, 256] (65,536)\n", - "│ │ └─LayerNorm: 3-5 [128, 16, 256] (512)\n", - "│ │ └─SiLU: 3-6 [128, 16, 256] --\n", - "│ │ └─Linear: 3-7 [128, 16, 256] (65,536)\n", - "│ │ └─LayerNorm: 3-8 [128, 16, 256] (512)\n", - "│ │ └─SiLU: 3-9 [128, 16, 256] --\n", - "==========================================================================================\n", + "===================================================================================================================\n", + "Layer (type:depth-idx) Input Shape Output Shape Param #\n", + "===================================================================================================================\n", + "MultiEncoder [128, 16, 130, 110, 3] [128, 16, 256] --\n", + "├─MLP: 1-1 [128, 16, 16536] [128, 16, 256] --\n", + "│ └─Sequential: 2-1 [128, 16, 16536] [128, 16, 256] --\n", + "│ │ └─Linear: 3-1 [128, 16, 16536] [128, 16, 256] (4,233,216)\n", + "│ │ └─LayerNorm: 3-2 [128, 16, 256] [128, 16, 256] (512)\n", + "│ │ └─SiLU: 3-3 [128, 16, 256] [128, 16, 256] --\n", + "│ │ └─Linear: 3-4 [128, 16, 256] [128, 16, 256] (65,536)\n", + "│ │ └─LayerNorm: 3-5 [128, 16, 256] [128, 16, 256] (512)\n", + "│ │ └─SiLU: 3-6 [128, 16, 256] [128, 16, 256] --\n", + "│ │ └─Linear: 3-7 [128, 16, 256] [128, 16, 256] (65,536)\n", + "│ │ └─LayerNorm: 3-8 [128, 16, 256] [128, 16, 256] (512)\n", + "│ │ └─SiLU: 3-9 [128, 16, 256] [128, 16, 256] --\n", + "===================================================================================================================\n", "Total params: 4,365,824\n", "Trainable params: 0\n", "Non-trainable params: 4,365,824\n", "Total mult-adds (M): 558.83\n", - "==========================================================================================\n", + "===================================================================================================================\n", "Input size (MB): 487.31\n", "Forward/backward pass size (MB): 25.17\n", "Params size (MB): 17.46\n", "Estimated Total Size (MB): 529.94\n", - "==========================================================================================" + "===================================================================================================================" ] }, - "execution_count": 13, + "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "summary(wm.encoder, input_data=(data,), depth=3, col_names=[\"output_size\", \"num_params\", ])" + "summary(wm.encoder, input_data=(data,), depth=4, col_names=[\"input_size\", \"output_size\", \"num_params\", ])" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 42, "metadata": {}, "outputs": [ { @@ -451,34 +499,34 @@ "output_type": "stream", "text": [ "decoder\n", - "==========================================================================================\n", - "Layer (type:depth-idx) Output Shape Param #\n", - "==========================================================================================\n", - "MultiDecoder -- --\n", - "├─MLP: 1-1 -- --\n", - "│ └─Sequential: 2-1 [128, 16, 256] --\n", - "│ │ └─Linear: 3-1 [128, 16, 256] (393,216)\n", - "│ │ └─LayerNorm: 3-2 [128, 16, 256] (512)\n", - "│ │ └─SiLU: 3-3 [128, 16, 256] --\n", - "│ │ └─Linear: 3-4 [128, 16, 256] (65,536)\n", - "│ │ └─LayerNorm: 3-5 [128, 16, 256] (512)\n", - "│ │ └─SiLU: 3-6 [128, 16, 256] --\n", - "│ │ └─Linear: 3-7 [128, 16, 256] (65,536)\n", - "│ │ └─LayerNorm: 3-8 [128, 16, 256] (512)\n", - "│ │ └─SiLU: 3-9 [128, 16, 256] --\n", - "│ └─ModuleDict: 2-2 -- --\n", - "│ │ └─Linear: 3-10 [128, 16, 16536] (4,249,752)\n", - "==========================================================================================\n", + "===================================================================================================================\n", + "Layer (type:depth-idx) Input Shape Output Shape Param #\n", + "===================================================================================================================\n", + "MultiDecoder [128, 16, 1536] -- --\n", + "├─MLP: 1-1 [128, 16, 1536] -- --\n", + "│ └─Sequential: 2-1 [128, 16, 1536] [128, 16, 256] --\n", + "│ │ └─Linear: 3-1 [128, 16, 1536] [128, 16, 256] (393,216)\n", + "│ │ └─LayerNorm: 3-2 [128, 16, 256] [128, 16, 256] (512)\n", + "│ │ └─SiLU: 3-3 [128, 16, 256] [128, 16, 256] --\n", + "│ │ └─Linear: 3-4 [128, 16, 256] [128, 16, 256] (65,536)\n", + "│ │ └─LayerNorm: 3-5 [128, 16, 256] [128, 16, 256] (512)\n", + "│ │ └─SiLU: 3-6 [128, 16, 256] [128, 16, 256] --\n", + "│ │ └─Linear: 3-7 [128, 16, 256] [128, 16, 256] (65,536)\n", + "│ │ └─LayerNorm: 3-8 [128, 16, 256] [128, 16, 256] (512)\n", + "│ │ └─SiLU: 3-9 [128, 16, 256] [128, 16, 256] --\n", + "│ └─ModuleDict: 2-2 -- -- --\n", + "│ │ └─Linear: 3-10 [128, 16, 256] [128, 16, 16536] (4,249,752)\n", + "===================================================================================================================\n", "Total params: 4,775,576\n", "Trainable params: 0\n", "Non-trainable params: 4,775,576\n", "Total mult-adds (M): 611.27\n", - "==========================================================================================\n", + "===================================================================================================================\n", "Input size (MB): 12.58\n", "Forward/backward pass size (MB): 296.09\n", "Params size (MB): 19.10\n", "Estimated Total Size (MB): 327.78\n", - "==========================================================================================\n", + "===================================================================================================================\n", "Summary Failed for reward Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Sequential: 1, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 1]\n", "Summary Failed for cont Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Sequential: 1, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 1]\n" ] @@ -489,7 +537,7 @@ "feat = wm.dynamics.get_feat(post)\n", "for name, head in wm.heads.items():\n", " try:\n", - " o = summary(head, input_data=(feat,), depth=3, col_names=[\"output_size\", \"num_params\", ])\n", + " o = summary(head, input_data=(feat,), depth=3, col_names=[\"input_size\", \"output_size\", \"num_params\", ])\n", " print(name)\n", " print(o)\n", " except Exception as e:\n", @@ -514,6 +562,221 @@ "# summary(wm.dynamics, input_data=(embed, data[\"action\"], data[\"is_first\"]), depth=3, col_names=[\"output_size\", \"num_params\", ])" ] }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "===================================================================================================================\n", + "Layer (type:depth-idx) Output Shape Param # Output Shape\n", + "===================================================================================================================\n", + "Sequential [128, 16, 512] -- [128, 16, 512]\n", + "├─Linear: 1-1 [128, 16, 512] (786,432) [128, 16, 512]\n", + "├─LayerNorm: 1-2 [128, 16, 512] (1,024) [128, 16, 512]\n", + "├─SiLU: 1-3 [128, 16, 512] -- [128, 16, 512]\n", + "├─Linear: 1-4 [128, 16, 512] (262,144) [128, 16, 512]\n", + "├─LayerNorm: 1-5 [128, 16, 512] (1,024) [128, 16, 512]\n", + "├─SiLU: 1-6 [128, 16, 512] -- [128, 16, 512]\n", + "├─Linear: 1-7 [128, 16, 512] (262,144) [128, 16, 512]\n", + "├─LayerNorm: 1-8 [128, 16, 512] (1,024) [128, 16, 512]\n", + "├─SiLU: 1-9 [128, 16, 512] -- [128, 16, 512]\n", + "===================================================================================================================\n", + "Total params: 1,313,792\n", + "Trainable params: 0\n", + "Non-trainable params: 1,313,792\n", + "Total mult-adds (M): 168.17\n", + "===================================================================================================================\n", + "Input size (MB): 12.58\n", + "Forward/backward pass size (MB): 50.33\n", + "Params size (MB): 5.26\n", + "Estimated Total Size (MB): 68.17\n", + "===================================================================================================================" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "actor = agent._task_behavior.actor\n", + "\n", + "summary(actor.layers, input_data=(feat,), depth=3, col_names=[\"output_size\", \"num_params\", \"output_size\" ])\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "===================================================================================================================\n", + "Layer (type:depth-idx) Output Shape Param # Output Shape\n", + "===================================================================================================================\n", + "Sequential [128, 16, 512] -- [128, 16, 512]\n", + "├─Linear: 1-1 [128, 16, 512] (786,432) [128, 16, 512]\n", + "├─LayerNorm: 1-2 [128, 16, 512] (1,024) [128, 16, 512]\n", + "├─SiLU: 1-3 [128, 16, 512] -- [128, 16, 512]\n", + "├─Linear: 1-4 [128, 16, 512] (262,144) [128, 16, 512]\n", + "├─LayerNorm: 1-5 [128, 16, 512] (1,024) [128, 16, 512]\n", + "├─SiLU: 1-6 [128, 16, 512] -- [128, 16, 512]\n", + "├─Linear: 1-7 [128, 16, 512] (262,144) [128, 16, 512]\n", + "├─LayerNorm: 1-8 [128, 16, 512] (1,024) [128, 16, 512]\n", + "├─SiLU: 1-9 [128, 16, 512] -- [128, 16, 512]\n", + "===================================================================================================================\n", + "Total params: 1,313,792\n", + "Trainable params: 0\n", + "Non-trainable params: 1,313,792\n", + "Total mult-adds (M): 168.17\n", + "===================================================================================================================\n", + "Input size (MB): 12.58\n", + "Forward/backward pass size (MB): 50.33\n", + "Params size (MB): 5.26\n", + "Estimated Total Size (MB): 68.17\n", + "===================================================================================================================" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "value = agent._task_behavior.actor\n", + "summary(value.layers, input_data=(feat,), depth=3, col_names=[\"output_size\", \"num_params\", \"output_size\" ])" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[[0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " ...,\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.]],\n", + "\n", + " [[0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " ...,\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.]],\n", + "\n", + " [[0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " ...,\n", + " [0., 0., 1., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.]],\n", + "\n", + " ...,\n", + "\n", + " [[0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " ...,\n", + " [0., 0., 1., ..., 0., 0., 1.],\n", + " [0., 0., 1., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.]],\n", + "\n", + " [[0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " ...,\n", + " [0., 0., 1., ..., 0., 0., 1.],\n", + " [0., 0., 1., ..., 0., 0., 1.],\n", + " [0., 0., 1., ..., 0., 0., 1.]],\n", + "\n", + " [[0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " ...,\n", + " [0., 0., 1., ..., 0., 0., 1.],\n", + " [0., 0., 1., ..., 0., 0., 1.],\n", + " [0., 0., 1., ..., 0., 0., 1.]]],\n", + "\n", + "\n", + " [[[0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " ...,\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.]],\n", + "\n", + " [[0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " ...,\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.]],\n", + "\n", + " [[0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " ...,\n", + " [0., 0., 1., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.]],\n", + "\n", + " ...,\n", + "\n", + " [[0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " ...,\n", + " [0., 0., 1., ..., 0., 0., 1.],\n", + " [0., 0., 1., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.]],\n", + "\n", + " [[0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " ...,\n", + " [0., 0., 1., ..., 0., 0., 1.],\n", + " [0., 0., 1., ..., 0., 0., 1.],\n", + " [0., 0., 1., ..., 0., 0., 1.]],\n", + "\n", + " [[0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " [0., 0., 0., ..., 0., 0., 1.],\n", + " ...,\n", + " [0., 0., 1., ..., 0., 0., 1.],\n", + " [0., 0., 1., ..., 0., 0., 1.],\n", + " [0., 0., 1., ..., 0., 0., 1.]]]], dtype=float16)" + ] + }, + "execution_count": 62, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "o = obs['state'].reshape((-1, 8268))\n", + "map = o[:, :8217].reshape((-1, 9, 11, 83))\n", + "map.shape\n", + "inventories = o[:, 8217:]\n", + "inventories\n", + "\n", + "map" + ] + }, { "cell_type": "code", "execution_count": null, @@ -521,6 +784,24 @@ "outputs": [], "source": [] }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "8268" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [] + }, { "cell_type": "code", "execution_count": null, diff --git a/networks.py b/networks.py index 7661f7f..1b1dd7e 100644 --- a/networks.py +++ b/networks.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from torch import distributions as torchd from loguru import logger import tools +from einops import rearrange class RSSM(nn.Module): @@ -485,15 +486,16 @@ class ConvEncoder(nn.Module): def forward(self, obs): obs -= 0.5 - # (batch, time, h, w, ch) -> (batch * time, h, w, ch) - x = obs.reshape((-1,) + tuple(obs.shape[-3:])) - # (batch * time, h, w, ch) -> (batch * time, ch, h, w) - x = x.permute(0, 3, 1, 2) - x = self.layers(x) - # (batch * time, ...) -> (batch * time, -1) - x = x.reshape([x.shape[0], np.prod(x.shape[1:])]) - # (batch * time, -1) -> (batch, time, -1) - return x.reshape(list(obs.shape[:-3]) + [x.shape[-1]]) + if obs.ndim == 4: + x = rearrange(obs, "b h w c -> b c h w") + x = self.layers(x) + x = rearrange(x, "b c h w -> b (c h w)") + else: + x = rearrange(obs, "b t h w c -> (b t) c h w") + x = self.layers(x) + x = rearrange(x, "(b t) c h w -> b t (c h w)", t=obs.shape[1]) + assert x.shape[-1] == self.outdim, f"{x.shape[-1]}!={self.outdim}" + return x class ConvDecoder(nn.Module): @@ -605,6 +607,7 @@ class MLP(nn.Module): symlog_inputs=False, device="cuda", name="NoName", + embedding_dim=None, ): super(MLP, self).__init__() self._shape = (shape,) if isinstance(shape, int) else shape @@ -623,9 +626,15 @@ class MLP(nn.Module): self.layers = nn.Sequential() for i in range(layers): - self.layers.add_module( - f"{name}_linear{i}", nn.Linear(inp_dim, units, bias=False) - ) + if i==0 and embedding_dim is not None: + self.layers.add_module( + f"{name}_embed", nn.Embedding(inp_dim, units) + ) + inp_dim = units + else: + self.layers.add_module( + f"{name}_linear{i}", nn.Linear(inp_dim, units, bias=False) + ) if norm: self.layers.add_module( f"{name}_norm{i}", nn.LayerNorm(units, eps=1e-03)