mirror of
https://github.com/wassname/dreamerv3-torch.git
synced 2026-06-27 16:15:31 +08:00
284 lines
9.4 KiB
Python
284 lines
9.4 KiB
Python
import gymnasium as gym
|
|
import numpy as np
|
|
from einops import rearrange
|
|
import torch.nn.functional as F
|
|
|
|
from craftax.craftax_env import make_craftax_env_from_name
|
|
from craftax.craftax.play_craftax import CraftaxRenderer
|
|
from craftax.craftax.renderer import (
|
|
render_craftax_pixels,
|
|
render_craftax_text,
|
|
inverse_render_craftax_symbolic,
|
|
)
|
|
from craftax.craftax.constants import Action, Achievement
|
|
from craftax.craftax.craftax_state import EnvState
|
|
|
|
import gymnasium
|
|
# from gymnasium.wrappers.jax_to_torch import jax_to_torch
|
|
# from gymnasium.wrappers.numpy_to_torch import numpy_to_torch
|
|
from gymnasium.wrappers import FrameStack, TimeLimit, FrameStack
|
|
import gymnasium.spaces as gym_spaces
|
|
from gymnasium.wrappers import TransformObservation
|
|
|
|
# import jax
|
|
import chex
|
|
import jax.numpy as jnp
|
|
import torch
|
|
from jaxtyping import Float, Int, Bool
|
|
from torch import Tensor
|
|
from typing import Optional, Tuple, Union, Any, Dict
|
|
from jax import dlpack as jax_dlpack
|
|
from torch.utils import dlpack as torch_dlpack
|
|
|
|
from envs.gymnax2gymnasium import GymnaxToGymWrapper, GymnaxToVectorGymWrapper
|
|
|
|
def state2img(state: chex.Array, env_state: EnvState) -> np.ndarray:
|
|
img = inverse_render_craftax_symbolic(state, env_state).astype(np.uint8)
|
|
return np.array(img).astype(np.uint8)
|
|
|
|
|
|
def permute_env(env, prm=[1, 0, 2]):
|
|
os = env.observation_space
|
|
oshape = os.shape
|
|
new_os = gym_spaces.Box(
|
|
low=np.transpose(os.low, prm),
|
|
high=np.transpose(os.high, prm),
|
|
shape=[oshape[i] for i in prm],
|
|
dtype=os.dtype,
|
|
)
|
|
env = TransformObservation(env, lambda x: jnp.transpose(x, prm), obs_space=new_os)
|
|
return env
|
|
|
|
|
|
def jax_to_torch(v) -> torch.Tensor:
|
|
dlpack = jax_dlpack.to_dlpack(v)
|
|
return torch_dlpack.from_dlpack(dlpack)
|
|
|
|
def numpy_to_torch(v) -> torch.Tensor:
|
|
|
|
# for lazyframes
|
|
if hasattr(v, '_frames'): v = np.array(v._frames)
|
|
return torch.from_numpy(v)
|
|
|
|
def to_torch(v) -> torch.Tensor:
|
|
if isinstance(v, jnp.ndarray):
|
|
if v.dtype=='bool':
|
|
# bool doesn't convert using the jax_to_torch dlpack
|
|
# return torch.from_numpy(v._npy_value.copy())
|
|
return torch.as_tensor(v.tolist())
|
|
return jax_to_torch(v)
|
|
if isinstance(v, np.ndarray):
|
|
return numpy_to_torch(v)
|
|
if isinstance(v, torch.Tensor):
|
|
return v
|
|
else:
|
|
return torch.as_tensor(v)
|
|
|
|
|
|
|
|
class CraftaxCompatWrapper(gymnasium.core.Wrapper):
|
|
"""
|
|
Misc compat
|
|
- from jax
|
|
"""
|
|
|
|
def __init__(self, env) -> None:
|
|
super().__init__(env)
|
|
self._env = env.unwrapped._env
|
|
|
|
def step(
|
|
self, action: int
|
|
) -> Tuple[Float[Tensor, "frames odim"], float, bool, bool, Dict]:
|
|
next_obs, reward, terminated, truncated, info = self.env.step(action)
|
|
return (
|
|
numpy_to_torch(next_obs).to(torch.float16), # in symbolic only lighting needs values other than 0 and 1
|
|
to_torch(reward),
|
|
# to_torch(terminated),
|
|
to_torch(truncated | terminated),
|
|
info,
|
|
)
|
|
|
|
def reset(self, *args, **kwargs):
|
|
obs, state = self.env.reset(*args, **kwargs)
|
|
return numpy_to_torch(obs).to(torch.float16), state
|
|
|
|
def get_action_meanings(self) -> Dict[int, str]:
|
|
return {i.value: s for s, i in Action.__members__.items()}
|
|
|
|
@property
|
|
def env_state(self):
|
|
return self.env.unwrapped.env_state
|
|
|
|
|
|
class CraftaxRenderWrapper(gymnasium.core.Wrapper):
|
|
"""
|
|
Wrap Gymax (jas gym) to Gym (original gym)
|
|
The main difference is that Gymax needs a rng key for every step and reset
|
|
"""
|
|
|
|
def __init__(self, env, render_method: Optional[str] = None) -> None:
|
|
super().__init__(env)
|
|
self.render_method = render_method
|
|
if render_method == "play":
|
|
self.renderer = CraftaxRenderer(
|
|
self.env, self.env_params, pixel_render_size=1
|
|
)
|
|
self.renderer = None
|
|
|
|
def step(self, *args, **kwargs):
|
|
o = self.env.step(*args, **kwargs)
|
|
if self.renderer is not None:
|
|
self.renderer.update()
|
|
return o
|
|
|
|
def reset(self, *args, **kwargs):
|
|
o = self.env.reset(*args, **kwargs)
|
|
if self.renderer is not None:
|
|
self.renderer.update()
|
|
return o
|
|
|
|
def render(self, mode="rgb_array"):
|
|
o = self.env.render()
|
|
if self.renderer is not None:
|
|
return self.renderer.render(self.env_state)
|
|
elif self.render_method == "text":
|
|
return render_craftax_text(self.env_state)
|
|
else:
|
|
return render_craftax_pixels(self.env_state, 10)
|
|
return o
|
|
|
|
def close(self):
|
|
if self.renderer is not None:
|
|
self.renderer.pygame.quit()
|
|
self.renderer.close()
|
|
|
|
|
|
def create_craftax_env(
|
|
game="Craftax-Symbolic-AutoReset-v1", frame_stack=2, time_limit=None, seed=42, eval=False, num_envs=1
|
|
):
|
|
"""
|
|
Craftax with
|
|
- frame_stack 4?
|
|
time_limit = 27000
|
|
|
|
"""
|
|
# see https://github.dev/MichaelTMatthews/Craftax_Baselines/blob/main/ppo_rnn.py
|
|
assert 'AutoReset' in game, f"Only AutoReset games supported, got {game}"
|
|
env = make_craftax_env_from_name(game, auto_reset=True)
|
|
if num_envs > 1:
|
|
# FIXME: naive optimistic resets don't work well with multiple envs see OptimisticResetVecEnvWrapper
|
|
env = GymnaxToVectorGymWrapper(env, seed=seed, num_envs=num_envs)
|
|
raise NotImplementedError("Only num_envs > 1 supported FIXME")
|
|
else:
|
|
env = GymnaxToGymWrapper(env, env.default_params, seed=seed)
|
|
# env = LogWrapper(env)
|
|
|
|
# We have to vectorise using jax earlier as there is not framestack wrapepr avaiable for jax
|
|
env = FrameStack(env, frame_stack)
|
|
if num_envs > 1:
|
|
# but then the framestack dim is before the env dim [framestack, batch, obs_dim] so lets swap those
|
|
env = permute_env(env, [1, 0, 2])
|
|
|
|
# env.unwrapped.spec = gym.spec(game) # required for AtariPreprocessing
|
|
if not eval and time_limit is not None:
|
|
env = TimeLimit(env, max_episode_steps=time_limit)
|
|
|
|
env = CraftaxRenderWrapper(env, render_method=None)
|
|
env = CraftaxCompatWrapper(env)
|
|
return env
|
|
|
|
|
|
def reshape_state(state: Float[Tensor, 'frames state_dim']) -> (Float[Tensor,'frames h w c'], Float[Tensor,'frames inv']):
|
|
"""
|
|
reshapes state into map and inv
|
|
|
|
https://github.com/MichaelTMatthews/Craftax/blob/main/obs_description.md
|
|
"""
|
|
map = rearrange(state[:, :8217], 'frames (h w c) -> frames h w c', h=9, w=11, c=83)
|
|
# now pad from (9,11) to (16,16) so that it's 2*n
|
|
map = F.pad(map, (0, 0, 3, 2, 4, 3))
|
|
map = rearrange(map, 'f h w c -> h w (f c)')
|
|
inventories = rearrange(state[:, 8217:], 'frames c -> (frames c)')
|
|
return map, inventories
|
|
|
|
class Craftax:
|
|
metadata = {}
|
|
|
|
def __init__(self, task, seed=0):
|
|
|
|
self._env = create_craftax_env(task, seed=seed)
|
|
# self._achievements = crafter.constants.achievements.copy()
|
|
self.reward_range = [-np.inf, np.inf]
|
|
|
|
@property
|
|
def observation_space(self):
|
|
frames = self._env.observation_space.shape[0]
|
|
spaces = {
|
|
# "state": gym.spaces.Box(0, 1, (np.prod(self._env.observation_space.shape),), dtype=np.float32),
|
|
"state_map": gym.spaces.Box(0, 1, (16, 16, frames*83), dtype=np.float16),
|
|
"state_inventory": gym.spaces.Box(0, 1, (frames * 51,), dtype=np.float16),
|
|
"image": gym.spaces.Box(0, 255, (130, 110, 3), dtype=np.uint8),
|
|
"is_first": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
|
|
"is_last": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
|
|
"is_terminal": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
|
|
"log_reward": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
|
|
}
|
|
|
|
spaces.update(
|
|
{
|
|
f"log_achievement_{k.name.lower()}": gym.spaces.Box(
|
|
-np.inf, np.inf, (1,), dtype=np.float32
|
|
)
|
|
for k in Achievement
|
|
}
|
|
)
|
|
return gym.spaces.Dict(spaces)
|
|
|
|
@property
|
|
def action_space(self):
|
|
action_space = self._env.action_space
|
|
# action_space.discrete = True
|
|
return action_space
|
|
|
|
def step(self, action):
|
|
state, reward, done, info = self._env.step(action)
|
|
|
|
info2 = {k.replace('Ach','log_ach'):v for k,v in info.items()}
|
|
|
|
reward = np.float32(reward)
|
|
state_map, state_inv = reshape_state(state)
|
|
obs = {
|
|
"image": self.get_image(),
|
|
"state": state.flatten(),
|
|
"state_map": state_map,
|
|
"state_inventory": state_inv,
|
|
"is_first": False,
|
|
"is_last": done,
|
|
"is_terminal": info["discount"] == 0,
|
|
**info2,
|
|
}
|
|
return obs, reward, done, info
|
|
|
|
def render(self):
|
|
return self._env.render()
|
|
|
|
def get_image(self):
|
|
# it looks like we need an image in the obs, even if it's not used, so that we can record videos?
|
|
image = render_craftax_pixels(self._env.env_state, 10).astype(np.uint8)
|
|
image = np.array(image).astype(np.uint8)
|
|
return image
|
|
|
|
def reset(self, seed=None, options=None):
|
|
state, info = self._env.reset()
|
|
state_map, state_inv = reshape_state(state)
|
|
obs = {
|
|
"image": self.get_image(),
|
|
"state": state.flatten(),
|
|
"state_map": state_map,
|
|
"state_inventory": state_inv,
|
|
"is_first": True,
|
|
"is_last": False,
|
|
"is_terminal": False,
|
|
}
|
|
return obs
|