try craftax smaller

This commit is contained in:
wassname
2024-06-06 20:29:17 +08:00
parent ae030731e3
commit c77993207a
6 changed files with 827 additions and 116 deletions
+17 -31
View File
@@ -149,26 +149,6 @@ craftax:
imag_gradient: 'reinforce'
craftax_small:
task: craftax_Craftax-Symbolic-AutoReset-v1
step: 1e6
action_repeat: 1
envs: 1
train_ratio: 512
video_pred_log: false
dyn_hidden: 512
dyn_deter: 512
units: 512
encoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 3, mlp_units: 256}
decoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 3, mlp_units: 256}
actor: {layers: 3, dist: 'onehot', std: 'none'}
value: {layers: 3}
reward_head: {layers: 3}
cont_head: {layers: 3}
imag_gradient: 'reinforce'
batch_size: 128
batch_length: 16
craftax_small2:
task: craftax_Craftax-Symbolic-AutoReset-v1
step: 1e6
action_repeat: 1
@@ -176,10 +156,12 @@ craftax_small2:
train_ratio: 512
video_pred_log: false
dyn_hidden: 256
dyn_deter: 512
dyn_deter: 256
dyn_stoch: 24
dyn_discrete: 24
# note: depth is cnn hidden_dim
encoder: {cnn_keys: 'state_map', cnn_depth: 16, kernel_size: 4, minres: 4, mlp_keys: "state_inventory", mlp_layers: 2, mlp_units: 16}
decoder: {cnn_keys: 'state_map', cnn_depth: 16, kernel_size: 4, minres: 4, mlp_keys: "state_inventory", mlp_layers: 2, mlp_units: 16}
encoder: {cnn_keys: 'state_map', cnn_depth: 32, kernel_size: 4, minres: 2, mlp_keys: "state_inventory", mlp_layers: 2, mlp_units: 16}
decoder: {cnn_keys: 'state_map', cnn_depth: 32, kernel_size: 4, minres: 2, mlp_keys: "state_inventory", mlp_layers: 2, mlp_units: 16}
actor: {layers: 3, dist: 'onehot', std: 'none'}
value: {layers: 3}
# note units is the head hidden_dim
@@ -188,27 +170,31 @@ craftax_small2:
cont_head: {layers: 3}
imag_gradient: 'reinforce'
batch_size: 256
batch_length: 16
batch_length: 32
craftax_smaller:
task: craftax_Craftax-Symbolic-AutoReset-v1
step: 1e6
action_repeat: 1
envs: 1
train_ratio: 256
train_ratio: 512
video_pred_log: false
dyn_hidden: 256
dyn_deter: 1024
units: 256
encoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 2, mlp_units: 256, }
decoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 2, mlp_units: 256}
dyn_hidden: 128
dyn_deter: 128
dyn_stoch: 16
dyn_discrete: 16
# note: depth is cnn hidden_dim
encoder: {cnn_keys: 'state_map', cnn_depth: 16, kernel_size: 4, minres: 2, mlp_keys: "state_inventory", mlp_layers: 2, mlp_units: 8}
decoder: {cnn_keys: 'state_map', cnn_depth: 16, kernel_size: 4, minres: 2, mlp_keys: "state_inventory", mlp_layers: 2, mlp_units: 8}
actor: {layers: 2, dist: 'onehot', std: 'none'}
value: {layers: 2}
# note units is the head hidden_dim
units: 128
reward_head: {layers: 2}
cont_head: {layers: 2}
imag_gradient: 'reinforce'
batch_size: 256
batch_length: 16
batch_length: 32
atari100k:
steps: 4e5
+7 -1
View File
@@ -360,9 +360,14 @@ def main(config):
def parse_args(argv=None):
parser = argparse.ArgumentParser()
parser.add_argument("--configs", nargs="+")
if argv is None:
argv = sys.argv
args, remaining = parser.parse_known_args(argv[1:])
# load config, using relative path
root_dir = pathlib.Path(__file__).parent
configs = yaml.safe_load(
(pathlib.Path(argv[0]).parent / "configs.yaml").read_text()
(root_dir / "configs.yaml").read_text()
)
def recursive_update(base, update):
@@ -376,6 +381,7 @@ def parse_args(argv=None):
defaults = {}
for name in name_list:
recursive_update(defaults, configs[name])
parser = argparse.ArgumentParser()
for key, value in sorted(defaults.items(), key=lambda x: x[0]):
arg_type = tools.args_type(value)
+5 -5
View File
@@ -195,8 +195,8 @@ def reshape_state(state: Float[Tensor, 'frames state_dim']) -> (Float[Tensor,'fr
https://github.com/MichaelTMatthews/Craftax/blob/main/obs_description.md
"""
map = rearrange(state[:, :8217], 'frames (h w c) -> frames h w c', h=9, w=11, c=83)
# now pad from (9,11) to (12,12) using torch
map = F.pad(map, (0, 0, 1, 0, 2, 1))
# now pad from (9,11) to (16,16) so that it's 2*n
map = F.pad(map, (0, 0, 3, 2, 4, 3))
map = rearrange(map, 'f h w c -> h w (f c)')
inventories = rearrange(state[:, 8217:], 'frames c -> (frames c)')
return map, inventories
@@ -214,9 +214,9 @@ class Craftax:
def observation_space(self):
frames = self._env.observation_space.shape[0]
spaces = {
"state": gym.spaces.Box(0, 1, (np.prod(self._env.observation_space.shape),), dtype=np.float32),
"state_map": gym.spaces.Box(0, 1, (12, 12, frames*83), dtype=np.float32),
"state_inventory": gym.spaces.Box(0, 1, (frames * 51,), dtype=np.float32),
# "state": gym.spaces.Box(0, 1, (np.prod(self._env.observation_space.shape),), dtype=np.float32),
"state_map": gym.spaces.Box(0, 1, (16, 16, frames*83), dtype=np.float16),
"state_inventory": gym.spaces.Box(0, 1, (frames * 51,), dtype=np.float16),
"image": gym.spaces.Box(0, 255, (130, 110, 3), dtype=np.uint8),
"is_first": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
"is_last": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
+777 -69
View File
@@ -9,7 +9,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -19,7 +19,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [
{
@@ -48,7 +48,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [
{
@@ -61,10 +61,10 @@
{
"data": {
"text/plain": [
"Namespace(act='SiLU', action_repeat=1, actor={'layers': 3, 'dist': 'onehot', 'entropy': 0.0003, 'unimix_ratio': 0.01, 'std': 'none', 'min_std': 0.1, 'max_std': 1.0, 'temp': 0.1, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 1.0}, batch_length=16, batch_size=256, compile=True, cont_head={'layers': 3, 'loss_scale': 1.0, 'outscale': 1.0}, critic={'layers': 2, 'dist': 'symlog_disc', 'slow_target': True, 'slow_target_update': 1, 'slow_target_fraction': 0.02, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 0.0}, dataset_size=1000000, debug=False, decoder={'mlp_keys': 'state_inventory', 'cnn_keys': 'state_map', 'act': 'SiLU', 'norm': True, 'cnn_depth': 16, 'kernel_size': 4, 'minres': 4, 'mlp_layers': 2, 'mlp_units': 16, 'cnn_sigmoid': False, 'image_dist': 'mse', 'vector_dist': 'symlog_mse', 'outscale': 1.0}, deterministic_run=False, device='cuda:0', disag_action_cond=False, disag_layers=4, disag_log=True, disag_models=10, disag_offset=1, disag_target='stoch', disag_units=400, discount=0.997, discount_lambda=0.95, dyn_deter=512, dyn_discrete=32, dyn_hidden=256, dyn_mean_act='none', dyn_min_std=0.1, dyn_rec_depth=1, dyn_scale=0.5, dyn_std_act='sigmoid2', dyn_stoch=32, encoder={'mlp_keys': 'state_inventory', 'cnn_keys': 'state_map', 'act': 'SiLU', 'norm': True, 'cnn_depth': 16, 'kernel_size': 4, 'minres': 4, 'mlp_layers': 2, 'mlp_units': 16, 'symlog_inputs': True}, envs=1, eval_episode_num=10, eval_every=10000.0, eval_state_mean=False, evaldir=None, expl_behavior='greedy', expl_extr_scale=0.0, expl_intr_scale=1.0, expl_until=0, grad_clip=1000, grad_heads=('decoder', 'reward', 'cont'), grayscale=False, imag_gradient='reinforce', imag_gradient_mix=0.0, imag_horizon=15, initial='learned', kl_free=1.0, log_every=10000.0, logdir='../logdir/craftax_small2', model_lr=0.0001, norm=True, offline_evaldir='', offline_traindir='', opt='adam', opt_eps=1e-08, parallel=False, precision=32, prefill=2500, pretrain=100, rep_scale=0.1, reset_every=0, reward_EMA=True, reward_head={'layers': 3, 'dist': 'symlog_disc', 'loss_scale': 1.0, 'outscale': 0.0}, seed=0, size=(64, 64), step=1000000.0, steps=1000000.0, task='craftax_Craftax-Symbolic-AutoReset-v1', time_limit=1000, train_ratio=512, traindir=None, unimix_ratio=0.01, units=256, value={'layers': 3}, video_pred_log=False, weight_decay=0.0)"
"Namespace(act='SiLU', action_repeat=1, actor={'layers': 3, 'dist': 'onehot', 'entropy': 0.0003, 'unimix_ratio': 0.01, 'std': 'none', 'min_std': 0.1, 'max_std': 1.0, 'temp': 0.1, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 1.0}, batch_length=32, batch_size=256, compile=True, cont_head={'layers': 3, 'loss_scale': 1.0, 'outscale': 1.0}, critic={'layers': 2, 'dist': 'symlog_disc', 'slow_target': True, 'slow_target_update': 1, 'slow_target_fraction': 0.02, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 0.0}, dataset_size=1000000, debug=False, decoder={'mlp_keys': 'state_inventory', 'cnn_keys': 'state_map', 'act': 'SiLU', 'norm': True, 'cnn_depth': 32, 'kernel_size': 4, 'minres': 2, 'mlp_layers': 2, 'mlp_units': 16, 'cnn_sigmoid': False, 'image_dist': 'mse', 'vector_dist': 'symlog_mse', 'outscale': 1.0}, deterministic_run=False, device='cuda:0', disag_action_cond=False, disag_layers=4, disag_log=True, disag_models=10, disag_offset=1, disag_target='stoch', disag_units=400, discount=0.997, discount_lambda=0.95, dyn_deter=256, dyn_discrete=24, dyn_hidden=256, dyn_mean_act='none', dyn_min_std=0.1, dyn_rec_depth=1, dyn_scale=0.5, dyn_std_act='sigmoid2', dyn_stoch=24, encoder={'mlp_keys': 'state_inventory', 'cnn_keys': 'state_map', 'act': 'SiLU', 'norm': True, 'cnn_depth': 32, 'kernel_size': 4, 'minres': 2, 'mlp_layers': 2, 'mlp_units': 16, 'symlog_inputs': True}, envs=1, eval_episode_num=10, eval_every=10000.0, eval_state_mean=False, evaldir=None, expl_behavior='greedy', expl_extr_scale=0.0, expl_intr_scale=1.0, expl_until=0, grad_clip=1000, grad_heads=('decoder', 'reward', 'cont'), grayscale=False, imag_gradient='reinforce', imag_gradient_mix=0.0, imag_horizon=15, initial='learned', kl_free=1.0, log_every=10000.0, logdir='../logdir/craftax_small2', model_lr=0.0001, norm=True, offline_evaldir='', offline_traindir='', opt='adam', opt_eps=1e-08, parallel=False, precision=32, prefill=2500, pretrain=100, rep_scale=0.1, reset_every=0, reward_EMA=True, reward_head={'layers': 3, 'dist': 'symlog_disc', 'loss_scale': 1.0, 'outscale': 0.0}, seed=0, size=(64, 64), step=1000000.0, steps=1000000.0, task='craftax_Craftax-Symbolic-AutoReset-v1', time_limit=1000, train_ratio=512, traindir=None, unimix_ratio=0.01, units=256, value={'layers': 3}, video_pred_log=False, weight_decay=0.0)"
]
},
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@@ -80,27 +80,19 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32m2024-06-06 16:21:39.870\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m27\u001b[0m - \u001b[1mLogdir ../logdir/craftax_small2\u001b[0m\n",
"\u001b[32m2024-06-06 16:21:39.887\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m36\u001b[0m - \u001b[1mCreate envs.\u001b[0m\n",
"\u001b[32m2024-06-06 16:22:16.800\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m57\u001b[0m - \u001b[1mAction Space Box(0.0, 1.0, (43,), float32)\u001b[0m\n",
"\u001b[32m2024-06-06 16:22:16.801\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m63\u001b[0m - \u001b[1mPrefill dataset (2500 steps).\u001b[0m\n",
"\u001b[32m2024-06-06 16:23:40.174\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mtools\u001b[0m:\u001b[36mwrite\u001b[0m:\u001b[36m85\u001b[0m - \u001b[1m[0] log_achievements/cast_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/cast_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_coal \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_diamond \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_drink \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_iron \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_ruby \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapling \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapphire \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_wood \u001b[31m100.0\u001b[0m\u001b[1m / log_achievements/damage_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_deep_thing \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_fire_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_frost_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_warrior \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_ice_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_knight \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_kobold \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_lizard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_mage \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_solider \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_pigman \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_skeleton \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_zombie \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/drink_potion \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_bat \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_cow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_snail \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_dungeon \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_fire_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_gnomish_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_graveyard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_ice_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_sewers \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_troll_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_vault \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/find_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/fire_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_arrow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/open_chest \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_furnace \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_table \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/wake_up \u001b[31m100.0\u001b[0m\u001b[1m / dataset_size \u001b[31m242.0\u001b[0m\u001b[1m / train_return \u001b[31m1.1\u001b[0m\u001b[1m / train_length \u001b[31m242.0\u001b[0m\u001b[1m / train_episodes \u001b[31m1.0\u001b[0m\u001b[1m\u001b[0m\n",
"\u001b[32m2024-06-06 16:23:42.913\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mtools\u001b[0m:\u001b[36mwrite\u001b[0m:\u001b[36m85\u001b[0m - \u001b[1m[0] log_achievements/cast_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/cast_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_coal \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_diamond \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_drink \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_iron \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_ruby \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapling \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapphire \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_wood \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/damage_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_deep_thing \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_fire_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_frost_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_warrior \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_ice_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_knight \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_kobold \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_lizard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_mage \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_solider \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_pigman \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_skeleton \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_zombie \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/drink_potion \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_bat \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_cow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_snail \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_dungeon \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_fire_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_gnomish_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_graveyard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_ice_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_sewers \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_troll_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_vault \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/find_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/fire_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_arrow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/open_chest \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_furnace \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_table \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/wake_up \u001b[31m100.0\u001b[0m\u001b[1m / dataset_size \u001b[31m551.0\u001b[0m\u001b[1m / train_return \u001b[31m0.1\u001b[0m\u001b[1m / train_length \u001b[31m309.0\u001b[0m\u001b[1m / train_episodes \u001b[31m2.0\u001b[0m\u001b[1m\u001b[0m\n",
"\u001b[32m2024-06-06 16:23:44.957\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mtools\u001b[0m:\u001b[36mwrite\u001b[0m:\u001b[36m85\u001b[0m - \u001b[1m[0] log_achievements/cast_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/cast_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_coal \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_diamond \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_drink \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_iron \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_ruby \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapling \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapphire \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_wood \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/damage_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_deep_thing \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_fire_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_frost_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_warrior \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_ice_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_knight \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_kobold \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_lizard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_mage \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_solider \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_pigman \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_skeleton \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_zombie \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/drink_potion \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_bat \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_cow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_snail \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_dungeon \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_fire_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_gnomish_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_graveyard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_ice_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_sewers \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_troll_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_vault \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/find_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/fire_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_arrow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/open_chest \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_furnace \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_table \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/wake_up \u001b[31m100.0\u001b[0m\u001b[1m / dataset_size \u001b[31m781.0\u001b[0m\u001b[1m / train_return \u001b[31m0.1\u001b[0m\u001b[1m / train_length \u001b[31m230.0\u001b[0m\u001b[1m / train_episodes \u001b[31m3.0\u001b[0m\u001b[1m\u001b[0m\n",
"\u001b[32m2024-06-06 16:23:48.066\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mtools\u001b[0m:\u001b[36mwrite\u001b[0m:\u001b[36m85\u001b[0m - \u001b[1m[0] log_achievements/cast_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/cast_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_coal \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_diamond \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_drink \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_iron \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_ruby \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapling \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapphire \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_wood \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/damage_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_deep_thing \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_fire_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_frost_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_warrior \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_ice_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_knight \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_kobold \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_lizard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_mage \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_solider \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_pigman \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_skeleton \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_zombie \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/drink_potion \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_bat \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_cow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_snail \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_dungeon \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_fire_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_gnomish_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_graveyard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_ice_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_sewers \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_troll_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_vault \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/find_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/fire_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_arrow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/open_chest \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_furnace \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_table \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/wake_up \u001b[31m100.0\u001b[0m\u001b[1m / dataset_size \u001b[31m1134.0\u001b[0m\u001b[1m / train_return \u001b[31m0.1\u001b[0m\u001b[1m / train_length \u001b[31m353.0\u001b[0m\u001b[1m / train_episodes \u001b[31m4.0\u001b[0m\u001b[1m\u001b[0m\n",
"\u001b[32m2024-06-06 16:23:51.250\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mtools\u001b[0m:\u001b[36mwrite\u001b[0m:\u001b[36m85\u001b[0m - \u001b[1m[0] log_achievements/cast_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/cast_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_coal \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_diamond \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_drink \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_iron \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_ruby \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapling \u001b[31m100.0\u001b[0m\u001b[1m / log_achievements/collect_sapphire \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_wood \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/damage_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_deep_thing \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_fire_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_frost_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_warrior \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_ice_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_knight \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_kobold \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_lizard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_mage \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_solider \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_pigman \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_skeleton \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_zombie \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/drink_potion \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_bat \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_cow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_snail \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_dungeon \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_fire_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_gnomish_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_graveyard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_ice_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_sewers \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_troll_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_vault \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/find_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/fire_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_arrow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/open_chest \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_furnace \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_plant \u001b[31m100.0\u001b[0m\u001b[1m / log_achievements/place_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_table \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/wake_up \u001b[31m100.0\u001b[0m\u001b[1m / dataset_size \u001b[31m1509.0\u001b[0m\u001b[1m / train_return \u001b[31m2.1\u001b[0m\u001b[1m / train_length \u001b[31m375.0\u001b[0m\u001b[1m / train_episodes \u001b[31m5.0\u001b[0m\u001b[1m\u001b[0m\n",
"\u001b[32m2024-06-06 16:23:54.187\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mtools\u001b[0m:\u001b[36mwrite\u001b[0m:\u001b[36m85\u001b[0m - \u001b[1m[0] log_achievements/cast_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/cast_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_coal \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_diamond \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_drink \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_iron \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_ruby \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapling \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapphire \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_wood \u001b[31m100.0\u001b[0m\u001b[1m / log_achievements/damage_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_deep_thing \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_fire_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_frost_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_warrior \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_ice_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_knight \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_kobold \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_lizard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_mage \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_solider \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_pigman \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_skeleton \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_zombie \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/drink_potion \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_bat \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_cow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_snail \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_dungeon \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_fire_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_gnomish_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_graveyard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_ice_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_sewers \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_troll_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_vault \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/find_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/fire_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_arrow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/open_chest \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_furnace \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_table \u001b[31m100.0\u001b[0m\u001b[1m / log_achievements/place_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/wake_up \u001b[31m100.0\u001b[0m\u001b[1m / dataset_size \u001b[31m1827.0\u001b[0m\u001b[1m / train_return \u001b[31m2.1\u001b[0m\u001b[1m / train_length \u001b[31m318.0\u001b[0m\u001b[1m / train_episodes \u001b[31m6.0\u001b[0m\u001b[1m\u001b[0m\n",
"\u001b[32m2024-06-06 16:23:56.423\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mtools\u001b[0m:\u001b[36mwrite\u001b[0m:\u001b[36m85\u001b[0m - \u001b[1m[0] log_achievements/cast_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/cast_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_coal \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_diamond \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_drink \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_iron \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_ruby \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapling \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapphire \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_wood \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/damage_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_deep_thing \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_fire_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_frost_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_warrior \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_ice_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_knight \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_kobold \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_lizard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_mage \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_solider \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_pigman \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_skeleton \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_zombie \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/drink_potion \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_bat \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_cow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_snail \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_dungeon \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_fire_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_gnomish_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_graveyard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_ice_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_sewers \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_troll_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_vault \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/find_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/fire_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_arrow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/open_chest \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_furnace \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_table \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/wake_up \u001b[31m100.0\u001b[0m\u001b[1m / dataset_size \u001b[31m2084.0\u001b[0m\u001b[1m / train_return \u001b[31m0.1\u001b[0m\u001b[1m / train_length \u001b[31m257.0\u001b[0m\u001b[1m / train_episodes \u001b[31m7.0\u001b[0m\u001b[1m\u001b[0m\n",
"\u001b[32m2024-06-06 16:23:59.835\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mtools\u001b[0m:\u001b[36mwrite\u001b[0m:\u001b[36m85\u001b[0m - \u001b[1m[0] log_achievements/cast_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/cast_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_coal \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_diamond \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_drink \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_iron \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_ruby \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapling \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_sapphire \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/collect_wood \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/damage_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_deep_thing \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_fire_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_frost_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_archer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_gnome_warrior \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_ice_elemental \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_knight \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_kobold \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_lizard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_necromancer \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_mage \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_orc_solider \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_pigman \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_skeleton \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_troll \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/defeat_zombie \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/drink_potion \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_bat \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_cow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/eat_snail \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enchant_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_dungeon \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_fire_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_gnomish_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_graveyard \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_ice_realm \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_sewers \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_troll_mines \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/enter_vault \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/find_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/fire_bow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_fireball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/learn_iceball \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_arrow \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_diamond_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_armour \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_iron_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_stone_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_pickaxe \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/make_wood_sword \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/open_chest \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_furnace \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_plant \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_stone \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_table \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/place_torch \u001b[31m0.0\u001b[0m\u001b[1m / log_achievements/wake_up \u001b[31m100.0\u001b[0m\u001b[1m / dataset_size \u001b[31m2474.0\u001b[0m\u001b[1m / train_return \u001b[31m0.1\u001b[0m\u001b[1m / train_length \u001b[31m390.0\u001b[0m\u001b[1m / train_episodes \u001b[31m8.0\u001b[0m\u001b[1m\u001b[0m\n",
"\u001b[32m2024-06-06 16:24:00.056\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m92\u001b[0m - \u001b[1mLogger: (2500 steps).\u001b[0m\n",
"\u001b[32m2024-06-06 16:24:00.057\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m94\u001b[0m - \u001b[1mSimulate agent.\u001b[0m\n"
"\u001b[32m2024-06-06 17:08:10.379\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m27\u001b[0m - \u001b[1mLogdir ../logdir/craftax_small2\u001b[0m\n",
"\u001b[32m2024-06-06 17:08:10.384\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m36\u001b[0m - \u001b[1mCreate envs.\u001b[0m\n",
"\u001b[32m2024-06-06 17:08:41.190\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m57\u001b[0m - \u001b[1mAction Space Box(0.0, 1.0, (43,), float32)\u001b[0m\n",
"\u001b[32m2024-06-06 17:08:41.191\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m63\u001b[0m - \u001b[1mPrefill dataset (26 steps).\u001b[0m\n",
"\u001b[32m2024-06-06 17:09:31.587\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m92\u001b[0m - \u001b[1mLogger: (2500 steps).\u001b[0m\n",
"\u001b[32m2024-06-06 17:09:31.588\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m94\u001b[0m - \u001b[1mSimulate agent.\u001b[0m\n"
]
}
],
@@ -205,18 +197,476 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/.venv/lib/python3.9/site-packages/numpy/core/numeric.py:330: RuntimeWarning: invalid value encountered in cast\n",
" multiarray.copyto(a, fill_value, casting='unsafe')\n"
]
},
{
"data": {
"text/plain": [
"Dict('image': Box(0, 255, (130, 110, 3), uint8), 'is_first': Box(0, 0, (1,), uint8), 'is_last': Box(0, 0, (1,), uint8), 'is_terminal': Box(0, 0, (1,), uint8), 'log_achievement_cast_fireball': Box(-inf, inf, (1,), float32), 'log_achievement_cast_iceball': Box(-inf, inf, (1,), float32), 'log_achievement_collect_coal': Box(-inf, inf, (1,), float32), 'log_achievement_collect_diamond': Box(-inf, inf, (1,), float32), 'log_achievement_collect_drink': Box(-inf, inf, (1,), float32), 'log_achievement_collect_iron': Box(-inf, inf, (1,), float32), 'log_achievement_collect_ruby': Box(-inf, inf, (1,), float32), 'log_achievement_collect_sapling': Box(-inf, inf, (1,), float32), 'log_achievement_collect_sapphire': Box(-inf, inf, (1,), float32), 'log_achievement_collect_stone': Box(-inf, inf, (1,), float32), 'log_achievement_collect_wood': Box(-inf, inf, (1,), float32), 'log_achievement_damage_necromancer': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_archer': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_deep_thing': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_fire_elemental': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_frost_troll': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_gnome_archer': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_gnome_warrior': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_ice_elemental': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_knight': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_kobold': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_lizard': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_necromancer': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_orc_mage': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_orc_solider': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_pigman': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_skeleton': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_troll': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_zombie': Box(-inf, inf, (1,), float32), 'log_achievement_drink_potion': Box(-inf, inf, (1,), float32), 'log_achievement_eat_bat': Box(-inf, inf, (1,), float32), 'log_achievement_eat_cow': Box(-inf, inf, (1,), float32), 'log_achievement_eat_plant': Box(-inf, inf, (1,), float32), 'log_achievement_eat_snail': Box(-inf, inf, (1,), float32), 'log_achievement_enchant_armour': Box(-inf, inf, (1,), float32), 'log_achievement_enchant_sword': Box(-inf, inf, (1,), float32), 'log_achievement_enter_dungeon': Box(-inf, inf, (1,), float32), 'log_achievement_enter_fire_realm': Box(-inf, inf, (1,), float32), 'log_achievement_enter_gnomish_mines': Box(-inf, inf, (1,), float32), 'log_achievement_enter_graveyard': Box(-inf, inf, (1,), float32), 'log_achievement_enter_ice_realm': Box(-inf, inf, (1,), float32), 'log_achievement_enter_sewers': Box(-inf, inf, (1,), float32), 'log_achievement_enter_troll_mines': Box(-inf, inf, (1,), float32), 'log_achievement_enter_vault': Box(-inf, inf, (1,), float32), 'log_achievement_find_bow': Box(-inf, inf, (1,), float32), 'log_achievement_fire_bow': Box(-inf, inf, (1,), float32), 'log_achievement_learn_fireball': Box(-inf, inf, (1,), float32), 'log_achievement_learn_iceball': Box(-inf, inf, (1,), float32), 'log_achievement_make_arrow': Box(-inf, inf, (1,), float32), 'log_achievement_make_diamond_armour': Box(-inf, inf, (1,), float32), 'log_achievement_make_diamond_pickaxe': Box(-inf, inf, (1,), float32), 'log_achievement_make_diamond_sword': Box(-inf, inf, (1,), float32), 'log_achievement_make_iron_armour': Box(-inf, inf, (1,), float32), 'log_achievement_make_iron_pickaxe': Box(-inf, inf, (1,), float32), 'log_achievement_make_iron_sword': Box(-inf, inf, (1,), float32), 'log_achievement_make_stone_pickaxe': Box(-inf, inf, (1,), float32), 'log_achievement_make_stone_sword': Box(-inf, inf, (1,), float32), 'log_achievement_make_torch': Box(-inf, inf, (1,), float32), 'log_achievement_make_wood_pickaxe': Box(-inf, inf, (1,), float32), 'log_achievement_make_wood_sword': Box(-inf, inf, (1,), float32), 'log_achievement_open_chest': Box(-inf, inf, (1,), float32), 'log_achievement_place_furnace': Box(-inf, inf, (1,), float32), 'log_achievement_place_plant': Box(-inf, inf, (1,), float32), 'log_achievement_place_stone': Box(-inf, inf, (1,), float32), 'log_achievement_place_table': Box(-inf, inf, (1,), float32), 'log_achievement_place_torch': Box(-inf, inf, (1,), float32), 'log_achievement_wake_up': Box(-inf, inf, (1,), float32), 'log_reward': Box(-inf, inf, (1,), float32), 'state': Box(0.0, 1.0, (16536,), float32), 'state_inventory': Box(0.0, 1.0, (102,), float32), 'state_map': Box(0.0, 1.0, (12, 12, 166), float32))"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_envs[0].observation_space"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32m2024-06-06 17:09:31.695\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m324\u001b[0m - \u001b[1mEncoder CNN shapes: {'state_map': (12, 12, 166)}\u001b[0m\n",
"\u001b[32m2024-06-06 17:09:31.696\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m325\u001b[0m - \u001b[1mEncoder MLP shapes: {'state_inventory': (102,)}\u001b[0m\n",
"\u001b[32m2024-06-06 17:09:31.913\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m391\u001b[0m - \u001b[1mDecoder CNN shapes: {'state_map': (12, 12, 166)}\u001b[0m\n",
"\u001b[32m2024-06-06 17:09:31.914\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m392\u001b[0m - \u001b[1mDecoder MLP shapes: {'state_inventory': (102,)}\u001b[0m\n",
"\u001b[32m2024-06-06 17:09:32.650\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmodels\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m102\u001b[0m - \u001b[1mOptimizer model_opt has 2357196 variables.\u001b[0m\n",
"\u001b[32m2024-06-06 17:09:32.657\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmodels\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m281\u001b[0m - \u001b[1mOptimizer actor_opt has 356651 variables.\u001b[0m\n",
"\u001b[32m2024-06-06 17:09:32.657\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmodels\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m292\u001b[0m - \u001b[1mOptimizer value_opt has 345087 variables.\u001b[0m\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/wassname/miniforge3/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
" self.pid = os.fork()\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dreamer(\n",
" (_wm): OptimizedModule(\n",
" (_orig_mod): WorldModel(\n",
" (encoder): MultiEncoder(\n",
" (_cnn): ConvEncoder(\n",
" (layers): Sequential(\n",
" (0): Conv2dSamePad(166, 32, kernel_size=(4, 4), stride=(2, 2), bias=False)\n",
" (1): ImgChLayerNorm(\n",
" (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n",
" )\n",
" (2): SiLU()\n",
" (3): Conv2dSamePad(32, 64, kernel_size=(4, 4), stride=(2, 2), bias=False)\n",
" (4): ImgChLayerNorm(\n",
" (norm): LayerNorm((64,), eps=0.001, elementwise_affine=True)\n",
" )\n",
" (5): SiLU()\n",
" )\n",
" )\n",
" (_mlp): MLP(\n",
" (layers): Sequential(\n",
" (Encoder_linear0): Linear(in_features=102, out_features=16, bias=False)\n",
" (Encoder_norm0): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n",
" (Encoder_act0): SiLU()\n",
" (Encoder_linear1): Linear(in_features=16, out_features=16, bias=False)\n",
" (Encoder_norm1): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n",
" (Encoder_act1): SiLU()\n",
" )\n",
" )\n",
" )\n",
" (dynamics): RSSM(\n",
" (_img_in_layers): Sequential(\n",
" (0): Linear(in_features=619, out_features=256, bias=False)\n",
" (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (2): SiLU()\n",
" )\n",
" (_cell): GRUCell(\n",
" (layers): Sequential(\n",
" (GRU_linear): Linear(in_features=512, out_features=768, bias=False)\n",
" (GRU_norm): LayerNorm((768,), eps=0.001, elementwise_affine=True)\n",
" )\n",
" )\n",
" (_img_out_layers): Sequential(\n",
" (0): Linear(in_features=256, out_features=256, bias=False)\n",
" (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (2): SiLU()\n",
" )\n",
" (_obs_out_layers): Sequential(\n",
" (0): Linear(in_features=848, out_features=256, bias=False)\n",
" (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (2): SiLU()\n",
" )\n",
" (_imgs_stat_layer): Linear(in_features=256, out_features=576, bias=True)\n",
" (_obs_stat_layer): Linear(in_features=256, out_features=576, bias=True)\n",
" )\n",
" (heads): ModuleDict(\n",
" (decoder): MultiDecoder(\n",
" (_cnn): ConvDecoder(\n",
" (_linear_layer): Linear(in_features=832, out_features=256, bias=True)\n",
" (layers): Sequential(\n",
" (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (1): ImgChLayerNorm(\n",
" (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n",
" )\n",
" (2): SiLU()\n",
" (3): ConvTranspose2d(32, 166, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
" )\n",
" )\n",
" (_mlp): MLP(\n",
" (layers): Sequential(\n",
" (Decoder_linear0): Linear(in_features=832, out_features=16, bias=False)\n",
" (Decoder_norm0): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n",
" (Decoder_act0): SiLU()\n",
" (Decoder_linear1): Linear(in_features=16, out_features=16, bias=False)\n",
" (Decoder_norm1): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n",
" (Decoder_act1): SiLU()\n",
" )\n",
" (mean_layer): ModuleDict(\n",
" (state_inventory): Linear(in_features=16, out_features=102, bias=True)\n",
" )\n",
" )\n",
" )\n",
" (reward): MLP(\n",
" (layers): Sequential(\n",
" (Reward_linear0): Linear(in_features=832, out_features=256, bias=False)\n",
" (Reward_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Reward_act0): SiLU()\n",
" (Reward_linear1): Linear(in_features=256, out_features=256, bias=False)\n",
" (Reward_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Reward_act1): SiLU()\n",
" (Reward_linear2): Linear(in_features=256, out_features=256, bias=False)\n",
" (Reward_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Reward_act2): SiLU()\n",
" )\n",
" (mean_layer): Linear(in_features=256, out_features=255, bias=True)\n",
" )\n",
" (cont): MLP(\n",
" (layers): Sequential(\n",
" (Cont_linear0): Linear(in_features=832, out_features=256, bias=False)\n",
" (Cont_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Cont_act0): SiLU()\n",
" (Cont_linear1): Linear(in_features=256, out_features=256, bias=False)\n",
" (Cont_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Cont_act1): SiLU()\n",
" (Cont_linear2): Linear(in_features=256, out_features=256, bias=False)\n",
" (Cont_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Cont_act2): SiLU()\n",
" )\n",
" (mean_layer): Linear(in_features=256, out_features=1, bias=True)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (_task_behavior): OptimizedModule(\n",
" (_orig_mod): ImagBehavior(\n",
" (_world_model): WorldModel(\n",
" (encoder): MultiEncoder(\n",
" (_cnn): ConvEncoder(\n",
" (layers): Sequential(\n",
" (0): Conv2dSamePad(166, 32, kernel_size=(4, 4), stride=(2, 2), bias=False)\n",
" (1): ImgChLayerNorm(\n",
" (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n",
" )\n",
" (2): SiLU()\n",
" (3): Conv2dSamePad(32, 64, kernel_size=(4, 4), stride=(2, 2), bias=False)\n",
" (4): ImgChLayerNorm(\n",
" (norm): LayerNorm((64,), eps=0.001, elementwise_affine=True)\n",
" )\n",
" (5): SiLU()\n",
" )\n",
" )\n",
" (_mlp): MLP(\n",
" (layers): Sequential(\n",
" (Encoder_linear0): Linear(in_features=102, out_features=16, bias=False)\n",
" (Encoder_norm0): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n",
" (Encoder_act0): SiLU()\n",
" (Encoder_linear1): Linear(in_features=16, out_features=16, bias=False)\n",
" (Encoder_norm1): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n",
" (Encoder_act1): SiLU()\n",
" )\n",
" )\n",
" )\n",
" (dynamics): RSSM(\n",
" (_img_in_layers): Sequential(\n",
" (0): Linear(in_features=619, out_features=256, bias=False)\n",
" (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (2): SiLU()\n",
" )\n",
" (_cell): GRUCell(\n",
" (layers): Sequential(\n",
" (GRU_linear): Linear(in_features=512, out_features=768, bias=False)\n",
" (GRU_norm): LayerNorm((768,), eps=0.001, elementwise_affine=True)\n",
" )\n",
" )\n",
" (_img_out_layers): Sequential(\n",
" (0): Linear(in_features=256, out_features=256, bias=False)\n",
" (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (2): SiLU()\n",
" )\n",
" (_obs_out_layers): Sequential(\n",
" (0): Linear(in_features=848, out_features=256, bias=False)\n",
" (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (2): SiLU()\n",
" )\n",
" (_imgs_stat_layer): Linear(in_features=256, out_features=576, bias=True)\n",
" (_obs_stat_layer): Linear(in_features=256, out_features=576, bias=True)\n",
" )\n",
" (heads): ModuleDict(\n",
" (decoder): MultiDecoder(\n",
" (_cnn): ConvDecoder(\n",
" (_linear_layer): Linear(in_features=832, out_features=256, bias=True)\n",
" (layers): Sequential(\n",
" (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (1): ImgChLayerNorm(\n",
" (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n",
" )\n",
" (2): SiLU()\n",
" (3): ConvTranspose2d(32, 166, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
" )\n",
" )\n",
" (_mlp): MLP(\n",
" (layers): Sequential(\n",
" (Decoder_linear0): Linear(in_features=832, out_features=16, bias=False)\n",
" (Decoder_norm0): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n",
" (Decoder_act0): SiLU()\n",
" (Decoder_linear1): Linear(in_features=16, out_features=16, bias=False)\n",
" (Decoder_norm1): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n",
" (Decoder_act1): SiLU()\n",
" )\n",
" (mean_layer): ModuleDict(\n",
" (state_inventory): Linear(in_features=16, out_features=102, bias=True)\n",
" )\n",
" )\n",
" )\n",
" (reward): MLP(\n",
" (layers): Sequential(\n",
" (Reward_linear0): Linear(in_features=832, out_features=256, bias=False)\n",
" (Reward_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Reward_act0): SiLU()\n",
" (Reward_linear1): Linear(in_features=256, out_features=256, bias=False)\n",
" (Reward_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Reward_act1): SiLU()\n",
" (Reward_linear2): Linear(in_features=256, out_features=256, bias=False)\n",
" (Reward_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Reward_act2): SiLU()\n",
" )\n",
" (mean_layer): Linear(in_features=256, out_features=255, bias=True)\n",
" )\n",
" (cont): MLP(\n",
" (layers): Sequential(\n",
" (Cont_linear0): Linear(in_features=832, out_features=256, bias=False)\n",
" (Cont_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Cont_act0): SiLU()\n",
" (Cont_linear1): Linear(in_features=256, out_features=256, bias=False)\n",
" (Cont_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Cont_act1): SiLU()\n",
" (Cont_linear2): Linear(in_features=256, out_features=256, bias=False)\n",
" (Cont_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Cont_act2): SiLU()\n",
" )\n",
" (mean_layer): Linear(in_features=256, out_features=1, bias=True)\n",
" )\n",
" )\n",
" )\n",
" (actor): MLP(\n",
" (layers): Sequential(\n",
" (Actor_linear0): Linear(in_features=832, out_features=256, bias=False)\n",
" (Actor_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Actor_act0): SiLU()\n",
" (Actor_linear1): Linear(in_features=256, out_features=256, bias=False)\n",
" (Actor_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Actor_act1): SiLU()\n",
" (Actor_linear2): Linear(in_features=256, out_features=256, bias=False)\n",
" (Actor_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Actor_act2): SiLU()\n",
" )\n",
" (mean_layer): Linear(in_features=256, out_features=43, bias=True)\n",
" )\n",
" (value): MLP(\n",
" (layers): Sequential(\n",
" (Value_linear0): Linear(in_features=832, out_features=256, bias=False)\n",
" (Value_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Value_act0): SiLU()\n",
" (Value_linear1): Linear(in_features=256, out_features=256, bias=False)\n",
" (Value_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Value_act1): SiLU()\n",
" )\n",
" (mean_layer): Linear(in_features=256, out_features=255, bias=True)\n",
" )\n",
" (_slow_value): MLP(\n",
" (layers): Sequential(\n",
" (Value_linear0): Linear(in_features=832, out_features=256, bias=False)\n",
" (Value_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Value_act0): SiLU()\n",
" (Value_linear1): Linear(in_features=256, out_features=256, bias=False)\n",
" (Value_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Value_act1): SiLU()\n",
" )\n",
" (mean_layer): Linear(in_features=256, out_features=255, bias=True)\n",
" )\n",
" )\n",
" )\n",
" (_expl_behavior): OptimizedModule(\n",
" (_orig_mod): ImagBehavior(\n",
" (_world_model): WorldModel(\n",
" (encoder): MultiEncoder(\n",
" (_cnn): ConvEncoder(\n",
" (layers): Sequential(\n",
" (0): Conv2dSamePad(166, 32, kernel_size=(4, 4), stride=(2, 2), bias=False)\n",
" (1): ImgChLayerNorm(\n",
" (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n",
" )\n",
" (2): SiLU()\n",
" (3): Conv2dSamePad(32, 64, kernel_size=(4, 4), stride=(2, 2), bias=False)\n",
" (4): ImgChLayerNorm(\n",
" (norm): LayerNorm((64,), eps=0.001, elementwise_affine=True)\n",
" )\n",
" (5): SiLU()\n",
" )\n",
" )\n",
" (_mlp): MLP(\n",
" (layers): Sequential(\n",
" (Encoder_linear0): Linear(in_features=102, out_features=16, bias=False)\n",
" (Encoder_norm0): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n",
" (Encoder_act0): SiLU()\n",
" (Encoder_linear1): Linear(in_features=16, out_features=16, bias=False)\n",
" (Encoder_norm1): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n",
" (Encoder_act1): SiLU()\n",
" )\n",
" )\n",
" )\n",
" (dynamics): RSSM(\n",
" (_img_in_layers): Sequential(\n",
" (0): Linear(in_features=619, out_features=256, bias=False)\n",
" (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (2): SiLU()\n",
" )\n",
" (_cell): GRUCell(\n",
" (layers): Sequential(\n",
" (GRU_linear): Linear(in_features=512, out_features=768, bias=False)\n",
" (GRU_norm): LayerNorm((768,), eps=0.001, elementwise_affine=True)\n",
" )\n",
" )\n",
" (_img_out_layers): Sequential(\n",
" (0): Linear(in_features=256, out_features=256, bias=False)\n",
" (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (2): SiLU()\n",
" )\n",
" (_obs_out_layers): Sequential(\n",
" (0): Linear(in_features=848, out_features=256, bias=False)\n",
" (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (2): SiLU()\n",
" )\n",
" (_imgs_stat_layer): Linear(in_features=256, out_features=576, bias=True)\n",
" (_obs_stat_layer): Linear(in_features=256, out_features=576, bias=True)\n",
" )\n",
" (heads): ModuleDict(\n",
" (decoder): MultiDecoder(\n",
" (_cnn): ConvDecoder(\n",
" (_linear_layer): Linear(in_features=832, out_features=256, bias=True)\n",
" (layers): Sequential(\n",
" (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (1): ImgChLayerNorm(\n",
" (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n",
" )\n",
" (2): SiLU()\n",
" (3): ConvTranspose2d(32, 166, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
" )\n",
" )\n",
" (_mlp): MLP(\n",
" (layers): Sequential(\n",
" (Decoder_linear0): Linear(in_features=832, out_features=16, bias=False)\n",
" (Decoder_norm0): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n",
" (Decoder_act0): SiLU()\n",
" (Decoder_linear1): Linear(in_features=16, out_features=16, bias=False)\n",
" (Decoder_norm1): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n",
" (Decoder_act1): SiLU()\n",
" )\n",
" (mean_layer): ModuleDict(\n",
" (state_inventory): Linear(in_features=16, out_features=102, bias=True)\n",
" )\n",
" )\n",
" )\n",
" (reward): MLP(\n",
" (layers): Sequential(\n",
" (Reward_linear0): Linear(in_features=832, out_features=256, bias=False)\n",
" (Reward_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Reward_act0): SiLU()\n",
" (Reward_linear1): Linear(in_features=256, out_features=256, bias=False)\n",
" (Reward_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Reward_act1): SiLU()\n",
" (Reward_linear2): Linear(in_features=256, out_features=256, bias=False)\n",
" (Reward_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Reward_act2): SiLU()\n",
" )\n",
" (mean_layer): Linear(in_features=256, out_features=255, bias=True)\n",
" )\n",
" (cont): MLP(\n",
" (layers): Sequential(\n",
" (Cont_linear0): Linear(in_features=832, out_features=256, bias=False)\n",
" (Cont_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Cont_act0): SiLU()\n",
" (Cont_linear1): Linear(in_features=256, out_features=256, bias=False)\n",
" (Cont_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Cont_act1): SiLU()\n",
" (Cont_linear2): Linear(in_features=256, out_features=256, bias=False)\n",
" (Cont_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Cont_act2): SiLU()\n",
" )\n",
" (mean_layer): Linear(in_features=256, out_features=1, bias=True)\n",
" )\n",
" )\n",
" )\n",
" (actor): MLP(\n",
" (layers): Sequential(\n",
" (Actor_linear0): Linear(in_features=832, out_features=256, bias=False)\n",
" (Actor_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Actor_act0): SiLU()\n",
" (Actor_linear1): Linear(in_features=256, out_features=256, bias=False)\n",
" (Actor_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Actor_act1): SiLU()\n",
" (Actor_linear2): Linear(in_features=256, out_features=256, bias=False)\n",
" (Actor_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Actor_act2): SiLU()\n",
" )\n",
" (mean_layer): Linear(in_features=256, out_features=43, bias=True)\n",
" )\n",
" (value): MLP(\n",
" (layers): Sequential(\n",
" (Value_linear0): Linear(in_features=832, out_features=256, bias=False)\n",
" (Value_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Value_act0): SiLU()\n",
" (Value_linear1): Linear(in_features=256, out_features=256, bias=False)\n",
" (Value_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Value_act1): SiLU()\n",
" )\n",
" (mean_layer): Linear(in_features=256, out_features=255, bias=True)\n",
" )\n",
" (_slow_value): MLP(\n",
" (layers): Sequential(\n",
" (Value_linear0): Linear(in_features=832, out_features=256, bias=False)\n",
" (Value_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Value_act0): SiLU()\n",
" (Value_linear1): Linear(in_features=256, out_features=256, bias=False)\n",
" (Value_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
" (Value_act1): SiLU()\n",
" )\n",
" (mean_layer): Linear(in_features=256, out_features=255, bias=True)\n",
" )\n",
" )\n",
" )\n",
")\n"
]
}
],
"source": [
"config = parse_args(argv)\n",
"config.num_actions = acts.n if hasattr(acts, \"n\") else acts.shape[0]\n",
@@ -227,7 +677,7 @@
" tlogger,\n",
" train_dataset,\n",
").to(config.device)\n",
"# print(agent)\n",
"print(agent)\n",
"agent.requires_grad_(requires_grad=False)\n",
"if (logdir / \"latest.pt\").exists():\n",
" checkpoint = torch.load(logdir / \"latest.pt\")\n",
@@ -262,19 +712,19 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"assert state is not None\n",
"import numpy as np\n",
"\n",
"state"
"# state"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
@@ -307,27 +757,33 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"envs[0].observation_space"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"obs[0]['state_map'].shape, obs[0]['state_inventory'].shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/nbs/../networks.py:790: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)\n",
" ret = F.conv2d(\n"
]
},
{
"ename": "AssertionError",
"evalue": "(torch.Size([256, 32, 8, 8, 166]), torch.Size([256, 32, 12, 12, 166]))",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[10], line 6\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# from tools.simulate\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# step\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# step, episode, done, length, obs, agent_state, reward = state\u001b[39;00m\n\u001b[1;32m 5\u001b[0m obs2 \u001b[38;5;241m=\u001b[39m {k: np\u001b[38;5;241m.\u001b[39mstack([o[k] \u001b[38;5;28;01mfor\u001b[39;00m o \u001b[38;5;129;01min\u001b[39;00m obs]) \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m obs[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlog_\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m k}\n\u001b[0;32m----> 6\u001b[0m action, agent_state \u001b[38;5;241m=\u001b[39m \u001b[43magent\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobs2\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdone\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43magent_state\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/nbs/../dreamer.py:71\u001b[0m, in \u001b[0;36mDreamer.__call__\u001b[0;34m(self, obs, reset, state, training)\u001b[0m\n\u001b[1;32m 65\u001b[0m steps \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 66\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_config\u001b[38;5;241m.\u001b[39mpretrain\n\u001b[1;32m 67\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_should_pretrain()\n\u001b[1;32m 68\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_should_train(step)\n\u001b[1;32m 69\u001b[0m )\n\u001b[1;32m 70\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(steps):\n\u001b[0;32m---> 71\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_train\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_update_count \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_metrics[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mupdate_count\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_update_count\n",
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/nbs/../dreamer.py:124\u001b[0m, in \u001b[0;36mDreamer._train\u001b[0;34m(self, data)\u001b[0m\n\u001b[1;32m 122\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_train\u001b[39m(\u001b[38;5;28mself\u001b[39m, data):\n\u001b[1;32m 123\u001b[0m metrics \u001b[38;5;241m=\u001b[39m {}\n\u001b[0;32m--> 124\u001b[0m post, context, mets \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_wm\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_train\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 125\u001b[0m metrics\u001b[38;5;241m.\u001b[39mupdate(mets)\n\u001b[1;32m 126\u001b[0m start \u001b[38;5;241m=\u001b[39m post\n",
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/nbs/../models.py:143\u001b[0m, in \u001b[0;36mWorldModel._train\u001b[0;34m(self, data)\u001b[0m\n\u001b[1;32m 141\u001b[0m losses \u001b[38;5;241m=\u001b[39m {}\n\u001b[1;32m 142\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m name, pred \u001b[38;5;129;01min\u001b[39;00m preds\u001b[38;5;241m.\u001b[39mitems():\n\u001b[0;32m--> 143\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[43mpred\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlog_prob\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m[\u001b[49m\u001b[43mname\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 144\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m loss\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m==\u001b[39m embed\u001b[38;5;241m.\u001b[39mshape[:\u001b[38;5;241m2\u001b[39m], (name, loss\u001b[38;5;241m.\u001b[39mshape)\n\u001b[1;32m 145\u001b[0m losses[name] \u001b[38;5;241m=\u001b[39m loss\n",
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/nbs/../tools.py:528\u001b[0m, in \u001b[0;36mMSEDist.log_prob\u001b[0;34m(self, value)\u001b[0m\n\u001b[1;32m 527\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mlog_prob\u001b[39m(\u001b[38;5;28mself\u001b[39m, value):\n\u001b[0;32m--> 528\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_mode\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m==\u001b[39m value\u001b[38;5;241m.\u001b[39mshape, (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_mode\u001b[38;5;241m.\u001b[39mshape, value\u001b[38;5;241m.\u001b[39mshape)\n\u001b[1;32m 529\u001b[0m distance \u001b[38;5;241m=\u001b[39m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_mode \u001b[38;5;241m-\u001b[39m value) \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m \u001b[38;5;241m2\u001b[39m\n\u001b[1;32m 530\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_agg \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmean\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n",
"\u001b[0;31mAssertionError\u001b[0m: (torch.Size([256, 32, 8, 8, 166]), torch.Size([256, 32, 12, 12, 166]))"
]
}
],
"source": [
"# from tools.simulate\n",
"\n",
@@ -341,10 +797,120 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"==========================================================================================\n",
"Layer (type:depth-idx) Param #\n",
"==========================================================================================\n",
"Dreamer --\n",
"├─OptimizedModule: 1-1 --\n",
"│ └─WorldModel: 2-1 --\n",
"│ │ └─MultiEncoder: 3-1 (44,480)\n",
"│ │ └─RSSM: 3-2 (2,397,952)\n",
"│ │ └─ModuleDict: 3-3 (1,580,204)\n",
"├─OptimizedModule: 1-2 --\n",
"│ └─ImagBehavior: 2-2 4,022,636\n",
"│ │ └─WorldModel: 3-4 (recursive)\n",
"│ │ └─MLP: 3-5 (536,875)\n",
"│ │ └─MLP: 3-6 (525,311)\n",
"│ │ └─MLP: 3-7 (525,311)\n",
"├─OptimizedModule: 1-3 (recursive)\n",
"│ └─ImagBehavior: 2-3 (recursive)\n",
"│ │ └─WorldModel: 3-8 (recursive)\n",
"│ │ └─MLP: 3-9 (recursive)\n",
"│ │ └─MLP: 3-10 (recursive)\n",
"│ │ └─MLP: 3-11 (recursive)\n",
"==========================================================================================\n",
"Total params: 9,632,769\n",
"Trainable params: 0\n",
"Non-trainable params: 9,632,769\n",
"=========================================================================================="
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from torchinfo import summary\n",
"\n",
"summary(agent, input=(obs, done, agent_state), depth=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"==========================================================================================\n",
"Layer (type:depth-idx) Param #\n",
"==========================================================================================\n",
"Dreamer --\n",
"├─OptimizedModule: 1-1 --\n",
"│ └─WorldModel: 2-1 --\n",
"│ │ └─MultiEncoder: 3-1 --\n",
"│ │ │ └─ConvEncoder: 4-1 (42,528)\n",
"│ │ │ └─MLP: 4-2 (1,952)\n",
"│ │ └─RSSM: 3-2 512\n",
"│ │ │ └─Sequential: 4-3 (273,664)\n",
"│ │ │ └─GRUCell: 4-4 (1,182,720)\n",
"│ │ │ └─Sequential: 4-5 (131,584)\n",
"│ │ │ └─Sequential: 4-6 (283,136)\n",
"│ │ │ └─Linear: 4-7 (263,168)\n",
"│ │ │ └─Linear: 4-8 (263,168)\n",
"│ │ └─ModuleDict: 3-3 --\n",
"│ │ │ └─MultiDecoder: 4-9 (462,764)\n",
"│ │ │ └─MLP: 4-10 (591,359)\n",
"│ │ │ └─MLP: 4-11 (526,081)\n",
"├─OptimizedModule: 1-2 --\n",
"│ └─ImagBehavior: 2-2 4,022,636\n",
"│ │ └─WorldModel: 3-4 (recursive)\n",
"│ │ │ └─MultiEncoder: 4-12 (recursive)\n",
"│ │ │ └─RSSM: 4-13 (recursive)\n",
"│ │ │ └─ModuleDict: 4-14 (recursive)\n",
"│ │ └─MLP: 3-5 --\n",
"│ │ │ └─Sequential: 4-15 (525,824)\n",
"│ │ │ └─Linear: 4-16 (11,051)\n",
"│ │ └─MLP: 3-6 --\n",
"│ │ │ └─Sequential: 4-17 (459,776)\n",
"│ │ │ └─Linear: 4-18 (65,535)\n",
"│ │ └─MLP: 3-7 --\n",
"│ │ │ └─Sequential: 4-19 (459,776)\n",
"│ │ │ └─Linear: 4-20 (65,535)\n",
"├─OptimizedModule: 1-3 (recursive)\n",
"│ └─ImagBehavior: 2-3 (recursive)\n",
"│ │ └─WorldModel: 3-8 (recursive)\n",
"│ │ │ └─MultiEncoder: 4-21 (recursive)\n",
"│ │ │ └─RSSM: 4-22 (recursive)\n",
"│ │ │ └─ModuleDict: 4-23 (recursive)\n",
"│ │ └─MLP: 3-9 (recursive)\n",
"│ │ │ └─Sequential: 4-24 (recursive)\n",
"│ │ │ └─Linear: 4-25 (recursive)\n",
"│ │ └─MLP: 3-10 (recursive)\n",
"│ │ │ └─Sequential: 4-26 (recursive)\n",
"│ │ │ └─Linear: 4-27 (recursive)\n",
"│ │ └─MLP: 3-11 (recursive)\n",
"│ │ │ └─Sequential: 4-28 (recursive)\n",
"│ │ │ └─Linear: 4-29 (recursive)\n",
"==========================================================================================\n",
"Total params: 9,632,769\n",
"Trainable params: 0\n",
"Non-trainable params: 9,632,769\n",
"=========================================================================================="
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"summary(agent, input=(obs, done, agent_state), depth=4)"
]
},
@@ -385,7 +951,46 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"===================================================================================================================\n",
"Layer (type:depth-idx) Input Shape Output Shape Param #\n",
"===================================================================================================================\n",
"MultiEncoder [256, 16, 130, 110, 3] [256, 16, 592] --\n",
"├─ConvEncoder: 1-1 [256, 16, 12, 12, 166] [256, 16, 576] --\n",
"│ └─Sequential: 2-1 [4096, 166, 12, 12] [4096, 16, 6, 6] --\n",
"│ │ └─Conv2dSamePad: 3-1 [4096, 166, 12, 12] [4096, 16, 6, 6] (42,496)\n",
"│ │ └─ImgChLayerNorm: 3-2 [4096, 16, 6, 6] [4096, 16, 6, 6] --\n",
"│ │ │ └─LayerNorm: 4-1 [4096, 6, 6, 16] [4096, 6, 6, 16] (32)\n",
"│ │ └─SiLU: 3-3 [4096, 16, 6, 6] [4096, 16, 6, 6] --\n",
"├─MLP: 1-2 [256, 16, 102] [256, 16, 16] --\n",
"│ └─Sequential: 2-2 [256, 16, 102] [256, 16, 16] --\n",
"│ │ └─Linear: 3-4 [256, 16, 102] [256, 16, 16] (1,632)\n",
"│ │ └─LayerNorm: 3-5 [256, 16, 16] [256, 16, 16] (32)\n",
"│ │ └─SiLU: 3-6 [256, 16, 16] [256, 16, 16] --\n",
"│ │ └─Linear: 3-7 [256, 16, 16] [256, 16, 16] (256)\n",
"│ │ └─LayerNorm: 3-8 [256, 16, 16] [256, 16, 16] (32)\n",
"│ │ └─SiLU: 3-9 [256, 16, 16] [256, 16, 16] --\n",
"===================================================================================================================\n",
"Total params: 44,480\n",
"Trainable params: 0\n",
"Non-trainable params: 44,480\n",
"Total mult-adds (G): 6.27\n",
"===================================================================================================================\n",
"Input size (MB): 1367.93\n",
"Forward/backward pass size (MB): 39.85\n",
"Params size (MB): 0.18\n",
"Estimated Total Size (MB): 1407.96\n",
"==================================================================================================================="
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"summary(wm.encoder, input_data=(data,), depth=4, col_names=[\"input_size\", \"output_size\", \"num_params\", ])"
]
@@ -394,7 +999,46 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"decoder\n",
"===================================================================================================================\n",
"Layer (type:depth-idx) Input Shape Output Shape Param #\n",
"===================================================================================================================\n",
"MultiDecoder [256, 16, 1536] -- --\n",
"├─ConvDecoder: 1-1 [256, 16, 1536] [256, 16, 8, 8, 166] --\n",
"│ └─Linear: 2-1 [256, 16, 1536] [256, 16, 256] (393,472)\n",
"│ └─Sequential: 2-2 [4096, 16, 4, 4] [4096, 166, 8, 8] --\n",
"│ │ └─ConvTranspose2d: 3-1 [4096, 16, 4, 4] [4096, 166, 8, 8] (42,662)\n",
"├─MLP: 1-2 [256, 16, 1536] -- --\n",
"│ └─Sequential: 2-3 [256, 16, 1536] [256, 16, 16] --\n",
"│ │ └─Linear: 3-2 [256, 16, 1536] [256, 16, 16] (24,576)\n",
"│ │ └─LayerNorm: 3-3 [256, 16, 16] [256, 16, 16] (32)\n",
"│ │ └─SiLU: 3-4 [256, 16, 16] [256, 16, 16] --\n",
"│ │ └─Linear: 3-5 [256, 16, 16] [256, 16, 16] (256)\n",
"│ │ └─LayerNorm: 3-6 [256, 16, 16] [256, 16, 16] (32)\n",
"│ │ └─SiLU: 3-7 [256, 16, 16] [256, 16, 16] --\n",
"│ └─ModuleDict: 2-4 -- -- --\n",
"│ │ └─Linear: 3-8 [256, 16, 16] [256, 16, 102] (1,734)\n",
"===================================================================================================================\n",
"Total params: 462,764\n",
"Trainable params: 0\n",
"Non-trainable params: 462,764\n",
"Total mult-adds (G): 11.29\n",
"===================================================================================================================\n",
"Input size (MB): 25.17\n",
"Forward/backward pass size (MB): 361.96\n",
"Params size (MB): 1.85\n",
"Estimated Total Size (MB): 388.97\n",
"===================================================================================================================\n",
"Summary Failed for reward Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Sequential: 1, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 1]\n",
"Summary Failed for cont Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Sequential: 1, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 1]\n"
]
}
],
"source": [
"# heads\n",
"feat = wm.dynamics.get_feat(post)\n",
@@ -429,7 +1073,41 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"===================================================================================================================\n",
"Layer (type:depth-idx) Output Shape Param # Output Shape\n",
"===================================================================================================================\n",
"Sequential [256, 16, 256] -- [256, 16, 256]\n",
"├─Linear: 1-1 [256, 16, 256] (393,216) [256, 16, 256]\n",
"├─LayerNorm: 1-2 [256, 16, 256] (512) [256, 16, 256]\n",
"├─SiLU: 1-3 [256, 16, 256] -- [256, 16, 256]\n",
"├─Linear: 1-4 [256, 16, 256] (65,536) [256, 16, 256]\n",
"├─LayerNorm: 1-5 [256, 16, 256] (512) [256, 16, 256]\n",
"├─SiLU: 1-6 [256, 16, 256] -- [256, 16, 256]\n",
"├─Linear: 1-7 [256, 16, 256] (65,536) [256, 16, 256]\n",
"├─LayerNorm: 1-8 [256, 16, 256] (512) [256, 16, 256]\n",
"├─SiLU: 1-9 [256, 16, 256] -- [256, 16, 256]\n",
"===================================================================================================================\n",
"Total params: 525,824\n",
"Trainable params: 0\n",
"Non-trainable params: 525,824\n",
"Total mult-adds (M): 134.61\n",
"===================================================================================================================\n",
"Input size (MB): 25.17\n",
"Forward/backward pass size (MB): 50.33\n",
"Params size (MB): 2.10\n",
"Estimated Total Size (MB): 77.60\n",
"==================================================================================================================="
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"actor = agent._task_behavior.actor\n",
"\n",
@@ -441,7 +1119,41 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"===================================================================================================================\n",
"Layer (type:depth-idx) Output Shape Param # Output Shape\n",
"===================================================================================================================\n",
"Sequential [256, 16, 256] -- [256, 16, 256]\n",
"├─Linear: 1-1 [256, 16, 256] (393,216) [256, 16, 256]\n",
"├─LayerNorm: 1-2 [256, 16, 256] (512) [256, 16, 256]\n",
"├─SiLU: 1-3 [256, 16, 256] -- [256, 16, 256]\n",
"├─Linear: 1-4 [256, 16, 256] (65,536) [256, 16, 256]\n",
"├─LayerNorm: 1-5 [256, 16, 256] (512) [256, 16, 256]\n",
"├─SiLU: 1-6 [256, 16, 256] -- [256, 16, 256]\n",
"├─Linear: 1-7 [256, 16, 256] (65,536) [256, 16, 256]\n",
"├─LayerNorm: 1-8 [256, 16, 256] (512) [256, 16, 256]\n",
"├─SiLU: 1-9 [256, 16, 256] -- [256, 16, 256]\n",
"===================================================================================================================\n",
"Total params: 525,824\n",
"Trainable params: 0\n",
"Non-trainable params: 525,824\n",
"Total mult-adds (M): 134.61\n",
"===================================================================================================================\n",
"Input size (MB): 25.17\n",
"Forward/backward pass size (MB): 50.33\n",
"Params size (MB): 2.10\n",
"Estimated Total Size (MB): 77.60\n",
"==================================================================================================================="
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"value = agent._task_behavior.actor\n",
"summary(value.layers, input_data=(feat,), depth=3, col_names=[\"output_size\", \"num_params\", \"output_size\" ])"
@@ -451,26 +1163,22 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"8268"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"8268"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"o = obs['state'].reshape((-1, 8268))\n",
"map = o[:, :8217].reshape((-1, 9, 11, 83))\n",
"map.shape\n",
"inventories = o[:, 8217:]\n",
"inventories\n",
"\n",
"map"
]
},
{
"cell_type": "code",
"execution_count": null,
+17
View File
@@ -246,6 +246,23 @@
"- actor"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Check encoder decoder"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"agent.encoder(x)\n",
"agent.heads['decoder'](x) "
]
},
{
"cell_type": "markdown",
"metadata": {},
+4 -10
View File
@@ -560,6 +560,7 @@ class ConvDecoder(nn.Module):
[m.apply(tools.weight_init) for m in layers[:-1]]
layers[-1].apply(tools.uniform_weight_init(outscale))
self.layers = nn.Sequential(*layers)
self.outdim = out_dim
def calc_same_pad(self, k, s, d):
val = d * (k - 1) - s + 1
@@ -569,17 +570,10 @@ class ConvDecoder(nn.Module):
def forward(self, features, dtype=None):
x = self._linear_layer(features)
# (batch, time, -1) -> (batch * time, h, w, ch)
x = x.reshape(
[-1, self._minres, self._minres, self._embed_size // self._minres**2]
)
# (batch, time, -1) -> (batch * time, ch, h, w)
x = x.permute(0, 3, 1, 2)
x = rearrange(x, "b t (h w c) -> (b t) c h w", h=self._minres, w=self._minres)
x = self.layers(x)
# (batch, time, -1) -> (batch, time, ch, h, w)
mean = x.reshape(features.shape[:-1] + self._shape)
# (batch, time, ch, h, w) -> (batch, time, h, w, ch)
mean = mean.permute(0, 1, 3, 4, 2)
mean = rearrange(x, "(b t) c h w -> b t h w c ", t=features.shape[1])
# assert mean.shape[-2]==self.outdim, f"{mean.shape[-2]}!={self.outdim}"
if self._cnn_sigmoid:
mean = F.sigmoid(mean)
else: