it learnt some acheivements, still takes a long time, and I might have made some params to small

This commit is contained in:
wassname
2024-06-06 13:39:17 +08:00
parent ff14ca4639
commit e0233faf88
9 changed files with 628 additions and 16 deletions
+45 -4
View File
@@ -62,7 +62,7 @@ defaults:
initial: 'learned'
# Training
batch_size: 256
batch_size: 64
batch_length: 64
train_ratio: 512
pretrain: 100
@@ -136,18 +136,59 @@ craftax:
action_repeat: 1
envs: 1
train_ratio: 512
video_pred_log: false # FIXME
video_pred_log: false
dyn_hidden: 1024
dyn_deter: 4096
units: 1024
encoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 5, mlp_units: 1024, }
decoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 5, mlp_units: 1024}
encoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 4, mlp_units: 512, }
decoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 4, mlp_units: 512}
actor: {layers: 5, dist: 'onehot', std: 'none'}
value: {layers: 5}
reward_head: {layers: 5}
cont_head: {layers: 5}
imag_gradient: 'reinforce'
craftax_small:
task: craftax_Craftax-Symbolic-AutoReset-v1
step: 1e6
action_repeat: 1
envs: 1
train_ratio: 512
video_pred_log: false
dyn_hidden: 512
dyn_deter: 512
units: 512
encoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 3, mlp_units: 256}
decoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 3, mlp_units: 256}
actor: {layers: 3, dist: 'onehot', std: 'none'}
value: {layers: 3}
reward_head: {layers: 3}
cont_head: {layers: 3}
imag_gradient: 'reinforce'
batch_size: 128
batch_length: 16
craftax_smaller:
task: craftax_Craftax-Symbolic-AutoReset-v1
step: 1e6
action_repeat: 1
envs: 1
train_ratio: 256
video_pred_log: false
dyn_hidden: 256
dyn_deter: 1024
units: 256
encoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 2, mlp_units: 256, }
decoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 2, mlp_units: 256}
actor: {layers: 2, dist: 'onehot', std: 'none'}
value: {layers: 2}
reward_head: {layers: 2}
cont_head: {layers: 2}
imag_gradient: 'reinforce'
batch_size: 256
batch_length: 16
atari100k:
steps: 4e5
envs: 1
+9 -5
View File
@@ -308,6 +308,7 @@ def main(config):
agent.load_state_dict(checkpoint["agent_state_dict"])
tools.recursively_load_optim_state_dict(agent, checkpoint["optims_state_dict"])
agent._should_pretrain._once = False
logger.warning(f"Loaded model from {logdir / 'latest.pt'}")
# make sure eval will be executed once after config.steps
with tqdm(total=config.steps + config.eval_every, unit='step') as pbar:
@@ -356,13 +357,12 @@ def main(config):
except Exception:
pass
if __name__ == "__main__":
def parse_args(argv=None):
parser = argparse.ArgumentParser()
parser.add_argument("--configs", nargs="+")
args, remaining = parser.parse_known_args()
args, remaining = parser.parse_known_args(argv[1:])
configs = yaml.safe_load(
(pathlib.Path(sys.argv[0]).parent / "configs.yaml").read_text()
(pathlib.Path(argv[0]).parent / "configs.yaml").read_text()
)
def recursive_update(base, update):
@@ -380,4 +380,8 @@ if __name__ == "__main__":
for key, value in sorted(defaults.items(), key=lambda x: x[0]):
arg_type = tools.args_type(value)
parser.add_argument(f"--{key}", type=arg_type, default=arg_type(value))
main(parser.parse_args(remaining))
args = parser.parse_args(remaining)
return args
if __name__ == "__main__":
main(parse_args())
+3 -1
View File
@@ -224,6 +224,8 @@ class Craftax:
def step(self, action):
state, reward, done, info = self._env.step(action)
info2 = {k.replace('Ach','log_ach'):v for k,v in info.items()}
reward = np.float32(reward)
obs = {
"image": self.get_image(),
@@ -231,7 +233,7 @@ class Craftax:
"is_first": False,
"is_last": done,
"is_terminal": info["discount"] == 0,
**info,
**info2,
}
return obs, reward, done, info
BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 36 KiB

