Initial commit.

This commit is contained in:
Danijar Hafner
2020-12-04 17:44:38 -05:00
commit a4ff8ada65
15 changed files with 2974 additions and 0 deletions
+5
View File
@@ -0,0 +1,5 @@
__pycache__/
*.py[cod]
*.egg-info
./dist
MUJOCO_LOG.TXT
+19
View File
@@ -0,0 +1,19 @@
Copyright (c) 2020 Danijar Hafner
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
+106
View File
@@ -0,0 +1,106 @@
# Mastering Atari with Discrete World Models
Implementation of the [DreamerV2][website] agent in TensorFlow 2. Training
curves for all 55 games are included.
<p align="center">
<img width="90%" src="https://imgur.com/gO1rvEn.gif">
</p>
If you find this code useful, please reference in your paper:
```
@article{hafner2020dreamerv2,
title={Mastering Atari with Discrete World Models},
author={Hafner, Danijar and Lillicrap, Timothy and Norouzi, Mohammad and Ba, Jimmy},
journal={arXiv preprint arXiv:2010.02193},
year={2020}
}
```
[website]: https://danijar.com/dreamerv2
## Method
DreamerV2 is the first world model agent that achieves human-level performance
on the Atari benchmark. DreamerV2 also outperforms the final performance of the
top model-free agents Rainbow and IQN using the same amount of experience and
computation. The implementation in this repository alternates between training
the world model, training the policy, and collecting experience and runs on a
single GPU.
![World Model Learning](https://imgur.com/GRC9QAw.png)
DreamerV2 learns a model of the environment directly from high-dimensional
input images. For this, it predicts ahead using compact learned states. The
states consist of a deterministic part and several categorical variables that
are sampled. The prior for these categoricals is learned through a KL loss. The
world model is learned end-to-end via straight-through gradients, meaning that
the gradient of the density is set to the gradient of the sample.
![Actor Critic Learning](https://imgur.com/wH9kJ2O.png)
DreamerV2 learns actor and critic networks from imagined trajectories of latent
states. The trajectories start at encoded states of previously encountered
sequences. The world model then predicts ahead using the selected actions and
its learned state prior. The critic is trained using temporal difference
learning and the actor is trained to maximize the value function via reinforce
and straight-through gradients.
For more information:
- [Google AI Blog post](https://ai.googleblog.com/2021/02/mastering-atari-with-discrete-world.html)
- [Project website](https://danijar.com/dreamerv2/)
- [Research paper](https://arxiv.org/pdf/2010.02193.pdf)
## Instructions
Get dependencies:
```sh
pip3 install --user tensorflow==2.3.1
pip3 install --user tensorflow_probability==0.11.1
pip3 install --user pandas
pip3 install --user matplotlib
pip3 install --user ruamel.yaml
pip3 install --user 'gym[atari]'
```
Train the agent:
```sh
python3 dreamer.py --logdir ~/logdir/atari_pong/dreamerv2/1 \
--configs defaults atari --task atari_pong
```
Monitor results:
```sh
tensorboard --logdir ~/logdir
```
Generate plots:
```sh
python3 plotting.py --indir ~/logdir --outdir ~/plots --xaxis step --yaxis eval_return --bins 1e6
```
Tips:
- **Efficient debugging.** You can use the `debug` config as in `--configs
defaults atari debug`. This reduces the batch size, increases the evaluation
frequency, and disables `tf.function` graph compilation for easy line-by-line
debugging.
- **Infinite gradient norms.** This is normal and described under loss scaling in
the [mixed precision][mixed] guide. You can disable mixed precision by passing
`--precision 32` to the training script. Mixed precision is faster but can in
principle cause numerical instabilities.
- **Accessing logged metrics.** The metrics are stored in both TensorBoard and
JSON lines format. You can directly load them using `pandas.read_json()`. The
plotting script also stores the binned and aggregated metrics of multiple runs
into a single JSON lines file for easy manual plotting.
[mixed]: https://www.tensorflow.org/guide/mixed_precision
+228
View File
@@ -0,0 +1,228 @@
defaults:
logdir: null
traindir: null
evaldir: null
offline_traindir: ''
offline_evaldir: ''
seed: 0
steps: 1e7
eval_every: 1e4
log_every: 1e4
reset_every: 0
gpu_growth: True
precision: 32
debug: False
expl_gifs: False
# Environment
task: 'dmc_walker_walk'
size: [64, 64]
envs: 1
action_repeat: 2
time_limit: 1000
grayscale: False
prefill: 2500
eval_noise: 0.0
clip_rewards: 'identity'
# Model
dyn_cell: 'gru'
dyn_hidden: 200
dyn_deter: 200
dyn_stoch: 50
dyn_discrete: 0
dyn_input_layers: 1
dyn_output_layers: 1
dyn_rec_depth: 1
dyn_shared: False
dyn_mean_act: 'none'
dyn_std_act: 'sigmoid2'
dyn_min_std: 0.1
dyn_temp_post: True
grad_heads: ['image', 'reward']
units: 400
reward_layers: 2
discount_layers: 3
value_layers: 3
actor_layers: 4
act: 'elu'
cnn_depth: 32
encoder_kernels: [4, 4, 4, 4]
decoder_kernels: [5, 5, 6, 6]
decoder_thin: True
value_head: 'normal'
kl_scale: '1.0'
kl_balance: '0.8'
kl_free: '1.0'
kl_forward: False
pred_discount: False
discount_scale: 1.0
reward_scale: 1.0
weight_decay: 0.0
# Training
batch_size: 50
batch_length: 50
train_every: 5
train_steps: 1
pretrain: 100
model_lr: 3e-4
value_lr: 8e-5
actor_lr: 8e-5
opt_eps: 1e-5
grad_clip: 100
value_grad_clip: 100
actor_grad_clip: 100
dataset_size: 0
oversample_ends: False
slow_value_target: True
slow_actor_target: True
slow_target_update: 100
slow_target_fraction: 1
opt: 'adam'
# Behavior.
discount: 0.99
discount_lambda: 0.95
imag_horizon: 15
imag_gradient: 'dynamics'
imag_gradient_mix: '0.1'
imag_sample: True
actor_dist: 'trunc_normal'
actor_entropy: '1e-4'
actor_state_entropy: 0.0
actor_init_std: 1.0
actor_min_std: 0.1
actor_disc: 5
actor_temp: 0.1
actor_outscale: 0.0
expl_amount: 0.0
eval_state_mean: False
collect_dyn_sample: True
behavior_stop_grad: True
value_decay: 0.0
future_entropy: False
# Exploration
expl_behavior: 'greedy'
expl_until: 0
expl_extr_scale: 0.0
expl_intr_scale: 1.0
disag_target: 'stoch'
disag_log: True
disag_models: 10
disag_offset: 1
disag_layers: 4
disag_units: 400
atari:
# General
task: 'atari_pong'
steps: 2e8
eval_every: 1e5
log_every: 1e4
prefill: 50000
dataset_size: 2e6
pretrain: 0
precision: 16
# Environment
time_limit: 108000 # 30 minutes of game play.
grayscale: True
action_repeat: 4
eval_noise: 0.0
train_every: 16
train_steps: 1
clip_rewards: 'tanh'
# Model
grad_heads: ['image', 'reward', 'discount']
dyn_cell: 'gru_layer_norm'
pred_discount: True
cnn_depth: 48
dyn_deter: 600
dyn_hidden: 600
dyn_stoch: 32
dyn_discrete: 32
reward_layers: 4
discount_layers: 4
value_layers: 4
actor_layers: 4
# Behavior
actor_dist: 'onehot'
actor_entropy: 'linear(3e-3,3e-4,2.5e6)'
expl_amount: 0.0
discount: 0.999
imag_gradient: 'both'
imag_gradient_mix: 'linear(0.1,0,2.5e6)'
# Training
discount_scale: 5.0
reward_scale: 1
weight_decay: 1e-6
model_lr: 2e-4
kl_scale: 0.1
kl_free: 0.0
actor_lr: 4e-5
value_lr: 1e-4
oversample_ends: True
dmc:
# General
task: 'dmc_walker_walk'
steps: 1e7
eval_every: 1e4
log_every: 1e4
prefill: 2500
dataset_size: 0
pretrain: 100
# Environment
time_limit: 1000
action_repeat: 2
train_every: 5
train_steps: 1
# Model
grad_heads: ['image', 'reward']
dyn_cell: 'gru_layer_norm'
pred_discount: False
cnn_depth: 32
dyn_deter: 200
dyn_stoch: 50
dyn_discrete: 0
reward_layers: 2
discount_layers: 3
value_layers: 3
actor_layers: 4
# Behavior
actor_dist: 'trunc_normal'
expl_amount: 0.0
actor_entropy: '1e-4'
discount: 0.99
imag_gradient: 'dynamics'
imag_gradient_mix: 1.0
# Training
reward_scale: 2
weight_decay: 0.0
model_lr: 3e-4
value_lr: 8e-5
actor_lr: 8e-5
opt_eps: 1e-5
kl_free: '1.0'
kl_scale: '1.0'
debug:
debug: True
pretrain: 1
prefill: 1
train_steps: 1
batch_size: 10
batch_length: 20
+323
View File
@@ -0,0 +1,323 @@
import argparse
import collections
import functools
import os
import pathlib
import sys
import warnings
warnings.filterwarnings('ignore', '.*box bound precision lowered.*')
warnings.filterwarnings('ignore', '.*TensorFloat-32 matmul/conv*')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['MUJOCO_GL'] = 'egl'
import numpy as np
import ruamel.yaml as yaml
import tensorflow as tf
from tensorflow.keras.mixed_precision import experimental as prec
tf.get_logger().setLevel('ERROR')
from tensorflow_probability import distributions as tfd
sys.path.append(str(pathlib.Path(__file__).parent))
import exploration as expl
import models
import tools
import wrappers
class Dreamer(tools.Module):
def __init__(self, config, logger, dataset):
self._config = config
self._logger = logger
self._float = prec.global_policy().compute_dtype
self._should_log = tools.Every(config.log_every)
self._should_train = tools.Every(config.train_every)
self._should_pretrain = tools.Once()
self._should_reset = tools.Every(config.reset_every)
self._should_expl = tools.Until(int(
config.expl_until / config.action_repeat))
self._metrics = collections.defaultdict(tf.metrics.Mean)
with tf.device('cpu:0'):
self._step = tf.Variable(count_steps(config.traindir), dtype=tf.int64)
# Schedules.
config.actor_entropy = (
lambda x=config.actor_entropy: tools.schedule(x, self._step))
config.actor_state_entropy = (
lambda x=config.actor_state_entropy: tools.schedule(x, self._step))
config.imag_gradient_mix = (
lambda x=config.imag_gradient_mix: tools.schedule(x, self._step))
self._dataset = iter(dataset)
self._wm = models.WorldModel(self._step, config)
self._task_behavior = models.ImagBehavior(
config, self._wm, config.behavior_stop_grad)
reward = lambda f, s, a: self._wm.heads['reward'](f).mode()
self._expl_behavior = dict(
greedy=lambda: self._task_behavior,
random=lambda: expl.Random(config),
plan2explore=lambda: expl.Plan2Explore(config, self._wm, reward),
)[config.expl_behavior]()
# Train step to initialize variables including optimizer statistics.
self._train(next(self._dataset))
def __call__(self, obs, reset, state=None, training=True):
step = self._step.numpy().item()
if self._should_reset(step):
state = None
if state is not None and reset.any():
mask = tf.cast(1 - reset, self._float)[:, None]
state = tf.nest.map_structure(lambda x: x * mask, state)
if training and self._should_train(step):
steps = (
self._config.pretrain if self._should_pretrain()
else self._config.train_steps)
for _ in range(steps):
self._train(next(self._dataset))
if self._should_log(step):
for name, mean in self._metrics.items():
self._logger.scalar(name, float(mean.result()))
mean.reset_states()
openl = self._wm.video_pred(next(self._dataset))
self._logger.video('train_openl', openl)
self._logger.write(fps=True)
policy_output, state = self._policy(obs, state, training)
if training:
self._step.assign_add(len(reset))
self._logger.step = self._config.action_repeat \
* self._step.numpy().item()
return policy_output, state
@tf.function
def _policy(self, obs, state, training):
if state is None:
batch_size = len(obs['image'])
latent = self._wm.dynamics.initial(len(obs['image']))
action = tf.zeros((batch_size, self._config.num_actions), self._float)
else:
latent, action = state
embed = self._wm.encoder(self._wm.preprocess(obs))
latent, _ = self._wm.dynamics.obs_step(
latent, action, embed, self._config.collect_dyn_sample)
if self._config.eval_state_mean:
latent['stoch'] = latent['mean']
feat = self._wm.dynamics.get_feat(latent)
if not training:
actor = self._task_behavior.actor(feat)
action = actor.mode()
elif self._should_expl(self._step):
actor = self._expl_behavior.actor(feat)
action = actor.sample()
else:
actor = self._task_behavior.actor(feat)
action = actor.sample()
logprob = actor.log_prob(tf.cast(action, tf.float32))
if self._config.actor_dist == 'onehot_gumble':
action = tf.cast(
tf.one_hot(tf.argmax(action, axis=-1), self._config.num_actions),
action.dtype)
action = self._exploration(action, training)
policy_output = {'action': action, 'logprob': logprob}
state = (latent, action)
return policy_output, state
def _exploration(self, action, training):
amount = self._config.expl_amount if training else self._config.eval_noise
if amount == 0:
return action
amount = tf.cast(amount, self._float)
if 'onehot' in self._config.actor_dist:
probs = amount / self._config.num_actions + (1 - amount) * action
return tools.OneHotDist(probs=probs).sample()
else:
return tf.clip_by_value(tfd.Normal(action, amount).sample(), -1, 1)
raise NotImplementedError(self._config.action_noise)
@tf.function
def _train(self, data):
print('Tracing train function.')
metrics = {}
post, context, mets = self._wm.train(data)
metrics.update(mets)
start = post
if self._config.pred_discount: # Last step could be terminal.
start = {k: v[:, :-1] for k, v in post.items()}
context = {k: v[:, :-1] for k, v in context.items()}
reward = lambda f, s, a: self._wm.heads['reward'](
self._wm.dynamics.get_feat(s)).mode()
metrics.update(self._task_behavior.train(start, reward)[-1])
if self._config.expl_behavior != 'greedy':
mets = self._expl_behavior.train(start, context)[-1]
metrics.update({'expl_' + key: value for key, value in mets.items()})
for name, value in metrics.items():
self._metrics[name].update_state(value)
def count_steps(folder):
return sum(int(str(n).split('-')[-1][:-4]) - 1 for n in folder.glob('*.npz'))
def make_dataset(episodes, config):
example = episodes[next(iter(episodes.keys()))]
types = {k: v.dtype for k, v in example.items()}
shapes = {k: (None,) + v.shape[1:] for k, v in example.items()}
generator = lambda: tools.sample_episodes(
episodes, config.batch_length, config.oversample_ends)
dataset = tf.data.Dataset.from_generator(generator, types, shapes)
dataset = dataset.batch(config.batch_size, drop_remainder=True)
dataset = dataset.prefetch(10)
return dataset
def make_env(config, logger, mode, train_eps, eval_eps):
suite, task = config.task.split('_', 1)
if suite == 'dmc':
env = wrappers.DeepMindControl(task, config.action_repeat, config.size)
env = wrappers.NormalizeActions(env)
elif suite == 'atari':
env = wrappers.Atari(
task, config.action_repeat, config.size,
grayscale=config.grayscale,
life_done=False and (mode == 'train'),
sticky_actions=True,
all_actions=True)
env = wrappers.OneHotAction(env)
else:
raise NotImplementedError(suite)
env = wrappers.TimeLimit(env, config.time_limit)
env = wrappers.SelectAction(env, key='action')
callbacks = [functools.partial(
process_episode, config, logger, mode, train_eps, eval_eps)]
env = wrappers.CollectDataset(env, callbacks)
env = wrappers.RewardObs(env)
return env
def process_episode(config, logger, mode, train_eps, eval_eps, episode):
directory = dict(train=config.traindir, eval=config.evaldir)[mode]
cache = dict(train=train_eps, eval=eval_eps)[mode]
filename = tools.save_episodes(directory, [episode])[0]
length = len(episode['reward']) - 1
score = float(episode['reward'].astype(np.float64).sum())
video = episode['image']
if mode == 'eval':
cache.clear()
if mode == 'train' and config.dataset_size:
total = 0
for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])):
if total <= config.dataset_size - length:
total += len(ep['reward']) - 1
else:
del cache[key]
logger.scalar('dataset_size', total + length)
cache[str(filename)] = episode
print(f'{mode.title()} episode has {length} steps and return {score:.1f}.')
logger.scalar(f'{mode}_return', score)
logger.scalar(f'{mode}_length', length)
logger.scalar(f'{mode}_episodes', len(cache))
if mode == 'eval' or config.expl_gifs:
logger.video(f'{mode}_policy', video[None])
logger.write()
def main(config):
logdir = pathlib.Path(config.logdir).expanduser()
config.traindir = config.traindir or logdir / 'train_eps'
config.evaldir = config.evaldir or logdir / 'eval_eps'
config.steps //= config.action_repeat
config.eval_every //= config.action_repeat
config.log_every //= config.action_repeat
config.time_limit //= config.action_repeat
config.act = getattr(tf.nn, config.act)
if config.debug:
tf.config.experimental_run_functions_eagerly(True)
if config.gpu_growth:
message = 'No GPU found. To actually train on CPU remove this assert.'
assert tf.config.experimental.list_physical_devices('GPU'), message
for gpu in tf.config.experimental.list_physical_devices('GPU'):
tf.config.experimental.set_memory_growth(gpu, True)
assert config.precision in (16, 32), config.precision
if config.precision == 16:
prec.set_policy(prec.Policy('mixed_float16'))
print('Logdir', logdir)
logdir.mkdir(parents=True, exist_ok=True)
config.traindir.mkdir(parents=True, exist_ok=True)
config.evaldir.mkdir(parents=True, exist_ok=True)
step = count_steps(config.traindir)
logger = tools.Logger(logdir, config.action_repeat * step)
print('Create envs.')
if config.offline_traindir:
directory = config.offline_traindir.format(**vars(config))
else:
directory = config.traindir
train_eps = tools.load_episodes(directory, limit=config.dataset_size)
if config.offline_evaldir:
directory = config.offline_evaldir.format(**vars(config))
else:
directory = config.evaldir
eval_eps = tools.load_episodes(directory, limit=1)
make = lambda mode: make_env(config, logger, mode, train_eps, eval_eps)
train_envs = [make('train') for _ in range(config.envs)]
eval_envs = [make('eval') for _ in range(config.envs)]
acts = train_envs[0].action_space
config.num_actions = acts.n if hasattr(acts, 'n') else acts.shape[0]
prefill = max(0, config.prefill - count_steps(config.traindir))
print(f'Prefill dataset ({prefill} steps).')
if hasattr(acts, 'discrete'):
random_actor = tools.OneHotDist(tf.zeros_like(acts.low)[None])
else:
random_actor = tfd.Independent(
tfd.Uniform(acts.low[None], acts.high[None]), 1)
def random_agent(o, d, s):
action = random_actor.sample()
logprob = random_actor.log_prob(action)
return {'action': action, 'logprob': logprob}, None
tools.simulate(random_agent, train_envs, prefill)
tools.simulate(random_agent, eval_envs, episodes=1)
logger.step = config.action_repeat * count_steps(config.traindir)
print('Simulate agent.')
train_dataset = make_dataset(train_eps, config)
eval_dataset = iter(make_dataset(eval_eps, config))
agent = Dreamer(config, logger, train_dataset)
if (logdir / 'variables.pkl').exists():
agent.load(logdir / 'variables.pkl')
agent._should_pretrain._once = False
state = None
while agent._step.numpy().item() < config.steps:
logger.write()
print('Start evaluation.')
video_pred = agent._wm.video_pred(next(eval_dataset))
logger.video('eval_openl', video_pred)
eval_policy = functools.partial(agent, training=False)
tools.simulate(eval_policy, eval_envs, episodes=1)
print('Start training.')
state = tools.simulate(agent, train_envs, config.eval_every, state=state)
agent.save(logdir / 'variables.pkl')
for env in train_envs + eval_envs:
try:
env.close()
except Exception:
pass
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--configs', nargs='+', required=True)
args, remaining = parser.parse_known_args()
configs = yaml.safe_load(
(pathlib.Path(sys.argv[0]).parent / 'configs.yaml').read_text())
defaults = {}
for name in args.configs:
defaults.update(configs[name])
parser = argparse.ArgumentParser()
for key, value in sorted(defaults.items(), key=lambda x: x[0]):
arg_type = tools.args_type(value)
parser.add_argument(f'--{key}', type=arg_type, default=arg_type(value))
main(parser.parse_args(remaining))
+91
View File
@@ -0,0 +1,91 @@
import tensorflow as tf
from tensorflow.keras.mixed_precision import experimental as prec
from tensorflow_probability import distributions as tfd
import models
import networks
import tools
class Random(tools.Module):
def __init__(self, config):
self._config = config
self._float = prec.global_policy().compute_dtype
def actor(self, feat):
shape = feat.shape[:-1] + [self._config.num_actions]
if self._config.actor_dist == 'onehot':
return tools.OneHotDist(tf.zeros(shape))
else:
ones = tf.ones(shape, self._float)
return tfd.Uniform(-ones, ones)
def train(self, start, context):
return None, {}
class Plan2Explore(tools.Module):
def __init__(self, config, world_model, reward=None):
self._config = config
self._reward = reward
self._behavior = models.ImagBehavior(config, world_model)
self.actor = self._behavior.actor
stoch_size = config.dyn_stoch
if config.dyn_discrete:
stoch_size *= config.dyn_discrete
size = {
'embed': 32 * config.cnn_depth,
'stoch': stoch_size,
'deter': config.dyn_deter,
'feat': config.dyn_stoch + config.dyn_deter,
}[self._config.disag_target]
kw = dict(
shape=size, layers=config.disag_layers, units=config.disag_units,
act=config.act)
self._networks = [
networks.DenseHead(**kw) for _ in range(config.disag_models)]
self._opt = tools.Optimizer(
'ensemble', config.model_lr, config.opt_eps, config.grad_clip,
config.weight_decay, opt=config.opt)
def train(self, start, context):
metrics = {}
stoch = start['stoch']
if self._config.dyn_discrete:
stoch = tf.reshape(
stoch, stoch.shape[:-2] + (stoch.shape[-2] * stoch.shape[-1]))
target = {
'embed': context['embed'],
'stoch': stoch,
'deter': start['deter'],
'feat': context['feat'],
}[self._config.disag_target]
metrics.update(self._train_ensemble(context['feat'], target))
metrics.update(self._behavior.train(start, self._intrinsic_reward)[-1])
return None, metrics
def _intrinsic_reward(self, feat, state, action):
preds = [head(feat, tf.float32).mean() for head in self._networks]
disag = tf.reduce_mean(tf.math.reduce_std(preds, 0), -1)
if self._config.disag_log:
disag = tf.math.log(disag)
reward = self._config.expl_intr_scale * disag
if self._config.expl_extr_scale:
reward += tf.cast(self._config.expl_extr_scale * self._reward(
feat, state, action), tf.float32)
return reward
def _train_ensemble(self, inputs, targets):
if self._config.disag_offset:
targets = targets[:, self._config.disag_offset:]
inputs = inputs[:, :-self._config.disag_offset]
targets = tf.stop_gradient(targets)
inputs = tf.stop_gradient(inputs)
with tf.GradientTape() as tape:
preds = [head(inputs) for head in self._networks]
likes = [tf.reduce_mean(pred.log_prob(targets)) for pred in preds]
loss = -tf.cast(tf.reduce_sum(likes), tf.float32)
metrics = self._opt(tape, loss, self._networks)
return metrics
+260
View File
@@ -0,0 +1,260 @@
import tensorflow as tf
from tensorflow.keras.mixed_precision import experimental as prec
import networks
import tools
class WorldModel(tools.Module):
def __init__(self, step, config):
self._step = step
self._config = config
self.encoder = networks.ConvEncoder(
config.cnn_depth, config.act, config.encoder_kernels)
self.dynamics = networks.RSSM(
config.dyn_stoch, config.dyn_deter, config.dyn_hidden,
config.dyn_input_layers, config.dyn_output_layers,
config.dyn_rec_depth, config.dyn_shared, config.dyn_discrete,
config.act, config.dyn_mean_act, config.dyn_std_act,
config.dyn_temp_post, config.dyn_min_std, config.dyn_cell)
self.heads = {}
channels = (1 if config.grayscale else 3)
shape = config.size + (channels,)
self.heads['image'] = networks.ConvDecoder(
config.cnn_depth, config.act, shape, config.decoder_kernels,
config.decoder_thin)
self.heads['reward'] = networks.DenseHead(
[], config.reward_layers, config.units, config.act)
if config.pred_discount:
self.heads['discount'] = networks.DenseHead(
[], config.discount_layers, config.units, config.act, dist='binary')
for name in config.grad_heads:
assert name in self.heads, name
self._model_opt = tools.Optimizer(
'model', config.model_lr, config.opt_eps, config.grad_clip,
config.weight_decay, opt=config.opt)
self._scales = dict(
reward=config.reward_scale, discount=config.discount_scale)
def train(self, data):
data = self.preprocess(data)
with tf.GradientTape() as model_tape:
embed = self.encoder(data)
post, prior = self.dynamics.observe(embed, data['action'])
kl_balance = tools.schedule(self._config.kl_balance, self._step)
kl_free = tools.schedule(self._config.kl_free, self._step)
kl_scale = tools.schedule(self._config.kl_scale, self._step)
kl_loss, kl_value = self.dynamics.kl_loss(
post, prior, self._config.kl_forward, kl_balance, kl_free, kl_scale)
losses = {}
likes = {}
for name, head in self.heads.items():
grad_head = (name in self._config.grad_heads)
feat = self.dynamics.get_feat(post)
feat = feat if grad_head else tf.stop_gradient(feat)
pred = head(feat, tf.float32)
like = pred.log_prob(tf.cast(data[name], tf.float32))
likes[name] = like
losses[name] = -tf.reduce_mean(like) * self._scales.get(name, 1.0)
model_loss = sum(losses.values()) + kl_loss
model_parts = [self.encoder, self.dynamics] + list(self.heads.values())
metrics = self._model_opt(model_tape, model_loss, model_parts)
metrics.update({f'{name}_loss': loss for name, loss in losses.items()})
metrics['kl_balance'] = kl_balance
metrics['kl_free'] = kl_free
metrics['kl_scale'] = kl_scale
metrics['kl'] = tf.reduce_mean(kl_value)
metrics['prior_ent'] = self.dynamics.get_dist(prior).entropy()
metrics['post_ent'] = self.dynamics.get_dist(post).entropy()
context = dict(
embed=embed, feat=self.dynamics.get_feat(post),
kl=kl_value, postent=self.dynamics.get_dist(post).entropy())
return post, context, metrics
@tf.function
def preprocess(self, obs):
dtype = prec.global_policy().compute_dtype
obs = obs.copy()
obs['image'] = tf.cast(obs['image'], dtype) / 255.0 - 0.5
obs['reward'] = getattr(tf, self._config.clip_rewards)(obs['reward'])
if 'discount' in obs:
obs['discount'] *= self._config.discount
for key, value in obs.items():
if tf.dtypes.as_dtype(value.dtype) in (
tf.float16, tf.float32, tf.float64):
obs[key] = tf.cast(value, dtype)
return obs
@tf.function
def video_pred(self, data):
data = self.preprocess(data)
truth = data['image'][:6] + 0.5
embed = self.encoder(data)
states, _ = self.dynamics.observe(embed[:6, :5], data['action'][:6, :5])
recon = self.heads['image'](
self.dynamics.get_feat(states)).mode()[:6]
init = {k: v[:, -1] for k, v in states.items()}
prior = self.dynamics.imagine(data['action'][:6, 5:], init)
openl = self.heads['image'](self.dynamics.get_feat(prior)).mode()
model = tf.concat([recon[:, :5] + 0.5, openl + 0.5], 1)
error = (model - truth + 1) / 2
return tf.concat([truth, model, error], 2)
class ImagBehavior(tools.Module):
def __init__(self, config, world_model, stop_grad_actor=True, reward=None):
self._config = config
self._world_model = world_model
self._stop_grad_actor = stop_grad_actor
self._reward = reward
self.actor = networks.ActionHead(
config.num_actions, config.actor_layers, config.units, config.act,
config.actor_dist, config.actor_init_std, config.actor_min_std,
config.actor_dist, config.actor_temp, config.actor_outscale)
self.value = networks.DenseHead(
[], config.value_layers, config.units, config.act,
config.value_head)
if config.slow_value_target or config.slow_actor_target:
self._slow_value = networks.DenseHead(
[], config.value_layers, config.units, config.act)
self._updates = tf.Variable(0, tf.int64)
kw = dict(wd=config.weight_decay, opt=config.opt)
self._actor_opt = tools.Optimizer(
'actor', config.actor_lr, config.opt_eps, config.actor_grad_clip, **kw)
self._value_opt = tools.Optimizer(
'value', config.value_lr, config.opt_eps, config.value_grad_clip, **kw)
def train(
self, start, objective=None, imagine=None, tape=None, repeats=None):
objective = objective or self._reward
self._update_slow_target()
metrics = {}
with (tape or tf.GradientTape()) as actor_tape:
assert bool(objective) != bool(imagine)
if objective:
imag_feat, imag_state, imag_action = self._imagine(
start, self.actor, self._config.imag_horizon, repeats)
reward = objective(imag_feat, imag_state, imag_action)
else:
imag_feat, imag_state, imag_action, reward = imagine(start)
actor_ent = self.actor(imag_feat, tf.float32).entropy()
state_ent = self._world_model.dynamics.get_dist(
imag_state, tf.float32).entropy()
target, weights = self._compute_target(
imag_feat, imag_state, imag_action, reward, actor_ent, state_ent,
self._config.slow_actor_target)
actor_loss, mets = self._compute_actor_loss(
imag_feat, imag_state, imag_action, target, actor_ent, state_ent,
weights)
metrics.update(mets)
if self._config.slow_value_target != self._config.slow_actor_target:
target, weights = self._compute_target(
imag_feat, imag_state, imag_action, reward, actor_ent, state_ent,
self._config.slow_value_target)
value_input = imag_feat
with tf.GradientTape() as value_tape:
value = self.value(value_input, tf.float32)[:-1]
value_loss = -value.log_prob(tf.stop_gradient(target))
if self._config.value_decay:
value_loss += self._config.value_decay * value.mode()
value_loss = tf.reduce_mean(weights[:-1] * value_loss)
metrics['reward_mean'] = tf.reduce_mean(reward)
metrics['reward_std'] = tf.math.reduce_std(reward)
metrics['actor_ent'] = tf.reduce_mean(actor_ent)
metrics.update(self._actor_opt(actor_tape, actor_loss, [self.actor]))
metrics.update(self._value_opt(value_tape, value_loss, [self.value]))
return imag_feat, imag_state, imag_action, weights, metrics
def _imagine(self, start, policy, horizon, repeats=None):
dynamics = self._world_model.dynamics
if repeats:
start = {k: tf.repeat(v, repeats, axis=1) for k, v in start.items()}
flatten = lambda x: tf.reshape(x, [-1] + list(x.shape[2:]))
start = {k: flatten(v) for k, v in start.items()}
def step(prev, _):
state, _, _ = prev
feat = dynamics.get_feat(state)
inp = tf.stop_gradient(feat) if self._stop_grad_actor else feat
action = policy(inp).sample()
succ = dynamics.img_step(state, action, sample=self._config.imag_sample)
return succ, feat, action
feat = 0 * dynamics.get_feat(start)
action = policy(feat).mode()
succ, feats, actions = tools.static_scan(
step, tf.range(horizon), (start, feat, action))
states = {k: tf.concat([
start[k][None], v[:-1]], 0) for k, v in succ.items()}
if repeats:
def unfold(tensor):
s = tensor.shape
return tf.reshape(tensor, [s[0], s[1] // repeats, repeats] + s[2:])
states, feats, actions = tf.nest.map_structure(
unfold, (states, feats, actions))
return feats, states, actions
def _compute_target(
self, imag_feat, imag_state, imag_action, reward, actor_ent, state_ent,
slow):
reward = tf.cast(reward, tf.float32)
if 'discount' in self._world_model.heads:
inp = self._world_model.dynamics.get_feat(imag_state)
discount = self._world_model.heads['discount'](inp, tf.float32).mean()
else:
discount = self._config.discount * tf.ones_like(reward)
if self._config.future_entropy and tf.greater(
self._config.actor_entropy(), 0):
reward += self._config.actor_entropy() * actor_ent
if self._config.future_entropy and tf.greater(
self._config.actor_state_entropy(), 0):
reward += self._config.actor_state_entropy() * state_ent
if slow:
value = self._slow_value(imag_feat, tf.float32).mode()
else:
value = self.value(imag_feat, tf.float32).mode()
target = tools.lambda_return(
reward[:-1], value[:-1], discount[:-1],
bootstrap=value[-1], lambda_=self._config.discount_lambda, axis=0)
weights = tf.stop_gradient(tf.math.cumprod(tf.concat(
[tf.ones_like(discount[:1]), discount[:-1]], 0), 0))
return target, weights
def _compute_actor_loss(
self, imag_feat, imag_state, imag_action, target, actor_ent, state_ent,
weights):
metrics = {}
inp = tf.stop_gradient(imag_feat) if self._stop_grad_actor else imag_feat
policy = self.actor(inp, tf.float32)
actor_ent = policy.entropy()
if self._config.imag_gradient == 'dynamics':
actor_target = target
elif self._config.imag_gradient == 'reinforce':
imag_action = tf.cast(imag_action, tf.float32)
actor_target = policy.log_prob(imag_action)[:-1] * tf.stop_gradient(
target - self.value(imag_feat[:-1], tf.float32).mode())
elif self._config.imag_gradient == 'both':
imag_action = tf.cast(imag_action, tf.float32)
actor_target = policy.log_prob(imag_action)[:-1] * tf.stop_gradient(
target - self.value(imag_feat[:-1], tf.float32).mode())
mix = self._config.imag_gradient_mix()
actor_target = mix * target + (1 - mix) * actor_target
metrics['imag_gradient_mix'] = mix
else:
raise NotImplementedError(self._config.imag_gradient)
if not self._config.future_entropy and tf.greater(
self._config.actor_entropy(), 0):
actor_target += self._config.actor_entropy() * actor_ent[:-1]
if not self._config.future_entropy and tf.greater(
self._config.actor_state_entropy(), 0):
actor_target += self._config.actor_state_entropy() * state_ent[:-1]
actor_loss = -tf.reduce_mean(weights[:-1] * actor_target)
return actor_loss, metrics
def _update_slow_target(self):
if self._config.slow_value_target or self._config.slow_actor_target:
if self._updates % self._config.slow_target_update == 0:
mix = self._config.slow_target_fraction
for s, d in zip(self.value.variables, self._slow_value.variables):
d.assign(mix * s + (1 - mix) * d)
self._updates.assign_add(1)
+397
View File
@@ -0,0 +1,397 @@
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers as tfkl
from tensorflow_probability import distributions as tfd
from tensorflow.keras.mixed_precision import experimental as prec
import tools
class RSSM(tools.Module):
def __init__(
self, stoch=30, deter=200, hidden=200, layers_input=1, layers_output=1,
rec_depth=1, shared=False, discrete=False, act=tf.nn.elu,
mean_act='none', std_act='softplus', temp_post=True, min_std=0.1,
cell='keras'):
super().__init__()
self._stoch = stoch
self._deter = deter
self._hidden = hidden
self._min_std = min_std
self._layers_input = layers_input
self._layers_output = layers_output
self._rec_depth = rec_depth
self._shared = shared
self._discrete = discrete
self._act = act
self._mean_act = mean_act
self._std_act = std_act
self._temp_post = temp_post
self._embed = None
if cell == 'gru':
self._cell = tfkl.GRUCell(self._deter)
elif cell == 'gru_layer_norm':
self._cell = GRUCell(self._deter, norm=True)
else:
raise NotImplementedError(cell)
def initial(self, batch_size):
dtype = prec.global_policy().compute_dtype
if self._discrete:
state = dict(
logit=tf.zeros([batch_size, self._stoch, self._discrete], dtype),
stoch=tf.zeros([batch_size, self._stoch, self._discrete], dtype),
deter=self._cell.get_initial_state(None, batch_size, dtype))
else:
state = dict(
mean=tf.zeros([batch_size, self._stoch], dtype),
std=tf.zeros([batch_size, self._stoch], dtype),
stoch=tf.zeros([batch_size, self._stoch], dtype),
deter=self._cell.get_initial_state(None, batch_size, dtype))
return state
@tf.function
def observe(self, embed, action, state=None):
swap = lambda x: tf.transpose(x, [1, 0] + list(range(2, len(x.shape))))
if state is None:
state = self.initial(tf.shape(action)[0])
embed, action = swap(embed), swap(action)
post, prior = tools.static_scan(
lambda prev, inputs: self.obs_step(prev[0], *inputs),
(action, embed), (state, state))
post = {k: swap(v) for k, v in post.items()}
prior = {k: swap(v) for k, v in prior.items()}
return post, prior
@tf.function
def imagine(self, action, state=None):
swap = lambda x: tf.transpose(x, [1, 0] + list(range(2, len(x.shape))))
if state is None:
state = self.initial(tf.shape(action)[0])
assert isinstance(state, dict), state
action = swap(action)
prior = tools.static_scan(self.img_step, action, state)
prior = {k: swap(v) for k, v in prior.items()}
return prior
def get_feat(self, state):
stoch = state['stoch']
if self._discrete:
shape = stoch.shape[:-2] + [self._stoch * self._discrete]
stoch = tf.reshape(stoch, shape)
return tf.concat([stoch, state['deter']], -1)
def get_dist(self, state, dtype=None):
if self._discrete:
logit = state['logit']
logit = tf.cast(logit, tf.float32)
dist = tfd.Independent(tools.OneHotDist(logit), 1)
if dtype != tf.float32:
dist = tools.DtypeDist(dist, dtype or state['logit'].dtype)
else:
mean, std = state['mean'], state['std']
if dtype:
mean = tf.cast(mean, dtype)
std = tf.cast(std, dtype)
dist = tfd.MultivariateNormalDiag(mean, std)
return dist
@tf.function
def obs_step(self, prev_state, prev_action, embed, sample=True):
if not self._embed:
self._embed = embed.shape[-1]
prior = self.img_step(prev_state, prev_action, None, sample)
if self._shared:
post = self.img_step(prev_state, prev_action, embed, sample)
else:
if self._temp_post:
x = tf.concat([prior['deter'], embed], -1)
else:
x = embed
for i in range(self._layers_output):
x = self.get(f'obi{i}', tfkl.Dense, self._hidden, self._act)(x)
stats = self._suff_stats_layer('obs', x)
if sample:
stoch = self.get_dist(stats).sample()
else:
stoch = self.get_dist(stats).mode()
post = {'stoch': stoch, 'deter': prior['deter'], **stats}
return post, prior
@tf.function
def img_step(self, prev_state, prev_action, embed=None, sample=True):
prev_stoch = prev_state['stoch']
if self._discrete:
shape = prev_stoch.shape[:-2] + [self._stoch * self._discrete]
prev_stoch = tf.reshape(prev_stoch, shape)
if self._shared:
if embed is None:
shape = prev_action.shape[:-1] + [self._embed]
embed = tf.zeros(shape, prev_action.dtype)
x = tf.concat([prev_stoch, prev_action, embed], -1)
else:
x = tf.concat([prev_stoch, prev_action], -1)
for i in range(self._layers_input):
x = self.get(f'ini{i}', tfkl.Dense, self._hidden, self._act)(x)
for _ in range(self._rec_depth):
deter = prev_state['deter']
x, deter = self._cell(x, [deter])
deter = deter[0] # Keras wraps the state in a list.
for i in range(self._layers_output):
x = self.get(f'imo{i}', tfkl.Dense, self._hidden, self._act)(x)
stats = self._suff_stats_layer('ims', x)
if sample:
stoch = self.get_dist(stats).sample()
else:
stoch = self.get_dist(stats).mode()
prior = {'stoch': stoch, 'deter': deter, **stats}
return prior
def _suff_stats_layer(self, name, x):
if self._discrete:
x = self.get(name, tfkl.Dense, self._stoch * self._discrete, None)(x)
logit = tf.reshape(x, x.shape[:-1] + [self._stoch, self._discrete])
return {'logit': logit}
else:
x = self.get(name, tfkl.Dense, 2 * self._stoch, None)(x)
mean, std = tf.split(x, 2, -1)
mean = {
'none': lambda: mean,
'tanh5': lambda: 5.0 * tf.math.tanh(mean / 5.0),
}[self._mean_act]()
std = {
'softplus': lambda: tf.nn.softplus(std),
'abs': lambda: tf.math.abs(std + 1),
'sigmoid': lambda: tf.nn.sigmoid(std),
'sigmoid2': lambda: 2 * tf.nn.sigmoid(std / 2),
}[self._std_act]()
std = std + self._min_std
return {'mean': mean, 'std': std}
def kl_loss(self, post, prior, forward, balance, free, scale):
kld = tfd.kl_divergence
dist = lambda x: self.get_dist(x, tf.float32)
sg = lambda x: tf.nest.map_structure(tf.stop_gradient, x)
lhs, rhs = (prior, post) if forward else (post, prior)
mix = balance if forward else (1 - balance)
if balance == 0.5:
value = kld(dist(lhs), dist(rhs))
loss = tf.reduce_mean(tf.maximum(value, free))
else:
value_lhs = value = kld(dist(lhs), dist(sg(rhs)))
value_rhs = kld(dist(sg(lhs)), dist(rhs))
loss_lhs = tf.maximum(tf.reduce_mean(value_lhs), free)
loss_rhs = tf.maximum(tf.reduce_mean(value_rhs), free)
loss = mix * loss_lhs + (1 - mix) * loss_rhs
loss *= scale
return loss, value
class ConvEncoder(tools.Module):
def __init__(
self, depth=32, act=tf.nn.relu, kernels=(4, 4, 4, 4)):
self._act = act
self._depth = depth
self._kernels = kernels
def __call__(self, obs):
x = tf.reshape(obs['image'], (-1,) + tuple(obs['image'].shape[-3:]))
for i, kernel in enumerate(self._kernels):
depth = 2 ** i * self._depth
x = self._act(self.get(f'h{i}', tfkl.Conv2D, depth, kernel, 2)(x))
x = tf.reshape(x, [x.shape[0], np.prod(x.shape[1:])])
# print('Encoder output:', x.shape)
shape = tf.concat([tf.shape(obs['image'])[:-3], [x.shape[-1]]], 0)
return tf.reshape(x, shape)
class ConvDecoder(tools.Module):
def __init__(
self, depth=32, act=tf.nn.relu, shape=(64, 64, 3), kernels=(5, 5, 6, 6),
thin=True):
self._act = act
self._depth = depth
self._shape = shape
self._kernels = kernels
self._thin = thin
def __call__(self, features, dtype=None):
ConvT = tfkl.Conv2DTranspose
if self._thin:
x = self.get('hin', tfkl.Dense, 32 * self._depth, None)(features)
x = tf.reshape(x, [-1, 1, 1, 32 * self._depth])
else:
x = self.get('hin', tfkl.Dense, 128 * self._depth, None)(features)
x = tf.reshape(x, [-1, 2, 2, 32 * self._depth])
for i, kernel in enumerate(self._kernels):
depth = 2 ** (len(self._kernels) - i - 1) * self._depth
act = self._act
if i == len(self._kernels) - 1:
depth = self._shape[-1]
act = None
x = self.get(f'h{i}', ConvT, depth, kernel, 2, activation=act)(x)
# print('Decoder output:', x.shape)
mean = tf.reshape(x, tf.concat([tf.shape(features)[:-1], self._shape], 0))
if dtype:
mean = tf.cast(mean, dtype)
return tfd.Independent(tfd.Normal(mean, 1), len(self._shape))
class DenseHead(tools.Module):
def __init__(
self, shape, layers, units, act=tf.nn.elu, dist='normal', std=1.0):
self._shape = (shape,) if isinstance(shape, int) else shape
self._layers = layers
self._units = units
self._act = act
self._dist = dist
self._std = std
def __call__(self, features, dtype=None):
x = features
for index in range(self._layers):
x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x)
mean = self.get(f'hmean', tfkl.Dense, np.prod(self._shape))(x)
mean = tf.reshape(mean, tf.concat(
[tf.shape(features)[:-1], self._shape], 0))
if self._std == 'learned':
std = self.get(f'hstd', tfkl.Dense, np.prod(self._shape))(x)
std = tf.nn.softplus(std) + 0.01
std = tf.reshape(std, tf.concat(
[tf.shape(features)[:-1], self._shape], 0))
else:
std = self._std
if dtype:
mean, std = tf.cast(mean, dtype), tf.cast(std, dtype)
if self._dist == 'normal':
return tfd.Independent(tfd.Normal(mean, std), len(self._shape))
if self._dist == 'huber':
return tfd.Independent(
tools.UnnormalizedHuber(mean, std, 1.0), len(self._shape))
if self._dist == 'binary':
return tfd.Independent(tfd.Bernoulli(mean), len(self._shape))
raise NotImplementedError(self._dist)
class ActionHead(tools.Module):
def __init__(
self, size, layers, units, act=tf.nn.elu, dist='trunc_normal',
init_std=0.0, min_std=0.1, action_disc=5, temp=0.1, outscale=0):
# assert min_std <= 2
self._size = size
self._layers = layers
self._units = units
self._dist = dist
self._act = act
self._min_std = min_std
self._init_std = init_std
self._action_disc = action_disc
self._temp = temp() if callable(temp) else temp
self._outscale = outscale
def __call__(self, features, dtype=None):
x = features
for index in range(self._layers):
kw = {}
if index == self._layers - 1 and self._outscale:
kw['kernel_initializer'] = tf.keras.initializers.VarianceScaling(
self._outscale)
x = self.get(f'h{index}', tfkl.Dense, self._units, self._act, **kw)(x)
if self._dist == 'tanh_normal':
# https://www.desmos.com/calculator/rcmcf5jwe7
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
if dtype:
x = tf.cast(x, dtype)
mean, std = tf.split(x, 2, -1)
mean = tf.tanh(mean)
std = tf.nn.softplus(std + self._init_std) + self._min_std
dist = tfd.Normal(mean, std)
dist = tfd.TransformedDistribution(dist, tools.TanhBijector())
dist = tfd.Independent(dist, 1)
dist = tools.SampleDist(dist)
elif self._dist == 'tanh_normal_5':
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
if dtype:
x = tf.cast(x, dtype)
mean, std = tf.split(x, 2, -1)
mean = 5 * tf.tanh(mean / 5)
std = tf.nn.softplus(std + 5) + 5
dist = tfd.Normal(mean, std)
dist = tfd.TransformedDistribution(dist, tools.TanhBijector())
dist = tfd.Independent(dist, 1)
dist = tools.SampleDist(dist)
elif self._dist == 'normal':
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
if dtype:
x = tf.cast(x, dtype)
mean, std = tf.split(x, 2, -1)
std = tf.nn.softplus(std + self._init_std) + self._min_std
dist = tfd.Normal(mean, std)
dist = tfd.Independent(dist, 1)
elif self._dist == 'normal_1':
mean = self.get(f'hout', tfkl.Dense, self._size)(x)
if dtype:
mean = tf.cast(mean, dtype)
dist = tfd.Normal(mean, 1)
dist = tfd.Independent(dist, 1)
elif self._dist == 'trunc_normal':
# https://www.desmos.com/calculator/mmuvuhnyxo
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
x = tf.cast(x, tf.float32)
mean, std = tf.split(x, 2, -1)
mean = tf.tanh(mean)
std = 2 * tf.nn.sigmoid(std / 2) + self._min_std
dist = tools.SafeTruncatedNormal(mean, std, -1, 1)
dist = tools.DtypeDist(dist, dtype)
dist = tfd.Independent(dist, 1)
elif self._dist == 'onehot':
x = self.get(f'hout', tfkl.Dense, self._size)(x)
x = tf.cast(x, tf.float32)
dist = tools.OneHotDist(x, dtype=dtype)
dist = tools.DtypeDist(dist, dtype)
elif self._dist == 'onehot_gumble':
x = self.get(f'hout', tfkl.Dense, self._size)(x)
if dtype:
x = tf.cast(x, dtype)
temp = self._temp
dist = tools.GumbleDist(temp, x, dtype=dtype)
else:
raise NotImplementedError(self._dist)
return dist
class GRUCell(tf.keras.layers.AbstractRNNCell):
def __init__(self, size, norm=False, act=tf.tanh, update_bias=-1, **kwargs):
super().__init__()
self._size = size
self._act = act
self._norm = norm
self._update_bias = update_bias
self._layer = tfkl.Dense(3 * size, use_bias=norm is not None, **kwargs)
if norm:
self._norm = tfkl.LayerNormalization(dtype=tf.float32)
@property
def state_size(self):
return self._size
def call(self, inputs, state):
state = state[0] # Keras wraps the state in a list.
parts = self._layer(tf.concat([inputs, state], -1))
if self._norm:
dtype = parts.dtype
parts = tf.cast(parts, tf.float32)
parts = self._norm(parts)
parts = tf.cast(parts, dtype)
reset, cand, update = tf.split(parts, 3, -1)
reset = tf.nn.sigmoid(reset)
cand = self._act(reset * cand)
update = tf.nn.sigmoid(update + self._update_bias)
output = update * cand + (1 - update) * state
return output, [output]
+540
View File
@@ -0,0 +1,540 @@
import argparse
import collections
import functools
import itertools
import json
import multiprocessing as mp
import os
import pathlib
import re
import subprocess
import warnings
os.environ['NO_AT_BRIDGE'] = '1' # Hide X org false warning.
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import pandas as pd
np.set_string_function(lambda x: f'<np.array shape={x.shape} dtype={x.dtype}>')
Run = collections.namedtuple('Run', 'task method seed xs ys')
PALETTES = dict(
discrete=(
'#377eb8', '#4daf4a', '#984ea3', '#e41a1c', '#ff7f00', '#a65628',
'#f781bf', '#888888', '#a6cee3', '#b2df8a', '#cab2d6', '#fb9a99',
),
contrast=(
'#0022ff', '#33aa00', '#ff0011', '#ddaa00', '#cc44dd', '#0088aa',
'#001177', '#117700', '#990022', '#885500', '#553366', '#006666',
),
gradient=(
'#fde725', '#a0da39', '#4ac16d', '#1fa187', '#277f8e', '#365c8d',
'#46327e', '#440154',
),
baselines=(
'#222222', '#666666', '#aaaaaa', '#cccccc',
),
)
LEGEND = dict(
fontsize='medium', numpoints=1, labelspacing=0, columnspacing=1.2,
handlelength=1.5, handletextpad=0.5, ncol=4, loc='lower center')
DEFAULT_BASELINES = [
'd4pg', 'dqn_sticky', 'rainbow_sticky', 'human$', 'impala']
BINS = collections.defaultdict(int)
BINS.update(dmc=1e5, atari=1e6, particle=1e5)
def find_keys(args):
filenames = []
for indir in args.indir:
task = next(indir.iterdir()) # First only.
for method in task.iterdir():
seed = next(indir.iterdir()) # First only.
filenames += list(seed.glob('**/*.jsonl'))
keys = set()
for filename in filenames:
keys |= set(load_jsonl(filename).columns)
print(f'Keys ({len(keys)}):', ', '.join(keys), flush=True)
def load_runs(args):
total, toload = [], []
for indir in args.indir:
filenames = list(indir.glob('**/*.jsonl'))
total += filenames
for filename in filenames:
task, method, seed = filename.relative_to(indir).parts[:-1]
if not any(p.search(task) for p in args.tasks):
continue
if not any(p.search(method) for p in args.methods):
continue
toload.append((filename, indir))
print(f'Loading {len(toload)} of {len(total)} runs...')
jobs = [functools.partial(load_run, f, i, args) for f, i in toload]
# Disable async data loading:
# runs = [j() for j in jobs]
with mp.Pool(10) as pool:
promises = [pool.apply_async(j) for j in jobs]
runs = [p.get() for p in promises]
runs = [r for r in runs if r is not None]
return runs
def load_run(filename, indir, args):
task, method, seed = filename.relative_to(indir).parts[:-1]
prefix = f'indir{args.indir.index(indir)+1}_'
if task == 'atari_jamesbond':
task = 'atari_james_bond'
seed = prefix + seed
if args.prefix:
method = prefix + method
df = load_jsonl(filename)
if df is None:
print('Skipping empty run')
return
try:
df = df[[args.xaxis, args.yaxis]].dropna()
if args.maxval:
df = df.replace([+np.inf], +args.maxval)
df = df.replace([-np.inf], -args.maxval)
df[args.yaxis] = df[args.yaxis].clip(-args.maxval, +args.maxval)
except KeyError:
return
xs = df[args.xaxis].to_numpy()
ys = df[args.yaxis].to_numpy()
bins = BINS[task.split('_')[0]] if args.bins == -1 else args.bins
if bins:
borders = np.arange(0, xs.max() + 1e-8, bins)
xs, ys = bin_scores(xs, ys, borders)
if not len(xs):
print('Skipping empty run', task, method, seed)
return
return Run(task, method, seed, xs, ys)
def load_baselines(patterns, prefix=False):
runs = []
directory = pathlib.Path(__file__).parent / 'scores'
for filename in directory.glob('**/*_baselines.json'):
for task, methods in json.loads(filename.read_text()).items():
for method, score in methods.items():
if prefix:
method = f'baseline_{method}'
if not any(p.search(method) for p in patterns):
continue
runs.append(Run(task, method, None, None, score))
return runs
def stats(runs, baselines):
tasks = sorted(set(r.task for r in runs))
methods = sorted(set(r.method for r in runs))
seeds = sorted(set(r.seed for r in runs))
baseline = sorted(set(r.method for r in baselines))
print('Loaded', len(runs), 'runs.')
print(f'Tasks ({len(tasks)}):', ', '.join(tasks))
print(f'Methods ({len(methods)}):', ', '.join(methods))
print(f'Seeds ({len(seeds)}):', ', '.join(seeds))
print(f'Baselines ({len(baseline)}):', ', '.join(baseline))
def order_methods(runs, baselines, args):
methods = []
for pattern in args.methods:
for method in sorted(set(r.method for r in runs)):
if pattern.search(method):
if method not in methods:
methods.append(method)
if method not in args.colors:
index = len(args.colors) % len(args.palette)
args.colors[method] = args.palette[index]
non_baseline_colors = len(args.colors)
for pattern in args.baselines:
for method in sorted(set(r.method for r in baselines)):
if pattern.search(method):
if method not in methods:
methods.append(method)
if method not in args.colors:
index = len(args.colors) - non_baseline_colors
index = index % len(PALETTES['baselines'])
args.colors[method] = PALETTES['baselines'][index]
return methods
def figure(runs, methods, args):
tasks = sorted(set(r.task for r in runs if r.xs is not None))
rows = int(np.ceil((len(tasks) + len(args.add)) / args.cols))
figsize = args.size[0] * args.cols, args.size[1] * rows
fig, axes = plt.subplots(rows, args.cols, figsize=figsize)
for task, ax in zip(tasks, axes.flatten()):
relevant = [r for r in runs if r.task == task]
plot(task, ax, relevant, methods, args)
for name, ax in zip(args.add, axes.flatten()[len(tasks):]):
ax.set_facecolor((0.9, 0.9, 0.9))
if name == 'median':
plot_combined(
'combined_median', ax, runs, methods, args,
lo='random', hi='human$',
agg=lambda x: np.nanmedian(x, -1))
elif name == 'mean':
plot_combined(
'combined_mean', ax, runs, methods, args,
lo='random', hi='human$',
agg=lambda x: np.nanmean(x, -1))
elif name == 'gamer_median':
plot_combined(
'combined_gamer_median', ax, runs, methods, args,
lo='random', hi='human$',
agg=lambda x: np.nanmedian(x, -1))
elif name == 'gamer_mean':
plot_combined(
'combined_gamer_mean', ax, runs, methods, args,
lo='random', hi='human$',
agg=lambda x: np.nanmean(x, -1))
elif name == 'record_mean':
plot_combined(
'combined_record_mean', ax, runs, methods, args,
lo='random', hi='record',
agg=lambda x: np.nanmean(x, -1))
elif name == 'clipped_record_mean':
plot_combined(
'combined_clipped_record_mean', ax, runs, methods, args,
lo='random', hi='record', clip=True,
agg=lambda x: np.nanmean(x, -1))
elif name == 'num_seeds':
plot_combined(
'combined_num_seeds', ax, runs, methods, args,
agg=lambda x: np.isfinite(x).sum(-1))
elif name == 'human_above':
plot_combined(
'combined_above_human$', ax, runs, methods, args,
agg=lambda y: (y >= 1.0).astype(float).sum(-1))
elif name == 'human_below':
plot_combined(
'combined_below_human$', ax, runs, methods, args,
agg=lambda y: (y <= 1.0).astype(float).sum(-1))
else:
raise NotImplementedError(name)
if args.xlim:
for ax in axes[:-1].flatten():
ax.xaxis.get_offset_text().set_visible(False)
if args.xlabel:
for ax in axes[-1]:
ax.set_xlabel(args.xlabel)
if args.ylabel:
for ax in axes[:, 0]:
ax.set_ylabel(args.ylabel)
for ax in axes.flatten()[len(tasks) + len(args.add):]:
ax.axis('off')
legend(fig, args.labels, **LEGEND)
return fig
def plot(task, ax, runs, methods, args):
assert runs
try:
title = task.split('_', 1)[1].replace('_', ' ').title()
except IndexError:
title = task.title()
ax.set_title(title)
xlim = [+np.inf, -np.inf]
for index, method in enumerate(methods):
relevant = [r for r in runs if r.method == method]
if not relevant:
continue
if any(r.xs is None for r in relevant):
baseline(index, method, ax, relevant, args)
else:
if args.aggregate == 'none':
xs, ys = curve_lines(index, task, method, ax, relevant, args)
else:
xs, ys = curve_area(index, task, method, ax, relevant, args)
if len(xs) == len(ys) == 0:
print(f'Skipping empty: {task} {method}')
continue
xlim = [min(xlim[0], xs.min()), max(xlim[1], xs.max())]
ax.ticklabel_format(axis='x', style='sci', scilimits=(0, 0))
steps = [1, 2, 2.5, 5, 10]
ax.xaxis.set_major_locator(ticker.MaxNLocator(args.xticks, steps=steps))
ax.yaxis.set_major_locator(ticker.MaxNLocator(args.yticks, steps=steps))
if np.isfinite(xlim).all():
ax.set_xlim(args.xlim or xlim)
if args.xlim:
ticks = sorted({*ax.get_xticks(), *args.xlim})
ticks = [x for x in ticks if args.xlim[0] <= x <= args.xlim[1]]
ax.set_xticks(ticks)
if args.ylim:
ax.set_ylim(args.ylim)
if args.ylimticks:
ticks = sorted({*ax.get_yticks(), *args.ylim})
ticks = [x for x in ticks if args.ylim[0] <= x <= args.ylim[1]]
ax.set_yticks(ticks)
def plot_combined(
name, ax, runs, methods, args, agg, lo=None, hi=None, clip=False):
tasks = sorted(set(run.task for run in runs if run.xs is not None))
seeds = list(set(run.seed for run in runs))
runs = [r for r in runs if r.task in tasks] # Discard unused baselines.
# Bin all runs onto the same X steps.
borders = sorted(
[r.xs for r in runs if r.xs is not None],
key=lambda x: np.nanmax(x))[-1]
for index, run in enumerate(runs):
if run.xs is None:
continue
xs, ys = bin_scores(run.xs, run.ys, borders)
runs[index] = run._replace(xs=xs, ys=ys)
# Per-task normalization by low and high baseline.
if lo or hi:
mins = collections.defaultdict(list)
maxs = collections.defaultdict(list)
[mins[r.task].append(r.ys) for r in load_baselines([re.compile(lo)])]
[maxs[r.task].append(r.ys) for r in load_baselines([re.compile(hi)])]
mins = {task: min(ys) for task, ys in mins.items() if task in tasks}
maxs = {task: max(ys) for task, ys in maxs.items() if task in tasks}
missing_baselines = []
for task in tasks:
if task not in mins or task not in maxs:
missing_baselines.append(task)
if set(missing_baselines) == set(tasks):
print(f'No baselines found to normalize any tasks in {name} plot.')
else:
for task in missing_baselines:
print(f'No baselines found to normalize {task} in {name} plot.')
for index, run in enumerate(runs):
if run.task not in mins or run.task not in maxs:
continue
ys = (run.ys - mins[run.task]) / (maxs[run.task] - mins[run.task])
if clip:
ys = np.minimum(ys, 1.0)
runs[index] = run._replace(ys=ys)
# Aggregate across tasks but not methods or seeds.
combined = []
for method, seed in itertools.product(methods, seeds):
relevant = [r for r in runs if r.method == method and r.seed == seed]
if not relevant:
continue
if relevant[0].xs is None:
xs, ys = None, np.array([r.ys for r in relevant])
else:
xs, ys = stack_scores(*zip(*[(r.xs, r.ys) for r in relevant]))
with warnings.catch_warnings(): # Ignore empty slice warnings.
warnings.simplefilter('ignore', category=RuntimeWarning)
combined.append(Run('combined', method, seed, xs, agg(ys)))
plot(name, ax, combined, methods, args)
def curve_lines(index, task, method, ax, runs, args):
zorder = 10000 - 10 * index - 1
for run in runs:
color = args.colors[method]
ax.plot(run.xs, run.ys, label=method, color=color, zorder=zorder)
return runs[0].xs, runs[0].ys
def curve_area(index, task, method, ax, runs, args):
xs, ys = stack_scores(*zip(*[(r.xs, r.ys) for r in runs]))
with warnings.catch_warnings(): # NaN buckets remain NaN.
warnings.simplefilter('ignore', category=RuntimeWarning)
if args.aggregate == 'std1':
mean, std = np.nanmean(ys, -1), np.nanstd(ys, -1)
lo, mi, hi = mean - std, mean, mean + std
elif args.aggregate == 'per0':
lo, mi, hi = [np.nanpercentile(ys, k, -1) for k in (0, 50, 100)]
elif args.aggregate == 'per5':
lo, mi, hi = [np.nanpercentile(ys, k, -1) for k in (5, 50, 95)]
elif args.aggregate == 'per25':
lo, mi, hi = [np.nanpercentile(ys, k, -1) for k in (25, 50, 75)]
else:
raise NotImplementedError(args.aggregate)
color = args.colors[method]
kw = dict(color=color, zorder=1000 - 10 * index, alpha=0.1, linewidths=0)
ax.fill_between(xs, lo, hi, **kw)
ax.plot(xs, mi, label=method, color=color, zorder=10000 - 10 * index - 1)
return xs, mi
def baseline(index, method, ax, runs, args):
assert all(run.xs is None for run in runs)
ys = np.array([run.ys for run in runs])
mean, std = ys.mean(), ys.std()
color = args.colors[method]
kw = dict(color=color, zorder=500 - 20 * index - 1, alpha=0.1, linewidths=0)
ax.fill_between([-np.inf, np.inf], [mean - std] * 2, [mean + std] * 2, **kw)
kw = dict(ls='--', color=color, zorder=5000 - 10 * index - 1)
ax.axhline(mean, label=method, **kw)
def legend(fig, mapping=None, **kwargs):
entries = {}
for ax in fig.axes:
for handle, label in zip(*ax.get_legend_handles_labels()):
if mapping and label in mapping:
label = mapping[label]
entries[label] = handle
leg = fig.legend(entries.values(), entries.keys(), **kwargs)
leg.get_frame().set_edgecolor('white')
extent = leg.get_window_extent(fig.canvas.get_renderer())
extent = extent.transformed(fig.transFigure.inverted())
yloc, xloc = kwargs['loc'].split()
y0 = dict(lower=extent.y1, center=0, upper=0)[yloc]
y1 = dict(lower=1, center=1, upper=extent.y0)[yloc]
x0 = dict(left=extent.x1, center=0, right=0)[xloc]
x1 = dict(left=1, center=1, right=extent.x0)[xloc]
fig.tight_layout(rect=[x0, y0, x1, y1], h_pad=0.5, w_pad=0.5)
def save(fig, args):
args.outdir.mkdir(parents=True, exist_ok=True)
filename = args.outdir / 'curves.png'
fig.savefig(filename, dpi=args.dpi)
print('Saved to', filename)
filename = args.outdir / 'curves.pdf'
fig.savefig(filename)
try:
subprocess.call(['pdfcrop', str(filename), str(filename)])
except FileNotFoundError:
print('Install texlive-extra-utils to crop PDF outputs.')
def bin_scores(xs, ys, borders, reducer=np.nanmean):
order = np.argsort(xs)
xs, ys = xs[order], ys[order]
binned = []
with warnings.catch_warnings(): # Empty buckets become NaN.
warnings.simplefilter('ignore', category=RuntimeWarning)
for start, stop in zip(borders[:-1], borders[1:]):
left = (xs <= start).sum()
right = (xs <= stop).sum()
binned.append(reducer(ys[left:right]))
return borders[1:], np.array(binned)
def stack_scores(multiple_xs, multiple_ys):
longest_xs = sorted(multiple_xs, key=lambda x: len(x))[-1]
multiple_padded_ys = []
for xs, ys in zip(multiple_xs, multiple_ys):
assert (longest_xs[:len(xs)] == xs).all(), (list(xs), list(longest_xs))
padding = [np.inf] * (len(longest_xs) - len(xs))
padded_ys = np.concatenate([ys, padding])
multiple_padded_ys.append(padded_ys)
stacked_ys = np.stack(multiple_padded_ys, -1)
return longest_xs, stacked_ys
def load_jsonl(filename):
try:
with filename.open() as f:
lines = list(f.readlines())
records = []
for index, line in enumerate(lines):
try:
records.append(json.loads(line))
except Exception:
if index == len(lines) - 1:
continue # Silently skip last line if it is incomplete.
raise ValueError(
f'Skipping invalid JSON line ({index+1}/{len(lines)+1}) in'
f'{filename}: {line}')
return pd.DataFrame(records)
except ValueError as e:
print('Invalid', filename, e)
return None
def save_runs(runs, filename):
filename.parent.mkdir(parents=True, exist_ok=True)
records = []
for run in runs:
if run.xs is None:
continue
records.append(dict(
task=run.task, method=run.method, seed=run.seed,
xs=run.xs.tolist(), ys=run.ys.tolist()))
runs = json.dumps(records)
filename.write_text(runs)
print('Saved', filename)
def main(args):
find_keys(args)
runs = load_runs(args)
save_runs(runs, args.outdir / 'runs.jsonl')
baselines = load_baselines(args.baselines, args.prefix)
stats(runs, baselines)
methods = order_methods(runs, baselines, args)
if not runs:
print('Noting to plot.')
return
print('Plotting...')
fig = figure(runs + baselines, methods, args)
save(fig, args)
def parse_args():
boolean = lambda x: bool(['False', 'True'].index(x))
parser = argparse.ArgumentParser()
parser.add_argument('--indir', nargs='+', type=pathlib.Path, required=True)
parser.add_argument('--outdir', type=pathlib.Path, required=True)
parser.add_argument('--subdir', type=boolean, default=True)
parser.add_argument('--xaxis', type=str, required=True)
parser.add_argument('--yaxis', type=str, required=True)
parser.add_argument('--tasks', nargs='+', default=[r'.*'])
parser.add_argument('--methods', nargs='+', default=[r'.*'])
parser.add_argument('--baselines', nargs='+', default=DEFAULT_BASELINES)
parser.add_argument('--prefix', type=boolean, default=False)
parser.add_argument('--bins', type=float, default=-1)
parser.add_argument('--aggregate', type=str, default='std1')
parser.add_argument('--size', nargs=2, type=float, default=[2.5, 2.3])
parser.add_argument('--dpi', type=int, default=80)
parser.add_argument('--cols', type=int, default=6)
parser.add_argument('--xlim', nargs=2, type=float, default=None)
parser.add_argument('--ylim', nargs=2, type=float, default=None)
parser.add_argument('--ylimticks', type=boolean, default=True)
parser.add_argument('--xlabel', type=str, default=None)
parser.add_argument('--ylabel', type=str, default=None)
parser.add_argument('--xticks', type=int, default=6)
parser.add_argument('--yticks', type=int, default=5)
parser.add_argument('--labels', nargs='+', default=None)
parser.add_argument('--palette', nargs='+', default=['contrast'])
parser.add_argument('--colors', nargs='+', default={})
parser.add_argument('--maxval', type=float, default=0)
parser.add_argument('--add', nargs='+', type=str, default=[
'gamer_median', 'gamer_mean', 'record_mean',
'clipped_record_mean', 'num_seeds'])
args = parser.parse_args()
if args.subdir:
args.outdir /= args.indir[0].stem
args.indir = [d.expanduser() for d in args.indir]
args.outdir = args.outdir.expanduser()
if args.labels:
assert len(args.labels) % 2 == 0
args.labels = {k: v for k, v in zip(args.labels[:-1], args.labels[1:])}
if args.colors:
assert len(args.colors) % 2 == 0
args.colors = {k: v for k, v in zip(args.colors[:-1], args.colors[1:])}
args.tasks = [re.compile(p) for p in args.tasks]
args.methods = [re.compile(p) for p in args.methods]
args.baselines = [re.compile(p) for p in args.baselines]
if 'return' not in args.yaxis:
args.baselines = []
if args.prefix is None:
args.prefix = len(args.indir) > 1
if len(args.palette) == 1 and args.palette[0] in PALETTES:
args.palette = 10 * PALETTES[args.palette[0]]
if len(args.add) == 1 and args.add[0] == 'none':
args.add = []
return args
if __name__ == '__main__':
main(parse_args())
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
+1
View File
@@ -0,0 +1 @@
{"dmc_acrobot_swingup": {"d4pg_100m": 91.7, "a3c_100m_proprio": 41.9}, "dmc_cartpole_balance": {"d4pg_100m": 992.8, "a3c_100m_proprio": 951.6}, "dmc_cartpole_swingup": {"d4pg_100m": 862.0, "planet_1e6": 821, "a3c_100m_proprio": 558.4}, "dmc_cartpole_balance_sparse": {"d4pg_100m": 1000.0, "a3c_100m_proprio": 857.4}, "dmc_cartpole_swingup_sparse": {"d4pg_100m": 482.0, "a3c_100m_proprio": 179.8}, "dmc_cheetah_run": {"slac_3e6": 880, "d4pg_100m": 523.8, "planet_1e6": 662, "a3c_100m_proprio": 213.9}, "dmc_cup_catch": {"slac_3e6": 970, "d4pg_100m": 980.5, "planet_1e6": 930, "a3c_100m_proprio": 104.7}, "dmc_finger_spin": {"slac_3e6": 950, "d4pg_100m": 985.7, "planet_1e6": 700, "a3c_100m_proprio": 129.4}, "dmc_finger_turn_easy": {"d4pg_100m": 971.4, "a3c_100m_proprio": 167.3}, "dmc_finger_turn_hard": {"d4pg_100m": 966.0, "a3c_100m_proprio": 88.7}, "dmc_hopper_hop": {"d4pg_100m": 242.0, "a3c_100m_proprio": 0.5}, "dmc_hopper_stand": {"d4pg_100m": 929.9, "a3c_100m_proprio": 27.9}, "dmc_reacher_easy": {"d4pg_100m": 967.4, "planet_1e6": 832, "a3c_100m_proprio": 95.6}, "dmc_reacher_hard": {"d4pg_100m": 957.1, "a3c_100m_proprio": 39.7}, "dmc_walker_stand": {"d4pg_100m": 985.2, "a3c_100m_proprio": 378.4}, "dmc_walker_walk": {"slac_3e6": 840, "d4pg_100m": 968.3, "planet_1e6": 951, "a3c_100m_proprio": 311.0}, "dmc_walker_run": {"d4pg_100m": 567.2, "a3c_100m_proprio": 191.8}, "dmc_pendulum_swingup": {"d4pg_100m": 680.9, "a3c_100m_proprio": 48.6}}
+694
View File
@@ -0,0 +1,694 @@
import datetime
import io
import json
import pathlib
import pickle
import re
import time
import uuid
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow_probability as tfp
from tensorflow.keras.mixed_precision import experimental as prec
from tensorflow_probability import distributions as tfd
# Patch to ignore seed to avoid synchronization across GPUs.
_orig_random_categorical = tf.random.categorical
def random_categorical(*args, **kwargs):
kwargs['seed'] = None
return _orig_random_categorical(*args, **kwargs)
tf.random.categorical = random_categorical
# Patch to ignore seed to avoid synchronization across GPUs.
_orig_random_normal = tf.random.normal
def random_normal(*args, **kwargs):
kwargs['seed'] = None
return _orig_random_normal(*args, **kwargs)
tf.random.normal = random_normal
class AttrDict(dict):
__setattr__ = dict.__setitem__
__getattr__ = dict.__getitem__
class Module(tf.Module):
def save(self, filename):
values = tf.nest.map_structure(lambda x: x.numpy(), self.variables)
amount = len(tf.nest.flatten(values))
count = int(sum(np.prod(x.shape) for x in tf.nest.flatten(values)))
print(f'Save checkpoint with {amount} tensors and {count} parameters.')
with pathlib.Path(filename).open('wb') as f:
pickle.dump(values, f)
def load(self, filename):
with pathlib.Path(filename).open('rb') as f:
values = pickle.load(f)
amount = len(tf.nest.flatten(values))
count = int(sum(np.prod(x.shape) for x in tf.nest.flatten(values)))
print(f'Load checkpoint with {amount} tensors and {count} parameters.')
tf.nest.map_structure(lambda x, y: x.assign(y), self.variables, values)
def get(self, name, ctor, *args, **kwargs):
# Create or get layer by name to avoid mentioning it in the constructor.
if not hasattr(self, '_modules'):
self._modules = {}
if name not in self._modules:
self._modules[name] = ctor(*args, **kwargs)
return self._modules[name]
def var_nest_names(nest):
if isinstance(nest, dict):
items = ' '.join(f'{k}:{var_nest_names(v)}' for k, v in nest.items())
return '{' + items + '}'
if isinstance(nest, (list, tuple)):
items = ' '.join(var_nest_names(v) for v in nest)
return '[' + items + ']'
if hasattr(nest, 'name') and hasattr(nest, 'shape'):
return nest.name + str(nest.shape).replace(', ', 'x')
if hasattr(nest, 'shape'):
return str(nest.shape).replace(', ', 'x')
return '?'
class Logger:
def __init__(self, logdir, step):
self._logdir = logdir
self._writer = tf.summary.create_file_writer(str(logdir), max_queue=1000)
self._last_step = None
self._last_time = None
self._scalars = {}
self._images = {}
self._videos = {}
self.step = step
def scalar(self, name, value):
self._scalars[name] = float(value)
def image(self, name, value):
self._images[name] = np.array(value)
def video(self, name, value):
self._videos[name] = np.array(value)
def write(self, fps=False):
scalars = list(self._scalars.items())
if fps:
scalars.append(('fps', self._compute_fps(self.step)))
print(f'[{self.step}]', ' / '.join(f'{k} {v:.1f}' for k, v in scalars))
with (self._logdir / 'metrics.jsonl').open('a') as f:
f.write(json.dumps({'step': self.step, ** dict(scalars)}) + '\n')
with self._writer.as_default():
for name, value in scalars:
tf.summary.scalar('scalars/' + name, value, self.step)
for name, value in self._images.items():
tf.summary.image(name, value, self.step)
for name, value in self._videos.items():
video_summary(name, value, self.step)
self._writer.flush()
self._scalars = {}
self._images = {}
self._videos = {}
def _compute_fps(self, step):
if self._last_step is None:
self._last_time = time.time()
self._last_step = step
return 0
steps = step - self._last_step
duration = time.time() - self._last_time
self._last_time += duration
self._last_step = step
return steps / duration
def graph_summary(writer, step, fn, *args):
def inner(*args):
tf.summary.experimental.set_step(step.numpy().item())
with writer.as_default():
fn(*args)
return tf.numpy_function(inner, args, [])
def video_summary(name, video, step=None, fps=20):
name = name if isinstance(name, str) else name.decode('utf-8')
if np.issubdtype(video.dtype, np.floating):
video = np.clip(255 * video, 0, 255).astype(np.uint8)
B, T, H, W, C = video.shape
try:
frames = video.transpose((1, 2, 0, 3, 4)).reshape((T, H, B * W, C))
summary = tf1.Summary()
image = tf1.Summary.Image(height=B * H, width=T * W, colorspace=C)
image.encoded_image_string = encode_gif(frames, fps)
summary.value.add(tag=name, image=image)
tf.summary.experimental.write_raw_pb(summary.SerializeToString(), step)
except (IOError, OSError) as e:
print('GIF summaries require ffmpeg in $PATH.', e)
frames = video.transpose((0, 2, 1, 3, 4)).reshape((1, B * H, T * W, C))
tf.summary.image(name, frames, step)
def encode_gif(frames, fps):
from subprocess import Popen, PIPE
h, w, c = frames[0].shape
pxfmt = {1: 'gray', 3: 'rgb24'}[c]
cmd = ' '.join([
f'ffmpeg -y -f rawvideo -vcodec rawvideo',
f'-r {fps:.02f} -s {w}x{h} -pix_fmt {pxfmt} -i - -filter_complex',
f'[0:v]split[x][z];[z]palettegen[y];[x]fifo[x];[x][y]paletteuse',
f'-r {fps:.02f} -f gif -'])
proc = Popen(cmd.split(' '), stdin=PIPE, stdout=PIPE, stderr=PIPE)
for image in frames:
proc.stdin.write(image.tostring())
out, err = proc.communicate()
if proc.returncode:
raise IOError('\n'.join([' '.join(cmd), err.decode('utf8')]))
del proc
return out
def simulate(agent, envs, steps=0, episodes=0, state=None):
# Initialize or unpack simulation state.
if state is None:
step, episode = 0, 0
done = np.ones(len(envs), np.bool)
length = np.zeros(len(envs), np.int32)
obs = [None] * len(envs)
agent_state = None
else:
step, episode, done, length, obs, agent_state = state
while (steps and step < steps) or (episodes and episode < episodes):
# Reset envs if necessary.
if done.any():
indices = [index for index, d in enumerate(done) if d]
results = [envs[i].reset() for i in indices]
for index, result in zip(indices, results):
obs[index] = result
# Step agents.
obs = {k: np.stack([o[k] for o in obs]) for k in obs[0]}
action, agent_state = agent(obs, done, agent_state)
if isinstance(action, dict):
action = [
{k: np.array(action[k][i]) for k in action}
for i in range(len(envs))]
else:
action = np.array(action)
assert len(action) == len(envs)
# Step envs.
results = [e.step(a) for e, a in zip(envs, action)]
obs, _, done = zip(*[p[:3] for p in results])
obs = list(obs)
done = np.stack(done)
episode += int(done.sum())
length += 1
step += (done * length).sum()
length *= (1 - done)
# Return new state to allow resuming the simulation.
return (step - steps, episode - episodes, done, length, obs, agent_state)
def save_episodes(directory, episodes):
directory = pathlib.Path(directory).expanduser()
directory.mkdir(parents=True, exist_ok=True)
timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
filenames = []
for episode in episodes:
identifier = str(uuid.uuid4().hex)
length = len(episode['reward'])
filename = directory / f'{timestamp}-{identifier}-{length}.npz'
with io.BytesIO() as f1:
np.savez_compressed(f1, **episode)
f1.seek(0)
with filename.open('wb') as f2:
f2.write(f1.read())
filenames.append(filename)
return filenames
def sample_episodes(episodes, length=None, balance=False, seed=0):
random = np.random.RandomState(seed)
while True:
episode = random.choice(list(episodes.values()))
if length:
total = len(next(iter(episode.values())))
available = total - length
if available < 1:
print(f'Skipped short episode of length {available}.')
continue
if balance:
index = min(random.randint(0, total), available)
else:
index = int(random.randint(0, available + 1))
episode = {k: v[index: index + length] for k, v in episode.items()}
yield episode
def load_episodes(directory, limit=None):
directory = pathlib.Path(directory).expanduser()
episodes = {}
total = 0
for filename in reversed(sorted(directory.glob('*.npz'))):
try:
with filename.open('rb') as f:
episode = np.load(f)
episode = {k: episode[k] for k in episode.keys()}
except Exception as e:
print(f'Could not load episode: {e}')
continue
episodes[str(filename)] = episode
total += len(episode['reward']) - 1
if limit and total >= limit:
break
return episodes
class DtypeDist:
def __init__(self, dist, dtype=None):
self._dist = dist
self._dtype = dtype or prec.global_policy().compute_dtype
@property
def name(self):
return 'DtypeDist'
def __getattr__(self, name):
return getattr(self._dist, name)
def mean(self):
return tf.cast(self._dist.mean(), self._dtype)
def mode(self):
return tf.cast(self._dist.mode(), self._dtype)
def entropy(self):
return tf.cast(self._dist.entropy(), self._dtype)
def sample(self, *args, **kwargs):
return tf.cast(self._dist.sample(*args, **kwargs), self._dtype)
class SampleDist:
def __init__(self, dist, samples=100):
self._dist = dist
self._samples = samples
@property
def name(self):
return 'SampleDist'
def __getattr__(self, name):
return getattr(self._dist, name)
def mean(self):
samples = self._dist.sample(self._samples)
return tf.reduce_mean(samples, 0)
def mode(self):
sample = self._dist.sample(self._samples)
logprob = self._dist.log_prob(sample)
return tf.gather(sample, tf.argmax(logprob))[0]
def entropy(self):
sample = self._dist.sample(self._samples)
logprob = self.log_prob(sample)
return -tf.reduce_mean(logprob, 0)
class OneHotDist(tfd.OneHotCategorical):
def __init__(self, logits=None, probs=None, dtype=None):
self._sample_dtype = dtype or prec.global_policy().compute_dtype
super().__init__(logits=logits, probs=probs)
def mode(self):
return tf.cast(super().mode(), self._sample_dtype)
def sample(self, sample_shape=(), seed=None):
# Straight through biased gradient estimator.
sample = tf.cast(super().sample(sample_shape, seed), self._sample_dtype)
probs = super().probs_parameter()
while len(probs.shape) < len(sample.shape):
probs = probs[None]
sample += tf.cast(probs - tf.stop_gradient(probs), self._sample_dtype)
return sample
class GumbleDist(tfd.RelaxedOneHotCategorical):
def __init__(self, temp, logits=None, probs=None, dtype=None):
self._sample_dtype = dtype or prec.global_policy().compute_dtype
self._exact = tfd.OneHotCategorical(logits=logits, probs=probs)
super().__init__(temp, logits=logits, probs=probs)
def mode(self):
return tf.cast(self._exact.mode(), self._sample_dtype)
def entropy(self):
return tf.cast(self._exact.entropy(), self._sample_dtype)
def sample(self, sample_shape=(), seed=None):
return tf.cast(super().sample(sample_shape, seed), self._sample_dtype)
class UnnormalizedHuber(tfd.Normal):
def __init__(self, loc, scale, threshold=1, **kwargs):
self._threshold = tf.cast(threshold, loc.dtype)
super().__init__(loc, scale, **kwargs)
def log_prob(self, event):
return -(tf.math.sqrt(
(event - self.mean()) ** 2 + self._threshold ** 2) - self._threshold)
class SafeTruncatedNormal(tfd.TruncatedNormal):
def __init__(self, loc, scale, low, high, clip=1e-6, mult=1):
super().__init__(loc, scale, low, high)
self._clip = clip
self._mult = mult
def sample(self, *args, **kwargs):
event = super().sample(*args, **kwargs)
if self._clip:
clipped = tf.clip_by_value(
event, self.low + self._clip, self.high - self._clip)
event = event - tf.stop_gradient(event) + tf.stop_gradient(clipped)
if self._mult:
event *= self._mult
return event
class TanhBijector(tfp.bijectors.Bijector):
def __init__(self, validate_args=False, name='tanh'):
super().__init__(
forward_min_event_ndims=0,
validate_args=validate_args,
name=name)
def _forward(self, x):
return tf.nn.tanh(x)
def _inverse(self, y):
dtype = y.dtype
y = tf.cast(y, tf.float32)
y = tf.where(
tf.less_equal(tf.abs(y), 1.),
tf.clip_by_value(y, -0.99999997, 0.99999997), y)
y = tf.atanh(y)
y = tf.cast(y, dtype)
return y
def _forward_log_det_jacobian(self, x):
log2 = tf.math.log(tf.constant(2.0, dtype=x.dtype))
return 2.0 * (log2 - x - tf.nn.softplus(-2.0 * x))
def lambda_return(
reward, value, pcont, bootstrap, lambda_, axis):
# Setting lambda=1 gives a discounted Monte Carlo return.
# Setting lambda=0 gives a fixed 1-step return.
assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape)
if isinstance(pcont, (int, float)):
pcont = pcont * tf.ones_like(reward)
dims = list(range(reward.shape.ndims))
dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:]
if axis != 0:
reward = tf.transpose(reward, dims)
value = tf.transpose(value, dims)
pcont = tf.transpose(pcont, dims)
if bootstrap is None:
bootstrap = tf.zeros_like(value[-1])
next_values = tf.concat([value[1:], bootstrap[None]], 0)
inputs = reward + pcont * next_values * (1 - lambda_)
returns = static_scan(
lambda agg, cur: cur[0] + cur[1] * lambda_ * agg,
(inputs, pcont), bootstrap, reverse=True)
if axis != 0:
returns = tf.transpose(returns, dims)
return returns
class Optimizer(tf.Module):
def __init__(
self, name, lr, eps=1e-4, clip=None, wd=None, wd_pattern=r'.*',
opt='adam'):
assert 0 <= wd < 1
assert not clip or 1 <= clip
self._name = name
self._clip = clip
self._wd = wd
self._wd_pattern = wd_pattern
self._opt = {
'adam': lambda: tf.optimizers.Adam(lr, epsilon=eps),
'nadam': lambda: tf.optimizers.Nadam(lr, epsilon=eps),
'adamax': lambda: tf.optimizers.Adamax(lr, epsilon=eps),
'sgd': lambda: tf.optimizers.SGD(lr),
'momentum': lambda: tf.optimizers.SGD(lr, 0.9),
}[opt]()
self._mixed = (prec.global_policy().compute_dtype == tf.float16)
if self._mixed:
self._opt = prec.LossScaleOptimizer(self._opt, 'dynamic')
@property
def variables(self):
return self._opt.variables()
def __call__(self, tape, loss, modules):
assert loss.dtype is tf.float32, self._name
modules = modules if hasattr(modules, '__len__') else (modules,)
varibs = tf.nest.flatten([module.variables for module in modules])
count = sum(np.prod(x.shape) for x in varibs)
print(f'Found {count} {self._name} parameters.')
assert len(loss.shape) == 0, loss.shape
tf.debugging.check_numerics(loss, self._name + '_loss')
if self._mixed:
with tape:
loss = self._opt.get_scaled_loss(loss)
grads = tape.gradient(loss, varibs)
if self._mixed:
grads = self._opt.get_unscaled_gradients(grads)
norm = tf.linalg.global_norm(grads)
if not self._mixed:
tf.debugging.check_numerics(norm, self._name + '_norm')
if self._clip:
grads, _ = tf.clip_by_global_norm(grads, self._clip, norm)
if self._wd:
self._apply_weight_decay(varibs)
self._opt.apply_gradients(zip(grads, varibs))
metrics = {}
metrics[f'{self._name}_loss'] = loss
metrics[f'{self._name}_grad_norm'] = norm
if self._mixed:
try:
metrics[f'{self._name}_loss_scale'] = float(self._opt.loss_scale)
except TypeError:
metrics[f'{self._name}_loss_scale'] = float(
self._opt.loss_scale._current_loss_scale)
return metrics
def _apply_weight_decay(self, varibs):
nontrivial = (self._wd_pattern != r'.*')
if nontrivial:
print('Applied weight decay to variables:')
for var in varibs:
if re.search(self._wd_pattern, self._name + '/' + var.name):
if nontrivial:
print('- ' + self._name + '/' + var.name)
var.assign((1 - self._wd) * var)
def args_type(default):
def parse_string(x):
if default is None:
return x
if isinstance(default, bool):
return bool(['False', 'True'].index(x))
if isinstance(default, int):
return float(x) if ('e' in x or '.' in x) else int(x)
if isinstance(default, (list, tuple)):
return tuple(args_type(default[0])(y) for y in x.split(','))
return type(default)(x)
def parse_object(x):
if isinstance(default, (list, tuple)):
return tuple(x)
return x
return lambda x: parse_string(x) if isinstance(x, str) else parse_object(x)
def static_scan(fn, inputs, start, reverse=False):
last = start
outputs = [[] for _ in tf.nest.flatten(start)]
indices = range(len(tf.nest.flatten(inputs)[0]))
if reverse:
indices = reversed(indices)
for index in indices:
inp = tf.nest.map_structure(lambda x: x[index], inputs)
last = fn(last, inp)
[o.append(l) for o, l in zip(outputs, tf.nest.flatten(last))]
if reverse:
outputs = [list(reversed(x)) for x in outputs]
outputs = [tf.stack(x, 0) for x in outputs]
return tf.nest.pack_sequence_as(start, outputs)
def uniform_mixture(dist, dtype=None):
if dist.batch_shape[-1] == 1:
return tfd.BatchReshape(dist, dist.batch_shape[:-1])
dtype = dtype or prec.global_policy().compute_dtype
weights = tfd.Categorical(tf.zeros(dist.batch_shape, dtype))
return tfd.MixtureSameFamily(weights, dist)
def cat_mixture_entropy(dist):
if isinstance(dist, tfd.MixtureSameFamily):
probs = dist.components_distribution.probs_parameter()
else:
probs = dist.probs_parameter()
return -tf.reduce_mean(
tf.reduce_mean(probs, 2) *
tf.math.log(tf.reduce_mean(probs, 2) + 1e-8), -1)
@tf.function
def cem_planner(
state, num_actions, horizon, proposals, topk, iterations, imagine,
objective):
dtype = prec.global_policy().compute_dtype
B, P = list(state.values())[0].shape[0], proposals
H, A = horizon, num_actions
flat_state = {k: tf.repeat(v, P, 0) for k, v in state.items()}
mean = tf.zeros((B, H, A), dtype)
std = tf.ones((B, H, A), dtype)
for _ in range(iterations):
proposals = tf.random.normal((B, P, H, A), dtype=dtype)
proposals = proposals * std[:, None] + mean[:, None]
proposals = tf.clip_by_value(proposals, -1, 1)
flat_proposals = tf.reshape(proposals, (B * P, H, A))
states = imagine(flat_proposals, flat_state)
scores = objective(states)
scores = tf.reshape(tf.reduce_sum(scores, -1), (B, P))
_, indices = tf.math.top_k(scores, topk, sorted=False)
best = tf.gather(proposals, indices, axis=1, batch_dims=1)
mean, var = tf.nn.moments(best, 1)
std = tf.sqrt(var + 1e-6)
return mean[:, 0, :]
@tf.function
def grad_planner(
state, num_actions, horizon, proposals, iterations, imagine, objective,
kl_scale, step_size):
dtype = prec.global_policy().compute_dtype
B, P = list(state.values())[0].shape[0], proposals
H, A = horizon, num_actions
flat_state = {k: tf.repeat(v, P, 0) for k, v in state.items()}
mean = tf.zeros((B, H, A), dtype)
rawstd = 0.54 * tf.ones((B, H, A), dtype)
for _ in range(iterations):
proposals = tf.random.normal((B, P, H, A), dtype=dtype)
with tf.GradientTape(watch_accessed_variables=False) as tape:
tape.watch(mean)
tape.watch(rawstd)
std = tf.nn.softplus(rawstd)
proposals = proposals * std[:, None] + mean[:, None]
proposals = (
tf.stop_gradient(tf.clip_by_value(proposals, -1, 1)) +
proposals - tf.stop_gradient(proposals))
flat_proposals = tf.reshape(proposals, (B * P, H, A))
states = imagine(flat_proposals, flat_state)
scores = objective(states)
scores = tf.reshape(tf.reduce_sum(scores, -1), (B, P))
div = tfd.kl_divergence(
tfd.Normal(mean, std),
tfd.Normal(tf.zeros_like(mean), tf.ones_like(std)))
elbo = tf.reduce_sum(scores) - kl_scale * div
elbo /= tf.cast(tf.reduce_prod(tf.shape(scores)), dtype)
grad_mean, grad_rawstd = tape.gradient(elbo, [mean, rawstd])
e, v = tf.nn.moments(grad_mean, [1, 2], keepdims=True)
grad_mean /= tf.sqrt(e * e + v + 1e-4)
e, v = tf.nn.moments(grad_rawstd, [1, 2], keepdims=True)
grad_rawstd /= tf.sqrt(e * e + v + 1e-4)
mean = tf.clip_by_value(mean + step_size * grad_mean, -1, 1)
rawstd = rawstd + step_size * grad_rawstd
return mean[:, 0, :]
class Every:
def __init__(self, every):
self._every = every
self._last = None
def __call__(self, step):
if not self._every:
return False
if self._last is None:
self._last = step
return True
if step >= self._last + self._every:
self._last += self._every
return True
return False
class Once:
def __init__(self):
self._once = True
def __call__(self):
if self._once:
self._once = False
return True
return False
class Until:
def __init__(self, until):
self._until = until
def __call__(self, step):
if not self._until:
return True
return step < self._until
def schedule(string, step):
try:
return float(string)
except ValueError:
step = tf.cast(step, tf.float32)
match = re.match(r'linear\((.+),(.+),(.+)\)', string)
if match:
initial, final, duration = [float(group) for group in match.groups()]
mix = tf.clip_by_value(step / duration, 0, 1)
return (1 - mix) * initial + mix * final
match = re.match(r'warmup\((.+),(.+)\)', string)
if match:
warmup, value = [float(group) for group in match.groups()]
scale = tf.clip_by_value(step / warmup, 0, 1)
return scale * value
match = re.match(r'exp\((.+),(.+),(.+)\)', string)
if match:
initial, final, halflife = [float(group) for group in match.groups()]
return (initial - final) * 0.5 ** (step / halflife) + final
match = re.match(r'horizon\((.+),(.+),(.+)\)', string)
if match:
initial, final, duration = [float(group) for group in match.groups()]
mix = tf.clip_by_value(step / duration, 0, 1)
horizon = (1 - mix) * initial + mix * final
return 1 - 1 / horizon
raise NotImplementedError(string)
+306
View File
@@ -0,0 +1,306 @@
import threading
import gym
import numpy as np
class DeepMindControl:
def __init__(self, name, action_repeat=1, size=(64, 64), camera=None):
domain, task = name.split('_', 1)
if domain == 'cup': # Only domain with multiple words.
domain = 'ball_in_cup'
if isinstance(domain, str):
from dm_control import suite
self._env = suite.load(domain, task)
else:
assert task is None
self._env = domain()
self._action_repeat = action_repeat
self._size = size
if camera is None:
camera = dict(quadruped=2).get(domain, 0)
self._camera = camera
@property
def observation_space(self):
spaces = {}
for key, value in self._env.observation_spec().items():
spaces[key] = gym.spaces.Box(
-np.inf, np.inf, value.shape, dtype=np.float32)
spaces['image'] = gym.spaces.Box(
0, 255, self._size + (3,), dtype=np.uint8)
return gym.spaces.Dict(spaces)
@property
def action_space(self):
spec = self._env.action_spec()
return gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32)
def step(self, action):
assert np.isfinite(action).all(), action
reward = 0
for _ in range(self._action_repeat):
time_step = self._env.step(action)
reward += time_step.reward or 0
if time_step.last():
break
obs = dict(time_step.observation)
obs['image'] = self.render()
done = time_step.last()
info = {'discount': np.array(time_step.discount, np.float32)}
return obs, reward, done, info
def reset(self):
time_step = self._env.reset()
obs = dict(time_step.observation)
obs['image'] = self.render()
return obs
def render(self, *args, **kwargs):
if kwargs.get('mode', 'rgb_array') != 'rgb_array':
raise ValueError("Only render mode 'rgb_array' is supported.")
return self._env.physics.render(*self._size, camera_id=self._camera)
class Atari:
LOCK = threading.Lock()
def __init__(
self, name, action_repeat=4, size=(84, 84), grayscale=True, noops=30,
life_done=False, sticky_actions=True, all_actions=False):
assert size[0] == size[1]
import gym.wrappers
import gym.envs.atari
if name == 'james_bond':
name = 'jamesbond'
with self.LOCK:
env = gym.envs.atari.AtariEnv(
game=name, obs_type='image', frameskip=1,
repeat_action_probability=0.25 if sticky_actions else 0.0,
full_action_space=all_actions)
# Avoid unnecessary rendering in inner env.
env._get_obs = lambda: None
# Tell wrapper that the inner env has no action repeat.
env.spec = gym.envs.registration.EnvSpec('NoFrameskip-v0')
env = gym.wrappers.AtariPreprocessing(
env, noops, action_repeat, size[0], life_done, grayscale)
self._env = env
self._grayscale = grayscale
@property
def observation_space(self):
return gym.spaces.Dict({
'image': self._env.observation_space,
'ram': gym.spaces.Box(0, 255, (128,), np.uint8),
})
@property
def action_space(self):
return self._env.action_space
def close(self):
return self._env.close()
def reset(self):
with self.LOCK:
image = self._env.reset()
if self._grayscale:
image = image[..., None]
obs = {'image': image, 'ram': self._env.env._get_ram()}
return obs
def step(self, action):
image, reward, done, info = self._env.step(action)
if self._grayscale:
image = image[..., None]
obs = {'image': image, 'ram': self._env.env._get_ram()}
return obs, reward, done, info
def render(self, mode):
return self._env.render(mode)
class CollectDataset:
def __init__(self, env, callbacks=None, precision=32):
self._env = env
self._callbacks = callbacks or ()
self._precision = precision
self._episode = None
def __getattr__(self, name):
return getattr(self._env, name)
def step(self, action):
obs, reward, done, info = self._env.step(action)
obs = {k: self._convert(v) for k, v in obs.items()}
transition = obs.copy()
if isinstance(action, dict):
transition.update(action)
else:
transition['action'] = action
transition['reward'] = reward
transition['discount'] = info.get('discount', np.array(1 - float(done)))
self._episode.append(transition)
if done:
for key, value in self._episode[1].items():
if key not in self._episode[0]:
self._episode[0][key] = 0 * value
episode = {k: [t[k] for t in self._episode] for k in self._episode[0]}
episode = {k: self._convert(v) for k, v in episode.items()}
info['episode'] = episode
for callback in self._callbacks:
callback(episode)
return obs, reward, done, info
def reset(self):
obs = self._env.reset()
transition = obs.copy()
# Missing keys will be filled with a zeroed out version of the first
# transition, because we do not know what action information the agent will
# pass yet.
transition['reward'] = 0.0
transition['discount'] = 1.0
self._episode = [transition]
return obs
def _convert(self, value):
value = np.array(value)
if np.issubdtype(value.dtype, np.floating):
dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self._precision]
elif np.issubdtype(value.dtype, np.signedinteger):
dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision]
elif np.issubdtype(value.dtype, np.uint8):
dtype = np.uint8
else:
raise NotImplementedError(value.dtype)
return value.astype(dtype)
class TimeLimit:
def __init__(self, env, duration):
self._env = env
self._duration = duration
self._step = None
def __getattr__(self, name):
return getattr(self._env, name)
def step(self, action):
assert self._step is not None, 'Must reset environment.'
obs, reward, done, info = self._env.step(action)
self._step += 1
if self._step >= self._duration:
done = True
if 'discount' not in info:
info['discount'] = np.array(1.0).astype(np.float32)
self._step = None
return obs, reward, done, info
def reset(self):
self._step = 0
return self._env.reset()
class NormalizeActions:
def __init__(self, env):
self._env = env
self._mask = np.logical_and(
np.isfinite(env.action_space.low),
np.isfinite(env.action_space.high))
self._low = np.where(self._mask, env.action_space.low, -1)
self._high = np.where(self._mask, env.action_space.high, 1)
def __getattr__(self, name):
return getattr(self._env, name)
@property
def action_space(self):
low = np.where(self._mask, -np.ones_like(self._low), self._low)
high = np.where(self._mask, np.ones_like(self._low), self._high)
return gym.spaces.Box(low, high, dtype=np.float32)
def step(self, action):
original = (action + 1) / 2 * (self._high - self._low) + self._low
original = np.where(self._mask, original, action)
return self._env.step(original)
class OneHotAction:
def __init__(self, env):
assert isinstance(env.action_space, gym.spaces.Discrete)
self._env = env
self._random = np.random.RandomState()
def __getattr__(self, name):
return getattr(self._env, name)
@property
def action_space(self):
shape = (self._env.action_space.n,)
space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
space.sample = self._sample_action
space.discrete = True
return space
def step(self, action):
index = np.argmax(action).astype(int)
reference = np.zeros_like(action)
reference[index] = 1
if not np.allclose(reference, action):
raise ValueError(f'Invalid one-hot action:\n{action}')
return self._env.step(index)
def reset(self):
return self._env.reset()
def _sample_action(self):
actions = self._env.action_space.n
index = self._random.randint(0, actions)
reference = np.zeros(actions, dtype=np.float32)
reference[index] = 1.0
return reference
class RewardObs:
def __init__(self, env):
self._env = env
def __getattr__(self, name):
return getattr(self._env, name)
@property
def observation_space(self):
spaces = self._env.observation_space.spaces
assert 'reward' not in spaces
spaces['reward'] = gym.spaces.Box(-np.inf, np.inf, dtype=np.float32)
return gym.spaces.Dict(spaces)
def step(self, action):
obs, reward, done, info = self._env.step(action)
obs['reward'] = reward
return obs, reward, done, info
def reset(self):
obs = self._env.reset()
obs['reward'] = 0.0
return obs
class SelectAction:
def __init__(self, env, key):
self._env = env
self._key = key
def __getattr__(self, name):
return getattr(self._env, name)
def step(self, action):
return self._env.step(action[self._key])