This commit is contained in:
wassname
2024-06-03 20:09:54 +08:00
parent 5a0e8dc5ac
commit ff14ca4639
9 changed files with 1858 additions and 105 deletions
+20 -1
View File
@@ -62,7 +62,7 @@ defaults:
initial: 'learned'
# Training
batch_size: 16
batch_size: 256
batch_length: 64
train_ratio: 512
pretrain: 100
@@ -129,6 +129,25 @@ crafter:
cont_head: {layers: 5}
imag_gradient: 'reinforce'
craftax:
task: craftax_Craftax-Symbolic-AutoReset-v1
step: 1e6
action_repeat: 1
envs: 1
train_ratio: 512
video_pred_log: false # FIXME
dyn_hidden: 1024
dyn_deter: 4096
units: 1024
encoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 5, mlp_units: 1024, }
decoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 5, mlp_units: 1024}
actor: {layers: 5, dist: 'onehot', std: 'none'}
value: {layers: 5}
reward_head: {layers: 5}
cont_head: {layers: 5}
imag_gradient: 'reinforce'
atari100k:
steps: 4e5
envs: 1
+13 -4
View File
@@ -76,7 +76,8 @@ class Dreamer(nn.Module):
self._logger.scalar(name, float(np.mean(values)))
self._metrics[name] = []
if self._config.video_pred_log:
openl = self._wm.video_pred(next(self._dataset))
# FIXME need to provide this state or synthetic one
openl = self._wm.video_pred(next(self._dataset), env_state)
self._logger.video("train_openl", to_np(openl))
self._logger.write(fps=True)
@@ -192,6 +193,11 @@ def make_env(config, mode, id):
env = crafter.Crafter(task, config.size, seed=config.seed + id)
env = wrappers.OneHotAction(env)
elif suite == "craftax":
from envs import craftax_env
env = craftax_env.Craftax(task,seed=config.seed + id)
env = wrappers.OneHotAction(env)
elif suite == "minecraft":
import envs.minecraft as minecraft
@@ -226,6 +232,7 @@ def main(config):
step = count_steps(config.traindir)
# step in logger is environmental step
tlogger = tools.Logger(logdir, config.action_repeat * step)
logger.add(logdir/"logger.log")
logger.info("Create envs.")
if config.offline_traindir:
@@ -303,7 +310,7 @@ def main(config):
agent._should_pretrain._once = False
# make sure eval will be executed once after config.steps
with tqdm(total=config.steps + config.eval_every) as pbar:
with tqdm(total=config.steps + config.eval_every, unit='step') as pbar:
while agent._step < config.steps + config.eval_every:
tlogger.write()
if config.eval_episode_num > 0:
@@ -321,7 +328,8 @@ def main(config):
pbar=pbar,
)
if config.video_pred_log:
video_pred = agent._wm.video_pred(next(eval_dataset))
env_state = eval_envs[0].env_state
video_pred = agent._wm.video_pred(next(eval_dataset), env_state)
tlogger.video("eval_openl", to_np(video_pred))
logger.info("Start training.")
state = tools.simulate(
@@ -333,6 +341,7 @@ def main(config):
limit=config.dataset_size,
steps=config.eval_every,
state=state,
pbar=pbar,
)
items_to_save = {
"agent_state_dict": agent.state_dict(),
@@ -340,7 +349,7 @@ def main(config):
}
torch.save(items_to_save, logdir / "latest.pt")
logger.info(f"Saved model to {logdir / 'latest.pt'}")
pbar.update(agent._step-pbar.n) # 16858 at a time
# pbar.update(agent._step-pbar.n) # 16858 at a time
for env in train_envs + eval_envs:
try:
env.close()
+256
View File
@@ -0,0 +1,256 @@
import gymnasium as gym
import numpy as np
from craftax.craftax_env import make_craftax_env_from_name
from craftax.craftax.play_craftax import CraftaxRenderer
from craftax.craftax.renderer import (
render_craftax_pixels,
render_craftax_text,
inverse_render_craftax_symbolic,
)
from craftax.craftax.constants import Action, Achievement
from craftax.craftax.craftax_state import EnvState
import gymnasium
# from gymnasium.wrappers.jax_to_torch import jax_to_torch
# from gymnasium.wrappers.numpy_to_torch import numpy_to_torch
from gymnasium.wrappers import FrameStack, TimeLimit, FrameStack
import gymnasium.spaces as gym_spaces
from gymnasium.wrappers import TransformObservation
# import jax
import chex
import jax.numpy as jnp
import torch
from jaxtyping import Float, Int, Bool
from torch import Tensor
from typing import Optional, Tuple, Union, Any, Dict
from jax import dlpack as jax_dlpack
from torch.utils import dlpack as torch_dlpack
from envs.gymnax2gymnasium import GymnaxToGymWrapper, GymnaxToVectorGymWrapper
def state2img(state: chex.Array, env_state: EnvState) -> np.ndarray:
img = inverse_render_craftax_symbolic(state, env_state).astype(np.uint8)
return np.array(img).astype(np.uint8)
def permute_env(env, prm=[1, 0, 2]):
os = env.observation_space
oshape = os.shape
new_os = gym_spaces.Box(
low=np.transpose(os.low, prm),
high=np.transpose(os.high, prm),
shape=[oshape[i] for i in prm],
dtype=os.dtype,
)
env = TransformObservation(env, lambda x: jnp.transpose(x, prm), obs_space=new_os)
return env
def jax_to_torch(v) -> torch.Tensor:
dlpack = jax_dlpack.to_dlpack(v)
return torch_dlpack.from_dlpack(dlpack)
def numpy_to_torch(v) -> torch.Tensor:
# for lazyframes
if hasattr(v, '_frames'): v = np.array(v._frames)
return torch.from_numpy(v)
def to_torch(v) -> torch.Tensor:
if isinstance(v, jnp.ndarray):
if v.dtype=='bool':
# bool doesn't convert using the jax_to_torch dlpack
# return torch.from_numpy(v._npy_value.copy())
return torch.as_tensor(v.tolist())
return jax_to_torch(v)
if isinstance(v, np.ndarray):
return numpy_to_torch(v)
if isinstance(v, torch.Tensor):
return v
else:
return torch.as_tensor(v)
class CraftaxCompatWrapper(gymnasium.core.Wrapper):
"""
Misc compat
- from jax
"""
def __init__(self, env) -> None:
super().__init__(env)
self._env = env.unwrapped._env
def step(
self, action: int
) -> Tuple[Float[Tensor, "frames odim"], float, bool, bool, Dict]:
next_obs, reward, terminated, truncated, info = self.env.step(action)
return (
numpy_to_torch(next_obs).to(torch.float16), # in symbolic only lighting needs values other than 0 and 1
to_torch(reward),
# to_torch(terminated),
to_torch(truncated | terminated),
info,
)
def reset(self, *args, **kwargs):
obs, state = self.env.reset(*args, **kwargs)
return numpy_to_torch(obs).to(torch.float16), state
def get_action_meanings(self) -> Dict[int, str]:
return {i.value: s for s, i in Action.__members__.items()}
@property
def env_state(self):
return self.env.unwrapped.env_state
class CraftaxRenderWrapper(gymnasium.core.Wrapper):
"""
Wrap Gymax (jas gym) to Gym (original gym)
The main difference is that Gymax needs a rng key for every step and reset
"""
def __init__(self, env, render_method: Optional[str] = None) -> None:
super().__init__(env)
self.render_method = render_method
if render_method == "play":
self.renderer = CraftaxRenderer(
self.env, self.env_params, pixel_render_size=1
)
self.renderer = None
def step(self, *args, **kwargs):
o = self.env.step(*args, **kwargs)
if self.renderer is not None:
self.renderer.update()
return o
def reset(self, *args, **kwargs):
o = self.env.reset(*args, **kwargs)
if self.renderer is not None:
self.renderer.update()
return o
def render(self, mode="rgb_array"):
o = self.env.render()
if self.renderer is not None:
return self.renderer.render(self.env_state)
elif self.render_method == "text":
return render_craftax_text(self.env_state)
else:
return render_craftax_pixels(self.env_state, 10)
return o
def close(self):
if self.renderer is not None:
self.renderer.pygame.quit()
self.renderer.close()
def create_craftax_env(
game="Craftax-Symbolic-AutoReset-v1", frame_stack=2, time_limit=None, seed=42, eval=False, num_envs=1
):
"""
Craftax with
- frame_stack 4?
time_limit = 27000
"""
# see https://github.dev/MichaelTMatthews/Craftax_Baselines/blob/main/ppo_rnn.py
assert 'AutoReset' in game, f"Only AutoReset games supported, got {game}"
env = make_craftax_env_from_name(game, auto_reset=True)
if num_envs > 1:
# FIXME: naive optimistic resets don't work well with multiple envs see OptimisticResetVecEnvWrapper
env = GymnaxToVectorGymWrapper(env, seed=seed, num_envs=num_envs)
raise NotImplementedError("Only num_envs > 1 supported FIXME")
else:
env = GymnaxToGymWrapper(env, env.default_params, seed=seed)
# env = LogWrapper(env)
# We have to vectorise using jax earlier as there is not framestack wrapepr avaiable for jax
env = FrameStack(env, frame_stack)
if num_envs > 1:
# but then the framestack dim is before the env dim [framestack, batch, obs_dim] so lets swap those
env = permute_env(env, [1, 0, 2])
# env.unwrapped.spec = gym.spec(game) # required for AtariPreprocessing
if not eval and time_limit is not None:
env = TimeLimit(env, max_episode_steps=time_limit)
env = CraftaxRenderWrapper(env, render_method=None)
env = CraftaxCompatWrapper(env)
return env
class Craftax:
metadata = {}
def __init__(self, task, seed=0):
self._env = create_craftax_env(task, seed=seed)
# self._achievements = crafter.constants.achievements.copy()
self.reward_range = [-np.inf, np.inf]
@property
def observation_space(self):
spaces = {
"state": gym.spaces.Box(0, 1, (np.prod(self._env.observation_space.shape),), dtype=np.float32),
"image": gym.spaces.Box(0, 255, (130, 110, 3), dtype=np.uint8),
"is_first": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
"is_last": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
"is_terminal": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
"log_reward": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
}
spaces.update(
{
f"log_achievement_{k.name.lower()}": gym.spaces.Box(
-np.inf, np.inf, (1,), dtype=np.float32
)
for k in Achievement
}
)
return gym.spaces.Dict(spaces)
@property
def action_space(self):
action_space = self._env.action_space
# action_space.discrete = True
return action_space
def step(self, action):
state, reward, done, info = self._env.step(action)
reward = np.float32(reward)
obs = {
"image": self.get_image(),
"state": state.flatten(),
"is_first": False,
"is_last": done,
"is_terminal": info["discount"] == 0,
**info,
}
return obs, reward, done, info
def render(self):
return self._env.render()
def get_image(self):
# it looks like we need an image in the obs, even if it's not used, so that we can record videos?
image = render_craftax_pixels(self._env.env_state, 10).astype(np.uint8)
image = np.array(image).astype(np.uint8)
return image
def reset(self, seed=None, options=None):
state, info = self._env.reset()
obs = {
"image": self.get_image(),
"state": state.flatten(),
"is_first": True,
"is_last": False,
"is_terminal": False,
}
return obs
+216
View File
@@ -0,0 +1,216 @@
"""modified from gymnax, to update to gymnasium 1.0.0a2."""
from typing import Optional, Tuple, Union, List
import chex
import gymnasium
import jax.random
from gymnasium.core import ActType, ObsType, RenderFrame
from gymnasium.vector.utils import batch_space
from gymnax.environments.spaces import Space, Discrete, Box
from gymnax.environments.environment import Environment, EnvParams
from copy import deepcopy
from typing import Tuple, Sequence, Any, Dict
import chex
import numpy as np
from gymnasium import spaces as gspc
def gymnax_space_to_gym_space(space: Space) -> gspc.Space:
"""Convert Gymnax space to equivalent Gym space"""
if isinstance(space, Discrete):
return gspc.Discrete(space.n)
elif isinstance(space, Box):
low = (
float(space.low)
if (np.isscalar(space.low) or space.low.size == 1)
else np.array(space.low)
)
high = (
float(space.high)
if (np.isscalar(space.high) or space.low.size == 1)
else np.array(space.high)
)
return gspc.Box(low, high, space.shape, space.dtype)
elif isinstance(space, Dict):
return gspc.Dict({k: gymnax_space_to_gym_space(v) for k, v in space.spaces})
elif isinstance(space, Tuple):
return gspc.Tuple(space.spaces)
else:
raise NotImplementedError(
f"Conversion of {space.__class__.__name__} not supported"
)
class GymnaxToGymWrapper(gymnasium.Env):
def __init__(
self,
env: Environment,
params: Optional[EnvParams] = None,
seed: Optional[int] = None,
):
"""Wrap Gymnax environment as OOP Gym environment
Args:
env: Gymnax Environment instance
params: If provided, gymnax EnvParams for environment (otherwise uses default)
seed: If provided, seed for JAX PRNG (otherwise picks 0)
"""
super().__init__()
self._env = deepcopy(env)
self.env_params = params if params is not None else env.default_params
self.metadata.update(
{
"name": env.name,
"render_modes": ["human", "rgb_array"]
if hasattr(env, "render")
else [],
}
)
self.rng: chex.PRNGKey = jax.random.PRNGKey(0) # Placeholder
self._seed(seed)
_, self.env_state = self._env.reset(self.rng, self.env_params)
@property
def action_space(self):
"""Dynamically adjust action space depending on params"""
return gymnax_space_to_gym_space(self._env.action_space(self.env_params))
@property
def observation_space(self):
"""Dynamically adjust state space depending on params"""
return gymnax_space_to_gym_space(self._env.observation_space(self.env_params))
def _seed(self, seed: Optional[int] = None):
"""Set RNG seed (or use 0)"""
self.rng = jax.random.PRNGKey(seed or 0)
def step(
self, action: ActType
) -> Union[
Tuple[ObsType, float, bool, bool, dict],
Tuple[ObsType, float, bool, dict],
]:
"""Step environment, follow new step API"""
self.rng, step_key = jax.random.split(self.rng)
o, self.env_state, r, d, info = self._env.step(
step_key, self.env_state, action, self.env_params
)
return o, r, d, d, info
def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
) -> Tuple[ObsType, dict]:
"""Reset environment, update parameters and seed if provided"""
if seed is not None:
self._seed(seed)
if options is not None:
self.env_params = options.get(
"env_params", self.env_params
) # Allow changing environment parameters on reset
self.rng, reset_key = jax.random.split(self.rng)
o, self.env_state = self._env.reset(reset_key, self.env_params)
return o, {}
def render(self, mode="human") -> Optional[Union[RenderFrame, List[RenderFrame]]]:
"""use underlying environment rendering if it exists, otherwise return None"""
return getattr(self._env, "render", lambda x, y: None)(
self.env_state, self.env_params
)
class GymnaxToVectorGymWrapper(gymnasium.vector.VectorEnv):
def __init__(
self,
env: Environment,
num_envs: int = 1,
params: Optional[EnvParams] = None,
seed: Optional[int] = None,
):
"""Wrap Gymnax environment as OOP Gym Vector Environment
Args:
env: Gymnax Environment instance
num_envs: Desired number of environments to run in parallel
params: If provided, gymnax EnvParams for environment (otherwise uses default)
seed: If provided, seed for JAX PRNG (otherwise picks 0)
"""
self._env = deepcopy(env)
self.num_envs = num_envs
self.is_vector_env = True
self.new_step_api = True
self.closed = False
self.viewer = None
self.rng: chex.PRNGKey = jax.random.PRNGKey(0) # Placeholder
self._seed(seed)
# Jit-of-vmap is faster than vmap-of-jit. Map over leading axis of all but env params
self._env.reset = jax.jit(jax.vmap(self._env.reset, in_axes=(0, None)))
self._env.step = jax.jit(jax.vmap(self._env.step, in_axes=(0, 0, 0, None)))
self.env_params = params if params is not None else env.default_params
_, self.env_state = self._env.reset(self.rng, self.env_params) # Placeholder
self._batched_rng_split = jax.jit(
jax.vmap(jax.random.split, in_axes=0, out_axes=1)
) # Split all rng keys
@property
def single_action_space(self):
"""Dynamically adjust action space depending on params"""
return gymnax_space_to_gym_space(self._env.action_space(self.env_params))
@property
def single_observation_space(self):
"""Dynamically adjust state space depending on params"""
return gymnax_space_to_gym_space(self._env.observation_space(self.env_params))
@property
def action_space(self):
"""Dynamically adjust action space depending on params"""
return batch_space(self.single_action_space, self.num_envs)
@property
def observation_space(self):
"""Dynamically adjust state space depending on params"""
return batch_space(self.single_observation_space, self.num_envs)
def _seed(self, seed: Optional[int] = None):
"""Set RNG seed (or use 0)"""
self.rng = jax.random.split(
jax.random.PRNGKey(seed or 0), self.num_envs
) # 1 RNG per env
def reset(
self,
*,
seed: Optional[Union[int, List[int]]] = None,
return_info: bool = False,
options: Optional[dict] = None,
) -> Tuple[ObsType, dict]:
"""Reset environment, update parameters and seed if provided"""
if seed is not None:
self._seed(seed)
if options is not None:
self.env_params = options.get(
"env_params", self.env_params
) # Allow changing environment parameters on reset
self.rng, reset_key = self._batched_rng_split(self.rng) # Split all keys
o, self.env_state = self._env.reset(reset_key, self.env_params)
return o, {}
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
"""Step environment, follow new step API"""
self.rng, step_key = self._batched_rng_split(self.rng)
o, self.env_state, r, d, info = self._env.step(
step_key, self.env_state, action, self.env_params
)
return o, r, d, d, info
def render(self, mode="human") -> Optional[Union[RenderFrame, List[RenderFrame]]]:
"""use underlying environment rendering if it exists (for first environment), otherwise return None"""
return getattr(self._env, "render", lambda x, y: None)(
jax.tree_map(lambda x: x[0], self.env_state), self.env_params
)
+8
View File
@@ -1,4 +1,12 @@
set export
export OSTYPE := "linux-gnu"
export TQDM_MININTERVAL := "30"
main:
. ./.venv/bin/activate
python dreamer.py --configs crafter --task crafter_reward --logdir ./logdir/crafter
logs:
tensorboard --logdir logdir/craftax
+23 -7
View File
@@ -6,6 +6,8 @@ import networks
import tools
from loguru import logger
from envs.craftax_env import state2img
to_np = lambda x: x.detach().cpu().numpy()
@@ -174,7 +176,8 @@ class WorldModel(nn.Module):
# this function is called during both rollout and training
def preprocess(self, obs):
obs = obs.copy()
obs["image"] = torch.Tensor(obs["image"]) / 255.0
if "image" in obs:
obs["image"] = torch.Tensor(obs["image"]) / 255.0
if "discount" in obs:
obs["discount"] *= self._config.discount
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
@@ -187,24 +190,37 @@ class WorldModel(nn.Module):
obs = {k: torch.Tensor(v).to(self._config.device) for k, v in obs.items()}
return obs
def video_pred(self, data):
def video_pred(self, data, env_state= None):
# FIXME: in crafter we are not imagining image
data = self.preprocess(data)
embed = self.encoder(data)
states, _ = self.dynamics.observe(
embed[:6, :5], data["action"][:6, :5], data["is_first"][:6, :5]
)
recon = self.heads["decoder"](self.dynamics.get_feat(states))["image"].mode()[
:6
]
# FIXME: assume decoder returns image
recon = self.heads["decoder"](self.dynamics.get_feat(states))
reward_post = self.heads["reward"](self.dynamics.get_feat(states)).mode()[:6]
init = {k: v[:, -1] for k, v in states.items()}
prior = self.dynamics.imagine_with_action(data["action"][:6, 5:], init)
openl = self.heads["decoder"](self.dynamics.get_feat(prior))["image"].mode()
openl = self.heads["decoder"](self.dynamics.get_feat(prior))
reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode()
if "Craftax" in self.config['task']:
# in Craftax we run in symbolic mode with no image, so to make a video we need to convert
recon = state2img(recon["state"].mode()[:6], env_state)
truth = state2img(data["state"][:6], env_state)
openl = state2img(openl['state'][:6], env_state)
else:
recon = recon["image"].mode()[:6]
truth = data["image"][:6]
openl = ["image"].mode()
# observed image is given until 5 steps
model = torch.cat([recon[:, :5], openl], 1)
truth = data["image"][:6]
model = model
error = (model - truth + 1.0) / 2.0
Generated
+1221 -1
View File
File diff suppressed because it is too large Load Diff
+6
View File
@@ -35,8 +35,14 @@ loguru = "^0.7.2"
imageio-ffmpeg = "^0.5.0"
importlib = "^1.0.4"
imageio = "^2.34.1"
craftax = {path = "/media/wassname/SGIronWolf/projects5/2024/Craftax", develop = true }
# craftax = {git = "https://github.com/wassname/Craftax" , develop = true }
chex = "^0.1.86"
[tool.poetry.group.dev.dependencies]
ipywidgets = "^8.1.3"
ipykernel = "^6.29.4"
ruff = "^0.1.3"
[build-system]
requires = ["poetry-core"]
+95 -92
View File
@@ -80,7 +80,9 @@ class Logger:
scalars = list(self._scalars.items())
if fps:
scalars.append(("fps", self._compute_fps(step)))
logger.info(f"[{step}]", " / ".join(f"{k} {v:.1f}" for k, v in scalars))
# print out the episode stats
stats = " / ".join(f"{k.replace('log_achievement_', '')} <red>{v:.1f}</red>" for k, v in scalars)
logger.opt(colors=True).info(f"[{step}] {stats}")
with (self._logdir / "metrics.jsonl").open("a") as f:
f.write(json.dumps({"step": step, **dict(scalars)}) + "\n")
for name, value in scalars:
@@ -149,102 +151,103 @@ def simulate(
reward = [0] * len(envs)
else:
step, episode, done, length, obs, agent_state, reward = state
while (steps and step < steps) or (episodes and episode < episodes):
# reset envs if necessary
if done.any():
indices = [index for index, d in enumerate(done) if d]
results = [envs[i].reset() for i in indices]
results = [r() for r in results]
for index, result in zip(indices, results):
t = result.copy()
t = {k: convert(v) for k, v in t.items()}
# action will be added to transition in add_to_cache
t["reward"] = 0.0
t["discount"] = 1.0
# initial state should be added to cache
add_to_cache(cache, envs[index].id, t)
# replace obs with done by initial state
obs[index] = result
# step agents
obs = {k: np.stack([o[k] for o in obs]) for k in obs[0] if "log_" not in k}
action, agent_state = agent(obs, done, agent_state)
if isinstance(action, dict):
action = [
{k: np.array(action[k][i].detach().cpu()) for k in action}
for i in range(len(envs))
]
else:
action = np.array(action)
assert len(action) == len(envs)
# step envs
results = [e.step(a) for e, a in zip(envs, action)]
results = [r() for r in results]
obs, reward, done = zip(*[p[:3] for p in results])
obs = list(obs)
reward = list(reward)
done = np.stack(done)
episode += int(done.sum())
length += 1
step += len(envs)
pbar.update(len(envs))
length *= 1 - done
# add to cache
for a, result, env in zip(action, results, envs):
o, r, d, info = result
o = {k: convert(v) for k, v in o.items()}
transition = o.copy()
if isinstance(a, dict):
transition.update(a)
with tqdm(total=steps, disable=pbar is None) as pbar:
while (steps and step < steps) or (episodes and episode < episodes):
# reset envs if necessary
if done.any():
indices = [index for index, d in enumerate(done) if d]
results = [envs[i].reset() for i in indices]
results = [r() for r in results]
for index, result in zip(indices, results):
t = result.copy()
t = {k: convert(v) for k, v in t.items()}
# action will be added to transition in add_to_cache
t["reward"] = 0.0
t["discount"] = 1.0
# initial state should be added to cache
add_to_cache(cache, envs[index].id, t)
# replace obs with done by initial state
obs[index] = result
# step agents
obs = {k: np.stack([o[k] for o in obs]) for k in obs[0] if "log_" not in k}
action, agent_state = agent(obs, done, agent_state)
if isinstance(action, dict):
action = [
{k: np.array(action[k][i].detach().cpu()) for k in action}
for i in range(len(envs))
]
else:
transition["action"] = a
transition["reward"] = r
transition["discount"] = info.get("discount", np.array(1 - float(d)))
add_to_cache(cache, env.id, transition)
if done.any():
indices = [index for index, d in enumerate(done) if d]
# logging for done episode
for i in indices:
save_episodes(directory, {envs[i].id: cache[envs[i].id]})
length = len(cache[envs[i].id]["reward"]) - 1
score = float(np.array(cache[envs[i].id]["reward"]).sum())
video = cache[envs[i].id]["image"]
# record logs given from environments
for key in list(cache[envs[i].id].keys()):
if "log_" in key:
logger.scalar(
key, float(np.array(cache[envs[i].id][key]).sum())
)
# log items won't be used later
cache[envs[i].id].pop(key)
if not is_eval:
step_in_dataset = erase_over_episodes(cache, limit)
logger.scalar(f"dataset_size", step_in_dataset)
logger.scalar(f"train_return", score)
logger.scalar(f"train_length", length)
logger.scalar(f"train_episodes", len(cache))
logger.write(step=logger.step)
action = np.array(action)
assert len(action) == len(envs)
# step envs
results = [e.step(a) for e, a in zip(envs, action)]
results = [r() for r in results]
obs, reward, done = zip(*[p[:3] for p in results])
obs = list(obs)
reward = list(reward)
done = np.stack(done)
episode += int(done.sum())
length += 1
step += len(envs)
pbar.update(len(envs))
length *= 1 - done
# add to cache
for a, result, env in zip(action, results, envs):
o, r, d, info = result
o = {k: convert(v) for k, v in o.items()}
transition = o.copy()
if isinstance(a, dict):
transition.update(a)
else:
if not "eval_lengths" in locals():
eval_lengths = []
eval_scores = []
eval_done = False
# start counting scores for evaluation
eval_scores.append(score)
eval_lengths.append(length)
transition["action"] = a
transition["reward"] = r
transition["discount"] = info.get("discount", np.array(1 - float(d)))
add_to_cache(cache, env.id, transition)
score = sum(eval_scores) / len(eval_scores)
length = sum(eval_lengths) / len(eval_lengths)
if video_pred_log:
logger.video(f"eval_policy", np.array(video)[None])
if done.any():
indices = [index for index, d in enumerate(done) if d]
# logging for done episode
for i in indices:
save_episodes(directory, {envs[i].id: cache[envs[i].id]})
length = len(cache[envs[i].id]["reward"]) - 1
score = float(np.array(cache[envs[i].id]["reward"]).sum())
video = cache[envs[i].id]["image"]
# record logs given from environments
for key in list(cache[envs[i].id].keys()):
if "log_" in key:
logger.scalar(
key, float(np.array(cache[envs[i].id][key]).sum())
)
# log items won't be used later
cache[envs[i].id].pop(key)
if len(eval_scores) >= episodes and not eval_done:
logger.scalar(f"eval_return", score)
logger.scalar(f"eval_length", length)
logger.scalar(f"eval_episodes", len(eval_scores))
if not is_eval:
step_in_dataset = erase_over_episodes(cache, limit)
logger.scalar(f"dataset_size", step_in_dataset)
logger.scalar(f"train_return", score)
logger.scalar(f"train_length", length)
logger.scalar(f"train_episodes", len(cache))
logger.write(step=logger.step)
eval_done = True
else:
if not "eval_lengths" in locals():
eval_lengths = []
eval_scores = []
eval_done = False
# start counting scores for evaluation
eval_scores.append(score)
eval_lengths.append(length)
score = sum(eval_scores) / len(eval_scores)
length = sum(eval_lengths) / len(eval_lengths)
if video_pred_log:
logger.video(f"eval_policy", np.array(video)[None])
if len(eval_scores) >= episodes and not eval_done:
logger.scalar(f"eval_return", score)
logger.scalar(f"eval_length", length)
logger.scalar(f"eval_episodes", len(eval_scores))
logger.write(step=logger.step)
eval_done = True
if is_eval:
# keep only last item for saving memory. this cache is used for video_pred later
while len(cache) > 1: