mirror of
https://github.com/wassname/dreamerv2.git
synced 2026-06-27 16:45:34 +08:00
Initial commit.
This commit is contained in:
@@ -0,0 +1,5 @@
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*.egg-info
|
||||
./dist
|
||||
MUJOCO_LOG.TXT
|
||||
@@ -0,0 +1,19 @@
|
||||
Copyright (c) 2020 Danijar Hafner
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -0,0 +1,106 @@
|
||||
# Mastering Atari with Discrete World Models
|
||||
|
||||
Implementation of the [DreamerV2][website] agent in TensorFlow 2. Training
|
||||
curves for all 55 games are included.
|
||||
|
||||
<p align="center">
|
||||
<img width="90%" src="https://imgur.com/gO1rvEn.gif">
|
||||
</p>
|
||||
|
||||
If you find this code useful, please reference in your paper:
|
||||
|
||||
```
|
||||
@article{hafner2020dreamerv2,
|
||||
title={Mastering Atari with Discrete World Models},
|
||||
author={Hafner, Danijar and Lillicrap, Timothy and Norouzi, Mohammad and Ba, Jimmy},
|
||||
journal={arXiv preprint arXiv:2010.02193},
|
||||
year={2020}
|
||||
}
|
||||
```
|
||||
|
||||
[website]: https://danijar.com/dreamerv2
|
||||
|
||||
## Method
|
||||
|
||||
DreamerV2 is the first world model agent that achieves human-level performance
|
||||
on the Atari benchmark. DreamerV2 also outperforms the final performance of the
|
||||
top model-free agents Rainbow and IQN using the same amount of experience and
|
||||
computation. The implementation in this repository alternates between training
|
||||
the world model, training the policy, and collecting experience and runs on a
|
||||
single GPU.
|
||||
|
||||

|
||||
|
||||
DreamerV2 learns a model of the environment directly from high-dimensional
|
||||
input images. For this, it predicts ahead using compact learned states. The
|
||||
states consist of a deterministic part and several categorical variables that
|
||||
are sampled. The prior for these categoricals is learned through a KL loss. The
|
||||
world model is learned end-to-end via straight-through gradients, meaning that
|
||||
the gradient of the density is set to the gradient of the sample.
|
||||
|
||||

