import argparse import functools import os import pathlib import sys # os.environ["MUJOCO_GL"] = "osmesa" import numpy as np import ruamel.yaml as yaml # sys.path.append(str(pathlib.Path(__file__).parent)) import exploration as expl import models import tools import envs.wrappers as wrappers from parallel import Parallel, Damy import torch from torch import nn from torch import distributions as torchd from loguru import logger from tqdm.auto import tqdm logger.remove() logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True) to_np = lambda x: x.detach().cpu().numpy() class Dreamer(nn.Module): def __init__(self, obs_space, act_space, config, tlogger, dataset): super(Dreamer, self).__init__() self._config = config self._logger = tlogger self._should_log = tools.Every(config.log_every) batch_steps = config.batch_size * config.batch_length self._should_train = tools.Every(batch_steps / config.train_ratio) self._should_pretrain = tools.Once() self._should_reset = tools.Every(config.reset_every) self._should_expl = tools.Until(int(config.expl_until / config.action_repeat)) self._metrics = {} # this is update step self._step = tlogger.step // config.action_repeat self._update_count = 0 self._dataset = dataset self._wm = models.WorldModel(obs_space, act_space, self._step, config) self._task_behavior = models.ImagBehavior(config, self._wm) if ( config.compile and os.name != "nt" ): # compilation is not supported on windows self._wm = torch.compile(self._wm) self._task_behavior = torch.compile(self._task_behavior) reward = lambda f, s, a: self._wm.heads["reward"](f).mean() self._expl_behavior = dict( greedy=lambda: self._task_behavior, random=lambda: expl.Random(config, act_space), plan2explore=lambda: expl.Plan2Explore(config, self._wm, reward), )[config.expl_behavior]().to(self._config.device) def __call__(self, obs, reset, state=None, training=True): step = self._step if training: steps = ( self._config.pretrain if self._should_pretrain() else self._should_train(step) ) for _ in range(steps): self._train(next(self._dataset)) self._update_count += 1 self._metrics["update_count"] = self._update_count if self._should_log(step): for name, values in self._metrics.items(): self._logger.scalar(name, float(np.mean(values))) self._metrics[name] = [] if self._config.video_pred_log: # FIXME need to provide this state or synthetic one openl = self._wm.video_pred(next(self._dataset), env_state) self._logger.video("train_openl", to_np(openl)) self._logger.write(fps=True) policy_output, state = self._policy(obs, state, training) if training: self._step += len(reset) self._logger.step = self._config.action_repeat * self._step return policy_output, state def _policy(self, obs, state, training): if state is None: latent = action = None else: latent, action = state obs = self._wm.preprocess(obs) embed = self._wm.encoder(obs) latent, _ = self._wm.dynamics.obs_step(latent, action, embed, obs["is_first"]) if self._config.eval_state_mean: latent["stoch"] = latent["mean"] feat = self._wm.dynamics.get_feat(latent) if not training: actor = self._task_behavior.actor(feat) action = actor.mode() elif self._should_expl(self._step): actor = self._expl_behavior.actor(feat) action = actor.sample() else: actor = self._task_behavior.actor(feat) action = actor.sample() logprob = actor.log_prob(action) latent = {k: v.detach() for k, v in latent.items()} action = action.detach() if self._config.actor["dist"] == "onehot_gumble": action = torch.one_hot( torch.argmax(action, dim=-1), self._config.num_actions ) policy_output = {"action": action, "logprob": logprob} state = (latent, action) return policy_output, state def _train(self, data): metrics = {} post, context, mets = self._wm._train(data) metrics.update(mets) start = post reward = lambda f, s, a: self._wm.heads["reward"]( self._wm.dynamics.get_feat(s) ).mode() metrics.update(self._task_behavior._train(start, reward)[-1]) if self._config.expl_behavior != "greedy": mets = self._expl_behavior.train(start, context, data)[-1] metrics.update({"expl_" + key: value for key, value in mets.items()}) for name, value in metrics.items(): if not name in self._metrics.keys(): self._metrics[name] = [value] else: self._metrics[name].append(value) def count_steps(folder): return sum(int(str(n).split("-")[-1][:-4]) - 1 for n in folder.glob("*.npz")) def make_dataset(episodes, config): generator = tools.sample_episodes(episodes, config.batch_length) dataset = tools.from_generator(generator, config.batch_size) return dataset def make_env(config, mode, id): suite, task = config.task.split("_", 1) if suite == "dmc": import envs.dmc as dmc env = dmc.DeepMindControl( task, config.action_repeat, config.size, seed=config.seed + id ) env = wrappers.NormalizeActions(env) elif suite == "atari": import envs.atari as atari env = atari.Atari( task, config.action_repeat, config.size, gray=config.grayscale, noops=config.noops, lives=config.lives, sticky=config.stickey, actions=config.actions, resize=config.resize, seed=config.seed + id, ) env = wrappers.OneHotAction(env) elif suite == "dmlab": import envs.dmlab as dmlab env = dmlab.DeepMindLabyrinth( task, mode if "train" in mode else "test", config.action_repeat, seed=config.seed + id, ) env = wrappers.OneHotAction(env) elif suite == "memorymaze": from envs.memorymaze import MemoryMaze env = MemoryMaze(task, seed=config.seed + id) env = wrappers.OneHotAction(env) elif suite == "crafter": import envs.crafter as crafter env = crafter.Crafter(task, config.size, seed=config.seed + id) env = wrappers.OneHotAction(env) elif suite == "craftax": from envs import craftax_env env = craftax_env.Craftax(task,seed=config.seed + id) env = wrappers.OneHotAction(env) elif suite == "minecraft": import envs.minecraft as minecraft env = minecraft.make_env(task, size=config.size, break_speed=config.break_speed) env = wrappers.OneHotAction(env) else: raise NotImplementedError(suite) env = wrappers.TimeLimit(env, config.time_limit) env = wrappers.SelectAction(env, key="action") env = wrappers.UUID(env) if suite == "minecraft": env = wrappers.RewardObs(env) return env def main(config): tools.set_seed_everywhere(config.seed) if config.deterministic_run: tools.enable_deterministic_run() logdir = pathlib.Path(config.logdir).expanduser() config.traindir = config.traindir or logdir / "train_eps" config.evaldir = config.evaldir or logdir / "eval_eps" config.steps //= config.action_repeat config.eval_every //= config.action_repeat config.log_every //= config.action_repeat config.time_limit //= config.action_repeat logger.info(f"Logdir {logdir}") logdir.mkdir(parents=True, exist_ok=True) config.traindir.mkdir(parents=True, exist_ok=True) config.evaldir.mkdir(parents=True, exist_ok=True) step = count_steps(config.traindir) # step in logger is environmental step tlogger = tools.Logger(logdir, config.action_repeat * step) logger.add(logdir/"logger.log") logger.info("Create envs.") if config.offline_traindir: directory = config.offline_traindir.format(**vars(config)) else: directory = config.traindir train_eps = tools.load_episodes(directory, limit=config.dataset_size) if config.offline_evaldir: directory = config.offline_evaldir.format(**vars(config)) else: directory = config.evaldir eval_eps = tools.load_episodes(directory, limit=1) make = lambda mode, id: make_env(config, mode, id) train_envs = [make("train", i) for i in range(config.envs)] eval_envs = [make("eval", i) for i in range(config.envs)] if config.parallel: train_envs = [Parallel(env, "process") for env in train_envs] eval_envs = [Parallel(env, "process") for env in eval_envs] else: train_envs = [Damy(env) for env in train_envs] eval_envs = [Damy(env) for env in eval_envs] acts = train_envs[0].action_space logger.info(f"Action Space {acts}" ) config.num_actions = acts.n if hasattr(acts, "n") else acts.shape[0] state = None if not config.offline_traindir: prefill = max(0, config.prefill - count_steps(config.traindir)) logger.info(f"Prefill dataset ({prefill} steps).") if hasattr(acts, "discrete"): random_actor = tools.OneHotDist( torch.zeros(config.num_actions).repeat(config.envs, 1) ) else: random_actor = torchd.independent.Independent( torchd.uniform.Uniform( torch.Tensor(acts.low).repeat(config.envs, 1), torch.Tensor(acts.high).repeat(config.envs, 1), ), 1, ) def random_agent(o, d, s): action = random_actor.sample() logprob = random_actor.log_prob(action) return {"action": action, "logprob": logprob}, None state = tools.simulate( random_agent, train_envs, train_eps, config.traindir, tlogger, limit=config.dataset_size, steps=prefill, ) tlogger.step += prefill * config.action_repeat logger.info(f"Logger: ({tlogger.step} steps).") logger.info("Simulate agent.") train_dataset = make_dataset(train_eps, config) eval_dataset = make_dataset(eval_eps, config) agent = Dreamer( train_envs[0].observation_space, train_envs[0].action_space, config, tlogger, train_dataset, ).to(config.device) agent.requires_grad_(requires_grad=False) if (logdir / "latest.pt").exists(): checkpoint = torch.load(logdir / "latest.pt") agent.load_state_dict(checkpoint["agent_state_dict"]) tools.recursively_load_optim_state_dict(agent, checkpoint["optims_state_dict"]) agent._should_pretrain._once = False logger.warning(f"⚠️ Loaded model from {logdir / 'latest.pt'}, this could invalidate your step budget") # make sure eval will be executed once after config.steps with tqdm(total=config.steps + config.eval_every, unit='step', mininterval=60) as pbar: while agent._step < config.steps + config.eval_every: tlogger.write() if config.eval_episode_num > 0: logger.info("Start evaluation.") eval_policy = functools.partial(agent, training=False) tools.simulate( eval_policy, eval_envs, eval_eps, config.evaldir, tlogger, is_eval=True, episodes=config.eval_episode_num, video_pred_log=config.video_pred_log, pbar=pbar, ) if config.video_pred_log: env_state = eval_envs[0].env_state video_pred = agent._wm.video_pred(next(eval_dataset), env_state) tlogger.video("eval_openl", to_np(video_pred)) logger.info("Start training.") state = tools.simulate( agent, train_envs, train_eps, config.traindir, tlogger, limit=config.dataset_size, steps=config.eval_every, state=state, pbar=pbar, ) items_to_save = { "agent_state_dict": agent.state_dict(), "optims_state_dict": tools.recursively_collect_optim_state_dict(agent), } torch.save(items_to_save, logdir / "latest.pt") logger.info(f"Saved model to {logdir / 'latest.pt'}") pbar.update(agent._step-pbar.n) # 16858 at a time for env in train_envs + eval_envs: try: env.close() except Exception: pass def parse_args(argv=None): # first load config name as arg from command line parser = argparse.ArgumentParser() parser.add_argument("--configs", nargs="+", help="one or more config files") if argv is None: argv = sys.argv args, remaining = parser.parse_known_args(argv[1:]) # load config, using relative path root_dir = pathlib.Path(__file__).parent configs = yaml.safe_load( (root_dir / "configs.yaml").read_text() ) def recursive_update(base, update): for key, value in update.items(): if isinstance(value, dict) and key in base: recursive_update(base[key], value) else: base[key] = value name_list = ["defaults", *args.configs] if args.configs else ["defaults"] defaults = {} for name in name_list: recursive_update(defaults, configs[name]) # defaults = {k:tools.args_type(v)(v) for k, v in defaults.items()} # config = argparse.Namespace(**defaults) # now use argparse to parse config, allowing us to override config with any extra args from cli. You can even use -h parser = argparse.ArgumentParser() for key, value in sorted(defaults.items(), key=lambda x: x[0]): arg_type = tools.args_type(value) parser.add_argument(f"--{key}", type=arg_type, default=arg_type(value)) parser.print_usage() args = parser.parse_args(remaining) logger.info(f"config={args}") return args if __name__ == "__main__": main(parse_args())