mirror of
https://github.com/wassname/dreamerv3-torch.git
synced 2026-06-27 17:30:36 +08:00
1226 lines
72 KiB
Plaintext
1226 lines
72 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"In this notebook we load a saved dreamer, and run it, to look at params, speed and improve hackability"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%load_ext autoreload\n",
|
|
"%autoreload 2"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Loading textures from cache\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# TODO make this a proper package\n",
|
|
"import os, sys\n",
|
|
"sys.path.append('..')\n",
|
|
"\n",
|
|
"\n",
|
|
"from dreamer import parse_args, main, make_env, make_dataset, count_steps,Dreamer"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"['../dreamer.py', '--configs', 'craftax_small2', '--logdir', '../logdir/craftax_small2']\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Namespace(act='SiLU', action_repeat=1, actor={'layers': 3, 'dist': 'onehot', 'entropy': 0.0003, 'unimix_ratio': 0.01, 'std': 'none', 'min_std': 0.1, 'max_std': 1.0, 'temp': 0.1, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 1.0}, batch_length=32, batch_size=256, compile=True, cont_head={'layers': 3, 'loss_scale': 1.0, 'outscale': 1.0}, critic={'layers': 2, 'dist': 'symlog_disc', 'slow_target': True, 'slow_target_update': 1, 'slow_target_fraction': 0.02, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 0.0}, dataset_size=1000000, debug=False, decoder={'mlp_keys': 'state_inventory', 'cnn_keys': 'state_map', 'act': 'SiLU', 'norm': True, 'cnn_depth': 32, 'kernel_size': 4, 'minres': 2, 'mlp_layers': 2, 'mlp_units': 16, 'cnn_sigmoid': False, 'image_dist': 'mse', 'vector_dist': 'symlog_mse', 'outscale': 1.0}, deterministic_run=False, device='cuda:0', disag_action_cond=False, disag_layers=4, disag_log=True, disag_models=10, disag_offset=1, disag_target='stoch', disag_units=400, discount=0.997, discount_lambda=0.95, dyn_deter=256, dyn_discrete=24, dyn_hidden=256, dyn_mean_act='none', dyn_min_std=0.1, dyn_rec_depth=1, dyn_scale=0.5, dyn_std_act='sigmoid2', dyn_stoch=24, encoder={'mlp_keys': 'state_inventory', 'cnn_keys': 'state_map', 'act': 'SiLU', 'norm': True, 'cnn_depth': 32, 'kernel_size': 4, 'minres': 2, 'mlp_layers': 2, 'mlp_units': 16, 'symlog_inputs': True}, envs=1, eval_episode_num=10, eval_every=10000.0, eval_state_mean=False, evaldir=None, expl_behavior='greedy', expl_extr_scale=0.0, expl_intr_scale=1.0, expl_until=0, grad_clip=1000, grad_heads=('decoder', 'reward', 'cont'), grayscale=False, imag_gradient='reinforce', imag_gradient_mix=0.0, imag_horizon=15, initial='learned', kl_free=1.0, log_every=10000.0, logdir='../logdir/craftax_small2', model_lr=0.0001, norm=True, offline_evaldir='', offline_traindir='', opt='adam', opt_eps=1e-08, parallel=False, precision=32, prefill=2500, pretrain=100, rep_scale=0.1, reset_every=0, reward_EMA=True, reward_head={'layers': 3, 'dist': 'symlog_disc', 'loss_scale': 1.0, 'outscale': 0.0}, seed=0, size=(64, 64), step=1000000.0, steps=1000000.0, task='craftax_Craftax-Symbolic-AutoReset-v1', time_limit=1000, train_ratio=512, traindir=None, unimix_ratio=0.01, units=256, value={'layers': 3}, video_pred_log=False, weight_decay=0.0)"
|
|
]
|
|
},
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# emulate cli\n",
|
|
"argv = f\"../dreamer.py --configs craftax_small2 --logdir ../logdir/craftax_small2\"\n",
|
|
"argv = argv.split()\n",
|
|
"print(argv)\n",
|
|
"config = parse_args(argv)\n",
|
|
"config"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\u001b[32m2024-06-06 17:08:10.379\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m27\u001b[0m - \u001b[1mLogdir ../logdir/craftax_small2\u001b[0m\n",
|
|
"\u001b[32m2024-06-06 17:08:10.384\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m36\u001b[0m - \u001b[1mCreate envs.\u001b[0m\n",
|
|
"\u001b[32m2024-06-06 17:08:41.190\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m57\u001b[0m - \u001b[1mAction Space Box(0.0, 1.0, (43,), float32)\u001b[0m\n",
|
|
"\u001b[32m2024-06-06 17:08:41.191\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m63\u001b[0m - \u001b[1mPrefill dataset (26 steps).\u001b[0m\n",
|
|
"\u001b[32m2024-06-06 17:09:31.587\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m92\u001b[0m - \u001b[1mLogger: (2500 steps).\u001b[0m\n",
|
|
"\u001b[32m2024-06-06 17:09:31.588\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m94\u001b[0m - \u001b[1mSimulate agent.\u001b[0m\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from loguru import logger\n",
|
|
"from tqdm.auto import tqdm\n",
|
|
"import pathlib\n",
|
|
"\n",
|
|
"import torch\n",
|
|
"from torch import nn\n",
|
|
"from torch import distributions as torchd\n",
|
|
"\n",
|
|
"import exploration as expl\n",
|
|
"import models\n",
|
|
"import tools\n",
|
|
"import envs.wrappers as wrappers\n",
|
|
"from parallel import Parallel, Damy\n",
|
|
"\n",
|
|
"# from main\n",
|
|
"tools.set_seed_everywhere(config.seed)\n",
|
|
"if config.deterministic_run:\n",
|
|
" tools.enable_deterministic_run()\n",
|
|
"logdir = pathlib.Path(config.logdir).expanduser()\n",
|
|
"config.traindir = config.traindir or logdir / \"train_eps\"\n",
|
|
"config.evaldir = config.evaldir or logdir / \"eval_eps\"\n",
|
|
"config.steps //= config.action_repeat\n",
|
|
"config.eval_every //= config.action_repeat\n",
|
|
"config.log_every //= config.action_repeat\n",
|
|
"config.time_limit //= config.action_repeat\n",
|
|
"\n",
|
|
"logger.info(f\"Logdir {logdir}\")\n",
|
|
"logdir.mkdir(parents=True, exist_ok=True)\n",
|
|
"config.traindir.mkdir(parents=True, exist_ok=True)\n",
|
|
"config.evaldir.mkdir(parents=True, exist_ok=True)\n",
|
|
"step = count_steps(config.traindir)\n",
|
|
"# step in logger is environmental step\n",
|
|
"tlogger = tools.Logger(logdir, config.action_repeat * step)\n",
|
|
"logger.add(logdir/\"logger.log\")\n",
|
|
"\n",
|
|
"logger.info(\"Create envs.\")\n",
|
|
"if config.offline_traindir:\n",
|
|
" directory = config.offline_traindir.format(**vars(config))\n",
|
|
"else:\n",
|
|
" directory = config.traindir\n",
|
|
"train_eps = tools.load_episodes(directory, limit=config.dataset_size)\n",
|
|
"if config.offline_evaldir:\n",
|
|
" directory = config.offline_evaldir.format(**vars(config))\n",
|
|
"else:\n",
|
|
" directory = config.evaldir\n",
|
|
"eval_eps = tools.load_episodes(directory, limit=1)\n",
|
|
"make = lambda mode, id: make_env(config, mode, id)\n",
|
|
"train_envs = [make(\"train\", i) for i in range(config.envs)]\n",
|
|
"eval_envs = [make(\"eval\", i) for i in range(config.envs)]\n",
|
|
"if config.parallel:\n",
|
|
" train_envs = [Parallel(env, \"process\") for env in train_envs]\n",
|
|
" eval_envs = [Parallel(env, \"process\") for env in eval_envs]\n",
|
|
"else:\n",
|
|
" train_envs = [Damy(env) for env in train_envs]\n",
|
|
" eval_envs = [Damy(env) for env in eval_envs]\n",
|
|
"acts = train_envs[0].action_space\n",
|
|
"logger.info(f\"Action Space {acts}\" )\n",
|
|
"config.num_actions = acts.n if hasattr(acts, \"n\") else acts.shape[0]\n",
|
|
"\n",
|
|
"state = None\n",
|
|
"if not config.offline_traindir:\n",
|
|
" prefill = max(0, config.prefill - count_steps(config.traindir))\n",
|
|
" logger.info(f\"Prefill dataset ({prefill} steps).\")\n",
|
|
" if hasattr(acts, \"discrete\"):\n",
|
|
" random_actor = tools.OneHotDist(\n",
|
|
" torch.zeros(config.num_actions).repeat(config.envs, 1)\n",
|
|
" )\n",
|
|
" else:\n",
|
|
" random_actor = torchd.independent.Independent(\n",
|
|
" torchd.uniform.Uniform(\n",
|
|
" torch.Tensor(acts.low).repeat(config.envs, 1),\n",
|
|
" torch.Tensor(acts.high).repeat(config.envs, 1),\n",
|
|
" ),\n",
|
|
" 1,\n",
|
|
" )\n",
|
|
"\n",
|
|
" def random_agent(o, d, s):\n",
|
|
" action = random_actor.sample()\n",
|
|
" logprob = random_actor.log_prob(action)\n",
|
|
" return {\"action\": action, \"logprob\": logprob}, None\n",
|
|
"\n",
|
|
" state = tools.simulate(\n",
|
|
" random_agent,\n",
|
|
" train_envs,\n",
|
|
" train_eps,\n",
|
|
" config.traindir,\n",
|
|
" tlogger,\n",
|
|
" limit=config.dataset_size,\n",
|
|
" steps=prefill,\n",
|
|
" )\n",
|
|
" tlogger.step += prefill * config.action_repeat\n",
|
|
" logger.info(f\"Logger: ({tlogger.step} steps).\")\n",
|
|
"\n",
|
|
"logger.info(\"Simulate agent.\")\n",
|
|
"train_dataset = make_dataset(train_eps, config)\n",
|
|
"eval_dataset = make_dataset(eval_eps, config)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/.venv/lib/python3.9/site-packages/numpy/core/numeric.py:330: RuntimeWarning: invalid value encountered in cast\n",
|
|
" multiarray.copyto(a, fill_value, casting='unsafe')\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Dict('image': Box(0, 255, (130, 110, 3), uint8), 'is_first': Box(0, 0, (1,), uint8), 'is_last': Box(0, 0, (1,), uint8), 'is_terminal': Box(0, 0, (1,), uint8), 'log_achievement_cast_fireball': Box(-inf, inf, (1,), float32), 'log_achievement_cast_iceball': Box(-inf, inf, (1,), float32), 'log_achievement_collect_coal': Box(-inf, inf, (1,), float32), 'log_achievement_collect_diamond': Box(-inf, inf, (1,), float32), 'log_achievement_collect_drink': Box(-inf, inf, (1,), float32), 'log_achievement_collect_iron': Box(-inf, inf, (1,), float32), 'log_achievement_collect_ruby': Box(-inf, inf, (1,), float32), 'log_achievement_collect_sapling': Box(-inf, inf, (1,), float32), 'log_achievement_collect_sapphire': Box(-inf, inf, (1,), float32), 'log_achievement_collect_stone': Box(-inf, inf, (1,), float32), 'log_achievement_collect_wood': Box(-inf, inf, (1,), float32), 'log_achievement_damage_necromancer': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_archer': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_deep_thing': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_fire_elemental': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_frost_troll': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_gnome_archer': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_gnome_warrior': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_ice_elemental': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_knight': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_kobold': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_lizard': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_necromancer': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_orc_mage': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_orc_solider': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_pigman': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_skeleton': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_troll': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_zombie': Box(-inf, inf, (1,), float32), 'log_achievement_drink_potion': Box(-inf, inf, (1,), float32), 'log_achievement_eat_bat': Box(-inf, inf, (1,), float32), 'log_achievement_eat_cow': Box(-inf, inf, (1,), float32), 'log_achievement_eat_plant': Box(-inf, inf, (1,), float32), 'log_achievement_eat_snail': Box(-inf, inf, (1,), float32), 'log_achievement_enchant_armour': Box(-inf, inf, (1,), float32), 'log_achievement_enchant_sword': Box(-inf, inf, (1,), float32), 'log_achievement_enter_dungeon': Box(-inf, inf, (1,), float32), 'log_achievement_enter_fire_realm': Box(-inf, inf, (1,), float32), 'log_achievement_enter_gnomish_mines': Box(-inf, inf, (1,), float32), 'log_achievement_enter_graveyard': Box(-inf, inf, (1,), float32), 'log_achievement_enter_ice_realm': Box(-inf, inf, (1,), float32), 'log_achievement_enter_sewers': Box(-inf, inf, (1,), float32), 'log_achievement_enter_troll_mines': Box(-inf, inf, (1,), float32), 'log_achievement_enter_vault': Box(-inf, inf, (1,), float32), 'log_achievement_find_bow': Box(-inf, inf, (1,), float32), 'log_achievement_fire_bow': Box(-inf, inf, (1,), float32), 'log_achievement_learn_fireball': Box(-inf, inf, (1,), float32), 'log_achievement_learn_iceball': Box(-inf, inf, (1,), float32), 'log_achievement_make_arrow': Box(-inf, inf, (1,), float32), 'log_achievement_make_diamond_armour': Box(-inf, inf, (1,), float32), 'log_achievement_make_diamond_pickaxe': Box(-inf, inf, (1,), float32), 'log_achievement_make_diamond_sword': Box(-inf, inf, (1,), float32), 'log_achievement_make_iron_armour': Box(-inf, inf, (1,), float32), 'log_achievement_make_iron_pickaxe': Box(-inf, inf, (1,), float32), 'log_achievement_make_iron_sword': Box(-inf, inf, (1,), float32), 'log_achievement_make_stone_pickaxe': Box(-inf, inf, (1,), float32), 'log_achievement_make_stone_sword': Box(-inf, inf, (1,), float32), 'log_achievement_make_torch': Box(-inf, inf, (1,), float32), 'log_achievement_make_wood_pickaxe': Box(-inf, inf, (1,), float32), 'log_achievement_make_wood_sword': Box(-inf, inf, (1,), float32), 'log_achievement_open_chest': Box(-inf, inf, (1,), float32), 'log_achievement_place_furnace': Box(-inf, inf, (1,), float32), 'log_achievement_place_plant': Box(-inf, inf, (1,), float32), 'log_achievement_place_stone': Box(-inf, inf, (1,), float32), 'log_achievement_place_table': Box(-inf, inf, (1,), float32), 'log_achievement_place_torch': Box(-inf, inf, (1,), float32), 'log_achievement_wake_up': Box(-inf, inf, (1,), float32), 'log_reward': Box(-inf, inf, (1,), float32), 'state': Box(0.0, 1.0, (16536,), float32), 'state_inventory': Box(0.0, 1.0, (102,), float32), 'state_map': Box(0.0, 1.0, (12, 12, 166), float32))"
|
|
]
|
|
},
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"train_envs[0].observation_space"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\u001b[32m2024-06-06 17:09:31.695\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m324\u001b[0m - \u001b[1mEncoder CNN shapes: {'state_map': (12, 12, 166)}\u001b[0m\n",
|
|
"\u001b[32m2024-06-06 17:09:31.696\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m325\u001b[0m - \u001b[1mEncoder MLP shapes: {'state_inventory': (102,)}\u001b[0m\n",
|
|
"\u001b[32m2024-06-06 17:09:31.913\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m391\u001b[0m - \u001b[1mDecoder CNN shapes: {'state_map': (12, 12, 166)}\u001b[0m\n",
|
|
"\u001b[32m2024-06-06 17:09:31.914\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m392\u001b[0m - \u001b[1mDecoder MLP shapes: {'state_inventory': (102,)}\u001b[0m\n",
|
|
"\u001b[32m2024-06-06 17:09:32.650\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmodels\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m102\u001b[0m - \u001b[1mOptimizer model_opt has 2357196 variables.\u001b[0m\n",
|
|
"\u001b[32m2024-06-06 17:09:32.657\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmodels\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m281\u001b[0m - \u001b[1mOptimizer actor_opt has 356651 variables.\u001b[0m\n",
|
|
"\u001b[32m2024-06-06 17:09:32.657\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmodels\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m292\u001b[0m - \u001b[1mOptimizer value_opt has 345087 variables.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/home/wassname/miniforge3/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
|
|
" self.pid = os.fork()\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Dreamer(\n",
|
|
" (_wm): OptimizedModule(\n",
|
|
" (_orig_mod): WorldModel(\n",
|
|
" (encoder): MultiEncoder(\n",
|
|
" (_cnn): ConvEncoder(\n",
|
|
" (layers): Sequential(\n",
|
|
" (0): Conv2dSamePad(166, 32, kernel_size=(4, 4), stride=(2, 2), bias=False)\n",
|
|
" (1): ImgChLayerNorm(\n",
|
|
" (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n",
|
|
" )\n",
|
|
" (2): SiLU()\n",
|
|
" (3): Conv2dSamePad(32, 64, kernel_size=(4, 4), stride=(2, 2), bias=False)\n",
|
|
" (4): ImgChLayerNorm(\n",
|
|
" (norm): LayerNorm((64,), eps=0.001, elementwise_affine=True)\n",
|
|
" )\n",
|
|
" (5): SiLU()\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (_mlp): MLP(\n",
|
|
" (layers): Sequential(\n",
|
|
" (Encoder_linear0): Linear(in_features=102, out_features=16, bias=False)\n",
|
|
" (Encoder_norm0): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Encoder_act0): SiLU()\n",
|
|
" (Encoder_linear1): Linear(in_features=16, out_features=16, bias=False)\n",
|
|
" (Encoder_norm1): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Encoder_act1): SiLU()\n",
|
|
" )\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (dynamics): RSSM(\n",
|
|
" (_img_in_layers): Sequential(\n",
|
|
" (0): Linear(in_features=619, out_features=256, bias=False)\n",
|
|
" (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (2): SiLU()\n",
|
|
" )\n",
|
|
" (_cell): GRUCell(\n",
|
|
" (layers): Sequential(\n",
|
|
" (GRU_linear): Linear(in_features=512, out_features=768, bias=False)\n",
|
|
" (GRU_norm): LayerNorm((768,), eps=0.001, elementwise_affine=True)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (_img_out_layers): Sequential(\n",
|
|
" (0): Linear(in_features=256, out_features=256, bias=False)\n",
|
|
" (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (2): SiLU()\n",
|
|
" )\n",
|
|
" (_obs_out_layers): Sequential(\n",
|
|
" (0): Linear(in_features=848, out_features=256, bias=False)\n",
|
|
" (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (2): SiLU()\n",
|
|
" )\n",
|
|
" (_imgs_stat_layer): Linear(in_features=256, out_features=576, bias=True)\n",
|
|
" (_obs_stat_layer): Linear(in_features=256, out_features=576, bias=True)\n",
|
|
" )\n",
|
|
" (heads): ModuleDict(\n",
|
|
" (decoder): MultiDecoder(\n",
|
|
" (_cnn): ConvDecoder(\n",
|
|
" (_linear_layer): Linear(in_features=832, out_features=256, bias=True)\n",
|
|
" (layers): Sequential(\n",
|
|
" (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
|
" (1): ImgChLayerNorm(\n",
|
|
" (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n",
|
|
" )\n",
|
|
" (2): SiLU()\n",
|
|
" (3): ConvTranspose2d(32, 166, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (_mlp): MLP(\n",
|
|
" (layers): Sequential(\n",
|
|
" (Decoder_linear0): Linear(in_features=832, out_features=16, bias=False)\n",
|
|
" (Decoder_norm0): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Decoder_act0): SiLU()\n",
|
|
" (Decoder_linear1): Linear(in_features=16, out_features=16, bias=False)\n",
|
|
" (Decoder_norm1): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Decoder_act1): SiLU()\n",
|
|
" )\n",
|
|
" (mean_layer): ModuleDict(\n",
|
|
" (state_inventory): Linear(in_features=16, out_features=102, bias=True)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (reward): MLP(\n",
|
|
" (layers): Sequential(\n",
|
|
" (Reward_linear0): Linear(in_features=832, out_features=256, bias=False)\n",
|
|
" (Reward_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Reward_act0): SiLU()\n",
|
|
" (Reward_linear1): Linear(in_features=256, out_features=256, bias=False)\n",
|
|
" (Reward_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Reward_act1): SiLU()\n",
|
|
" (Reward_linear2): Linear(in_features=256, out_features=256, bias=False)\n",
|
|
" (Reward_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Reward_act2): SiLU()\n",
|
|
" )\n",
|
|
" (mean_layer): Linear(in_features=256, out_features=255, bias=True)\n",
|
|
" )\n",
|
|
" (cont): MLP(\n",
|
|
" (layers): Sequential(\n",
|
|
" (Cont_linear0): Linear(in_features=832, out_features=256, bias=False)\n",
|
|
" (Cont_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Cont_act0): SiLU()\n",
|
|
" (Cont_linear1): Linear(in_features=256, out_features=256, bias=False)\n",
|
|
" (Cont_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Cont_act1): SiLU()\n",
|
|
" (Cont_linear2): Linear(in_features=256, out_features=256, bias=False)\n",
|
|
" (Cont_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Cont_act2): SiLU()\n",
|
|
" )\n",
|
|
" (mean_layer): Linear(in_features=256, out_features=1, bias=True)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (_task_behavior): OptimizedModule(\n",
|
|
" (_orig_mod): ImagBehavior(\n",
|
|
" (_world_model): WorldModel(\n",
|
|
" (encoder): MultiEncoder(\n",
|
|
" (_cnn): ConvEncoder(\n",
|
|
" (layers): Sequential(\n",
|
|
" (0): Conv2dSamePad(166, 32, kernel_size=(4, 4), stride=(2, 2), bias=False)\n",
|
|
" (1): ImgChLayerNorm(\n",
|
|
" (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n",
|
|
" )\n",
|
|
" (2): SiLU()\n",
|
|
" (3): Conv2dSamePad(32, 64, kernel_size=(4, 4), stride=(2, 2), bias=False)\n",
|
|
" (4): ImgChLayerNorm(\n",
|
|
" (norm): LayerNorm((64,), eps=0.001, elementwise_affine=True)\n",
|
|
" )\n",
|
|
" (5): SiLU()\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (_mlp): MLP(\n",
|
|
" (layers): Sequential(\n",
|
|
" (Encoder_linear0): Linear(in_features=102, out_features=16, bias=False)\n",
|
|
" (Encoder_norm0): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Encoder_act0): SiLU()\n",
|
|
" (Encoder_linear1): Linear(in_features=16, out_features=16, bias=False)\n",
|
|
" (Encoder_norm1): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Encoder_act1): SiLU()\n",
|
|
" )\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (dynamics): RSSM(\n",
|
|
" (_img_in_layers): Sequential(\n",
|
|
" (0): Linear(in_features=619, out_features=256, bias=False)\n",
|
|
" (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (2): SiLU()\n",
|
|
" )\n",
|
|
" (_cell): GRUCell(\n",
|
|
" (layers): Sequential(\n",
|
|
" (GRU_linear): Linear(in_features=512, out_features=768, bias=False)\n",
|
|
" (GRU_norm): LayerNorm((768,), eps=0.001, elementwise_affine=True)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (_img_out_layers): Sequential(\n",
|
|
" (0): Linear(in_features=256, out_features=256, bias=False)\n",
|
|
" (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (2): SiLU()\n",
|
|
" )\n",
|
|
" (_obs_out_layers): Sequential(\n",
|
|
" (0): Linear(in_features=848, out_features=256, bias=False)\n",
|
|
" (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (2): SiLU()\n",
|
|
" )\n",
|
|
" (_imgs_stat_layer): Linear(in_features=256, out_features=576, bias=True)\n",
|
|
" (_obs_stat_layer): Linear(in_features=256, out_features=576, bias=True)\n",
|
|
" )\n",
|
|
" (heads): ModuleDict(\n",
|
|
" (decoder): MultiDecoder(\n",
|
|
" (_cnn): ConvDecoder(\n",
|
|
" (_linear_layer): Linear(in_features=832, out_features=256, bias=True)\n",
|
|
" (layers): Sequential(\n",
|
|
" (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
|
" (1): ImgChLayerNorm(\n",
|
|
" (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n",
|
|
" )\n",
|
|
" (2): SiLU()\n",
|
|
" (3): ConvTranspose2d(32, 166, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (_mlp): MLP(\n",
|
|
" (layers): Sequential(\n",
|
|
" (Decoder_linear0): Linear(in_features=832, out_features=16, bias=False)\n",
|
|
" (Decoder_norm0): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Decoder_act0): SiLU()\n",
|
|
" (Decoder_linear1): Linear(in_features=16, out_features=16, bias=False)\n",
|
|
" (Decoder_norm1): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Decoder_act1): SiLU()\n",
|
|
" )\n",
|
|
" (mean_layer): ModuleDict(\n",
|
|
" (state_inventory): Linear(in_features=16, out_features=102, bias=True)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (reward): MLP(\n",
|
|
" (layers): Sequential(\n",
|
|
" (Reward_linear0): Linear(in_features=832, out_features=256, bias=False)\n",
|
|
" (Reward_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Reward_act0): SiLU()\n",
|
|
" (Reward_linear1): Linear(in_features=256, out_features=256, bias=False)\n",
|
|
" (Reward_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Reward_act1): SiLU()\n",
|
|
" (Reward_linear2): Linear(in_features=256, out_features=256, bias=False)\n",
|
|
" (Reward_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Reward_act2): SiLU()\n",
|
|
" )\n",
|
|
" (mean_layer): Linear(in_features=256, out_features=255, bias=True)\n",
|
|
" )\n",
|
|
" (cont): MLP(\n",
|
|
" (layers): Sequential(\n",
|
|
" (Cont_linear0): Linear(in_features=832, out_features=256, bias=False)\n",
|
|
" (Cont_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Cont_act0): SiLU()\n",
|
|
" (Cont_linear1): Linear(in_features=256, out_features=256, bias=False)\n",
|
|
" (Cont_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Cont_act1): SiLU()\n",
|
|
" (Cont_linear2): Linear(in_features=256, out_features=256, bias=False)\n",
|
|
" (Cont_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Cont_act2): SiLU()\n",
|
|
" )\n",
|
|
" (mean_layer): Linear(in_features=256, out_features=1, bias=True)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (actor): MLP(\n",
|
|
" (layers): Sequential(\n",
|
|
" (Actor_linear0): Linear(in_features=832, out_features=256, bias=False)\n",
|
|
" (Actor_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Actor_act0): SiLU()\n",
|
|
" (Actor_linear1): Linear(in_features=256, out_features=256, bias=False)\n",
|
|
" (Actor_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Actor_act1): SiLU()\n",
|
|
" (Actor_linear2): Linear(in_features=256, out_features=256, bias=False)\n",
|
|
" (Actor_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Actor_act2): SiLU()\n",
|
|
" )\n",
|
|
" (mean_layer): Linear(in_features=256, out_features=43, bias=True)\n",
|
|
" )\n",
|
|
" (value): MLP(\n",
|
|
" (layers): Sequential(\n",
|
|
" (Value_linear0): Linear(in_features=832, out_features=256, bias=False)\n",
|
|
" (Value_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Value_act0): SiLU()\n",
|
|
" (Value_linear1): Linear(in_features=256, out_features=256, bias=False)\n",
|
|
" (Value_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Value_act1): SiLU()\n",
|
|
" )\n",
|
|
" (mean_layer): Linear(in_features=256, out_features=255, bias=True)\n",
|
|
" )\n",
|
|
" (_slow_value): MLP(\n",
|
|
" (layers): Sequential(\n",
|
|
" (Value_linear0): Linear(in_features=832, out_features=256, bias=False)\n",
|
|
" (Value_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Value_act0): SiLU()\n",
|
|
" (Value_linear1): Linear(in_features=256, out_features=256, bias=False)\n",
|
|
" (Value_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Value_act1): SiLU()\n",
|
|
" )\n",
|
|
" (mean_layer): Linear(in_features=256, out_features=255, bias=True)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (_expl_behavior): OptimizedModule(\n",
|
|
" (_orig_mod): ImagBehavior(\n",
|
|
" (_world_model): WorldModel(\n",
|
|
" (encoder): MultiEncoder(\n",
|
|
" (_cnn): ConvEncoder(\n",
|
|
" (layers): Sequential(\n",
|
|
" (0): Conv2dSamePad(166, 32, kernel_size=(4, 4), stride=(2, 2), bias=False)\n",
|
|
" (1): ImgChLayerNorm(\n",
|
|
" (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n",
|
|
" )\n",
|
|
" (2): SiLU()\n",
|
|
" (3): Conv2dSamePad(32, 64, kernel_size=(4, 4), stride=(2, 2), bias=False)\n",
|
|
" (4): ImgChLayerNorm(\n",
|
|
" (norm): LayerNorm((64,), eps=0.001, elementwise_affine=True)\n",
|
|
" )\n",
|
|
" (5): SiLU()\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (_mlp): MLP(\n",
|
|
" (layers): Sequential(\n",
|
|
" (Encoder_linear0): Linear(in_features=102, out_features=16, bias=False)\n",
|
|
" (Encoder_norm0): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Encoder_act0): SiLU()\n",
|
|
" (Encoder_linear1): Linear(in_features=16, out_features=16, bias=False)\n",
|
|
" (Encoder_norm1): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Encoder_act1): SiLU()\n",
|
|
" )\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (dynamics): RSSM(\n",
|
|
" (_img_in_layers): Sequential(\n",
|
|
" (0): Linear(in_features=619, out_features=256, bias=False)\n",
|
|
" (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (2): SiLU()\n",
|
|
" )\n",
|
|
" (_cell): GRUCell(\n",
|
|
" (layers): Sequential(\n",
|
|
" (GRU_linear): Linear(in_features=512, out_features=768, bias=False)\n",
|
|
" (GRU_norm): LayerNorm((768,), eps=0.001, elementwise_affine=True)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (_img_out_layers): Sequential(\n",
|
|
" (0): Linear(in_features=256, out_features=256, bias=False)\n",
|
|
" (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (2): SiLU()\n",
|
|
" )\n",
|
|
" (_obs_out_layers): Sequential(\n",
|
|
" (0): Linear(in_features=848, out_features=256, bias=False)\n",
|
|
" (1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (2): SiLU()\n",
|
|
" )\n",
|
|
" (_imgs_stat_layer): Linear(in_features=256, out_features=576, bias=True)\n",
|
|
" (_obs_stat_layer): Linear(in_features=256, out_features=576, bias=True)\n",
|
|
" )\n",
|
|
" (heads): ModuleDict(\n",
|
|
" (decoder): MultiDecoder(\n",
|
|
" (_cnn): ConvDecoder(\n",
|
|
" (_linear_layer): Linear(in_features=832, out_features=256, bias=True)\n",
|
|
" (layers): Sequential(\n",
|
|
" (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
|
" (1): ImgChLayerNorm(\n",
|
|
" (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)\n",
|
|
" )\n",
|
|
" (2): SiLU()\n",
|
|
" (3): ConvTranspose2d(32, 166, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (_mlp): MLP(\n",
|
|
" (layers): Sequential(\n",
|
|
" (Decoder_linear0): Linear(in_features=832, out_features=16, bias=False)\n",
|
|
" (Decoder_norm0): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Decoder_act0): SiLU()\n",
|
|
" (Decoder_linear1): Linear(in_features=16, out_features=16, bias=False)\n",
|
|
" (Decoder_norm1): LayerNorm((16,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Decoder_act1): SiLU()\n",
|
|
" )\n",
|
|
" (mean_layer): ModuleDict(\n",
|
|
" (state_inventory): Linear(in_features=16, out_features=102, bias=True)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (reward): MLP(\n",
|
|
" (layers): Sequential(\n",
|
|
" (Reward_linear0): Linear(in_features=832, out_features=256, bias=False)\n",
|
|
" (Reward_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Reward_act0): SiLU()\n",
|
|
" (Reward_linear1): Linear(in_features=256, out_features=256, bias=False)\n",
|
|
" (Reward_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Reward_act1): SiLU()\n",
|
|
" (Reward_linear2): Linear(in_features=256, out_features=256, bias=False)\n",
|
|
" (Reward_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Reward_act2): SiLU()\n",
|
|
" )\n",
|
|
" (mean_layer): Linear(in_features=256, out_features=255, bias=True)\n",
|
|
" )\n",
|
|
" (cont): MLP(\n",
|
|
" (layers): Sequential(\n",
|
|
" (Cont_linear0): Linear(in_features=832, out_features=256, bias=False)\n",
|
|
" (Cont_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Cont_act0): SiLU()\n",
|
|
" (Cont_linear1): Linear(in_features=256, out_features=256, bias=False)\n",
|
|
" (Cont_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Cont_act1): SiLU()\n",
|
|
" (Cont_linear2): Linear(in_features=256, out_features=256, bias=False)\n",
|
|
" (Cont_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Cont_act2): SiLU()\n",
|
|
" )\n",
|
|
" (mean_layer): Linear(in_features=256, out_features=1, bias=True)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (actor): MLP(\n",
|
|
" (layers): Sequential(\n",
|
|
" (Actor_linear0): Linear(in_features=832, out_features=256, bias=False)\n",
|
|
" (Actor_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Actor_act0): SiLU()\n",
|
|
" (Actor_linear1): Linear(in_features=256, out_features=256, bias=False)\n",
|
|
" (Actor_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Actor_act1): SiLU()\n",
|
|
" (Actor_linear2): Linear(in_features=256, out_features=256, bias=False)\n",
|
|
" (Actor_norm2): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Actor_act2): SiLU()\n",
|
|
" )\n",
|
|
" (mean_layer): Linear(in_features=256, out_features=43, bias=True)\n",
|
|
" )\n",
|
|
" (value): MLP(\n",
|
|
" (layers): Sequential(\n",
|
|
" (Value_linear0): Linear(in_features=832, out_features=256, bias=False)\n",
|
|
" (Value_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Value_act0): SiLU()\n",
|
|
" (Value_linear1): Linear(in_features=256, out_features=256, bias=False)\n",
|
|
" (Value_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Value_act1): SiLU()\n",
|
|
" )\n",
|
|
" (mean_layer): Linear(in_features=256, out_features=255, bias=True)\n",
|
|
" )\n",
|
|
" (_slow_value): MLP(\n",
|
|
" (layers): Sequential(\n",
|
|
" (Value_linear0): Linear(in_features=832, out_features=256, bias=False)\n",
|
|
" (Value_norm0): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Value_act0): SiLU()\n",
|
|
" (Value_linear1): Linear(in_features=256, out_features=256, bias=False)\n",
|
|
" (Value_norm1): LayerNorm((256,), eps=0.001, elementwise_affine=True)\n",
|
|
" (Value_act1): SiLU()\n",
|
|
" )\n",
|
|
" (mean_layer): Linear(in_features=256, out_features=255, bias=True)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" )\n",
|
|
")\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"config = parse_args(argv)\n",
|
|
"config.num_actions = acts.n if hasattr(acts, \"n\") else acts.shape[0]\n",
|
|
"agent = Dreamer(\n",
|
|
" train_envs[0].observation_space,\n",
|
|
" train_envs[0].action_space,\n",
|
|
" config,\n",
|
|
" tlogger,\n",
|
|
" train_dataset,\n",
|
|
").to(config.device)\n",
|
|
"print(agent)\n",
|
|
"agent.requires_grad_(requires_grad=False)\n",
|
|
"if (logdir / \"latest.pt\").exists():\n",
|
|
" checkpoint = torch.load(logdir / \"latest.pt\")\n",
|
|
" agent.load_state_dict(checkpoint[\"agent_state_dict\"])\n",
|
|
" tools.recursively_load_optim_state_dict(agent, checkpoint[\"optims_state_dict\"])\n",
|
|
" agent._should_pretrain._once = False\n",
|
|
" logger.warning(f\"Loaded model from {logdir / 'latest.pt'}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"- note model_opt includes actor.wm\n",
|
|
" - encoder\n",
|
|
" - rssm\n",
|
|
" - heads\n",
|
|
"- actor"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Now lets play"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"assert state is not None\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"# state"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from tools import convert, add_to_cache\n",
|
|
"envs = train_envs\n",
|
|
"cache = train_eps\n",
|
|
"\n",
|
|
"step, episode = 0, 0\n",
|
|
"done = np.ones(len(envs), bool)\n",
|
|
"length = np.zeros(len(envs), np.int32)\n",
|
|
"obs = [None] * len(envs)\n",
|
|
"agent_state = None\n",
|
|
"reward = [0] * len(envs)\n",
|
|
"\n",
|
|
"indices = [index for index, d in enumerate(done) if d]\n",
|
|
"results = [envs[i].reset() for i in indices]\n",
|
|
"results = [r() for r in results]\n",
|
|
"for index, result in zip(indices, results):\n",
|
|
" t = result.copy()\n",
|
|
" t = {k: convert(v) for k, v in t.items()}\n",
|
|
" # action will be added to transition in add_to_cache\n",
|
|
" t[\"reward\"] = 0.0\n",
|
|
" t[\"discount\"] = 1.0\n",
|
|
" # initial state should be added to cache\n",
|
|
" add_to_cache(cache, envs[index].id, t)\n",
|
|
" # replace obs with done by initial state\n",
|
|
" obs[index] = result\n",
|
|
"# step agents"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/nbs/../networks.py:790: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)\n",
|
|
" ret = F.conv2d(\n"
|
|
]
|
|
},
|
|
{
|
|
"ename": "AssertionError",
|
|
"evalue": "(torch.Size([256, 32, 8, 8, 166]), torch.Size([256, 32, 12, 12, 166]))",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
"\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
|
|
"Cell \u001b[0;32mIn[10], line 6\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# from tools.simulate\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# step\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# step, episode, done, length, obs, agent_state, reward = state\u001b[39;00m\n\u001b[1;32m 5\u001b[0m obs2 \u001b[38;5;241m=\u001b[39m {k: np\u001b[38;5;241m.\u001b[39mstack([o[k] \u001b[38;5;28;01mfor\u001b[39;00m o \u001b[38;5;129;01min\u001b[39;00m obs]) \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m obs[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlog_\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m k}\n\u001b[0;32m----> 6\u001b[0m action, agent_state \u001b[38;5;241m=\u001b[39m \u001b[43magent\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobs2\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdone\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43magent_state\u001b[49m\u001b[43m)\u001b[49m\n",
|
|
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/nbs/../dreamer.py:71\u001b[0m, in \u001b[0;36mDreamer.__call__\u001b[0;34m(self, obs, reset, state, training)\u001b[0m\n\u001b[1;32m 65\u001b[0m steps \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 66\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_config\u001b[38;5;241m.\u001b[39mpretrain\n\u001b[1;32m 67\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_should_pretrain()\n\u001b[1;32m 68\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_should_train(step)\n\u001b[1;32m 69\u001b[0m )\n\u001b[1;32m 70\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(steps):\n\u001b[0;32m---> 71\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_train\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_update_count \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_metrics[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mupdate_count\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_update_count\n",
|
|
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/nbs/../dreamer.py:124\u001b[0m, in \u001b[0;36mDreamer._train\u001b[0;34m(self, data)\u001b[0m\n\u001b[1;32m 122\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_train\u001b[39m(\u001b[38;5;28mself\u001b[39m, data):\n\u001b[1;32m 123\u001b[0m metrics \u001b[38;5;241m=\u001b[39m {}\n\u001b[0;32m--> 124\u001b[0m post, context, mets \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_wm\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_train\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 125\u001b[0m metrics\u001b[38;5;241m.\u001b[39mupdate(mets)\n\u001b[1;32m 126\u001b[0m start \u001b[38;5;241m=\u001b[39m post\n",
|
|
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/nbs/../models.py:143\u001b[0m, in \u001b[0;36mWorldModel._train\u001b[0;34m(self, data)\u001b[0m\n\u001b[1;32m 141\u001b[0m losses \u001b[38;5;241m=\u001b[39m {}\n\u001b[1;32m 142\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m name, pred \u001b[38;5;129;01min\u001b[39;00m preds\u001b[38;5;241m.\u001b[39mitems():\n\u001b[0;32m--> 143\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[43mpred\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlog_prob\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m[\u001b[49m\u001b[43mname\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 144\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m loss\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m==\u001b[39m embed\u001b[38;5;241m.\u001b[39mshape[:\u001b[38;5;241m2\u001b[39m], (name, loss\u001b[38;5;241m.\u001b[39mshape)\n\u001b[1;32m 145\u001b[0m losses[name] \u001b[38;5;241m=\u001b[39m loss\n",
|
|
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/worldmodels/dreamerv3-torch/nbs/../tools.py:528\u001b[0m, in \u001b[0;36mMSEDist.log_prob\u001b[0;34m(self, value)\u001b[0m\n\u001b[1;32m 527\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mlog_prob\u001b[39m(\u001b[38;5;28mself\u001b[39m, value):\n\u001b[0;32m--> 528\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_mode\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m==\u001b[39m value\u001b[38;5;241m.\u001b[39mshape, (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_mode\u001b[38;5;241m.\u001b[39mshape, value\u001b[38;5;241m.\u001b[39mshape)\n\u001b[1;32m 529\u001b[0m distance \u001b[38;5;241m=\u001b[39m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_mode \u001b[38;5;241m-\u001b[39m value) \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m \u001b[38;5;241m2\u001b[39m\n\u001b[1;32m 530\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_agg \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmean\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n",
|
|
"\u001b[0;31mAssertionError\u001b[0m: (torch.Size([256, 32, 8, 8, 166]), torch.Size([256, 32, 12, 12, 166]))"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# from tools.simulate\n",
|
|
"\n",
|
|
"# step\n",
|
|
"# step, episode, done, length, obs, agent_state, reward = state\n",
|
|
"obs2 = {k: np.stack([o[k] for o in obs]) for k in obs[0] if \"log_\" not in k}\n",
|
|
"action, agent_state = agent(obs2, done, agent_state)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"==========================================================================================\n",
|
|
"Layer (type:depth-idx) Param #\n",
|
|
"==========================================================================================\n",
|
|
"Dreamer --\n",
|
|
"├─OptimizedModule: 1-1 --\n",
|
|
"│ └─WorldModel: 2-1 --\n",
|
|
"│ │ └─MultiEncoder: 3-1 (44,480)\n",
|
|
"│ │ └─RSSM: 3-2 (2,397,952)\n",
|
|
"│ │ └─ModuleDict: 3-3 (1,580,204)\n",
|
|
"├─OptimizedModule: 1-2 --\n",
|
|
"│ └─ImagBehavior: 2-2 4,022,636\n",
|
|
"│ │ └─WorldModel: 3-4 (recursive)\n",
|
|
"│ │ └─MLP: 3-5 (536,875)\n",
|
|
"│ │ └─MLP: 3-6 (525,311)\n",
|
|
"│ │ └─MLP: 3-7 (525,311)\n",
|
|
"├─OptimizedModule: 1-3 (recursive)\n",
|
|
"│ └─ImagBehavior: 2-3 (recursive)\n",
|
|
"│ │ └─WorldModel: 3-8 (recursive)\n",
|
|
"│ │ └─MLP: 3-9 (recursive)\n",
|
|
"│ │ └─MLP: 3-10 (recursive)\n",
|
|
"│ │ └─MLP: 3-11 (recursive)\n",
|
|
"==========================================================================================\n",
|
|
"Total params: 9,632,769\n",
|
|
"Trainable params: 0\n",
|
|
"Non-trainable params: 9,632,769\n",
|
|
"=========================================================================================="
|
|
]
|
|
},
|
|
"execution_count": 19,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"from torchinfo import summary\n",
|
|
"\n",
|
|
"summary(agent, input=(obs, done, agent_state), depth=3)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"==========================================================================================\n",
|
|
"Layer (type:depth-idx) Param #\n",
|
|
"==========================================================================================\n",
|
|
"Dreamer --\n",
|
|
"├─OptimizedModule: 1-1 --\n",
|
|
"│ └─WorldModel: 2-1 --\n",
|
|
"│ │ └─MultiEncoder: 3-1 --\n",
|
|
"│ │ │ └─ConvEncoder: 4-1 (42,528)\n",
|
|
"│ │ │ └─MLP: 4-2 (1,952)\n",
|
|
"│ │ └─RSSM: 3-2 512\n",
|
|
"│ │ │ └─Sequential: 4-3 (273,664)\n",
|
|
"│ │ │ └─GRUCell: 4-4 (1,182,720)\n",
|
|
"│ │ │ └─Sequential: 4-5 (131,584)\n",
|
|
"│ │ │ └─Sequential: 4-6 (283,136)\n",
|
|
"│ │ │ └─Linear: 4-7 (263,168)\n",
|
|
"│ │ │ └─Linear: 4-8 (263,168)\n",
|
|
"│ │ └─ModuleDict: 3-3 --\n",
|
|
"│ │ │ └─MultiDecoder: 4-9 (462,764)\n",
|
|
"│ │ │ └─MLP: 4-10 (591,359)\n",
|
|
"│ │ │ └─MLP: 4-11 (526,081)\n",
|
|
"├─OptimizedModule: 1-2 --\n",
|
|
"│ └─ImagBehavior: 2-2 4,022,636\n",
|
|
"│ │ └─WorldModel: 3-4 (recursive)\n",
|
|
"│ │ │ └─MultiEncoder: 4-12 (recursive)\n",
|
|
"│ │ │ └─RSSM: 4-13 (recursive)\n",
|
|
"│ │ │ └─ModuleDict: 4-14 (recursive)\n",
|
|
"│ │ └─MLP: 3-5 --\n",
|
|
"│ │ │ └─Sequential: 4-15 (525,824)\n",
|
|
"│ │ │ └─Linear: 4-16 (11,051)\n",
|
|
"│ │ └─MLP: 3-6 --\n",
|
|
"│ │ │ └─Sequential: 4-17 (459,776)\n",
|
|
"│ │ │ └─Linear: 4-18 (65,535)\n",
|
|
"│ │ └─MLP: 3-7 --\n",
|
|
"│ │ │ └─Sequential: 4-19 (459,776)\n",
|
|
"│ │ │ └─Linear: 4-20 (65,535)\n",
|
|
"├─OptimizedModule: 1-3 (recursive)\n",
|
|
"│ └─ImagBehavior: 2-3 (recursive)\n",
|
|
"│ │ └─WorldModel: 3-8 (recursive)\n",
|
|
"│ │ │ └─MultiEncoder: 4-21 (recursive)\n",
|
|
"│ │ │ └─RSSM: 4-22 (recursive)\n",
|
|
"│ │ │ └─ModuleDict: 4-23 (recursive)\n",
|
|
"│ │ └─MLP: 3-9 (recursive)\n",
|
|
"│ │ │ └─Sequential: 4-24 (recursive)\n",
|
|
"│ │ │ └─Linear: 4-25 (recursive)\n",
|
|
"│ │ └─MLP: 3-10 (recursive)\n",
|
|
"│ │ │ └─Sequential: 4-26 (recursive)\n",
|
|
"│ │ │ └─Linear: 4-27 (recursive)\n",
|
|
"│ │ └─MLP: 3-11 (recursive)\n",
|
|
"│ │ │ └─Sequential: 4-28 (recursive)\n",
|
|
"│ │ │ └─Linear: 4-29 (recursive)\n",
|
|
"==========================================================================================\n",
|
|
"Total params: 9,632,769\n",
|
|
"Trainable params: 0\n",
|
|
"Non-trainable params: 9,632,769\n",
|
|
"=========================================================================================="
|
|
]
|
|
},
|
|
"execution_count": 28,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"summary(agent, input=(obs, done, agent_state), depth=4)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# agent._wm.heads"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Fine grained torchinfo"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"wm = agent._wm\n",
|
|
"data = next(agent._dataset) \n",
|
|
"# self._train()\n",
|
|
"# post, context, mets = wm._train(data)\n",
|
|
"data = wm.preprocess(data)\n",
|
|
"embed = wm.encoder(data)\n",
|
|
"post, prior = wm.dynamics.observe(\n",
|
|
" embed, data[\"action\"], data[\"is_first\"]\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"===================================================================================================================\n",
|
|
"Layer (type:depth-idx) Input Shape Output Shape Param #\n",
|
|
"===================================================================================================================\n",
|
|
"MultiEncoder [256, 16, 130, 110, 3] [256, 16, 592] --\n",
|
|
"├─ConvEncoder: 1-1 [256, 16, 12, 12, 166] [256, 16, 576] --\n",
|
|
"│ └─Sequential: 2-1 [4096, 166, 12, 12] [4096, 16, 6, 6] --\n",
|
|
"│ │ └─Conv2dSamePad: 3-1 [4096, 166, 12, 12] [4096, 16, 6, 6] (42,496)\n",
|
|
"│ │ └─ImgChLayerNorm: 3-2 [4096, 16, 6, 6] [4096, 16, 6, 6] --\n",
|
|
"│ │ │ └─LayerNorm: 4-1 [4096, 6, 6, 16] [4096, 6, 6, 16] (32)\n",
|
|
"│ │ └─SiLU: 3-3 [4096, 16, 6, 6] [4096, 16, 6, 6] --\n",
|
|
"├─MLP: 1-2 [256, 16, 102] [256, 16, 16] --\n",
|
|
"│ └─Sequential: 2-2 [256, 16, 102] [256, 16, 16] --\n",
|
|
"│ │ └─Linear: 3-4 [256, 16, 102] [256, 16, 16] (1,632)\n",
|
|
"│ │ └─LayerNorm: 3-5 [256, 16, 16] [256, 16, 16] (32)\n",
|
|
"│ │ └─SiLU: 3-6 [256, 16, 16] [256, 16, 16] --\n",
|
|
"│ │ └─Linear: 3-7 [256, 16, 16] [256, 16, 16] (256)\n",
|
|
"│ │ └─LayerNorm: 3-8 [256, 16, 16] [256, 16, 16] (32)\n",
|
|
"│ │ └─SiLU: 3-9 [256, 16, 16] [256, 16, 16] --\n",
|
|
"===================================================================================================================\n",
|
|
"Total params: 44,480\n",
|
|
"Trainable params: 0\n",
|
|
"Non-trainable params: 44,480\n",
|
|
"Total mult-adds (G): 6.27\n",
|
|
"===================================================================================================================\n",
|
|
"Input size (MB): 1367.93\n",
|
|
"Forward/backward pass size (MB): 39.85\n",
|
|
"Params size (MB): 0.18\n",
|
|
"Estimated Total Size (MB): 1407.96\n",
|
|
"==================================================================================================================="
|
|
]
|
|
},
|
|
"execution_count": 21,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"summary(wm.encoder, input_data=(data,), depth=4, col_names=[\"input_size\", \"output_size\", \"num_params\", ])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"decoder\n",
|
|
"===================================================================================================================\n",
|
|
"Layer (type:depth-idx) Input Shape Output Shape Param #\n",
|
|
"===================================================================================================================\n",
|
|
"MultiDecoder [256, 16, 1536] -- --\n",
|
|
"├─ConvDecoder: 1-1 [256, 16, 1536] [256, 16, 8, 8, 166] --\n",
|
|
"│ └─Linear: 2-1 [256, 16, 1536] [256, 16, 256] (393,472)\n",
|
|
"│ └─Sequential: 2-2 [4096, 16, 4, 4] [4096, 166, 8, 8] --\n",
|
|
"│ │ └─ConvTranspose2d: 3-1 [4096, 16, 4, 4] [4096, 166, 8, 8] (42,662)\n",
|
|
"├─MLP: 1-2 [256, 16, 1536] -- --\n",
|
|
"│ └─Sequential: 2-3 [256, 16, 1536] [256, 16, 16] --\n",
|
|
"│ │ └─Linear: 3-2 [256, 16, 1536] [256, 16, 16] (24,576)\n",
|
|
"│ │ └─LayerNorm: 3-3 [256, 16, 16] [256, 16, 16] (32)\n",
|
|
"│ │ └─SiLU: 3-4 [256, 16, 16] [256, 16, 16] --\n",
|
|
"│ │ └─Linear: 3-5 [256, 16, 16] [256, 16, 16] (256)\n",
|
|
"│ │ └─LayerNorm: 3-6 [256, 16, 16] [256, 16, 16] (32)\n",
|
|
"│ │ └─SiLU: 3-7 [256, 16, 16] [256, 16, 16] --\n",
|
|
"│ └─ModuleDict: 2-4 -- -- --\n",
|
|
"│ │ └─Linear: 3-8 [256, 16, 16] [256, 16, 102] (1,734)\n",
|
|
"===================================================================================================================\n",
|
|
"Total params: 462,764\n",
|
|
"Trainable params: 0\n",
|
|
"Non-trainable params: 462,764\n",
|
|
"Total mult-adds (G): 11.29\n",
|
|
"===================================================================================================================\n",
|
|
"Input size (MB): 25.17\n",
|
|
"Forward/backward pass size (MB): 361.96\n",
|
|
"Params size (MB): 1.85\n",
|
|
"Estimated Total Size (MB): 388.97\n",
|
|
"===================================================================================================================\n",
|
|
"Summary Failed for reward Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Sequential: 1, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 1]\n",
|
|
"Summary Failed for cont Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Sequential: 1, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 1]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# heads\n",
|
|
"feat = wm.dynamics.get_feat(post)\n",
|
|
"for name, head in wm.heads.items():\n",
|
|
" try:\n",
|
|
" o = summary(head, input_data=(feat,), depth=3, col_names=[\"input_size\", \"output_size\", \"num_params\", ])\n",
|
|
" print(name)\n",
|
|
" print(o)\n",
|
|
" except Exception as e:\n",
|
|
" print(f\"Summary Failed for {name} {e}\")\n",
|
|
" continue"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# fail as no call method\n",
|
|
"# summary(wm.dynamics, input_data=(embed, data[\"action\"], data[\"is_first\"]), depth=3, col_names=[\"output_size\", \"num_params\", ])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"===================================================================================================================\n",
|
|
"Layer (type:depth-idx) Output Shape Param # Output Shape\n",
|
|
"===================================================================================================================\n",
|
|
"Sequential [256, 16, 256] -- [256, 16, 256]\n",
|
|
"├─Linear: 1-1 [256, 16, 256] (393,216) [256, 16, 256]\n",
|
|
"├─LayerNorm: 1-2 [256, 16, 256] (512) [256, 16, 256]\n",
|
|
"├─SiLU: 1-3 [256, 16, 256] -- [256, 16, 256]\n",
|
|
"├─Linear: 1-4 [256, 16, 256] (65,536) [256, 16, 256]\n",
|
|
"├─LayerNorm: 1-5 [256, 16, 256] (512) [256, 16, 256]\n",
|
|
"├─SiLU: 1-6 [256, 16, 256] -- [256, 16, 256]\n",
|
|
"├─Linear: 1-7 [256, 16, 256] (65,536) [256, 16, 256]\n",
|
|
"├─LayerNorm: 1-8 [256, 16, 256] (512) [256, 16, 256]\n",
|
|
"├─SiLU: 1-9 [256, 16, 256] -- [256, 16, 256]\n",
|
|
"===================================================================================================================\n",
|
|
"Total params: 525,824\n",
|
|
"Trainable params: 0\n",
|
|
"Non-trainable params: 525,824\n",
|
|
"Total mult-adds (M): 134.61\n",
|
|
"===================================================================================================================\n",
|
|
"Input size (MB): 25.17\n",
|
|
"Forward/backward pass size (MB): 50.33\n",
|
|
"Params size (MB): 2.10\n",
|
|
"Estimated Total Size (MB): 77.60\n",
|
|
"==================================================================================================================="
|
|
]
|
|
},
|
|
"execution_count": 24,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"actor = agent._task_behavior.actor\n",
|
|
"\n",
|
|
"summary(actor.layers, input_data=(feat,), depth=3, col_names=[\"output_size\", \"num_params\", \"output_size\" ])\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"===================================================================================================================\n",
|
|
"Layer (type:depth-idx) Output Shape Param # Output Shape\n",
|
|
"===================================================================================================================\n",
|
|
"Sequential [256, 16, 256] -- [256, 16, 256]\n",
|
|
"├─Linear: 1-1 [256, 16, 256] (393,216) [256, 16, 256]\n",
|
|
"├─LayerNorm: 1-2 [256, 16, 256] (512) [256, 16, 256]\n",
|
|
"├─SiLU: 1-3 [256, 16, 256] -- [256, 16, 256]\n",
|
|
"├─Linear: 1-4 [256, 16, 256] (65,536) [256, 16, 256]\n",
|
|
"├─LayerNorm: 1-5 [256, 16, 256] (512) [256, 16, 256]\n",
|
|
"├─SiLU: 1-6 [256, 16, 256] -- [256, 16, 256]\n",
|
|
"├─Linear: 1-7 [256, 16, 256] (65,536) [256, 16, 256]\n",
|
|
"├─LayerNorm: 1-8 [256, 16, 256] (512) [256, 16, 256]\n",
|
|
"├─SiLU: 1-9 [256, 16, 256] -- [256, 16, 256]\n",
|
|
"===================================================================================================================\n",
|
|
"Total params: 525,824\n",
|
|
"Trainable params: 0\n",
|
|
"Non-trainable params: 525,824\n",
|
|
"Total mult-adds (M): 134.61\n",
|
|
"===================================================================================================================\n",
|
|
"Input size (MB): 25.17\n",
|
|
"Forward/backward pass size (MB): 50.33\n",
|
|
"Params size (MB): 2.10\n",
|
|
"Estimated Total Size (MB): 77.60\n",
|
|
"==================================================================================================================="
|
|
]
|
|
},
|
|
"execution_count": 25,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"value = agent._task_behavior.actor\n",
|
|
"summary(value.layers, input_data=(feat,), depth=3, col_names=[\"output_size\", \"num_params\", \"output_size\" ])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"8268"
|
|
]
|
|
},
|
|
"execution_count": 26,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"8268"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": ".venv",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.16"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|