mirror of
https://github.com/wassname/dreamerv3-torch.git
synced 2026-06-27 16:30:24 +08:00
fix mem overflow, torchinfo
This commit is contained in:
+5
-1
@@ -69,7 +69,7 @@ defaults:
|
||||
model_lr: 1e-4
|
||||
opt_eps: 1e-8
|
||||
grad_clip: 1000
|
||||
dataset_size: 1000000
|
||||
dataset_size: 1_000_000
|
||||
opt: 'adam'
|
||||
|
||||
# Behavior.
|
||||
@@ -147,6 +147,7 @@ craftax:
|
||||
reward_head: {layers: 5}
|
||||
cont_head: {layers: 5}
|
||||
imag_gradient: 'reinforce'
|
||||
time_limit: 4000
|
||||
|
||||
craftax_small:
|
||||
task: craftax_Craftax-Symbolic-AutoReset-v1
|
||||
@@ -171,6 +172,7 @@ craftax_small:
|
||||
imag_gradient: 'reinforce'
|
||||
batch_size: 256
|
||||
batch_length: 32
|
||||
time_limit: 4000
|
||||
|
||||
craftax_smaller:
|
||||
task: craftax_Craftax-Symbolic-AutoReset-v1
|
||||
@@ -195,6 +197,8 @@ craftax_smaller:
|
||||
imag_gradient: 'reinforce'
|
||||
batch_size: 256
|
||||
batch_length: 32
|
||||
time_limit: 4000
|
||||
dataset_size: 20_000
|
||||
|
||||
atari100k:
|
||||
steps: 4e5
|
||||
|
||||
+7
-1
@@ -350,7 +350,7 @@ def main(config):
|
||||
}
|
||||
torch.save(items_to_save, logdir / "latest.pt")
|
||||
logger.info(f"Saved model to {logdir / 'latest.pt'}")
|
||||
# pbar.update(agent._step-pbar.n) # 16858 at a time
|
||||
pbar.update(agent._step-pbar.n) # 16858 at a time
|
||||
for env in train_envs + eval_envs:
|
||||
try:
|
||||
env.close()
|
||||
@@ -358,6 +358,7 @@ def main(config):
|
||||
pass
|
||||
|
||||
def parse_args(argv=None):
|
||||
# first load config name as arg from command line
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--configs", nargs="+")
|
||||
if argv is None:
|
||||
@@ -382,11 +383,16 @@ def parse_args(argv=None):
|
||||
for name in name_list:
|
||||
recursive_update(defaults, configs[name])
|
||||
|
||||
# defaults = {k:tools.args_type(v)(v) for k, v in defaults.items()}
|
||||
# config = argparse.Namespace(**defaults)
|
||||
|
||||
# now use argparse to parse config, allowing us to override config with any extra args from cli. You can even use -h
|
||||
parser = argparse.ArgumentParser()
|
||||
for key, value in sorted(defaults.items(), key=lambda x: x[0]):
|
||||
arg_type = tools.args_type(value)
|
||||
parser.add_argument(f"--{key}", type=arg_type, default=arg_type(value))
|
||||
args = parser.parse_args(remaining)
|
||||
logger.info(f"config={args}")
|
||||
return args
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -6,7 +6,7 @@ export TQDM_MININTERVAL := "30"
|
||||
|
||||
main:
|
||||
. ./.venv/bin/activate
|
||||
python dreamer.py --configs craftax_small --logdir ./logdir/crafter
|
||||
python dreamer.py --configs craftax_smaller --logdir ./logdir/crafterer
|
||||
|
||||
logs:
|
||||
tensorboard --logdir logdir/craftax
|
||||
|
||||
@@ -5,7 +5,7 @@ from torch import nn
|
||||
import networks
|
||||
import tools
|
||||
from loguru import logger
|
||||
|
||||
from torchinfo import summary
|
||||
from envs.craftax_env import state2img
|
||||
|
||||
to_np = lambda x: x.detach().cpu().numpy()
|
||||
@@ -99,9 +99,7 @@ class WorldModel(nn.Module):
|
||||
opt=config.opt,
|
||||
use_amp=self._use_amp,
|
||||
)
|
||||
logger.info(
|
||||
f"Optimizer model_opt has {sum(param.numel() for param in self.parameters())} variables."
|
||||
)
|
||||
logger.info(f"World Model\n{summary(self, row_settings=['var_names'],)}")
|
||||
# other losses are scaled by 1.0.
|
||||
self._scales = dict(
|
||||
reward=config.reward_head["loss_scale"],
|
||||
|
||||
+969
-326
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,142 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from craftax.environment_base.util import load_compressed_pickle"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os, sys\n",
|
||||
"os.sys.path.append('/media/wassname/SGIronWolf/projects5/2024/Craftax/craftax/craftax/')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "ModuleNotFoundError",
|
||||
"evalue": "No module named 'craftax.craftax_state'",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[14], line 11\u001b[0m\n\u001b[1;32m 8\u001b[0m data \u001b[38;5;241m=\u001b[39m pickle\u001b[38;5;241m.\u001b[39mload(data, errors\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mignore\u001b[39m\u001b[38;5;124m'\u001b[39m, fix_imports\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m data\n\u001b[0;32m---> 11\u001b[0m \u001b[43mload_compressed_pickle\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43m/home/wassname/Downloads/people/run1.pbz2\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n",
|
||||
"Cell \u001b[0;32mIn[14], line 8\u001b[0m, in \u001b[0;36mload_compressed_pickle\u001b[0;34m(file)\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mload_compressed_pickle\u001b[39m(file: \u001b[38;5;28mstr\u001b[39m):\n\u001b[1;32m 7\u001b[0m data \u001b[38;5;241m=\u001b[39m bz2\u001b[38;5;241m.\u001b[39mBZ2File(file, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrb\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 8\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[43mpickle\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43merrors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mignore\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfix_imports\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m data\n",
|
||||
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'craftax.craftax_state'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "",
|
||||
"evalue": "",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
|
||||
"\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
|
||||
"\u001b[1;31mClick <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. \n",
|
||||
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import pickle, bz2\n",
|
||||
"import craftax.craftax.craftax_state\n",
|
||||
"craftax.craftax_state = craftax.craftax.craftax_state\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def load_compressed_pickle(file: str):\n",
|
||||
" data = bz2.BZ2File(file, \"rb\")\n",
|
||||
" data = pickle.load(data, errors='ignore', fix_imports=False)\n",
|
||||
" return data\n",
|
||||
"\n",
|
||||
"load_compressed_pickle('/home/wassname/Downloads/people/run1.pbz2')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[0;31mSignature:\u001b[0m\n",
|
||||
"\u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||
"\u001b[0;34m\u001b[0m \u001b[0mfile\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||
"\u001b[0;34m\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||
"\u001b[0;34m\u001b[0m \u001b[0mfix_imports\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||
"\u001b[0;34m\u001b[0m \u001b[0mencoding\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'ASCII'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||
"\u001b[0;34m\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'strict'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||
"\u001b[0;34m\u001b[0m \u001b[0mbuffers\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||
"\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;31mDocstring:\u001b[0m\n",
|
||||
"Read and return an object from the pickle data stored in a file.\n",
|
||||
"\n",
|
||||
"This is equivalent to ``Unpickler(file).load()``, but may be more\n",
|
||||
"efficient.\n",
|
||||
"\n",
|
||||
"The protocol version of the pickle is detected automatically, so no\n",
|
||||
"protocol argument is needed. Bytes past the pickled object's\n",
|
||||
"representation are ignored.\n",
|
||||
"\n",
|
||||
"The argument *file* must have two methods, a read() method that takes\n",
|
||||
"an integer argument, and a readline() method that requires no\n",
|
||||
"arguments. Both methods should return bytes. Thus *file* can be a\n",
|
||||
"binary file object opened for reading, an io.BytesIO object, or any\n",
|
||||
"other custom object that meets this interface.\n",
|
||||
"\n",
|
||||
"Optional keyword arguments are *fix_imports*, *encoding* and *errors*,\n",
|
||||
"which are used to control compatibility support for pickle stream\n",
|
||||
"generated by Python 2. If *fix_imports* is True, pickle will try to\n",
|
||||
"map the old Python 2 names to the new names used in Python 3. The\n",
|
||||
"*encoding* and *errors* tell pickle how to decode 8-bit string\n",
|
||||
"instances pickled by Python 2; these default to 'ASCII' and 'strict',\n",
|
||||
"respectively. The *encoding* can be 'bytes' to read these 8-bit\n",
|
||||
"string instances as bytes objects.\n",
|
||||
"\u001b[0;31mType:\u001b[0m builtin_function_or_method"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -9,6 +9,10 @@ from torch import distributions as torchd
|
||||
from loguru import logger
|
||||
import tools
|
||||
from einops import rearrange
|
||||
from torchinfo import summary
|
||||
|
||||
def my_summary(model, input_data):
|
||||
return summary(model, input_data, col_names=('input_size', 'output_size', 'num_params', 'mult_adds'), verbose=0, row_settings=['depth', 'var_names', 'ascii_only'])
|
||||
|
||||
|
||||
class RSSM(nn.Module):
|
||||
@@ -332,6 +336,7 @@ class MultiEncoder(nn.Module):
|
||||
input_shape, cnn_depth, act, norm, kernel_size, minres
|
||||
)
|
||||
self.outdim += self._cnn.outdim
|
||||
logger.debug(f"Encoder cnn\n{my_summary(self._cnn, (1,)+input_shape)}")
|
||||
if self.mlp_shapes:
|
||||
input_size = sum([sum(v) for v in self.mlp_shapes.values()])
|
||||
self._mlp = MLP(
|
||||
@@ -344,6 +349,7 @@ class MultiEncoder(nn.Module):
|
||||
symlog_inputs=symlog_inputs,
|
||||
name="Encoder",
|
||||
)
|
||||
logger.debug(f"Encoder mlp\n{my_summary(self._mlp, (1,input_size))}")
|
||||
self.outdim += mlp_units
|
||||
|
||||
def forward(self, obs):
|
||||
@@ -405,6 +411,7 @@ class MultiDecoder(nn.Module):
|
||||
outscale=outscale,
|
||||
cnn_sigmoid=cnn_sigmoid,
|
||||
)
|
||||
logger.debug(f"Decoder cnn\n{my_summary(self._cnn, (1,1,feat_size))}")
|
||||
if self.mlp_shapes:
|
||||
self._mlp = MLP(
|
||||
feat_size,
|
||||
@@ -417,6 +424,7 @@ class MultiDecoder(nn.Module):
|
||||
outscale=outscale,
|
||||
name="Decoder",
|
||||
)
|
||||
logger.debug(f"Decoder mlp\n{my_summary(self._mlp, (1,feat_size))}")
|
||||
self._image_dist = image_dist
|
||||
|
||||
def forward(self, features):
|
||||
|
||||
@@ -15,7 +15,7 @@ from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch import distributions as torchd
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
to_np = lambda x: x.detach().cpu().numpy()
|
||||
|
||||
@@ -82,7 +82,7 @@ class Logger:
|
||||
scalars.append(("fps", self._compute_fps(step)))
|
||||
# print out the episode stats
|
||||
stats = " / ".join(f"{k.replace('log_achievement_', '')} <red>{v:.1f}</red>" for k, v in scalars)
|
||||
logger.opt(colors=True).info(f"[{step}] {stats}")
|
||||
logger.opt(colors=True).debug(f"[{step}] {stats}")
|
||||
with (self._logdir / "metrics.jsonl").open("a") as f:
|
||||
f.write(json.dumps({"step": step, **dict(scalars)}) + "\n")
|
||||
for name, value in scalars:
|
||||
@@ -127,6 +127,14 @@ class Logger:
|
||||
self._writer.add_video(name, value, step, 16)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def cond_tqdm(pbar=None, *args, **kwargs):
|
||||
if pbar is None:
|
||||
with tqdm(*args, **kwargs) as pbar:
|
||||
yield pbar
|
||||
else:
|
||||
yield pbar
|
||||
|
||||
def simulate(
|
||||
agent,
|
||||
envs,
|
||||
@@ -151,7 +159,7 @@ def simulate(
|
||||
reward = [0] * len(envs)
|
||||
else:
|
||||
step, episode, done, length, obs, agent_state, reward = state
|
||||
with tqdm(total=steps, disable=pbar is None) as pbar:
|
||||
with cond_tqdm(total=steps, pbar=pbar) as pbar:
|
||||
while (steps and step < steps) or (episodes and episode < episodes):
|
||||
# reset envs if necessary
|
||||
if done.any():
|
||||
|
||||
Reference in New Issue
Block a user