mirror of
https://github.com/wassname/dreamerv3-torch.git
synced 2026-06-27 16:15:31 +08:00
trains
This commit is contained in:
+20
-1
@@ -62,7 +62,7 @@ defaults:
|
||||
initial: 'learned'
|
||||
|
||||
# Training
|
||||
batch_size: 16
|
||||
batch_size: 256
|
||||
batch_length: 64
|
||||
train_ratio: 512
|
||||
pretrain: 100
|
||||
@@ -129,6 +129,25 @@ crafter:
|
||||
cont_head: {layers: 5}
|
||||
imag_gradient: 'reinforce'
|
||||
|
||||
|
||||
craftax:
|
||||
task: craftax_Craftax-Symbolic-AutoReset-v1
|
||||
step: 1e6
|
||||
action_repeat: 1
|
||||
envs: 1
|
||||
train_ratio: 512
|
||||
video_pred_log: false # FIXME
|
||||
dyn_hidden: 1024
|
||||
dyn_deter: 4096
|
||||
units: 1024
|
||||
encoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 5, mlp_units: 1024, }
|
||||
decoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 5, mlp_units: 1024}
|
||||
actor: {layers: 5, dist: 'onehot', std: 'none'}
|
||||
value: {layers: 5}
|
||||
reward_head: {layers: 5}
|
||||
cont_head: {layers: 5}
|
||||
imag_gradient: 'reinforce'
|
||||
|
||||
atari100k:
|
||||
steps: 4e5
|
||||
envs: 1
|
||||
|
||||
+13
-4
@@ -76,7 +76,8 @@ class Dreamer(nn.Module):
|
||||
self._logger.scalar(name, float(np.mean(values)))
|
||||
self._metrics[name] = []
|
||||
if self._config.video_pred_log:
|
||||
openl = self._wm.video_pred(next(self._dataset))
|
||||
# FIXME need to provide this state or synthetic one
|
||||
openl = self._wm.video_pred(next(self._dataset), env_state)
|
||||
self._logger.video("train_openl", to_np(openl))
|
||||
self._logger.write(fps=True)
|
||||
|
||||
@@ -192,6 +193,11 @@ def make_env(config, mode, id):
|
||||
|
||||
env = crafter.Crafter(task, config.size, seed=config.seed + id)
|
||||
env = wrappers.OneHotAction(env)
|
||||
elif suite == "craftax":
|
||||
from envs import craftax_env
|
||||
|
||||
env = craftax_env.Craftax(task,seed=config.seed + id)
|
||||
env = wrappers.OneHotAction(env)
|
||||
elif suite == "minecraft":
|
||||
import envs.minecraft as minecraft
|
||||
|
||||
@@ -226,6 +232,7 @@ def main(config):
|
||||
step = count_steps(config.traindir)
|
||||
# step in logger is environmental step
|
||||
tlogger = tools.Logger(logdir, config.action_repeat * step)
|
||||
logger.add(logdir/"logger.log")
|
||||
|
||||
logger.info("Create envs.")
|
||||
if config.offline_traindir:
|
||||
@@ -303,7 +310,7 @@ def main(config):
|
||||
agent._should_pretrain._once = False
|
||||
|
||||
# make sure eval will be executed once after config.steps
|
||||
with tqdm(total=config.steps + config.eval_every) as pbar:
|
||||
with tqdm(total=config.steps + config.eval_every, unit='step') as pbar:
|
||||
while agent._step < config.steps + config.eval_every:
|
||||
tlogger.write()
|
||||
if config.eval_episode_num > 0:
|
||||
@@ -321,7 +328,8 @@ def main(config):
|
||||
pbar=pbar,
|
||||
)
|
||||
if config.video_pred_log:
|
||||
video_pred = agent._wm.video_pred(next(eval_dataset))
|
||||
env_state = eval_envs[0].env_state
|
||||
video_pred = agent._wm.video_pred(next(eval_dataset), env_state)
|
||||
tlogger.video("eval_openl", to_np(video_pred))
|
||||
logger.info("Start training.")
|
||||
state = tools.simulate(
|
||||
@@ -333,6 +341,7 @@ def main(config):
|
||||
limit=config.dataset_size,
|
||||
steps=config.eval_every,
|
||||
state=state,
|
||||
pbar=pbar,
|
||||
)
|
||||
items_to_save = {
|
||||
"agent_state_dict": agent.state_dict(),
|
||||
@@ -340,7 +349,7 @@ def main(config):
|
||||
}
|
||||
torch.save(items_to_save, logdir / "latest.pt")
|
||||
logger.info(f"Saved model to {logdir / 'latest.pt'}")
|
||||
pbar.update(agent._step-pbar.n) # 16858 at a time
|
||||
# pbar.update(agent._step-pbar.n) # 16858 at a time
|
||||
for env in train_envs + eval_envs:
|
||||
try:
|
||||
env.close()
|
||||
|
||||
@@ -0,0 +1,256 @@
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
||||
from craftax.craftax_env import make_craftax_env_from_name
|
||||
from craftax.craftax.play_craftax import CraftaxRenderer
|
||||
from craftax.craftax.renderer import (
|
||||
render_craftax_pixels,
|
||||
render_craftax_text,
|
||||
inverse_render_craftax_symbolic,
|
||||
)
|
||||
from craftax.craftax.constants import Action, Achievement
|
||||
from craftax.craftax.craftax_state import EnvState
|
||||
|
||||
import gymnasium
|
||||
# from gymnasium.wrappers.jax_to_torch import jax_to_torch
|
||||
# from gymnasium.wrappers.numpy_to_torch import numpy_to_torch
|
||||
from gymnasium.wrappers import FrameStack, TimeLimit, FrameStack
|
||||
import gymnasium.spaces as gym_spaces
|
||||
from gymnasium.wrappers import TransformObservation
|
||||
|
||||
# import jax
|
||||
import chex
|
||||
import jax.numpy as jnp
|
||||
import torch
|
||||
from jaxtyping import Float, Int, Bool
|
||||
from torch import Tensor
|
||||
from typing import Optional, Tuple, Union, Any, Dict
|
||||
from jax import dlpack as jax_dlpack
|
||||
from torch.utils import dlpack as torch_dlpack
|
||||
|
||||
from envs.gymnax2gymnasium import GymnaxToGymWrapper, GymnaxToVectorGymWrapper
|
||||
|
||||
def state2img(state: chex.Array, env_state: EnvState) -> np.ndarray:
|
||||
img = inverse_render_craftax_symbolic(state, env_state).astype(np.uint8)
|
||||
return np.array(img).astype(np.uint8)
|
||||
|
||||
|
||||
def permute_env(env, prm=[1, 0, 2]):
|
||||
os = env.observation_space
|
||||
oshape = os.shape
|
||||
new_os = gym_spaces.Box(
|
||||
low=np.transpose(os.low, prm),
|
||||
high=np.transpose(os.high, prm),
|
||||
shape=[oshape[i] for i in prm],
|
||||
dtype=os.dtype,
|
||||
)
|
||||
env = TransformObservation(env, lambda x: jnp.transpose(x, prm), obs_space=new_os)
|
||||
return env
|
||||
|
||||
|
||||
def jax_to_torch(v) -> torch.Tensor:
|
||||
dlpack = jax_dlpack.to_dlpack(v)
|
||||
return torch_dlpack.from_dlpack(dlpack)
|
||||
|
||||
def numpy_to_torch(v) -> torch.Tensor:
|
||||
|
||||
# for lazyframes
|
||||
if hasattr(v, '_frames'): v = np.array(v._frames)
|
||||
return torch.from_numpy(v)
|
||||
|
||||
def to_torch(v) -> torch.Tensor:
|
||||
if isinstance(v, jnp.ndarray):
|
||||
if v.dtype=='bool':
|
||||
# bool doesn't convert using the jax_to_torch dlpack
|
||||
# return torch.from_numpy(v._npy_value.copy())
|
||||
return torch.as_tensor(v.tolist())
|
||||
return jax_to_torch(v)
|
||||
if isinstance(v, np.ndarray):
|
||||
return numpy_to_torch(v)
|
||||
if isinstance(v, torch.Tensor):
|
||||
return v
|
||||
else:
|
||||
return torch.as_tensor(v)
|
||||
|
||||
|
||||
|
||||
class CraftaxCompatWrapper(gymnasium.core.Wrapper):
|
||||
"""
|
||||
Misc compat
|
||||
- from jax
|
||||
"""
|
||||
|
||||
def __init__(self, env) -> None:
|
||||
super().__init__(env)
|
||||
self._env = env.unwrapped._env
|
||||
|
||||
def step(
|
||||
self, action: int
|
||||
) -> Tuple[Float[Tensor, "frames odim"], float, bool, bool, Dict]:
|
||||
next_obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
return (
|
||||
numpy_to_torch(next_obs).to(torch.float16), # in symbolic only lighting needs values other than 0 and 1
|
||||
to_torch(reward),
|
||||
# to_torch(terminated),
|
||||
to_torch(truncated | terminated),
|
||||
info,
|
||||
)
|
||||
|
||||
def reset(self, *args, **kwargs):
|
||||
obs, state = self.env.reset(*args, **kwargs)
|
||||
return numpy_to_torch(obs).to(torch.float16), state
|
||||
|
||||
def get_action_meanings(self) -> Dict[int, str]:
|
||||
return {i.value: s for s, i in Action.__members__.items()}
|
||||
|
||||
@property
|
||||
def env_state(self):
|
||||
return self.env.unwrapped.env_state
|
||||
|
||||
|
||||
class CraftaxRenderWrapper(gymnasium.core.Wrapper):
|
||||
"""
|
||||
Wrap Gymax (jas gym) to Gym (original gym)
|
||||
The main difference is that Gymax needs a rng key for every step and reset
|
||||
"""
|
||||
|
||||
def __init__(self, env, render_method: Optional[str] = None) -> None:
|
||||
super().__init__(env)
|
||||
self.render_method = render_method
|
||||
if render_method == "play":
|
||||
self.renderer = CraftaxRenderer(
|
||||
self.env, self.env_params, pixel_render_size=1
|
||||
)
|
||||
self.renderer = None
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
o = self.env.step(*args, **kwargs)
|
||||
if self.renderer is not None:
|
||||
self.renderer.update()
|
||||
return o
|
||||
|
||||
def reset(self, *args, **kwargs):
|
||||
o = self.env.reset(*args, **kwargs)
|
||||
if self.renderer is not None:
|
||||
self.renderer.update()
|
||||
return o
|
||||
|
||||
def render(self, mode="rgb_array"):
|
||||
o = self.env.render()
|
||||
if self.renderer is not None:
|
||||
return self.renderer.render(self.env_state)
|
||||
elif self.render_method == "text":
|
||||
return render_craftax_text(self.env_state)
|
||||
else:
|
||||
return render_craftax_pixels(self.env_state, 10)
|
||||
return o
|
||||
|
||||
def close(self):
|
||||
if self.renderer is not None:
|
||||
self.renderer.pygame.quit()
|
||||
self.renderer.close()
|
||||
|
||||
|
||||
def create_craftax_env(
|
||||
game="Craftax-Symbolic-AutoReset-v1", frame_stack=2, time_limit=None, seed=42, eval=False, num_envs=1
|
||||
):
|
||||
"""
|
||||
Craftax with
|
||||
- frame_stack 4?
|
||||
time_limit = 27000
|
||||
|
||||
"""
|
||||
# see https://github.dev/MichaelTMatthews/Craftax_Baselines/blob/main/ppo_rnn.py
|
||||
assert 'AutoReset' in game, f"Only AutoReset games supported, got {game}"
|
||||
env = make_craftax_env_from_name(game, auto_reset=True)
|
||||
if num_envs > 1:
|
||||
# FIXME: naive optimistic resets don't work well with multiple envs see OptimisticResetVecEnvWrapper
|
||||
env = GymnaxToVectorGymWrapper(env, seed=seed, num_envs=num_envs)
|
||||
raise NotImplementedError("Only num_envs > 1 supported FIXME")
|
||||
else:
|
||||
env = GymnaxToGymWrapper(env, env.default_params, seed=seed)
|
||||
# env = LogWrapper(env)
|
||||
|
||||
# We have to vectorise using jax earlier as there is not framestack wrapepr avaiable for jax
|
||||
env = FrameStack(env, frame_stack)
|
||||
if num_envs > 1:
|
||||
# but then the framestack dim is before the env dim [framestack, batch, obs_dim] so lets swap those
|
||||
env = permute_env(env, [1, 0, 2])
|
||||
|
||||
# env.unwrapped.spec = gym.spec(game) # required for AtariPreprocessing
|
||||
if not eval and time_limit is not None:
|
||||
env = TimeLimit(env, max_episode_steps=time_limit)
|
||||
|
||||
env = CraftaxRenderWrapper(env, render_method=None)
|
||||
env = CraftaxCompatWrapper(env)
|
||||
return env
|
||||
|
||||
class Craftax:
|
||||
metadata = {}
|
||||
|
||||
def __init__(self, task, seed=0):
|
||||
|
||||
self._env = create_craftax_env(task, seed=seed)
|
||||
# self._achievements = crafter.constants.achievements.copy()
|
||||
self.reward_range = [-np.inf, np.inf]
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
spaces = {
|
||||
"state": gym.spaces.Box(0, 1, (np.prod(self._env.observation_space.shape),), dtype=np.float32),
|
||||
"image": gym.spaces.Box(0, 255, (130, 110, 3), dtype=np.uint8),
|
||||
"is_first": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
|
||||
"is_last": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
|
||||
"is_terminal": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
|
||||
"log_reward": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
|
||||
}
|
||||
|
||||
spaces.update(
|
||||
{
|
||||
f"log_achievement_{k.name.lower()}": gym.spaces.Box(
|
||||
-np.inf, np.inf, (1,), dtype=np.float32
|
||||
)
|
||||
for k in Achievement
|
||||
}
|
||||
)
|
||||
return gym.spaces.Dict(spaces)
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
action_space = self._env.action_space
|
||||
# action_space.discrete = True
|
||||
return action_space
|
||||
|
||||
def step(self, action):
|
||||
state, reward, done, info = self._env.step(action)
|
||||
|
||||
reward = np.float32(reward)
|
||||
obs = {
|
||||
"image": self.get_image(),
|
||||
"state": state.flatten(),
|
||||
"is_first": False,
|
||||
"is_last": done,
|
||||
"is_terminal": info["discount"] == 0,
|
||||
**info,
|
||||
}
|
||||
return obs, reward, done, info
|
||||
|
||||
def render(self):
|
||||
return self._env.render()
|
||||
|
||||
def get_image(self):
|
||||
# it looks like we need an image in the obs, even if it's not used, so that we can record videos?
|
||||
image = render_craftax_pixels(self._env.env_state, 10).astype(np.uint8)
|
||||
image = np.array(image).astype(np.uint8)
|
||||
return image
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
state, info = self._env.reset()
|
||||
obs = {
|
||||
"image": self.get_image(),
|
||||
"state": state.flatten(),
|
||||
"is_first": True,
|
||||
"is_last": False,
|
||||
"is_terminal": False,
|
||||
}
|
||||
return obs
|
||||
@@ -0,0 +1,216 @@
|
||||
"""modified from gymnax, to update to gymnasium 1.0.0a2."""
|
||||
|
||||
from typing import Optional, Tuple, Union, List
|
||||
|
||||
import chex
|
||||
import gymnasium
|
||||
import jax.random
|
||||
from gymnasium.core import ActType, ObsType, RenderFrame
|
||||
from gymnasium.vector.utils import batch_space
|
||||
from gymnax.environments.spaces import Space, Discrete, Box
|
||||
from gymnax.environments.environment import Environment, EnvParams
|
||||
from copy import deepcopy
|
||||
|
||||
from typing import Tuple, Sequence, Any, Dict
|
||||
import chex
|
||||
import numpy as np
|
||||
from gymnasium import spaces as gspc
|
||||
|
||||
|
||||
def gymnax_space_to_gym_space(space: Space) -> gspc.Space:
|
||||
"""Convert Gymnax space to equivalent Gym space"""
|
||||
if isinstance(space, Discrete):
|
||||
return gspc.Discrete(space.n)
|
||||
elif isinstance(space, Box):
|
||||
low = (
|
||||
float(space.low)
|
||||
if (np.isscalar(space.low) or space.low.size == 1)
|
||||
else np.array(space.low)
|
||||
)
|
||||
high = (
|
||||
float(space.high)
|
||||
if (np.isscalar(space.high) or space.low.size == 1)
|
||||
else np.array(space.high)
|
||||
)
|
||||
return gspc.Box(low, high, space.shape, space.dtype)
|
||||
elif isinstance(space, Dict):
|
||||
return gspc.Dict({k: gymnax_space_to_gym_space(v) for k, v in space.spaces})
|
||||
elif isinstance(space, Tuple):
|
||||
return gspc.Tuple(space.spaces)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Conversion of {space.__class__.__name__} not supported"
|
||||
)
|
||||
|
||||
|
||||
class GymnaxToGymWrapper(gymnasium.Env):
|
||||
def __init__(
|
||||
self,
|
||||
env: Environment,
|
||||
params: Optional[EnvParams] = None,
|
||||
seed: Optional[int] = None,
|
||||
):
|
||||
"""Wrap Gymnax environment as OOP Gym environment
|
||||
|
||||
Args:
|
||||
env: Gymnax Environment instance
|
||||
params: If provided, gymnax EnvParams for environment (otherwise uses default)
|
||||
seed: If provided, seed for JAX PRNG (otherwise picks 0)
|
||||
"""
|
||||
super().__init__()
|
||||
self._env = deepcopy(env)
|
||||
self.env_params = params if params is not None else env.default_params
|
||||
self.metadata.update(
|
||||
{
|
||||
"name": env.name,
|
||||
"render_modes": ["human", "rgb_array"]
|
||||
if hasattr(env, "render")
|
||||
else [],
|
||||
}
|
||||
)
|
||||
self.rng: chex.PRNGKey = jax.random.PRNGKey(0) # Placeholder
|
||||
self._seed(seed)
|
||||
_, self.env_state = self._env.reset(self.rng, self.env_params)
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
"""Dynamically adjust action space depending on params"""
|
||||
return gymnax_space_to_gym_space(self._env.action_space(self.env_params))
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
"""Dynamically adjust state space depending on params"""
|
||||
return gymnax_space_to_gym_space(self._env.observation_space(self.env_params))
|
||||
|
||||
def _seed(self, seed: Optional[int] = None):
|
||||
"""Set RNG seed (or use 0)"""
|
||||
self.rng = jax.random.PRNGKey(seed or 0)
|
||||
|
||||
def step(
|
||||
self, action: ActType
|
||||
) -> Union[
|
||||
Tuple[ObsType, float, bool, bool, dict],
|
||||
Tuple[ObsType, float, bool, dict],
|
||||
]:
|
||||
"""Step environment, follow new step API"""
|
||||
self.rng, step_key = jax.random.split(self.rng)
|
||||
o, self.env_state, r, d, info = self._env.step(
|
||||
step_key, self.env_state, action, self.env_params
|
||||
)
|
||||
return o, r, d, d, info
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
) -> Tuple[ObsType, dict]:
|
||||
"""Reset environment, update parameters and seed if provided"""
|
||||
if seed is not None:
|
||||
self._seed(seed)
|
||||
if options is not None:
|
||||
self.env_params = options.get(
|
||||
"env_params", self.env_params
|
||||
) # Allow changing environment parameters on reset
|
||||
self.rng, reset_key = jax.random.split(self.rng)
|
||||
o, self.env_state = self._env.reset(reset_key, self.env_params)
|
||||
return o, {}
|
||||
|
||||
def render(self, mode="human") -> Optional[Union[RenderFrame, List[RenderFrame]]]:
|
||||
"""use underlying environment rendering if it exists, otherwise return None"""
|
||||
return getattr(self._env, "render", lambda x, y: None)(
|
||||
self.env_state, self.env_params
|
||||
)
|
||||
|
||||
|
||||
class GymnaxToVectorGymWrapper(gymnasium.vector.VectorEnv):
|
||||
def __init__(
|
||||
self,
|
||||
env: Environment,
|
||||
num_envs: int = 1,
|
||||
params: Optional[EnvParams] = None,
|
||||
seed: Optional[int] = None,
|
||||
):
|
||||
"""Wrap Gymnax environment as OOP Gym Vector Environment
|
||||
|
||||
Args:
|
||||
env: Gymnax Environment instance
|
||||
num_envs: Desired number of environments to run in parallel
|
||||
params: If provided, gymnax EnvParams for environment (otherwise uses default)
|
||||
seed: If provided, seed for JAX PRNG (otherwise picks 0)
|
||||
"""
|
||||
self._env = deepcopy(env)
|
||||
self.num_envs = num_envs
|
||||
self.is_vector_env = True
|
||||
self.new_step_api = True
|
||||
self.closed = False
|
||||
self.viewer = None
|
||||
self.rng: chex.PRNGKey = jax.random.PRNGKey(0) # Placeholder
|
||||
self._seed(seed)
|
||||
# Jit-of-vmap is faster than vmap-of-jit. Map over leading axis of all but env params
|
||||
self._env.reset = jax.jit(jax.vmap(self._env.reset, in_axes=(0, None)))
|
||||
self._env.step = jax.jit(jax.vmap(self._env.step, in_axes=(0, 0, 0, None)))
|
||||
self.env_params = params if params is not None else env.default_params
|
||||
_, self.env_state = self._env.reset(self.rng, self.env_params) # Placeholder
|
||||
self._batched_rng_split = jax.jit(
|
||||
jax.vmap(jax.random.split, in_axes=0, out_axes=1)
|
||||
) # Split all rng keys
|
||||
|
||||
@property
|
||||
def single_action_space(self):
|
||||
"""Dynamically adjust action space depending on params"""
|
||||
return gymnax_space_to_gym_space(self._env.action_space(self.env_params))
|
||||
|
||||
@property
|
||||
def single_observation_space(self):
|
||||
"""Dynamically adjust state space depending on params"""
|
||||
return gymnax_space_to_gym_space(self._env.observation_space(self.env_params))
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
"""Dynamically adjust action space depending on params"""
|
||||
return batch_space(self.single_action_space, self.num_envs)
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
"""Dynamically adjust state space depending on params"""
|
||||
return batch_space(self.single_observation_space, self.num_envs)
|
||||
|
||||
def _seed(self, seed: Optional[int] = None):
|
||||
"""Set RNG seed (or use 0)"""
|
||||
self.rng = jax.random.split(
|
||||
jax.random.PRNGKey(seed or 0), self.num_envs
|
||||
) # 1 RNG per env
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[Union[int, List[int]]] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
) -> Tuple[ObsType, dict]:
|
||||
"""Reset environment, update parameters and seed if provided"""
|
||||
if seed is not None:
|
||||
self._seed(seed)
|
||||
if options is not None:
|
||||
self.env_params = options.get(
|
||||
"env_params", self.env_params
|
||||
) # Allow changing environment parameters on reset
|
||||
self.rng, reset_key = self._batched_rng_split(self.rng) # Split all keys
|
||||
o, self.env_state = self._env.reset(reset_key, self.env_params)
|
||||
return o, {}
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
|
||||
"""Step environment, follow new step API"""
|
||||
self.rng, step_key = self._batched_rng_split(self.rng)
|
||||
o, self.env_state, r, d, info = self._env.step(
|
||||
step_key, self.env_state, action, self.env_params
|
||||
)
|
||||
return o, r, d, d, info
|
||||
|
||||
def render(self, mode="human") -> Optional[Union[RenderFrame, List[RenderFrame]]]:
|
||||
"""use underlying environment rendering if it exists (for first environment), otherwise return None"""
|
||||
return getattr(self._env, "render", lambda x, y: None)(
|
||||
jax.tree_map(lambda x: x[0], self.env_state), self.env_params
|
||||
)
|
||||
@@ -1,4 +1,12 @@
|
||||
|
||||
|
||||
set export
|
||||
export OSTYPE := "linux-gnu"
|
||||
export TQDM_MININTERVAL := "30"
|
||||
|
||||
main:
|
||||
. ./.venv/bin/activate
|
||||
python dreamer.py --configs crafter --task crafter_reward --logdir ./logdir/crafter
|
||||
|
||||
logs:
|
||||
tensorboard --logdir logdir/craftax
|
||||
|
||||
@@ -6,6 +6,8 @@ import networks
|
||||
import tools
|
||||
from loguru import logger
|
||||
|
||||
from envs.craftax_env import state2img
|
||||
|
||||
to_np = lambda x: x.detach().cpu().numpy()
|
||||
|
||||
|
||||
@@ -174,7 +176,8 @@ class WorldModel(nn.Module):
|
||||
# this function is called during both rollout and training
|
||||
def preprocess(self, obs):
|
||||
obs = obs.copy()
|
||||
obs["image"] = torch.Tensor(obs["image"]) / 255.0
|
||||
if "image" in obs:
|
||||
obs["image"] = torch.Tensor(obs["image"]) / 255.0
|
||||
if "discount" in obs:
|
||||
obs["discount"] *= self._config.discount
|
||||
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
|
||||
@@ -187,24 +190,37 @@ class WorldModel(nn.Module):
|
||||
obs = {k: torch.Tensor(v).to(self._config.device) for k, v in obs.items()}
|
||||
return obs
|
||||
|
||||
def video_pred(self, data):
|
||||
def video_pred(self, data, env_state= None):
|
||||
# FIXME: in crafter we are not imagining image
|
||||
data = self.preprocess(data)
|
||||
embed = self.encoder(data)
|
||||
|
||||
states, _ = self.dynamics.observe(
|
||||
embed[:6, :5], data["action"][:6, :5], data["is_first"][:6, :5]
|
||||
)
|
||||
recon = self.heads["decoder"](self.dynamics.get_feat(states))["image"].mode()[
|
||||
:6
|
||||
]
|
||||
# FIXME: assume decoder returns image
|
||||
recon = self.heads["decoder"](self.dynamics.get_feat(states))
|
||||
reward_post = self.heads["reward"](self.dynamics.get_feat(states)).mode()[:6]
|
||||
init = {k: v[:, -1] for k, v in states.items()}
|
||||
prior = self.dynamics.imagine_with_action(data["action"][:6, 5:], init)
|
||||
openl = self.heads["decoder"](self.dynamics.get_feat(prior))["image"].mode()
|
||||
openl = self.heads["decoder"](self.dynamics.get_feat(prior))
|
||||
reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode()
|
||||
|
||||
if "Craftax" in self.config['task']:
|
||||
# in Craftax we run in symbolic mode with no image, so to make a video we need to convert
|
||||
|
||||
recon = state2img(recon["state"].mode()[:6], env_state)
|
||||
truth = state2img(data["state"][:6], env_state)
|
||||
openl = state2img(openl['state'][:6], env_state)
|
||||
|
||||
else:
|
||||
recon = recon["image"].mode()[:6]
|
||||
truth = data["image"][:6]
|
||||
openl = ["image"].mode()
|
||||
|
||||
|
||||
# observed image is given until 5 steps
|
||||
model = torch.cat([recon[:, :5], openl], 1)
|
||||
truth = data["image"][:6]
|
||||
model = model
|
||||
error = (model - truth + 1.0) / 2.0
|
||||
|
||||
|
||||
Generated
+1221
-1
File diff suppressed because it is too large
Load Diff
@@ -35,8 +35,14 @@ loguru = "^0.7.2"
|
||||
imageio-ffmpeg = "^0.5.0"
|
||||
importlib = "^1.0.4"
|
||||
imageio = "^2.34.1"
|
||||
craftax = {path = "/media/wassname/SGIronWolf/projects5/2024/Craftax", develop = true }
|
||||
# craftax = {git = "https://github.com/wassname/Craftax" , develop = true }
|
||||
chex = "^0.1.86"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ipywidgets = "^8.1.3"
|
||||
ipykernel = "^6.29.4"
|
||||
ruff = "^0.1.3"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -80,7 +80,9 @@ class Logger:
|
||||
scalars = list(self._scalars.items())
|
||||
if fps:
|
||||
scalars.append(("fps", self._compute_fps(step)))
|
||||
logger.info(f"[{step}]", " / ".join(f"{k} {v:.1f}" for k, v in scalars))
|
||||
# print out the episode stats
|
||||
stats = " / ".join(f"{k.replace('log_achievement_', '')} <red>{v:.1f}</red>" for k, v in scalars)
|
||||
logger.opt(colors=True).info(f"[{step}] {stats}")
|
||||
with (self._logdir / "metrics.jsonl").open("a") as f:
|
||||
f.write(json.dumps({"step": step, **dict(scalars)}) + "\n")
|
||||
for name, value in scalars:
|
||||
@@ -149,102 +151,103 @@ def simulate(
|
||||
reward = [0] * len(envs)
|
||||
else:
|
||||
step, episode, done, length, obs, agent_state, reward = state
|
||||
while (steps and step < steps) or (episodes and episode < episodes):
|
||||
# reset envs if necessary
|
||||
if done.any():
|
||||
indices = [index for index, d in enumerate(done) if d]
|
||||
results = [envs[i].reset() for i in indices]
|
||||
results = [r() for r in results]
|
||||
for index, result in zip(indices, results):
|
||||
t = result.copy()
|
||||
t = {k: convert(v) for k, v in t.items()}
|
||||
# action will be added to transition in add_to_cache
|
||||
t["reward"] = 0.0
|
||||
t["discount"] = 1.0
|
||||
# initial state should be added to cache
|
||||
add_to_cache(cache, envs[index].id, t)
|
||||
# replace obs with done by initial state
|
||||
obs[index] = result
|
||||
# step agents
|
||||
obs = {k: np.stack([o[k] for o in obs]) for k in obs[0] if "log_" not in k}
|
||||
action, agent_state = agent(obs, done, agent_state)
|
||||
if isinstance(action, dict):
|
||||
action = [
|
||||
{k: np.array(action[k][i].detach().cpu()) for k in action}
|
||||
for i in range(len(envs))
|
||||
]
|
||||
else:
|
||||
action = np.array(action)
|
||||
assert len(action) == len(envs)
|
||||
# step envs
|
||||
results = [e.step(a) for e, a in zip(envs, action)]
|
||||
results = [r() for r in results]
|
||||
obs, reward, done = zip(*[p[:3] for p in results])
|
||||
obs = list(obs)
|
||||
reward = list(reward)
|
||||
done = np.stack(done)
|
||||
episode += int(done.sum())
|
||||
length += 1
|
||||
step += len(envs)
|
||||
pbar.update(len(envs))
|
||||
length *= 1 - done
|
||||
# add to cache
|
||||
for a, result, env in zip(action, results, envs):
|
||||
o, r, d, info = result
|
||||
o = {k: convert(v) for k, v in o.items()}
|
||||
transition = o.copy()
|
||||
if isinstance(a, dict):
|
||||
transition.update(a)
|
||||
with tqdm(total=steps, disable=pbar is None) as pbar:
|
||||
while (steps and step < steps) or (episodes and episode < episodes):
|
||||
# reset envs if necessary
|
||||
if done.any():
|
||||
indices = [index for index, d in enumerate(done) if d]
|
||||
results = [envs[i].reset() for i in indices]
|
||||
results = [r() for r in results]
|
||||
for index, result in zip(indices, results):
|
||||
t = result.copy()
|
||||
t = {k: convert(v) for k, v in t.items()}
|
||||
# action will be added to transition in add_to_cache
|
||||
t["reward"] = 0.0
|
||||
t["discount"] = 1.0
|
||||
# initial state should be added to cache
|
||||
add_to_cache(cache, envs[index].id, t)
|
||||
# replace obs with done by initial state
|
||||
obs[index] = result
|
||||
# step agents
|
||||
obs = {k: np.stack([o[k] for o in obs]) for k in obs[0] if "log_" not in k}
|
||||
action, agent_state = agent(obs, done, agent_state)
|
||||
if isinstance(action, dict):
|
||||
action = [
|
||||
{k: np.array(action[k][i].detach().cpu()) for k in action}
|
||||
for i in range(len(envs))
|
||||
]
|
||||
else:
|
||||
transition["action"] = a
|
||||
transition["reward"] = r
|
||||
transition["discount"] = info.get("discount", np.array(1 - float(d)))
|
||||
add_to_cache(cache, env.id, transition)
|
||||
|
||||
if done.any():
|
||||
indices = [index for index, d in enumerate(done) if d]
|
||||
# logging for done episode
|
||||
for i in indices:
|
||||
save_episodes(directory, {envs[i].id: cache[envs[i].id]})
|
||||
length = len(cache[envs[i].id]["reward"]) - 1
|
||||
score = float(np.array(cache[envs[i].id]["reward"]).sum())
|
||||
video = cache[envs[i].id]["image"]
|
||||
# record logs given from environments
|
||||
for key in list(cache[envs[i].id].keys()):
|
||||
if "log_" in key:
|
||||
logger.scalar(
|
||||
key, float(np.array(cache[envs[i].id][key]).sum())
|
||||
)
|
||||
# log items won't be used later
|
||||
cache[envs[i].id].pop(key)
|
||||
|
||||
if not is_eval:
|
||||
step_in_dataset = erase_over_episodes(cache, limit)
|
||||
logger.scalar(f"dataset_size", step_in_dataset)
|
||||
logger.scalar(f"train_return", score)
|
||||
logger.scalar(f"train_length", length)
|
||||
logger.scalar(f"train_episodes", len(cache))
|
||||
logger.write(step=logger.step)
|
||||
action = np.array(action)
|
||||
assert len(action) == len(envs)
|
||||
# step envs
|
||||
results = [e.step(a) for e, a in zip(envs, action)]
|
||||
results = [r() for r in results]
|
||||
obs, reward, done = zip(*[p[:3] for p in results])
|
||||
obs = list(obs)
|
||||
reward = list(reward)
|
||||
done = np.stack(done)
|
||||
episode += int(done.sum())
|
||||
length += 1
|
||||
step += len(envs)
|
||||
pbar.update(len(envs))
|
||||
length *= 1 - done
|
||||
# add to cache
|
||||
for a, result, env in zip(action, results, envs):
|
||||
o, r, d, info = result
|
||||
o = {k: convert(v) for k, v in o.items()}
|
||||
transition = o.copy()
|
||||
if isinstance(a, dict):
|
||||
transition.update(a)
|
||||
else:
|
||||
if not "eval_lengths" in locals():
|
||||
eval_lengths = []
|
||||
eval_scores = []
|
||||
eval_done = False
|
||||
# start counting scores for evaluation
|
||||
eval_scores.append(score)
|
||||
eval_lengths.append(length)
|
||||
transition["action"] = a
|
||||
transition["reward"] = r
|
||||
transition["discount"] = info.get("discount", np.array(1 - float(d)))
|
||||
add_to_cache(cache, env.id, transition)
|
||||
|
||||
score = sum(eval_scores) / len(eval_scores)
|
||||
length = sum(eval_lengths) / len(eval_lengths)
|
||||
if video_pred_log:
|
||||
logger.video(f"eval_policy", np.array(video)[None])
|
||||
if done.any():
|
||||
indices = [index for index, d in enumerate(done) if d]
|
||||
# logging for done episode
|
||||
for i in indices:
|
||||
save_episodes(directory, {envs[i].id: cache[envs[i].id]})
|
||||
length = len(cache[envs[i].id]["reward"]) - 1
|
||||
score = float(np.array(cache[envs[i].id]["reward"]).sum())
|
||||
video = cache[envs[i].id]["image"]
|
||||
# record logs given from environments
|
||||
for key in list(cache[envs[i].id].keys()):
|
||||
if "log_" in key:
|
||||
logger.scalar(
|
||||
key, float(np.array(cache[envs[i].id][key]).sum())
|
||||
)
|
||||
# log items won't be used later
|
||||
cache[envs[i].id].pop(key)
|
||||
|
||||
if len(eval_scores) >= episodes and not eval_done:
|
||||
logger.scalar(f"eval_return", score)
|
||||
logger.scalar(f"eval_length", length)
|
||||
logger.scalar(f"eval_episodes", len(eval_scores))
|
||||
if not is_eval:
|
||||
step_in_dataset = erase_over_episodes(cache, limit)
|
||||
logger.scalar(f"dataset_size", step_in_dataset)
|
||||
logger.scalar(f"train_return", score)
|
||||
logger.scalar(f"train_length", length)
|
||||
logger.scalar(f"train_episodes", len(cache))
|
||||
logger.write(step=logger.step)
|
||||
eval_done = True
|
||||
else:
|
||||
if not "eval_lengths" in locals():
|
||||
eval_lengths = []
|
||||
eval_scores = []
|
||||
eval_done = False
|
||||
# start counting scores for evaluation
|
||||
eval_scores.append(score)
|
||||
eval_lengths.append(length)
|
||||
|
||||
score = sum(eval_scores) / len(eval_scores)
|
||||
length = sum(eval_lengths) / len(eval_lengths)
|
||||
if video_pred_log:
|
||||
logger.video(f"eval_policy", np.array(video)[None])
|
||||
|
||||
if len(eval_scores) >= episodes and not eval_done:
|
||||
logger.scalar(f"eval_return", score)
|
||||
logger.scalar(f"eval_length", length)
|
||||
logger.scalar(f"eval_episodes", len(eval_scores))
|
||||
logger.write(step=logger.step)
|
||||
eval_done = True
|
||||
if is_eval:
|
||||
# keep only last item for saving memory. this cache is used for video_pred later
|
||||
while len(cache) > 1:
|
||||
|
||||
Reference in New Issue
Block a user