+1 -1
View File
@@ -6,7 +6,7 @@ export TQDM_MININTERVAL := "30"
main:
. ./.venv/bin/activate
python dreamer.py --configs crafter --task crafter_reward --logdir ./logdir/crafter
python dreamer.py --configs craftax_small --logdir ./logdir/crafter
logs:
tensorboard --logdir logdir/craftax
+553
View File
@@ -0,0 +1,553 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this notebook we load a saved dreamer, and run it, to look at params, speed and improve hackability"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading textures from cache\n"
]
}
],
"source": [
"# TODO make this a proper package\n",
"import os, sys\n",
"sys.path.append('..')\n",
"\n",
"\n",
"from dreamer import parse_args, main, make_env, make_dataset, count_steps,Dreamer"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['../dreamer.py', '--configs', 'craftax_small', '--logdir', '../logdir/craftax_small']\n"
]
},
{
"data": {
"text/plain": [
"Namespace(act='SiLU', action_repeat=1, actor={'layers': 3, 'dist': 'onehot', 'entropy': 0.0003, 'unimix_ratio': 0.01, 'std': 'none', 'min_std': 0.1, 'max_std': 1.0, 'temp': 0.1, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 1.0}, batch_length=16, batch_size=128, compile=True, cont_head={'layers': 3, 'loss_scale': 1.0, 'outscale': 1.0}, critic={'layers': 2, 'dist': 'symlog_disc', 'slow_target': True, 'slow_target_update': 1, 'slow_target_fraction': 0.02, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 0.0}, dataset_size=1000000, debug=False, decoder={'mlp_keys': 'state', 'cnn_keys': '$^', 'act': 'SiLU', 'norm': True, 'cnn_depth': 32, 'kernel_size': 4, 'minres': 4, 'mlp_layers': 3, 'mlp_units': 256, 'cnn_sigmoid': False, 'image_dist': 'mse', 'vector_dist': 'symlog_mse', 'outscale': 1.0}, deterministic_run=False, device='cuda:0', disag_action_cond=False, disag_layers=4, disag_log=True, disag_models=10, disag_offset=1, disag_target='stoch', disag_units=400, discount=0.997, discount_lambda=0.95, dyn_deter=512, dyn_discrete=32, dyn_hidden=256, dyn_mean_act='none', dyn_min_std=0.1, dyn_rec_depth=1, dyn_scale=0.5, dyn_std_act='sigmoid2', dyn_stoch=32, encoder={'mlp_keys': 'state', 'cnn_keys': '$^', 'act': 'SiLU', 'norm': True, 'cnn_depth': 32, 'kernel_size': 4, 'minres': 4, 'mlp_layers': 3, 'mlp_units': 256, 'symlog_inputs': True}, envs=1, eval_episode_num=10, eval_every=10000.0, eval_state_mean=False, evaldir=None, expl_behavior='greedy', expl_extr_scale=0.0, expl_intr_scale=1.0, expl_until=0, grad_clip=1000, grad_heads=('decoder', 'reward', 'cont'), grayscale=False, imag_gradient='reinforce', imag_gradient_mix=0.0, imag_horizon=15, initial='learned', kl_free=1.0, log_every=10000.0, logdir='../logdir/craftax_small', model_lr=0.0001, norm=True, offline_evaldir='', offline_traindir='', opt='adam', opt_eps=1e-08, parallel=False, precision=32, prefill=2500, pretrain=100, rep_scale=0.1, reset_every=0, reward_EMA=True, reward_head={'layers': 3, 'dist': 'symlog_disc', 'loss_scale': 1.0, 'outscale': 0.0}, seed=0, size=(64, 64), step=1000000.0, steps=1000000.0, task='craftax_Craftax-Symbolic-AutoReset-v1', time_limit=1000, train_ratio=512, traindir=None, unimix_ratio=0.01, units=512, value={'layers': 3}, video_pred_log=False, weight_decay=0.0)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# emulate cli\n",
"argv = f\"../dreamer.py --configs craftax_small --logdir ../logdir/craftax_small\"\n",
"argv = argv.split()\n",
"print(argv)\n",
"config = parse_args(argv)\n",
"config"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32m2024-06-06 13:35:50.147\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m27\u001b[0m - \u001b[1mLogdir ../logdir/craftax_small\u001b[0m\n",
"\u001b[32m2024-06-06 13:35:50.153\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m36\u001b[0m - \u001b[1mCreate envs.\u001b[0m\n",
"\u001b[32m2024-06-06 13:36:42.176\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m57\u001b[0m - \u001b[1mAction Space Box(0.0, 1.0, (43,), float32)\u001b[0m\n",
"\u001b[32m2024-06-06 13:36:42.178\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m63\u001b[0m - \u001b[1mPrefill dataset (0 steps).\u001b[0m\n",
"\u001b[32m2024-06-06 13:36:42.180\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m92\u001b[0m - \u001b[1mLogger: (128521 steps).\u001b[0m\n",
"\u001b[32m2024-06-06 13:36:42.180\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m94\u001b[0m - \u001b[1mSimulate agent.\u001b[0m\n"
]
}
],
"source": [
"from loguru import logger\n",
"from tqdm.auto import tqdm\n",
"import pathlib\n",
"\n",
"import torch\n",
"from torch import nn\n",
"from torch import distributions as torchd\n",
"\n",
"import exploration as expl\n",
"import models\n",
"import tools\n",
"import envs.wrappers as wrappers\n",
"from parallel import Parallel, Damy\n",
"\n",
"# from main\n",
"tools.set_seed_everywhere(config.seed)\n",
"if config.deterministic_run:\n",
" tools.enable_deterministic_run()\n",
"logdir = pathlib.Path(config.logdir).expanduser()\n",
"config.traindir = config.traindir or logdir / \"train_eps\"\n",
"config.evaldir = config.evaldir or logdir / \"eval_eps\"\n",
"config.steps //= config.action_repeat\n",
"config.eval_every //= config.action_repeat\n",
"config.log_every //= config.action_repeat\n",
"config.time_limit //= config.action_repeat\n",
"\n",
"logger.info(f\"Logdir {logdir}\")\n",
"logdir.mkdir(parents=True, exist_ok=True)\n",
"config.traindir.mkdir(parents=True, exist_ok=True)\n",
"config.evaldir.mkdir(parents=True, exist_ok=True)\n",
"step = count_steps(config.traindir)\n",
"# step in logger is environmental step\n",
"tlogger = tools.Logger(logdir, config.action_repeat * step)\n",
"logger.add(logdir/\"logger.log\")\n",
"\n",
"logger.info(\"Create envs.\")\n",
"if config.offline_traindir:\n",
" directory = config.offline_traindir.format(**vars(config))\n",
"else:\n",
" directory = config.traindir\n",
"train_eps = tools.load_episodes(directory, limit=config.dataset_size)\n",
"if config.offline_evaldir:\n",
" directory = config.offline_evaldir.format(**vars(config))\n",
"else:\n",
" directory = config.evaldir\n",
"eval_eps = tools.load_episodes(directory, limit=1)\n",
"make = lambda mode, id: make_env(config, mode, id)\n",
"train_envs = [make(\"train\", i) for i in range(config.envs)]\n",
"eval_envs = [make(\"eval\", i) for i in range(config.envs)]\n",
"if config.parallel:\n",
" train_envs = [Parallel(env, \"process\") for env in train_envs]\n",
" eval_envs = [Parallel(env, \"process\") for env in eval_envs]\n",
"else:\n",
" train_envs = [Damy(env) for env in train_envs]\n",
" eval_envs = [Damy(env) for env in eval_envs]\n",
"acts = train_envs[0].action_space\n",
"logger.info(f\"Action Space {acts}\" )\n",
"config.num_actions = acts.n if hasattr(acts, \"n\") else acts.shape[0]\n",
"\n",
"state = None\n",
"if not config.offline_traindir:\n",
" prefill = max(0, config.prefill - count_steps(config.traindir))\n",
" logger.info(f\"Prefill dataset ({prefill} steps).\")\n",
" if hasattr(acts, \"discrete\"):\n",
" random_actor = tools.OneHotDist(\n",
" torch.zeros(config.num_actions).repeat(config.envs, 1)\n",
" )\n",
" else:\n",
" random_actor = torchd.independent.Independent(\n",
" torchd.uniform.Uniform(\n",
" torch.Tensor(acts.low).repeat(config.envs, 1),\n",
" torch.Tensor(acts.high).repeat(config.envs, 1),\n",
" ),\n",
" 1,\n",
" )\n",
"\n",
" def random_agent(o, d, s):\n",
" action = random_actor.sample()\n",
" logprob = random_actor.log_prob(action)\n",
" return {\"action\": action, \"logprob\": logprob}, None\n",
"\n",
" state = tools.simulate(\n",
" random_agent,\n",
" train_envs,\n",
" train_eps,\n",
" config.traindir,\n",
" tlogger,\n",
" limit=config.dataset_size,\n",
" steps=prefill,\n",
" )\n",
" tlogger.step += prefill * config.action_repeat\n",
" logger.info(f\"Logger: ({tlogger.step} steps).\")\n",
"\n",
"logger.info(\"Simulate agent.\")\n",
"train_dataset = make_dataset(train_eps, config)\n",
"eval_dataset = make_dataset(eval_eps, config)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32m2024-06-06 13:38:20.651\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m323\u001b[0m - \u001b[1mEncoder CNN shapes: {}\u001b[0m\n",
"\u001b[32m2024-06-06 13:38:20.651\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m324\u001b[0m - \u001b[1mEncoder MLP shapes: {'state': (16536,)}\u001b[0m\n",
"\u001b[32m2024-06-06 13:38:20.751\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m390\u001b[0m - \u001b[1mDecoder CNN shapes: {}\u001b[0m\n",
"\u001b[32m2024-06-06 13:38:20.751\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m391\u001b[0m - \u001b[1mDecoder MLP shapes: {'state': (16536,)}\u001b[0m\n",
"\u001b[32m2024-06-06 13:38:20.813\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmodels\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m102\u001b[0m - \u001b[1mOptimizer model_opt has 15732120 variables.\u001b[0m\n",
"\u001b[32m2024-06-06 13:38:20.836\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmodels\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m281\u001b[0m - \u001b[1mOptimizer actor_opt has 1335851 variables.\u001b[0m\n",
"\u001b[32m2024-06-06 13:38:20.837\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmodels\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m292\u001b[0m - \u001b[1mOptimizer value_opt has 1181439 variables.\u001b[0m\n",
"\u001b[32m2024-06-06 13:38:21.032\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m17\u001b[0m - \u001b[33m\u001b[1mLoaded model from ../logdir/craftax_small/latest.pt\u001b[0m\n"
]
}
],
"source": [
"config = parse_args(argv)\n",
"config.num_actions = acts.n if hasattr(acts, \"n\") else acts.shape[0]\n",
"agent = Dreamer(\n",
" train_envs[0].observation_space,\n",
" train_envs[0].action_space,\n",
" config,\n",
" tlogger,\n",
" train_dataset,\n",
").to(config.device)\n",
"# print(agent)\n",
"agent.requires_grad_(requires_grad=False)\n",
"if (logdir / \"latest.pt\").exists():\n",
" checkpoint = torch.load(logdir / \"latest.pt\")\n",
" agent.load_state_dict(checkpoint[\"agent_state_dict\"])\n",
" tools.recursively_load_optim_state_dict(agent, checkpoint[\"optims_state_dict\"])\n",
" agent._should_pretrain._once = False\n",
" logger.warning(f\"Loaded model from {logdir / 'latest.pt'}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Now lets play"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0, 0, array([ True]), array([0], dtype=int32), [None], None, [0])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"assert state is not None\n",
"import numpy as np\n",
"\n",
"state"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"from tools import convert, add_to_cache\n",
"envs = train_envs\n",
"cache = train_eps\n",
"\n",
"step, episode = 0, 0\n",
"done = np.ones(len(envs), bool)\n",
"length = np.zeros(len(envs), np.int32)\n",
"obs = [None] * len(envs)\n",
"agent_state = None\n",
"reward = [0] * len(envs)\n",
"\n",
"indices = [index for index, d in enumerate(done) if d]\n",
"results = [envs[i].reset() for i in indices]\n",
"results = [r() for r in results]\n",
"for index, result in zip(indices, results):\n",
" t = result.copy()\n",
" t = {k: convert(v) for k, v in t.items()}\n",
" # action will be added to transition in add_to_cache\n",
" t[\"reward\"] = 0.0\n",
" t[\"discount\"] = 1.0\n",
" # initial state should be added to cache\n",
" add_to_cache(cache, envs[index].id, t)\n",
" # replace obs with done by initial state\n",
" obs[index] = result\n",
"# step agents"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32m2024-06-06 13:38:34.000\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mtools\u001b[0m:\u001b[36mwrite\u001b[0m:\u001b[36m85\u001b[0m - \u001b[1m[128521] model_loss \u001b[31m22.2\u001b[0m\u001b[1m / model_grad_norm \u001b[31m14.4\u001b[0m\u001b[1m / state_loss \u001b[31m17.4\u001b[0m\u001b[1m / reward_loss \u001b[31m0.1\u001b[0m\u001b[1m / cont_loss \u001b[31m0.0\u001b[0m\u001b[1m / kl_free \u001b[31m1.0\u001b[0m\u001b[1m / dyn_scale \u001b[31m0.5\u001b[0m\u001b[1m / rep_scale \u001b[31m0.1\u001b[0m\u001b[1m / dyn_loss \u001b[31m7.8\u001b[0m\u001b[1m / rep_loss \u001b[31m7.8\u001b[0m\u001b[1m / kl \u001b[31m7.7\u001b[0m\u001b[1m / prior_ent \u001b[31m48.4\u001b[0m\u001b[1m / post_ent \u001b[31m40.7\u001b[0m\u001b[1m / normed_target_mean \u001b[31m0.4\u001b[0m\u001b[1m / normed_target_std \u001b[31m0.3\u001b[0m\u001b[1m / normed_target_min \u001b[31m-0.3\u001b[0m\u001b[1m / normed_target_max \u001b[31m1.8\u001b[0m\u001b[1m / EMA_005 \u001b[31m12.3\u001b[0m\u001b[1m / EMA_095 \u001b[31m26.4\u001b[0m\u001b[1m / value_mean \u001b[31m18.2\u001b[0m\u001b[1m / value_std \u001b[31m4.3\u001b[0m\u001b[1m / value_min \u001b[31m10.1\u001b[0m\u001b[1m / value_max \u001b[31m31.1\u001b[0m\u001b[1m / target_mean \u001b[31m18.4\u001b[0m\u001b[1m / target_std \u001b[31m4.7\u001b[0m\u001b[1m / target_min \u001b[31m8.4\u001b[0m\u001b[1m / target_max \u001b[31m37.8\u001b[0m\u001b[1m / imag_reward_mean \u001b[31m0.0\u001b[0m\u001b[1m / imag_reward_std \u001b[31m0.1\u001b[0m\u001b[1m / imag_reward_min \u001b[31m-0.2\u001b[0m\u001b[1m / imag_reward_max \u001b[31m1.0\u001b[0m\u001b[1m / imag_action_mean \u001b[31m10.0\u001b[0m\u001b[1m / imag_action_std \u001b[31m12.9\u001b[0m\u001b[1m / imag_action_min \u001b[31m0.0\u001b[0m\u001b[1m / imag_action_max \u001b[31m42.0\u001b[0m\u001b[1m / actor_entropy \u001b[31m0.9\u001b[0m\u001b[1m / actor_loss \u001b[31m0.1\u001b[0m\u001b[1m / actor_grad_norm \u001b[31m0.5\u001b[0m\u001b[1m / value_loss \u001b[31m1.3\u001b[0m\u001b[1m / value_grad_norm \u001b[31m0.9\u001b[0m\u001b[1m / update_count \u001b[31m1.0\u001b[0m\u001b[1m / fps \u001b[31m0.0\u001b[0m\u001b[1m\u001b[0m\n"
]
}
],
"source": [
"# from tools.simulate\n",
"\n",
"# step\n",
"# step, episode, done, length, obs, agent_state, reward = state\n",
"obs = {k: np.stack([o[k] for o in obs]) for k in obs[0] if \"log_\" not in k}\n",
"action, agent_state = agent(obs, done, agent_state)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"=====================================================================================\n",
"Layer (type:depth-idx) Param #\n",
"=====================================================================================\n",
"Dreamer --\n",
"├─OptimizedModule: 1-1 --\n",
"│ └─WorldModel: 2-1 --\n",
"│ │ └─MultiEncoder: 3-1 (4,365,824)\n",
"│ │ └─RSSM: 3-2 (3,831,808)\n",
"│ │ └─ModuleDict: 3-3 (7,534,488)\n",
"├─OptimizedModule: 1-2 --\n",
"│ └─ImagBehavior: 2-2 15,732,120\n",
"│ │ └─WorldModel: 3-4 (recursive)\n",
"│ │ └─MLP: 3-5 (1,335,851)\n",
"│ │ └─MLP: 3-6 (1,181,439)\n",
"│ │ └─MLP: 3-7 (1,181,439)\n",
"├─OptimizedModule: 1-3 (recursive)\n",
"│ └─ImagBehavior: 2-3 (recursive)\n",
"│ │ └─WorldModel: 3-8 (recursive)\n",
"│ │ └─MLP: 3-9 (recursive)\n",
"│ │ └─MLP: 3-10 (recursive)\n",
"│ │ └─MLP: 3-11 (recursive)\n",
"=====================================================================================\n",
"Total params: 35,162,969\n",
"Trainable params: 0\n",
"Non-trainable params: 35,162,969\n",
"====================================================================================="
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from torchinfo import summary\n",
"\n",
"summary(agent, input=(obs, done, agent_state), depth=3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Fine grained torchinfo"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"wm = agent._wm\n",
"data = next(agent._dataset) \n",
"# self._train()\n",
"# post, context, mets = wm._train(data)\n",
"data = wm.preprocess(data)\n",
"embed = wm.encoder(data)\n",
"post, prior = wm.dynamics.observe(\n",
" embed, data[\"action\"], data[\"is_first\"]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"==========================================================================================\n",
"Layer (type:depth-idx) Output Shape Param #\n",
"==========================================================================================\n",
"MultiEncoder [128, 16, 256] --\n",
"├─MLP: 1-1 [128, 16, 256] --\n",
"│ └─Sequential: 2-1 [128, 16, 256] --\n",
"│ │ └─Linear: 3-1 [128, 16, 256] (4,233,216)\n",
"│ │ └─LayerNorm: 3-2 [128, 16, 256] (512)\n",
"│ │ └─SiLU: 3-3 [128, 16, 256] --\n",
"│ │ └─Linear: 3-4 [128, 16, 256] (65,536)\n",
"│ │ └─LayerNorm: 3-5 [128, 16, 256] (512)\n",
"│ │ └─SiLU: 3-6 [128, 16, 256] --\n",
"│ │ └─Linear: 3-7 [128, 16, 256] (65,536)\n",
"│ │ └─LayerNorm: 3-8 [128, 16, 256] (512)\n",
"│ │ └─SiLU: 3-9 [128, 16, 256] --\n",
"==========================================================================================\n",
"Total params: 4,365,824\n",
"Trainable params: 0\n",
"Non-trainable params: 4,365,824\n",
"Total mult-adds (M): 558.83\n",
"==========================================================================================\n",
"Input size (MB): 487.31\n",
"Forward/backward pass size (MB): 25.17\n",
"Params size (MB): 17.46\n",
"Estimated Total Size (MB): 529.94\n",
"=========================================================================================="
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"summary(wm.encoder, input_data=(data,), depth=3, col_names=[\"output_size\", \"num_params\", ])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"decoder\n",
"==========================================================================================\n",
"Layer (type:depth-idx) Output Shape Param #\n",
"==========================================================================================\n",
"MultiDecoder -- --\n",
"├─MLP: 1-1 -- --\n",
"│ └─Sequential: 2-1 [128, 16, 256] --\n",
"│ │ └─Linear: 3-1 [128, 16, 256] (393,216)\n",
"│ │ └─LayerNorm: 3-2 [128, 16, 256] (512)\n",
"│ │ └─SiLU: 3-3 [128, 16, 256] --\n",
"│ │ └─Linear: 3-4 [128, 16, 256] (65,536)\n",
"│ │ └─LayerNorm: 3-5 [128, 16, 256] (512)\n",
"│ │ └─SiLU: 3-6 [128, 16, 256] --\n",
"│ │ └─Linear: 3-7 [128, 16, 256] (65,536)\n",
"│ │ └─LayerNorm: 3-8 [128, 16, 256] (512)\n",
"│ │ └─SiLU: 3-9 [128, 16, 256] --\n",
"│ └─ModuleDict: 2-2 -- --\n",
"│ │ └─Linear: 3-10 [128, 16, 16536] (4,249,752)\n",
"==========================================================================================\n",
"Total params: 4,775,576\n",
"Trainable params: 0\n",
"Non-trainable params: 4,775,576\n",
"Total mult-adds (M): 611.27\n",
"==========================================================================================\n",
"Input size (MB): 12.58\n",
"Forward/backward pass size (MB): 296.09\n",
"Params size (MB): 19.10\n",
"Estimated Total Size (MB): 327.78\n",
"==========================================================================================\n",
"Summary Failed for reward Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Sequential: 1, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 1]\n",
"Summary Failed for cont Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Sequential: 1, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 1]\n"
]
}
],
"source": [
"# heads\n",
"feat = wm.dynamics.get_feat(post)\n",
"for name, head in wm.heads.items():\n",
" try:\n",
" o = summary(head, input_data=(feat,), depth=3, col_names=[\"output_size\", \"num_params\", ])\n",
" print(name)\n",
" print(o)\n",
" except Exception as e:\n",
" print(f\"Summary Failed for {name} {e}\")\n",
" continue"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"# fail as no call method\n",
"# summary(wm.dynamics, input_data=(embed, data[\"action\"], data[\"is_first\"]), depth=3, col_names=[\"output_size\", \"num_params\", ])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
+4 -4
View File
@@ -320,8 +320,8 @@ class MultiEncoder(nn.Module):
for k, v in shapes.items()
if len(v) in (1, 2) and re.match(mlp_keys, k)
}
logger.info("Encoder CNN shapes:", self.cnn_shapes)
logger.info("Encoder MLP shapes:", self.mlp_shapes)
logger.info("Encoder CNN shapes: {}", self.cnn_shapes)
logger.info("Encoder MLP shapes: {}", self.mlp_shapes)
self.outdim = 0
if self.cnn_shapes:
@@ -387,8 +387,8 @@ class MultiDecoder(nn.Module):
for k, v in shapes.items()
if len(v) in (1, 2) and re.match(mlp_keys, k)
}
logger.info("Decoder CNN shapes: %s", self.cnn_shapes)
logger.info("Decoder MLP shapes: %s", self.mlp_shapes)
logger.info("Decoder CNN shapes: {}", self.cnn_shapes)
logger.info("Decoder MLP shapes: {}", self.mlp_shapes)
if self.cnn_shapes:
some_shape = list(self.cnn_shapes.values())[0]
Generated
+12 -1
View File
@@ -3578,6 +3578,17 @@ type = "legacy"
url = "https://download.pytorch.org/whl/cu121"
reference = "pytorch"
[[package]]
name = "torchinfo"
version = "1.8.0"
description = "Model summary in PyTorch, based off of the original torchsummary."
optional = false
python-versions = ">=3.7"
files = [
{file = "torchinfo-1.8.0-py3-none-any.whl", hash = "sha256:2e911c2918603f945c26ff21a3a838d12709223dc4ccf243407bce8b6e897b46"},
{file = "torchinfo-1.8.0.tar.gz", hash = "sha256:72e94b0e9a3e64dc583a8e5b7940b8938a1ac0f033f795457f27e6f4e7afa2e9"},
]
[[package]]
name = "tornado"
version = "6.4"
@@ -3836,4 +3847,4 @@ test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-it
[metadata]
lock-version = "2.0"
python-versions = "^3.9"
content-hash = "8d04aef5b114f7ae76dc03bc61d308f1b239d390b0c71fab7d0c8f467cc95dd4"
content-hash = "0275da73363d94f6a5cdadc9662c1b254ef50310aabce3aa663552aa4802b001"
+1
View File
@@ -38,6 +38,7 @@ imageio = "^2.34.1"
craftax = {path = "/media/wassname/SGIronWolf/projects5/2024/Craftax", develop = true }
# craftax = {git = "https://github.com/wassname/Craftax" , develop = true }
chex = "^0.1.86"
torchinfo = "^1.8.0"
[tool.poetry.group.dev.dependencies]
ipywidgets = "^8.1.3"