fix mem overflow, torchinfo

This commit is contained in:
wassname
2024-06-07 06:00:27 +08:00
parent c77993207a
commit 718e92a9a1
8 changed files with 1145 additions and 336 deletions
+5 -1
View File
@@ -69,7 +69,7 @@ defaults:
model_lr: 1e-4 model_lr: 1e-4
opt_eps: 1e-8 opt_eps: 1e-8
grad_clip: 1000 grad_clip: 1000
dataset_size: 1000000 dataset_size: 1_000_000
opt: 'adam' opt: 'adam'
# Behavior. # Behavior.
@@ -147,6 +147,7 @@ craftax:
reward_head: {layers: 5} reward_head: {layers: 5}
cont_head: {layers: 5} cont_head: {layers: 5}
imag_gradient: 'reinforce' imag_gradient: 'reinforce'
time_limit: 4000
craftax_small: craftax_small:
task: craftax_Craftax-Symbolic-AutoReset-v1 task: craftax_Craftax-Symbolic-AutoReset-v1
@@ -171,6 +172,7 @@ craftax_small:
imag_gradient: 'reinforce' imag_gradient: 'reinforce'
batch_size: 256 batch_size: 256
batch_length: 32 batch_length: 32
time_limit: 4000
craftax_smaller: craftax_smaller:
task: craftax_Craftax-Symbolic-AutoReset-v1 task: craftax_Craftax-Symbolic-AutoReset-v1
@@ -195,6 +197,8 @@ craftax_smaller:
imag_gradient: 'reinforce' imag_gradient: 'reinforce'
batch_size: 256 batch_size: 256
batch_length: 32 batch_length: 32
time_limit: 4000
dataset_size: 20_000
atari100k: atari100k:
steps: 4e5 steps: 4e5
+7 -1
View File
@@ -350,7 +350,7 @@ def main(config):
} }
torch.save(items_to_save, logdir / "latest.pt") torch.save(items_to_save, logdir / "latest.pt")
logger.info(f"Saved model to {logdir / 'latest.pt'}") logger.info(f"Saved model to {logdir / 'latest.pt'}")
# pbar.update(agent._step-pbar.n) # 16858 at a time pbar.update(agent._step-pbar.n) # 16858 at a time
for env in train_envs + eval_envs: for env in train_envs + eval_envs:
try: try:
env.close() env.close()
@@ -358,6 +358,7 @@ def main(config):
pass pass
def parse_args(argv=None): def parse_args(argv=None):
# first load config name as arg from command line
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--configs", nargs="+") parser.add_argument("--configs", nargs="+")
if argv is None: if argv is None:
@@ -382,11 +383,16 @@ def parse_args(argv=None):
for name in name_list: for name in name_list:
recursive_update(defaults, configs[name]) recursive_update(defaults, configs[name])
# defaults = {k:tools.args_type(v)(v) for k, v in defaults.items()}
# config = argparse.Namespace(**defaults)
# now use argparse to parse config, allowing us to override config with any extra args from cli. You can even use -h
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
for key, value in sorted(defaults.items(), key=lambda x: x[0]): for key, value in sorted(defaults.items(), key=lambda x: x[0]):
arg_type = tools.args_type(value) arg_type = tools.args_type(value)
parser.add_argument(f"--{key}", type=arg_type, default=arg_type(value)) parser.add_argument(f"--{key}", type=arg_type, default=arg_type(value))
args = parser.parse_args(remaining) args = parser.parse_args(remaining)
logger.info(f"config={args}")
return args return args
if __name__ == "__main__": if __name__ == "__main__":
+1 -1
View File
@@ -6,7 +6,7 @@ export TQDM_MININTERVAL := "30"
main: main:
. ./.venv/bin/activate . ./.venv/bin/activate
python dreamer.py --configs craftax_small --logdir ./logdir/crafter python dreamer.py --configs craftax_smaller --logdir ./logdir/crafterer
logs: logs:
tensorboard --logdir logdir/craftax tensorboard --logdir logdir/craftax
+2 -4
View File
@@ -5,7 +5,7 @@ from torch import nn
import networks import networks
import tools import tools
from loguru import logger from loguru import logger
from torchinfo import summary
from envs.craftax_env import state2img from envs.craftax_env import state2img
to_np = lambda x: x.detach().cpu().numpy() to_np = lambda x: x.detach().cpu().numpy()
@@ -99,9 +99,7 @@ class WorldModel(nn.Module):
opt=config.opt, opt=config.opt,
use_amp=self._use_amp, use_amp=self._use_amp,
) )
logger.info( logger.info(f"World Model\n{summary(self, row_settings=['var_names'],)}")
f"Optimizer model_opt has {sum(param.numel() for param in self.parameters())} variables."
)
# other losses are scaled by 1.0. # other losses are scaled by 1.0.
self._scales = dict( self._scales = dict(
reward=config.reward_head["loss_scale"], reward=config.reward_head["loss_scale"],
File diff suppressed because it is too large Load Diff
+142
View File
@@ -0,0 +1,142 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from craftax.environment_base.util import load_compressed_pickle"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"import os, sys\n",
"os.sys.path.append('/media/wassname/SGIronWolf/projects5/2024/Craftax/craftax/craftax/')"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'craftax.craftax_state'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[14], line 11\u001b[0m\n\u001b[1;32m 8\u001b[0m data \u001b[38;5;241m=\u001b[39m pickle\u001b[38;5;241m.\u001b[39mload(data, errors\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mignore\u001b[39m\u001b[38;5;124m'\u001b[39m, fix_imports\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m data\n\u001b[0;32m---> 11\u001b[0m \u001b[43mload_compressed_pickle\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43m/home/wassname/Downloads/people/run1.pbz2\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n",
"Cell \u001b[0;32mIn[14], line 8\u001b[0m, in \u001b[0;36mload_compressed_pickle\u001b[0;34m(file)\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mload_compressed_pickle\u001b[39m(file: \u001b[38;5;28mstr\u001b[39m):\n\u001b[1;32m 7\u001b[0m data \u001b[38;5;241m=\u001b[39m bz2\u001b[38;5;241m.\u001b[39mBZ2File(file, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrb\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 8\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[43mpickle\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43merrors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mignore\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfix_imports\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m data\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'craftax.craftax_state'"
]
},
{
"ename": "",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
"\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
"\u001b[1;31mClick <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. \n",
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
]
}
],
"source": [
"import pickle, bz2\n",
"import craftax.craftax.craftax_state\n",
"craftax.craftax_state = craftax.craftax.craftax_state\n",
"\n",
"\n",
"def load_compressed_pickle(file: str):\n",
" data = bz2.BZ2File(file, \"rb\")\n",
" data = pickle.load(data, errors='ignore', fix_imports=False)\n",
" return data\n",
"\n",
"load_compressed_pickle('/home/wassname/Downloads/people/run1.pbz2')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[0;31mSignature:\u001b[0m\n",
"\u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mfile\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mfix_imports\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mencoding\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'ASCII'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'strict'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mbuffers\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mDocstring:\u001b[0m\n",
"Read and return an object from the pickle data stored in a file.\n",
"\n",
"This is equivalent to ``Unpickler(file).load()``, but may be more\n",
"efficient.\n",
"\n",
"The protocol version of the pickle is detected automatically, so no\n",
"protocol argument is needed. Bytes past the pickled object's\n",
"representation are ignored.\n",
"\n",
"The argument *file* must have two methods, a read() method that takes\n",
"an integer argument, and a readline() method that requires no\n",
"arguments. Both methods should return bytes. Thus *file* can be a\n",
"binary file object opened for reading, an io.BytesIO object, or any\n",
"other custom object that meets this interface.\n",
"\n",
"Optional keyword arguments are *fix_imports*, *encoding* and *errors*,\n",
"which are used to control compatibility support for pickle stream\n",
"generated by Python 2. If *fix_imports* is True, pickle will try to\n",
"map the old Python 2 names to the new names used in Python 3. The\n",
"*encoding* and *errors* tell pickle how to decode 8-bit string\n",
"instances pickled by Python 2; these default to 'ASCII' and 'strict',\n",
"respectively. The *encoding* can be 'bytes' to read these 8-bit\n",
"string instances as bytes objects.\n",
"\u001b[0;31mType:\u001b[0m builtin_function_or_method"
]
}
],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
+8
View File
@@ -9,6 +9,10 @@ from torch import distributions as torchd
from loguru import logger from loguru import logger
import tools import tools
from einops import rearrange from einops import rearrange
from torchinfo import summary
def my_summary(model, input_data):
return summary(model, input_data, col_names=('input_size', 'output_size', 'num_params', 'mult_adds'), verbose=0, row_settings=['depth', 'var_names', 'ascii_only'])
class RSSM(nn.Module): class RSSM(nn.Module):
@@ -332,6 +336,7 @@ class MultiEncoder(nn.Module):
input_shape, cnn_depth, act, norm, kernel_size, minres input_shape, cnn_depth, act, norm, kernel_size, minres
) )
self.outdim += self._cnn.outdim self.outdim += self._cnn.outdim
logger.debug(f"Encoder cnn\n{my_summary(self._cnn, (1,)+input_shape)}")
if self.mlp_shapes: if self.mlp_shapes:
input_size = sum([sum(v) for v in self.mlp_shapes.values()]) input_size = sum([sum(v) for v in self.mlp_shapes.values()])
self._mlp = MLP( self._mlp = MLP(
@@ -344,6 +349,7 @@ class MultiEncoder(nn.Module):
symlog_inputs=symlog_inputs, symlog_inputs=symlog_inputs,
name="Encoder", name="Encoder",
) )
logger.debug(f"Encoder mlp\n{my_summary(self._mlp, (1,input_size))}")
self.outdim += mlp_units self.outdim += mlp_units
def forward(self, obs): def forward(self, obs):
@@ -405,6 +411,7 @@ class MultiDecoder(nn.Module):
outscale=outscale, outscale=outscale,
cnn_sigmoid=cnn_sigmoid, cnn_sigmoid=cnn_sigmoid,
) )
logger.debug(f"Decoder cnn\n{my_summary(self._cnn, (1,1,feat_size))}")
if self.mlp_shapes: if self.mlp_shapes:
self._mlp = MLP( self._mlp = MLP(
feat_size, feat_size,
@@ -417,6 +424,7 @@ class MultiDecoder(nn.Module):
outscale=outscale, outscale=outscale,
name="Decoder", name="Decoder",
) )
logger.debug(f"Decoder mlp\n{my_summary(self._mlp, (1,feat_size))}")
self._image_dist = image_dist self._image_dist = image_dist
def forward(self, features): def forward(self, features):
+11 -3
View File
@@ -15,7 +15,7 @@ from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from torch import distributions as torchd from torch import distributions as torchd
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from contextlib import contextmanager
to_np = lambda x: x.detach().cpu().numpy() to_np = lambda x: x.detach().cpu().numpy()
@@ -82,7 +82,7 @@ class Logger:
scalars.append(("fps", self._compute_fps(step))) scalars.append(("fps", self._compute_fps(step)))
# print out the episode stats # print out the episode stats
stats = " / ".join(f"{k.replace('log_achievement_', '')} <red>{v:.1f}</red>" for k, v in scalars) stats = " / ".join(f"{k.replace('log_achievement_', '')} <red>{v:.1f}</red>" for k, v in scalars)
logger.opt(colors=True).info(f"[{step}] {stats}") logger.opt(colors=True).debug(f"[{step}] {stats}")
with (self._logdir / "metrics.jsonl").open("a") as f: with (self._logdir / "metrics.jsonl").open("a") as f:
f.write(json.dumps({"step": step, **dict(scalars)}) + "\n") f.write(json.dumps({"step": step, **dict(scalars)}) + "\n")
for name, value in scalars: for name, value in scalars:
@@ -127,6 +127,14 @@ class Logger:
self._writer.add_video(name, value, step, 16) self._writer.add_video(name, value, step, 16)
@contextmanager
def cond_tqdm(pbar=None, *args, **kwargs):
if pbar is None:
with tqdm(*args, **kwargs) as pbar:
yield pbar
else:
yield pbar
def simulate( def simulate(
agent, agent,
envs, envs,
@@ -151,7 +159,7 @@ def simulate(
reward = [0] * len(envs) reward = [0] * len(envs)
else: else:
step, episode, done, length, obs, agent_state, reward = state step, episode, done, length, obs, agent_state, reward = state
with tqdm(total=steps, disable=pbar is None) as pbar: with cond_tqdm(total=steps, pbar=pbar) as pbar:
while (steps and step < steps) or (episodes and episode < episodes): while (steps and step < steps) or (episodes and episode < episodes):
# reset envs if necessary # reset envs if necessary
if done.any(): if done.any():