Files
dreamerv3-torch/nbs/02_torchinfo copy.ipynb
2024-06-07 06:00:35 +08:00

1869 lines
98 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this notebook we load a saved dreamer, and run it, to look at params, speed and improve hackability"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading textures from cache\n"
]
}
],
"source": [
"# TODO make this a proper package\n",
"import os, sys\n",
"sys.path.append('..')\n",
"\n",
"\n",
"from dreamer import parse_args, main, make_env, make_dataset, count_steps,Dreamer"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['../dreamer.py', '--configs', 'craftax_smaller', '--logdir', '../logdir/craftax_smaller']\n",
"\u001b[32m2024-06-07 05:25:00.098\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mdreamer\u001b[0m:\u001b[36mparse_args\u001b[0m:\u001b[36m392\u001b[0m - \u001b[1margs=Namespace(act='SiLU', action_repeat=1, actor={'layers': 2, 'dist': 'onehot', 'entropy': 0.0003, 'unimix_ratio': 0.01, 'std': 'none', 'min_std': 0.1, 'max_std': 1.0, 'temp': 0.1, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 1.0}, batch_length=32, batch_size=256, compile=True, cont_head={'layers': 2, 'loss_scale': 1.0, 'outscale': 1.0}, critic={'layers': 2, 'dist': 'symlog_disc', 'slow_target': True, 'slow_target_update': 1, 'slow_target_fraction': 0.02, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 0.0}, dataset_size=20000, debug=False, decoder={'mlp_keys': 'state_inventory', 'cnn_keys': 'state_map', 'act': 'SiLU', 'norm': True, 'cnn_depth': 16, 'kernel_size': 4, 'minres': 2, 'mlp_layers': 2, 'mlp_units': 8, 'cnn_sigmoid': False, 'image_dist': 'mse', 'vector_dist': 'symlog_mse', 'outscale': 1.0}, deterministic_run=False, device='cuda:0', disag_action_cond=False, disag_layers=4, disag_log=True, disag_models=10, disag_offset=1, disag_target='stoch', disag_units=400, discount=0.997, discount_lambda=0.95, dyn_deter=128, dyn_discrete=16, dyn_hidden=128, dyn_mean_act='none', dyn_min_std=0.1, dyn_rec_depth=1, dyn_scale=0.5, dyn_std_act='sigmoid2', dyn_stoch=16, encoder={'mlp_keys': 'state_inventory', 'cnn_keys': 'state_map', 'act': 'SiLU', 'norm': True, 'cnn_depth': 16, 'kernel_size': 4, 'minres': 2, 'mlp_layers': 2, 'mlp_units': 8, 'symlog_inputs': True}, envs=1, eval_episode_num=10, eval_every=10000.0, eval_state_mean=False, evaldir=None, expl_behavior='greedy', expl_extr_scale=0.0, expl_intr_scale=1.0, expl_until=0, grad_clip=1000, grad_heads=('decoder', 'reward', 'cont'), grayscale=False, imag_gradient='reinforce', imag_gradient_mix=0.0, imag_horizon=15, initial='learned', kl_free=1.0, log_every=10000.0, logdir='../logdir/craftax_smaller', model_lr=0.0001, norm=True, offline_evaldir='', offline_traindir='', opt='adam', opt_eps=1e-08, parallel=False, precision=32, prefill=2500, pretrain=100, rep_scale=0.1, reset_every=0, reward_EMA=True, reward_head={'layers': 2, 'dist': 'symlog_disc', 'loss_scale': 1.0, 'outscale': 0.0}, seed=0, size=(64, 64), step=1000000.0, steps=1000000.0, task='craftax_Craftax-Symbolic-AutoReset-v1', time_limit=4000, train_ratio=512, traindir=None, unimix_ratio=0.01, units=128, value={'layers': 2}, video_pred_log=False, weight_decay=0.0)\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"Namespace(act='SiLU', action_repeat=1, actor={'layers': 2, 'dist': 'onehot', 'entropy': 0.0003, 'unimix_ratio': 0.01, 'std': 'none', 'min_std': 0.1, 'max_std': 1.0, 'temp': 0.1, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 1.0}, batch_length=32, batch_size=256, compile=True, cont_head={'layers': 2, 'loss_scale': 1.0, 'outscale': 1.0}, critic={'layers': 2, 'dist': 'symlog_disc', 'slow_target': True, 'slow_target_update': 1, 'slow_target_fraction': 0.02, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 0.0}, dataset_size=20000, debug=False, decoder={'mlp_keys': 'state_inventory', 'cnn_keys': 'state_map', 'act': 'SiLU', 'norm': True, 'cnn_depth': 16, 'kernel_size': 4, 'minres': 2, 'mlp_layers': 2, 'mlp_units': 8, 'cnn_sigmoid': False, 'image_dist': 'mse', 'vector_dist': 'symlog_mse', 'outscale': 1.0}, deterministic_run=False, device='cuda:0', disag_action_cond=False, disag_layers=4, disag_log=True, disag_models=10, disag_offset=1, disag_target='stoch', disag_units=400, discount=0.997, discount_lambda=0.95, dyn_deter=128, dyn_discrete=16, dyn_hidden=128, dyn_mean_act='none', dyn_min_std=0.1, dyn_rec_depth=1, dyn_scale=0.5, dyn_std_act='sigmoid2', dyn_stoch=16, encoder={'mlp_keys': 'state_inventory', 'cnn_keys': 'state_map', 'act': 'SiLU', 'norm': True, 'cnn_depth': 16, 'kernel_size': 4, 'minres': 2, 'mlp_layers': 2, 'mlp_units': 8, 'symlog_inputs': True}, envs=1, eval_episode_num=10, eval_every=10000.0, eval_state_mean=False, evaldir=None, expl_behavior='greedy', expl_extr_scale=0.0, expl_intr_scale=1.0, expl_until=0, grad_clip=1000, grad_heads=('decoder', 'reward', 'cont'), grayscale=False, imag_gradient='reinforce', imag_gradient_mix=0.0, imag_horizon=15, initial='learned', kl_free=1.0, log_every=10000.0, logdir='../logdir/craftax_smaller', model_lr=0.0001, norm=True, offline_evaldir='', offline_traindir='', opt='adam', opt_eps=1e-08, parallel=False, precision=32, prefill=2500, pretrain=100, rep_scale=0.1, reset_every=0, reward_EMA=True, reward_head={'layers': 2, 'dist': 'symlog_disc', 'loss_scale': 1.0, 'outscale': 0.0}, seed=0, size=(64, 64), step=1000000.0, steps=1000000.0, task='craftax_Craftax-Symbolic-AutoReset-v1', time_limit=4000, train_ratio=512, traindir=None, unimix_ratio=0.01, units=128, value={'layers': 2}, video_pred_log=False, weight_decay=0.0)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# emulate cli\n",
"argv = f\"../dreamer.py --configs craftax_smaller --logdir ../logdir/craftax_smaller\"\n",
"argv = argv.split()\n",
"print(argv)\n",
"config = parse_args(argv)\n",
"config"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_20867/1136853903.py:1: UnsafeLoaderWarning: \n",
"The default 'Loader' for 'load(stream)' without further arguments can be unsafe.\n",
"Use 'load(stream, Loader=ruamel.yaml.Loader)' explicitly if that is OK.\n",
"Alternatively include the following in your code:\n",
"\n",
" import warnings\n",
" warnings.simplefilter('ignore', ruamel.yaml.error.UnsafeLoaderWarning)\n",
"\n",
"In most other cases you should consider using 'safe_load(stream)'\n",
" yaml.load(\n"
]
},
{
"data": {
"text/plain": [
"{'defaults': {'logdir': None,\n",
" 'traindir': None,\n",
" 'evaldir': None,\n",
" 'offline_traindir': '',\n",
" 'offline_evaldir': '',\n",
" 'seed': 0,\n",
" 'deterministic_run': False,\n",
" 'steps': 1000000.0,\n",
" 'parallel': False,\n",
" 'eval_every': 10000.0,\n",
" 'eval_episode_num': 10,\n",
" 'log_every': 10000.0,\n",
" 'reset_every': 0,\n",
" 'device': 'cuda:0',\n",
" 'compile': True,\n",
" 'precision': 32,\n",
" 'debug': False,\n",
" 'video_pred_log': False,\n",
" 'task': 'dmc_walker_walk',\n",
" 'size': [64, 64],\n",
" 'envs': 1,\n",
" 'action_repeat': 2,\n",
" 'time_limit': 1000,\n",
" 'grayscale': False,\n",
" 'prefill': 2500,\n",
" 'reward_EMA': True,\n",
" 'dyn_hidden': 512,\n",
" 'dyn_deter': 512,\n",
" 'dyn_stoch': 32,\n",
" 'dyn_discrete': 32,\n",
" 'dyn_rec_depth': 1,\n",
" 'dyn_mean_act': 'none',\n",
" 'dyn_std_act': 'sigmoid2',\n",
" 'dyn_min_std': 0.1,\n",
" 'grad_heads': ['decoder', 'reward', 'cont'],\n",
" 'units': 512,\n",
" 'act': 'SiLU',\n",
" 'norm': True,\n",
" 'encoder': {'mlp_keys': '$^',\n",
" 'cnn_keys': 'image',\n",
" 'act': 'SiLU',\n",
" 'norm': True,\n",
" 'cnn_depth': 32,\n",
" 'kernel_size': 4,\n",
" 'minres': 4,\n",
" 'mlp_layers': 5,\n",
" 'mlp_units': 1024,\n",
" 'symlog_inputs': True},\n",
" 'decoder': {'mlp_keys': '$^',\n",
" 'cnn_keys': 'image',\n",
" 'act': 'SiLU',\n",
" 'norm': True,\n",
" 'cnn_depth': 32,\n",
" 'kernel_size': 4,\n",
" 'minres': 4,\n",
" 'mlp_layers': 5,\n",
" 'mlp_units': 1024,\n",
" 'cnn_sigmoid': False,\n",
" 'image_dist': 'mse',\n",
" 'vector_dist': 'symlog_mse',\n",
" 'outscale': 1.0},\n",
" 'actor': {'layers': 2,\n",
" 'dist': 'normal',\n",
" 'entropy': 0.0003,\n",
" 'unimix_ratio': 0.01,\n",
" 'std': 'learned',\n",
" 'min_std': 0.1,\n",
" 'max_std': 1.0,\n",
" 'temp': 0.1,\n",
" 'lr': 3e-05,\n",
" 'eps': 1e-05,\n",
" 'grad_clip': 100.0,\n",
" 'outscale': 1.0},\n",
" 'critic': {'layers': 2,\n",
" 'dist': 'symlog_disc',\n",
" 'slow_target': True,\n",
" 'slow_target_update': 1,\n",
" 'slow_target_fraction': 0.02,\n",
" 'lr': 3e-05,\n",
" 'eps': 1e-05,\n",
" 'grad_clip': 100.0,\n",
" 'outscale': 0.0},\n",
" 'reward_head': {'layers': 2,\n",
" 'dist': 'symlog_disc',\n",
" 'loss_scale': 1.0,\n",
" 'outscale': 0.0},\n",
" 'cont_head': {'layers': 2, 'loss_scale': 1.0, 'outscale': 1.0},\n",
" 'dyn_scale': 0.5,\n",
" 'rep_scale': 0.1,\n",
" 'kl_free': 1.0,\n",
" 'weight_decay': 0.0,\n",
" 'unimix_ratio': 0.01,\n",
" 'initial': 'learned',\n",
" 'batch_size': 64,\n",
" 'batch_length': 64,\n",
" 'train_ratio': 512,\n",
" 'pretrain': 100,\n",
" 'model_lr': 0.0001,\n",
" 'opt_eps': 1e-08,\n",
" 'grad_clip': 1000,\n",
" 'dataset_size': 1000000,\n",
" 'opt': 'adam',\n",
" 'discount': 0.997,\n",
" 'discount_lambda': 0.95,\n",
" 'imag_horizon': 15,\n",
" 'imag_gradient': 'dynamics',\n",
" 'imag_gradient_mix': 0.0,\n",
" 'eval_state_mean': False,\n",
" 'expl_behavior': 'greedy',\n",
" 'expl_until': 0,\n",
" 'expl_extr_scale': 0.0,\n",
" 'expl_intr_scale': 1.0,\n",
" 'disag_target': 'stoch',\n",
" 'disag_log': True,\n",
" 'disag_models': 10,\n",
" 'disag_offset': 1,\n",
" 'disag_layers': 4,\n",
" 'disag_units': 400,\n",
" 'disag_action_cond': False},\n",
" 'dmc_proprio': {'steps': 500000.0,\n",
" 'action_repeat': 2,\n",
" 'envs': 4,\n",
" 'train_ratio': 512,\n",
" 'video_pred_log': False,\n",
" 'encoder': {'mlp_keys': '.*', 'cnn_keys': '$^'},\n",
" 'decoder': {'mlp_keys': '.*', 'cnn_keys': '$^'}},\n",
" 'dmc_vision': {'steps': 1000000.0,\n",
" 'action_repeat': 2,\n",
" 'envs': 4,\n",
" 'train_ratio': 512,\n",
" 'video_pred_log': True,\n",
" 'encoder': {'mlp_keys': '$^', 'cnn_keys': 'image'},\n",
" 'decoder': {'mlp_keys': '$^', 'cnn_keys': 'image'}},\n",
" 'crafter': {'task': 'crafter_reward',\n",
" 'step': 1000000.0,\n",
" 'action_repeat': 1,\n",
" 'envs': 1,\n",
" 'train_ratio': 512,\n",
" 'video_pred_log': False,\n",
" 'dyn_hidden': 1024,\n",
" 'dyn_deter': 4096,\n",
" 'units': 1024,\n",
" 'encoder': {'mlp_keys': '$^',\n",
" 'cnn_keys': 'image',\n",
" 'cnn_depth': 96,\n",
" 'mlp_layers': 5,\n",
" 'mlp_units': 1024},\n",
" 'decoder': {'mlp_keys': '$^',\n",
" 'cnn_keys': 'image',\n",
" 'cnn_depth': 96,\n",
" 'mlp_layers': 5,\n",
" 'mlp_units': 1024},\n",
" 'actor': {'layers': 5, 'dist': 'onehot', 'std': 'none'},\n",
" 'value': {'layers': 5},\n",
" 'reward_head': {'layers': 5},\n",
" 'cont_head': {'layers': 5},\n",
" 'imag_gradient': 'reinforce'},\n",
" 'craftax': {'task': 'craftax_Craftax-Symbolic-AutoReset-v1',\n",
" 'step': 1000000.0,\n",
" 'action_repeat': 1,\n",
" 'envs': 1,\n",
" 'train_ratio': 512,\n",
" 'video_pred_log': False,\n",
" 'dyn_hidden': 1024,\n",
" 'dyn_deter': 4096,\n",
" 'units': 1024,\n",
" 'encoder': {'cnn_keys': '$^',\n",
" 'mlp_keys': 'state',\n",
" 'mlp_layers': 4,\n",
" 'mlp_units': 512},\n",
" 'decoder': {'cnn_keys': '$^',\n",
" 'mlp_keys': 'state',\n",
" 'mlp_layers': 4,\n",
" 'mlp_units': 512},\n",
" 'actor': {'layers': 5, 'dist': 'onehot', 'std': 'none'},\n",
" 'value': {'layers': 5},\n",
" 'reward_head': {'layers': 5},\n",
" 'cont_head': {'layers': 5},\n",
" 'imag_gradient': 'reinforce',\n",
" 'time_limit': 4000},\n",
" 'craftax_small': {'task': 'craftax_Craftax-Symbolic-AutoReset-v1',\n",
" 'step': 1000000.0,\n",
" 'action_repeat': 1,\n",
" 'envs': 1,\n",
" 'train_ratio': 512,\n",
" 'video_pred_log': False,\n",
" 'dyn_hidden': 256,\n",
" 'dyn_deter': 256,\n",
" 'dyn_stoch': 24,\n",
" 'dyn_discrete': 24,\n",
" 'encoder': {'cnn_keys': 'state_map',\n",
" 'cnn_depth': 32,\n",
" 'kernel_size': 4,\n",
" 'minres': 2,\n",
" 'mlp_keys': 'state_inventory',\n",
" 'mlp_layers': 2,\n",
" 'mlp_units': 16},\n",
" 'decoder': {'cnn_keys': 'state_map',\n",
" 'cnn_depth': 32,\n",
" 'kernel_size': 4,\n",
" 'minres': 2,\n",
" 'mlp_keys': 'state_inventory',\n",
" 'mlp_layers': 2,\n",
" 'mlp_units': 16},\n",
" 'actor': {'layers': 3, 'dist': 'onehot', 'std': 'none'},\n",
" 'value': {'layers': 3},\n",
" 'units': 256,\n",
" 'reward_head': {'layers': 3},\n",
" 'cont_head': {'layers': 3},\n",
" 'imag_gradient': 'reinforce',\n",
" 'batch_size': 256,\n",
" 'batch_length': 32,\n",
" 'time_limit': 4000},\n",
" 'craftax_smaller': {'task': 'craftax_Craftax-Symbolic-AutoReset-v1',\n",
" 'step': 1000000.0,\n",
" 'action_repeat': 1,\n",
" 'envs': 1,\n",
" 'train_ratio': 512,\n",
" 'video_pred_log': False,\n",
" 'dyn_hidden': 128,\n",
" 'dyn_deter': 128,\n",
" 'dyn_stoch': 16,\n",
" 'dyn_discrete': 16,\n",
" 'encoder': {'cnn_keys': 'state_map',\n",
" 'cnn_depth': 16,\n",
" 'kernel_size': 4,\n",
" 'minres': 2,\n",
" 'mlp_keys': 'state_inventory',\n",
" 'mlp_layers': 2,\n",
" 'mlp_units': 8},\n",
" 'decoder': {'cnn_keys': 'state_map',\n",
" 'cnn_depth': 16,\n",
" 'kernel_size': 4,\n",
" 'minres': 2,\n",
" 'mlp_keys': 'state_inventory',\n",
" 'mlp_layers': 2,\n",
" 'mlp_units': 8},\n",
" 'actor': {'layers': 2, 'dist': 'onehot', 'std': 'none'},\n",
" 'value': {'layers': 2},\n",
" 'units': 128,\n",
" 'reward_head': {'layers': 2},\n",
" 'cont_head': {'layers': 2},\n",
" 'imag_gradient': 'reinforce',\n",
" 'batch_size': 256,\n",
" 'batch_length': 32,\n",
" 'time_limit': 4000,\n",
" 'dataset_size': 20000},\n",
" 'atari100k': {'steps': 400000.0,\n",
" 'envs': 1,\n",
" 'action_repeat': 4,\n",
" 'train_ratio': 1024,\n",
" 'video_pred_log': True,\n",
" 'eval_episode_num': 100,\n",
" 'actor': {'dist': 'onehot', 'std': 'none'},\n",
" 'imag_gradient': 'reinforce',\n",
" 'stickey': False,\n",
" 'lives': 'unused',\n",
" 'noops': 30,\n",
" 'resize': 'opencv',\n",
" 'actions': 'needed',\n",
" 'time_limit': 108000},\n",
" 'minecraft': {'task': 'minecraft_diamond',\n",
" 'step': 100000000.0,\n",
" 'parallel': True,\n",
" 'envs': 16,\n",
" 'eval_episode_num': 0,\n",
" 'eval_every': 10000.0,\n",
" 'action_repeat': 1,\n",
" 'train_ratio': 16,\n",
" 'video_pred_log': True,\n",
" 'dyn_hidden': 1024,\n",
" 'dyn_deter': 4096,\n",
" 'units': 1024,\n",
" 'encoder': {'mlp_keys': 'inventory|inventory_max|equipped|health|hunger|breath|obs_reward',\n",
" 'cnn_keys': 'image',\n",
" 'cnn_depth': 96,\n",
" 'mlp_layers': 5,\n",
" 'mlp_units': 1024},\n",
" 'decoder': {'mlp_keys': 'inventory|inventory_max|equipped|health|hunger|breath',\n",
" 'cnn_keys': 'image',\n",
" 'cnn_depth': 96,\n",
" 'mlp_layers': 5,\n",
" 'mlp_units': 1024},\n",
" 'actor': {'layers': 5, 'dist': 'onehot', 'std': 'none'},\n",
" 'value': {'layers': 5},\n",
" 'reward_head': {'layers': 5},\n",
" 'cont_head': {'layers': 5},\n",
" 'imag_gradient': 'reinforce',\n",
" 'break_speed': 100.0,\n",
" 'time_limit': 36000},\n",
" 'memorymaze': {'steps': 100000000.0,\n",
" 'action_repeat': 2,\n",
" 'actor': {'dist': 'onehot', 'std': 'none'},\n",
" 'imag_gradient': 'reinforce',\n",
" 'task': 'memorymaze_9x9'},\n",
" 'debug': {'debug': True,\n",
" 'pretrain': 1,\n",
" 'prefill': 1,\n",
" 'batch_size': 10,\n",
" 'batch_length': 20}}"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"yaml.safe_load(\n",
" pathlib.Path(\"/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/configs.yaml\").read_text()\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"usage: ipykernel_launcher.py [-h] [--act ACT] [--action_repeat ACTION_REPEAT]\n",
" [--actor ACTOR] [--batch_length BATCH_LENGTH]\n",
" [--batch_size BATCH_SIZE] [--compile COMPILE]\n",
" [--cont_head CONT_HEAD] [--critic CRITIC]\n",
" [--dataset_size DATASET_SIZE] [--debug DEBUG]\n",
" [--decoder DECODER]\n",
" [--deterministic_run DETERMINISTIC_RUN]\n",
" [--device DEVICE]\n",
" [--disag_action_cond DISAG_ACTION_COND]\n",
" [--disag_layers DISAG_LAYERS]\n",
" [--disag_log DISAG_LOG]\n",
" [--disag_models DISAG_MODELS]\n",
" [--disag_offset DISAG_OFFSET]\n",
" [--disag_target DISAG_TARGET]\n",
" [--disag_units DISAG_UNITS] [--discount DISCOUNT]\n",
" [--discount_lambda DISCOUNT_LAMBDA]\n",
" [--dyn_deter DYN_DETER]\n",
" [--dyn_discrete DYN_DISCRETE]\n",
" [--dyn_hidden DYN_HIDDEN]\n",
" [--dyn_mean_act DYN_MEAN_ACT]\n",
" [--dyn_min_std DYN_MIN_STD]\n",
" [--dyn_rec_depth DYN_REC_DEPTH]\n",
" [--dyn_scale DYN_SCALE]\n",
" [--dyn_std_act DYN_STD_ACT]\n",
" [--dyn_stoch DYN_STOCH] [--encoder ENCODER]\n",
" [--envs ENVS]\n",
" [--eval_episode_num EVAL_EPISODE_NUM]\n",
" [--eval_every EVAL_EVERY]\n",
" [--eval_state_mean EVAL_STATE_MEAN]\n",
" [--evaldir EVALDIR]\n",
" [--expl_behavior EXPL_BEHAVIOR]\n",
" [--expl_extr_scale EXPL_EXTR_SCALE]\n",
" [--expl_intr_scale EXPL_INTR_SCALE]\n",
" [--expl_until EXPL_UNTIL] [--grad_clip GRAD_CLIP]\n",
" [--grad_heads GRAD_HEADS] [--grayscale GRAYSCALE]\n",
" [--imag_gradient IMAG_GRADIENT]\n",
" [--imag_gradient_mix IMAG_GRADIENT_MIX]\n",
" [--imag_horizon IMAG_HORIZON] [--initial INITIAL]\n",
" [--kl_free KL_FREE] [--log_every LOG_EVERY]\n",
" [--logdir LOGDIR] [--model_lr MODEL_LR]\n",
" [--norm NORM] [--offline_evaldir OFFLINE_EVALDIR]\n",
" [--offline_traindir OFFLINE_TRAINDIR] [--opt OPT]\n",
" [--opt_eps OPT_EPS] [--parallel PARALLEL]\n",
" [--precision PRECISION] [--prefill PREFILL]\n",
" [--pretrain PRETRAIN] [--rep_scale REP_SCALE]\n",
" [--reset_every RESET_EVERY]\n",
" [--reward_EMA REWARD_EMA]\n",
" [--reward_head REWARD_HEAD] [--seed SEED]\n",
" [--size SIZE] [--steps STEPS] [--task TASK]\n",
" [--time_limit TIME_LIMIT]\n",
" [--train_ratio TRAIN_RATIO] [--traindir TRAINDIR]\n",
" [--unimix_ratio UNIMIX_RATIO] [--units UNITS]\n",
" [--video_pred_log VIDEO_PRED_LOG]\n",
" [--weight_decay WEIGHT_DECAY]\n",
"\n",
"optional arguments:\n",
" -h, --help show this help message and exit\n",
" --act ACT\n",
" --action_repeat ACTION_REPEAT\n",
" --actor ACTOR\n",
" --batch_length BATCH_LENGTH\n",
" --batch_size BATCH_SIZE\n",
" --compile COMPILE\n",
" --cont_head CONT_HEAD\n",
" --critic CRITIC\n",
" --dataset_size DATASET_SIZE\n",
" --debug DEBUG\n",
" --decoder DECODER\n",
" --deterministic_run DETERMINISTIC_RUN\n",
" --device DEVICE\n",
" --disag_action_cond DISAG_ACTION_COND\n",
" --disag_layers DISAG_LAYERS\n",
" --disag_log DISAG_LOG\n",
" --disag_models DISAG_MODELS\n",
" --disag_offset DISAG_OFFSET\n",
" --disag_target DISAG_TARGET\n",
" --disag_units DISAG_UNITS\n",
" --discount DISCOUNT\n",
" --discount_lambda DISCOUNT_LAMBDA\n",
" --dyn_deter DYN_DETER\n",
" --dyn_discrete DYN_DISCRETE\n",
" --dyn_hidden DYN_HIDDEN\n",
" --dyn_mean_act DYN_MEAN_ACT\n",
" --dyn_min_std DYN_MIN_STD\n",
" --dyn_rec_depth DYN_REC_DEPTH\n",
" --dyn_scale DYN_SCALE\n",
" --dyn_std_act DYN_STD_ACT\n",
" --dyn_stoch DYN_STOCH\n",
" --encoder ENCODER\n",
" --envs ENVS\n",
" --eval_episode_num EVAL_EPISODE_NUM\n",
" --eval_every EVAL_EVERY\n",
" --eval_state_mean EVAL_STATE_MEAN\n",
" --evaldir EVALDIR\n",
" --expl_behavior EXPL_BEHAVIOR\n",
" --expl_extr_scale EXPL_EXTR_SCALE\n",
" --expl_intr_scale EXPL_INTR_SCALE\n",
" --expl_until EXPL_UNTIL\n",
" --grad_clip GRAD_CLIP\n",
" --grad_heads GRAD_HEADS\n",
" --grayscale GRAYSCALE\n",
" --imag_gradient IMAG_GRADIENT\n",
" --imag_gradient_mix IMAG_GRADIENT_MIX\n",
" --imag_horizon IMAG_HORIZON\n",
" --initial INITIAL\n",
" --kl_free KL_FREE\n",
" --log_every LOG_EVERY\n",
" --logdir LOGDIR\n",
" --model_lr MODEL_LR\n",
" --norm NORM\n",
" --offline_evaldir OFFLINE_EVALDIR\n",
" --offline_traindir OFFLINE_TRAINDIR\n",
" --opt OPT\n",
" --opt_eps OPT_EPS\n",
" --parallel PARALLEL\n",
" --precision PRECISION\n",
" --prefill PREFILL\n",
" --pretrain PRETRAIN\n",
" --rep_scale REP_SCALE\n",
" --reset_every RESET_EVERY\n",
" --reward_EMA REWARD_EMA\n",
" --reward_head REWARD_HEAD\n",
" --seed SEED\n",
" --size SIZE\n",
" --steps STEPS\n",
" --task TASK\n",
" --time_limit TIME_LIMIT\n",
" --train_ratio TRAIN_RATIO\n",
" --traindir TRAINDIR\n",
" --unimix_ratio UNIMIX_RATIO\n",
" --units UNITS\n",
" --video_pred_log VIDEO_PRED_LOG\n",
" --weight_decay WEIGHT_DECAY\n"
]
},
{
"data": {
"text/plain": [
"''"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# load config, using relative path\n",
"import pathlib\n",
"import ruamel.yaml as yaml\n",
"import tools\n",
"\n",
"# root_dir = pathlib.Path(__file__).parent\n",
"configs = yaml.safe_load(\n",
" pathlib.Path(\"/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/configs.yaml\").read_text()\n",
")\n",
"\n",
"def recursive_update(base, update):\n",
" for key, value in update.items():\n",
" if isinstance(value, dict) and key in base:\n",
" recursive_update(base[key], value)\n",
" else:\n",
" base[key] = value\n",
"\n",
"name_list = [\"defaults\"]\n",
"defaults = {}\n",
"for name in name_list:\n",
" recursive_update(defaults, configs[name])\n",
"defaults2 = {k:tools.args_type(v)(v) for k, v in defaults.items()}\n",
"config2 = argparse.Namespace(**defaults2)\n",
"defaults2\n",
"\n",
"import argparse\n",
"parser = argparse.ArgumentParser()\n",
"for key, value in sorted(defaults.items(), key=lambda x: x[0]):\n",
" arg_type = tools.args_type(value)\n",
" parser.add_argument(f\"--{key}\", type=arg_type, default=arg_type(value))\n",
"args = parser.parse_args([])\n",
"# logger.debug(f\"{parser}\")\n",
"\n",
"# from loguru import logger\n",
"# parser.print_help()\n",
"# ;\n"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'added': {},\n",
" 'removed': {},\n",
" 'modified': {'size': ([64, 64], (64, 64)),\n",
" 'grad_heads': (['decoder', 'reward', 'cont'],\n",
" ('decoder', 'reward', 'cont'))}}"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# diff two dics\n",
"def dict_diff(dict1, dict2):\n",
" diff = {\n",
" 'added': {k: dict2[k] for k in dict2 if k not in dict1},\n",
" 'removed': {k: dict1[k] for k in dict1 if k not in dict2},\n",
" 'modified': {k: (dict1[k], dict2[k]) for k in dict1 if k in dict2 and dict1[k] != dict2[k]}\n",
" }\n",
" return diff\n",
"\n",
"dict_diff(defaults, defaults2)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Namespace(act='SiLU', action_repeat=1, actor={'layers': 2, 'dist': 'onehot', 'entropy': 0.0003, 'unimix_ratio': 0.01, 'std': 'none', 'min_std': 0.1, 'max_std': 1.0, 'temp': 0.1, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 1.0}, batch_length=32, batch_size=256, compile=True, cont_head={'layers': 2, 'loss_scale': 1.0, 'outscale': 1.0}, critic={'layers': 2, 'dist': 'symlog_disc', 'slow_target': True, 'slow_target_update': 1, 'slow_target_fraction': 0.02, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 0.0}, dataset_size=20000, debug=False, decoder={'mlp_keys': 'state_inventory', 'cnn_keys': 'state_map', 'act': 'SiLU', 'norm': True, 'cnn_depth': 16, 'kernel_size': 4, 'minres': 2, 'mlp_layers': 2, 'mlp_units': 8, 'cnn_sigmoid': False, 'image_dist': 'mse', 'vector_dist': 'symlog_mse', 'outscale': 1.0}, deterministic_run=False, device='cuda:0', disag_action_cond=False, disag_layers=4, disag_log=True, disag_models=10, disag_offset=1, disag_target='stoch', disag_units=400, discount=0.997, discount_lambda=0.95, dyn_deter=128, dyn_discrete=16, dyn_hidden=128, dyn_mean_act='none', dyn_min_std=0.1, dyn_rec_depth=1, dyn_scale=0.5, dyn_std_act='sigmoid2', dyn_stoch=16, encoder={'mlp_keys': 'state_inventory', 'cnn_keys': 'state_map', 'act': 'SiLU', 'norm': True, 'cnn_depth': 16, 'kernel_size': 4, 'minres': 2, 'mlp_layers': 2, 'mlp_units': 8, 'symlog_inputs': True}, envs=1, eval_episode_num=10, eval_every=10000.0, eval_state_mean=False, evaldir=None, expl_behavior='greedy', expl_extr_scale=0.0, expl_intr_scale=1.0, expl_until=0, grad_clip=1000, grad_heads=('decoder', 'reward', 'cont'), grayscale=False, imag_gradient='reinforce', imag_gradient_mix=0.0, imag_horizon=15, initial='learned', kl_free=1.0, log_every=10000.0, logdir='../logdir/craftax_smaller', model_lr=0.0001, norm=True, offline_evaldir='', offline_traindir='', opt='adam', opt_eps=1e-08, parallel=False, precision=32, prefill=2500, pretrain=100, rep_scale=0.1, reset_every=0, reward_EMA=True, reward_head={'layers': 2, 'dist': 'symlog_disc', 'loss_scale': 1.0, 'outscale': 0.0}, seed=0, size=(64, 64), step=1000000.0, steps=1000000.0, task='craftax_Craftax-Symbolic-AutoReset-v1', time_limit=4000, train_ratio=512, traindir=None, unimix_ratio=0.01, units=128, value={'layers': 2}, video_pred_log=False, weight_decay=0.0)"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"config"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Namespace(logdir=None, traindir=None, evaldir=None, offline_traindir='', offline_evaldir='', seed=0, deterministic_run=False, steps='1e6', parallel=False, eval_every='1e4', eval_episode_num=10, log_every='1e4', reset_every=0, device='cuda:0', compile=True, precision=32, debug=False, video_pred_log=False, task='dmc_walker_walk', size=[64, 64], envs=1, action_repeat=2, time_limit=1000, grayscale=False, prefill=2500, reward_EMA=True, dyn_hidden=512, dyn_deter=512, dyn_stoch=32, dyn_discrete=32, dyn_rec_depth=1, dyn_mean_act='none', dyn_std_act='sigmoid2', dyn_min_std=0.1, grad_heads=['decoder', 'reward', 'cont'], units=512, act='SiLU', norm=True, encoder={'mlp_keys': '$^', 'cnn_keys': 'image', 'act': 'SiLU', 'norm': True, 'cnn_depth': 32, 'kernel_size': 4, 'minres': 4, 'mlp_layers': 5, 'mlp_units': 1024, 'symlog_inputs': True}, decoder={'mlp_keys': '$^', 'cnn_keys': 'image', 'act': 'SiLU', 'norm': True, 'cnn_depth': 32, 'kernel_size': 4, 'minres': 4, 'mlp_layers': 5, 'mlp_units': 1024, 'cnn_sigmoid': False, 'image_dist': 'mse', 'vector_dist': 'symlog_mse', 'outscale': 1.0}, actor={'layers': 2, 'dist': 'normal', 'entropy': '3e-4', 'unimix_ratio': 0.01, 'std': 'learned', 'min_std': 0.1, 'max_std': 1.0, 'temp': 0.1, 'lr': '3e-5', 'eps': '1e-5', 'grad_clip': 100.0, 'outscale': 1.0}, critic={'layers': 2, 'dist': 'symlog_disc', 'slow_target': True, 'slow_target_update': 1, 'slow_target_fraction': 0.02, 'lr': '3e-5', 'eps': '1e-5', 'grad_clip': 100.0, 'outscale': 0.0}, reward_head={'layers': 2, 'dist': 'symlog_disc', 'loss_scale': 1.0, 'outscale': 0.0}, cont_head={'layers': 2, 'loss_scale': 1.0, 'outscale': 1.0}, dyn_scale=0.5, rep_scale=0.1, kl_free=1.0, weight_decay=0.0, unimix_ratio=0.01, initial='learned', batch_size=64, batch_length=64, train_ratio=512, pretrain=100, model_lr='1e-4', opt_eps='1e-8', grad_clip=1000, dataset_size=1000000, opt='adam', discount=0.997, discount_lambda=0.95, imag_horizon=15, imag_gradient='dynamics', imag_gradient_mix=0.0, eval_state_mean=False, expl_behavior='greedy', expl_until=0, expl_extr_scale=0.0, expl_intr_scale=1.0, disag_target='stoch', disag_log=True, disag_models=10, disag_offset=1, disag_layers=4, disag_units=400, disag_action_cond=False)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"config2"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32m2024-06-07 03:49:38.306\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m27\u001b[0m - \u001b[1mLogdir ../logdir/craftax_smaller\u001b[0m\n",
"\u001b[32m2024-06-07 03:49:38.336\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m36\u001b[0m - \u001b[1mCreate envs.\u001b[0m\n",
"\u001b[32m2024-06-07 03:50:12.609\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m57\u001b[0m - \u001b[1mAction Space Box(0.0, 1.0, (43,), float32)\u001b[0m\n",
"\u001b[32m2024-06-07 03:50:12.611\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m63\u001b[0m - \u001b[1mPrefill dataset (0 steps).\u001b[0m\n",
"\u001b[32m2024-06-07 03:50:12.951\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m92\u001b[0m - \u001b[1mLogger: (2714 steps).\u001b[0m\n",
"\u001b[32m2024-06-07 03:50:12.952\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m94\u001b[0m - \u001b[1mSimulate agent.\u001b[0m\n"
]
}
],
"source": [
"from loguru import logger\n",
"from tqdm.auto import tqdm\n",
"import pathlib\n",
"\n",
"import torch\n",
"from torch import nn\n",
"from torch import distributions as torchd\n",
"\n",
"import exploration as expl\n",
"import models\n",
"import tools\n",
"import envs.wrappers as wrappers\n",
"from parallel import Parallel, Damy\n",
"\n",
"# from main\n",
"tools.set_seed_everywhere(config.seed)\n",
"if config.deterministic_run:\n",
" tools.enable_deterministic_run()\n",
"logdir = pathlib.Path(config.logdir).expanduser()\n",
"config.traindir = config.traindir or logdir / \"train_eps\"\n",
"config.evaldir = config.evaldir or logdir / \"eval_eps\"\n",
"config.steps //= config.action_repeat\n",
"config.eval_every //= config.action_repeat\n",
"config.log_every //= config.action_repeat\n",
"config.time_limit //= config.action_repeat\n",
"\n",
"logger.info(f\"Logdir {logdir}\")\n",
"logdir.mkdir(parents=True, exist_ok=True)\n",
"config.traindir.mkdir(parents=True, exist_ok=True)\n",
"config.evaldir.mkdir(parents=True, exist_ok=True)\n",
"step = count_steps(config.traindir)\n",
"# step in logger is environmental step\n",
"tlogger = tools.Logger(logdir, config.action_repeat * step)\n",
"logger.add(logdir/\"logger.log\")\n",
"\n",
"logger.info(\"Create envs.\")\n",
"if config.offline_traindir:\n",
" directory = config.offline_traindir.format(**vars(config))\n",
"else:\n",
" directory = config.traindir\n",
"train_eps = tools.load_episodes(directory, limit=config.dataset_size)\n",
"if config.offline_evaldir:\n",
" directory = config.offline_evaldir.format(**vars(config))\n",
"else:\n",
" directory = config.evaldir\n",
"eval_eps = tools.load_episodes(directory, limit=1)\n",
"make = lambda mode, id: make_env(config, mode, id)\n",
"train_envs = [make(\"train\", i) for i in range(config.envs)]\n",
"eval_envs = [make(\"eval\", i) for i in range(config.envs)]\n",
"if config.parallel:\n",
" train_envs = [Parallel(env, \"process\") for env in train_envs]\n",
" eval_envs = [Parallel(env, \"process\") for env in eval_envs]\n",
"else:\n",
" train_envs = [Damy(env) for env in train_envs]\n",
" eval_envs = [Damy(env) for env in eval_envs]\n",
"acts = train_envs[0].action_space\n",
"logger.info(f\"Action Space {acts}\" )\n",
"config.num_actions = acts.n if hasattr(acts, \"n\") else acts.shape[0]\n",
"\n",
"state = None\n",
"if not config.offline_traindir:\n",
" prefill = max(0, config.prefill - count_steps(config.traindir))\n",
" logger.info(f\"Prefill dataset ({prefill} steps).\")\n",
" if hasattr(acts, \"discrete\"):\n",
" random_actor = tools.OneHotDist(\n",
" torch.zeros(config.num_actions).repeat(config.envs, 1)\n",
" )\n",
" else:\n",
" random_actor = torchd.independent.Independent(\n",
" torchd.uniform.Uniform(\n",
" torch.Tensor(acts.low).repeat(config.envs, 1),\n",
" torch.Tensor(acts.high).repeat(config.envs, 1),\n",
" ),\n",
" 1,\n",
" )\n",
"\n",
" def random_agent(o, d, s):\n",
" action = random_actor.sample()\n",
" logprob = random_actor.log_prob(action)\n",
" return {\"action\": action, \"logprob\": logprob}, None\n",
"\n",
" state = tools.simulate(\n",
" random_agent,\n",
" train_envs,\n",
" train_eps,\n",
" config.traindir,\n",
" tlogger,\n",
" limit=config.dataset_size,\n",
" steps=prefill,\n",
" )\n",
" tlogger.step += prefill * config.action_repeat\n",
" logger.info(f\"Logger: ({tlogger.step} steps).\")\n",
"\n",
"logger.info(\"Simulate agent.\")\n",
"train_dataset = make_dataset(train_eps, config)\n",
"eval_dataset = make_dataset(eval_eps, config)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/.venv/lib/python3.9/site-packages/numpy/core/numeric.py:330: RuntimeWarning: invalid value encountered in cast\n",
" multiarray.copyto(a, fill_value, casting='unsafe')\n"
]
},
{
"data": {
"text/plain": [
"Dict('image': Box(0, 255, (130, 110, 3), uint8), 'is_first': Box(0, 0, (1,), uint8), 'is_last': Box(0, 0, (1,), uint8), 'is_terminal': Box(0, 0, (1,), uint8), 'log_achievement_cast_fireball': Box(-inf, inf, (1,), float32), 'log_achievement_cast_iceball': Box(-inf, inf, (1,), float32), 'log_achievement_collect_coal': Box(-inf, inf, (1,), float32), 'log_achievement_collect_diamond': Box(-inf, inf, (1,), float32), 'log_achievement_collect_drink': Box(-inf, inf, (1,), float32), 'log_achievement_collect_iron': Box(-inf, inf, (1,), float32), 'log_achievement_collect_ruby': Box(-inf, inf, (1,), float32), 'log_achievement_collect_sapling': Box(-inf, inf, (1,), float32), 'log_achievement_collect_sapphire': Box(-inf, inf, (1,), float32), 'log_achievement_collect_stone': Box(-inf, inf, (1,), float32), 'log_achievement_collect_wood': Box(-inf, inf, (1,), float32), 'log_achievement_damage_necromancer': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_archer': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_deep_thing': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_fire_elemental': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_frost_troll': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_gnome_archer': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_gnome_warrior': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_ice_elemental': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_knight': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_kobold': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_lizard': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_necromancer': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_orc_mage': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_orc_solider': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_pigman': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_skeleton': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_troll': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_zombie': Box(-inf, inf, (1,), float32), 'log_achievement_drink_potion': Box(-inf, inf, (1,), float32), 'log_achievement_eat_bat': Box(-inf, inf, (1,), float32), 'log_achievement_eat_cow': Box(-inf, inf, (1,), float32), 'log_achievement_eat_plant': Box(-inf, inf, (1,), float32), 'log_achievement_eat_snail': Box(-inf, inf, (1,), float32), 'log_achievement_enchant_armour': Box(-inf, inf, (1,), float32), 'log_achievement_enchant_sword': Box(-inf, inf, (1,), float32), 'log_achievement_enter_dungeon': Box(-inf, inf, (1,), float32), 'log_achievement_enter_fire_realm': Box(-inf, inf, (1,), float32), 'log_achievement_enter_gnomish_mines': Box(-inf, inf, (1,), float32), 'log_achievement_enter_graveyard': Box(-inf, inf, (1,), float32), 'log_achievement_enter_ice_realm': Box(-inf, inf, (1,), float32), 'log_achievement_enter_sewers': Box(-inf, inf, (1,), float32), 'log_achievement_enter_troll_mines': Box(-inf, inf, (1,), float32), 'log_achievement_enter_vault': Box(-inf, inf, (1,), float32), 'log_achievement_find_bow': Box(-inf, inf, (1,), float32), 'log_achievement_fire_bow': Box(-inf, inf, (1,), float32), 'log_achievement_learn_fireball': Box(-inf, inf, (1,), float32), 'log_achievement_learn_iceball': Box(-inf, inf, (1,), float32), 'log_achievement_make_arrow': Box(-inf, inf, (1,), float32), 'log_achievement_make_diamond_armour': Box(-inf, inf, (1,), float32), 'log_achievement_make_diamond_pickaxe': Box(-inf, inf, (1,), float32), 'log_achievement_make_diamond_sword': Box(-inf, inf, (1,), float32), 'log_achievement_make_iron_armour': Box(-inf, inf, (1,), float32), 'log_achievement_make_iron_pickaxe': Box(-inf, inf, (1,), float32), 'log_achievement_make_iron_sword': Box(-inf, inf, (1,), float32), 'log_achievement_make_stone_pickaxe': Box(-inf, inf, (1,), float32), 'log_achievement_make_stone_sword': Box(-inf, inf, (1,), float32), 'log_achievement_make_torch': Box(-inf, inf, (1,), float32), 'log_achievement_make_wood_pickaxe': Box(-inf, inf, (1,), float32), 'log_achievement_make_wood_sword': Box(-inf, inf, (1,), float32), 'log_achievement_open_chest': Box(-inf, inf, (1,), float32), 'log_achievement_place_furnace': Box(-inf, inf, (1,), float32), 'log_achievement_place_plant': Box(-inf, inf, (1,), float32), 'log_achievement_place_stone': Box(-inf, inf, (1,), float32), 'log_achievement_place_table': Box(-inf, inf, (1,), float32), 'log_achievement_place_torch': Box(-inf, inf, (1,), float32), 'log_achievement_wake_up': Box(-inf, inf, (1,), float32), 'log_reward': Box(-inf, inf, (1,), float32), 'state_inventory': Box(0.0, 1.0, (102,), float16), 'state_map': Box(0.0, 1.0, (16, 16, 166), float16))"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_envs[0].observation_space"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32m2024-06-07 03:50:13.077\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m324\u001b[0m - \u001b[1mEncoder CNN shapes: {'state_map': (16, 16, 166)}\u001b[0m\n",
"\u001b[32m2024-06-07 03:50:13.077\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m325\u001b[0m - \u001b[1mEncoder MLP shapes: {'state_inventory': (102,)}\u001b[0m\n",
"\u001b[32m2024-06-07 03:50:13.509\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m391\u001b[0m - \u001b[1mDecoder CNN shapes: {'state_map': (16, 16, 166)}\u001b[0m\n",
"\u001b[32m2024-06-07 03:50:13.509\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m392\u001b[0m - \u001b[1mDecoder MLP shapes: {'state_inventory': (102,)}\u001b[0m\n",
"\u001b[32m2024-06-07 03:50:15.165\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmodels\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m102\u001b[0m - \u001b[1mOptimizer model_opt has 706924 variables.\u001b[0m\n",
"\u001b[32m2024-06-07 03:50:15.185\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmodels\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m281\u001b[0m - \u001b[1mOptimizer actor_opt has 71595 variables.\u001b[0m\n",
"\u001b[32m2024-06-07 03:50:15.186\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmodels\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m292\u001b[0m - \u001b[1mOptimizer value_opt has 98943 variables.\u001b[0m\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/wassname/miniforge3/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
" self.pid = os.fork()\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dreamer(\n",
" (_wm): OptimizedModule(\n",
" (_orig_mod): WorldModel(\n",
" (encoder): MultiEncoder(\n",
" (_cnn): ConvEncoder(\n",
" (layers): Sequential(\n",
" (0): Conv2dSamePad(166, 16, kernel_size=(4, 4), stride=(2, 2), bias=False)\n",
" (1): ImgChLayerNorm(\n",
" (norm): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n",
" )\n",
" (2): SiLU()\n",
" (3): Conv2dSamePad(16, 32, kernel_size=(4, 4), stride=(2, 2), bias=False)\n",
" (4): ImgChLayerNorm(\n",
" (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n",
" )\n",
" (5): SiLU()\n",
" (6): Conv2dSamePad(32, 64, kernel_size=(4, 4), stride=(2, 2), bias=False)\n",
" (7): ImgChLayerNorm(\n",
" (norm): LayerNorm((64,), eps=0.001, elementwise_affine=True)\n",
" )\n",
" (8): SiLU()\n",
" )\n",
" )\n",
" (_mlp): MLP(\n",
" (layers): Sequential(\n",
" (Encoder_linear0): Linear(in_features=102, out_features=8, bias=False)\n",
" (Encoder_norm0): LayerNorm((8,), eps=0.001, elementwise_affine=True)\n",
" (Encoder_act0): SiLU()\n",
" (Encoder_linear1): Linear(in_features=8, out_features=8, bias=False)\n",
" (Encoder_norm1): LayerNorm((8,), eps=0.001, elementwise_affine=True)\n",
" (Encoder_act1): SiLU()\n",
" )\n",
" )\n",
" )\n",
" (dynamics): RSSM(\n",
" (_img_in_layers): Sequential(\n",
" (0): Linear(in_features=299, out_features=128, bias=False)\n",
" (1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (2): SiLU()\n",
" )\n",
" (_cell): GRUCell(\n",
" (layers): Sequential(\n",
" (GRU_linear): Linear(in_features=256, out_features=384, bias=False)\n",
" (GRU_norm): LayerNorm((384,), eps=0.001, elementwise_affine=True)\n",
" )\n",
" )\n",
" (_img_out_layers): Sequential(\n",
" (0): Linear(in_features=128, out_features=128, bias=False)\n",
" (1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (2): SiLU()\n",
" )\n",
" (_obs_out_layers): Sequential(\n",
" (0): Linear(in_features=392, out_features=128, bias=False)\n",
" (1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (2): SiLU()\n",
" )\n",
" (_imgs_stat_layer): Linear(in_features=128, out_features=256, bias=True)\n",
" (_obs_stat_layer): Linear(in_features=128, out_features=256, bias=True)\n",
" )\n",
" (heads): ModuleDict(\n",
" (decoder): MultiDecoder(\n",
" (_cnn): ConvDecoder(\n",
" (_linear_layer): Linear(in_features=384, out_features=256, bias=True)\n",
" (layers): Sequential(\n",
" (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (1): ImgChLayerNorm(\n",
" (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n",
" )\n",
" (2): SiLU()\n",
" (3): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (4): ImgChLayerNorm(\n",
" (norm): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n",
" )\n",
" (5): SiLU()\n",
" (6): ConvTranspose2d(16, 166, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
" )\n",
" )\n",
" (_mlp): MLP(\n",
" (layers): Sequential(\n",
" (Decoder_linear0): Linear(in_features=384, out_features=8, bias=False)\n",
" (Decoder_norm0): LayerNorm((8,), eps=0.001, elementwise_affine=True)\n",
" (Decoder_act0): SiLU()\n",
" (Decoder_linear1): Linear(in_features=8, out_features=8, bias=False)\n",
" (Decoder_norm1): LayerNorm((8,), eps=0.001, elementwise_affine=True)\n",
" (Decoder_act1): SiLU()\n",
" )\n",
" (mean_layer): ModuleDict(\n",
" (state_inventory): Linear(in_features=8, out_features=102, bias=True)\n",
" )\n",
" )\n",
" )\n",
" (reward): MLP(\n",
" (layers): Sequential(\n",
" (Reward_linear0): Linear(in_features=384, out_features=128, bias=False)\n",
" (Reward_norm0): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (Reward_act0): SiLU()\n",
" (Reward_linear1): Linear(in_features=128, out_features=128, bias=False)\n",
" (Reward_norm1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (Reward_act1): SiLU()\n",
" )\n",
" (mean_layer): Linear(in_features=128, out_features=255, bias=True)\n",
" )\n",
" (cont): MLP(\n",
" (layers): Sequential(\n",
" (Cont_linear0): Linear(in_features=384, out_features=128, bias=False)\n",
" (Cont_norm0): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (Cont_act0): SiLU()\n",
" (Cont_linear1): Linear(in_features=128, out_features=128, bias=False)\n",
" (Cont_norm1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (Cont_act1): SiLU()\n",
" )\n",
" (mean_layer): Linear(in_features=128, out_features=1, bias=True)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (_task_behavior): OptimizedModule(\n",
" (_orig_mod): ImagBehavior(\n",
" (_world_model): WorldModel(\n",
" (encoder): MultiEncoder(\n",
" (_cnn): ConvEncoder(\n",
" (layers): Sequential(\n",
" (0): Conv2dSamePad(166, 16, kernel_size=(4, 4), stride=(2, 2), bias=False)\n",
" (1): ImgChLayerNorm(\n",
" (norm): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n",
" )\n",
" (2): SiLU()\n",
" (3): Conv2dSamePad(16, 32, kernel_size=(4, 4), stride=(2, 2), bias=False)\n",
" (4): ImgChLayerNorm(\n",
" (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n",
" )\n",
" (5): SiLU()\n",
" (6): Conv2dSamePad(32, 64, kernel_size=(4, 4), stride=(2, 2), bias=False)\n",
" (7): ImgChLayerNorm(\n",
" (norm): LayerNorm((64,), eps=0.001, elementwise_affine=True)\n",
" )\n",
" (8): SiLU()\n",
" )\n",
" )\n",
" (_mlp): MLP(\n",
" (layers): Sequential(\n",
" (Encoder_linear0): Linear(in_features=102, out_features=8, bias=False)\n",
" (Encoder_norm0): LayerNorm((8,), eps=0.001, elementwise_affine=True)\n",
" (Encoder_act0): SiLU()\n",
" (Encoder_linear1): Linear(in_features=8, out_features=8, bias=False)\n",
" (Encoder_norm1): LayerNorm((8,), eps=0.001, elementwise_affine=True)\n",
" (Encoder_act1): SiLU()\n",
" )\n",
" )\n",
" )\n",
" (dynamics): RSSM(\n",
" (_img_in_layers): Sequential(\n",
" (0): Linear(in_features=299, out_features=128, bias=False)\n",
" (1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (2): SiLU()\n",
" )\n",
" (_cell): GRUCell(\n",
" (layers): Sequential(\n",
" (GRU_linear): Linear(in_features=256, out_features=384, bias=False)\n",
" (GRU_norm): LayerNorm((384,), eps=0.001, elementwise_affine=True)\n",
" )\n",
" )\n",
" (_img_out_layers): Sequential(\n",
" (0): Linear(in_features=128, out_features=128, bias=False)\n",
" (1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (2): SiLU()\n",
" )\n",
" (_obs_out_layers): Sequential(\n",
" (0): Linear(in_features=392, out_features=128, bias=False)\n",
" (1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (2): SiLU()\n",
" )\n",
" (_imgs_stat_layer): Linear(in_features=128, out_features=256, bias=True)\n",
" (_obs_stat_layer): Linear(in_features=128, out_features=256, bias=True)\n",
" )\n",
" (heads): ModuleDict(\n",
" (decoder): MultiDecoder(\n",
" (_cnn): ConvDecoder(\n",
" (_linear_layer): Linear(in_features=384, out_features=256, bias=True)\n",
" (layers): Sequential(\n",
" (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (1): ImgChLayerNorm(\n",
" (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n",
" )\n",
" (2): SiLU()\n",
" (3): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (4): ImgChLayerNorm(\n",
" (norm): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n",
" )\n",
" (5): SiLU()\n",
" (6): ConvTranspose2d(16, 166, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
" )\n",
" )\n",
" (_mlp): MLP(\n",
" (layers): Sequential(\n",
" (Decoder_linear0): Linear(in_features=384, out_features=8, bias=False)\n",
" (Decoder_norm0): LayerNorm((8,), eps=0.001, elementwise_affine=True)\n",
" (Decoder_act0): SiLU()\n",
" (Decoder_linear1): Linear(in_features=8, out_features=8, bias=False)\n",
" (Decoder_norm1): LayerNorm((8,), eps=0.001, elementwise_affine=True)\n",
" (Decoder_act1): SiLU()\n",
" )\n",
" (mean_layer): ModuleDict(\n",
" (state_inventory): Linear(in_features=8, out_features=102, bias=True)\n",
" )\n",
" )\n",
" )\n",
" (reward): MLP(\n",
" (layers): Sequential(\n",
" (Reward_linear0): Linear(in_features=384, out_features=128, bias=False)\n",
" (Reward_norm0): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (Reward_act0): SiLU()\n",
" (Reward_linear1): Linear(in_features=128, out_features=128, bias=False)\n",
" (Reward_norm1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (Reward_act1): SiLU()\n",
" )\n",
" (mean_layer): Linear(in_features=128, out_features=255, bias=True)\n",
" )\n",
" (cont): MLP(\n",
" (layers): Sequential(\n",
" (Cont_linear0): Linear(in_features=384, out_features=128, bias=False)\n",
" (Cont_norm0): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (Cont_act0): SiLU()\n",
" (Cont_linear1): Linear(in_features=128, out_features=128, bias=False)\n",
" (Cont_norm1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (Cont_act1): SiLU()\n",
" )\n",
" (mean_layer): Linear(in_features=128, out_features=1, bias=True)\n",
" )\n",
" )\n",
" )\n",
" (actor): MLP(\n",
" (layers): Sequential(\n",
" (Actor_linear0): Linear(in_features=384, out_features=128, bias=False)\n",
" (Actor_norm0): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (Actor_act0): SiLU()\n",
" (Actor_linear1): Linear(in_features=128, out_features=128, bias=False)\n",
" (Actor_norm1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (Actor_act1): SiLU()\n",
" )\n",
" (mean_layer): Linear(in_features=128, out_features=43, bias=True)\n",
" )\n",
" (value): MLP(\n",
" (layers): Sequential(\n",
" (Value_linear0): Linear(in_features=384, out_features=128, bias=False)\n",
" (Value_norm0): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (Value_act0): SiLU()\n",
" (Value_linear1): Linear(in_features=128, out_features=128, bias=False)\n",
" (Value_norm1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (Value_act1): SiLU()\n",
" )\n",
" (mean_layer): Linear(in_features=128, out_features=255, bias=True)\n",
" )\n",
" (_slow_value): MLP(\n",
" (layers): Sequential(\n",
" (Value_linear0): Linear(in_features=384, out_features=128, bias=False)\n",
" (Value_norm0): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (Value_act0): SiLU()\n",
" (Value_linear1): Linear(in_features=128, out_features=128, bias=False)\n",
" (Value_norm1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (Value_act1): SiLU()\n",
" )\n",
" (mean_layer): Linear(in_features=128, out_features=255, bias=True)\n",
" )\n",
" )\n",
" )\n",
" (_expl_behavior): OptimizedModule(\n",
" (_orig_mod): ImagBehavior(\n",
" (_world_model): WorldModel(\n",
" (encoder): MultiEncoder(\n",
" (_cnn): ConvEncoder(\n",
" (layers): Sequential(\n",
" (0): Conv2dSamePad(166, 16, kernel_size=(4, 4), stride=(2, 2), bias=False)\n",
" (1): ImgChLayerNorm(\n",
" (norm): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n",
" )\n",
" (2): SiLU()\n",
" (3): Conv2dSamePad(16, 32, kernel_size=(4, 4), stride=(2, 2), bias=False)\n",
" (4): ImgChLayerNorm(\n",
" (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n",
" )\n",
" (5): SiLU()\n",
" (6): Conv2dSamePad(32, 64, kernel_size=(4, 4), stride=(2, 2), bias=False)\n",
" (7): ImgChLayerNorm(\n",
" (norm): LayerNorm((64,), eps=0.001, elementwise_affine=True)\n",
" )\n",
" (8): SiLU()\n",
" )\n",
" )\n",
" (_mlp): MLP(\n",
" (layers): Sequential(\n",
" (Encoder_linear0): Linear(in_features=102, out_features=8, bias=False)\n",
" (Encoder_norm0): LayerNorm((8,), eps=0.001, elementwise_affine=True)\n",
" (Encoder_act0): SiLU()\n",
" (Encoder_linear1): Linear(in_features=8, out_features=8, bias=False)\n",
" (Encoder_norm1): LayerNorm((8,), eps=0.001, elementwise_affine=True)\n",
" (Encoder_act1): SiLU()\n",
" )\n",
" )\n",
" )\n",
" (dynamics): RSSM(\n",
" (_img_in_layers): Sequential(\n",
" (0): Linear(in_features=299, out_features=128, bias=False)\n",
" (1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (2): SiLU()\n",
" )\n",
" (_cell): GRUCell(\n",
" (layers): Sequential(\n",
" (GRU_linear): Linear(in_features=256, out_features=384, bias=False)\n",
" (GRU_norm): LayerNorm((384,), eps=0.001, elementwise_affine=True)\n",
" )\n",
" )\n",
" (_img_out_layers): Sequential(\n",
" (0): Linear(in_features=128, out_features=128, bias=False)\n",
" (1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (2): SiLU()\n",
" )\n",
" (_obs_out_layers): Sequential(\n",
" (0): Linear(in_features=392, out_features=128, bias=False)\n",
" (1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (2): SiLU()\n",
" )\n",
" (_imgs_stat_layer): Linear(in_features=128, out_features=256, bias=True)\n",
" (_obs_stat_layer): Linear(in_features=128, out_features=256, bias=True)\n",
" )\n",
" (heads): ModuleDict(\n",
" (decoder): MultiDecoder(\n",
" (_cnn): ConvDecoder(\n",
" (_linear_layer): Linear(in_features=384, out_features=256, bias=True)\n",
" (layers): Sequential(\n",
" (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (1): ImgChLayerNorm(\n",
" (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n",
" )\n",
" (2): SiLU()\n",
" (3): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (4): ImgChLayerNorm(\n",
" (norm): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n",
" )\n",
" (5): SiLU()\n",
" (6): ConvTranspose2d(16, 166, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
" )\n",
" )\n",
" (_mlp): MLP(\n",
" (layers): Sequential(\n",
" (Decoder_linear0): Linear(in_features=384, out_features=8, bias=False)\n",
" (Decoder_norm0): LayerNorm((8,), eps=0.001, elementwise_affine=True)\n",
" (Decoder_act0): SiLU()\n",
" (Decoder_linear1): Linear(in_features=8, out_features=8, bias=False)\n",
" (Decoder_norm1): LayerNorm((8,), eps=0.001, elementwise_affine=True)\n",
" (Decoder_act1): SiLU()\n",
" )\n",
" (mean_layer): ModuleDict(\n",
" (state_inventory): Linear(in_features=8, out_features=102, bias=True)\n",
" )\n",
" )\n",
" )\n",
" (reward): MLP(\n",
" (layers): Sequential(\n",
" (Reward_linear0): Linear(in_features=384, out_features=128, bias=False)\n",
" (Reward_norm0): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (Reward_act0): SiLU()\n",
" (Reward_linear1): Linear(in_features=128, out_features=128, bias=False)\n",
" (Reward_norm1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (Reward_act1): SiLU()\n",
" )\n",
" (mean_layer): Linear(in_features=128, out_features=255, bias=True)\n",
" )\n",
" (cont): MLP(\n",
" (layers): Sequential(\n",
" (Cont_linear0): Linear(in_features=384, out_features=128, bias=False)\n",
" (Cont_norm0): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (Cont_act0): SiLU()\n",
" (Cont_linear1): Linear(in_features=128, out_features=128, bias=False)\n",
" (Cont_norm1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (Cont_act1): SiLU()\n",
" )\n",
" (mean_layer): Linear(in_features=128, out_features=1, bias=True)\n",
" )\n",
" )\n",
" )\n",
" (actor): MLP(\n",
" (layers): Sequential(\n",
" (Actor_linear0): Linear(in_features=384, out_features=128, bias=False)\n",
" (Actor_norm0): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (Actor_act0): SiLU()\n",
" (Actor_linear1): Linear(in_features=128, out_features=128, bias=False)\n",
" (Actor_norm1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (Actor_act1): SiLU()\n",
" )\n",
" (mean_layer): Linear(in_features=128, out_features=43, bias=True)\n",
" )\n",
" (value): MLP(\n",
" (layers): Sequential(\n",
" (Value_linear0): Linear(in_features=384, out_features=128, bias=False)\n",
" (Value_norm0): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (Value_act0): SiLU()\n",
" (Value_linear1): Linear(in_features=128, out_features=128, bias=False)\n",
" (Value_norm1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (Value_act1): SiLU()\n",
" )\n",
" (mean_layer): Linear(in_features=128, out_features=255, bias=True)\n",
" )\n",
" (_slow_value): MLP(\n",
" (layers): Sequential(\n",
" (Value_linear0): Linear(in_features=384, out_features=128, bias=False)\n",
" (Value_norm0): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (Value_act0): SiLU()\n",
" (Value_linear1): Linear(in_features=128, out_features=128, bias=False)\n",
" (Value_norm1): LayerNorm((128,), eps=0.001, elementwise_affine=True)\n",
" (Value_act1): SiLU()\n",
" )\n",
" (mean_layer): Linear(in_features=128, out_features=255, bias=True)\n",
" )\n",
" )\n",
" )\n",
")\n"
]
}
],
"source": [
"config = parse_args(argv)\n",
"config.num_actions = acts.n if hasattr(acts, \"n\") else acts.shape[0]\n",
"agent = Dreamer(\n",
" train_envs[0].observation_space,\n",
" train_envs[0].action_space,\n",
" config,\n",
" tlogger,\n",
" train_dataset,\n",
").to(config.device)\n",
"print(agent)\n",
"agent.requires_grad_(requires_grad=False)\n",
"if (logdir / \"latest.pt\").exists():\n",
" checkpoint = torch.load(logdir / \"latest.pt\")\n",
" agent.load_state_dict(checkpoint[\"agent_state_dict\"])\n",
" tools.recursively_load_optim_state_dict(agent, checkpoint[\"optims_state_dict\"])\n",
" agent._should_pretrain._once = False\n",
" logger.warning(f\"Loaded model from {logdir / 'latest.pt'}\")"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[0;31mType:\u001b[0m generator\n",
"\u001b[0;31mString form:\u001b[0m <generator object from_generator at 0x7ea9049cef20>\n",
"\u001b[0;31mDocstring:\u001b[0m <no docstring>"
]
},
{
"ename": "",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
"\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
"\u001b[1;31mClick <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. \n",
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
]
}
],
"source": [
"train_dataset??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- note model_opt includes actor.wm\n",
" - encoder\n",
" - rssm\n",
" - heads\n",
"- actor"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Now lets play"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"assert state is not None\n",
"import numpy as np\n",
"\n",
"# state"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from tools import convert, add_to_cache\n",
"envs = train_envs\n",
"cache = train_eps\n",
"\n",
"step, episode = 0, 0\n",
"done = np.ones(len(envs), bool)\n",
"length = np.zeros(len(envs), np.int32)\n",
"obs = [None] * len(envs)\n",
"agent_state = None\n",
"reward = [0] * len(envs)\n",
"\n",
"indices = [index for index, d in enumerate(done) if d]\n",
"results = [envs[i].reset() for i in indices]\n",
"results = [r() for r in results]\n",
"for index, result in zip(indices, results):\n",
" t = result.copy()\n",
" t = {k: convert(v) for k, v in t.items()}\n",
" # action will be added to transition in add_to_cache\n",
" t[\"reward\"] = 0.0\n",
" t[\"discount\"] = 1.0\n",
" # initial state should be added to cache\n",
" add_to_cache(cache, envs[index].id, t)\n",
" # replace obs with done by initial state\n",
" obs[index] = result\n",
"# step agents"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/nbs/../networks.py:792: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)\n",
" ret = F.conv2d(\n",
"/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/.venv/lib/python3.9/site-packages/torch/autograd/graph.py:744: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)\n",
" return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32m2024-06-07 03:53:49.056\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mtools\u001b[0m:\u001b[36mwrite\u001b[0m:\u001b[36m85\u001b[0m - \u001b[1m[2714] model_loss \u001b[31m7464.5\u001b[0m\u001b[1m / model_grad_norm \u001b[31m4841.3\u001b[0m\u001b[1m / state_map_loss \u001b[31m7449.7\u001b[0m\u001b[1m / state_inventory_loss \u001b[31m7.5\u001b[0m\u001b[1m / reward_loss \u001b[31m5.0\u001b[0m\u001b[1m / cont_loss \u001b[31m0.1\u001b[0m\u001b[1m / kl_free \u001b[31m1.0\u001b[0m\u001b[1m / dyn_scale \u001b[31m0.5\u001b[0m\u001b[1m / rep_scale \u001b[31m0.1\u001b[0m\u001b[1m / dyn_loss \u001b[31m3.7\u001b[0m\u001b[1m / rep_loss \u001b[31m3.7\u001b[0m\u001b[1m / kl \u001b[31m3.7\u001b[0m\u001b[1m / prior_ent \u001b[31m42.9\u001b[0m\u001b[1m / post_ent \u001b[31m39.3\u001b[0m\u001b[1m / normed_target_mean \u001b[31m-0.0\u001b[0m\u001b[1m / normed_target_std \u001b[31m0.0\u001b[0m\u001b[1m / normed_target_min \u001b[31m-0.0\u001b[0m\u001b[1m / normed_target_max \u001b[31m0.0\u001b[0m\u001b[1m / EMA_005 \u001b[31m-0.0\u001b[0m\u001b[1m / EMA_095 \u001b[31m-0.0\u001b[0m\u001b[1m / value_mean \u001b[31m-0.0\u001b[0m\u001b[1m / value_std \u001b[31m0.0\u001b[0m\u001b[1m / value_min \u001b[31m-0.0\u001b[0m\u001b[1m / value_max \u001b[31m-0.0\u001b[0m\u001b[1m / target_mean \u001b[31m-0.0\u001b[0m\u001b[1m / target_std \u001b[31m0.0\u001b[0m\u001b[1m / target_min \u001b[31m-0.0\u001b[0m\u001b[1m / target_max \u001b[31m-0.0\u001b[0m\u001b[1m / imag_reward_mean \u001b[31m-0.0\u001b[0m\u001b[1m / imag_reward_std \u001b[31m0.0\u001b[0m\u001b[1m / imag_reward_min \u001b[31m-0.0\u001b[0m\u001b[1m / imag_reward_max \u001b[31m-0.0\u001b[0m\u001b[1m / imag_action_mean \u001b[31m21.2\u001b[0m\u001b[1m / imag_action_std \u001b[31m12.6\u001b[0m\u001b[1m / imag_action_min \u001b[31m0.0\u001b[0m\u001b[1m / imag_action_max \u001b[31m42.0\u001b[0m\u001b[1m / actor_entropy \u001b[31m3.6\u001b[0m\u001b[1m / actor_loss \u001b[31m-0.0\u001b[0m\u001b[1m / actor_grad_norm \u001b[31m0.0\u001b[0m\u001b[1m / value_loss \u001b[31m7.6\u001b[0m\u001b[1m / value_grad_norm \u001b[31m8.4\u001b[0m\u001b[1m / update_count \u001b[31m100.0\u001b[0m\u001b[1m / fps \u001b[31m0.0\u001b[0m\u001b[1m\u001b[0m\n"
]
}
],
"source": [
"# from tools.simulate\n",
"\n",
"# step\n",
"# step, episode, done, length, obs, agent_state, reward = state\n",
"obs2 = {k: np.stack([o[k] for o in obs]) for k in obs[0] if \"log_\" not in k}\n",
"action, agent_state = agent(obs2, done, agent_state)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"==========================================================================================\n",
"Layer (type:depth-idx) Param #\n",
"==========================================================================================\n",
"Dreamer --\n",
"├─OptimizedModule: 1-1 --\n",
"│ └─WorldModel: 2-1 --\n",
"│ │ └─MultiEncoder: 3-1 (84,592)\n",
"│ │ └─RSSM: 3-2 (270,848)\n",
"│ │ └─ModuleDict: 3-3 (351,484)\n",
"├─OptimizedModule: 1-2 --\n",
"│ └─ImagBehavior: 2-2 706,924\n",
"│ │ └─WorldModel: 3-4 (recursive)\n",
"│ │ └─MLP: 3-5 (71,595)\n",
"│ │ └─MLP: 3-6 (98,943)\n",
"│ │ └─MLP: 3-7 (98,943)\n",
"├─OptimizedModule: 1-3 (recursive)\n",
"│ └─ImagBehavior: 2-3 (recursive)\n",
"│ │ └─WorldModel: 3-8 (recursive)\n",
"│ │ └─MLP: 3-9 (recursive)\n",
"│ │ └─MLP: 3-10 (recursive)\n",
"│ │ └─MLP: 3-11 (recursive)\n",
"==========================================================================================\n",
"Total params: 1,683,329\n",
"Trainable params: 0\n",
"Non-trainable params: 1,683,329\n",
"=========================================================================================="
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from torchinfo import summary\n",
"\n",
"summary(agent, input=(obs, done, agent_state), depth=3)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"==========================================================================================\n",
"Layer (type:depth-idx) Param #\n",
"==========================================================================================\n",
"Dreamer --\n",
"├─OptimizedModule: 1-1 --\n",
"│ └─WorldModel: 2-1 --\n",
"│ │ └─MultiEncoder: 3-1 --\n",
"│ │ │ └─ConvEncoder: 4-1 (83,680)\n",
"│ │ │ └─MLP: 4-2 (912)\n",
"│ │ └─RSSM: 3-2 128\n",
"│ │ │ └─Sequential: 4-3 (38,528)\n",
"│ │ │ └─GRUCell: 4-4 (99,072)\n",
"│ │ │ └─Sequential: 4-5 (16,640)\n",
"│ │ │ └─Sequential: 4-6 (50,432)\n",
"│ │ │ └─Linear: 4-7 (33,024)\n",
"│ │ │ └─Linear: 4-8 (33,024)\n",
"│ │ └─ModuleDict: 3-3 --\n",
"│ │ │ └─MultiDecoder: 4-9 (186,364)\n",
"│ │ │ └─MLP: 4-10 (98,943)\n",
"│ │ │ └─MLP: 4-11 (66,177)\n",
"├─OptimizedModule: 1-2 --\n",
"│ └─ImagBehavior: 2-2 706,924\n",
"│ │ └─WorldModel: 3-4 (recursive)\n",
"│ │ │ └─MultiEncoder: 4-12 (recursive)\n",
"│ │ │ └─RSSM: 4-13 (recursive)\n",
"│ │ │ └─ModuleDict: 4-14 (recursive)\n",
"│ │ └─MLP: 3-5 --\n",
"│ │ │ └─Sequential: 4-15 (66,048)\n",
"│ │ │ └─Linear: 4-16 (5,547)\n",
"│ │ └─MLP: 3-6 --\n",
"│ │ │ └─Sequential: 4-17 (66,048)\n",
"│ │ │ └─Linear: 4-18 (32,895)\n",
"│ │ └─MLP: 3-7 --\n",
"│ │ │ └─Sequential: 4-19 (66,048)\n",
"│ │ │ └─Linear: 4-20 (32,895)\n",
"├─OptimizedModule: 1-3 (recursive)\n",
"│ └─ImagBehavior: 2-3 (recursive)\n",
"│ │ └─WorldModel: 3-8 (recursive)\n",
"│ │ │ └─MultiEncoder: 4-21 (recursive)\n",
"│ │ │ └─RSSM: 4-22 (recursive)\n",
"│ │ │ └─ModuleDict: 4-23 (recursive)\n",
"│ │ └─MLP: 3-9 (recursive)\n",
"│ │ │ └─Sequential: 4-24 (recursive)\n",
"│ │ │ └─Linear: 4-25 (recursive)\n",
"│ │ └─MLP: 3-10 (recursive)\n",
"│ │ │ └─Sequential: 4-26 (recursive)\n",
"│ │ │ └─Linear: 4-27 (recursive)\n",
"│ │ └─MLP: 3-11 (recursive)\n",
"│ │ │ └─Sequential: 4-28 (recursive)\n",
"│ │ │ └─Linear: 4-29 (recursive)\n",
"==========================================================================================\n",
"Total params: 1,683,329\n",
"Trainable params: 0\n",
"Non-trainable params: 1,683,329\n",
"=========================================================================================="
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"summary(agent, input=(obs, done, agent_state), depth=4)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# agent._wm.heads"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Fine grained torchinfo"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"wm = agent._wm\n",
"data = next(agent._dataset) \n",
"# self._train()\n",
"# post, context, mets = wm._train(data)\n",
"data = wm.preprocess(data)\n",
"embed = wm.encoder(data)\n",
"post, prior = wm.dynamics.observe(\n",
" embed, data[\"action\"], data[\"is_first\"]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"===================================================================================================================\n",
"Layer (type:depth-idx) Input Shape Output Shape Param #\n",
"===================================================================================================================\n",
"MultiEncoder [256, 32, 130, 110, 3] [256, 32, 264] --\n",
"├─ConvEncoder: 1-1 [256, 32, 16, 16, 166] [256, 32, 256] --\n",
"│ └─Sequential: 2-1 [8192, 166, 16, 16] [8192, 64, 2, 2] --\n",
"│ │ └─Conv2dSamePad: 3-1 [8192, 166, 16, 16] [8192, 16, 8, 8] (42,496)\n",
"│ │ └─ImgChLayerNorm: 3-2 [8192, 16, 8, 8] [8192, 16, 8, 8] --\n",
"│ │ │ └─LayerNorm: 4-1 [8192, 8, 8, 16] [8192, 8, 8, 16] (32)\n",
"│ │ └─SiLU: 3-3 [8192, 16, 8, 8] [8192, 16, 8, 8] --\n",
"│ │ └─Conv2dSamePad: 3-4 [8192, 16, 8, 8] [8192, 32, 4, 4] (8,192)\n",
"│ │ └─ImgChLayerNorm: 3-5 [8192, 32, 4, 4] [8192, 32, 4, 4] --\n",
"│ │ │ └─LayerNorm: 4-2 [8192, 4, 4, 32] [8192, 4, 4, 32] (64)\n",
"│ │ └─SiLU: 3-6 [8192, 32, 4, 4] [8192, 32, 4, 4] --\n",
"│ │ └─Conv2dSamePad: 3-7 [8192, 32, 4, 4] [8192, 64, 2, 2] (32,768)\n",
"│ │ └─ImgChLayerNorm: 3-8 [8192, 64, 2, 2] [8192, 64, 2, 2] --\n",
"│ │ │ └─LayerNorm: 4-3 [8192, 2, 2, 64] [8192, 2, 2, 64] (128)\n",
"│ │ └─SiLU: 3-9 [8192, 64, 2, 2] [8192, 64, 2, 2] --\n",
"├─MLP: 1-2 [256, 32, 102] [256, 32, 8] --\n",
"│ └─Sequential: 2-2 [256, 32, 102] [256, 32, 8] --\n",
"│ │ └─Linear: 3-10 [256, 32, 102] [256, 32, 8] (816)\n",
"│ │ └─LayerNorm: 3-11 [256, 32, 8] [256, 32, 8] (16)\n",
"│ │ └─SiLU: 3-12 [256, 32, 8] [256, 32, 8] --\n",
"│ │ └─Linear: 3-13 [256, 32, 8] [256, 32, 8] (64)\n",
"│ │ └─LayerNorm: 3-14 [256, 32, 8] [256, 32, 8] (16)\n",
"│ │ └─SiLU: 3-15 [256, 32, 8] [256, 32, 8] --\n",
"===================================================================================================================\n",
"Total params: 84,592\n",
"Trainable params: 0\n",
"Non-trainable params: 84,592\n",
"Total mult-adds (G): 24.43\n",
"===================================================================================================================\n",
"Input size (MB): 3345.09\n",
"Forward/backward pass size (MB): 236.98\n",
"Params size (MB): 0.34\n",
"Estimated Total Size (MB): 3582.41\n",
"==================================================================================================================="
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"summary(wm.encoder, input_data=(data,), depth=4, col_names=[\"input_size\", \"output_size\", \"num_params\", ])"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"decoder\n",
"===================================================================================================================\n",
"Layer (type:depth-idx) Input Shape Output Shape Param #\n",
"===================================================================================================================\n",
"MultiDecoder [256, 32, 384] -- --\n",
"├─ConvDecoder: 1-1 [256, 32, 384] [256, 32, 16, 16, 166] --\n",
"│ └─Linear: 2-1 [256, 32, 384] [256, 32, 256] (98,560)\n",
"│ └─Sequential: 2-2 [8192, 64, 2, 2] [8192, 166, 16, 16] --\n",
"│ │ └─ConvTranspose2d: 3-1 [8192, 64, 2, 2] [8192, 32, 4, 4] (32,768)\n",
"│ │ └─ImgChLayerNorm: 3-2 [8192, 32, 4, 4] [8192, 32, 4, 4] (64)\n",
"│ │ └─SiLU: 3-3 [8192, 32, 4, 4] [8192, 32, 4, 4] --\n",
"│ │ └─ConvTranspose2d: 3-4 [8192, 32, 4, 4] [8192, 16, 8, 8] (8,192)\n",
"│ │ └─ImgChLayerNorm: 3-5 [8192, 16, 8, 8] [8192, 16, 8, 8] (32)\n",
"│ │ └─SiLU: 3-6 [8192, 16, 8, 8] [8192, 16, 8, 8] --\n",
"│ │ └─ConvTranspose2d: 3-7 [8192, 16, 8, 8] [8192, 166, 16, 16] (42,662)\n",
"├─MLP: 1-2 [256, 32, 384] -- --\n",
"│ └─Sequential: 2-3 [256, 32, 384] [256, 32, 8] --\n",
"│ │ └─Linear: 3-8 [256, 32, 384] [256, 32, 8] (3,072)\n",
"│ │ └─LayerNorm: 3-9 [256, 32, 8] [256, 32, 8] (16)\n",
"│ │ └─SiLU: 3-10 [256, 32, 8] [256, 32, 8] --\n",
"│ │ └─Linear: 3-11 [256, 32, 8] [256, 32, 8] (64)\n",
"│ │ └─LayerNorm: 3-12 [256, 32, 8] [256, 32, 8] (16)\n",
"│ │ └─SiLU: 3-13 [256, 32, 8] [256, 32, 8] --\n",
"│ └─ModuleDict: 2-4 -- -- --\n",
"│ │ └─Linear: 3-14 [256, 32, 8] [256, 32, 102] (918)\n",
"===================================================================================================================\n",
"Total params: 186,364\n",
"Trainable params: 0\n",
"Non-trainable params: 186,364\n",
"Total mult-adds (G): 98.09\n",
"===================================================================================================================\n",
"Input size (MB): 12.58\n",
"Forward/backward pass size (MB): 3011.90\n",
"Params size (MB): 0.75\n",
"Estimated Total Size (MB): 3025.23\n",
"===================================================================================================================\n",
"Summary Failed for reward Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Sequential: 1, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 1]\n",
"Summary Failed for cont Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Sequential: 1, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 1]\n"
]
}
],
"source": [
"# heads\n",
"feat = wm.dynamics.get_feat(post)\n",
"for name, head in wm.heads.items():\n",
" try:\n",
" o = summary(head, input_data=(feat,), depth=3, col_names=[\"input_size\", \"output_size\", \"num_params\", ])\n",
" print(name)\n",
" print(o)\n",
" except Exception as e:\n",
" print(f\"Summary Failed for {name} {e}\")\n",
" continue"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"# fail as no call method\n",
"# summary(wm.dynamics, input_data=(embed, data[\"action\"], data[\"is_first\"]), depth=3, col_names=[\"output_size\", \"num_params\", ])"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"===================================================================================================================\n",
"Layer (type:depth-idx) Output Shape Param # Output Shape\n",
"===================================================================================================================\n",
"Sequential [256, 32, 128] -- [256, 32, 128]\n",
"├─Linear: 1-1 [256, 32, 128] (49,152) [256, 32, 128]\n",
"├─LayerNorm: 1-2 [256, 32, 128] (256) [256, 32, 128]\n",
"├─SiLU: 1-3 [256, 32, 128] -- [256, 32, 128]\n",
"├─Linear: 1-4 [256, 32, 128] (16,384) [256, 32, 128]\n",
"├─LayerNorm: 1-5 [256, 32, 128] (256) [256, 32, 128]\n",
"├─SiLU: 1-6 [256, 32, 128] -- [256, 32, 128]\n",
"===================================================================================================================\n",
"Total params: 66,048\n",
"Trainable params: 0\n",
"Non-trainable params: 66,048\n",
"Total mult-adds (M): 16.91\n",
"===================================================================================================================\n",
"Input size (MB): 12.58\n",
"Forward/backward pass size (MB): 33.55\n",
"Params size (MB): 0.26\n",
"Estimated Total Size (MB): 46.40\n",
"==================================================================================================================="
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"actor = agent._task_behavior.actor\n",
"\n",
"summary(actor.layers, input_data=(feat,), depth=3, col_names=[\"output_size\", \"num_params\", \"output_size\" ])\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"===================================================================================================================\n",
"Layer (type:depth-idx) Output Shape Param # Output Shape\n",
"===================================================================================================================\n",
"Sequential [256, 32, 128] -- [256, 32, 128]\n",
"├─Linear: 1-1 [256, 32, 128] (49,152) [256, 32, 128]\n",
"├─LayerNorm: 1-2 [256, 32, 128] (256) [256, 32, 128]\n",
"├─SiLU: 1-3 [256, 32, 128] -- [256, 32, 128]\n",
"├─Linear: 1-4 [256, 32, 128] (16,384) [256, 32, 128]\n",
"├─LayerNorm: 1-5 [256, 32, 128] (256) [256, 32, 128]\n",
"├─SiLU: 1-6 [256, 32, 128] -- [256, 32, 128]\n",
"===================================================================================================================\n",
"Total params: 66,048\n",
"Trainable params: 0\n",
"Non-trainable params: 66,048\n",
"Total mult-adds (M): 16.91\n",
"===================================================================================================================\n",
"Input size (MB): 12.58\n",
"Forward/backward pass size (MB): 33.55\n",
"Params size (MB): 0.26\n",
"Estimated Total Size (MB): 46.40\n",
"==================================================================================================================="
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"value = agent._task_behavior.actor\n",
"summary(value.layers, input_data=(feat,), depth=3, col_names=[\"output_size\", \"num_params\", \"output_size\" ])"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"8268"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"8268"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 2
}