Files
dreamerv3-torch/dreamer.py
T
2024-06-07 06:00:35 +08:00

400 lines
14 KiB
Python

import argparse
import functools
import os
import pathlib
import sys
# os.environ["MUJOCO_GL"] = "osmesa"
import numpy as np
import ruamel.yaml as yaml
# sys.path.append(str(pathlib.Path(__file__).parent))
import exploration as expl
import models
import tools
import envs.wrappers as wrappers
from parallel import Parallel, Damy
import torch
from torch import nn
from torch import distributions as torchd
from loguru import logger
from tqdm.auto import tqdm
logger.remove()
logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True)
to_np = lambda x: x.detach().cpu().numpy()
class Dreamer(nn.Module):
def __init__(self, obs_space, act_space, config, tlogger, dataset):
super(Dreamer, self).__init__()
self._config = config
self._logger = tlogger
self._should_log = tools.Every(config.log_every)
batch_steps = config.batch_size * config.batch_length
self._should_train = tools.Every(batch_steps / config.train_ratio)
self._should_pretrain = tools.Once()
self._should_reset = tools.Every(config.reset_every)
self._should_expl = tools.Until(int(config.expl_until / config.action_repeat))
self._metrics = {}
# this is update step
self._step = tlogger.step // config.action_repeat
self._update_count = 0
self._dataset = dataset
self._wm = models.WorldModel(obs_space, act_space, self._step, config)
self._task_behavior = models.ImagBehavior(config, self._wm)
if (
config.compile and os.name != "nt"
): # compilation is not supported on windows
self._wm = torch.compile(self._wm)
self._task_behavior = torch.compile(self._task_behavior)
reward = lambda f, s, a: self._wm.heads["reward"](f).mean()
self._expl_behavior = dict(
greedy=lambda: self._task_behavior,
random=lambda: expl.Random(config, act_space),
plan2explore=lambda: expl.Plan2Explore(config, self._wm, reward),
)[config.expl_behavior]().to(self._config.device)
def __call__(self, obs, reset, state=None, training=True):
step = self._step
if training:
steps = (
self._config.pretrain
if self._should_pretrain()
else self._should_train(step)
)
for _ in range(steps):
self._train(next(self._dataset))
self._update_count += 1
self._metrics["update_count"] = self._update_count
if self._should_log(step):
for name, values in self._metrics.items():
self._logger.scalar(name, float(np.mean(values)))
self._metrics[name] = []
if self._config.video_pred_log:
# FIXME need to provide this state or synthetic one
openl = self._wm.video_pred(next(self._dataset), env_state)
self._logger.video("train_openl", to_np(openl))
self._logger.write(fps=True)
policy_output, state = self._policy(obs, state, training)
if training:
self._step += len(reset)
self._logger.step = self._config.action_repeat * self._step
return policy_output, state
def _policy(self, obs, state, training):
if state is None:
latent = action = None
else:
latent, action = state
obs = self._wm.preprocess(obs)
embed = self._wm.encoder(obs)
latent, _ = self._wm.dynamics.obs_step(latent, action, embed, obs["is_first"])
if self._config.eval_state_mean:
latent["stoch"] = latent["mean"]
feat = self._wm.dynamics.get_feat(latent)
if not training:
actor = self._task_behavior.actor(feat)
action = actor.mode()
elif self._should_expl(self._step):
actor = self._expl_behavior.actor(feat)
action = actor.sample()
else:
actor = self._task_behavior.actor(feat)
action = actor.sample()
logprob = actor.log_prob(action)
latent = {k: v.detach() for k, v in latent.items()}
action = action.detach()
if self._config.actor["dist"] == "onehot_gumble":
action = torch.one_hot(
torch.argmax(action, dim=-1), self._config.num_actions
)
policy_output = {"action": action, "logprob": logprob}
state = (latent, action)
return policy_output, state
def _train(self, data):
metrics = {}
post, context, mets = self._wm._train(data)
metrics.update(mets)
start = post
reward = lambda f, s, a: self._wm.heads["reward"](
self._wm.dynamics.get_feat(s)
).mode()
metrics.update(self._task_behavior._train(start, reward)[-1])
if self._config.expl_behavior != "greedy":
mets = self._expl_behavior.train(start, context, data)[-1]
metrics.update({"expl_" + key: value for key, value in mets.items()})
for name, value in metrics.items():
if not name in self._metrics.keys():
self._metrics[name] = [value]
else:
self._metrics[name].append(value)
def count_steps(folder):
return sum(int(str(n).split("-")[-1][:-4]) - 1 for n in folder.glob("*.npz"))
def make_dataset(episodes, config):
generator = tools.sample_episodes(episodes, config.batch_length)
dataset = tools.from_generator(generator, config.batch_size)
return dataset
def make_env(config, mode, id):
suite, task = config.task.split("_", 1)
if suite == "dmc":
import envs.dmc as dmc
env = dmc.DeepMindControl(
task, config.action_repeat, config.size, seed=config.seed + id
)
env = wrappers.NormalizeActions(env)
elif suite == "atari":
import envs.atari as atari
env = atari.Atari(
task,
config.action_repeat,
config.size,
gray=config.grayscale,
noops=config.noops,
lives=config.lives,
sticky=config.stickey,
actions=config.actions,
resize=config.resize,
seed=config.seed + id,
)
env = wrappers.OneHotAction(env)
elif suite == "dmlab":
import envs.dmlab as dmlab
env = dmlab.DeepMindLabyrinth(
task,
mode if "train" in mode else "test",
config.action_repeat,
seed=config.seed + id,
)
env = wrappers.OneHotAction(env)
elif suite == "memorymaze":
from envs.memorymaze import MemoryMaze
env = MemoryMaze(task, seed=config.seed + id)
env = wrappers.OneHotAction(env)
elif suite == "crafter":
import envs.crafter as crafter
env = crafter.Crafter(task, config.size, seed=config.seed + id)
env = wrappers.OneHotAction(env)
elif suite == "craftax":
from envs import craftax_env
env = craftax_env.Craftax(task,seed=config.seed + id)
env = wrappers.OneHotAction(env)
elif suite == "minecraft":
import envs.minecraft as minecraft
env = minecraft.make_env(task, size=config.size, break_speed=config.break_speed)
env = wrappers.OneHotAction(env)
else:
raise NotImplementedError(suite)
env = wrappers.TimeLimit(env, config.time_limit)
env = wrappers.SelectAction(env, key="action")
env = wrappers.UUID(env)
if suite == "minecraft":
env = wrappers.RewardObs(env)
return env
def main(config):
tools.set_seed_everywhere(config.seed)
if config.deterministic_run:
tools.enable_deterministic_run()
logdir = pathlib.Path(config.logdir).expanduser()
config.traindir = config.traindir or logdir / "train_eps"
config.evaldir = config.evaldir or logdir / "eval_eps"
config.steps //= config.action_repeat
config.eval_every //= config.action_repeat
config.log_every //= config.action_repeat
config.time_limit //= config.action_repeat
logger.info(f"Logdir {logdir}")
logdir.mkdir(parents=True, exist_ok=True)
config.traindir.mkdir(parents=True, exist_ok=True)
config.evaldir.mkdir(parents=True, exist_ok=True)
step = count_steps(config.traindir)
# step in logger is environmental step
tlogger = tools.Logger(logdir, config.action_repeat * step)
logger.add(logdir/"logger.log")
logger.info("Create envs.")
if config.offline_traindir:
directory = config.offline_traindir.format(**vars(config))
else:
directory = config.traindir
train_eps = tools.load_episodes(directory, limit=config.dataset_size)
if config.offline_evaldir:
directory = config.offline_evaldir.format(**vars(config))
else:
directory = config.evaldir
eval_eps = tools.load_episodes(directory, limit=1)
make = lambda mode, id: make_env(config, mode, id)
train_envs = [make("train", i) for i in range(config.envs)]
eval_envs = [make("eval", i) for i in range(config.envs)]
if config.parallel:
train_envs = [Parallel(env, "process") for env in train_envs]
eval_envs = [Parallel(env, "process") for env in eval_envs]
else:
train_envs = [Damy(env) for env in train_envs]
eval_envs = [Damy(env) for env in eval_envs]
acts = train_envs[0].action_space
logger.info(f"Action Space {acts}" )
config.num_actions = acts.n if hasattr(acts, "n") else acts.shape[0]
state = None
if not config.offline_traindir:
prefill = max(0, config.prefill - count_steps(config.traindir))
logger.info(f"Prefill dataset ({prefill} steps).")
if hasattr(acts, "discrete"):
random_actor = tools.OneHotDist(
torch.zeros(config.num_actions).repeat(config.envs, 1)
)
else:
random_actor = torchd.independent.Independent(
torchd.uniform.Uniform(
torch.Tensor(acts.low).repeat(config.envs, 1),
torch.Tensor(acts.high).repeat(config.envs, 1),
),
1,
)
def random_agent(o, d, s):
action = random_actor.sample()
logprob = random_actor.log_prob(action)
return {"action": action, "logprob": logprob}, None
state = tools.simulate(
random_agent,
train_envs,
train_eps,
config.traindir,
tlogger,
limit=config.dataset_size,
steps=prefill,
)
tlogger.step += prefill * config.action_repeat
logger.info(f"Logger: ({tlogger.step} steps).")
logger.info("Simulate agent.")
train_dataset = make_dataset(train_eps, config)
eval_dataset = make_dataset(eval_eps, config)
agent = Dreamer(
train_envs[0].observation_space,
train_envs[0].action_space,
config,
tlogger,
train_dataset,
).to(config.device)
agent.requires_grad_(requires_grad=False)
if (logdir / "latest.pt").exists():
checkpoint = torch.load(logdir / "latest.pt")
agent.load_state_dict(checkpoint["agent_state_dict"])
tools.recursively_load_optim_state_dict(agent, checkpoint["optims_state_dict"])
agent._should_pretrain._once = False
logger.warning(f"Loaded model from {logdir / 'latest.pt'}")
# make sure eval will be executed once after config.steps
with tqdm(total=config.steps + config.eval_every, unit='step') as pbar:
while agent._step < config.steps + config.eval_every:
tlogger.write()
if config.eval_episode_num > 0:
logger.info("Start evaluation.")
eval_policy = functools.partial(agent, training=False)
tools.simulate(
eval_policy,
eval_envs,
eval_eps,
config.evaldir,
tlogger,
is_eval=True,
episodes=config.eval_episode_num,
video_pred_log=config.video_pred_log,
pbar=pbar,
)
if config.video_pred_log:
env_state = eval_envs[0].env_state
video_pred = agent._wm.video_pred(next(eval_dataset), env_state)
tlogger.video("eval_openl", to_np(video_pred))
logger.info("Start training.")
state = tools.simulate(
agent,
train_envs,
train_eps,
config.traindir,
tlogger,
limit=config.dataset_size,
steps=config.eval_every,
state=state,
pbar=pbar,
)
items_to_save = {
"agent_state_dict": agent.state_dict(),
"optims_state_dict": tools.recursively_collect_optim_state_dict(agent),
}
torch.save(items_to_save, logdir / "latest.pt")
logger.info(f"Saved model to {logdir / 'latest.pt'}")
pbar.update(agent._step-pbar.n) # 16858 at a time
for env in train_envs + eval_envs:
try:
env.close()
except Exception:
pass
def parse_args(argv=None):
# first load config name as arg from command line
parser = argparse.ArgumentParser()
parser.add_argument("--configs", nargs="+")
if argv is None:
argv = sys.argv
args, remaining = parser.parse_known_args(argv[1:])
# load config, using relative path
root_dir = pathlib.Path(__file__).parent
configs = yaml.safe_load(
(root_dir / "configs.yaml").read_text()
)
def recursive_update(base, update):
for key, value in update.items():
if isinstance(value, dict) and key in base:
recursive_update(base[key], value)
else:
base[key] = value
name_list = ["defaults", *args.configs] if args.configs else ["defaults"]
defaults = {}
for name in name_list:
recursive_update(defaults, configs[name])
# defaults = {k:tools.args_type(v)(v) for k, v in defaults.items()}
# config = argparse.Namespace(**defaults)
# now use argparse to parse config, allowing us to override config with any extra args from cli. You can even use -h
parser = argparse.ArgumentParser()
for key, value in sorted(defaults.items(), key=lambda x: x[0]):
arg_type = tools.args_type(value)
parser.add_argument(f"--{key}", type=arg_type, default=arg_type(value))
args = parser.parse_args(remaining)
logger.info(f"config={args}")
return args
if __name__ == "__main__":
main(parse_args())