mirror of
https://github.com/wassname/dreamerv3-torch.git
synced 2026-06-27 17:48:35 +08:00
wip
This commit is contained in:
@@ -7,9 +7,9 @@ import pathlib
|
||||
import re
|
||||
import time
|
||||
import random
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
import numpy as np
|
||||
|
||||
from loguru import logger
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
@@ -51,7 +51,7 @@ class TimeRecording:
|
||||
def __exit__(self, *args):
|
||||
self._nd.record()
|
||||
torch.cuda.synchronize()
|
||||
print(self._comment, self._st.elapsed_time(self._nd) / 1000)
|
||||
logger.info(self._comment, self._st.elapsed_time(self._nd) / 1000)
|
||||
|
||||
|
||||
class Logger:
|
||||
@@ -80,7 +80,7 @@ class Logger:
|
||||
scalars = list(self._scalars.items())
|
||||
if fps:
|
||||
scalars.append(("fps", self._compute_fps(step)))
|
||||
print(f"[{step}]", " / ".join(f"{k} {v:.1f}" for k, v in scalars))
|
||||
logger.info(f"[{step}]", " / ".join(f"{k} {v:.1f}" for k, v in scalars))
|
||||
with (self._logdir / "metrics.jsonl").open("a") as f:
|
||||
f.write(json.dumps({"step": step, **dict(scalars)}) + "\n")
|
||||
for name, value in scalars:
|
||||
@@ -137,6 +137,7 @@ def simulate(
|
||||
episodes=0,
|
||||
state=None,
|
||||
video_pred_log=True,
|
||||
pbar=None
|
||||
):
|
||||
# initialize or unpack simulation state
|
||||
if state is None:
|
||||
@@ -185,6 +186,7 @@ def simulate(
|
||||
episode += int(done.sum())
|
||||
length += 1
|
||||
step += len(envs)
|
||||
pbar.update(len(envs))
|
||||
length *= 1 - done
|
||||
# add to cache
|
||||
for a, result, env in zip(action, results, envs):
|
||||
|
||||
Reference in New Issue
Block a user