This commit is contained in:
wassname
2024-06-03 08:29:29 +08:00
parent f853c03f4b
commit 5a0e8dc5ac
5 changed files with 16 additions and 12 deletions
+1 -1
View File
@@ -117,7 +117,7 @@ crafter:
action_repeat: 1
envs: 1
train_ratio: 512
video_pred_log: true
video_pred_log: false # FIXME
dyn_hidden: 1024
dyn_deter: 4096
units: 1024
+3 -2
View File
@@ -67,7 +67,7 @@ class Dreamer(nn.Module):
if self._should_pretrain()
else self._should_train(step)
)
for _ in tqdm(range(steps), desc='minitrain'):
for _ in range(steps):
self._train(next(self._dataset))
self._update_count += 1
self._metrics["update_count"] = self._update_count
@@ -318,6 +318,7 @@ def main(config):
is_eval=True,
episodes=config.eval_episode_num,
video_pred_log=config.video_pred_log,
pbar=pbar,
)
if config.video_pred_log:
video_pred = agent._wm.video_pred(next(eval_dataset))
@@ -339,7 +340,7 @@ def main(config):
}
torch.save(items_to_save, logdir / "latest.pt")
logger.info(f"Saved model to {logdir / 'latest.pt'}")
pbar.update_to(agent._step)
pbar.update(agent._step-pbar.n) # 16858 at a time
for env in train_envs + eval_envs:
try:
env.close()
+4 -3
View File
@@ -4,6 +4,7 @@ from torch import nn
import networks
import tools
from loguru import logger
to_np = lambda x: x.detach().cpu().numpy()
@@ -96,7 +97,7 @@ class WorldModel(nn.Module):
opt=config.opt,
use_amp=self._use_amp,
)
print(
logger.info(
f"Optimizer model_opt has {sum(param.numel() for param in self.parameters())} variables."
)
# other losses are scaled by 1.0.
@@ -261,7 +262,7 @@ class ImagBehavior(nn.Module):
config.actor["grad_clip"],
**kw,
)
print(
logger.info(
f"Optimizer actor_opt has {sum(param.numel() for param in self.actor.parameters())} variables."
)
self._value_opt = tools.Optimizer(
@@ -272,7 +273,7 @@ class ImagBehavior(nn.Module):
config.critic["grad_clip"],
**kw,
)
print(
logger.info(
f"Optimizer value_opt has {sum(param.numel() for param in self.value.parameters())} variables."
)
if self._config.reward_EMA:
+2 -2
View File
@@ -387,8 +387,8 @@ class MultiDecoder(nn.Module):
for k, v in shapes.items()
if len(v) in (1, 2) and re.match(mlp_keys, k)
}
print("Decoder CNN shapes:", self.cnn_shapes)
print("Decoder MLP shapes:", self.mlp_shapes)
logger.info("Decoder CNN shapes: %s", self.cnn_shapes)
logger.info("Decoder MLP shapes: %s", self.mlp_shapes)
if self.cnn_shapes:
some_shape = list(self.cnn_shapes.values())[0]
+6 -4
View File
@@ -7,9 +7,9 @@ import pathlib
import re
import time
import random
from tqdm.auto import tqdm
import numpy as np
from loguru import logger
import torch
from torch import nn
from torch.nn import functional as F
@@ -51,7 +51,7 @@ class TimeRecording:
def __exit__(self, *args):
self._nd.record()
torch.cuda.synchronize()
print(self._comment, self._st.elapsed_time(self._nd) / 1000)
logger.info(self._comment, self._st.elapsed_time(self._nd) / 1000)
class Logger:
@@ -80,7 +80,7 @@ class Logger:
scalars = list(self._scalars.items())
if fps:
scalars.append(("fps", self._compute_fps(step)))
print(f"[{step}]", " / ".join(f"{k} {v:.1f}" for k, v in scalars))
logger.info(f"[{step}]", " / ".join(f"{k} {v:.1f}" for k, v in scalars))
with (self._logdir / "metrics.jsonl").open("a") as f:
f.write(json.dumps({"step": step, **dict(scalars)}) + "\n")
for name, value in scalars:
@@ -137,6 +137,7 @@ def simulate(
episodes=0,
state=None,
video_pred_log=True,
pbar=None
):
# initialize or unpack simulation state
if state is None:
@@ -185,6 +186,7 @@ def simulate(
episode += int(done.sum())
length += 1
step += len(envs)
pbar.update(len(envs))
length *= 1 - done
# add to cache
for a, result, env in zip(action, results, envs):