|
||||
|
||||
DreamerV2 learns actor and critic networks from imagined trajectories of latent
|
||||
states. The trajectories start at encoded states of previously encountered
|
||||
sequences. The world model then predicts ahead using the selected actions and
|
||||
its learned state prior. The critic is trained using temporal difference
|
||||
learning and the actor is trained to maximize the value function via reinforce
|
||||
and straight-through gradients.
|
||||
|
||||
For more information:
|
||||
|
||||
- [Google AI Blog post](https://ai.googleblog.com/2021/02/mastering-atari-with-discrete-world.html)
|
||||
- [Project website](https://danijar.com/dreamerv2/)
|
||||
- [Research paper](https://arxiv.org/pdf/2010.02193.pdf)
|
||||
|
||||
## Instructions
|
||||
|
||||
Get dependencies:
|
||||
|
||||
```sh
|
||||
pip3 install --user tensorflow==2.3.1
|
||||
pip3 install --user tensorflow_probability==0.11.1
|
||||
pip3 install --user pandas
|
||||
pip3 install --user matplotlib
|
||||
pip3 install --user ruamel.yaml
|
||||
pip3 install --user 'gym[atari]'
|
||||
```
|
||||
|
||||
Train the agent:
|
||||
|
||||
```sh
|
||||
python3 dreamer.py --logdir ~/logdir/atari_pong/dreamerv2/1 \
|
||||
--configs defaults atari --task atari_pong
|
||||
```
|
||||
|
||||
Monitor results:
|
||||
|
||||
```sh
|
||||
tensorboard --logdir ~/logdir
|
||||
```
|
||||
|
||||
Generate plots:
|
||||
|
||||
```sh
|
||||
python3 plotting.py --indir ~/logdir --outdir ~/plots --xaxis step --yaxis eval_return --bins 1e6
|
||||
```
|
||||
|
||||
Tips:
|
||||
|
||||
- **Efficient debugging.** You can use the `debug` config as in `--configs
|
||||
defaults atari debug`. This reduces the batch size, increases the evaluation
|
||||
frequency, and disables `tf.function` graph compilation for easy line-by-line
|
||||
debugging.
|
||||
|
||||
- **Infinite gradient norms.** This is normal and described under loss scaling in
|
||||
the [mixed precision][mixed] guide. You can disable mixed precision by passing
|
||||
`--precision 32` to the training script. Mixed precision is faster but can in
|
||||
principle cause numerical instabilities.
|
||||
|
||||
- **Accessing logged metrics.** The metrics are stored in both TensorBoard and
|
||||
JSON lines format. You can directly load them using `pandas.read_json()`. The
|
||||
plotting script also stores the binned and aggregated metrics of multiple runs
|
||||
into a single JSON lines file for easy manual plotting.
|
||||
|
||||
[mixed]: https://www.tensorflow.org/guide/mixed_precision
|
||||
|
||||
+228
@@ -0,0 +1,228 @@
|
||||
defaults:
|
||||
|
||||
logdir: null
|
||||
traindir: null
|
||||
evaldir: null
|
||||
offline_traindir: ''
|
||||
offline_evaldir: ''
|
||||
seed: 0
|
||||
steps: 1e7
|
||||
eval_every: 1e4
|
||||
log_every: 1e4
|
||||
reset_every: 0
|
||||
gpu_growth: True
|
||||
precision: 32
|
||||
debug: False
|
||||
expl_gifs: False
|
||||
|
||||
# Environment
|
||||
task: 'dmc_walker_walk'
|
||||
size: [64, 64]
|
||||
envs: 1
|
||||
action_repeat: 2
|
||||
time_limit: 1000
|
||||
grayscale: False
|
||||
prefill: 2500
|
||||
eval_noise: 0.0
|
||||
clip_rewards: 'identity'
|
||||
|
||||
# Model
|
||||
dyn_cell: 'gru'
|
||||
dyn_hidden: 200
|
||||
dyn_deter: 200
|
||||
dyn_stoch: 50
|
||||
dyn_discrete: 0
|
||||
dyn_input_layers: 1
|
||||
dyn_output_layers: 1
|
||||
dyn_rec_depth: 1
|
||||
dyn_shared: False
|
||||
dyn_mean_act: 'none'
|
||||
dyn_std_act: 'sigmoid2'
|
||||
dyn_min_std: 0.1
|
||||
dyn_temp_post: True
|
||||
grad_heads: ['image', 'reward']
|
||||
units: 400
|
||||
reward_layers: 2
|
||||
discount_layers: 3
|
||||
value_layers: 3
|
||||
actor_layers: 4
|
||||
act: 'elu'
|
||||
cnn_depth: 32
|
||||
encoder_kernels: [4, 4, 4, 4]
|
||||
decoder_kernels: [5, 5, 6, 6]
|
||||
decoder_thin: True
|
||||
value_head: 'normal'
|
||||
kl_scale: '1.0'
|
||||
kl_balance: '0.8'
|
||||
kl_free: '1.0'
|
||||
kl_forward: False
|
||||
pred_discount: False
|
||||
discount_scale: 1.0
|
||||
reward_scale: 1.0
|
||||
weight_decay: 0.0
|
||||
|
||||
# Training
|
||||
batch_size: 50
|
||||
batch_length: 50
|
||||
train_every: 5
|
||||
train_steps: 1
|
||||
pretrain: 100
|
||||
model_lr: 3e-4
|
||||
value_lr: 8e-5
|
||||
actor_lr: 8e-5
|
||||
opt_eps: 1e-5
|
||||
grad_clip: 100
|
||||
value_grad_clip: 100
|
||||
actor_grad_clip: 100
|
||||
dataset_size: 0
|
||||
oversample_ends: False
|
||||
slow_value_target: True
|
||||
slow_actor_target: True
|
||||
slow_target_update: 100
|
||||
slow_target_fraction: 1
|
||||
opt: 'adam'
|
||||
|
||||
# Behavior.
|
||||
discount: 0.99
|
||||
discount_lambda: 0.95
|
||||
imag_horizon: 15
|
||||
imag_gradient: 'dynamics'
|
||||
imag_gradient_mix: '0.1'
|
||||
imag_sample: True
|
||||
actor_dist: 'trunc_normal'
|
||||
actor_entropy: '1e-4'
|
||||
actor_state_entropy: 0.0
|
||||
actor_init_std: 1.0
|
||||
actor_min_std: 0.1
|
||||
actor_disc: 5
|
||||
actor_temp: 0.1
|
||||
actor_outscale: 0.0
|
||||
expl_amount: 0.0
|
||||
eval_state_mean: False
|
||||
collect_dyn_sample: True
|
||||
behavior_stop_grad: True
|
||||
value_decay: 0.0
|
||||
future_entropy: False
|
||||
|
||||
# Exploration
|
||||
expl_behavior: 'greedy'
|
||||
expl_until: 0
|
||||
expl_extr_scale: 0.0
|
||||
expl_intr_scale: 1.0
|
||||
disag_target: 'stoch'
|
||||
disag_log: True
|
||||
disag_models: 10
|
||||
disag_offset: 1
|
||||
disag_layers: 4
|
||||
disag_units: 400
|
||||
|
||||
atari:
|
||||
|
||||
# General
|
||||
task: 'atari_pong'
|
||||
steps: 2e8
|
||||
eval_every: 1e5
|
||||
log_every: 1e4
|
||||
prefill: 50000
|
||||
dataset_size: 2e6
|
||||
pretrain: 0
|
||||
precision: 16
|
||||
|
||||
# Environment
|
||||
time_limit: 108000 # 30 minutes of game play.
|
||||
grayscale: True
|
||||
action_repeat: 4
|
||||
eval_noise: 0.0
|
||||
train_every: 16
|
||||
train_steps: 1
|
||||
clip_rewards: 'tanh'
|
||||
|
||||
# Model
|
||||
grad_heads: ['image', 'reward', 'discount']
|
||||
dyn_cell: 'gru_layer_norm'
|
||||
pred_discount: True
|
||||
cnn_depth: 48
|
||||
dyn_deter: 600
|
||||
dyn_hidden: 600
|
||||
dyn_stoch: 32
|
||||
dyn_discrete: 32
|
||||
reward_layers: 4
|
||||
discount_layers: 4
|
||||
value_layers: 4
|
||||
actor_layers: 4
|
||||
|
||||
# Behavior
|
||||
actor_dist: 'onehot'
|
||||
actor_entropy: 'linear(3e-3,3e-4,2.5e6)'
|
||||
expl_amount: 0.0
|
||||
discount: 0.999
|
||||
imag_gradient: 'both'
|
||||
imag_gradient_mix: 'linear(0.1,0,2.5e6)'
|
||||
|
||||
# Training
|
||||
discount_scale: 5.0
|
||||
reward_scale: 1
|
||||
weight_decay: 1e-6
|
||||
model_lr: 2e-4
|
||||
kl_scale: 0.1
|
||||
kl_free: 0.0
|
||||
actor_lr: 4e-5
|
||||
value_lr: 1e-4
|
||||
oversample_ends: True
|
||||
|
||||
dmc:
|
||||
|
||||
# General
|
||||
task: 'dmc_walker_walk'
|
||||
steps: 1e7
|
||||
eval_every: 1e4
|
||||
log_every: 1e4
|
||||
prefill: 2500
|
||||
dataset_size: 0
|
||||
pretrain: 100
|
||||
|
||||
# Environment
|
||||
time_limit: 1000
|
||||
action_repeat: 2
|
||||
train_every: 5
|
||||
train_steps: 1
|
||||
|
||||
# Model
|
||||
grad_heads: ['image', 'reward']
|
||||
dyn_cell: 'gru_layer_norm'
|
||||
pred_discount: False
|
||||
cnn_depth: 32
|
||||
dyn_deter: 200
|
||||
dyn_stoch: 50
|
||||
dyn_discrete: 0
|
||||
reward_layers: 2
|
||||
discount_layers: 3
|
||||
value_layers: 3
|
||||
actor_layers: 4
|
||||
|
||||
# Behavior
|
||||
actor_dist: 'trunc_normal'
|
||||
expl_amount: 0.0
|
||||
actor_entropy: '1e-4'
|
||||
discount: 0.99
|
||||
imag_gradient: 'dynamics'
|
||||
imag_gradient_mix: 1.0
|
||||
|
||||
# Training
|
||||
reward_scale: 2
|
||||
weight_decay: 0.0
|
||||
model_lr: 3e-4
|
||||
value_lr: 8e-5
|
||||
actor_lr: 8e-5
|
||||
opt_eps: 1e-5
|
||||
kl_free: '1.0'
|
||||
kl_scale: '1.0'
|
||||
|
||||
debug:
|
||||
|
||||
debug: True
|
||||
pretrain: 1
|
||||
prefill: 1
|
||||
train_steps: 1
|
||||
batch_size: 10
|
||||
batch_length: 20
|
||||
+323
@@ -0,0 +1,323 @@
|
||||
import argparse
|
||||
import collections
|
||||
import functools
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings('ignore', '.*box bound precision lowered.*')
|
||||
warnings.filterwarnings('ignore', '.*TensorFloat-32 matmul/conv*')
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
os.environ['MUJOCO_GL'] = 'egl'
|
||||
|
||||
import numpy as np
|
||||
import ruamel.yaml as yaml
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.mixed_precision import experimental as prec
|
||||
|
||||
tf.get_logger().setLevel('ERROR')
|
||||
|
||||
from tensorflow_probability import distributions as tfd
|
||||
|
||||
sys.path.append(str(pathlib.Path(__file__).parent))
|
||||
|
||||
import exploration as expl
|
||||
import models
|
||||
import tools
|
||||
import wrappers
|
||||
|
||||
|
||||
class Dreamer(tools.Module):
|
||||
|
||||
def __init__(self, config, logger, dataset):
|
||||
self._config = config
|
||||
self._logger = logger
|
||||
self._float = prec.global_policy().compute_dtype
|
||||
self._should_log = tools.Every(config.log_every)
|
||||
self._should_train = tools.Every(config.train_every)
|
||||
self._should_pretrain = tools.Once()
|
||||
self._should_reset = tools.Every(config.reset_every)
|
||||
self._should_expl = tools.Until(int(
|
||||
config.expl_until / config.action_repeat))
|
||||
self._metrics = collections.defaultdict(tf.metrics.Mean)
|
||||
with tf.device('cpu:0'):
|
||||
self._step = tf.Variable(count_steps(config.traindir), dtype=tf.int64)
|
||||
# Schedules.
|
||||
config.actor_entropy = (
|
||||
lambda x=config.actor_entropy: tools.schedule(x, self._step))
|
||||
config.actor_state_entropy = (
|
||||
lambda x=config.actor_state_entropy: tools.schedule(x, self._step))
|
||||
config.imag_gradient_mix = (
|
||||
lambda x=config.imag_gradient_mix: tools.schedule(x, self._step))
|
||||
self._dataset = iter(dataset)
|
||||
self._wm = models.WorldModel(self._step, config)
|
||||
self._task_behavior = models.ImagBehavior(
|
||||
config, self._wm, config.behavior_stop_grad)
|
||||
reward = lambda f, s, a: self._wm.heads['reward'](f).mode()
|
||||
self._expl_behavior = dict(
|
||||
greedy=lambda: self._task_behavior,
|
||||
random=lambda: expl.Random(config),
|
||||
plan2explore=lambda: expl.Plan2Explore(config, self._wm, reward),
|
||||
)[config.expl_behavior]()
|
||||
# Train step to initialize variables including optimizer statistics.
|
||||
self._train(next(self._dataset))
|
||||
|
||||
def __call__(self, obs, reset, state=None, training=True):
|
||||
step = self._step.numpy().item()
|
||||
if self._should_reset(step):
|
||||
state = None
|
||||
if state is not None and reset.any():
|
||||
mask = tf.cast(1 - reset, self._float)[:, None]
|
||||
state = tf.nest.map_structure(lambda x: x * mask, state)
|
||||
if training and self._should_train(step):
|
||||
steps = (
|
||||
self._config.pretrain if self._should_pretrain()
|
||||
else self._config.train_steps)
|
||||
for _ in range(steps):
|
||||
self._train(next(self._dataset))
|
||||
if self._should_log(step):
|
||||
for name, mean in self._metrics.items():
|
||||
self._logger.scalar(name, float(mean.result()))
|
||||
mean.reset_states()
|
||||
openl = self._wm.video_pred(next(self._dataset))
|
||||
self._logger.video('train_openl', openl)
|
||||
self._logger.write(fps=True)
|
||||
policy_output, state = self._policy(obs, state, training)
|
||||
if training:
|
||||
self._step.assign_add(len(reset))
|
||||
self._logger.step = self._config.action_repeat \
|
||||
* self._step.numpy().item()
|
||||
return policy_output, state
|
||||
|
||||
@tf.function
|
||||
def _policy(self, obs, state, training):
|
||||
if state is None:
|
||||
batch_size = len(obs['image'])
|
||||
latent = self._wm.dynamics.initial(len(obs['image']))
|
||||
action = tf.zeros((batch_size, self._config.num_actions), self._float)
|
||||
else:
|
||||
latent, action = state
|
||||
embed = self._wm.encoder(self._wm.preprocess(obs))
|
||||
latent, _ = self._wm.dynamics.obs_step(
|
||||
latent, action, embed, self._config.collect_dyn_sample)
|
||||
if self._config.eval_state_mean:
|
||||
latent['stoch'] = latent['mean']
|
||||
feat = self._wm.dynamics.get_feat(latent)
|
||||
if not training:
|
||||
actor = self._task_behavior.actor(feat)
|
||||
action = actor.mode()
|
||||
elif self._should_expl(self._step):
|
||||
actor = self._expl_behavior.actor(feat)
|
||||
action = actor.sample()
|
||||
else:
|
||||
actor = self._task_behavior.actor(feat)
|
||||
action = actor.sample()
|
||||
logprob = actor.log_prob(tf.cast(action, tf.float32))
|
||||
if self._config.actor_dist == 'onehot_gumble':
|
||||
action = tf.cast(
|
||||
tf.one_hot(tf.argmax(action, axis=-1), self._config.num_actions),
|
||||
action.dtype)
|
||||
action = self._exploration(action, training)
|
||||
policy_output = {'action': action, 'logprob': logprob}
|
||||
state = (latent, action)
|
||||
return policy_output, state
|
||||
|
||||
def _exploration(self, action, training):
|
||||
amount = self._config.expl_amount if training else self._config.eval_noise
|
||||
if amount == 0:
|
||||
return action
|
||||
amount = tf.cast(amount, self._float)
|
||||
if 'onehot' in self._config.actor_dist:
|
||||
probs = amount / self._config.num_actions + (1 - amount) * action
|
||||
return tools.OneHotDist(probs=probs).sample()
|
||||
else:
|
||||
return tf.clip_by_value(tfd.Normal(action, amount).sample(), -1, 1)
|
||||
raise NotImplementedError(self._config.action_noise)
|
||||
|
||||
@tf.function
|
||||
def _train(self, data):
|
||||
print('Tracing train function.')
|
||||
metrics = {}
|
||||
post, context, mets = self._wm.train(data)
|
||||
metrics.update(mets)
|
||||
start = post
|
||||
if self._config.pred_discount: # Last step could be terminal.
|
||||
start = {k: v[:, :-1] for k, v in post.items()}
|
||||
context = {k: v[:, :-1] for k, v in context.items()}
|
||||
reward = lambda f, s, a: self._wm.heads['reward'](
|
||||
self._wm.dynamics.get_feat(s)).mode()
|
||||
metrics.update(self._task_behavior.train(start, reward)[-1])
|
||||
if self._config.expl_behavior != 'greedy':
|
||||
mets = self._expl_behavior.train(start, context)[-1]
|
||||
metrics.update({'expl_' + key: value for key, value in mets.items()})
|
||||
for name, value in metrics.items():
|
||||
self._metrics[name].update_state(value)
|
||||
|
||||
|
||||
def count_steps(folder):
|
||||
return sum(int(str(n).split('-')[-1][:-4]) - 1 for n in folder.glob('*.npz'))
|
||||
|
||||
|
||||
def make_dataset(episodes, config):
|
||||
example = episodes[next(iter(episodes.keys()))]
|
||||
types = {k: v.dtype for k, v in example.items()}
|
||||
shapes = {k: (None,) + v.shape[1:] for k, v in example.items()}
|
||||
generator = lambda: tools.sample_episodes(
|
||||
episodes, config.batch_length, config.oversample_ends)
|
||||
dataset = tf.data.Dataset.from_generator(generator, types, shapes)
|
||||
dataset = dataset.batch(config.batch_size, drop_remainder=True)
|
||||
dataset = dataset.prefetch(10)
|
||||
return dataset
|
||||
|
||||
|
||||
def make_env(config, logger, mode, train_eps, eval_eps):
|
||||
suite, task = config.task.split('_', 1)
|
||||
if suite == 'dmc':
|
||||
env = wrappers.DeepMindControl(task, config.action_repeat, config.size)
|
||||
env = wrappers.NormalizeActions(env)
|
||||
elif suite == 'atari':
|
||||
env = wrappers.Atari(
|
||||
task, config.action_repeat, config.size,
|
||||
grayscale=config.grayscale,
|
||||
life_done=False and (mode == 'train'),
|
||||
sticky_actions=True,
|
||||
all_actions=True)
|
||||
env = wrappers.OneHotAction(env)
|
||||
else:
|
||||
raise NotImplementedError(suite)
|
||||
env = wrappers.TimeLimit(env, config.time_limit)
|
||||
env = wrappers.SelectAction(env, key='action')
|
||||
callbacks = [functools.partial(
|
||||
process_episode, config, logger, mode, train_eps, eval_eps)]
|
||||
env = wrappers.CollectDataset(env, callbacks)
|
||||
env = wrappers.RewardObs(env)
|
||||
return env
|
||||
|
||||
|
||||
def process_episode(config, logger, mode, train_eps, eval_eps, episode):
|
||||
directory = dict(train=config.traindir, eval=config.evaldir)[mode]
|
||||
cache = dict(train=train_eps, eval=eval_eps)[mode]
|
||||
filename = tools.save_episodes(directory, [episode])[0]
|
||||
length = len(episode['reward']) - 1
|
||||
score = float(episode['reward'].astype(np.float64).sum())
|
||||
video = episode['image']
|
||||
if mode == 'eval':
|
||||
cache.clear()
|
||||
if mode == 'train' and config.dataset_size:
|
||||
total = 0
|
||||
for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])):
|
||||
if total <= config.dataset_size - length:
|
||||
total += len(ep['reward']) - 1
|
||||
else:
|
||||
del cache[key]
|
||||
logger.scalar('dataset_size', total + length)
|
||||
cache[str(filename)] = episode
|
||||
print(f'{mode.title()} episode has {length} steps and return {score:.1f}.')
|
||||
logger.scalar(f'{mode}_return', score)
|
||||
logger.scalar(f'{mode}_length', length)
|
||||
logger.scalar(f'{mode}_episodes', len(cache))
|
||||
if mode == 'eval' or config.expl_gifs:
|
||||
logger.video(f'{mode}_policy', video[None])
|
||||
logger.write()
|
||||
|
||||
|
||||
def main(config):
|
||||
logdir = pathlib.Path(config.logdir).expanduser()
|
||||
config.traindir = config.traindir or logdir / 'train_eps'
|
||||
config.evaldir = config.evaldir or logdir / 'eval_eps'
|
||||
config.steps //= config.action_repeat
|
||||
config.eval_every //= config.action_repeat
|
||||
config.log_every //= config.action_repeat
|
||||
config.time_limit //= config.action_repeat
|
||||
config.act = getattr(tf.nn, config.act)
|
||||
|
||||
if config.debug:
|
||||
tf.config.experimental_run_functions_eagerly(True)
|
||||
if config.gpu_growth:
|
||||
message = 'No GPU found. To actually train on CPU remove this assert.'
|
||||
assert tf.config.experimental.list_physical_devices('GPU'), message
|
||||
for gpu in tf.config.experimental.list_physical_devices('GPU'):
|
||||
tf.config.experimental.set_memory_growth(gpu, True)
|
||||
assert config.precision in (16, 32), config.precision
|
||||
if config.precision == 16:
|
||||
prec.set_policy(prec.Policy('mixed_float16'))
|
||||
print('Logdir', logdir)
|
||||
logdir.mkdir(parents=True, exist_ok=True)
|
||||
config.traindir.mkdir(parents=True, exist_ok=True)
|
||||
config.evaldir.mkdir(parents=True, exist_ok=True)
|
||||
step = count_steps(config.traindir)
|
||||
logger = tools.Logger(logdir, config.action_repeat * step)
|
||||
|
||||
print('Create envs.')
|
||||
if config.offline_traindir:
|
||||
directory = config.offline_traindir.format(**vars(config))
|
||||
else:
|
||||
directory = config.traindir
|
||||
train_eps = tools.load_episodes(directory, limit=config.dataset_size)
|
||||
if config.offline_evaldir:
|
||||
directory = config.offline_evaldir.format(**vars(config))
|
||||
else:
|
||||
directory = config.evaldir
|
||||
eval_eps = tools.load_episodes(directory, limit=1)
|
||||
make = lambda mode: make_env(config, logger, mode, train_eps, eval_eps)
|
||||
train_envs = [make('train') for _ in range(config.envs)]
|
||||
eval_envs = [make('eval') for _ in range(config.envs)]
|
||||
acts = train_envs[0].action_space
|
||||
config.num_actions = acts.n if hasattr(acts, 'n') else acts.shape[0]
|
||||
|
||||
prefill = max(0, config.prefill - count_steps(config.traindir))
|
||||
print(f'Prefill dataset ({prefill} steps).')
|
||||
if hasattr(acts, 'discrete'):
|
||||
random_actor = tools.OneHotDist(tf.zeros_like(acts.low)[None])
|
||||
else:
|
||||
random_actor = tfd.Independent(
|
||||
tfd.Uniform(acts.low[None], acts.high[None]), 1)
|
||||
def random_agent(o, d, s):
|
||||
action = random_actor.sample()
|
||||
logprob = random_actor.log_prob(action)
|
||||
return {'action': action, 'logprob': logprob}, None
|
||||
tools.simulate(random_agent, train_envs, prefill)
|
||||
tools.simulate(random_agent, eval_envs, episodes=1)
|
||||
logger.step = config.action_repeat * count_steps(config.traindir)
|
||||
|
||||
print('Simulate agent.')
|
||||
train_dataset = make_dataset(train_eps, config)
|
||||
eval_dataset = iter(make_dataset(eval_eps, config))
|
||||
agent = Dreamer(config, logger, train_dataset)
|
||||
if (logdir / 'variables.pkl').exists():
|
||||
agent.load(logdir / 'variables.pkl')
|
||||
agent._should_pretrain._once = False
|
||||
|
||||
state = None
|
||||
while agent._step.numpy().item() < config.steps:
|
||||
logger.write()
|
||||
print('Start evaluation.')
|
||||
video_pred = agent._wm.video_pred(next(eval_dataset))
|
||||
logger.video('eval_openl', video_pred)
|
||||
eval_policy = functools.partial(agent, training=False)
|
||||
tools.simulate(eval_policy, eval_envs, episodes=1)
|
||||
print('Start training.')
|
||||
state = tools.simulate(agent, train_envs, config.eval_every, state=state)
|
||||
agent.save(logdir / 'variables.pkl')
|
||||
for env in train_envs + eval_envs:
|
||||
try:
|
||||
env.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--configs', nargs='+', required=True)
|
||||
args, remaining = parser.parse_known_args()
|
||||
configs = yaml.safe_load(
|
||||
(pathlib.Path(sys.argv[0]).parent / 'configs.yaml').read_text())
|
||||
defaults = {}
|
||||
for name in args.configs:
|
||||
defaults.update(configs[name])
|
||||
parser = argparse.ArgumentParser()
|
||||
for key, value in sorted(defaults.items(), key=lambda x: x[0]):
|
||||
arg_type = tools.args_type(value)
|
||||
parser.add_argument(f'--{key}', type=arg_type, default=arg_type(value))
|
||||
main(parser.parse_args(remaining))
|
||||
@@ -0,0 +1,91 @@
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.mixed_precision import experimental as prec
|
||||
from tensorflow_probability import distributions as tfd
|
||||
|
||||
import models
|
||||
import networks
|
||||
import tools
|
||||
|
||||
|
||||
class Random(tools.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
self._config = config
|
||||
self._float = prec.global_policy().compute_dtype
|
||||
|
||||
def actor(self, feat):
|
||||
shape = feat.shape[:-1] + [self._config.num_actions]
|
||||
if self._config.actor_dist == 'onehot':
|
||||
return tools.OneHotDist(tf.zeros(shape))
|
||||
else:
|
||||
ones = tf.ones(shape, self._float)
|
||||
return tfd.Uniform(-ones, ones)
|
||||
|
||||
def train(self, start, context):
|
||||
return None, {}
|
||||
|
||||
|
||||
class Plan2Explore(tools.Module):
|
||||
|
||||
def __init__(self, config, world_model, reward=None):
|
||||
self._config = config
|
||||
self._reward = reward
|
||||
self._behavior = models.ImagBehavior(config, world_model)
|
||||
self.actor = self._behavior.actor
|
||||
stoch_size = config.dyn_stoch
|
||||
if config.dyn_discrete:
|
||||
stoch_size *= config.dyn_discrete
|
||||
size = {
|
||||
'embed': 32 * config.cnn_depth,
|
||||
'stoch': stoch_size,
|
||||
'deter': config.dyn_deter,
|
||||
'feat': config.dyn_stoch + config.dyn_deter,
|
||||
}[self._config.disag_target]
|
||||
kw = dict(
|
||||
shape=size, layers=config.disag_layers, units=config.disag_units,
|
||||
act=config.act)
|
||||
self._networks = [
|
||||
networks.DenseHead(**kw) for _ in range(config.disag_models)]
|
||||
self._opt = tools.Optimizer(
|
||||
'ensemble', config.model_lr, config.opt_eps, config.grad_clip,
|
||||
config.weight_decay, opt=config.opt)
|
||||
|
||||
def train(self, start, context):
|
||||
metrics = {}
|
||||
stoch = start['stoch']
|
||||
if self._config.dyn_discrete:
|
||||
stoch = tf.reshape(
|
||||
stoch, stoch.shape[:-2] + (stoch.shape[-2] * stoch.shape[-1]))
|
||||
target = {
|
||||
'embed': context['embed'],
|
||||
'stoch': stoch,
|
||||
'deter': start['deter'],
|
||||
'feat': context['feat'],
|
||||
}[self._config.disag_target]
|
||||
metrics.update(self._train_ensemble(context['feat'], target))
|
||||
metrics.update(self._behavior.train(start, self._intrinsic_reward)[-1])
|
||||
return None, metrics
|
||||
|
||||
def _intrinsic_reward(self, feat, state, action):
|
||||
preds = [head(feat, tf.float32).mean() for head in self._networks]
|
||||
disag = tf.reduce_mean(tf.math.reduce_std(preds, 0), -1)
|
||||
if self._config.disag_log:
|
||||
disag = tf.math.log(disag)
|
||||
reward = self._config.expl_intr_scale * disag
|
||||
if self._config.expl_extr_scale:
|
||||
reward += tf.cast(self._config.expl_extr_scale * self._reward(
|
||||
feat, state, action), tf.float32)
|
||||
return reward
|
||||
|
||||
def _train_ensemble(self, inputs, targets):
|
||||
if self._config.disag_offset:
|
||||
targets = targets[:, self._config.disag_offset:]
|
||||
inputs = inputs[:, :-self._config.disag_offset]
|
||||
targets = tf.stop_gradient(targets)
|
||||
inputs = tf.stop_gradient(inputs)
|
||||
with tf.GradientTape() as tape:
|
||||
preds = [head(inputs) for head in self._networks]
|
||||
likes = [tf.reduce_mean(pred.log_prob(targets)) for pred in preds]
|
||||
loss = -tf.cast(tf.reduce_sum(likes), tf.float32)
|
||||
metrics = self._opt(tape, loss, self._networks)
|
||||
return metrics
|
||||
@@ -0,0 +1,260 @@
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.mixed_precision import experimental as prec
|
||||
|
||||
import networks
|
||||
import tools
|
||||
|
||||
|
||||
class WorldModel(tools.Module):
|
||||
|
||||
def __init__(self, step, config):
|
||||
self._step = step
|
||||
self._config = config
|
||||
self.encoder = networks.ConvEncoder(
|
||||
config.cnn_depth, config.act, config.encoder_kernels)
|
||||
self.dynamics = networks.RSSM(
|
||||
config.dyn_stoch, config.dyn_deter, config.dyn_hidden,
|
||||
config.dyn_input_layers, config.dyn_output_layers,
|
||||
config.dyn_rec_depth, config.dyn_shared, config.dyn_discrete,
|
||||
config.act, config.dyn_mean_act, config.dyn_std_act,
|
||||
config.dyn_temp_post, config.dyn_min_std, config.dyn_cell)
|
||||
self.heads = {}
|
||||
channels = (1 if config.grayscale else 3)
|
||||
shape = config.size + (channels,)
|
||||
self.heads['image'] = networks.ConvDecoder(
|
||||
config.cnn_depth, config.act, shape, config.decoder_kernels,
|
||||
config.decoder_thin)
|
||||
self.heads['reward'] = networks.DenseHead(
|
||||
[], config.reward_layers, config.units, config.act)
|
||||
if config.pred_discount:
|
||||
self.heads['discount'] = networks.DenseHead(
|
||||
[], config.discount_layers, config.units, config.act, dist='binary')
|
||||
for name in config.grad_heads:
|
||||
assert name in self.heads, name
|
||||
self._model_opt = tools.Optimizer(
|
||||
'model', config.model_lr, config.opt_eps, config.grad_clip,
|
||||
config.weight_decay, opt=config.opt)
|
||||
self._scales = dict(
|
||||
reward=config.reward_scale, discount=config.discount_scale)
|
||||
|
||||
def train(self, data):
|
||||
data = self.preprocess(data)
|
||||
with tf.GradientTape() as model_tape:
|
||||
embed = self.encoder(data)
|
||||
post, prior = self.dynamics.observe(embed, data['action'])
|
||||
kl_balance = tools.schedule(self._config.kl_balance, self._step)
|
||||
kl_free = tools.schedule(self._config.kl_free, self._step)
|
||||
kl_scale = tools.schedule(self._config.kl_scale, self._step)
|
||||
kl_loss, kl_value = self.dynamics.kl_loss(
|
||||
post, prior, self._config.kl_forward, kl_balance, kl_free, kl_scale)
|
||||
losses = {}
|
||||
likes = {}
|
||||
for name, head in self.heads.items():
|
||||
grad_head = (name in self._config.grad_heads)
|
||||
feat = self.dynamics.get_feat(post)
|
||||
feat = feat if grad_head else tf.stop_gradient(feat)
|
||||
pred = head(feat, tf.float32)
|
||||
like = pred.log_prob(tf.cast(data[name], tf.float32))
|
||||
likes[name] = like
|
||||
losses[name] = -tf.reduce_mean(like) * self._scales.get(name, 1.0)
|
||||
model_loss = sum(losses.values()) + kl_loss
|
||||
model_parts = [self.encoder, self.dynamics] + list(self.heads.values())
|
||||
metrics = self._model_opt(model_tape, model_loss, model_parts)
|
||||
metrics.update({f'{name}_loss': loss for name, loss in losses.items()})
|
||||
metrics['kl_balance'] = kl_balance
|
||||
metrics['kl_free'] = kl_free
|
||||
metrics['kl_scale'] = kl_scale
|
||||
metrics['kl'] = tf.reduce_mean(kl_value)
|
||||
metrics['prior_ent'] = self.dynamics.get_dist(prior).entropy()
|
||||
metrics['post_ent'] = self.dynamics.get_dist(post).entropy()
|
||||
context = dict(
|
||||
embed=embed, feat=self.dynamics.get_feat(post),
|
||||
kl=kl_value, postent=self.dynamics.get_dist(post).entropy())
|
||||
return post, context, metrics
|
||||
|
||||
@tf.function
|
||||
def preprocess(self, obs):
|
||||
dtype = prec.global_policy().compute_dtype
|
||||
obs = obs.copy()
|
||||
obs['image'] = tf.cast(obs['image'], dtype) / 255.0 - 0.5
|
||||
obs['reward'] = getattr(tf, self._config.clip_rewards)(obs['reward'])
|
||||
if 'discount' in obs:
|
||||
obs['discount'] *= self._config.discount
|
||||
for key, value in obs.items():
|
||||
if tf.dtypes.as_dtype(value.dtype) in (
|
||||
tf.float16, tf.float32, tf.float64):
|
||||
obs[key] = tf.cast(value, dtype)
|
||||
return obs
|
||||
|
||||
@tf.function
|
||||
def video_pred(self, data):
|
||||
data = self.preprocess(data)
|
||||
truth = data['image'][:6] + 0.5
|
||||
embed = self.encoder(data)
|
||||
states, _ = self.dynamics.observe(embed[:6, :5], data['action'][:6, :5])
|
||||
recon = self.heads['image'](
|
||||
self.dynamics.get_feat(states)).mode()[:6]
|
||||
init = {k: v[:, -1] for k, v in states.items()}
|
||||
prior = self.dynamics.imagine(data['action'][:6, 5:], init)
|
||||
openl = self.heads['image'](self.dynamics.get_feat(prior)).mode()
|
||||
model = tf.concat([recon[:, :5] + 0.5, openl + 0.5], 1)
|
||||
error = (model - truth + 1) / 2
|
||||
return tf.concat([truth, model, error], 2)
|
||||
|
||||
|
||||
class ImagBehavior(tools.Module):
|
||||
|
||||
def __init__(self, config, world_model, stop_grad_actor=True, reward=None):
|
||||
self._config = config
|
||||
self._world_model = world_model
|
||||
self._stop_grad_actor = stop_grad_actor
|
||||
self._reward = reward
|
||||
self.actor = networks.ActionHead(
|
||||
config.num_actions, config.actor_layers, config.units, config.act,
|
||||
config.actor_dist, config.actor_init_std, config.actor_min_std,
|
||||
config.actor_dist, config.actor_temp, config.actor_outscale)
|
||||
self.value = networks.DenseHead(
|
||||
[], config.value_layers, config.units, config.act,
|
||||
config.value_head)
|
||||
if config.slow_value_target or config.slow_actor_target:
|
||||
self._slow_value = networks.DenseHead(
|
||||
[], config.value_layers, config.units, config.act)
|
||||
self._updates = tf.Variable(0, tf.int64)
|
||||
kw = dict(wd=config.weight_decay, opt=config.opt)
|
||||
self._actor_opt = tools.Optimizer(
|
||||
'actor', config.actor_lr, config.opt_eps, config.actor_grad_clip, **kw)
|
||||
self._value_opt = tools.Optimizer(
|
||||
'value', config.value_lr, config.opt_eps, config.value_grad_clip, **kw)
|
||||
|
||||
def train(
|
||||
self, start, objective=None, imagine=None, tape=None, repeats=None):
|
||||
objective = objective or self._reward
|
||||
self._update_slow_target()
|
||||
metrics = {}
|
||||
with (tape or tf.GradientTape()) as actor_tape:
|
||||
assert bool(objective) != bool(imagine)
|
||||
if objective:
|
||||
imag_feat, imag_state, imag_action = self._imagine(
|
||||
start, self.actor, self._config.imag_horizon, repeats)
|
||||
reward = objective(imag_feat, imag_state, imag_action)
|
||||
else:
|
||||
imag_feat, imag_state, imag_action, reward = imagine(start)
|
||||
actor_ent = self.actor(imag_feat, tf.float32).entropy()
|
||||
state_ent = self._world_model.dynamics.get_dist(
|
||||
imag_state, tf.float32).entropy()
|
||||
target, weights = self._compute_target(
|
||||
imag_feat, imag_state, imag_action, reward, actor_ent, state_ent,
|
||||
self._config.slow_actor_target)
|
||||
actor_loss, mets = self._compute_actor_loss(
|
||||
imag_feat, imag_state, imag_action, target, actor_ent, state_ent,
|
||||
weights)
|
||||
metrics.update(mets)
|
||||
if self._config.slow_value_target != self._config.slow_actor_target:
|
||||
target, weights = self._compute_target(
|
||||
imag_feat, imag_state, imag_action, reward, actor_ent, state_ent,
|
||||
self._config.slow_value_target)
|
||||
value_input = imag_feat
|
||||
with tf.GradientTape() as value_tape:
|
||||
value = self.value(value_input, tf.float32)[:-1]
|
||||
value_loss = -value.log_prob(tf.stop_gradient(target))
|
||||
if self._config.value_decay:
|
||||
value_loss += self._config.value_decay * value.mode()
|
||||
value_loss = tf.reduce_mean(weights[:-1] * value_loss)
|
||||
metrics['reward_mean'] = tf.reduce_mean(reward)
|
||||
metrics['reward_std'] = tf.math.reduce_std(reward)
|
||||
metrics['actor_ent'] = tf.reduce_mean(actor_ent)
|
||||
metrics.update(self._actor_opt(actor_tape, actor_loss, [self.actor]))
|
||||
metrics.update(self._value_opt(value_tape, value_loss, [self.value]))
|
||||
return imag_feat, imag_state, imag_action, weights, metrics
|
||||
|
||||
def _imagine(self, start, policy, horizon, repeats=None):
|
||||
dynamics = self._world_model.dynamics
|
||||
if repeats:
|
||||
start = {k: tf.repeat(v, repeats, axis=1) for k, v in start.items()}
|
||||
flatten = lambda x: tf.reshape(x, [-1] + list(x.shape[2:]))
|
||||
start = {k: flatten(v) for k, v in start.items()}
|
||||
def step(prev, _):
|
||||
state, _, _ = prev
|
||||
feat = dynamics.get_feat(state)
|
||||
inp = tf.stop_gradient(feat) if self._stop_grad_actor else feat
|
||||
action = policy(inp).sample()
|
||||
succ = dynamics.img_step(state, action, sample=self._config.imag_sample)
|
||||
return succ, feat, action
|
||||
feat = 0 * dynamics.get_feat(start)
|
||||
action = policy(feat).mode()
|
||||
succ, feats, actions = tools.static_scan(
|
||||
step, tf.range(horizon), (start, feat, action))
|
||||
states = {k: tf.concat([
|
||||
start[k][None], v[:-1]], 0) for k, v in succ.items()}
|
||||
if repeats:
|
||||
def unfold(tensor):
|
||||
s = tensor.shape
|
||||
return tf.reshape(tensor, [s[0], s[1] // repeats, repeats] + s[2:])
|
||||
states, feats, actions = tf.nest.map_structure(
|
||||
unfold, (states, feats, actions))
|
||||
return feats, states, actions
|
||||
|
||||
def _compute_target(
|
||||
self, imag_feat, imag_state, imag_action, reward, actor_ent, state_ent,
|
||||
slow):
|
||||
reward = tf.cast(reward, tf.float32)
|
||||
if 'discount' in self._world_model.heads:
|
||||
inp = self._world_model.dynamics.get_feat(imag_state)
|
||||
discount = self._world_model.heads['discount'](inp, tf.float32).mean()
|
||||
else:
|
||||
discount = self._config.discount * tf.ones_like(reward)
|
||||
if self._config.future_entropy and tf.greater(
|
||||
self._config.actor_entropy(), 0):
|
||||
reward += self._config.actor_entropy() * actor_ent
|
||||
if self._config.future_entropy and tf.greater(
|
||||
self._config.actor_state_entropy(), 0):
|
||||
reward += self._config.actor_state_entropy() * state_ent
|
||||
if slow:
|
||||
value = self._slow_value(imag_feat, tf.float32).mode()
|
||||
else:
|
||||
value = self.value(imag_feat, tf.float32).mode()
|
||||
target = tools.lambda_return(
|
||||
reward[:-1], value[:-1], discount[:-1],
|
||||
bootstrap=value[-1], lambda_=self._config.discount_lambda, axis=0)
|
||||
weights = tf.stop_gradient(tf.math.cumprod(tf.concat(
|
||||
[tf.ones_like(discount[:1]), discount[:-1]], 0), 0))
|
||||
return target, weights
|
||||
|
||||
def _compute_actor_loss(
|
||||
self, imag_feat, imag_state, imag_action, target, actor_ent, state_ent,
|
||||
weights):
|
||||
metrics = {}
|
||||
inp = tf.stop_gradient(imag_feat) if self._stop_grad_actor else imag_feat
|
||||
policy = self.actor(inp, tf.float32)
|
||||
actor_ent = policy.entropy()
|
||||
if self._config.imag_gradient == 'dynamics':
|
||||
actor_target = target
|
||||
elif self._config.imag_gradient == 'reinforce':
|
||||
imag_action = tf.cast(imag_action, tf.float32)
|
||||
actor_target = policy.log_prob(imag_action)[:-1] * tf.stop_gradient(
|
||||
target - self.value(imag_feat[:-1], tf.float32).mode())
|
||||
elif self._config.imag_gradient == 'both':
|
||||
imag_action = tf.cast(imag_action, tf.float32)
|
||||
actor_target = policy.log_prob(imag_action)[:-1] * tf.stop_gradient(
|
||||
target - self.value(imag_feat[:-1], tf.float32).mode())
|
||||
mix = self._config.imag_gradient_mix()
|
||||
actor_target = mix * target + (1 - mix) * actor_target
|
||||
metrics['imag_gradient_mix'] = mix
|
||||
else:
|
||||
raise NotImplementedError(self._config.imag_gradient)
|
||||
if not self._config.future_entropy and tf.greater(
|
||||
self._config.actor_entropy(), 0):
|
||||
actor_target += self._config.actor_entropy() * actor_ent[:-1]
|
||||
if not self._config.future_entropy and tf.greater(
|
||||
self._config.actor_state_entropy(), 0):
|
||||
actor_target += self._config.actor_state_entropy() * state_ent[:-1]
|
||||
actor_loss = -tf.reduce_mean(weights[:-1] * actor_target)
|
||||
return actor_loss, metrics
|
||||
|
||||
def _update_slow_target(self):
|
||||
if self._config.slow_value_target or self._config.slow_actor_target:
|
||||
if self._updates % self._config.slow_target_update == 0:
|
||||
mix = self._config.slow_target_fraction
|
||||
for s, d in zip(self.value.variables, self._slow_value.variables):
|
||||
d.assign(mix * s + (1 - mix) * d)
|
||||
self._updates.assign_add(1)
|
||||
+397
@@ -0,0 +1,397 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras import layers as tfkl
|
||||
from tensorflow_probability import distributions as tfd
|
||||
from tensorflow.keras.mixed_precision import experimental as prec
|
||||
|
||||
import tools
|
||||
|
||||
|
||||
class RSSM(tools.Module):
|
||||
|
||||
def __init__(
|
||||
self, stoch=30, deter=200, hidden=200, layers_input=1, layers_output=1,
|
||||
rec_depth=1, shared=False, discrete=False, act=tf.nn.elu,
|
||||
mean_act='none', std_act='softplus', temp_post=True, min_std=0.1,
|
||||
cell='keras'):
|
||||
super().__init__()
|
||||
self._stoch = stoch
|
||||
self._deter = deter
|
||||
self._hidden = hidden
|
||||
self._min_std = min_std
|
||||
self._layers_input = layers_input
|
||||
self._layers_output = layers_output
|
||||
self._rec_depth = rec_depth
|
||||
self._shared = shared
|
||||
self._discrete = discrete
|
||||
self._act = act
|
||||
self._mean_act = mean_act
|
||||
self._std_act = std_act
|
||||
self._temp_post = temp_post
|
||||
self._embed = None
|
||||
if cell == 'gru':
|
||||
self._cell = tfkl.GRUCell(self._deter)
|
||||
elif cell == 'gru_layer_norm':
|
||||
self._cell = GRUCell(self._deter, norm=True)
|
||||
else:
|
||||
raise NotImplementedError(cell)
|
||||
|
||||
def initial(self, batch_size):
|
||||
dtype = prec.global_policy().compute_dtype
|
||||
if self._discrete:
|
||||
state = dict(
|
||||
logit=tf.zeros([batch_size, self._stoch, self._discrete], dtype),
|
||||
stoch=tf.zeros([batch_size, self._stoch, self._discrete], dtype),
|
||||
deter=self._cell.get_initial_state(None, batch_size, dtype))
|
||||
else:
|
||||
state = dict(
|
||||
mean=tf.zeros([batch_size, self._stoch], dtype),
|
||||
std=tf.zeros([batch_size, self._stoch], dtype),
|
||||
stoch=tf.zeros([batch_size, self._stoch], dtype),
|
||||
deter=self._cell.get_initial_state(None, batch_size, dtype))
|
||||
return state
|
||||
|
||||
@tf.function
|
||||
def observe(self, embed, action, state=None):
|
||||
swap = lambda x: tf.transpose(x, [1, 0] + list(range(2, len(x.shape))))
|
||||
if state is None:
|
||||
state = self.initial(tf.shape(action)[0])
|
||||
embed, action = swap(embed), swap(action)
|
||||
post, prior = tools.static_scan(
|
||||
lambda prev, inputs: self.obs_step(prev[0], *inputs),
|
||||
(action, embed), (state, state))
|
||||
post = {k: swap(v) for k, v in post.items()}
|
||||
prior = {k: swap(v) for k, v in prior.items()}
|
||||
return post, prior
|
||||
|
||||
@tf.function
|
||||
def imagine(self, action, state=None):
|
||||
swap = lambda x: tf.transpose(x, [1, 0] + list(range(2, len(x.shape))))
|
||||
if state is None:
|
||||
state = self.initial(tf.shape(action)[0])
|
||||
assert isinstance(state, dict), state
|
||||
action = swap(action)
|
||||
prior = tools.static_scan(self.img_step, action, state)
|
||||
prior = {k: swap(v) for k, v in prior.items()}
|
||||
return prior
|
||||
|
||||
def get_feat(self, state):
|
||||
stoch = state['stoch']
|
||||
if self._discrete:
|
||||
shape = stoch.shape[:-2] + [self._stoch * self._discrete]
|
||||
stoch = tf.reshape(stoch, shape)
|
||||
return tf.concat([stoch, state['deter']], -1)
|
||||
|
||||
def get_dist(self, state, dtype=None):
|
||||
if self._discrete:
|
||||
logit = state['logit']
|
||||
logit = tf.cast(logit, tf.float32)
|
||||
dist = tfd.Independent(tools.OneHotDist(logit), 1)
|
||||
if dtype != tf.float32:
|
||||
dist = tools.DtypeDist(dist, dtype or state['logit'].dtype)
|
||||
else:
|
||||
mean, std = state['mean'], state['std']
|
||||
if dtype:
|
||||
mean = tf.cast(mean, dtype)
|
||||
std = tf.cast(std, dtype)
|
||||
dist = tfd.MultivariateNormalDiag(mean, std)
|
||||
return dist
|
||||
|
||||
@tf.function
|
||||
def obs_step(self, prev_state, prev_action, embed, sample=True):
|
||||
if not self._embed:
|
||||
self._embed = embed.shape[-1]
|
||||
prior = self.img_step(prev_state, prev_action, None, sample)
|
||||
if self._shared:
|
||||
post = self.img_step(prev_state, prev_action, embed, sample)
|
||||
else:
|
||||
if self._temp_post:
|
||||
x = tf.concat([prior['deter'], embed], -1)
|
||||
else:
|
||||
x = embed
|
||||
for i in range(self._layers_output):
|
||||
x = self.get(f'obi{i}', tfkl.Dense, self._hidden, self._act)(x)
|
||||
stats = self._suff_stats_layer('obs', x)
|
||||
if sample:
|
||||
stoch = self.get_dist(stats).sample()
|
||||
else:
|
||||
stoch = self.get_dist(stats).mode()
|
||||
post = {'stoch': stoch, 'deter': prior['deter'], **stats}
|
||||
return post, prior
|
||||
|
||||
@tf.function
|
||||
def img_step(self, prev_state, prev_action, embed=None, sample=True):
|
||||
prev_stoch = prev_state['stoch']
|
||||
if self._discrete:
|
||||
shape = prev_stoch.shape[:-2] + [self._stoch * self._discrete]
|
||||
prev_stoch = tf.reshape(prev_stoch, shape)
|
||||
if self._shared:
|
||||
if embed is None:
|
||||
shape = prev_action.shape[:-1] + [self._embed]
|
||||
embed = tf.zeros(shape, prev_action.dtype)
|
||||
x = tf.concat([prev_stoch, prev_action, embed], -1)
|
||||
else:
|
||||
x = tf.concat([prev_stoch, prev_action], -1)
|
||||
for i in range(self._layers_input):
|
||||
x = self.get(f'ini{i}', tfkl.Dense, self._hidden, self._act)(x)
|
||||
for _ in range(self._rec_depth):
|
||||
deter = prev_state['deter']
|
||||
x, deter = self._cell(x, [deter])
|
||||
deter = deter[0] # Keras wraps the state in a list.
|
||||
for i in range(self._layers_output):
|
||||
x = self.get(f'imo{i}', tfkl.Dense, self._hidden, self._act)(x)
|
||||
stats = self._suff_stats_layer('ims', x)
|
||||
if sample:
|
||||
stoch = self.get_dist(stats).sample()
|
||||
else:
|
||||
stoch = self.get_dist(stats).mode()
|
||||
prior = {'stoch': stoch, 'deter': deter, **stats}
|
||||
return prior
|
||||
|
||||
def _suff_stats_layer(self, name, x):
|
||||
if self._discrete:
|
||||
x = self.get(name, tfkl.Dense, self._stoch * self._discrete, None)(x)
|
||||
logit = tf.reshape(x, x.shape[:-1] + [self._stoch, self._discrete])
|
||||
return {'logit': logit}
|
||||
else:
|
||||
x = self.get(name, tfkl.Dense, 2 * self._stoch, None)(x)
|
||||
mean, std = tf.split(x, 2, -1)
|
||||
mean = {
|
||||
'none': lambda: mean,
|
||||
'tanh5': lambda: 5.0 * tf.math.tanh(mean / 5.0),
|
||||
}[self._mean_act]()
|
||||
std = {
|
||||
'softplus': lambda: tf.nn.softplus(std),
|
||||
'abs': lambda: tf.math.abs(std + 1),
|
||||
'sigmoid': lambda: tf.nn.sigmoid(std),
|
||||
'sigmoid2': lambda: 2 * tf.nn.sigmoid(std / 2),
|
||||
}[self._std_act]()
|
||||
std = std + self._min_std
|
||||
return {'mean': mean, 'std': std}
|
||||
|
||||
def kl_loss(self, post, prior, forward, balance, free, scale):
|
||||
kld = tfd.kl_divergence
|
||||
dist = lambda x: self.get_dist(x, tf.float32)
|
||||
sg = lambda x: tf.nest.map_structure(tf.stop_gradient, x)
|
||||
lhs, rhs = (prior, post) if forward else (post, prior)
|
||||
mix = balance if forward else (1 - balance)
|
||||
if balance == 0.5:
|
||||
value = kld(dist(lhs), dist(rhs))
|
||||
loss = tf.reduce_mean(tf.maximum(value, free))
|
||||
else:
|
||||
value_lhs = value = kld(dist(lhs), dist(sg(rhs)))
|
||||
value_rhs = kld(dist(sg(lhs)), dist(rhs))
|
||||
loss_lhs = tf.maximum(tf.reduce_mean(value_lhs), free)
|
||||
loss_rhs = tf.maximum(tf.reduce_mean(value_rhs), free)
|
||||
loss = mix * loss_lhs + (1 - mix) * loss_rhs
|
||||
loss *= scale
|
||||
return loss, value
|
||||
|
||||
|
||||
class ConvEncoder(tools.Module):
|
||||
|
||||
def __init__(
|
||||
self, depth=32, act=tf.nn.relu, kernels=(4, 4, 4, 4)):
|
||||
self._act = act
|
||||
self._depth = depth
|
||||
self._kernels = kernels
|
||||
|
||||
def __call__(self, obs):
|
||||
x = tf.reshape(obs['image'], (-1,) + tuple(obs['image'].shape[-3:]))
|
||||
for i, kernel in enumerate(self._kernels):
|
||||
depth = 2 ** i * self._depth
|
||||
x = self._act(self.get(f'h{i}', tfkl.Conv2D, depth, kernel, 2)(x))
|
||||
x = tf.reshape(x, [x.shape[0], np.prod(x.shape[1:])])
|
||||
# print('Encoder output:', x.shape)
|
||||
shape = tf.concat([tf.shape(obs['image'])[:-3], [x.shape[-1]]], 0)
|
||||
return tf.reshape(x, shape)
|
||||
|
||||
|
||||
class ConvDecoder(tools.Module):
|
||||
|
||||
def __init__(
|
||||
self, depth=32, act=tf.nn.relu, shape=(64, 64, 3), kernels=(5, 5, 6, 6),
|
||||
thin=True):
|
||||
self._act = act
|
||||
self._depth = depth
|
||||
self._shape = shape
|
||||
self._kernels = kernels
|
||||
self._thin = thin
|
||||
|
||||
def __call__(self, features, dtype=None):
|
||||
ConvT = tfkl.Conv2DTranspose
|
||||
if self._thin:
|
||||
x = self.get('hin', tfkl.Dense, 32 * self._depth, None)(features)
|
||||
x = tf.reshape(x, [-1, 1, 1, 32 * self._depth])
|
||||
else:
|
||||
x = self.get('hin', tfkl.Dense, 128 * self._depth, None)(features)
|
||||
x = tf.reshape(x, [-1, 2, 2, 32 * self._depth])
|
||||
for i, kernel in enumerate(self._kernels):
|
||||
depth = 2 ** (len(self._kernels) - i - 1) * self._depth
|
||||
act = self._act
|
||||
if i == len(self._kernels) - 1:
|
||||
depth = self._shape[-1]
|
||||
act = None
|
||||
x = self.get(f'h{i}', ConvT, depth, kernel, 2, activation=act)(x)
|
||||
# print('Decoder output:', x.shape)
|
||||
mean = tf.reshape(x, tf.concat([tf.shape(features)[:-1], self._shape], 0))
|
||||
if dtype:
|
||||
mean = tf.cast(mean, dtype)
|
||||
return tfd.Independent(tfd.Normal(mean, 1), len(self._shape))
|
||||
|
||||
|
||||
class DenseHead(tools.Module):
|
||||
|
||||
def __init__(
|
||||
self, shape, layers, units, act=tf.nn.elu, dist='normal', std=1.0):
|
||||
self._shape = (shape,) if isinstance(shape, int) else shape
|
||||
self._layers = layers
|
||||
self._units = units
|
||||
self._act = act
|
||||
self._dist = dist
|
||||
self._std = std
|
||||
|
||||
def __call__(self, features, dtype=None):
|
||||
x = features
|
||||
for index in range(self._layers):
|
||||
x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x)
|
||||
mean = self.get(f'hmean', tfkl.Dense, np.prod(self._shape))(x)
|
||||
mean = tf.reshape(mean, tf.concat(
|
||||
[tf.shape(features)[:-1], self._shape], 0))
|
||||
if self._std == 'learned':
|
||||
std = self.get(f'hstd', tfkl.Dense, np.prod(self._shape))(x)
|
||||
std = tf.nn.softplus(std) + 0.01
|
||||
std = tf.reshape(std, tf.concat(
|
||||
[tf.shape(features)[:-1], self._shape], 0))
|
||||
else:
|
||||
std = self._std
|
||||
if dtype:
|
||||
mean, std = tf.cast(mean, dtype), tf.cast(std, dtype)
|
||||
if self._dist == 'normal':
|
||||
return tfd.Independent(tfd.Normal(mean, std), len(self._shape))
|
||||
if self._dist == 'huber':
|
||||
return tfd.Independent(
|
||||
tools.UnnormalizedHuber(mean, std, 1.0), len(self._shape))
|
||||
if self._dist == 'binary':
|
||||
return tfd.Independent(tfd.Bernoulli(mean), len(self._shape))
|
||||
raise NotImplementedError(self._dist)
|
||||
|
||||
|
||||
class ActionHead(tools.Module):
|
||||
|
||||
def __init__(
|
||||
self, size, layers, units, act=tf.nn.elu, dist='trunc_normal',
|
||||
init_std=0.0, min_std=0.1, action_disc=5, temp=0.1, outscale=0):
|
||||
# assert min_std <= 2
|
||||
self._size = size
|
||||
self._layers = layers
|
||||
self._units = units
|
||||
self._dist = dist
|
||||
self._act = act
|
||||
self._min_std = min_std
|
||||
self._init_std = init_std
|
||||
self._action_disc = action_disc
|
||||
self._temp = temp() if callable(temp) else temp
|
||||
self._outscale = outscale
|
||||
|
||||
def __call__(self, features, dtype=None):
|
||||
x = features
|
||||
for index in range(self._layers):
|
||||
kw = {}
|
||||
if index == self._layers - 1 and self._outscale:
|
||||
kw['kernel_initializer'] = tf.keras.initializers.VarianceScaling(
|
||||
self._outscale)
|
||||
x = self.get(f'h{index}', tfkl.Dense, self._units, self._act, **kw)(x)
|
||||
if self._dist == 'tanh_normal':
|
||||
# https://www.desmos.com/calculator/rcmcf5jwe7
|
||||
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
|
||||
if dtype:
|
||||
x = tf.cast(x, dtype)
|
||||
mean, std = tf.split(x, 2, -1)
|
||||
mean = tf.tanh(mean)
|
||||
std = tf.nn.softplus(std + self._init_std) + self._min_std
|
||||
dist = tfd.Normal(mean, std)
|
||||
dist = tfd.TransformedDistribution(dist, tools.TanhBijector())
|
||||
dist = tfd.Independent(dist, 1)
|
||||
dist = tools.SampleDist(dist)
|
||||
elif self._dist == 'tanh_normal_5':
|
||||
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
|
||||
if dtype:
|
||||
x = tf.cast(x, dtype)
|
||||
mean, std = tf.split(x, 2, -1)
|
||||
mean = 5 * tf.tanh(mean / 5)
|
||||
std = tf.nn.softplus(std + 5) + 5
|
||||
dist = tfd.Normal(mean, std)
|
||||
dist = tfd.TransformedDistribution(dist, tools.TanhBijector())
|
||||
dist = tfd.Independent(dist, 1)
|
||||
dist = tools.SampleDist(dist)
|
||||
elif self._dist == 'normal':
|
||||
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
|
||||
if dtype:
|
||||
x = tf.cast(x, dtype)
|
||||
mean, std = tf.split(x, 2, -1)
|
||||
std = tf.nn.softplus(std + self._init_std) + self._min_std
|
||||
dist = tfd.Normal(mean, std)
|
||||
dist = tfd.Independent(dist, 1)
|
||||
elif self._dist == 'normal_1':
|
||||
mean = self.get(f'hout', tfkl.Dense, self._size)(x)
|
||||
if dtype:
|
||||
mean = tf.cast(mean, dtype)
|
||||
dist = tfd.Normal(mean, 1)
|
||||
dist = tfd.Independent(dist, 1)
|
||||
elif self._dist == 'trunc_normal':
|
||||
# https://www.desmos.com/calculator/mmuvuhnyxo
|
||||
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
|
||||
x = tf.cast(x, tf.float32)
|
||||
mean, std = tf.split(x, 2, -1)
|
||||
mean = tf.tanh(mean)
|
||||
std = 2 * tf.nn.sigmoid(std / 2) + self._min_std
|
||||
dist = tools.SafeTruncatedNormal(mean, std, -1, 1)
|
||||
dist = tools.DtypeDist(dist, dtype)
|
||||
dist = tfd.Independent(dist, 1)
|
||||
elif self._dist == 'onehot':
|
||||
x = self.get(f'hout', tfkl.Dense, self._size)(x)
|
||||
x = tf.cast(x, tf.float32)
|
||||
dist = tools.OneHotDist(x, dtype=dtype)
|
||||
dist = tools.DtypeDist(dist, dtype)
|
||||
elif self._dist == 'onehot_gumble':
|
||||
x = self.get(f'hout', tfkl.Dense, self._size)(x)
|
||||
if dtype:
|
||||
x = tf.cast(x, dtype)
|
||||
temp = self._temp
|
||||
dist = tools.GumbleDist(temp, x, dtype=dtype)
|
||||
else:
|
||||
raise NotImplementedError(self._dist)
|
||||
return dist
|
||||
|
||||
|
||||
class GRUCell(tf.keras.layers.AbstractRNNCell):
|
||||
|
||||
def __init__(self, size, norm=False, act=tf.tanh, update_bias=-1, **kwargs):
|
||||
super().__init__()
|
||||
self._size = size
|
||||
self._act = act
|
||||
self._norm = norm
|
||||
self._update_bias = update_bias
|
||||
self._layer = tfkl.Dense(3 * size, use_bias=norm is not None, **kwargs)
|
||||
if norm:
|
||||
self._norm = tfkl.LayerNormalization(dtype=tf.float32)
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return self._size
|
||||
|
||||
def call(self, inputs, state):
|
||||
state = state[0] # Keras wraps the state in a list.
|
||||
parts = self._layer(tf.concat([inputs, state], -1))
|
||||
if self._norm:
|
||||
dtype = parts.dtype
|
||||
parts = tf.cast(parts, tf.float32)
|
||||
parts = self._norm(parts)
|
||||
parts = tf.cast(parts, dtype)
|
||||
reset, cand, update = tf.split(parts, 3, -1)
|
||||
reset = tf.nn.sigmoid(reset)
|
||||
cand = self._act(reset * cand)
|
||||
update = tf.nn.sigmoid(update + self._update_bias)
|
||||
output = update * cand + (1 - update) * state
|
||||
return output, [output]
|
||||
+540
@@ -0,0 +1,540 @@
|
||||
import argparse
|
||||
import collections
|
||||
import functools
|
||||
import itertools
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
import subprocess
|
||||
import warnings
|
||||
|
||||
os.environ['NO_AT_BRIDGE'] = '1' # Hide X org false warning.
|
||||
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as ticker
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
np.set_string_function(lambda x: f'<np.array shape={x.shape} dtype={x.dtype}>')
|
||||
|
||||
Run = collections.namedtuple('Run', 'task method seed xs ys')
|
||||
|
||||
PALETTES = dict(
|
||||
discrete=(
|
||||
'#377eb8', '#4daf4a', '#984ea3', '#e41a1c', '#ff7f00', '#a65628',
|
||||
'#f781bf', '#888888', '#a6cee3', '#b2df8a', '#cab2d6', '#fb9a99',
|
||||
),
|
||||
contrast=(
|
||||
'#0022ff', '#33aa00', '#ff0011', '#ddaa00', '#cc44dd', '#0088aa',
|
||||
'#001177', '#117700', '#990022', '#885500', '#553366', '#006666',
|
||||
),
|
||||
gradient=(
|
||||
'#fde725', '#a0da39', '#4ac16d', '#1fa187', '#277f8e', '#365c8d',
|
||||
'#46327e', '#440154',
|
||||
),
|
||||
baselines=(
|
||||
'#222222', '#666666', '#aaaaaa', '#cccccc',
|
||||
),
|
||||
)
|
||||
|
||||
LEGEND = dict(
|
||||
fontsize='medium', numpoints=1, labelspacing=0, columnspacing=1.2,
|
||||
handlelength=1.5, handletextpad=0.5, ncol=4, loc='lower center')
|
||||
|
||||
DEFAULT_BASELINES = [
|
||||
'd4pg', 'dqn_sticky', 'rainbow_sticky', 'human$', 'impala']
|
||||
|
||||
BINS = collections.defaultdict(int)
|
||||
BINS.update(dmc=1e5, atari=1e6, particle=1e5)
|
||||
|
||||
|
||||
def find_keys(args):
|
||||
filenames = []
|
||||
for indir in args.indir:
|
||||
task = next(indir.iterdir()) # First only.
|
||||
for method in task.iterdir():
|
||||
seed = next(indir.iterdir()) # First only.
|
||||
filenames += list(seed.glob('**/*.jsonl'))
|
||||
keys = set()
|
||||
for filename in filenames:
|
||||
keys |= set(load_jsonl(filename).columns)
|
||||
print(f'Keys ({len(keys)}):', ', '.join(keys), flush=True)
|
||||
|
||||
|
||||
def load_runs(args):
|
||||
total, toload = [], []
|
||||
for indir in args.indir:
|
||||
filenames = list(indir.glob('**/*.jsonl'))
|
||||
total += filenames
|
||||
for filename in filenames:
|
||||
task, method, seed = filename.relative_to(indir).parts[:-1]
|
||||
if not any(p.search(task) for p in args.tasks):
|
||||
continue
|
||||
if not any(p.search(method) for p in args.methods):
|
||||
continue
|
||||
toload.append((filename, indir))
|
||||
print(f'Loading {len(toload)} of {len(total)} runs...')
|
||||
jobs = [functools.partial(load_run, f, i, args) for f, i in toload]
|
||||
# Disable async data loading:
|
||||
# runs = [j() for j in jobs]
|
||||
with mp.Pool(10) as pool:
|
||||
promises = [pool.apply_async(j) for j in jobs]
|
||||
runs = [p.get() for p in promises]
|
||||
runs = [r for r in runs if r is not None]
|
||||
return runs
|
||||
|
||||
|
||||
def load_run(filename, indir, args):
|
||||
task, method, seed = filename.relative_to(indir).parts[:-1]
|
||||
prefix = f'indir{args.indir.index(indir)+1}_'
|
||||
if task == 'atari_jamesbond':
|
||||
task = 'atari_james_bond'
|
||||
seed = prefix + seed
|
||||
if args.prefix:
|
||||
method = prefix + method
|
||||
df = load_jsonl(filename)
|
||||
if df is None:
|
||||
print('Skipping empty run')
|
||||
return
|
||||
try:
|
||||
df = df[[args.xaxis, args.yaxis]].dropna()
|
||||
if args.maxval:
|
||||
df = df.replace([+np.inf], +args.maxval)
|
||||
df = df.replace([-np.inf], -args.maxval)
|
||||
df[args.yaxis] = df[args.yaxis].clip(-args.maxval, +args.maxval)
|
||||
except KeyError:
|
||||
return
|
||||
xs = df[args.xaxis].to_numpy()
|
||||
ys = df[args.yaxis].to_numpy()
|
||||
bins = BINS[task.split('_')[0]] if args.bins == -1 else args.bins
|
||||
if bins:
|
||||
borders = np.arange(0, xs.max() + 1e-8, bins)
|
||||
xs, ys = bin_scores(xs, ys, borders)
|
||||
if not len(xs):
|
||||
print('Skipping empty run', task, method, seed)
|
||||
return
|
||||
return Run(task, method, seed, xs, ys)
|
||||
|
||||
|
||||
def load_baselines(patterns, prefix=False):
|
||||
runs = []
|
||||
directory = pathlib.Path(__file__).parent / 'scores'
|
||||
for filename in directory.glob('**/*_baselines.json'):
|
||||
for task, methods in json.loads(filename.read_text()).items():
|
||||
for method, score in methods.items():
|
||||
if prefix:
|
||||
method = f'baseline_{method}'
|
||||
if not any(p.search(method) for p in patterns):
|
||||
continue
|
||||
runs.append(Run(task, method, None, None, score))
|
||||
return runs
|
||||
|
||||
|
||||
def stats(runs, baselines):
|
||||
tasks = sorted(set(r.task for r in runs))
|
||||
methods = sorted(set(r.method for r in runs))
|
||||
seeds = sorted(set(r.seed for r in runs))
|
||||
baseline = sorted(set(r.method for r in baselines))
|
||||
print('Loaded', len(runs), 'runs.')
|
||||
print(f'Tasks ({len(tasks)}):', ', '.join(tasks))
|
||||
print(f'Methods ({len(methods)}):', ', '.join(methods))
|
||||
print(f'Seeds ({len(seeds)}):', ', '.join(seeds))
|
||||
print(f'Baselines ({len(baseline)}):', ', '.join(baseline))
|
||||
|
||||
|
||||
def order_methods(runs, baselines, args):
|
||||
methods = []
|
||||
for pattern in args.methods:
|
||||
for method in sorted(set(r.method for r in runs)):
|
||||
if pattern.search(method):
|
||||
if method not in methods:
|
||||
methods.append(method)
|
||||
if method not in args.colors:
|
||||
index = len(args.colors) % len(args.palette)
|
||||
args.colors[method] = args.palette[index]
|
||||
non_baseline_colors = len(args.colors)
|
||||
for pattern in args.baselines:
|
||||
for method in sorted(set(r.method for r in baselines)):
|
||||
if pattern.search(method):
|
||||
if method not in methods:
|
||||
methods.append(method)
|
||||
if method not in args.colors:
|
||||
index = len(args.colors) - non_baseline_colors
|
||||
index = index % len(PALETTES['baselines'])
|
||||
args.colors[method] = PALETTES['baselines'][index]
|
||||
return methods
|
||||
|
||||
|
||||
def figure(runs, methods, args):
|
||||
tasks = sorted(set(r.task for r in runs if r.xs is not None))
|
||||
rows = int(np.ceil((len(tasks) + len(args.add)) / args.cols))
|
||||
figsize = args.size[0] * args.cols, args.size[1] * rows
|
||||
fig, axes = plt.subplots(rows, args.cols, figsize=figsize)
|
||||
for task, ax in zip(tasks, axes.flatten()):
|
||||
relevant = [r for r in runs if r.task == task]
|
||||
plot(task, ax, relevant, methods, args)
|
||||
for name, ax in zip(args.add, axes.flatten()[len(tasks):]):
|
||||
ax.set_facecolor((0.9, 0.9, 0.9))
|
||||
if name == 'median':
|
||||
plot_combined(
|
||||
'combined_median', ax, runs, methods, args,
|
||||
lo='random', hi='human$',
|
||||
agg=lambda x: np.nanmedian(x, -1))
|
||||
elif name == 'mean':
|
||||
plot_combined(
|
||||
'combined_mean', ax, runs, methods, args,
|
||||
lo='random', hi='human$',
|
||||
agg=lambda x: np.nanmean(x, -1))
|
||||
elif name == 'gamer_median':
|
||||
plot_combined(
|
||||
'combined_gamer_median', ax, runs, methods, args,
|
||||
lo='random', hi='human$',
|
||||
agg=lambda x: np.nanmedian(x, -1))
|
||||
elif name == 'gamer_mean':
|
||||
plot_combined(
|
||||
'combined_gamer_mean', ax, runs, methods, args,
|
||||
lo='random', hi='human$',
|
||||
agg=lambda x: np.nanmean(x, -1))
|
||||
elif name == 'record_mean':
|
||||
plot_combined(
|
||||
'combined_record_mean', ax, runs, methods, args,
|
||||
lo='random', hi='record',
|
||||
agg=lambda x: np.nanmean(x, -1))
|
||||
elif name == 'clipped_record_mean':
|
||||
plot_combined(
|
||||
'combined_clipped_record_mean', ax, runs, methods, args,
|
||||
lo='random', hi='record', clip=True,
|
||||
agg=lambda x: np.nanmean(x, -1))
|
||||
elif name == 'num_seeds':
|
||||
plot_combined(
|
||||
'combined_num_seeds', ax, runs, methods, args,
|
||||
agg=lambda x: np.isfinite(x).sum(-1))
|
||||
elif name == 'human_above':
|
||||
plot_combined(
|
||||
'combined_above_human$', ax, runs, methods, args,
|
||||
agg=lambda y: (y >= 1.0).astype(float).sum(-1))
|
||||
elif name == 'human_below':
|
||||
plot_combined(
|
||||
'combined_below_human$', ax, runs, methods, args,
|
||||
agg=lambda y: (y <= 1.0).astype(float).sum(-1))
|
||||
else:
|
||||
raise NotImplementedError(name)
|
||||
if args.xlim:
|
||||
for ax in axes[:-1].flatten():
|
||||
ax.xaxis.get_offset_text().set_visible(False)
|
||||
if args.xlabel:
|
||||
for ax in axes[-1]:
|
||||
ax.set_xlabel(args.xlabel)
|
||||
if args.ylabel:
|
||||
for ax in axes[:, 0]:
|
||||
ax.set_ylabel(args.ylabel)
|
||||
for ax in axes.flatten()[len(tasks) + len(args.add):]:
|
||||
ax.axis('off')
|
||||
legend(fig, args.labels, **LEGEND)
|
||||
return fig
|
||||
|
||||
|
||||
def plot(task, ax, runs, methods, args):
|
||||
assert runs
|
||||
try:
|
||||
title = task.split('_', 1)[1].replace('_', ' ').title()
|
||||
except IndexError:
|
||||
title = task.title()
|
||||
ax.set_title(title)
|
||||
xlim = [+np.inf, -np.inf]
|
||||
for index, method in enumerate(methods):
|
||||
relevant = [r for r in runs if r.method == method]
|
||||
if not relevant:
|
||||
continue
|
||||
if any(r.xs is None for r in relevant):
|
||||
baseline(index, method, ax, relevant, args)
|
||||
else:
|
||||
if args.aggregate == 'none':
|
||||
xs, ys = curve_lines(index, task, method, ax, relevant, args)
|
||||
else:
|
||||
xs, ys = curve_area(index, task, method, ax, relevant, args)
|
||||
if len(xs) == len(ys) == 0:
|
||||
print(f'Skipping empty: {task} {method}')
|
||||
continue
|
||||
xlim = [min(xlim[0], xs.min()), max(xlim[1], xs.max())]
|
||||
ax.ticklabel_format(axis='x', style='sci', scilimits=(0, 0))
|
||||
steps = [1, 2, 2.5, 5, 10]
|
||||
ax.xaxis.set_major_locator(ticker.MaxNLocator(args.xticks, steps=steps))
|
||||
ax.yaxis.set_major_locator(ticker.MaxNLocator(args.yticks, steps=steps))
|
||||
if np.isfinite(xlim).all():
|
||||
ax.set_xlim(args.xlim or xlim)
|
||||
if args.xlim:
|
||||
ticks = sorted({*ax.get_xticks(), *args.xlim})
|
||||
ticks = [x for x in ticks if args.xlim[0] <= x <= args.xlim[1]]
|
||||
ax.set_xticks(ticks)
|
||||
if args.ylim:
|
||||
ax.set_ylim(args.ylim)
|
||||
if args.ylimticks:
|
||||
ticks = sorted({*ax.get_yticks(), *args.ylim})
|
||||
ticks = [x for x in ticks if args.ylim[0] <= x <= args.ylim[1]]
|
||||
ax.set_yticks(ticks)
|
||||
|
||||
|
||||
def plot_combined(
|
||||
name, ax, runs, methods, args, agg, lo=None, hi=None, clip=False):
|
||||
tasks = sorted(set(run.task for run in runs if run.xs is not None))
|
||||
seeds = list(set(run.seed for run in runs))
|
||||
runs = [r for r in runs if r.task in tasks] # Discard unused baselines.
|
||||
# Bin all runs onto the same X steps.
|
||||
borders = sorted(
|
||||
[r.xs for r in runs if r.xs is not None],
|
||||
key=lambda x: np.nanmax(x))[-1]
|
||||
for index, run in enumerate(runs):
|
||||
if run.xs is None:
|
||||
continue
|
||||
xs, ys = bin_scores(run.xs, run.ys, borders)
|
||||
runs[index] = run._replace(xs=xs, ys=ys)
|
||||
# Per-task normalization by low and high baseline.
|
||||
if lo or hi:
|
||||
mins = collections.defaultdict(list)
|
||||
maxs = collections.defaultdict(list)
|
||||
[mins[r.task].append(r.ys) for r in load_baselines([re.compile(lo)])]
|
||||
[maxs[r.task].append(r.ys) for r in load_baselines([re.compile(hi)])]
|
||||
mins = {task: min(ys) for task, ys in mins.items() if task in tasks}
|
||||
maxs = {task: max(ys) for task, ys in maxs.items() if task in tasks}
|
||||
missing_baselines = []
|
||||
for task in tasks:
|
||||
if task not in mins or task not in maxs:
|
||||
missing_baselines.append(task)
|
||||
if set(missing_baselines) == set(tasks):
|
||||
print(f'No baselines found to normalize any tasks in {name} plot.')
|
||||
else:
|
||||
for task in missing_baselines:
|
||||
print(f'No baselines found to normalize {task} in {name} plot.')
|
||||
for index, run in enumerate(runs):
|
||||
if run.task not in mins or run.task not in maxs:
|
||||
continue
|
||||
ys = (run.ys - mins[run.task]) / (maxs[run.task] - mins[run.task])
|
||||
if clip:
|
||||
ys = np.minimum(ys, 1.0)
|
||||
runs[index] = run._replace(ys=ys)
|
||||
# Aggregate across tasks but not methods or seeds.
|
||||
combined = []
|
||||
for method, seed in itertools.product(methods, seeds):
|
||||
relevant = [r for r in runs if r.method == method and r.seed == seed]
|
||||
if not relevant:
|
||||
continue
|
||||
if relevant[0].xs is None:
|
||||
xs, ys = None, np.array([r.ys for r in relevant])
|
||||
else:
|
||||
xs, ys = stack_scores(*zip(*[(r.xs, r.ys) for r in relevant]))
|
||||
with warnings.catch_warnings(): # Ignore empty slice warnings.
|
||||
warnings.simplefilter('ignore', category=RuntimeWarning)
|
||||
combined.append(Run('combined', method, seed, xs, agg(ys)))
|
||||
plot(name, ax, combined, methods, args)
|
||||
|
||||
|
||||
def curve_lines(index, task, method, ax, runs, args):
|
||||
zorder = 10000 - 10 * index - 1
|
||||
for run in runs:
|
||||
color = args.colors[method]
|
||||
ax.plot(run.xs, run.ys, label=method, color=color, zorder=zorder)
|
||||
return runs[0].xs, runs[0].ys
|
||||
|
||||
|
||||
def curve_area(index, task, method, ax, runs, args):
|
||||
xs, ys = stack_scores(*zip(*[(r.xs, r.ys) for r in runs]))
|
||||
with warnings.catch_warnings(): # NaN buckets remain NaN.
|
||||
warnings.simplefilter('ignore', category=RuntimeWarning)
|
||||
if args.aggregate == 'std1':
|
||||
mean, std = np.nanmean(ys, -1), np.nanstd(ys, -1)
|
||||
lo, mi, hi = mean - std, mean, mean + std
|
||||
elif args.aggregate == 'per0':
|
||||
lo, mi, hi = [np.nanpercentile(ys, k, -1) for k in (0, 50, 100)]
|
||||
elif args.aggregate == 'per5':
|
||||
lo, mi, hi = [np.nanpercentile(ys, k, -1) for k in (5, 50, 95)]
|
||||
elif args.aggregate == 'per25':
|
||||
lo, mi, hi = [np.nanpercentile(ys, k, -1) for k in (25, 50, 75)]
|
||||
else:
|
||||
raise NotImplementedError(args.aggregate)
|
||||
color = args.colors[method]
|
||||
kw = dict(color=color, zorder=1000 - 10 * index, alpha=0.1, linewidths=0)
|
||||
ax.fill_between(xs, lo, hi, **kw)
|
||||
ax.plot(xs, mi, label=method, color=color, zorder=10000 - 10 * index - 1)
|
||||
return xs, mi
|
||||
|
||||
|
||||
def baseline(index, method, ax, runs, args):
|
||||
assert all(run.xs is None for run in runs)
|
||||
ys = np.array([run.ys for run in runs])
|
||||
mean, std = ys.mean(), ys.std()
|
||||
color = args.colors[method]
|
||||
kw = dict(color=color, zorder=500 - 20 * index - 1, alpha=0.1, linewidths=0)
|
||||
ax.fill_between([-np.inf, np.inf], [mean - std] * 2, [mean + std] * 2, **kw)
|
||||
kw = dict(ls='--', color=color, zorder=5000 - 10 * index - 1)
|
||||
ax.axhline(mean, label=method, **kw)
|
||||
|
||||
|
||||
def legend(fig, mapping=None, **kwargs):
|
||||
entries = {}
|
||||
for ax in fig.axes:
|
||||
for handle, label in zip(*ax.get_legend_handles_labels()):
|
||||
if mapping and label in mapping:
|
||||
label = mapping[label]
|
||||
entries[label] = handle
|
||||
leg = fig.legend(entries.values(), entries.keys(), **kwargs)
|
||||
leg.get_frame().set_edgecolor('white')
|
||||
extent = leg.get_window_extent(fig.canvas.get_renderer())
|
||||
extent = extent.transformed(fig.transFigure.inverted())
|
||||
yloc, xloc = kwargs['loc'].split()
|
||||
y0 = dict(lower=extent.y1, center=0, upper=0)[yloc]
|
||||
y1 = dict(lower=1, center=1, upper=extent.y0)[yloc]
|
||||
x0 = dict(left=extent.x1, center=0, right=0)[xloc]
|
||||
x1 = dict(left=1, center=1, right=extent.x0)[xloc]
|
||||
fig.tight_layout(rect=[x0, y0, x1, y1], h_pad=0.5, w_pad=0.5)
|
||||
|
||||
|
||||
def save(fig, args):
|
||||
args.outdir.mkdir(parents=True, exist_ok=True)
|
||||
filename = args.outdir / 'curves.png'
|
||||
fig.savefig(filename, dpi=args.dpi)
|
||||
print('Saved to', filename)
|
||||
filename = args.outdir / 'curves.pdf'
|
||||
fig.savefig(filename)
|
||||
try:
|
||||
subprocess.call(['pdfcrop', str(filename), str(filename)])
|
||||
except FileNotFoundError:
|
||||
print('Install texlive-extra-utils to crop PDF outputs.')
|
||||
|
||||
|
||||
def bin_scores(xs, ys, borders, reducer=np.nanmean):
|
||||
order = np.argsort(xs)
|
||||
xs, ys = xs[order], ys[order]
|
||||
binned = []
|
||||
with warnings.catch_warnings(): # Empty buckets become NaN.
|
||||
warnings.simplefilter('ignore', category=RuntimeWarning)
|
||||
for start, stop in zip(borders[:-1], borders[1:]):
|
||||
left = (xs <= start).sum()
|
||||
right = (xs <= stop).sum()
|
||||
binned.append(reducer(ys[left:right]))
|
||||
return borders[1:], np.array(binned)
|
||||
|
||||
|
||||
def stack_scores(multiple_xs, multiple_ys):
|
||||
longest_xs = sorted(multiple_xs, key=lambda x: len(x))[-1]
|
||||
multiple_padded_ys = []
|
||||
for xs, ys in zip(multiple_xs, multiple_ys):
|
||||
assert (longest_xs[:len(xs)] == xs).all(), (list(xs), list(longest_xs))
|
||||
padding = [np.inf] * (len(longest_xs) - len(xs))
|
||||
padded_ys = np.concatenate([ys, padding])
|
||||
multiple_padded_ys.append(padded_ys)
|
||||
stacked_ys = np.stack(multiple_padded_ys, -1)
|
||||
return longest_xs, stacked_ys
|
||||
|
||||
|
||||
def load_jsonl(filename):
|
||||
try:
|
||||
with filename.open() as f:
|
||||
lines = list(f.readlines())
|
||||
records = []
|
||||
for index, line in enumerate(lines):
|
||||
try:
|
||||
records.append(json.loads(line))
|
||||
except Exception:
|
||||
if index == len(lines) - 1:
|
||||
continue # Silently skip last line if it is incomplete.
|
||||
raise ValueError(
|
||||
f'Skipping invalid JSON line ({index+1}/{len(lines)+1}) in'
|
||||
f'{filename}: {line}')
|
||||
return pd.DataFrame(records)
|
||||
except ValueError as e:
|
||||
print('Invalid', filename, e)
|
||||
return None
|
||||
|
||||
|
||||
def save_runs(runs, filename):
|
||||
filename.parent.mkdir(parents=True, exist_ok=True)
|
||||
records = []
|
||||
for run in runs:
|
||||
if run.xs is None:
|
||||
continue
|
||||
records.append(dict(
|
||||
task=run.task, method=run.method, seed=run.seed,
|
||||
xs=run.xs.tolist(), ys=run.ys.tolist()))
|
||||
runs = json.dumps(records)
|
||||
filename.write_text(runs)
|
||||
print('Saved', filename)
|
||||
|
||||
|
||||
def main(args):
|
||||
find_keys(args)
|
||||
runs = load_runs(args)
|
||||
save_runs(runs, args.outdir / 'runs.jsonl')
|
||||
baselines = load_baselines(args.baselines, args.prefix)
|
||||
stats(runs, baselines)
|
||||
methods = order_methods(runs, baselines, args)
|
||||
if not runs:
|
||||
print('Noting to plot.')
|
||||
return
|
||||
print('Plotting...')
|
||||
fig = figure(runs + baselines, methods, args)
|
||||
save(fig, args)
|
||||
|
||||
|
||||
def parse_args():
|
||||
boolean = lambda x: bool(['False', 'True'].index(x))
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--indir', nargs='+', type=pathlib.Path, required=True)
|
||||
parser.add_argument('--outdir', type=pathlib.Path, required=True)
|
||||
parser.add_argument('--subdir', type=boolean, default=True)
|
||||
parser.add_argument('--xaxis', type=str, required=True)
|
||||
parser.add_argument('--yaxis', type=str, required=True)
|
||||
parser.add_argument('--tasks', nargs='+', default=[r'.*'])
|
||||
parser.add_argument('--methods', nargs='+', default=[r'.*'])
|
||||
parser.add_argument('--baselines', nargs='+', default=DEFAULT_BASELINES)
|
||||
parser.add_argument('--prefix', type=boolean, default=False)
|
||||
parser.add_argument('--bins', type=float, default=-1)
|
||||
parser.add_argument('--aggregate', type=str, default='std1')
|
||||
parser.add_argument('--size', nargs=2, type=float, default=[2.5, 2.3])
|
||||
parser.add_argument('--dpi', type=int, default=80)
|
||||
parser.add_argument('--cols', type=int, default=6)
|
||||
parser.add_argument('--xlim', nargs=2, type=float, default=None)
|
||||
parser.add_argument('--ylim', nargs=2, type=float, default=None)
|
||||
parser.add_argument('--ylimticks', type=boolean, default=True)
|
||||
parser.add_argument('--xlabel', type=str, default=None)
|
||||
parser.add_argument('--ylabel', type=str, default=None)
|
||||
parser.add_argument('--xticks', type=int, default=6)
|
||||
parser.add_argument('--yticks', type=int, default=5)
|
||||
parser.add_argument('--labels', nargs='+', default=None)
|
||||
parser.add_argument('--palette', nargs='+', default=['contrast'])
|
||||
parser.add_argument('--colors', nargs='+', default={})
|
||||
parser.add_argument('--maxval', type=float, default=0)
|
||||
parser.add_argument('--add', nargs='+', type=str, default=[
|
||||
'gamer_median', 'gamer_mean', 'record_mean',
|
||||
'clipped_record_mean', 'num_seeds'])
|
||||
args = parser.parse_args()
|
||||
if args.subdir:
|
||||
args.outdir /= args.indir[0].stem
|
||||
args.indir = [d.expanduser() for d in args.indir]
|
||||
args.outdir = args.outdir.expanduser()
|
||||
if args.labels:
|
||||
assert len(args.labels) % 2 == 0
|
||||
args.labels = {k: v for k, v in zip(args.labels[:-1], args.labels[1:])}
|
||||
if args.colors:
|
||||
assert len(args.colors) % 2 == 0
|
||||
args.colors = {k: v for k, v in zip(args.colors[:-1], args.colors[1:])}
|
||||
args.tasks = [re.compile(p) for p in args.tasks]
|
||||
args.methods = [re.compile(p) for p in args.methods]
|
||||
args.baselines = [re.compile(p) for p in args.baselines]
|
||||
if 'return' not in args.yaxis:
|
||||
args.baselines = []
|
||||
if args.prefix is None:
|
||||
args.prefix = len(args.indir) > 1
|
||||
if len(args.palette) == 1 and args.palette[0] in PALETTES:
|
||||
args.palette = 10 * PALETTES[args.palette[0]]
|
||||
if len(args.add) == 1 and args.add[0] == 'none':
|
||||
args.add = []
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(parse_args())
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -0,0 +1 @@
|
||||
{"dmc_acrobot_swingup": {"d4pg_100m": 91.7, "a3c_100m_proprio": 41.9}, "dmc_cartpole_balance": {"d4pg_100m": 992.8, "a3c_100m_proprio": 951.6}, "dmc_cartpole_swingup": {"d4pg_100m": 862.0, "planet_1e6": 821, "a3c_100m_proprio": 558.4}, "dmc_cartpole_balance_sparse": {"d4pg_100m": 1000.0, "a3c_100m_proprio": 857.4}, "dmc_cartpole_swingup_sparse": {"d4pg_100m": 482.0, "a3c_100m_proprio": 179.8}, "dmc_cheetah_run": {"slac_3e6": 880, "d4pg_100m": 523.8, "planet_1e6": 662, "a3c_100m_proprio": 213.9}, "dmc_cup_catch": {"slac_3e6": 970, "d4pg_100m": 980.5, "planet_1e6": 930, "a3c_100m_proprio": 104.7}, "dmc_finger_spin": {"slac_3e6": 950, "d4pg_100m": 985.7, "planet_1e6": 700, "a3c_100m_proprio": 129.4}, "dmc_finger_turn_easy": {"d4pg_100m": 971.4, "a3c_100m_proprio": 167.3}, "dmc_finger_turn_hard": {"d4pg_100m": 966.0, "a3c_100m_proprio": 88.7}, "dmc_hopper_hop": {"d4pg_100m": 242.0, "a3c_100m_proprio": 0.5}, "dmc_hopper_stand": {"d4pg_100m": 929.9, "a3c_100m_proprio": 27.9}, "dmc_reacher_easy": {"d4pg_100m": 967.4, "planet_1e6": 832, "a3c_100m_proprio": 95.6}, "dmc_reacher_hard": {"d4pg_100m": 957.1, "a3c_100m_proprio": 39.7}, "dmc_walker_stand": {"d4pg_100m": 985.2, "a3c_100m_proprio": 378.4}, "dmc_walker_walk": {"slac_3e6": 840, "d4pg_100m": 968.3, "planet_1e6": 951, "a3c_100m_proprio": 311.0}, "dmc_walker_run": {"d4pg_100m": 567.2, "a3c_100m_proprio": 191.8}, "dmc_pendulum_swingup": {"d4pg_100m": 680.9, "a3c_100m_proprio": 48.6}}
|
||||
@@ -0,0 +1,694 @@
|
||||
import datetime
|
||||
import io
|
||||
import json
|
||||
import pathlib
|
||||
import pickle
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tf1
|
||||
import tensorflow_probability as tfp
|
||||
from tensorflow.keras.mixed_precision import experimental as prec
|
||||
from tensorflow_probability import distributions as tfd
|
||||
|
||||
|
||||
# Patch to ignore seed to avoid synchronization across GPUs.
|
||||
_orig_random_categorical = tf.random.categorical
|
||||
def random_categorical(*args, **kwargs):
|
||||
kwargs['seed'] = None
|
||||
return _orig_random_categorical(*args, **kwargs)
|
||||
tf.random.categorical = random_categorical
|
||||
|
||||
# Patch to ignore seed to avoid synchronization across GPUs.
|
||||
_orig_random_normal = tf.random.normal
|
||||
def random_normal(*args, **kwargs):
|
||||
kwargs['seed'] = None
|
||||
return _orig_random_normal(*args, **kwargs)
|
||||
tf.random.normal = random_normal
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
|
||||
__setattr__ = dict.__setitem__
|
||||
__getattr__ = dict.__getitem__
|
||||
|
||||
|
||||
class Module(tf.Module):
|
||||
|
||||
def save(self, filename):
|
||||
values = tf.nest.map_structure(lambda x: x.numpy(), self.variables)
|
||||
amount = len(tf.nest.flatten(values))
|
||||
count = int(sum(np.prod(x.shape) for x in tf.nest.flatten(values)))
|
||||
print(f'Save checkpoint with {amount} tensors and {count} parameters.')
|
||||
with pathlib.Path(filename).open('wb') as f:
|
||||
pickle.dump(values, f)
|
||||
|
||||
def load(self, filename):
|
||||
with pathlib.Path(filename).open('rb') as f:
|
||||
values = pickle.load(f)
|
||||
amount = len(tf.nest.flatten(values))
|
||||
count = int(sum(np.prod(x.shape) for x in tf.nest.flatten(values)))
|
||||
print(f'Load checkpoint with {amount} tensors and {count} parameters.')
|
||||
tf.nest.map_structure(lambda x, y: x.assign(y), self.variables, values)
|
||||
|
||||
def get(self, name, ctor, *args, **kwargs):
|
||||
# Create or get layer by name to avoid mentioning it in the constructor.
|
||||
if not hasattr(self, '_modules'):
|
||||
self._modules = {}
|
||||
if name not in self._modules:
|
||||
self._modules[name] = ctor(*args, **kwargs)
|
||||
return self._modules[name]
|
||||
|
||||
|
||||
def var_nest_names(nest):
|
||||
if isinstance(nest, dict):
|
||||
items = ' '.join(f'{k}:{var_nest_names(v)}' for k, v in nest.items())
|
||||
return '{' + items + '}'
|
||||
if isinstance(nest, (list, tuple)):
|
||||
items = ' '.join(var_nest_names(v) for v in nest)
|
||||
return '[' + items + ']'
|
||||
if hasattr(nest, 'name') and hasattr(nest, 'shape'):
|
||||
return nest.name + str(nest.shape).replace(', ', 'x')
|
||||
if hasattr(nest, 'shape'):
|
||||
return str(nest.shape).replace(', ', 'x')
|
||||
return '?'
|
||||
|
||||
|
||||
class Logger:
|
||||
|
||||
def __init__(self, logdir, step):
|
||||
self._logdir = logdir
|
||||
self._writer = tf.summary.create_file_writer(str(logdir), max_queue=1000)
|
||||
self._last_step = None
|
||||
self._last_time = None
|
||||
self._scalars = {}
|
||||
self._images = {}
|
||||
self._videos = {}
|
||||
self.step = step
|
||||
|
||||
def scalar(self, name, value):
|
||||
self._scalars[name] = float(value)
|
||||
|
||||
def image(self, name, value):
|
||||
self._images[name] = np.array(value)
|
||||
|
||||
def video(self, name, value):
|
||||
self._videos[name] = np.array(value)
|
||||
|
||||
def write(self, fps=False):
|
||||
scalars = list(self._scalars.items())
|
||||
if fps:
|
||||
scalars.append(('fps', self._compute_fps(self.step)))
|
||||
print(f'[{self.step}]', ' / '.join(f'{k} {v:.1f}' for k, v in scalars))
|
||||
with (self._logdir / 'metrics.jsonl').open('a') as f:
|
||||
f.write(json.dumps({'step': self.step, ** dict(scalars)}) + '\n')
|
||||
with self._writer.as_default():
|
||||
for name, value in scalars:
|
||||
tf.summary.scalar('scalars/' + name, value, self.step)
|
||||
for name, value in self._images.items():
|
||||
tf.summary.image(name, value, self.step)
|
||||
for name, value in self._videos.items():
|
||||
video_summary(name, value, self.step)
|
||||
self._writer.flush()
|
||||
self._scalars = {}
|
||||
self._images = {}
|
||||
self._videos = {}
|
||||
|
||||
def _compute_fps(self, step):
|
||||
if self._last_step is None:
|
||||
self._last_time = time.time()
|
||||
self._last_step = step
|
||||
return 0
|
||||
steps = step - self._last_step
|
||||
duration = time.time() - self._last_time
|
||||
self._last_time += duration
|
||||
self._last_step = step
|
||||
return steps / duration
|
||||
|
||||
|
||||
def graph_summary(writer, step, fn, *args):
|
||||
def inner(*args):
|
||||
tf.summary.experimental.set_step(step.numpy().item())
|
||||
with writer.as_default():
|
||||
fn(*args)
|
||||
return tf.numpy_function(inner, args, [])
|
||||
|
||||
|
||||
def video_summary(name, video, step=None, fps=20):
|
||||
name = name if isinstance(name, str) else name.decode('utf-8')
|
||||
if np.issubdtype(video.dtype, np.floating):
|
||||
video = np.clip(255 * video, 0, 255).astype(np.uint8)
|
||||
B, T, H, W, C = video.shape
|
||||
try:
|
||||
frames = video.transpose((1, 2, 0, 3, 4)).reshape((T, H, B * W, C))
|
||||
summary = tf1.Summary()
|
||||
image = tf1.Summary.Image(height=B * H, width=T * W, colorspace=C)
|
||||
image.encoded_image_string = encode_gif(frames, fps)
|
||||
summary.value.add(tag=name, image=image)
|
||||
tf.summary.experimental.write_raw_pb(summary.SerializeToString(), step)
|
||||
except (IOError, OSError) as e:
|
||||
print('GIF summaries require ffmpeg in $PATH.', e)
|
||||
frames = video.transpose((0, 2, 1, 3, 4)).reshape((1, B * H, T * W, C))
|
||||
tf.summary.image(name, frames, step)
|
||||
|
||||
|
||||
def encode_gif(frames, fps):
|
||||
from subprocess import Popen, PIPE
|
||||
h, w, c = frames[0].shape
|
||||
pxfmt = {1: 'gray', 3: 'rgb24'}[c]
|
||||
cmd = ' '.join([
|
||||
f'ffmpeg -y -f rawvideo -vcodec rawvideo',
|
||||
f'-r {fps:.02f} -s {w}x{h} -pix_fmt {pxfmt} -i - -filter_complex',
|
||||
f'[0:v]split[x][z];[z]palettegen[y];[x]fifo[x];[x][y]paletteuse',
|
||||
f'-r {fps:.02f} -f gif -'])
|
||||
proc = Popen(cmd.split(' '), stdin=PIPE, stdout=PIPE, stderr=PIPE)
|
||||
for image in frames:
|
||||
proc.stdin.write(image.tostring())
|
||||
out, err = proc.communicate()
|
||||
if proc.returncode:
|
||||
raise IOError('\n'.join([' '.join(cmd), err.decode('utf8')]))
|
||||
del proc
|
||||
return out
|
||||
|
||||
|
||||
def simulate(agent, envs, steps=0, episodes=0, state=None):
|
||||
# Initialize or unpack simulation state.
|
||||
if state is None:
|
||||
step, episode = 0, 0
|
||||
done = np.ones(len(envs), np.bool)
|
||||
length = np.zeros(len(envs), np.int32)
|
||||
obs = [None] * len(envs)
|
||||
agent_state = None
|
||||
else:
|
||||
step, episode, done, length, obs, agent_state = state
|
||||
while (steps and step < steps) or (episodes and episode < episodes):
|
||||
# Reset envs if necessary.
|
||||
if done.any():
|
||||
indices = [index for index, d in enumerate(done) if d]
|
||||
results = [envs[i].reset() for i in indices]
|
||||
for index, result in zip(indices, results):
|
||||
obs[index] = result
|
||||
# Step agents.
|
||||
obs = {k: np.stack([o[k] for o in obs]) for k in obs[0]}
|
||||
action, agent_state = agent(obs, done, agent_state)
|
||||
if isinstance(action, dict):
|
||||
action = [
|
||||
{k: np.array(action[k][i]) for k in action}
|
||||
for i in range(len(envs))]
|
||||
else:
|
||||
action = np.array(action)
|
||||
assert len(action) == len(envs)
|
||||
# Step envs.
|
||||
results = [e.step(a) for e, a in zip(envs, action)]
|
||||
obs, _, done = zip(*[p[:3] for p in results])
|
||||
obs = list(obs)
|
||||
done = np.stack(done)
|
||||
episode += int(done.sum())
|
||||
length += 1
|
||||
step += (done * length).sum()
|
||||
length *= (1 - done)
|
||||
# Return new state to allow resuming the simulation.
|
||||
return (step - steps, episode - episodes, done, length, obs, agent_state)
|
||||
|
||||
|
||||
def save_episodes(directory, episodes):
|
||||
directory = pathlib.Path(directory).expanduser()
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
|
||||
filenames = []
|
||||
for episode in episodes:
|
||||
identifier = str(uuid.uuid4().hex)
|
||||
length = len(episode['reward'])
|
||||
filename = directory / f'{timestamp}-{identifier}-{length}.npz'
|
||||
with io.BytesIO() as f1:
|
||||
np.savez_compressed(f1, **episode)
|
||||
f1.seek(0)
|
||||
with filename.open('wb') as f2:
|
||||
f2.write(f1.read())
|
||||
filenames.append(filename)
|
||||
return filenames
|
||||
|
||||
|
||||
def sample_episodes(episodes, length=None, balance=False, seed=0):
|
||||
random = np.random.RandomState(seed)
|
||||
while True:
|
||||
episode = random.choice(list(episodes.values()))
|
||||
if length:
|
||||
total = len(next(iter(episode.values())))
|
||||
available = total - length
|
||||
if available < 1:
|
||||
print(f'Skipped short episode of length {available}.')
|
||||
continue
|
||||
if balance:
|
||||
index = min(random.randint(0, total), available)
|
||||
else:
|
||||
index = int(random.randint(0, available + 1))
|
||||
episode = {k: v[index: index + length] for k, v in episode.items()}
|
||||
yield episode
|
||||
|
||||
|
||||
def load_episodes(directory, limit=None):
|
||||
directory = pathlib.Path(directory).expanduser()
|
||||
episodes = {}
|
||||
total = 0
|
||||
for filename in reversed(sorted(directory.glob('*.npz'))):
|
||||
try:
|
||||
with filename.open('rb') as f:
|
||||
episode = np.load(f)
|
||||
episode = {k: episode[k] for k in episode.keys()}
|
||||
except Exception as e:
|
||||
print(f'Could not load episode: {e}')
|
||||
continue
|
||||
episodes[str(filename)] = episode
|
||||
total += len(episode['reward']) - 1
|
||||
if limit and total >= limit:
|
||||
break
|
||||
return episodes
|
||||
|
||||
|
||||
class DtypeDist:
|
||||
|
||||
def __init__(self, dist, dtype=None):
|
||||
self._dist = dist
|
||||
self._dtype = dtype or prec.global_policy().compute_dtype
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return 'DtypeDist'
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._dist, name)
|
||||
|
||||
def mean(self):
|
||||
return tf.cast(self._dist.mean(), self._dtype)
|
||||
|
||||
def mode(self):
|
||||
return tf.cast(self._dist.mode(), self._dtype)
|
||||
|
||||
def entropy(self):
|
||||
return tf.cast(self._dist.entropy(), self._dtype)
|
||||
|
||||
def sample(self, *args, **kwargs):
|
||||
return tf.cast(self._dist.sample(*args, **kwargs), self._dtype)
|
||||
|
||||
|
||||
class SampleDist:
|
||||
|
||||
def __init__(self, dist, samples=100):
|
||||
self._dist = dist
|
||||
self._samples = samples
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return 'SampleDist'
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._dist, name)
|
||||
|
||||
def mean(self):
|
||||
samples = self._dist.sample(self._samples)
|
||||
return tf.reduce_mean(samples, 0)
|
||||
|
||||
def mode(self):
|
||||
sample = self._dist.sample(self._samples)
|
||||
logprob = self._dist.log_prob(sample)
|
||||
return tf.gather(sample, tf.argmax(logprob))[0]
|
||||
|
||||
def entropy(self):
|
||||
sample = self._dist.sample(self._samples)
|
||||
logprob = self.log_prob(sample)
|
||||
return -tf.reduce_mean(logprob, 0)
|
||||
|
||||
|
||||
class OneHotDist(tfd.OneHotCategorical):
|
||||
|
||||
def __init__(self, logits=None, probs=None, dtype=None):
|
||||
self._sample_dtype = dtype or prec.global_policy().compute_dtype
|
||||
super().__init__(logits=logits, probs=probs)
|
||||
|
||||
def mode(self):
|
||||
return tf.cast(super().mode(), self._sample_dtype)
|
||||
|
||||
def sample(self, sample_shape=(), seed=None):
|
||||
# Straight through biased gradient estimator.
|
||||
sample = tf.cast(super().sample(sample_shape, seed), self._sample_dtype)
|
||||
probs = super().probs_parameter()
|
||||
while len(probs.shape) < len(sample.shape):
|
||||
probs = probs[None]
|
||||
sample += tf.cast(probs - tf.stop_gradient(probs), self._sample_dtype)
|
||||
return sample
|
||||
|
||||
|
||||
class GumbleDist(tfd.RelaxedOneHotCategorical):
|
||||
|
||||
def __init__(self, temp, logits=None, probs=None, dtype=None):
|
||||
self._sample_dtype = dtype or prec.global_policy().compute_dtype
|
||||
self._exact = tfd.OneHotCategorical(logits=logits, probs=probs)
|
||||
super().__init__(temp, logits=logits, probs=probs)
|
||||
|
||||
def mode(self):
|
||||
return tf.cast(self._exact.mode(), self._sample_dtype)
|
||||
|
||||
def entropy(self):
|
||||
return tf.cast(self._exact.entropy(), self._sample_dtype)
|
||||
|
||||
def sample(self, sample_shape=(), seed=None):
|
||||
return tf.cast(super().sample(sample_shape, seed), self._sample_dtype)
|
||||
|
||||
|
||||
class UnnormalizedHuber(tfd.Normal):
|
||||
|
||||
def __init__(self, loc, scale, threshold=1, **kwargs):
|
||||
self._threshold = tf.cast(threshold, loc.dtype)
|
||||
super().__init__(loc, scale, **kwargs)
|
||||
|
||||
def log_prob(self, event):
|
||||
return -(tf.math.sqrt(
|
||||
(event - self.mean()) ** 2 + self._threshold ** 2) - self._threshold)
|
||||
|
||||
|
||||
class SafeTruncatedNormal(tfd.TruncatedNormal):
|
||||
|
||||
def __init__(self, loc, scale, low, high, clip=1e-6, mult=1):
|
||||
super().__init__(loc, scale, low, high)
|
||||
self._clip = clip
|
||||
self._mult = mult
|
||||
|
||||
def sample(self, *args, **kwargs):
|
||||
event = super().sample(*args, **kwargs)
|
||||
if self._clip:
|
||||
clipped = tf.clip_by_value(
|
||||
event, self.low + self._clip, self.high - self._clip)
|
||||
event = event - tf.stop_gradient(event) + tf.stop_gradient(clipped)
|
||||
if self._mult:
|
||||
event *= self._mult
|
||||
return event
|
||||
|
||||
|
||||
class TanhBijector(tfp.bijectors.Bijector):
|
||||
|
||||
def __init__(self, validate_args=False, name='tanh'):
|
||||
super().__init__(
|
||||
forward_min_event_ndims=0,
|
||||
validate_args=validate_args,
|
||||
name=name)
|
||||
|
||||
def _forward(self, x):
|
||||
return tf.nn.tanh(x)
|
||||
|
||||
def _inverse(self, y):
|
||||
dtype = y.dtype
|
||||
y = tf.cast(y, tf.float32)
|
||||
y = tf.where(
|
||||
tf.less_equal(tf.abs(y), 1.),
|
||||
tf.clip_by_value(y, -0.99999997, 0.99999997), y)
|
||||
y = tf.atanh(y)
|
||||
y = tf.cast(y, dtype)
|
||||
return y
|
||||
|
||||
def _forward_log_det_jacobian(self, x):
|
||||
log2 = tf.math.log(tf.constant(2.0, dtype=x.dtype))
|
||||
return 2.0 * (log2 - x - tf.nn.softplus(-2.0 * x))
|
||||
|
||||
|
||||
def lambda_return(
|
||||
reward, value, pcont, bootstrap, lambda_, axis):
|
||||
# Setting lambda=1 gives a discounted Monte Carlo return.
|
||||
# Setting lambda=0 gives a fixed 1-step return.
|
||||
assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape)
|
||||
if isinstance(pcont, (int, float)):
|
||||
pcont = pcont * tf.ones_like(reward)
|
||||
dims = list(range(reward.shape.ndims))
|
||||
dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:]
|
||||
if axis != 0:
|
||||
reward = tf.transpose(reward, dims)
|
||||
value = tf.transpose(value, dims)
|
||||
pcont = tf.transpose(pcont, dims)
|
||||
if bootstrap is None:
|
||||
bootstrap = tf.zeros_like(value[-1])
|
||||
next_values = tf.concat([value[1:], bootstrap[None]], 0)
|
||||
inputs = reward + pcont * next_values * (1 - lambda_)
|
||||
returns = static_scan(
|
||||
lambda agg, cur: cur[0] + cur[1] * lambda_ * agg,
|
||||
(inputs, pcont), bootstrap, reverse=True)
|
||||
if axis != 0:
|
||||
returns = tf.transpose(returns, dims)
|
||||
return returns
|
||||
|
||||
|
||||
class Optimizer(tf.Module):
|
||||
|
||||
def __init__(
|
||||
self, name, lr, eps=1e-4, clip=None, wd=None, wd_pattern=r'.*',
|
||||
opt='adam'):
|
||||
assert 0 <= wd < 1
|
||||
assert not clip or 1 <= clip
|
||||
self._name = name
|
||||
self._clip = clip
|
||||
self._wd = wd
|
||||
self._wd_pattern = wd_pattern
|
||||
self._opt = {
|
||||
'adam': lambda: tf.optimizers.Adam(lr, epsilon=eps),
|
||||
'nadam': lambda: tf.optimizers.Nadam(lr, epsilon=eps),
|
||||
'adamax': lambda: tf.optimizers.Adamax(lr, epsilon=eps),
|
||||
'sgd': lambda: tf.optimizers.SGD(lr),
|
||||
'momentum': lambda: tf.optimizers.SGD(lr, 0.9),
|
||||
}[opt]()
|
||||
self._mixed = (prec.global_policy().compute_dtype == tf.float16)
|
||||
if self._mixed:
|
||||
self._opt = prec.LossScaleOptimizer(self._opt, 'dynamic')
|
||||
|
||||
@property
|
||||
def variables(self):
|
||||
return self._opt.variables()
|
||||
|
||||
def __call__(self, tape, loss, modules):
|
||||
assert loss.dtype is tf.float32, self._name
|
||||
modules = modules if hasattr(modules, '__len__') else (modules,)
|
||||
varibs = tf.nest.flatten([module.variables for module in modules])
|
||||
count = sum(np.prod(x.shape) for x in varibs)
|
||||
print(f'Found {count} {self._name} parameters.')
|
||||
assert len(loss.shape) == 0, loss.shape
|
||||
tf.debugging.check_numerics(loss, self._name + '_loss')
|
||||
if self._mixed:
|
||||
with tape:
|
||||
loss = self._opt.get_scaled_loss(loss)
|
||||
grads = tape.gradient(loss, varibs)
|
||||
if self._mixed:
|
||||
grads = self._opt.get_unscaled_gradients(grads)
|
||||
norm = tf.linalg.global_norm(grads)
|
||||
if not self._mixed:
|
||||
tf.debugging.check_numerics(norm, self._name + '_norm')
|
||||
if self._clip:
|
||||
grads, _ = tf.clip_by_global_norm(grads, self._clip, norm)
|
||||
if self._wd:
|
||||
self._apply_weight_decay(varibs)
|
||||
self._opt.apply_gradients(zip(grads, varibs))
|
||||
metrics = {}
|
||||
metrics[f'{self._name}_loss'] = loss
|
||||
metrics[f'{self._name}_grad_norm'] = norm
|
||||
if self._mixed:
|
||||
try:
|
||||
metrics[f'{self._name}_loss_scale'] = float(self._opt.loss_scale)
|
||||
except TypeError:
|
||||
metrics[f'{self._name}_loss_scale'] = float(
|
||||
self._opt.loss_scale._current_loss_scale)
|
||||
return metrics
|
||||
|
||||
def _apply_weight_decay(self, varibs):
|
||||
nontrivial = (self._wd_pattern != r'.*')
|
||||
if nontrivial:
|
||||
print('Applied weight decay to variables:')
|
||||
for var in varibs:
|
||||
if re.search(self._wd_pattern, self._name + '/' + var.name):
|
||||
if nontrivial:
|
||||
print('- ' + self._name + '/' + var.name)
|
||||
var.assign((1 - self._wd) * var)
|
||||
|
||||
|
||||
def args_type(default):
|
||||
def parse_string(x):
|
||||
if default is None:
|
||||
return x
|
||||
if isinstance(default, bool):
|
||||
return bool(['False', 'True'].index(x))
|
||||
if isinstance(default, int):
|
||||
return float(x) if ('e' in x or '.' in x) else int(x)
|
||||
if isinstance(default, (list, tuple)):
|
||||
return tuple(args_type(default[0])(y) for y in x.split(','))
|
||||
return type(default)(x)
|
||||
def parse_object(x):
|
||||
if isinstance(default, (list, tuple)):
|
||||
return tuple(x)
|
||||
return x
|
||||
return lambda x: parse_string(x) if isinstance(x, str) else parse_object(x)
|
||||
|
||||
|
||||
def static_scan(fn, inputs, start, reverse=False):
|
||||
last = start
|
||||
outputs = [[] for _ in tf.nest.flatten(start)]
|
||||
indices = range(len(tf.nest.flatten(inputs)[0]))
|
||||
if reverse:
|
||||
indices = reversed(indices)
|
||||
for index in indices:
|
||||
inp = tf.nest.map_structure(lambda x: x[index], inputs)
|
||||
last = fn(last, inp)
|
||||
[o.append(l) for o, l in zip(outputs, tf.nest.flatten(last))]
|
||||
if reverse:
|
||||
outputs = [list(reversed(x)) for x in outputs]
|
||||
outputs = [tf.stack(x, 0) for x in outputs]
|
||||
return tf.nest.pack_sequence_as(start, outputs)
|
||||
|
||||
|
||||
def uniform_mixture(dist, dtype=None):
|
||||
if dist.batch_shape[-1] == 1:
|
||||
return tfd.BatchReshape(dist, dist.batch_shape[:-1])
|
||||
dtype = dtype or prec.global_policy().compute_dtype
|
||||
weights = tfd.Categorical(tf.zeros(dist.batch_shape, dtype))
|
||||
return tfd.MixtureSameFamily(weights, dist)
|
||||
|
||||
|
||||
def cat_mixture_entropy(dist):
|
||||
if isinstance(dist, tfd.MixtureSameFamily):
|
||||
probs = dist.components_distribution.probs_parameter()
|
||||
else:
|
||||
probs = dist.probs_parameter()
|
||||
return -tf.reduce_mean(
|
||||
tf.reduce_mean(probs, 2) *
|
||||
tf.math.log(tf.reduce_mean(probs, 2) + 1e-8), -1)
|
||||
|
||||
|
||||
@tf.function
|
||||
def cem_planner(
|
||||
state, num_actions, horizon, proposals, topk, iterations, imagine,
|
||||
objective):
|
||||
dtype = prec.global_policy().compute_dtype
|
||||
B, P = list(state.values())[0].shape[0], proposals
|
||||
H, A = horizon, num_actions
|
||||
flat_state = {k: tf.repeat(v, P, 0) for k, v in state.items()}
|
||||
mean = tf.zeros((B, H, A), dtype)
|
||||
std = tf.ones((B, H, A), dtype)
|
||||
for _ in range(iterations):
|
||||
proposals = tf.random.normal((B, P, H, A), dtype=dtype)
|
||||
proposals = proposals * std[:, None] + mean[:, None]
|
||||
proposals = tf.clip_by_value(proposals, -1, 1)
|
||||
flat_proposals = tf.reshape(proposals, (B * P, H, A))
|
||||
states = imagine(flat_proposals, flat_state)
|
||||
scores = objective(states)
|
||||
scores = tf.reshape(tf.reduce_sum(scores, -1), (B, P))
|
||||
_, indices = tf.math.top_k(scores, topk, sorted=False)
|
||||
best = tf.gather(proposals, indices, axis=1, batch_dims=1)
|
||||
mean, var = tf.nn.moments(best, 1)
|
||||
std = tf.sqrt(var + 1e-6)
|
||||
return mean[:, 0, :]
|
||||
|
||||
|
||||
@tf.function
|
||||
def grad_planner(
|
||||
state, num_actions, horizon, proposals, iterations, imagine, objective,
|
||||
kl_scale, step_size):
|
||||
dtype = prec.global_policy().compute_dtype
|
||||
B, P = list(state.values())[0].shape[0], proposals
|
||||
H, A = horizon, num_actions
|
||||
flat_state = {k: tf.repeat(v, P, 0) for k, v in state.items()}
|
||||
mean = tf.zeros((B, H, A), dtype)
|
||||
rawstd = 0.54 * tf.ones((B, H, A), dtype)
|
||||
for _ in range(iterations):
|
||||
proposals = tf.random.normal((B, P, H, A), dtype=dtype)
|
||||
with tf.GradientTape(watch_accessed_variables=False) as tape:
|
||||
tape.watch(mean)
|
||||
tape.watch(rawstd)
|
||||
std = tf.nn.softplus(rawstd)
|
||||
proposals = proposals * std[:, None] + mean[:, None]
|
||||
proposals = (
|
||||
tf.stop_gradient(tf.clip_by_value(proposals, -1, 1)) +
|
||||
proposals - tf.stop_gradient(proposals))
|
||||
flat_proposals = tf.reshape(proposals, (B * P, H, A))
|
||||
states = imagine(flat_proposals, flat_state)
|
||||
scores = objective(states)
|
||||
scores = tf.reshape(tf.reduce_sum(scores, -1), (B, P))
|
||||
div = tfd.kl_divergence(
|
||||
tfd.Normal(mean, std),
|
||||
tfd.Normal(tf.zeros_like(mean), tf.ones_like(std)))
|
||||
elbo = tf.reduce_sum(scores) - kl_scale * div
|
||||
elbo /= tf.cast(tf.reduce_prod(tf.shape(scores)), dtype)
|
||||
grad_mean, grad_rawstd = tape.gradient(elbo, [mean, rawstd])
|
||||
e, v = tf.nn.moments(grad_mean, [1, 2], keepdims=True)
|
||||
grad_mean /= tf.sqrt(e * e + v + 1e-4)
|
||||
e, v = tf.nn.moments(grad_rawstd, [1, 2], keepdims=True)
|
||||
grad_rawstd /= tf.sqrt(e * e + v + 1e-4)
|
||||
mean = tf.clip_by_value(mean + step_size * grad_mean, -1, 1)
|
||||
rawstd = rawstd + step_size * grad_rawstd
|
||||
return mean[:, 0, :]
|
||||
|
||||
|
||||
class Every:
|
||||
|
||||
def __init__(self, every):
|
||||
self._every = every
|
||||
self._last = None
|
||||
|
||||
def __call__(self, step):
|
||||
if not self._every:
|
||||
return False
|
||||
if self._last is None:
|
||||
self._last = step
|
||||
return True
|
||||
if step >= self._last + self._every:
|
||||
self._last += self._every
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class Once:
|
||||
|
||||
def __init__(self):
|
||||
self._once = True
|
||||
|
||||
def __call__(self):
|
||||
if self._once:
|
||||
self._once = False
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class Until:
|
||||
|
||||
def __init__(self, until):
|
||||
self._until = until
|
||||
|
||||
def __call__(self, step):
|
||||
if not self._until:
|
||||
return True
|
||||
return step < self._until
|
||||
|
||||
|
||||
def schedule(string, step):
|
||||
try:
|
||||
return float(string)
|
||||
except ValueError:
|
||||
step = tf.cast(step, tf.float32)
|
||||
match = re.match(r'linear\((.+),(.+),(.+)\)', string)
|
||||
if match:
|
||||
initial, final, duration = [float(group) for group in match.groups()]
|
||||
mix = tf.clip_by_value(step / duration, 0, 1)
|
||||
return (1 - mix) * initial + mix * final
|
||||
match = re.match(r'warmup\((.+),(.+)\)', string)
|
||||
if match:
|
||||
warmup, value = [float(group) for group in match.groups()]
|
||||
scale = tf.clip_by_value(step / warmup, 0, 1)
|
||||
return scale * value
|
||||
match = re.match(r'exp\((.+),(.+),(.+)\)', string)
|
||||
if match:
|
||||
initial, final, halflife = [float(group) for group in match.groups()]
|
||||
return (initial - final) * 0.5 ** (step / halflife) + final
|
||||
match = re.match(r'horizon\((.+),(.+),(.+)\)', string)
|
||||
if match:
|
||||
initial, final, duration = [float(group) for group in match.groups()]
|
||||
mix = tf.clip_by_value(step / duration, 0, 1)
|
||||
horizon = (1 - mix) * initial + mix * final
|
||||
return 1 - 1 / horizon
|
||||
raise NotImplementedError(string)
|
||||
+306
@@ -0,0 +1,306 @@
|
||||
import threading
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DeepMindControl:
|
||||
|
||||
def __init__(self, name, action_repeat=1, size=(64, 64), camera=None):
|
||||
domain, task = name.split('_', 1)
|
||||
if domain == 'cup': # Only domain with multiple words.
|
||||
domain = 'ball_in_cup'
|
||||
if isinstance(domain, str):
|
||||
from dm_control import suite
|
||||
self._env = suite.load(domain, task)
|
||||
else:
|
||||
assert task is None
|
||||
self._env = domain()
|
||||
self._action_repeat = action_repeat
|
||||
self._size = size
|
||||
if camera is None:
|
||||
camera = dict(quadruped=2).get(domain, 0)
|
||||
self._camera = camera
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
spaces = {}
|
||||
for key, value in self._env.observation_spec().items():
|
||||
spaces[key] = gym.spaces.Box(
|
||||
-np.inf, np.inf, value.shape, dtype=np.float32)
|
||||
spaces['image'] = gym.spaces.Box(
|
||||
0, 255, self._size + (3,), dtype=np.uint8)
|
||||
return gym.spaces.Dict(spaces)
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
spec = self._env.action_spec()
|
||||
return gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32)
|
||||
|
||||
def step(self, action):
|
||||
assert np.isfinite(action).all(), action
|
||||
reward = 0
|
||||
for _ in range(self._action_repeat):
|
||||
time_step = self._env.step(action)
|
||||
reward += time_step.reward or 0
|
||||
if time_step.last():
|
||||
break
|
||||
obs = dict(time_step.observation)
|
||||
obs['image'] = self.render()
|
||||
done = time_step.last()
|
||||
info = {'discount': np.array(time_step.discount, np.float32)}
|
||||
return obs, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
time_step = self._env.reset()
|
||||
obs = dict(time_step.observation)
|
||||
obs['image'] = self.render()
|
||||
return obs
|
||||
|
||||
def render(self, *args, **kwargs):
|
||||
if kwargs.get('mode', 'rgb_array') != 'rgb_array':
|
||||
raise ValueError("Only render mode 'rgb_array' is supported.")
|
||||
return self._env.physics.render(*self._size, camera_id=self._camera)
|
||||
|
||||
|
||||
class Atari:
|
||||
|
||||
LOCK = threading.Lock()
|
||||
|
||||
def __init__(
|
||||
self, name, action_repeat=4, size=(84, 84), grayscale=True, noops=30,
|
||||
life_done=False, sticky_actions=True, all_actions=False):
|
||||
assert size[0] == size[1]
|
||||
import gym.wrappers
|
||||
import gym.envs.atari
|
||||
if name == 'james_bond':
|
||||
name = 'jamesbond'
|
||||
with self.LOCK:
|
||||
env = gym.envs.atari.AtariEnv(
|
||||
game=name, obs_type='image', frameskip=1,
|
||||
repeat_action_probability=0.25 if sticky_actions else 0.0,
|
||||
full_action_space=all_actions)
|
||||
# Avoid unnecessary rendering in inner env.
|
||||
env._get_obs = lambda: None
|
||||
# Tell wrapper that the inner env has no action repeat.
|
||||
env.spec = gym.envs.registration.EnvSpec('NoFrameskip-v0')
|
||||
env = gym.wrappers.AtariPreprocessing(
|
||||
env, noops, action_repeat, size[0], life_done, grayscale)
|
||||
self._env = env
|
||||
self._grayscale = grayscale
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
return gym.spaces.Dict({
|
||||
'image': self._env.observation_space,
|
||||
'ram': gym.spaces.Box(0, 255, (128,), np.uint8),
|
||||
})
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
return self._env.action_space
|
||||
|
||||
def close(self):
|
||||
return self._env.close()
|
||||
|
||||
def reset(self):
|
||||
with self.LOCK:
|
||||
image = self._env.reset()
|
||||
if self._grayscale:
|
||||
image = image[..., None]
|
||||
obs = {'image': image, 'ram': self._env.env._get_ram()}
|
||||
return obs
|
||||
|
||||
def step(self, action):
|
||||
image, reward, done, info = self._env.step(action)
|
||||
if self._grayscale:
|
||||
image = image[..., None]
|
||||
obs = {'image': image, 'ram': self._env.env._get_ram()}
|
||||
return obs, reward, done, info
|
||||
|
||||
def render(self, mode):
|
||||
return self._env.render(mode)
|
||||
|
||||
|
||||
class CollectDataset:
|
||||
|
||||
def __init__(self, env, callbacks=None, precision=32):
|
||||
self._env = env
|
||||
self._callbacks = callbacks or ()
|
||||
self._precision = precision
|
||||
self._episode = None
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, info = self._env.step(action)
|
||||
obs = {k: self._convert(v) for k, v in obs.items()}
|
||||
transition = obs.copy()
|
||||
if isinstance(action, dict):
|
||||
transition.update(action)
|
||||
else:
|
||||
transition['action'] = action
|
||||
transition['reward'] = reward
|
||||
transition['discount'] = info.get('discount', np.array(1 - float(done)))
|
||||
self._episode.append(transition)
|
||||
if done:
|
||||
for key, value in self._episode[1].items():
|
||||
if key not in self._episode[0]:
|
||||
self._episode[0][key] = 0 * value
|
||||
episode = {k: [t[k] for t in self._episode] for k in self._episode[0]}
|
||||
episode = {k: self._convert(v) for k, v in episode.items()}
|
||||
info['episode'] = episode
|
||||
for callback in self._callbacks:
|
||||
callback(episode)
|
||||
return obs, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
obs = self._env.reset()
|
||||
transition = obs.copy()
|
||||
# Missing keys will be filled with a zeroed out version of the first
|
||||
# transition, because we do not know what action information the agent will
|
||||
# pass yet.
|
||||
transition['reward'] = 0.0
|
||||
transition['discount'] = 1.0
|
||||
self._episode = [transition]
|
||||
return obs
|
||||
|
||||
def _convert(self, value):
|
||||
value = np.array(value)
|
||||
if np.issubdtype(value.dtype, np.floating):
|
||||
dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self._precision]
|
||||
elif np.issubdtype(value.dtype, np.signedinteger):
|
||||
dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision]
|
||||
elif np.issubdtype(value.dtype, np.uint8):
|
||||
dtype = np.uint8
|
||||
else:
|
||||
raise NotImplementedError(value.dtype)
|
||||
return value.astype(dtype)
|
||||
|
||||
|
||||
class TimeLimit:
|
||||
|
||||
def __init__(self, env, duration):
|
||||
self._env = env
|
||||
self._duration = duration
|
||||
self._step = None
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
def step(self, action):
|
||||
assert self._step is not None, 'Must reset environment.'
|
||||
obs, reward, done, info = self._env.step(action)
|
||||
self._step += 1
|
||||
if self._step >= self._duration:
|
||||
done = True
|
||||
if 'discount' not in info:
|
||||
info['discount'] = np.array(1.0).astype(np.float32)
|
||||
self._step = None
|
||||
return obs, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
self._step = 0
|
||||
return self._env.reset()
|
||||
|
||||
|
||||
class NormalizeActions:
|
||||
|
||||
def __init__(self, env):
|
||||
self._env = env
|
||||
self._mask = np.logical_and(
|
||||
np.isfinite(env.action_space.low),
|
||||
np.isfinite(env.action_space.high))
|
||||
self._low = np.where(self._mask, env.action_space.low, -1)
|
||||
self._high = np.where(self._mask, env.action_space.high, 1)
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
low = np.where(self._mask, -np.ones_like(self._low), self._low)
|
||||
high = np.where(self._mask, np.ones_like(self._low), self._high)
|
||||
return gym.spaces.Box(low, high, dtype=np.float32)
|
||||
|
||||
def step(self, action):
|
||||
original = (action + 1) / 2 * (self._high - self._low) + self._low
|
||||
original = np.where(self._mask, original, action)
|
||||
return self._env.step(original)
|
||||
|
||||
|
||||
class OneHotAction:
|
||||
|
||||
def __init__(self, env):
|
||||
assert isinstance(env.action_space, gym.spaces.Discrete)
|
||||
self._env = env
|
||||
self._random = np.random.RandomState()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
shape = (self._env.action_space.n,)
|
||||
space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
|
||||
space.sample = self._sample_action
|
||||
space.discrete = True
|
||||
return space
|
||||
|
||||
def step(self, action):
|
||||
index = np.argmax(action).astype(int)
|
||||
reference = np.zeros_like(action)
|
||||
reference[index] = 1
|
||||
if not np.allclose(reference, action):
|
||||
raise ValueError(f'Invalid one-hot action:\n{action}')
|
||||
return self._env.step(index)
|
||||
|
||||
def reset(self):
|
||||
return self._env.reset()
|
||||
|
||||
def _sample_action(self):
|
||||
actions = self._env.action_space.n
|
||||
index = self._random.randint(0, actions)
|
||||
reference = np.zeros(actions, dtype=np.float32)
|
||||
reference[index] = 1.0
|
||||
return reference
|
||||
|
||||
|
||||
class RewardObs:
|
||||
|
||||
def __init__(self, env):
|
||||
self._env = env
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
spaces = self._env.observation_space.spaces
|
||||
assert 'reward' not in spaces
|
||||
spaces['reward'] = gym.spaces.Box(-np.inf, np.inf, dtype=np.float32)
|
||||
return gym.spaces.Dict(spaces)
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, info = self._env.step(action)
|
||||
obs['reward'] = reward
|
||||
return obs, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
obs = self._env.reset()
|
||||
obs['reward'] = 0.0
|
||||
return obs
|
||||
|
||||
|
||||
class SelectAction:
|
||||
|
||||
def __init__(self, env, key):
|
||||
self._env = env
|
||||
self._key = key
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
def step(self, action):
|
||||
return self._env.step(action[self._key])
|
||||
Reference in New Issue
Block a user