mirror of
https://github.com/wassname/dreamerv3-torch.git
synced 2026-06-27 16:30:24 +08:00
wip
This commit is contained in:
+1
-1
@@ -311,7 +311,7 @@ def main(config):
|
||||
logger.warning(f"Loaded model from {logdir / 'latest.pt'}")
|
||||
|
||||
# make sure eval will be executed once after config.steps
|
||||
with tqdm(total=config.steps + config.eval_every, unit='step') as pbar:
|
||||
with tqdm(total=config.steps + config.eval_every, unit='step', mininterval=60) as pbar:
|
||||
while agent._step < config.steps + config.eval_every:
|
||||
tlogger.write()
|
||||
if config.eval_episode_num > 0:
|
||||
|
||||
@@ -6,7 +6,7 @@ export TQDM_MININTERVAL := "30"
|
||||
|
||||
main:
|
||||
. ./.venv/bin/activate
|
||||
python dreamer.py --configs craftax_smaller --logdir ./logdir/crafterer
|
||||
python dreamer.py --configs craftax_smaller --logdir ./logdir/craftax_smaller
|
||||
|
||||
logs:
|
||||
tensorboard --logdir logdir/craftax
|
||||
|
||||
@@ -99,7 +99,7 @@ class WorldModel(nn.Module):
|
||||
opt=config.opt,
|
||||
use_amp=self._use_amp,
|
||||
)
|
||||
logger.info(f"World Model\n{summary(self, row_settings=['var_names'],)}")
|
||||
logger.info(f"World Model\n{summary(self, row_settings=['depth', 'var_names'], verbose=False)}")
|
||||
# other losses are scaled by 1.0.
|
||||
self._scales = dict(
|
||||
reward=config.reward_head["loss_scale"],
|
||||
|
||||
+1
-1
@@ -12,7 +12,7 @@ from einops import rearrange
|
||||
from torchinfo import summary
|
||||
|
||||
def my_summary(model, input_data):
|
||||
return summary(model, input_data, col_names=('input_size', 'output_size', 'num_params', 'mult_adds'), verbose=0, row_settings=['depth', 'var_names', 'ascii_only'])
|
||||
return summary(model, input_data, col_names=('input_size', 'output_size', 'num_params', 'mult_adds'), verbose=0, row_settings=['depth', 'var_names'])
|
||||
|
||||
|
||||
class RSSM(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user