Files
2024-06-06 20:29:17 +08:00

284 lines
9.4 KiB
Python

import gymnasium as gym
import numpy as np
from einops import rearrange
import torch.nn.functional as F
from craftax.craftax_env import make_craftax_env_from_name
from craftax.craftax.play_craftax import CraftaxRenderer
from craftax.craftax.renderer import (
render_craftax_pixels,
render_craftax_text,
inverse_render_craftax_symbolic,
)
from craftax.craftax.constants import Action, Achievement
from craftax.craftax.craftax_state import EnvState
import gymnasium
# from gymnasium.wrappers.jax_to_torch import jax_to_torch
# from gymnasium.wrappers.numpy_to_torch import numpy_to_torch
from gymnasium.wrappers import FrameStack, TimeLimit, FrameStack
import gymnasium.spaces as gym_spaces
from gymnasium.wrappers import TransformObservation
# import jax
import chex
import jax.numpy as jnp
import torch
from jaxtyping import Float, Int, Bool
from torch import Tensor
from typing import Optional, Tuple, Union, Any, Dict
from jax import dlpack as jax_dlpack
from torch.utils import dlpack as torch_dlpack
from envs.gymnax2gymnasium import GymnaxToGymWrapper, GymnaxToVectorGymWrapper
def state2img(state: chex.Array, env_state: EnvState) -> np.ndarray:
img = inverse_render_craftax_symbolic(state, env_state).astype(np.uint8)
return np.array(img).astype(np.uint8)
def permute_env(env, prm=[1, 0, 2]):
os = env.observation_space
oshape = os.shape
new_os = gym_spaces.Box(
low=np.transpose(os.low, prm),
high=np.transpose(os.high, prm),
shape=[oshape[i] for i in prm],
dtype=os.dtype,
)
env = TransformObservation(env, lambda x: jnp.transpose(x, prm), obs_space=new_os)
return env
def jax_to_torch(v) -> torch.Tensor:
dlpack = jax_dlpack.to_dlpack(v)
return torch_dlpack.from_dlpack(dlpack)
def numpy_to_torch(v) -> torch.Tensor:
# for lazyframes
if hasattr(v, '_frames'): v = np.array(v._frames)
return torch.from_numpy(v)
def to_torch(v) -> torch.Tensor:
if isinstance(v, jnp.ndarray):
if v.dtype=='bool':
# bool doesn't convert using the jax_to_torch dlpack
# return torch.from_numpy(v._npy_value.copy())
return torch.as_tensor(v.tolist())
return jax_to_torch(v)
if isinstance(v, np.ndarray):
return numpy_to_torch(v)
if isinstance(v, torch.Tensor):
return v
else:
return torch.as_tensor(v)
class CraftaxCompatWrapper(gymnasium.core.Wrapper):
"""
Misc compat
- from jax
"""
def __init__(self, env) -> None:
super().__init__(env)
self._env = env.unwrapped._env
def step(
self, action: int
) -> Tuple[Float[Tensor, "frames odim"], float, bool, bool, Dict]:
next_obs, reward, terminated, truncated, info = self.env.step(action)
return (
numpy_to_torch(next_obs).to(torch.float16), # in symbolic only lighting needs values other than 0 and 1
to_torch(reward),
# to_torch(terminated),
to_torch(truncated | terminated),
info,
)
def reset(self, *args, **kwargs):
obs, state = self.env.reset(*args, **kwargs)
return numpy_to_torch(obs).to(torch.float16), state
def get_action_meanings(self) -> Dict[int, str]:
return {i.value: s for s, i in Action.__members__.items()}
@property
def env_state(self):
return self.env.unwrapped.env_state
class CraftaxRenderWrapper(gymnasium.core.Wrapper):
"""
Wrap Gymax (jas gym) to Gym (original gym)
The main difference is that Gymax needs a rng key for every step and reset
"""
def __init__(self, env, render_method: Optional[str] = None) -> None:
super().__init__(env)
self.render_method = render_method
if render_method == "play":
self.renderer = CraftaxRenderer(
self.env, self.env_params, pixel_render_size=1
)
self.renderer = None
def step(self, *args, **kwargs):
o = self.env.step(*args, **kwargs)
if self.renderer is not None:
self.renderer.update()
return o
def reset(self, *args, **kwargs):
o = self.env.reset(*args, **kwargs)
if self.renderer is not None:
self.renderer.update()
return o
def render(self, mode="rgb_array"):
o = self.env.render()
if self.renderer is not None:
return self.renderer.render(self.env_state)
elif self.render_method == "text":
return render_craftax_text(self.env_state)
else:
return render_craftax_pixels(self.env_state, 10)
return o
def close(self):
if self.renderer is not None:
self.renderer.pygame.quit()
self.renderer.close()
def create_craftax_env(
game="Craftax-Symbolic-AutoReset-v1", frame_stack=2, time_limit=None, seed=42, eval=False, num_envs=1
):
"""
Craftax with
- frame_stack 4?
time_limit = 27000
"""
# see https://github.dev/MichaelTMatthews/Craftax_Baselines/blob/main/ppo_rnn.py
assert 'AutoReset' in game, f"Only AutoReset games supported, got {game}"
env = make_craftax_env_from_name(game, auto_reset=True)
if num_envs > 1:
# FIXME: naive optimistic resets don't work well with multiple envs see OptimisticResetVecEnvWrapper
env = GymnaxToVectorGymWrapper(env, seed=seed, num_envs=num_envs)
raise NotImplementedError("Only num_envs > 1 supported FIXME")
else:
env = GymnaxToGymWrapper(env, env.default_params, seed=seed)
# env = LogWrapper(env)
# We have to vectorise using jax earlier as there is not framestack wrapepr avaiable for jax
env = FrameStack(env, frame_stack)
if num_envs > 1:
# but then the framestack dim is before the env dim [framestack, batch, obs_dim] so lets swap those
env = permute_env(env, [1, 0, 2])
# env.unwrapped.spec = gym.spec(game) # required for AtariPreprocessing
if not eval and time_limit is not None:
env = TimeLimit(env, max_episode_steps=time_limit)
env = CraftaxRenderWrapper(env, render_method=None)
env = CraftaxCompatWrapper(env)
return env
def reshape_state(state: Float[Tensor, 'frames state_dim']) -> (Float[Tensor,'frames h w c'], Float[Tensor,'frames inv']):
"""
reshapes state into map and inv
https://github.com/MichaelTMatthews/Craftax/blob/main/obs_description.md
"""
map = rearrange(state[:, :8217], 'frames (h w c) -> frames h w c', h=9, w=11, c=83)
# now pad from (9,11) to (16,16) so that it's 2*n
map = F.pad(map, (0, 0, 3, 2, 4, 3))
map = rearrange(map, 'f h w c -> h w (f c)')
inventories = rearrange(state[:, 8217:], 'frames c -> (frames c)')
return map, inventories
class Craftax:
metadata = {}
def __init__(self, task, seed=0):
self._env = create_craftax_env(task, seed=seed)
# self._achievements = crafter.constants.achievements.copy()
self.reward_range = [-np.inf, np.inf]
@property
def observation_space(self):
frames = self._env.observation_space.shape[0]
spaces = {
# "state": gym.spaces.Box(0, 1, (np.prod(self._env.observation_space.shape),), dtype=np.float32),
"state_map": gym.spaces.Box(0, 1, (16, 16, frames*83), dtype=np.float16),
"state_inventory": gym.spaces.Box(0, 1, (frames * 51,), dtype=np.float16),
"image": gym.spaces.Box(0, 255, (130, 110, 3), dtype=np.uint8),
"is_first": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
"is_last": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
"is_terminal": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
"log_reward": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
}
spaces.update(
{
f"log_achievement_{k.name.lower()}": gym.spaces.Box(
-np.inf, np.inf, (1,), dtype=np.float32
)
for k in Achievement
}
)
return gym.spaces.Dict(spaces)
@property
def action_space(self):
action_space = self._env.action_space
# action_space.discrete = True
return action_space
def step(self, action):
state, reward, done, info = self._env.step(action)
info2 = {k.replace('Ach','log_ach'):v for k,v in info.items()}
reward = np.float32(reward)
state_map, state_inv = reshape_state(state)
obs = {
"image": self.get_image(),
"state": state.flatten(),
"state_map": state_map,
"state_inventory": state_inv,
"is_first": False,
"is_last": done,
"is_terminal": info["discount"] == 0,
**info2,
}
return obs, reward, done, info
def render(self):
return self._env.render()
def get_image(self):
# it looks like we need an image in the obs, even if it's not used, so that we can record videos?
image = render_craftax_pixels(self._env.env_state, 10).astype(np.uint8)
image = np.array(image).astype(np.uint8)
return image
def reset(self, seed=None, options=None):
state, info = self._env.reset()
state_map, state_inv = reshape_state(state)
obs = {
"image": self.get_image(),
"state": state.flatten(),
"state_map": state_map,
"state_inventory": state_inv,
"is_first": True,
"is_last": False,
"is_terminal": False,
}
return obs