mirror of
https://github.com/wassname/dreamerv3-torch.git
synced 2026-06-27 15:00:13 +08:00
it learnt some acheivements, still takes a long time, and I might have made some params to small
This commit is contained in:
+45
-4
@@ -62,7 +62,7 @@ defaults:
|
||||
initial: 'learned'
|
||||
|
||||
# Training
|
||||
batch_size: 256
|
||||
batch_size: 64
|
||||
batch_length: 64
|
||||
train_ratio: 512
|
||||
pretrain: 100
|
||||
@@ -136,18 +136,59 @@ craftax:
|
||||
action_repeat: 1
|
||||
envs: 1
|
||||
train_ratio: 512
|
||||
video_pred_log: false # FIXME
|
||||
video_pred_log: false
|
||||
dyn_hidden: 1024
|
||||
dyn_deter: 4096
|
||||
units: 1024
|
||||
encoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 5, mlp_units: 1024, }
|
||||
decoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 5, mlp_units: 1024}
|
||||
encoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 4, mlp_units: 512, }
|
||||
decoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 4, mlp_units: 512}
|
||||
actor: {layers: 5, dist: 'onehot', std: 'none'}
|
||||
value: {layers: 5}
|
||||
reward_head: {layers: 5}
|
||||
cont_head: {layers: 5}
|
||||
imag_gradient: 'reinforce'
|
||||
|
||||
craftax_small:
|
||||
task: craftax_Craftax-Symbolic-AutoReset-v1
|
||||
step: 1e6
|
||||
action_repeat: 1
|
||||
envs: 1
|
||||
train_ratio: 512
|
||||
video_pred_log: false
|
||||
dyn_hidden: 512
|
||||
dyn_deter: 512
|
||||
units: 512
|
||||
encoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 3, mlp_units: 256}
|
||||
decoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 3, mlp_units: 256}
|
||||
actor: {layers: 3, dist: 'onehot', std: 'none'}
|
||||
value: {layers: 3}
|
||||
reward_head: {layers: 3}
|
||||
cont_head: {layers: 3}
|
||||
imag_gradient: 'reinforce'
|
||||
batch_size: 128
|
||||
batch_length: 16
|
||||
|
||||
|
||||
craftax_smaller:
|
||||
task: craftax_Craftax-Symbolic-AutoReset-v1
|
||||
step: 1e6
|
||||
action_repeat: 1
|
||||
envs: 1
|
||||
train_ratio: 256
|
||||
video_pred_log: false
|
||||
dyn_hidden: 256
|
||||
dyn_deter: 1024
|
||||
units: 256
|
||||
encoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 2, mlp_units: 256, }
|
||||
decoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 2, mlp_units: 256}
|
||||
actor: {layers: 2, dist: 'onehot', std: 'none'}
|
||||
value: {layers: 2}
|
||||
reward_head: {layers: 2}
|
||||
cont_head: {layers: 2}
|
||||
imag_gradient: 'reinforce'
|
||||
batch_size: 256
|
||||
batch_length: 16
|
||||
|
||||
atari100k:
|
||||
steps: 4e5
|
||||
envs: 1
|
||||
|
||||
+9
-5
@@ -308,6 +308,7 @@ def main(config):
|
||||
agent.load_state_dict(checkpoint["agent_state_dict"])
|
||||
tools.recursively_load_optim_state_dict(agent, checkpoint["optims_state_dict"])
|
||||
agent._should_pretrain._once = False
|
||||
logger.warning(f"Loaded model from {logdir / 'latest.pt'}")
|
||||
|
||||
# make sure eval will be executed once after config.steps
|
||||
with tqdm(total=config.steps + config.eval_every, unit='step') as pbar:
|
||||
@@ -356,13 +357,12 @@ def main(config):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def parse_args(argv=None):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--configs", nargs="+")
|
||||
args, remaining = parser.parse_known_args()
|
||||
args, remaining = parser.parse_known_args(argv[1:])
|
||||
configs = yaml.safe_load(
|
||||
(pathlib.Path(sys.argv[0]).parent / "configs.yaml").read_text()
|
||||
(pathlib.Path(argv[0]).parent / "configs.yaml").read_text()
|
||||
)
|
||||
|
||||
def recursive_update(base, update):
|
||||
@@ -380,4 +380,8 @@ if __name__ == "__main__":
|
||||
for key, value in sorted(defaults.items(), key=lambda x: x[0]):
|
||||
arg_type = tools.args_type(value)
|
||||
parser.add_argument(f"--{key}", type=arg_type, default=arg_type(value))
|
||||
main(parser.parse_args(remaining))
|
||||
args = parser.parse_args(remaining)
|
||||
return args
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(parse_args())
|
||||
|
||||
+3
-1
@@ -224,6 +224,8 @@ class Craftax:
|
||||
def step(self, action):
|
||||
state, reward, done, info = self._env.step(action)
|
||||
|
||||
info2 = {k.replace('Ach','log_ach'):v for k,v in info.items()}
|
||||
|
||||
reward = np.float32(reward)
|
||||
obs = {
|
||||
"image": self.get_image(),
|
||||
@@ -231,7 +233,7 @@ class Craftax:
|
||||
"is_first": False,
|
||||
"is_last": done,
|
||||
"is_terminal": info["discount"] == 0,
|
||||
**info,
|
||||
**info2,
|
||||
}
|
||||
return obs, reward, done, info
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ export TQDM_MININTERVAL := "30"
|
||||
|
||||
main:
|
||||
. ./.venv/bin/activate
|
||||
python dreamer.py --configs crafter --task crafter_reward --logdir ./logdir/crafter
|
||||
python dreamer.py --configs craftax_small --logdir ./logdir/crafter
|
||||
|
||||
logs:
|
||||
tensorboard --logdir logdir/craftax
|
||||
|
||||
@@ -0,0 +1,553 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In this notebook we load a saved dreamer, and run it, to look at params, speed and improve hackability"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Loading textures from cache\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# TODO make this a proper package\n",
|
||||
"import os, sys\n",
|
||||
"sys.path.append('..')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"from dreamer import parse_args, main, make_env, make_dataset, count_steps,Dreamer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"['../dreamer.py', '--configs', 'craftax_small', '--logdir', '../logdir/craftax_small']\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Namespace(act='SiLU', action_repeat=1, actor={'layers': 3, 'dist': 'onehot', 'entropy': 0.0003, 'unimix_ratio': 0.01, 'std': 'none', 'min_std': 0.1, 'max_std': 1.0, 'temp': 0.1, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 1.0}, batch_length=16, batch_size=128, compile=True, cont_head={'layers': 3, 'loss_scale': 1.0, 'outscale': 1.0}, critic={'layers': 2, 'dist': 'symlog_disc', 'slow_target': True, 'slow_target_update': 1, 'slow_target_fraction': 0.02, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 0.0}, dataset_size=1000000, debug=False, decoder={'mlp_keys': 'state', 'cnn_keys': '$^', 'act': 'SiLU', 'norm': True, 'cnn_depth': 32, 'kernel_size': 4, 'minres': 4, 'mlp_layers': 3, 'mlp_units': 256, 'cnn_sigmoid': False, 'image_dist': 'mse', 'vector_dist': 'symlog_mse', 'outscale': 1.0}, deterministic_run=False, device='cuda:0', disag_action_cond=False, disag_layers=4, disag_log=True, disag_models=10, disag_offset=1, disag_target='stoch', disag_units=400, discount=0.997, discount_lambda=0.95, dyn_deter=512, dyn_discrete=32, dyn_hidden=256, dyn_mean_act='none', dyn_min_std=0.1, dyn_rec_depth=1, dyn_scale=0.5, dyn_std_act='sigmoid2', dyn_stoch=32, encoder={'mlp_keys': 'state', 'cnn_keys': '$^', 'act': 'SiLU', 'norm': True, 'cnn_depth': 32, 'kernel_size': 4, 'minres': 4, 'mlp_layers': 3, 'mlp_units': 256, 'symlog_inputs': True}, envs=1, eval_episode_num=10, eval_every=10000.0, eval_state_mean=False, evaldir=None, expl_behavior='greedy', expl_extr_scale=0.0, expl_intr_scale=1.0, expl_until=0, grad_clip=1000, grad_heads=('decoder', 'reward', 'cont'), grayscale=False, imag_gradient='reinforce', imag_gradient_mix=0.0, imag_horizon=15, initial='learned', kl_free=1.0, log_every=10000.0, logdir='../logdir/craftax_small', model_lr=0.0001, norm=True, offline_evaldir='', offline_traindir='', opt='adam', opt_eps=1e-08, parallel=False, precision=32, prefill=2500, pretrain=100, rep_scale=0.1, reset_every=0, reward_EMA=True, reward_head={'layers': 3, 'dist': 'symlog_disc', 'loss_scale': 1.0, 'outscale': 0.0}, seed=0, size=(64, 64), step=1000000.0, steps=1000000.0, task='craftax_Craftax-Symbolic-AutoReset-v1', time_limit=1000, train_ratio=512, traindir=None, unimix_ratio=0.01, units=512, value={'layers': 3}, video_pred_log=False, weight_decay=0.0)"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# emulate cli\n",
|
||||
"argv = f\"../dreamer.py --configs craftax_small --logdir ../logdir/craftax_small\"\n",
|
||||
"argv = argv.split()\n",
|
||||
"print(argv)\n",
|
||||
"config = parse_args(argv)\n",
|
||||
"config"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[32m2024-06-06 13:35:50.147\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m27\u001b[0m - \u001b[1mLogdir ../logdir/craftax_small\u001b[0m\n",
|
||||
"\u001b[32m2024-06-06 13:35:50.153\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m36\u001b[0m - \u001b[1mCreate envs.\u001b[0m\n",
|
||||
"\u001b[32m2024-06-06 13:36:42.176\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m57\u001b[0m - \u001b[1mAction Space Box(0.0, 1.0, (43,), float32)\u001b[0m\n",
|
||||
"\u001b[32m2024-06-06 13:36:42.178\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m63\u001b[0m - \u001b[1mPrefill dataset (0 steps).\u001b[0m\n",
|
||||
"\u001b[32m2024-06-06 13:36:42.180\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m92\u001b[0m - \u001b[1mLogger: (128521 steps).\u001b[0m\n",
|
||||
"\u001b[32m2024-06-06 13:36:42.180\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m94\u001b[0m - \u001b[1mSimulate agent.\u001b[0m\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from loguru import logger\n",
|
||||
"from tqdm.auto import tqdm\n",
|
||||
"import pathlib\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"from torch import nn\n",
|
||||
"from torch import distributions as torchd\n",
|
||||
"\n",
|
||||
"import exploration as expl\n",
|
||||
"import models\n",
|
||||
"import tools\n",
|
||||
"import envs.wrappers as wrappers\n",
|
||||
"from parallel import Parallel, Damy\n",
|
||||
"\n",
|
||||
"# from main\n",
|
||||
"tools.set_seed_everywhere(config.seed)\n",
|
||||
"if config.deterministic_run:\n",
|
||||
" tools.enable_deterministic_run()\n",
|
||||
"logdir = pathlib.Path(config.logdir).expanduser()\n",
|
||||
"config.traindir = config.traindir or logdir / \"train_eps\"\n",
|
||||
"config.evaldir = config.evaldir or logdir / \"eval_eps\"\n",
|
||||
"config.steps //= config.action_repeat\n",
|
||||
"config.eval_every //= config.action_repeat\n",
|
||||
"config.log_every //= config.action_repeat\n",
|
||||
"config.time_limit //= config.action_repeat\n",
|
||||
"\n",
|
||||
"logger.info(f\"Logdir {logdir}\")\n",
|
||||
"logdir.mkdir(parents=True, exist_ok=True)\n",
|
||||
"config.traindir.mkdir(parents=True, exist_ok=True)\n",
|
||||
"config.evaldir.mkdir(parents=True, exist_ok=True)\n",
|
||||
"step = count_steps(config.traindir)\n",
|
||||
"# step in logger is environmental step\n",
|
||||
"tlogger = tools.Logger(logdir, config.action_repeat * step)\n",
|
||||
"logger.add(logdir/\"logger.log\")\n",
|
||||
"\n",
|
||||
"logger.info(\"Create envs.\")\n",
|
||||
"if config.offline_traindir:\n",
|
||||
" directory = config.offline_traindir.format(**vars(config))\n",
|
||||
"else:\n",
|
||||
" directory = config.traindir\n",
|
||||
"train_eps = tools.load_episodes(directory, limit=config.dataset_size)\n",
|
||||
"if config.offline_evaldir:\n",
|
||||
" directory = config.offline_evaldir.format(**vars(config))\n",
|
||||
"else:\n",
|
||||
" directory = config.evaldir\n",
|
||||
"eval_eps = tools.load_episodes(directory, limit=1)\n",
|
||||
"make = lambda mode, id: make_env(config, mode, id)\n",
|
||||
"train_envs = [make(\"train\", i) for i in range(config.envs)]\n",
|
||||
"eval_envs = [make(\"eval\", i) for i in range(config.envs)]\n",
|
||||
"if config.parallel:\n",
|
||||
" train_envs = [Parallel(env, \"process\") for env in train_envs]\n",
|
||||
" eval_envs = [Parallel(env, \"process\") for env in eval_envs]\n",
|
||||
"else:\n",
|
||||
" train_envs = [Damy(env) for env in train_envs]\n",
|
||||
" eval_envs = [Damy(env) for env in eval_envs]\n",
|
||||
"acts = train_envs[0].action_space\n",
|
||||
"logger.info(f\"Action Space {acts}\" )\n",
|
||||
"config.num_actions = acts.n if hasattr(acts, \"n\") else acts.shape[0]\n",
|
||||
"\n",
|
||||
"state = None\n",
|
||||
"if not config.offline_traindir:\n",
|
||||
" prefill = max(0, config.prefill - count_steps(config.traindir))\n",
|
||||
" logger.info(f\"Prefill dataset ({prefill} steps).\")\n",
|
||||
" if hasattr(acts, \"discrete\"):\n",
|
||||
" random_actor = tools.OneHotDist(\n",
|
||||
" torch.zeros(config.num_actions).repeat(config.envs, 1)\n",
|
||||
" )\n",
|
||||
" else:\n",
|
||||
" random_actor = torchd.independent.Independent(\n",
|
||||
" torchd.uniform.Uniform(\n",
|
||||
" torch.Tensor(acts.low).repeat(config.envs, 1),\n",
|
||||
" torch.Tensor(acts.high).repeat(config.envs, 1),\n",
|
||||
" ),\n",
|
||||
" 1,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def random_agent(o, d, s):\n",
|
||||
" action = random_actor.sample()\n",
|
||||
" logprob = random_actor.log_prob(action)\n",
|
||||
" return {\"action\": action, \"logprob\": logprob}, None\n",
|
||||
"\n",
|
||||
" state = tools.simulate(\n",
|
||||
" random_agent,\n",
|
||||
" train_envs,\n",
|
||||
" train_eps,\n",
|
||||
" config.traindir,\n",
|
||||
" tlogger,\n",
|
||||
" limit=config.dataset_size,\n",
|
||||
" steps=prefill,\n",
|
||||
" )\n",
|
||||
" tlogger.step += prefill * config.action_repeat\n",
|
||||
" logger.info(f\"Logger: ({tlogger.step} steps).\")\n",
|
||||
"\n",
|
||||
"logger.info(\"Simulate agent.\")\n",
|
||||
"train_dataset = make_dataset(train_eps, config)\n",
|
||||
"eval_dataset = make_dataset(eval_eps, config)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[32m2024-06-06 13:38:20.651\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m323\u001b[0m - \u001b[1mEncoder CNN shapes: {}\u001b[0m\n",
|
||||
"\u001b[32m2024-06-06 13:38:20.651\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m324\u001b[0m - \u001b[1mEncoder MLP shapes: {'state': (16536,)}\u001b[0m\n",
|
||||
"\u001b[32m2024-06-06 13:38:20.751\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m390\u001b[0m - \u001b[1mDecoder CNN shapes: {}\u001b[0m\n",
|
||||
"\u001b[32m2024-06-06 13:38:20.751\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m391\u001b[0m - \u001b[1mDecoder MLP shapes: {'state': (16536,)}\u001b[0m\n",
|
||||
"\u001b[32m2024-06-06 13:38:20.813\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmodels\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m102\u001b[0m - \u001b[1mOptimizer model_opt has 15732120 variables.\u001b[0m\n",
|
||||
"\u001b[32m2024-06-06 13:38:20.836\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmodels\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m281\u001b[0m - \u001b[1mOptimizer actor_opt has 1335851 variables.\u001b[0m\n",
|
||||
"\u001b[32m2024-06-06 13:38:20.837\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmodels\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m292\u001b[0m - \u001b[1mOptimizer value_opt has 1181439 variables.\u001b[0m\n",
|
||||
"\u001b[32m2024-06-06 13:38:21.032\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m17\u001b[0m - \u001b[33m\u001b[1mLoaded model from ../logdir/craftax_small/latest.pt\u001b[0m\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"config = parse_args(argv)\n",
|
||||
"config.num_actions = acts.n if hasattr(acts, \"n\") else acts.shape[0]\n",
|
||||
"agent = Dreamer(\n",
|
||||
" train_envs[0].observation_space,\n",
|
||||
" train_envs[0].action_space,\n",
|
||||
" config,\n",
|
||||
" tlogger,\n",
|
||||
" train_dataset,\n",
|
||||
").to(config.device)\n",
|
||||
"# print(agent)\n",
|
||||
"agent.requires_grad_(requires_grad=False)\n",
|
||||
"if (logdir / \"latest.pt\").exists():\n",
|
||||
" checkpoint = torch.load(logdir / \"latest.pt\")\n",
|
||||
" agent.load_state_dict(checkpoint[\"agent_state_dict\"])\n",
|
||||
" tools.recursively_load_optim_state_dict(agent, checkpoint[\"optims_state_dict\"])\n",
|
||||
" agent._should_pretrain._once = False\n",
|
||||
" logger.warning(f\"Loaded model from {logdir / 'latest.pt'}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Now lets play"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(0, 0, array([ True]), array([0], dtype=int32), [None], None, [0])"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"assert state is not None\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"state"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from tools import convert, add_to_cache\n",
|
||||
"envs = train_envs\n",
|
||||
"cache = train_eps\n",
|
||||
"\n",
|
||||
"step, episode = 0, 0\n",
|
||||
"done = np.ones(len(envs), bool)\n",
|
||||
"length = np.zeros(len(envs), np.int32)\n",
|
||||
"obs = [None] * len(envs)\n",
|
||||
"agent_state = None\n",
|
||||
"reward = [0] * len(envs)\n",
|
||||
"\n",
|
||||
"indices = [index for index, d in enumerate(done) if d]\n",
|
||||
"results = [envs[i].reset() for i in indices]\n",
|
||||
"results = [r() for r in results]\n",
|
||||
"for index, result in zip(indices, results):\n",
|
||||
" t = result.copy()\n",
|
||||
" t = {k: convert(v) for k, v in t.items()}\n",
|
||||
" # action will be added to transition in add_to_cache\n",
|
||||
" t[\"reward\"] = 0.0\n",
|
||||
" t[\"discount\"] = 1.0\n",
|
||||
" # initial state should be added to cache\n",
|
||||
" add_to_cache(cache, envs[index].id, t)\n",
|
||||
" # replace obs with done by initial state\n",
|
||||
" obs[index] = result\n",
|
||||
"# step agents"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[32m2024-06-06 13:38:34.000\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mtools\u001b[0m:\u001b[36mwrite\u001b[0m:\u001b[36m85\u001b[0m - \u001b[1m[128521] model_loss \u001b[31m22.2\u001b[0m\u001b[1m / model_grad_norm \u001b[31m14.4\u001b[0m\u001b[1m / state_loss \u001b[31m17.4\u001b[0m\u001b[1m / reward_loss \u001b[31m0.1\u001b[0m\u001b[1m / cont_loss \u001b[31m0.0\u001b[0m\u001b[1m / kl_free \u001b[31m1.0\u001b[0m\u001b[1m / dyn_scale \u001b[31m0.5\u001b[0m\u001b[1m / rep_scale \u001b[31m0.1\u001b[0m\u001b[1m / dyn_loss \u001b[31m7.8\u001b[0m\u001b[1m / rep_loss \u001b[31m7.8\u001b[0m\u001b[1m / kl \u001b[31m7.7\u001b[0m\u001b[1m / prior_ent \u001b[31m48.4\u001b[0m\u001b[1m / post_ent \u001b[31m40.7\u001b[0m\u001b[1m / normed_target_mean \u001b[31m0.4\u001b[0m\u001b[1m / normed_target_std \u001b[31m0.3\u001b[0m\u001b[1m / normed_target_min \u001b[31m-0.3\u001b[0m\u001b[1m / normed_target_max \u001b[31m1.8\u001b[0m\u001b[1m / EMA_005 \u001b[31m12.3\u001b[0m\u001b[1m / EMA_095 \u001b[31m26.4\u001b[0m\u001b[1m / value_mean \u001b[31m18.2\u001b[0m\u001b[1m / value_std \u001b[31m4.3\u001b[0m\u001b[1m / value_min \u001b[31m10.1\u001b[0m\u001b[1m / value_max \u001b[31m31.1\u001b[0m\u001b[1m / target_mean \u001b[31m18.4\u001b[0m\u001b[1m / target_std \u001b[31m4.7\u001b[0m\u001b[1m / target_min \u001b[31m8.4\u001b[0m\u001b[1m / target_max \u001b[31m37.8\u001b[0m\u001b[1m / imag_reward_mean \u001b[31m0.0\u001b[0m\u001b[1m / imag_reward_std \u001b[31m0.1\u001b[0m\u001b[1m / imag_reward_min \u001b[31m-0.2\u001b[0m\u001b[1m / imag_reward_max \u001b[31m1.0\u001b[0m\u001b[1m / imag_action_mean \u001b[31m10.0\u001b[0m\u001b[1m / imag_action_std \u001b[31m12.9\u001b[0m\u001b[1m / imag_action_min \u001b[31m0.0\u001b[0m\u001b[1m / imag_action_max \u001b[31m42.0\u001b[0m\u001b[1m / actor_entropy \u001b[31m0.9\u001b[0m\u001b[1m / actor_loss \u001b[31m0.1\u001b[0m\u001b[1m / actor_grad_norm \u001b[31m0.5\u001b[0m\u001b[1m / value_loss \u001b[31m1.3\u001b[0m\u001b[1m / value_grad_norm \u001b[31m0.9\u001b[0m\u001b[1m / update_count \u001b[31m1.0\u001b[0m\u001b[1m / fps \u001b[31m0.0\u001b[0m\u001b[1m\u001b[0m\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# from tools.simulate\n",
|
||||
"\n",
|
||||
"# step\n",
|
||||
"# step, episode, done, length, obs, agent_state, reward = state\n",
|
||||
"obs = {k: np.stack([o[k] for o in obs]) for k in obs[0] if \"log_\" not in k}\n",
|
||||
"action, agent_state = agent(obs, done, agent_state)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"=====================================================================================\n",
|
||||
"Layer (type:depth-idx) Param #\n",
|
||||
"=====================================================================================\n",
|
||||
"Dreamer --\n",
|
||||
"├─OptimizedModule: 1-1 --\n",
|
||||
"│ └─WorldModel: 2-1 --\n",
|
||||
"│ │ └─MultiEncoder: 3-1 (4,365,824)\n",
|
||||
"│ │ └─RSSM: 3-2 (3,831,808)\n",
|
||||
"│ │ └─ModuleDict: 3-3 (7,534,488)\n",
|
||||
"├─OptimizedModule: 1-2 --\n",
|
||||
"│ └─ImagBehavior: 2-2 15,732,120\n",
|
||||
"│ │ └─WorldModel: 3-4 (recursive)\n",
|
||||
"│ │ └─MLP: 3-5 (1,335,851)\n",
|
||||
"│ │ └─MLP: 3-6 (1,181,439)\n",
|
||||
"│ │ └─MLP: 3-7 (1,181,439)\n",
|
||||
"├─OptimizedModule: 1-3 (recursive)\n",
|
||||
"│ └─ImagBehavior: 2-3 (recursive)\n",
|
||||
"│ │ └─WorldModel: 3-8 (recursive)\n",
|
||||
"│ │ └─MLP: 3-9 (recursive)\n",
|
||||
"│ │ └─MLP: 3-10 (recursive)\n",
|
||||
"│ │ └─MLP: 3-11 (recursive)\n",
|
||||
"=====================================================================================\n",
|
||||
"Total params: 35,162,969\n",
|
||||
"Trainable params: 0\n",
|
||||
"Non-trainable params: 35,162,969\n",
|
||||
"====================================================================================="
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from torchinfo import summary\n",
|
||||
"\n",
|
||||
"summary(agent, input=(obs, done, agent_state), depth=3)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Fine grained torchinfo"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"wm = agent._wm\n",
|
||||
"data = next(agent._dataset) \n",
|
||||
"# self._train()\n",
|
||||
"# post, context, mets = wm._train(data)\n",
|
||||
"data = wm.preprocess(data)\n",
|
||||
"embed = wm.encoder(data)\n",
|
||||
"post, prior = wm.dynamics.observe(\n",
|
||||
" embed, data[\"action\"], data[\"is_first\"]\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"==========================================================================================\n",
|
||||
"Layer (type:depth-idx) Output Shape Param #\n",
|
||||
"==========================================================================================\n",
|
||||
"MultiEncoder [128, 16, 256] --\n",
|
||||
"├─MLP: 1-1 [128, 16, 256] --\n",
|
||||
"│ └─Sequential: 2-1 [128, 16, 256] --\n",
|
||||
"│ │ └─Linear: 3-1 [128, 16, 256] (4,233,216)\n",
|
||||
"│ │ └─LayerNorm: 3-2 [128, 16, 256] (512)\n",
|
||||
"│ │ └─SiLU: 3-3 [128, 16, 256] --\n",
|
||||
"│ │ └─Linear: 3-4 [128, 16, 256] (65,536)\n",
|
||||
"│ │ └─LayerNorm: 3-5 [128, 16, 256] (512)\n",
|
||||
"│ │ └─SiLU: 3-6 [128, 16, 256] --\n",
|
||||
"│ │ └─Linear: 3-7 [128, 16, 256] (65,536)\n",
|
||||
"│ │ └─LayerNorm: 3-8 [128, 16, 256] (512)\n",
|
||||
"│ │ └─SiLU: 3-9 [128, 16, 256] --\n",
|
||||
"==========================================================================================\n",
|
||||
"Total params: 4,365,824\n",
|
||||
"Trainable params: 0\n",
|
||||
"Non-trainable params: 4,365,824\n",
|
||||
"Total mult-adds (M): 558.83\n",
|
||||
"==========================================================================================\n",
|
||||
"Input size (MB): 487.31\n",
|
||||
"Forward/backward pass size (MB): 25.17\n",
|
||||
"Params size (MB): 17.46\n",
|
||||
"Estimated Total Size (MB): 529.94\n",
|
||||
"=========================================================================================="
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"summary(wm.encoder, input_data=(data,), depth=3, col_names=[\"output_size\", \"num_params\", ])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"decoder\n",
|
||||
"==========================================================================================\n",
|
||||
"Layer (type:depth-idx) Output Shape Param #\n",
|
||||
"==========================================================================================\n",
|
||||
"MultiDecoder -- --\n",
|
||||
"├─MLP: 1-1 -- --\n",
|
||||
"│ └─Sequential: 2-1 [128, 16, 256] --\n",
|
||||
"│ │ └─Linear: 3-1 [128, 16, 256] (393,216)\n",
|
||||
"│ │ └─LayerNorm: 3-2 [128, 16, 256] (512)\n",
|
||||
"│ │ └─SiLU: 3-3 [128, 16, 256] --\n",
|
||||
"│ │ └─Linear: 3-4 [128, 16, 256] (65,536)\n",
|
||||
"│ │ └─LayerNorm: 3-5 [128, 16, 256] (512)\n",
|
||||
"│ │ └─SiLU: 3-6 [128, 16, 256] --\n",
|
||||
"│ │ └─Linear: 3-7 [128, 16, 256] (65,536)\n",
|
||||
"│ │ └─LayerNorm: 3-8 [128, 16, 256] (512)\n",
|
||||
"│ │ └─SiLU: 3-9 [128, 16, 256] --\n",
|
||||
"│ └─ModuleDict: 2-2 -- --\n",
|
||||
"│ │ └─Linear: 3-10 [128, 16, 16536] (4,249,752)\n",
|
||||
"==========================================================================================\n",
|
||||
"Total params: 4,775,576\n",
|
||||
"Trainable params: 0\n",
|
||||
"Non-trainable params: 4,775,576\n",
|
||||
"Total mult-adds (M): 611.27\n",
|
||||
"==========================================================================================\n",
|
||||
"Input size (MB): 12.58\n",
|
||||
"Forward/backward pass size (MB): 296.09\n",
|
||||
"Params size (MB): 19.10\n",
|
||||
"Estimated Total Size (MB): 327.78\n",
|
||||
"==========================================================================================\n",
|
||||
"Summary Failed for reward Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Sequential: 1, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 1]\n",
|
||||
"Summary Failed for cont Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Sequential: 1, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 1]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# heads\n",
|
||||
"feat = wm.dynamics.get_feat(post)\n",
|
||||
"for name, head in wm.heads.items():\n",
|
||||
" try:\n",
|
||||
" o = summary(head, input_data=(feat,), depth=3, col_names=[\"output_size\", \"num_params\", ])\n",
|
||||
" print(name)\n",
|
||||
" print(o)\n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"Summary Failed for {name} {e}\")\n",
|
||||
" continue"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# fail as no call method\n",
|
||||
"# summary(wm.dynamics, input_data=(embed, data[\"action\"], data[\"is_first\"]), depth=3, col_names=[\"output_size\", \"num_params\", ])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
+4
-4
@@ -320,8 +320,8 @@ class MultiEncoder(nn.Module):
|
||||
for k, v in shapes.items()
|
||||
if len(v) in (1, 2) and re.match(mlp_keys, k)
|
||||
}
|
||||
logger.info("Encoder CNN shapes:", self.cnn_shapes)
|
||||
logger.info("Encoder MLP shapes:", self.mlp_shapes)
|
||||
logger.info("Encoder CNN shapes: {}", self.cnn_shapes)
|
||||
logger.info("Encoder MLP shapes: {}", self.mlp_shapes)
|
||||
|
||||
self.outdim = 0
|
||||
if self.cnn_shapes:
|
||||
@@ -387,8 +387,8 @@ class MultiDecoder(nn.Module):
|
||||
for k, v in shapes.items()
|
||||
if len(v) in (1, 2) and re.match(mlp_keys, k)
|
||||
}
|
||||
logger.info("Decoder CNN shapes: %s", self.cnn_shapes)
|
||||
logger.info("Decoder MLP shapes: %s", self.mlp_shapes)
|
||||
logger.info("Decoder CNN shapes: {}", self.cnn_shapes)
|
||||
logger.info("Decoder MLP shapes: {}", self.mlp_shapes)
|
||||
|
||||
if self.cnn_shapes:
|
||||
some_shape = list(self.cnn_shapes.values())[0]
|
||||
|
||||
Generated
+12
-1
@@ -3578,6 +3578,17 @@ type = "legacy"
|
||||
url = "https://download.pytorch.org/whl/cu121"
|
||||
reference = "pytorch"
|
||||
|
||||
[[package]]
|
||||
name = "torchinfo"
|
||||
version = "1.8.0"
|
||||
description = "Model summary in PyTorch, based off of the original torchsummary."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "torchinfo-1.8.0-py3-none-any.whl", hash = "sha256:2e911c2918603f945c26ff21a3a838d12709223dc4ccf243407bce8b6e897b46"},
|
||||
{file = "torchinfo-1.8.0.tar.gz", hash = "sha256:72e94b0e9a3e64dc583a8e5b7940b8938a1ac0f033f795457f27e6f4e7afa2e9"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tornado"
|
||||
version = "6.4"
|
||||
@@ -3836,4 +3847,4 @@ test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-it
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.9"
|
||||
content-hash = "8d04aef5b114f7ae76dc03bc61d308f1b239d390b0c71fab7d0c8f467cc95dd4"
|
||||
content-hash = "0275da73363d94f6a5cdadc9662c1b254ef50310aabce3aa663552aa4802b001"
|
||||
|
||||
@@ -38,6 +38,7 @@ imageio = "^2.34.1"
|
||||
craftax = {path = "/media/wassname/SGIronWolf/projects5/2024/Craftax", develop = true }
|
||||
# craftax = {git = "https://github.com/wassname/Craftax" , develop = true }
|
||||
chex = "^0.1.86"
|
||||
torchinfo = "^1.8.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ipywidgets = "^8.1.3"
|
||||
|
||||
Reference in New Issue
Block a user