fix mem overflow, torchinfo

This commit is contained in:
wassname
2024-06-07 06:00:27 +08:00
parent c77993207a
commit 718e92a9a1
8 changed files with 1145 additions and 336 deletions
+5 -1
View File
@@ -69,7 +69,7 @@ defaults:
model_lr: 1e-4
opt_eps: 1e-8
grad_clip: 1000
dataset_size: 1000000
dataset_size: 1_000_000
opt: 'adam'
# Behavior.
@@ -147,6 +147,7 @@ craftax:
reward_head: {layers: 5}
cont_head: {layers: 5}
imag_gradient: 'reinforce'
time_limit: 4000
craftax_small:
task: craftax_Craftax-Symbolic-AutoReset-v1
@@ -171,6 +172,7 @@ craftax_small:
imag_gradient: 'reinforce'
batch_size: 256
batch_length: 32
time_limit: 4000
craftax_smaller:
task: craftax_Craftax-Symbolic-AutoReset-v1
@@ -195,6 +197,8 @@ craftax_smaller:
imag_gradient: 'reinforce'
batch_size: 256
batch_length: 32
time_limit: 4000
dataset_size: 20_000
atari100k:
steps: 4e5
+7 -1
View File
@@ -350,7 +350,7 @@ def main(config):
}
torch.save(items_to_save, logdir / "latest.pt")
logger.info(f"Saved model to {logdir / 'latest.pt'}")
# pbar.update(agent._step-pbar.n) # 16858 at a time
pbar.update(agent._step-pbar.n) # 16858 at a time
for env in train_envs + eval_envs:
try:
env.close()
@@ -358,6 +358,7 @@ def main(config):
pass
def parse_args(argv=None):
# first load config name as arg from command line
parser = argparse.ArgumentParser()
parser.add_argument("--configs", nargs="+")
if argv is None:
@@ -382,11 +383,16 @@ def parse_args(argv=None):
for name in name_list:
recursive_update(defaults, configs[name])
# defaults = {k:tools.args_type(v)(v) for k, v in defaults.items()}
# config = argparse.Namespace(**defaults)
# now use argparse to parse config, allowing us to override config with any extra args from cli. You can even use -h
parser = argparse.ArgumentParser()
for key, value in sorted(defaults.items(), key=lambda x: x[0]):
arg_type = tools.args_type(value)
parser.add_argument(f"--{key}", type=arg_type, default=arg_type(value))
args = parser.parse_args(remaining)
logger.info(f"config={args}")
return args
if __name__ == "__main__":
+1 -1
View File
@@ -6,7 +6,7 @@ export TQDM_MININTERVAL := "30"
main:
. ./.venv/bin/activate
python dreamer.py --configs craftax_small --logdir ./logdir/crafter
python dreamer.py --configs craftax_smaller --logdir ./logdir/crafterer
logs:
tensorboard --logdir logdir/craftax
+2 -4
View File
@@ -5,7 +5,7 @@ from torch import nn
import networks
import tools
from loguru import logger
from torchinfo import summary
from envs.craftax_env import state2img
to_np = lambda x: x.detach().cpu().numpy()
@@ -99,9 +99,7 @@ class WorldModel(nn.Module):
opt=config.opt,
use_amp=self._use_amp,
)
logger.info(
f"Optimizer model_opt has {sum(param.numel() for param in self.parameters())} variables."
)
logger.info(f"World Model\n{summary(self, row_settings=['var_names'],)}")
# other losses are scaled by 1.0.
self._scales = dict(
reward=config.reward_head["loss_scale"],
File diff suppressed because it is too large Load Diff
+142
View File
@@ -0,0 +1,142 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from craftax.environment_base.util import load_compressed_pickle"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"import os, sys\n",
"os.sys.path.append('/media/wassname/SGIronWolf/projects5/2024/Craftax/craftax/craftax/')"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'craftax.craftax_state'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[14], line 11\u001b[0m\n\u001b[1;32m 8\u001b[0m data \u001b[38;5;241m=\u001b[39m pickle\u001b[38;5;241m.\u001b[39mload(data, errors\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mignore\u001b[39m\u001b[38;5;124m'\u001b[39m, fix_imports\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m data\n\u001b[0;32m---> 11\u001b[0m \u001b[43mload_compressed_pickle\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43m/home/wassname/Downloads/people/run1.pbz2\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n",
"Cell \u001b[0;32mIn[14], line 8\u001b[0m, in \u001b[0;36mload_compressed_pickle\u001b[0;34m(file)\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mload_compressed_pickle\u001b[39m(file: \u001b[38;5;28mstr\u001b[39m):\n\u001b[1;32m 7\u001b[0m data \u001b[38;5;241m=\u001b[39m bz2\u001b[38;5;241m.\u001b[39mBZ2File(file, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrb\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 8\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[43mpickle\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43merrors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mignore\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfix_imports\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m data\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'craftax.craftax_state'"
]
},
{
"ename": "",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
"\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
"\u001b[1;31mClick <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. \n",
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
]
}
],
"source": [
"import pickle, bz2\n",
"import craftax.craftax.craftax_state\n",
"craftax.craftax_state = craftax.craftax.craftax_state\n",
"\n",
"\n",
"def load_compressed_pickle(file: str):\n",
" data = bz2.BZ2File(file, \"rb\")\n",
" data = pickle.load(data, errors='ignore', fix_imports=False)\n",
" return data\n",
"\n",
"load_compressed_pickle('/home/wassname/Downloads/people/run1.pbz2')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[0;31mSignature:\u001b[0m\n",
"\u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mfile\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mfix_imports\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mencoding\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'ASCII'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'strict'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mbuffers\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mDocstring:\u001b[0m\n",
"Read and return an object from the pickle data stored in a file.\n",
"\n",
"This is equivalent to ``Unpickler(file).load()``, but may be more\n",
"efficient.\n",
"\n",
"The protocol version of the pickle is detected automatically, so no\n",
"protocol argument is needed. Bytes past the pickled object's\n",
"representation are ignored.\n",
"\n",
"The argument *file* must have two methods, a read() method that takes\n",
"an integer argument, and a readline() method that requires no\n",
"arguments. Both methods should return bytes. Thus *file* can be a\n",
"binary file object opened for reading, an io.BytesIO object, or any\n",
"other custom object that meets this interface.\n",
"\n",
"Optional keyword arguments are *fix_imports*, *encoding* and *errors*,\n",
"which are used to control compatibility support for pickle stream\n",
"generated by Python 2. If *fix_imports* is True, pickle will try to\n",
"map the old Python 2 names to the new names used in Python 3. The\n",
"*encoding* and *errors* tell pickle how to decode 8-bit string\n",
"instances pickled by Python 2; these default to 'ASCII' and 'strict',\n",
"respectively. The *encoding* can be 'bytes' to read these 8-bit\n",
"string instances as bytes objects.\n",
"\u001b[0;31mType:\u001b[0m builtin_function_or_method"
]
}
],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
+8
View File
@@ -9,6 +9,10 @@ from torch import distributions as torchd
from loguru import logger
import tools
from einops import rearrange
from torchinfo import summary
def my_summary(model, input_data):
return summary(model, input_data, col_names=('input_size', 'output_size', 'num_params', 'mult_adds'), verbose=0, row_settings=['depth', 'var_names', 'ascii_only'])
class RSSM(nn.Module):
@@ -332,6 +336,7 @@ class MultiEncoder(nn.Module):
input_shape, cnn_depth, act, norm, kernel_size, minres
)
self.outdim += self._cnn.outdim
logger.debug(f"Encoder cnn\n{my_summary(self._cnn, (1,)+input_shape)}")
if self.mlp_shapes:
input_size = sum([sum(v) for v in self.mlp_shapes.values()])
self._mlp = MLP(
@@ -344,6 +349,7 @@ class MultiEncoder(nn.Module):
symlog_inputs=symlog_inputs,
name="Encoder",
)
logger.debug(f"Encoder mlp\n{my_summary(self._mlp, (1,input_size))}")
self.outdim += mlp_units
def forward(self, obs):
@@ -405,6 +411,7 @@ class MultiDecoder(nn.Module):
outscale=outscale,
cnn_sigmoid=cnn_sigmoid,
)
logger.debug(f"Decoder cnn\n{my_summary(self._cnn, (1,1,feat_size))}")
if self.mlp_shapes:
self._mlp = MLP(
feat_size,
@@ -417,6 +424,7 @@ class MultiDecoder(nn.Module):
outscale=outscale,
name="Decoder",
)
logger.debug(f"Decoder mlp\n{my_summary(self._mlp, (1,feat_size))}")
self._image_dist = image_dist
def forward(self, features):
+11 -3
View File
@@ -15,7 +15,7 @@ from torch import nn
from torch.nn import functional as F
from torch import distributions as torchd
from torch.utils.tensorboard import SummaryWriter
from contextlib import contextmanager
to_np = lambda x: x.detach().cpu().numpy()
@@ -82,7 +82,7 @@ class Logger:
scalars.append(("fps", self._compute_fps(step)))
# print out the episode stats
stats = " / ".join(f"{k.replace('log_achievement_', '')} <red>{v:.1f}</red>" for k, v in scalars)
logger.opt(colors=True).info(f"[{step}] {stats}")
logger.opt(colors=True).debug(f"[{step}] {stats}")
with (self._logdir / "metrics.jsonl").open("a") as f:
f.write(json.dumps({"step": step, **dict(scalars)}) + "\n")
for name, value in scalars:
@@ -127,6 +127,14 @@ class Logger:
self._writer.add_video(name, value, step, 16)
@contextmanager
def cond_tqdm(pbar=None, *args, **kwargs):
if pbar is None:
with tqdm(*args, **kwargs) as pbar:
yield pbar
else:
yield pbar
def simulate(
agent,
envs,
@@ -151,7 +159,7 @@ def simulate(
reward = [0] * len(envs)
else:
step, episode, done, length, obs, agent_state, reward = state
with tqdm(total=steps, disable=pbar is None) as pbar:
with cond_tqdm(total=steps, pbar=pbar) as pbar:
while (steps and step < steps) or (episodes and episode < episodes):
# reset envs if necessary
if done.any():