Files
dreamerv3-torch/nbs/02_torchinfo.ipynb
T

554 lines
26 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this notebook we load a saved dreamer, and run it, to look at params, speed and improve hackability"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading textures from cache\n"
]
}
],
"source": [
"# TODO make this a proper package\n",
"import os, sys\n",
"sys.path.append('..')\n",
"\n",
"\n",
"from dreamer import parse_args, main, make_env, make_dataset, count_steps,Dreamer"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['../dreamer.py', '--configs', 'craftax_small', '--logdir', '../logdir/craftax_small']\n"
]
},
{
"data": {
"text/plain": [
"Namespace(act='SiLU', action_repeat=1, actor={'layers': 3, 'dist': 'onehot', 'entropy': 0.0003, 'unimix_ratio': 0.01, 'std': 'none', 'min_std': 0.1, 'max_std': 1.0, 'temp': 0.1, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 1.0}, batch_length=16, batch_size=128, compile=True, cont_head={'layers': 3, 'loss_scale': 1.0, 'outscale': 1.0}, critic={'layers': 2, 'dist': 'symlog_disc', 'slow_target': True, 'slow_target_update': 1, 'slow_target_fraction': 0.02, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 0.0}, dataset_size=1000000, debug=False, decoder={'mlp_keys': 'state', 'cnn_keys': '$^', 'act': 'SiLU', 'norm': True, 'cnn_depth': 32, 'kernel_size': 4, 'minres': 4, 'mlp_layers': 3, 'mlp_units': 256, 'cnn_sigmoid': False, 'image_dist': 'mse', 'vector_dist': 'symlog_mse', 'outscale': 1.0}, deterministic_run=False, device='cuda:0', disag_action_cond=False, disag_layers=4, disag_log=True, disag_models=10, disag_offset=1, disag_target='stoch', disag_units=400, discount=0.997, discount_lambda=0.95, dyn_deter=512, dyn_discrete=32, dyn_hidden=256, dyn_mean_act='none', dyn_min_std=0.1, dyn_rec_depth=1, dyn_scale=0.5, dyn_std_act='sigmoid2', dyn_stoch=32, encoder={'mlp_keys': 'state', 'cnn_keys': '$^', 'act': 'SiLU', 'norm': True, 'cnn_depth': 32, 'kernel_size': 4, 'minres': 4, 'mlp_layers': 3, 'mlp_units': 256, 'symlog_inputs': True}, envs=1, eval_episode_num=10, eval_every=10000.0, eval_state_mean=False, evaldir=None, expl_behavior='greedy', expl_extr_scale=0.0, expl_intr_scale=1.0, expl_until=0, grad_clip=1000, grad_heads=('decoder', 'reward', 'cont'), grayscale=False, imag_gradient='reinforce', imag_gradient_mix=0.0, imag_horizon=15, initial='learned', kl_free=1.0, log_every=10000.0, logdir='../logdir/craftax_small', model_lr=0.0001, norm=True, offline_evaldir='', offline_traindir='', opt='adam', opt_eps=1e-08, parallel=False, precision=32, prefill=2500, pretrain=100, rep_scale=0.1, reset_every=0, reward_EMA=True, reward_head={'layers': 3, 'dist': 'symlog_disc', 'loss_scale': 1.0, 'outscale': 0.0}, seed=0, size=(64, 64), step=1000000.0, steps=1000000.0, task='craftax_Craftax-Symbolic-AutoReset-v1', time_limit=1000, train_ratio=512, traindir=None, unimix_ratio=0.01, units=512, value={'layers': 3}, video_pred_log=False, weight_decay=0.0)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# emulate cli\n",
"argv = f\"../dreamer.py --configs craftax_small --logdir ../logdir/craftax_small\"\n",
"argv = argv.split()\n",
"print(argv)\n",
"config = parse_args(argv)\n",
"config"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32m2024-06-06 13:35:50.147\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m27\u001b[0m - \u001b[1mLogdir ../logdir/craftax_small\u001b[0m\n",
"\u001b[32m2024-06-06 13:35:50.153\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m36\u001b[0m - \u001b[1mCreate envs.\u001b[0m\n",
"\u001b[32m2024-06-06 13:36:42.176\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m57\u001b[0m - \u001b[1mAction Space Box(0.0, 1.0, (43,), float32)\u001b[0m\n",
"\u001b[32m2024-06-06 13:36:42.178\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m63\u001b[0m - \u001b[1mPrefill dataset (0 steps).\u001b[0m\n",
"\u001b[32m2024-06-06 13:36:42.180\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m92\u001b[0m - \u001b[1mLogger: (128521 steps).\u001b[0m\n",
"\u001b[32m2024-06-06 13:36:42.180\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m94\u001b[0m - \u001b[1mSimulate agent.\u001b[0m\n"
]
}
],
"source": [
"from loguru import logger\n",
"from tqdm.auto import tqdm\n",
"import pathlib\n",
"\n",
"import torch\n",
"from torch import nn\n",
"from torch import distributions as torchd\n",
"\n",
"import exploration as expl\n",
"import models\n",
"import tools\n",
"import envs.wrappers as wrappers\n",
"from parallel import Parallel, Damy\n",
"\n",
"# from main\n",
"tools.set_seed_everywhere(config.seed)\n",
"if config.deterministic_run:\n",
" tools.enable_deterministic_run()\n",
"logdir = pathlib.Path(config.logdir).expanduser()\n",
"config.traindir = config.traindir or logdir / \"train_eps\"\n",
"config.evaldir = config.evaldir or logdir / \"eval_eps\"\n",
"config.steps //= config.action_repeat\n",
"config.eval_every //= config.action_repeat\n",
"config.log_every //= config.action_repeat\n",
"config.time_limit //= config.action_repeat\n",
"\n",
"logger.info(f\"Logdir {logdir}\")\n",
"logdir.mkdir(parents=True, exist_ok=True)\n",
"config.traindir.mkdir(parents=True, exist_ok=True)\n",
"config.evaldir.mkdir(parents=True, exist_ok=True)\n",
"step = count_steps(config.traindir)\n",
"# step in logger is environmental step\n",
"tlogger = tools.Logger(logdir, config.action_repeat * step)\n",
"logger.add(logdir/\"logger.log\")\n",
"\n",
"logger.info(\"Create envs.\")\n",
"if config.offline_traindir:\n",
" directory = config.offline_traindir.format(**vars(config))\n",
"else:\n",
" directory = config.traindir\n",
"train_eps = tools.load_episodes(directory, limit=config.dataset_size)\n",
"if config.offline_evaldir:\n",
" directory = config.offline_evaldir.format(**vars(config))\n",
"else:\n",
" directory = config.evaldir\n",
"eval_eps = tools.load_episodes(directory, limit=1)\n",
"make = lambda mode, id: make_env(config, mode, id)\n",
"train_envs = [make(\"train\", i) for i in range(config.envs)]\n",
"eval_envs = [make(\"eval\", i) for i in range(config.envs)]\n",
"if config.parallel:\n",
" train_envs = [Parallel(env, \"process\") for env in train_envs]\n",
" eval_envs = [Parallel(env, \"process\") for env in eval_envs]\n",
"else:\n",
" train_envs = [Damy(env) for env in train_envs]\n",
" eval_envs = [Damy(env) for env in eval_envs]\n",
"acts = train_envs[0].action_space\n",
"logger.info(f\"Action Space {acts}\" )\n",
"config.num_actions = acts.n if hasattr(acts, \"n\") else acts.shape[0]\n",
"\n",
"state = None\n",
"if not config.offline_traindir:\n",
" prefill = max(0, config.prefill - count_steps(config.traindir))\n",
" logger.info(f\"Prefill dataset ({prefill} steps).\")\n",
" if hasattr(acts, \"discrete\"):\n",
" random_actor = tools.OneHotDist(\n",
" torch.zeros(config.num_actions).repeat(config.envs, 1)\n",
" )\n",
" else:\n",
" random_actor = torchd.independent.Independent(\n",
" torchd.uniform.Uniform(\n",
" torch.Tensor(acts.low).repeat(config.envs, 1),\n",
" torch.Tensor(acts.high).repeat(config.envs, 1),\n",
" ),\n",
" 1,\n",
" )\n",
"\n",
" def random_agent(o, d, s):\n",
" action = random_actor.sample()\n",
" logprob = random_actor.log_prob(action)\n",
" return {\"action\": action, \"logprob\": logprob}, None\n",
"\n",
" state = tools.simulate(\n",
" random_agent,\n",
" train_envs,\n",
" train_eps,\n",
" config.traindir,\n",
" tlogger,\n",
" limit=config.dataset_size,\n",
" steps=prefill,\n",
" )\n",
" tlogger.step += prefill * config.action_repeat\n",
" logger.info(f\"Logger: ({tlogger.step} steps).\")\n",
"\n",
"logger.info(\"Simulate agent.\")\n",
"train_dataset = make_dataset(train_eps, config)\n",
"eval_dataset = make_dataset(eval_eps, config)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32m2024-06-06 13:38:20.651\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m323\u001b[0m - \u001b[1mEncoder CNN shapes: {}\u001b[0m\n",
"\u001b[32m2024-06-06 13:38:20.651\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m324\u001b[0m - \u001b[1mEncoder MLP shapes: {'state': (16536,)}\u001b[0m\n",
"\u001b[32m2024-06-06 13:38:20.751\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m390\u001b[0m - \u001b[1mDecoder CNN shapes: {}\u001b[0m\n",
"\u001b[32m2024-06-06 13:38:20.751\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m391\u001b[0m - \u001b[1mDecoder MLP shapes: {'state': (16536,)}\u001b[0m\n",
"\u001b[32m2024-06-06 13:38:20.813\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmodels\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m102\u001b[0m - \u001b[1mOptimizer model_opt has 15732120 variables.\u001b[0m\n",
"\u001b[32m2024-06-06 13:38:20.836\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmodels\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m281\u001b[0m - \u001b[1mOptimizer actor_opt has 1335851 variables.\u001b[0m\n",
"\u001b[32m2024-06-06 13:38:20.837\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmodels\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m292\u001b[0m - \u001b[1mOptimizer value_opt has 1181439 variables.\u001b[0m\n",
"\u001b[32m2024-06-06 13:38:21.032\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m17\u001b[0m - \u001b[33m\u001b[1mLoaded model from ../logdir/craftax_small/latest.pt\u001b[0m\n"
]
}
],
"source": [
"config = parse_args(argv)\n",
"config.num_actions = acts.n if hasattr(acts, \"n\") else acts.shape[0]\n",
"agent = Dreamer(\n",
" train_envs[0].observation_space,\n",
" train_envs[0].action_space,\n",
" config,\n",
" tlogger,\n",
" train_dataset,\n",
").to(config.device)\n",
"# print(agent)\n",
"agent.requires_grad_(requires_grad=False)\n",
"if (logdir / \"latest.pt\").exists():\n",
" checkpoint = torch.load(logdir / \"latest.pt\")\n",
" agent.load_state_dict(checkpoint[\"agent_state_dict\"])\n",
" tools.recursively_load_optim_state_dict(agent, checkpoint[\"optims_state_dict\"])\n",
" agent._should_pretrain._once = False\n",
" logger.warning(f\"Loaded model from {logdir / 'latest.pt'}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Now lets play"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0, 0, array([ True]), array([0], dtype=int32), [None], None, [0])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"assert state is not None\n",
"import numpy as np\n",
"\n",
"state"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"from tools import convert, add_to_cache\n",
"envs = train_envs\n",
"cache = train_eps\n",
"\n",
"step, episode = 0, 0\n",
"done = np.ones(len(envs), bool)\n",
"length = np.zeros(len(envs), np.int32)\n",
"obs = [None] * len(envs)\n",
"agent_state = None\n",
"reward = [0] * len(envs)\n",
"\n",
"indices = [index for index, d in enumerate(done) if d]\n",
"results = [envs[i].reset() for i in indices]\n",
"results = [r() for r in results]\n",
"for index, result in zip(indices, results):\n",
" t = result.copy()\n",
" t = {k: convert(v) for k, v in t.items()}\n",
" # action will be added to transition in add_to_cache\n",
" t[\"reward\"] = 0.0\n",
" t[\"discount\"] = 1.0\n",
" # initial state should be added to cache\n",
" add_to_cache(cache, envs[index].id, t)\n",
" # replace obs with done by initial state\n",
" obs[index] = result\n",
"# step agents"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32m2024-06-06 13:38:34.000\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mtools\u001b[0m:\u001b[36mwrite\u001b[0m:\u001b[36m85\u001b[0m - \u001b[1m[128521] model_loss \u001b[31m22.2\u001b[0m\u001b[1m / model_grad_norm \u001b[31m14.4\u001b[0m\u001b[1m / state_loss \u001b[31m17.4\u001b[0m\u001b[1m / reward_loss \u001b[31m0.1\u001b[0m\u001b[1m / cont_loss \u001b[31m0.0\u001b[0m\u001b[1m / kl_free \u001b[31m1.0\u001b[0m\u001b[1m / dyn_scale \u001b[31m0.5\u001b[0m\u001b[1m / rep_scale \u001b[31m0.1\u001b[0m\u001b[1m / dyn_loss \u001b[31m7.8\u001b[0m\u001b[1m / rep_loss \u001b[31m7.8\u001b[0m\u001b[1m / kl \u001b[31m7.7\u001b[0m\u001b[1m / prior_ent \u001b[31m48.4\u001b[0m\u001b[1m / post_ent \u001b[31m40.7\u001b[0m\u001b[1m / normed_target_mean \u001b[31m0.4\u001b[0m\u001b[1m / normed_target_std \u001b[31m0.3\u001b[0m\u001b[1m / normed_target_min \u001b[31m-0.3\u001b[0m\u001b[1m / normed_target_max \u001b[31m1.8\u001b[0m\u001b[1m / EMA_005 \u001b[31m12.3\u001b[0m\u001b[1m / EMA_095 \u001b[31m26.4\u001b[0m\u001b[1m / value_mean \u001b[31m18.2\u001b[0m\u001b[1m / value_std \u001b[31m4.3\u001b[0m\u001b[1m / value_min \u001b[31m10.1\u001b[0m\u001b[1m / value_max \u001b[31m31.1\u001b[0m\u001b[1m / target_mean \u001b[31m18.4\u001b[0m\u001b[1m / target_std \u001b[31m4.7\u001b[0m\u001b[1m / target_min \u001b[31m8.4\u001b[0m\u001b[1m / target_max \u001b[31m37.8\u001b[0m\u001b[1m / imag_reward_mean \u001b[31m0.0\u001b[0m\u001b[1m / imag_reward_std \u001b[31m0.1\u001b[0m\u001b[1m / imag_reward_min \u001b[31m-0.2\u001b[0m\u001b[1m / imag_reward_max \u001b[31m1.0\u001b[0m\u001b[1m / imag_action_mean \u001b[31m10.0\u001b[0m\u001b[1m / imag_action_std \u001b[31m12.9\u001b[0m\u001b[1m / imag_action_min \u001b[31m0.0\u001b[0m\u001b[1m / imag_action_max \u001b[31m42.0\u001b[0m\u001b[1m / actor_entropy \u001b[31m0.9\u001b[0m\u001b[1m / actor_loss \u001b[31m0.1\u001b[0m\u001b[1m / actor_grad_norm \u001b[31m0.5\u001b[0m\u001b[1m / value_loss \u001b[31m1.3\u001b[0m\u001b[1m / value_grad_norm \u001b[31m0.9\u001b[0m\u001b[1m / update_count \u001b[31m1.0\u001b[0m\u001b[1m / fps \u001b[31m0.0\u001b[0m\u001b[1m\u001b[0m\n"
]
}
],
"source": [
"# from tools.simulate\n",
"\n",
"# step\n",
"# step, episode, done, length, obs, agent_state, reward = state\n",
"obs = {k: np.stack([o[k] for o in obs]) for k in obs[0] if \"log_\" not in k}\n",
"action, agent_state = agent(obs, done, agent_state)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"=====================================================================================\n",
"Layer (type:depth-idx) Param #\n",
"=====================================================================================\n",
"Dreamer --\n",
"├─OptimizedModule: 1-1 --\n",
"│ └─WorldModel: 2-1 --\n",
"│ │ └─MultiEncoder: 3-1 (4,365,824)\n",
"│ │ └─RSSM: 3-2 (3,831,808)\n",
"│ │ └─ModuleDict: 3-3 (7,534,488)\n",
"├─OptimizedModule: 1-2 --\n",
"│ └─ImagBehavior: 2-2 15,732,120\n",
"│ │ └─WorldModel: 3-4 (recursive)\n",
"│ │ └─MLP: 3-5 (1,335,851)\n",
"│ │ └─MLP: 3-6 (1,181,439)\n",
"│ │ └─MLP: 3-7 (1,181,439)\n",
"├─OptimizedModule: 1-3 (recursive)\n",
"│ └─ImagBehavior: 2-3 (recursive)\n",
"│ │ └─WorldModel: 3-8 (recursive)\n",
"│ │ └─MLP: 3-9 (recursive)\n",
"│ │ └─MLP: 3-10 (recursive)\n",
"│ │ └─MLP: 3-11 (recursive)\n",
"=====================================================================================\n",
"Total params: 35,162,969\n",
"Trainable params: 0\n",
"Non-trainable params: 35,162,969\n",
"====================================================================================="
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from torchinfo import summary\n",
"\n",
"summary(agent, input=(obs, done, agent_state), depth=3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Fine grained torchinfo"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"wm = agent._wm\n",
"data = next(agent._dataset) \n",
"# self._train()\n",
"# post, context, mets = wm._train(data)\n",
"data = wm.preprocess(data)\n",
"embed = wm.encoder(data)\n",
"post, prior = wm.dynamics.observe(\n",
" embed, data[\"action\"], data[\"is_first\"]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"==========================================================================================\n",
"Layer (type:depth-idx) Output Shape Param #\n",
"==========================================================================================\n",
"MultiEncoder [128, 16, 256] --\n",
"├─MLP: 1-1 [128, 16, 256] --\n",
"│ └─Sequential: 2-1 [128, 16, 256] --\n",
"│ │ └─Linear: 3-1 [128, 16, 256] (4,233,216)\n",
"│ │ └─LayerNorm: 3-2 [128, 16, 256] (512)\n",
"│ │ └─SiLU: 3-3 [128, 16, 256] --\n",
"│ │ └─Linear: 3-4 [128, 16, 256] (65,536)\n",
"│ │ └─LayerNorm: 3-5 [128, 16, 256] (512)\n",
"│ │ └─SiLU: 3-6 [128, 16, 256] --\n",
"│ │ └─Linear: 3-7 [128, 16, 256] (65,536)\n",
"│ │ └─LayerNorm: 3-8 [128, 16, 256] (512)\n",
"│ │ └─SiLU: 3-9 [128, 16, 256] --\n",
"==========================================================================================\n",
"Total params: 4,365,824\n",
"Trainable params: 0\n",
"Non-trainable params: 4,365,824\n",
"Total mult-adds (M): 558.83\n",
"==========================================================================================\n",
"Input size (MB): 487.31\n",
"Forward/backward pass size (MB): 25.17\n",
"Params size (MB): 17.46\n",
"Estimated Total Size (MB): 529.94\n",
"=========================================================================================="
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"summary(wm.encoder, input_data=(data,), depth=3, col_names=[\"output_size\", \"num_params\", ])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"decoder\n",
"==========================================================================================\n",
"Layer (type:depth-idx) Output Shape Param #\n",
"==========================================================================================\n",
"MultiDecoder -- --\n",
"├─MLP: 1-1 -- --\n",
"│ └─Sequential: 2-1 [128, 16, 256] --\n",
"│ │ └─Linear: 3-1 [128, 16, 256] (393,216)\n",
"│ │ └─LayerNorm: 3-2 [128, 16, 256] (512)\n",
"│ │ └─SiLU: 3-3 [128, 16, 256] --\n",
"│ │ └─Linear: 3-4 [128, 16, 256] (65,536)\n",
"│ │ └─LayerNorm: 3-5 [128, 16, 256] (512)\n",
"│ │ └─SiLU: 3-6 [128, 16, 256] --\n",
"│ │ └─Linear: 3-7 [128, 16, 256] (65,536)\n",
"│ │ └─LayerNorm: 3-8 [128, 16, 256] (512)\n",
"│ │ └─SiLU: 3-9 [128, 16, 256] --\n",
"│ └─ModuleDict: 2-2 -- --\n",
"│ │ └─Linear: 3-10 [128, 16, 16536] (4,249,752)\n",
"==========================================================================================\n",
"Total params: 4,775,576\n",
"Trainable params: 0\n",
"Non-trainable params: 4,775,576\n",
"Total mult-adds (M): 611.27\n",
"==========================================================================================\n",
"Input size (MB): 12.58\n",
"Forward/backward pass size (MB): 296.09\n",
"Params size (MB): 19.10\n",
"Estimated Total Size (MB): 327.78\n",
"==========================================================================================\n",
"Summary Failed for reward Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Sequential: 1, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 1]\n",
"Summary Failed for cont Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Sequential: 1, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 1]\n"
]
}
],
"source": [
"# heads\n",
"feat = wm.dynamics.get_feat(post)\n",
"for name, head in wm.heads.items():\n",
" try:\n",
" o = summary(head, input_data=(feat,), depth=3, col_names=[\"output_size\", \"num_params\", ])\n",
" print(name)\n",
" print(o)\n",
" except Exception as e:\n",
" print(f\"Summary Failed for {name} {e}\")\n",
" continue"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"# fail as no call method\n",
"# summary(wm.dynamics, input_data=(embed, data[\"action\"], data[\"is_first\"]), depth=3, col_names=[\"output_size\", \"num_params\", ])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 2
}