mirror of
https://github.com/wassname/dreamerv3-torch.git
synced 2026-06-27 16:15:31 +08:00
wip
This commit is contained in:
+1
-1
@@ -117,7 +117,7 @@ crafter:
|
||||
action_repeat: 1
|
||||
envs: 1
|
||||
train_ratio: 512
|
||||
video_pred_log: true
|
||||
video_pred_log: false # FIXME
|
||||
dyn_hidden: 1024
|
||||
dyn_deter: 4096
|
||||
units: 1024
|
||||
|
||||
+3
-2
@@ -67,7 +67,7 @@ class Dreamer(nn.Module):
|
||||
if self._should_pretrain()
|
||||
else self._should_train(step)
|
||||
)
|
||||
for _ in tqdm(range(steps), desc='minitrain'):
|
||||
for _ in range(steps):
|
||||
self._train(next(self._dataset))
|
||||
self._update_count += 1
|
||||
self._metrics["update_count"] = self._update_count
|
||||
@@ -318,6 +318,7 @@ def main(config):
|
||||
is_eval=True,
|
||||
episodes=config.eval_episode_num,
|
||||
video_pred_log=config.video_pred_log,
|
||||
pbar=pbar,
|
||||
)
|
||||
if config.video_pred_log:
|
||||
video_pred = agent._wm.video_pred(next(eval_dataset))
|
||||
@@ -339,7 +340,7 @@ def main(config):
|
||||
}
|
||||
torch.save(items_to_save, logdir / "latest.pt")
|
||||
logger.info(f"Saved model to {logdir / 'latest.pt'}")
|
||||
pbar.update_to(agent._step)
|
||||
pbar.update(agent._step-pbar.n) # 16858 at a time
|
||||
for env in train_envs + eval_envs:
|
||||
try:
|
||||
env.close()
|
||||
|
||||
@@ -4,6 +4,7 @@ from torch import nn
|
||||
|
||||
import networks
|
||||
import tools
|
||||
from loguru import logger
|
||||
|
||||
to_np = lambda x: x.detach().cpu().numpy()
|
||||
|
||||
@@ -96,7 +97,7 @@ class WorldModel(nn.Module):
|
||||
opt=config.opt,
|
||||
use_amp=self._use_amp,
|
||||
)
|
||||
print(
|
||||
logger.info(
|
||||
f"Optimizer model_opt has {sum(param.numel() for param in self.parameters())} variables."
|
||||
)
|
||||
# other losses are scaled by 1.0.
|
||||
@@ -261,7 +262,7 @@ class ImagBehavior(nn.Module):
|
||||
config.actor["grad_clip"],
|
||||
**kw,
|
||||
)
|
||||
print(
|
||||
logger.info(
|
||||
f"Optimizer actor_opt has {sum(param.numel() for param in self.actor.parameters())} variables."
|
||||
)
|
||||
self._value_opt = tools.Optimizer(
|
||||
@@ -272,7 +273,7 @@ class ImagBehavior(nn.Module):
|
||||
config.critic["grad_clip"],
|
||||
**kw,
|
||||
)
|
||||
print(
|
||||
logger.info(
|
||||
f"Optimizer value_opt has {sum(param.numel() for param in self.value.parameters())} variables."
|
||||
)
|
||||
if self._config.reward_EMA:
|
||||
|
||||
+2
-2
@@ -387,8 +387,8 @@ class MultiDecoder(nn.Module):
|
||||
for k, v in shapes.items()
|
||||
if len(v) in (1, 2) and re.match(mlp_keys, k)
|
||||
}
|
||||
print("Decoder CNN shapes:", self.cnn_shapes)
|
||||
print("Decoder MLP shapes:", self.mlp_shapes)
|
||||
logger.info("Decoder CNN shapes: %s", self.cnn_shapes)
|
||||
logger.info("Decoder MLP shapes: %s", self.mlp_shapes)
|
||||
|
||||
if self.cnn_shapes:
|
||||
some_shape = list(self.cnn_shapes.values())[0]
|
||||
|
||||
@@ -7,9 +7,9 @@ import pathlib
|
||||
import re
|
||||
import time
|
||||
import random
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
import numpy as np
|
||||
|
||||
from loguru import logger
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
@@ -51,7 +51,7 @@ class TimeRecording:
|
||||
def __exit__(self, *args):
|
||||
self._nd.record()
|
||||
torch.cuda.synchronize()
|
||||
print(self._comment, self._st.elapsed_time(self._nd) / 1000)
|
||||
logger.info(self._comment, self._st.elapsed_time(self._nd) / 1000)
|
||||
|
||||
|
||||
class Logger:
|
||||
@@ -80,7 +80,7 @@ class Logger:
|
||||
scalars = list(self._scalars.items())
|
||||
if fps:
|
||||
scalars.append(("fps", self._compute_fps(step)))
|
||||
print(f"[{step}]", " / ".join(f"{k} {v:.1f}" for k, v in scalars))
|
||||
logger.info(f"[{step}]", " / ".join(f"{k} {v:.1f}" for k, v in scalars))
|
||||
with (self._logdir / "metrics.jsonl").open("a") as f:
|
||||
f.write(json.dumps({"step": step, **dict(scalars)}) + "\n")
|
||||
for name, value in scalars:
|
||||
@@ -137,6 +137,7 @@ def simulate(
|
||||
episodes=0,
|
||||
state=None,
|
||||
video_pred_log=True,
|
||||
pbar=None
|
||||
):
|
||||
# initialize or unpack simulation state
|
||||
if state is None:
|
||||
@@ -185,6 +186,7 @@ def simulate(
|
||||
episode += int(done.sum())
|
||||
length += 1
|
||||
step += len(envs)
|
||||
pbar.update(len(envs))
|
||||
length *= 1 - done
|
||||
# add to cache
|
||||
for a, result, env in zip(action, results, envs):
|
||||
|
||||
Reference in New Issue
Block a user