This commit is contained in:
wassname
2024-06-03 08:29:29 +08:00
parent f853c03f4b
commit 5a0e8dc5ac
5 changed files with 16 additions and 12 deletions
+6 -4
View File
@@ -7,9 +7,9 @@ import pathlib
import re
import time
import random
from tqdm.auto import tqdm
import numpy as np
from loguru import logger
import torch
from torch import nn
from torch.nn import functional as F
@@ -51,7 +51,7 @@ class TimeRecording:
def __exit__(self, *args):
self._nd.record()
torch.cuda.synchronize()
print(self._comment, self._st.elapsed_time(self._nd) / 1000)
logger.info(self._comment, self._st.elapsed_time(self._nd) / 1000)
class Logger:
@@ -80,7 +80,7 @@ class Logger:
scalars = list(self._scalars.items())
if fps:
scalars.append(("fps", self._compute_fps(step)))
print(f"[{step}]", " / ".join(f"{k} {v:.1f}" for k, v in scalars))
logger.info(f"[{step}]", " / ".join(f"{k} {v:.1f}" for k, v in scalars))
with (self._logdir / "metrics.jsonl").open("a") as f:
f.write(json.dumps({"step": step, **dict(scalars)}) + "\n")
for name, value in scalars:
@@ -137,6 +137,7 @@ def simulate(
episodes=0,
state=None,
video_pred_log=True,
pbar=None
):
# initialize or unpack simulation state
if state is None:
@@ -185,6 +186,7 @@ def simulate(
episode += int(done.sum())
length += 1
step += len(envs)
pbar.update(len(envs))
length *= 1 - done
# add to cache
for a, result, env in zip(action, results, envs):