mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 13:19:38 +08:00
[rllib] Basic port of baselines/deepq to rllib (#709)
* rllib v0 * fix imports * lint * comments * update docs * a3c wip * a3c wip * report stats * update doc * add common logdir attr * name is too long * fix small bug * propagate exception on error * fetch metrics * initial port * fix lint * add right license * port to common alg format * fix lint * rename dqn * add imports from future * fix lint
This commit is contained in:
committed by
Philipp Moritz
parent
6c45657280
commit
f012e597c2
@@ -0,0 +1 @@
|
||||
Code in this package is adapted from https://github.com/openai/baselines/tree/master/baselines/deepq.
|
||||
@@ -0,0 +1,7 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.dqn.dqn import DQN, DEFAULT_CONFIG
|
||||
|
||||
__all__ = ["DQN", "DEFAULT_CONFIG"]
|
||||
@@ -0,0 +1,277 @@
|
||||
"""Deep Q learning graph
|
||||
|
||||
The functions in this file can are used to create the following functions:
|
||||
|
||||
======= act ========
|
||||
|
||||
Function to chose an action given an observation
|
||||
|
||||
Parameters
|
||||
----------
|
||||
observation: object
|
||||
Observation that can be feed into the output of make_obs_ph
|
||||
stochastic: bool
|
||||
if set to False all the actions are always deterministic
|
||||
(default False)
|
||||
update_eps_ph: float
|
||||
update epsilon a new value, if negative not update happens
|
||||
(default: no update)
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor of dtype tf.int64 and shape (BATCH_SIZE,) with an action to be
|
||||
performed for every element of the batch.
|
||||
|
||||
|
||||
======= train =======
|
||||
|
||||
Function that takes a transition (s,a,r,s') and optimizes Bellman
|
||||
equation's error:
|
||||
|
||||
td_error = Q(s,a) - (r + gamma * max_a' Q(s', a'))
|
||||
loss = huber_loss[td_error]
|
||||
|
||||
Parameters
|
||||
----------
|
||||
obs_t: object
|
||||
a batch of observations
|
||||
action: np.array
|
||||
actions that were selected upon seeing obs_t.
|
||||
dtype must be int32 and shape must be (batch_size,)
|
||||
reward: np.array
|
||||
immediate reward attained after executing those actions
|
||||
dtype must be float32 and shape must be (batch_size,)
|
||||
obs_tp1: object
|
||||
observations that followed obs_t
|
||||
done: np.array
|
||||
1 if obs_t was the last observation in the episode and 0 otherwise
|
||||
obs_tp1 gets ignored, but must be of the valid shape.
|
||||
dtype must be float32 and shape must be (batch_size,)
|
||||
weight: np.array
|
||||
imporance weights for every element of the batch (gradient is
|
||||
multiplied by the importance weight) dtype must be float32 and shape
|
||||
must be (batch_size,)
|
||||
|
||||
Returns
|
||||
-------
|
||||
td_error: np.array
|
||||
a list of differences between Q(s,a) and the target in Bellman's
|
||||
equation. dtype is float32 and shape is (batch_size,)
|
||||
|
||||
======= update_target ========
|
||||
|
||||
copy the parameters from optimized Q function to the target Q function.
|
||||
In Q learning we actually optimize the following error:
|
||||
|
||||
Q(s,a) - (r + gamma * max_a' Q'(s', a'))
|
||||
|
||||
Where Q' is lagging behind Q to stablize the learning. For example for
|
||||
Atari
|
||||
|
||||
Q' is set to Q once every 10000 updates training steps.
|
||||
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
from ray.rllib.dqn.common import tf_util as U
|
||||
|
||||
|
||||
def build_act(make_obs_ph, q_func, num_actions, scope="deepq", reuse=None):
|
||||
"""Creates the act function:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
make_obs_ph: str -> tf.placeholder or TfInput
|
||||
a function that take a name and creates a placeholder of input with
|
||||
that name
|
||||
q_func: (tf.Variable, int, str, bool) -> tf.Variable
|
||||
the model that takes the following inputs:
|
||||
observation_in: object
|
||||
the output of observation placeholder
|
||||
num_actions: int
|
||||
number of actions
|
||||
scope: str
|
||||
reuse: bool
|
||||
should be passed to outer variable scope
|
||||
and returns a tensor of shape (batch_size, num_actions) with values of
|
||||
every action.
|
||||
num_actions: int
|
||||
number of actions.
|
||||
scope: str or VariableScope
|
||||
optional scope for variable_scope.
|
||||
reuse: bool or None
|
||||
whether or not the variables should be reused. To be able to reuse the
|
||||
scope must be given.
|
||||
|
||||
Returns
|
||||
-------
|
||||
act: (tf.Variable, bool, float) -> tf.Variable
|
||||
function to select and action given observation.
|
||||
` See the top of the file for details.
|
||||
"""
|
||||
with tf.variable_scope(scope, reuse=reuse):
|
||||
observations_ph = U.ensure_tf_input(make_obs_ph("observation"))
|
||||
stochastic_ph = tf.placeholder(tf.bool, (), name="stochastic")
|
||||
update_eps_ph = tf.placeholder(tf.float32, (), name="update_eps")
|
||||
|
||||
eps = tf.get_variable(
|
||||
"eps", (), initializer=tf.constant_initializer(0))
|
||||
|
||||
q_values = q_func(observations_ph.get(), num_actions, scope="q_func")
|
||||
deterministic_actions = tf.argmax(q_values, axis=1)
|
||||
|
||||
batch_size = tf.shape(observations_ph.get())[0]
|
||||
random_actions = tf.random_uniform(
|
||||
tf.stack([batch_size]), minval=0, maxval=num_actions, dtype=tf.int64)
|
||||
chose_random = tf.random_uniform(
|
||||
tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32) < eps
|
||||
stochastic_actions = tf.where(
|
||||
chose_random, random_actions, deterministic_actions)
|
||||
|
||||
output_actions = tf.cond(
|
||||
stochastic_ph, lambda: stochastic_actions,
|
||||
lambda: deterministic_actions)
|
||||
update_eps_expr = eps.assign(
|
||||
tf.cond(update_eps_ph >= 0, lambda: update_eps_ph, lambda: eps))
|
||||
|
||||
act = U.function(
|
||||
inputs=[observations_ph, stochastic_ph, update_eps_ph],
|
||||
outputs=output_actions,
|
||||
givens={update_eps_ph: -1.0, stochastic_ph: True},
|
||||
updates=[update_eps_expr])
|
||||
return act
|
||||
|
||||
|
||||
def build_train(
|
||||
make_obs_ph, q_func, num_actions, optimizer, grad_norm_clipping=None,
|
||||
gamma=1.0, double_q=True, scope="deepq", reuse=None):
|
||||
"""Creates the train function:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
make_obs_ph: str -> tf.placeholder or TfInput
|
||||
a function that takes a name and creates a placeholder of input with
|
||||
that name
|
||||
q_func: (tf.Variable, int, str, bool) -> tf.Variable
|
||||
the model that takes the following inputs:
|
||||
observation_in: object
|
||||
the output of observation placeholder
|
||||
num_actions: int
|
||||
number of actions
|
||||
scope: str
|
||||
reuse: bool
|
||||
should be passed to outer variable scope
|
||||
and returns a tensor of shape (batch_size, num_actions) with values of
|
||||
every action.
|
||||
num_actions: int
|
||||
number of actions
|
||||
reuse: bool
|
||||
whether or not to reuse the graph variables
|
||||
optimizer: tf.train.Optimizer
|
||||
optimizer to use for the Q-learning objective.
|
||||
grad_norm_clipping: float or None
|
||||
clip gradient norms to this value. If None no clipping is performed.
|
||||
gamma: float
|
||||
discount rate.
|
||||
double_q: bool
|
||||
if true will use Double Q Learning (https://arxiv.org/abs/1509.06461).
|
||||
In general it is a good idea to keep it enabled.
|
||||
scope: str or VariableScope
|
||||
optional scope for variable_scope.
|
||||
reuse: bool or None
|
||||
whether or not the variables should be reused. To be able to reuse the
|
||||
scope must be given.
|
||||
|
||||
Returns
|
||||
-------
|
||||
act: (tf.Variable, bool, float) -> tf.Variable
|
||||
function to select and action given observation.
|
||||
` See the top of the file for details.
|
||||
train: (object, np.array, np.array, object, np.array, np.array) -> np.array
|
||||
optimize the error in Bellman's equation.
|
||||
` See the top of the file for details.
|
||||
update_target: () -> ()
|
||||
copy the parameters from optimized Q function to the target Q function.
|
||||
` See the top of the file for details.
|
||||
debug: {str: function}
|
||||
a bunch of functions to print debug data like q_values.
|
||||
"""
|
||||
act_f = build_act(make_obs_ph, q_func, num_actions, scope=scope, reuse=reuse)
|
||||
|
||||
with tf.variable_scope(scope, reuse=reuse):
|
||||
# set up placeholders
|
||||
obs_t_input = U.ensure_tf_input(make_obs_ph("obs_t"))
|
||||
act_t_ph = tf.placeholder(tf.int32, [None], name="action")
|
||||
rew_t_ph = tf.placeholder(tf.float32, [None], name="reward")
|
||||
obs_tp1_input = U.ensure_tf_input(make_obs_ph("obs_tp1"))
|
||||
done_mask_ph = tf.placeholder(tf.float32, [None], name="done")
|
||||
importance_weights_ph = tf.placeholder(tf.float32, [None], name="weight")
|
||||
|
||||
# q network evaluation
|
||||
q_t = q_func(
|
||||
obs_t_input.get(), num_actions, scope="q_func",
|
||||
reuse=True) # reuse parameters from act
|
||||
q_func_vars = U.scope_vars(U.absolute_scope_name("q_func"))
|
||||
|
||||
# target q network evalution
|
||||
q_tp1 = q_func(obs_tp1_input.get(), num_actions, scope="target_q_func")
|
||||
target_q_func_vars = U.scope_vars(U.absolute_scope_name("target_q_func"))
|
||||
|
||||
# q scores for actions which we know were selected in the given state.
|
||||
q_t_selected = tf.reduce_sum(q_t * tf.one_hot(act_t_ph, num_actions), 1)
|
||||
|
||||
# compute estimate of best possible value starting from state at t + 1
|
||||
if double_q:
|
||||
q_tp1_using_online_net = q_func(
|
||||
obs_tp1_input.get(), num_actions, scope="q_func", reuse=True)
|
||||
q_tp1_best_using_online_net = tf.arg_max(q_tp1_using_online_net, 1)
|
||||
q_tp1_best = tf.reduce_sum(
|
||||
q_tp1 * tf.one_hot(q_tp1_best_using_online_net, num_actions), 1)
|
||||
else:
|
||||
q_tp1_best = tf.reduce_max(q_tp1, 1)
|
||||
q_tp1_best_masked = (1.0 - done_mask_ph) * q_tp1_best
|
||||
|
||||
# compute RHS of bellman equation
|
||||
q_t_selected_target = rew_t_ph + gamma * q_tp1_best_masked
|
||||
|
||||
# compute the error (potentially clipped)
|
||||
td_error = q_t_selected - tf.stop_gradient(q_t_selected_target)
|
||||
errors = U.huber_loss(td_error)
|
||||
weighted_error = tf.reduce_mean(importance_weights_ph * errors)
|
||||
# compute optimization op (potentially with gradient clipping)
|
||||
if grad_norm_clipping is not None:
|
||||
optimize_expr = U.minimize_and_clip(
|
||||
optimizer, weighted_error, var_list=q_func_vars,
|
||||
clip_val=grad_norm_clipping)
|
||||
else:
|
||||
optimize_expr = optimizer.minimize(weighted_error, var_list=q_func_vars)
|
||||
|
||||
# update_target_fn will be called periodically to copy Q network to
|
||||
# target Q network
|
||||
update_target_expr = []
|
||||
for var, var_target in zip(
|
||||
sorted(q_func_vars, key=lambda v: v.name),
|
||||
sorted(target_q_func_vars, key=lambda v: v.name)):
|
||||
update_target_expr.append(var_target.assign(var))
|
||||
update_target_expr = tf.group(*update_target_expr)
|
||||
|
||||
# Create callable functions
|
||||
train = U.function(
|
||||
inputs=[
|
||||
obs_t_input,
|
||||
act_t_ph,
|
||||
rew_t_ph,
|
||||
obs_tp1_input,
|
||||
done_mask_ph,
|
||||
importance_weights_ph
|
||||
],
|
||||
outputs=td_error,
|
||||
updates=[optimize_expr])
|
||||
update_target = U.function([], [], updates=[update_target_expr])
|
||||
|
||||
q_values = U.function([obs_t_input], q_t)
|
||||
|
||||
return act_f, train, update_target, {'q_values': q_values}
|
||||
@@ -0,0 +1,246 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import cv2
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
from collections import deque
|
||||
from gym import spaces
|
||||
|
||||
|
||||
class NoopResetEnv(gym.Wrapper):
|
||||
def __init__(self, env=None, noop_max=30):
|
||||
"""Sample initial states by taking random number of no-ops on reset.
|
||||
No-op is assumed to be action 0.
|
||||
"""
|
||||
super(NoopResetEnv, self).__init__(env)
|
||||
self.noop_max = noop_max
|
||||
self.override_num_noops = None
|
||||
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
|
||||
|
||||
def _reset(self):
|
||||
""" Do no-op action for a number of steps in [1, noop_max]."""
|
||||
self.env.reset()
|
||||
if self.override_num_noops is not None:
|
||||
noops = self.override_num_noops
|
||||
else:
|
||||
noops = np.random.randint(1, self.noop_max + 1)
|
||||
assert noops > 0
|
||||
obs = None
|
||||
for _ in range(noops):
|
||||
obs, _, done, _ = self.env.step(0)
|
||||
if done:
|
||||
obs = self.env.reset()
|
||||
return obs
|
||||
|
||||
|
||||
class FireResetEnv(gym.Wrapper):
|
||||
def __init__(self, env=None):
|
||||
"""For environments where the user need to press FIRE for the game to
|
||||
start."""
|
||||
super(FireResetEnv, self).__init__(env)
|
||||
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
|
||||
assert len(env.unwrapped.get_action_meanings()) >= 3
|
||||
|
||||
def _reset(self):
|
||||
self.env.reset()
|
||||
obs, _, done, _ = self.env.step(1)
|
||||
if done:
|
||||
self.env.reset()
|
||||
obs, _, done, _ = self.env.step(2)
|
||||
if done:
|
||||
self.env.reset()
|
||||
return obs
|
||||
|
||||
|
||||
class EpisodicLifeEnv(gym.Wrapper):
|
||||
def __init__(self, env=None):
|
||||
"""Make end-of-life == end-of-episode, but only reset on true game over.
|
||||
Done by DeepMind for the DQN and co. since it helps value estimation.
|
||||
"""
|
||||
super(EpisodicLifeEnv, self).__init__(env)
|
||||
self.lives = 0
|
||||
self.was_real_done = True
|
||||
self.was_real_reset = False
|
||||
|
||||
def _step(self, action):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self.was_real_done = done
|
||||
# check current lives, make loss of life terminal,
|
||||
# then update lives to handle bonus lives
|
||||
lives = self.env.unwrapped.ale.lives()
|
||||
if lives < self.lives and lives > 0:
|
||||
# for Qbert somtimes we stay in lives == 0 condtion for a few frames
|
||||
# so its important to keep lives > 0, so that we only reset once
|
||||
# the environment advertises done.
|
||||
done = True
|
||||
self.lives = lives
|
||||
return obs, reward, done, info
|
||||
|
||||
def _reset(self):
|
||||
"""Reset only when lives are exhausted.
|
||||
This way all states are still reachable even though lives are episodic,
|
||||
and the learner need not know about any of this behind-the-scenes.
|
||||
"""
|
||||
if self.was_real_done:
|
||||
obs = self.env.reset()
|
||||
self.was_real_reset = True
|
||||
else:
|
||||
# no-op step to advance from terminal/lost life state
|
||||
obs, _, _, _ = self.env.step(0)
|
||||
self.was_real_reset = False
|
||||
self.lives = self.env.unwrapped.ale.lives()
|
||||
return obs
|
||||
|
||||
|
||||
class MaxAndSkipEnv(gym.Wrapper):
|
||||
def __init__(self, env=None, skip=4):
|
||||
"""Return only every `skip`-th frame"""
|
||||
super(MaxAndSkipEnv, self).__init__(env)
|
||||
# most recent raw observations (for max pooling across time steps)
|
||||
self._obs_buffer = deque(maxlen=2)
|
||||
self._skip = skip
|
||||
|
||||
def _step(self, action):
|
||||
total_reward = 0.0
|
||||
done = None
|
||||
for _ in range(self._skip):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self._obs_buffer.append(obs)
|
||||
total_reward += reward
|
||||
if done:
|
||||
break
|
||||
|
||||
max_frame = np.max(np.stack(self._obs_buffer), axis=0)
|
||||
|
||||
return max_frame, total_reward, done, info
|
||||
|
||||
def _reset(self):
|
||||
"""Clear past frame buffer and init. to first obs. from inner env."""
|
||||
self._obs_buffer.clear()
|
||||
obs = self.env.reset()
|
||||
self._obs_buffer.append(obs)
|
||||
return obs
|
||||
|
||||
|
||||
class ProcessFrame84(gym.ObservationWrapper):
|
||||
def __init__(self, env=None):
|
||||
super(ProcessFrame84, self).__init__(env)
|
||||
self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1))
|
||||
|
||||
def _observation(self, obs):
|
||||
return ProcessFrame84.process(obs)
|
||||
|
||||
@staticmethod
|
||||
def process(frame):
|
||||
if frame.size == 210 * 160 * 3:
|
||||
img = np.reshape(frame, [210, 160, 3]).astype(np.float32)
|
||||
elif frame.size == 250 * 160 * 3:
|
||||
img = np.reshape(frame, [250, 160, 3]).astype(np.float32)
|
||||
else:
|
||||
assert False, "Unknown resolution."
|
||||
img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
|
||||
resized_screen = cv2.resize(img, (84, 110), interpolation=cv2.INTER_AREA)
|
||||
x_t = resized_screen[18:102, :]
|
||||
x_t = np.reshape(x_t, [84, 84, 1])
|
||||
return x_t.astype(np.uint8)
|
||||
|
||||
|
||||
class ClippedRewardsWrapper(gym.RewardWrapper):
|
||||
def _reward(self, reward):
|
||||
"""Change all the positive rewards to 1, negative to -1 and keep zero."""
|
||||
return np.sign(reward)
|
||||
|
||||
|
||||
class LazyFrames(object):
|
||||
def __init__(self, frames):
|
||||
"""This object ensures that common frames between the observations are only
|
||||
stored once. It exists purely to optimize memory usage which can be huge
|
||||
for DQN's 1M frames replay buffers.
|
||||
|
||||
This object should only be converted to numpy array before being passed to
|
||||
the model.
|
||||
|
||||
You'd not belive how complex the previous solution was."""
|
||||
self._frames = frames
|
||||
|
||||
def __array__(self, dtype=None):
|
||||
out = np.concatenate(self._frames, axis=2)
|
||||
if dtype is not None:
|
||||
out = out.astype(dtype)
|
||||
return out
|
||||
|
||||
|
||||
class FrameStack(gym.Wrapper):
|
||||
def __init__(self, env, k):
|
||||
"""Stack k last frames.
|
||||
|
||||
Returns lazy array, which is much more memory efficient.
|
||||
|
||||
See Also
|
||||
--------
|
||||
ray.rllib.dqn.common.atari_wrappers.LazyFrames
|
||||
"""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self.k = k
|
||||
self.frames = deque([], maxlen=k)
|
||||
shp = env.observation_space.shape
|
||||
self.observation_space = spaces.Box(
|
||||
low=0, high=255, shape=(shp[0], shp[1], shp[2] * k))
|
||||
|
||||
def _reset(self):
|
||||
ob = self.env.reset()
|
||||
for _ in range(self.k):
|
||||
self.frames.append(ob)
|
||||
return self._get_ob()
|
||||
|
||||
def _step(self, action):
|
||||
ob, reward, done, info = self.env.step(action)
|
||||
self.frames.append(ob)
|
||||
return self._get_ob(), reward, done, info
|
||||
|
||||
def _get_ob(self):
|
||||
assert len(self.frames) == self.k
|
||||
return LazyFrames(list(self.frames))
|
||||
|
||||
|
||||
class ScaledFloatFrame(gym.ObservationWrapper):
|
||||
def _observation(self, obs):
|
||||
# careful! This undoes the memory optimization, use
|
||||
# with smaller replay buffers only.
|
||||
return np.array(obs).astype(np.float32) / 255.0
|
||||
|
||||
|
||||
def wrap_dqn(env):
|
||||
"""Apply a common set of wrappers for Atari games."""
|
||||
assert 'NoFrameskip' in env.spec.id
|
||||
env = EpisodicLifeEnv(env)
|
||||
env = NoopResetEnv(env, noop_max=30)
|
||||
env = MaxAndSkipEnv(env, skip=4)
|
||||
if 'FIRE' in env.unwrapped.get_action_meanings():
|
||||
env = FireResetEnv(env)
|
||||
env = ProcessFrame84(env)
|
||||
env = FrameStack(env, 4)
|
||||
env = ClippedRewardsWrapper(env)
|
||||
return env
|
||||
|
||||
|
||||
class A2cProcessFrame(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1))
|
||||
|
||||
def _step(self, action):
|
||||
ob, reward, done, info = self.env.step(action)
|
||||
return A2cProcessFrame.process(ob), reward, done, info
|
||||
|
||||
def _reset(self):
|
||||
return A2cProcessFrame.process(self.env.reset())
|
||||
|
||||
@staticmethod
|
||||
def process(frame):
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
||||
frame = cv2.resize(frame, (84, 84), interpolation=cv2.INTER_AREA)
|
||||
return frame.reshape(84, 84, 1)
|
||||
@@ -0,0 +1,107 @@
|
||||
"""This file is used for specifying various schedules that evolve over
|
||||
time throughout the execution of the algorithm, such as:
|
||||
- learning rate for the optimizer
|
||||
- exploration epsilon for the epsilon greedy exploration strategy
|
||||
- beta parameter for beta parameter in prioritized replay
|
||||
|
||||
Each schedule has a function `value(t)` which returns the current value
|
||||
of the parameter given the timestep t of the optimization procedure.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
class Schedule(object):
|
||||
def value(self, t):
|
||||
"""Value of the schedule at time t"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class ConstantSchedule(object):
|
||||
def __init__(self, value):
|
||||
"""Value remains constant over time.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
value: float
|
||||
Constant value of the schedule
|
||||
"""
|
||||
self._v = value
|
||||
|
||||
def value(self, t):
|
||||
"""See Schedule.value"""
|
||||
return self._v
|
||||
|
||||
|
||||
def linear_interpolation(l, r, alpha):
|
||||
return l + alpha * (r - l)
|
||||
|
||||
|
||||
class PiecewiseSchedule(object):
|
||||
def __init__(
|
||||
self, endpoints, interpolation=linear_interpolation,
|
||||
outside_value=None):
|
||||
|
||||
"""Piecewise schedule.
|
||||
|
||||
endpoints: [(int, int)]
|
||||
list of pairs `(time, value)` meanining that schedule should output
|
||||
`value` when `t==time`. All the values for time must be sorted in
|
||||
an increasing order. When t is between two times, e.g.
|
||||
`(time_a, value_a)`
|
||||
and `(time_b, value_b)`, such that `time_a <= t < time_b` then value
|
||||
outputs `interpolation(value_a, value_b, alpha)` where alpha is a
|
||||
fraction of time passed between `time_a` and `time_b` for time `t`.
|
||||
interpolation: lambda float, float, float: float
|
||||
a function that takes value to the left and to the right of t according
|
||||
to the `endpoints`. Alpha is the fraction of distance from left endpoint
|
||||
to right endpoint that t has covered. See linear_interpolation for
|
||||
example.
|
||||
outside_value: float
|
||||
if the value is requested outside of all the intervals sepecified in
|
||||
`endpoints` this value is returned. If None then AssertionError is
|
||||
raised when outside value is requested.
|
||||
"""
|
||||
idxes = [e[0] for e in endpoints]
|
||||
assert idxes == sorted(idxes)
|
||||
self._interpolation = interpolation
|
||||
self._outside_value = outside_value
|
||||
self._endpoints = endpoints
|
||||
|
||||
def value(self, t):
|
||||
"""See Schedule.value"""
|
||||
for (l_t, l), (r_t, r) in zip(self._endpoints[:-1], self._endpoints[1:]):
|
||||
if l_t <= t and t < r_t:
|
||||
alpha = float(t - l_t) / (r_t - l_t)
|
||||
return self._interpolation(l, r, alpha)
|
||||
|
||||
# t does not belong to any of the pieces, so doom.
|
||||
assert self._outside_value is not None
|
||||
return self._outside_value
|
||||
|
||||
|
||||
class LinearSchedule(object):
|
||||
def __init__(self, schedule_timesteps, final_p, initial_p=1.0):
|
||||
"""Linear interpolation between initial_p and final_p over
|
||||
schedule_timesteps. After this many timesteps pass final_p is
|
||||
returned.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
schedule_timesteps: int
|
||||
Number of timesteps for which to linearly anneal initial_p
|
||||
to final_p
|
||||
initial_p: float
|
||||
initial output value
|
||||
final_p: float
|
||||
final output value
|
||||
"""
|
||||
self.schedule_timesteps = schedule_timesteps
|
||||
self.final_p = final_p
|
||||
self.initial_p = initial_p
|
||||
|
||||
def value(self, t):
|
||||
"""See Schedule.value"""
|
||||
fraction = min(float(t) / self.schedule_timesteps, 1.0)
|
||||
return self.initial_p + fraction * (self.final_p - self.initial_p)
|
||||
@@ -0,0 +1,151 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import operator
|
||||
|
||||
|
||||
class SegmentTree(object):
|
||||
def __init__(self, capacity, operation, neutral_element):
|
||||
"""Build a Segment Tree data structure.
|
||||
|
||||
https://en.wikipedia.org/wiki/Segment_tree
|
||||
|
||||
Can be used as regular array, but with two
|
||||
important differences:
|
||||
|
||||
a) setting item's value is slightly slower.
|
||||
It is O(lg capacity) instead of O(1).
|
||||
b) user has access to an efficient `reduce`
|
||||
operation which reduces `operation` over
|
||||
a contiguous subsequence of items in the
|
||||
array.
|
||||
|
||||
Paramters
|
||||
---------
|
||||
capacity: int
|
||||
Total size of the array - must be a power of two.
|
||||
operation: lambda obj, obj -> obj
|
||||
and operation for combining elements (eg. sum, max)
|
||||
must for a mathematical group together with the set of
|
||||
possible values for array elements.
|
||||
neutral_element: obj
|
||||
neutral element for the operation above. eg. float('-inf')
|
||||
for max and 0 for sum.
|
||||
"""
|
||||
|
||||
assert capacity > 0 and capacity & (capacity - 1) == 0, \
|
||||
"capacity must be positive and a power of 2."
|
||||
self._capacity = capacity
|
||||
self._value = [neutral_element for _ in range(2 * capacity)]
|
||||
self._operation = operation
|
||||
|
||||
def _reduce_helper(self, start, end, node, node_start, node_end):
|
||||
if start == node_start and end == node_end:
|
||||
return self._value[node]
|
||||
mid = (node_start + node_end) // 2
|
||||
if end <= mid:
|
||||
return self._reduce_helper(start, end, 2 * node, node_start, mid)
|
||||
else:
|
||||
if mid + 1 <= start:
|
||||
return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end)
|
||||
else:
|
||||
return self._operation(
|
||||
self._reduce_helper(start, mid, 2 * node, node_start, mid),
|
||||
self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end)
|
||||
)
|
||||
|
||||
def reduce(self, start=0, end=None):
|
||||
"""Returns result of applying `self.operation`
|
||||
to a contiguous subsequence of the array.
|
||||
|
||||
self.operation(
|
||||
arr[start], operation(arr[start+1], operation(... arr[end])))
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start: int
|
||||
beginning of the subsequence
|
||||
end: int
|
||||
end of the subsequences
|
||||
|
||||
Returns
|
||||
-------
|
||||
reduced: obj
|
||||
result of reducing self.operation over the specified range of array
|
||||
elements.
|
||||
"""
|
||||
if end is None:
|
||||
end = self._capacity
|
||||
if end < 0:
|
||||
end += self._capacity
|
||||
end -= 1
|
||||
return self._reduce_helper(start, end, 1, 0, self._capacity - 1)
|
||||
|
||||
def __setitem__(self, idx, val):
|
||||
# index of the leaf
|
||||
idx += self._capacity
|
||||
self._value[idx] = val
|
||||
idx //= 2
|
||||
while idx >= 1:
|
||||
self._value[idx] = self._operation(
|
||||
self._value[2 * idx],
|
||||
self._value[2 * idx + 1])
|
||||
idx //= 2
|
||||
|
||||
def __getitem__(self, idx):
|
||||
assert 0 <= idx < self._capacity
|
||||
return self._value[self._capacity + idx]
|
||||
|
||||
|
||||
class SumSegmentTree(SegmentTree):
|
||||
def __init__(self, capacity):
|
||||
super(SumSegmentTree, self).__init__(
|
||||
capacity=capacity,
|
||||
operation=operator.add,
|
||||
neutral_element=0.0)
|
||||
|
||||
def sum(self, start=0, end=None):
|
||||
"""Returns arr[start] + ... + arr[end]"""
|
||||
return super(SumSegmentTree, self).reduce(start, end)
|
||||
|
||||
def find_prefixsum_idx(self, prefixsum):
|
||||
"""Find the highest index `i` in the array such that
|
||||
sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum
|
||||
|
||||
if array values are probabilities, this function
|
||||
allows to sample indexes according to the discrete
|
||||
probability efficiently.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
perfixsum: float
|
||||
upperbound on the sum of array prefix
|
||||
|
||||
Returns
|
||||
-------
|
||||
idx: int
|
||||
highest index satisfying the prefixsum constraint
|
||||
"""
|
||||
assert 0 <= prefixsum <= self.sum() + 1e-5
|
||||
idx = 1
|
||||
while idx < self._capacity: # while non-leaf
|
||||
if self._value[2 * idx] > prefixsum:
|
||||
idx = 2 * idx
|
||||
else:
|
||||
prefixsum -= self._value[2 * idx]
|
||||
idx = 2 * idx + 1
|
||||
return idx - self._capacity
|
||||
|
||||
|
||||
class MinSegmentTree(SegmentTree):
|
||||
def __init__(self, capacity):
|
||||
super(MinSegmentTree, self).__init__(
|
||||
capacity=capacity,
|
||||
operation=min,
|
||||
neutral_element=float('inf'))
|
||||
|
||||
def min(self, start=0, end=None):
|
||||
"""Returns min(arr[start], ..., arr[end])"""
|
||||
|
||||
return super(MinSegmentTree, self).reduce(start, end)
|
||||
@@ -0,0 +1,107 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.dqn.common.segment_tree import SumSegmentTree, MinSegmentTree
|
||||
|
||||
|
||||
def test_tree_set():
|
||||
tree = SumSegmentTree(4)
|
||||
|
||||
tree[2] = 1.0
|
||||
tree[3] = 3.0
|
||||
|
||||
assert np.isclose(tree.sum(), 4.0)
|
||||
assert np.isclose(tree.sum(0, 2), 0.0)
|
||||
assert np.isclose(tree.sum(0, 3), 1.0)
|
||||
assert np.isclose(tree.sum(2, 3), 1.0)
|
||||
assert np.isclose(tree.sum(2, -1), 1.0)
|
||||
assert np.isclose(tree.sum(2, 4), 4.0)
|
||||
|
||||
|
||||
def test_tree_set_overlap():
|
||||
tree = SumSegmentTree(4)
|
||||
|
||||
tree[2] = 1.0
|
||||
tree[2] = 3.0
|
||||
|
||||
assert np.isclose(tree.sum(), 3.0)
|
||||
assert np.isclose(tree.sum(2, 3), 3.0)
|
||||
assert np.isclose(tree.sum(2, -1), 3.0)
|
||||
assert np.isclose(tree.sum(2, 4), 3.0)
|
||||
assert np.isclose(tree.sum(1, 2), 0.0)
|
||||
|
||||
|
||||
def test_prefixsum_idx():
|
||||
tree = SumSegmentTree(4)
|
||||
|
||||
tree[2] = 1.0
|
||||
tree[3] = 3.0
|
||||
|
||||
assert tree.find_prefixsum_idx(0.0) == 2
|
||||
assert tree.find_prefixsum_idx(0.5) == 2
|
||||
assert tree.find_prefixsum_idx(0.99) == 2
|
||||
assert tree.find_prefixsum_idx(1.01) == 3
|
||||
assert tree.find_prefixsum_idx(3.00) == 3
|
||||
assert tree.find_prefixsum_idx(4.00) == 3
|
||||
|
||||
|
||||
def test_prefixsum_idx2():
|
||||
tree = SumSegmentTree(4)
|
||||
|
||||
tree[0] = 0.5
|
||||
tree[1] = 1.0
|
||||
tree[2] = 1.0
|
||||
tree[3] = 3.0
|
||||
|
||||
assert tree.find_prefixsum_idx(0.00) == 0
|
||||
assert tree.find_prefixsum_idx(0.55) == 1
|
||||
assert tree.find_prefixsum_idx(0.99) == 1
|
||||
assert tree.find_prefixsum_idx(1.51) == 2
|
||||
assert tree.find_prefixsum_idx(3.00) == 3
|
||||
assert tree.find_prefixsum_idx(5.50) == 3
|
||||
|
||||
|
||||
def test_max_interval_tree():
|
||||
tree = MinSegmentTree(4)
|
||||
|
||||
tree[0] = 1.0
|
||||
tree[2] = 0.5
|
||||
tree[3] = 3.0
|
||||
|
||||
assert np.isclose(tree.min(), 0.5)
|
||||
assert np.isclose(tree.min(0, 2), 1.0)
|
||||
assert np.isclose(tree.min(0, 3), 0.5)
|
||||
assert np.isclose(tree.min(0, -1), 0.5)
|
||||
assert np.isclose(tree.min(2, 4), 0.5)
|
||||
assert np.isclose(tree.min(3, 4), 3.0)
|
||||
|
||||
tree[2] = 0.7
|
||||
|
||||
assert np.isclose(tree.min(), 0.7)
|
||||
assert np.isclose(tree.min(0, 2), 1.0)
|
||||
assert np.isclose(tree.min(0, 3), 0.7)
|
||||
assert np.isclose(tree.min(0, -1), 0.7)
|
||||
assert np.isclose(tree.min(2, 4), 0.7)
|
||||
assert np.isclose(tree.min(3, 4), 3.0)
|
||||
|
||||
tree[2] = 4.0
|
||||
|
||||
assert np.isclose(tree.min(), 1.0)
|
||||
assert np.isclose(tree.min(0, 2), 1.0)
|
||||
assert np.isclose(tree.min(0, 3), 1.0)
|
||||
assert np.isclose(tree.min(0, -1), 1.0)
|
||||
assert np.isclose(tree.min(2, 4), 3.0)
|
||||
assert np.isclose(tree.min(2, 3), 4.0)
|
||||
assert np.isclose(tree.min(2, -1), 4.0)
|
||||
assert np.isclose(tree.min(3, 4), 3.0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_tree_set()
|
||||
test_tree_set_overlap()
|
||||
test_prefixsum_idx()
|
||||
test_prefixsum_idx2()
|
||||
test_max_interval_tree()
|
||||
@@ -0,0 +1,73 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# tests for tf_util
|
||||
import tensorflow as tf
|
||||
from ray.rllib.dqn.common.tf_util import (
|
||||
function,
|
||||
initialize,
|
||||
set_value,
|
||||
single_threaded_session
|
||||
)
|
||||
|
||||
|
||||
def test_set_value():
|
||||
a = tf.Variable(42.)
|
||||
with single_threaded_session():
|
||||
set_value(a, 5)
|
||||
assert a.eval() == 5
|
||||
g = tf.get_default_graph()
|
||||
g.finalize()
|
||||
set_value(a, 6)
|
||||
assert a.eval() == 6
|
||||
|
||||
# test the test
|
||||
try:
|
||||
assert a.eval() == 7
|
||||
except AssertionError:
|
||||
pass
|
||||
else:
|
||||
assert False, "assertion should have failed"
|
||||
|
||||
|
||||
def test_function():
|
||||
tf.reset_default_graph()
|
||||
x = tf.placeholder(tf.int32, (), name="x")
|
||||
y = tf.placeholder(tf.int32, (), name="y")
|
||||
z = 3 * x + 2 * y
|
||||
lin = function([x, y], z, givens={y: 0})
|
||||
|
||||
with single_threaded_session():
|
||||
initialize()
|
||||
|
||||
assert lin(2) == 6
|
||||
assert lin(x=3) == 9
|
||||
assert lin(2, 2) == 10
|
||||
assert lin(x=2, y=3) == 12
|
||||
|
||||
|
||||
def test_multikwargs():
|
||||
tf.reset_default_graph()
|
||||
x = tf.placeholder(tf.int32, (), name="x")
|
||||
with tf.variable_scope("other"):
|
||||
x2 = tf.placeholder(tf.int32, (), name="x")
|
||||
z = 3 * x + 2 * x2
|
||||
|
||||
lin = function([x, x2], z, givens={x2: 0})
|
||||
with single_threaded_session():
|
||||
initialize()
|
||||
assert lin(2) == 6
|
||||
assert lin(2, 2) == 10
|
||||
expt_caught = False
|
||||
try:
|
||||
lin(x=2)
|
||||
except AssertionError:
|
||||
expt_caught = True
|
||||
assert expt_caught
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_set_value()
|
||||
test_function()
|
||||
test_multikwargs()
|
||||
@@ -0,0 +1,782 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf # pylint: ignore-module
|
||||
import builtins
|
||||
import functools
|
||||
import copy
|
||||
import os
|
||||
import collections
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Make consistent with numpy
|
||||
# ================================================================
|
||||
|
||||
clip = tf.clip_by_value
|
||||
|
||||
|
||||
def sum(x, axis=None, keepdims=False):
|
||||
axis = None if axis is None else [axis]
|
||||
return tf.reduce_sum(x, axis=axis, keep_dims=keepdims)
|
||||
|
||||
|
||||
def mean(x, axis=None, keepdims=False):
|
||||
axis = None if axis is None else [axis]
|
||||
return tf.reduce_mean(x, axis=axis, keep_dims=keepdims)
|
||||
|
||||
|
||||
def var(x, axis=None, keepdims=False):
|
||||
meanx = mean(x, axis=axis, keepdims=keepdims)
|
||||
return mean(tf.square(x - meanx), axis=axis, keepdims=keepdims)
|
||||
|
||||
|
||||
def std(x, axis=None, keepdims=False):
|
||||
return tf.sqrt(var(x, axis=axis, keepdims=keepdims))
|
||||
|
||||
|
||||
def max(x, axis=None, keepdims=False):
|
||||
axis = None if axis is None else [axis]
|
||||
return tf.reduce_max(x, axis=axis, keep_dims=keepdims)
|
||||
|
||||
|
||||
def min(x, axis=None, keepdims=False):
|
||||
axis = None if axis is None else [axis]
|
||||
return tf.reduce_min(x, axis=axis, keep_dims=keepdims)
|
||||
|
||||
|
||||
def concatenate(arrs, axis=0):
|
||||
return tf.concat(axis=axis, values=arrs)
|
||||
|
||||
|
||||
def argmax(x, axis=None):
|
||||
return tf.argmax(x, axis=axis)
|
||||
|
||||
|
||||
def switch(condition, then_expression, else_expression):
|
||||
"""Switches between two operations depending on a scalar value (int or bool).
|
||||
Note that both `then_expression` and `else_expression`
|
||||
should be symbolic tensors of the *same shape*.
|
||||
|
||||
# Arguments
|
||||
condition: scalar tensor.
|
||||
then_expression: TensorFlow operation.
|
||||
else_expression: TensorFlow operation.
|
||||
"""
|
||||
x_shape = copy.copy(then_expression.get_shape())
|
||||
x = tf.cond(tf.cast(condition, 'bool'),
|
||||
lambda: then_expression, lambda: else_expression)
|
||||
x.set_shape(x_shape)
|
||||
return x
|
||||
|
||||
# ================================================================
|
||||
# Extras
|
||||
# ================================================================
|
||||
|
||||
|
||||
def l2loss(params):
|
||||
if len(params) == 0:
|
||||
return tf.constant(0.0)
|
||||
else:
|
||||
return tf.add_n([sum(tf.square(p)) for p in params])
|
||||
|
||||
|
||||
def lrelu(x, leak=0.2):
|
||||
f1 = 0.5 * (1 + leak)
|
||||
f2 = 0.5 * (1 - leak)
|
||||
return f1 * x + f2 * abs(x)
|
||||
|
||||
|
||||
def categorical_sample_logits(X):
|
||||
# https://github.com/tensorflow/tensorflow/issues/456
|
||||
U = tf.random_uniform(tf.shape(X))
|
||||
return argmax(X - tf.log(-tf.log(U)), axis=1)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Inputs
|
||||
# ================================================================
|
||||
|
||||
|
||||
def is_placeholder(x):
|
||||
return type(x) is tf.Tensor and len(x.op.inputs) == 0
|
||||
|
||||
|
||||
class TfInput(object):
|
||||
def __init__(self, name="(unnamed)"):
|
||||
"""Generalized Tensorflow placeholder. The main differences are:
|
||||
- possibly uses multiple placeholders internally and returns multiple
|
||||
values
|
||||
- can apply light postprocessing to the value feed to placeholder.
|
||||
"""
|
||||
self.name = name
|
||||
|
||||
def get(self):
|
||||
"""Return the tf variable(s) representing the possibly postprocessed value
|
||||
of placeholder(s).
|
||||
"""
|
||||
raise NotImplemented()
|
||||
|
||||
def make_feed_dict(data):
|
||||
"""Given data input it to the placeholder(s)."""
|
||||
raise NotImplemented()
|
||||
|
||||
|
||||
class PlacholderTfInput(TfInput):
|
||||
def __init__(self, placeholder):
|
||||
"""Wrapper for regular tensorflow placeholder."""
|
||||
super().__init__(placeholder.name)
|
||||
self._placeholder = placeholder
|
||||
|
||||
def get(self):
|
||||
return self._placeholder
|
||||
|
||||
def make_feed_dict(self, data):
|
||||
return {self._placeholder: data}
|
||||
|
||||
|
||||
class BatchInput(PlacholderTfInput):
|
||||
def __init__(self, shape, dtype=tf.float32, name=None):
|
||||
"""Creates a placeholder for a batch of tensors of a given shape and dtype
|
||||
|
||||
Parameters
|
||||
----------
|
||||
shape: [int]
|
||||
shape of a single elemenet of the batch
|
||||
dtype: tf.dtype
|
||||
number representation used for tensor contents
|
||||
name: str
|
||||
name of the underlying placeholder
|
||||
"""
|
||||
super().__init__(tf.placeholder(dtype, [None] + list(shape), name=name))
|
||||
|
||||
|
||||
class Uint8Input(PlacholderTfInput):
|
||||
def __init__(self, shape, name=None):
|
||||
"""Takes input in uint8 format which is cast to float32 and divided by 255
|
||||
before passing it to the model.
|
||||
|
||||
On GPU this ensures lower data transfer times.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
shape: [int]
|
||||
shape of the tensor.
|
||||
name: str
|
||||
name of the underlying placeholder
|
||||
"""
|
||||
|
||||
super().__init__(tf.placeholder(tf.uint8, [None] + list(shape), name=name))
|
||||
self._shape = shape
|
||||
self._output = tf.cast(super().get(), tf.float32) / 255.0
|
||||
|
||||
def get(self):
|
||||
return self._output
|
||||
|
||||
|
||||
def ensure_tf_input(thing):
|
||||
"""Takes either tf.placeholder of TfInput and outputs equivalent TfInput"""
|
||||
if isinstance(thing, TfInput):
|
||||
return thing
|
||||
elif is_placeholder(thing):
|
||||
return PlacholderTfInput(thing)
|
||||
else:
|
||||
raise ValueError("Must be a placeholder or TfInput")
|
||||
|
||||
# ================================================================
|
||||
# Mathematical utils
|
||||
# ================================================================
|
||||
|
||||
|
||||
def huber_loss(x, delta=1.0):
|
||||
"""Reference: https://en.wikipedia.org/wiki/Huber_loss"""
|
||||
return tf.where(
|
||||
tf.abs(x) < delta,
|
||||
tf.square(x) * 0.5,
|
||||
delta * (tf.abs(x) - 0.5 * delta))
|
||||
|
||||
# ================================================================
|
||||
# Optimizer utils
|
||||
# ================================================================
|
||||
|
||||
|
||||
def minimize_and_clip(optimizer, objective, var_list, clip_val=10):
|
||||
"""Minimized `objective` using `optimizer` w.r.t. variables in
|
||||
`var_list` while ensure the norm of the gradients for each
|
||||
variable is clipped to `clip_val`
|
||||
"""
|
||||
gradients = optimizer.compute_gradients(objective, var_list=var_list)
|
||||
for i, (grad, var) in enumerate(gradients):
|
||||
if grad is not None:
|
||||
gradients[i] = (tf.clip_by_norm(grad, clip_val), var)
|
||||
return optimizer.apply_gradients(gradients)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Global session
|
||||
# ================================================================
|
||||
|
||||
def get_session():
|
||||
"""Returns recently made Tensorflow session"""
|
||||
return tf.get_default_session()
|
||||
|
||||
|
||||
def make_session(num_cpu):
|
||||
"""Returns a session that will use <num_cpu> CPU's only"""
|
||||
tf_config = tf.ConfigProto(
|
||||
inter_op_parallelism_threads=num_cpu,
|
||||
intra_op_parallelism_threads=num_cpu)
|
||||
return tf.Session(config=tf_config)
|
||||
|
||||
|
||||
def single_threaded_session():
|
||||
"""Returns a session which will only use a single CPU"""
|
||||
return make_session(1)
|
||||
|
||||
|
||||
ALREADY_INITIALIZED = set()
|
||||
|
||||
|
||||
def initialize():
|
||||
"""Initialize all the uninitialized variables in the global scope."""
|
||||
new_variables = set(tf.global_variables()) - ALREADY_INITIALIZED
|
||||
get_session().run(tf.variables_initializer(new_variables))
|
||||
ALREADY_INITIALIZED.update(new_variables)
|
||||
|
||||
|
||||
def eval(expr, feed_dict=None):
|
||||
if feed_dict is None:
|
||||
feed_dict = {}
|
||||
return get_session().run(expr, feed_dict=feed_dict)
|
||||
|
||||
|
||||
VALUE_SETTERS = collections.OrderedDict()
|
||||
|
||||
|
||||
def set_value(v, val):
|
||||
global VALUE_SETTERS
|
||||
if v in VALUE_SETTERS:
|
||||
set_op, set_endpoint = VALUE_SETTERS[v]
|
||||
else:
|
||||
set_endpoint = tf.placeholder(v.dtype)
|
||||
set_op = v.assign(set_endpoint)
|
||||
VALUE_SETTERS[v] = (set_op, set_endpoint)
|
||||
get_session().run(set_op, feed_dict={set_endpoint: val})
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Saving variables
|
||||
# ================================================================
|
||||
|
||||
|
||||
def load_state(fname):
|
||||
saver = tf.train.Saver()
|
||||
saver.restore(get_session(), fname)
|
||||
|
||||
|
||||
def save_state(fname):
|
||||
os.makedirs(os.path.dirname(fname), exist_ok=True)
|
||||
saver = tf.train.Saver()
|
||||
saver.save(get_session(), fname)
|
||||
|
||||
# ================================================================
|
||||
# Model components
|
||||
# ================================================================
|
||||
|
||||
|
||||
def normc_initializer(std=1.0):
|
||||
# pylint: disable=W0613
|
||||
def _initializer(shape, dtype=None, partition_info=None):
|
||||
out = np.random.randn(*shape).astype(np.float32)
|
||||
out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
|
||||
return tf.constant(out)
|
||||
return _initializer
|
||||
|
||||
|
||||
def conv2d(
|
||||
x, num_filters, name, filter_size=(3, 3), stride=(1, 1), pad="SAME",
|
||||
dtype=tf.float32, collections=None, summary_tag=None):
|
||||
with tf.variable_scope(name):
|
||||
stride_shape = [1, stride[0], stride[1], 1]
|
||||
filter_shape = [
|
||||
filter_size[0], filter_size[1], int(x.get_shape()[3]), num_filters]
|
||||
|
||||
# there are "num input feature maps * filter height * filter width"
|
||||
# inputs to each hidden unit
|
||||
fan_in = intprod(filter_shape[:3])
|
||||
# each unit in the lower layer receives a gradient from:
|
||||
# "num output feature maps * filter height * filter width" /
|
||||
# pooling size
|
||||
fan_out = intprod(filter_shape[:2]) * num_filters
|
||||
# initialize weights with random weights
|
||||
w_bound = np.sqrt(6. / (fan_in + fan_out))
|
||||
|
||||
w = tf.get_variable(
|
||||
"W", filter_shape, dtype,
|
||||
tf.random_uniform_initializer(-w_bound, w_bound),
|
||||
collections=collections)
|
||||
b = tf.get_variable(
|
||||
"b", [1, 1, 1, num_filters], initializer=tf.zeros_initializer(),
|
||||
collections=collections)
|
||||
|
||||
if summary_tag is not None:
|
||||
tf.summary.image(
|
||||
summary_tag,
|
||||
tf.transpose(tf.reshape(w, [filter_size[0], filter_size[1], -1, 1]),
|
||||
[2, 0, 1, 3]),
|
||||
max_images=10)
|
||||
|
||||
return tf.nn.conv2d(x, w, stride_shape, pad) + b
|
||||
|
||||
|
||||
def dense(x, size, name, weight_init=None, bias=True):
|
||||
w = tf.get_variable(name + "/w", [x.get_shape()[1], size],
|
||||
initializer=weight_init)
|
||||
ret = tf.matmul(x, w)
|
||||
if bias:
|
||||
b = tf.get_variable(
|
||||
name + "/b", [size], initializer=tf.zeros_initializer())
|
||||
return ret + b
|
||||
else:
|
||||
return ret
|
||||
|
||||
|
||||
def wndense(x, size, name, init_scale=1.0):
|
||||
v = tf.get_variable(
|
||||
name + "/V", [int(x.get_shape()[1]), size],
|
||||
initializer=tf.random_normal_initializer(0, 0.05))
|
||||
g = tf.get_variable(
|
||||
name + "/g", [size], initializer=tf.constant_initializer(init_scale))
|
||||
b = tf.get_variable(
|
||||
name + "/b", [size], initializer=tf.constant_initializer(0.0))
|
||||
|
||||
# use weight normalization (Salimans & Kingma, 2016)
|
||||
x = tf.matmul(x, v)
|
||||
scaler = g / tf.sqrt(sum(tf.square(v), axis=0, keepdims=True))
|
||||
return tf.reshape(scaler, [1, size]) * x + tf.reshape(b, [1, size])
|
||||
|
||||
|
||||
def densenobias(x, size, name, weight_init=None):
|
||||
return dense(x, size, name, weight_init=weight_init, bias=False)
|
||||
|
||||
|
||||
def dropout(x, pkeep, phase=None, mask=None):
|
||||
mask = tf.floor(
|
||||
pkeep + tf.random_uniform(tf.shape(x))) if mask is None else mask
|
||||
if phase is None:
|
||||
return mask * x
|
||||
else:
|
||||
return switch(phase, mask * x, pkeep * x)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Theano-like Function
|
||||
# ================================================================
|
||||
|
||||
|
||||
def function(inputs, outputs, updates=None, givens=None):
|
||||
"""Just like Theano function. Take a bunch of tensorflow placeholders and
|
||||
expressions computed based on those placeholders and produces f(inputs) ->
|
||||
outputs. Function f takes values to be fed to the input's placeholders and
|
||||
produces the values of the expressions in outputs.
|
||||
|
||||
Input values can be passed in the same order as inputs or can be provided as
|
||||
kwargs based on placeholder name (passed to constructor or accessible via
|
||||
placeholder.op.name).
|
||||
|
||||
Example:
|
||||
x = tf.placeholder(tf.int32, (), name="x")
|
||||
y = tf.placeholder(tf.int32, (), name="y")
|
||||
z = 3 * x + 2 * y
|
||||
lin = function([x, y], z, givens={y: 0})
|
||||
|
||||
with single_threaded_session():
|
||||
initialize()
|
||||
|
||||
assert lin(2) == 6
|
||||
assert lin(x=3) == 9
|
||||
assert lin(2, 2) == 10
|
||||
assert lin(x=2, y=3) == 12
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs: [tf.placeholder or TfInput]
|
||||
list of input arguments
|
||||
outputs: [tf.Variable] or tf.Variable
|
||||
list of outputs or a single output to be returned from function. Returned
|
||||
value will also have the same shape.
|
||||
"""
|
||||
if isinstance(outputs, list):
|
||||
return _Function(inputs, outputs, updates, givens=givens)
|
||||
elif isinstance(outputs, (dict, collections.OrderedDict)):
|
||||
f = _Function(inputs, outputs.values(), updates, givens=givens)
|
||||
return lambda *args, **kwargs: type(outputs)(
|
||||
zip(outputs.keys(), f(*args, **kwargs)))
|
||||
else:
|
||||
f = _Function(inputs, [outputs], updates, givens=givens)
|
||||
return lambda *args, **kwargs: f(*args, **kwargs)[0]
|
||||
|
||||
|
||||
class _Function(object):
|
||||
def __init__(self, inputs, outputs, updates, givens, check_nan=False):
|
||||
for inpt in inputs:
|
||||
if not issubclass(type(inpt), TfInput):
|
||||
assert len(inpt.op.inputs) == 0, \
|
||||
"inputs should all be placeholders of ray.rllib.dqn.common.TfInput"
|
||||
self.inputs = inputs
|
||||
updates = updates or []
|
||||
self.update_group = tf.group(*updates)
|
||||
self.outputs_update = list(outputs) + [self.update_group]
|
||||
self.givens = {} if givens is None else givens
|
||||
self.check_nan = check_nan
|
||||
|
||||
def _feed_input(self, feed_dict, inpt, value):
|
||||
if issubclass(type(inpt), TfInput):
|
||||
feed_dict.update(inpt.make_feed_dict(value))
|
||||
elif is_placeholder(inpt):
|
||||
feed_dict[inpt] = value
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
assert len(args) <= len(self.inputs), "Too many arguments provided"
|
||||
feed_dict = {}
|
||||
# Update the args
|
||||
for inpt, value in zip(self.inputs, args):
|
||||
self._feed_input(feed_dict, inpt, value)
|
||||
# Update the kwargs
|
||||
kwargs_passed_inpt_names = set()
|
||||
for inpt in self.inputs[len(args):]:
|
||||
inpt_name = inpt.name.split(':')[0]
|
||||
inpt_name = inpt_name.split('/')[-1]
|
||||
assert inpt_name not in kwargs_passed_inpt_names, \
|
||||
("this function has two arguments with the same name \"{}\", " +
|
||||
"so kwargs cannot be used.".format(inpt_name))
|
||||
if inpt_name in kwargs:
|
||||
kwargs_passed_inpt_names.add(inpt_name)
|
||||
self._feed_input(feed_dict, inpt, kwargs.pop(inpt_name))
|
||||
else:
|
||||
assert inpt in self.givens, "Missing argument " + inpt_name
|
||||
assert len(kwargs) == 0, \
|
||||
"Function got extra arguments " + str(list(kwargs.keys()))
|
||||
# Update feed dict with givens.
|
||||
for inpt in self.givens:
|
||||
feed_dict[inpt] = feed_dict.get(inpt, self.givens[inpt])
|
||||
results = get_session().run(self.outputs_update, feed_dict=feed_dict)[:-1]
|
||||
if self.check_nan:
|
||||
if any(np.isnan(r).any() for r in results):
|
||||
raise RuntimeError("Nan detected")
|
||||
return results
|
||||
|
||||
|
||||
def mem_friendly_function(nondata_inputs, data_inputs, outputs, batch_size):
|
||||
if isinstance(outputs, list):
|
||||
return _MemFriendlyFunction(
|
||||
nondata_inputs, data_inputs, outputs, batch_size)
|
||||
else:
|
||||
f = _MemFriendlyFunction(
|
||||
nondata_inputs, data_inputs, [outputs], batch_size)
|
||||
return lambda *inputs: f(*inputs)[0]
|
||||
|
||||
|
||||
class _MemFriendlyFunction(object):
|
||||
def __init__(self, nondata_inputs, data_inputs, outputs, batch_size):
|
||||
self.nondata_inputs = nondata_inputs
|
||||
self.data_inputs = data_inputs
|
||||
self.outputs = list(outputs)
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __call__(self, *inputvals):
|
||||
assert len(inputvals) == len(self.nondata_inputs) + len(self.data_inputs)
|
||||
nondata_vals = inputvals[0:len(self.nondata_inputs)]
|
||||
data_vals = inputvals[len(self.nondata_inputs):]
|
||||
feed_dict = dict(zip(self.nondata_inputs, nondata_vals))
|
||||
n = data_vals[0].shape[0]
|
||||
for v in data_vals[1:]:
|
||||
assert v.shape[0] == n
|
||||
for i_start in range(0, n, self.batch_size):
|
||||
slice_vals = [
|
||||
v[i_start:builtins.min(i_start + self.batch_size, n)]
|
||||
for v in data_vals]
|
||||
for (var, val) in zip(self.data_inputs, slice_vals):
|
||||
feed_dict[var] = val
|
||||
results = tf.get_default_session().run(self.outputs, feed_dict=feed_dict)
|
||||
if i_start == 0:
|
||||
sum_results = results
|
||||
else:
|
||||
for i in range(len(results)):
|
||||
sum_results[i] = sum_results[i] + results[i]
|
||||
for i in range(len(results)):
|
||||
sum_results[i] = sum_results[i] / n
|
||||
return sum_results
|
||||
|
||||
# ================================================================
|
||||
# Modules
|
||||
# ================================================================
|
||||
|
||||
|
||||
class Module(object):
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.first_time = True
|
||||
self.scope = None
|
||||
self.cache = {}
|
||||
|
||||
def __call__(self, *args):
|
||||
if args in self.cache:
|
||||
print("(%s) retrieving value from cache" % (self.name,))
|
||||
return self.cache[args]
|
||||
with tf.variable_scope(self.name, reuse=not self.first_time):
|
||||
scope = tf.get_variable_scope().name
|
||||
if self.first_time:
|
||||
self.scope = scope
|
||||
print("(%s) running function for the first time" % (self.name,))
|
||||
else:
|
||||
assert self.scope == scope, \
|
||||
"Tried calling function with a different scope"
|
||||
print("(%s) running function on new inputs" % (self.name,))
|
||||
self.first_time = False
|
||||
out = self._call(*args)
|
||||
self.cache[args] = out
|
||||
return out
|
||||
|
||||
def _call(self, *args):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def trainable_variables(self):
|
||||
assert self.scope is not None, \
|
||||
"need to call module once before getting variables"
|
||||
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope)
|
||||
|
||||
@property
|
||||
def variables(self):
|
||||
assert self.scope is not None, \
|
||||
"need to call module once before getting variables"
|
||||
return tf.get_collection(tf.GraphKeys.VARIABLES, self.scope)
|
||||
|
||||
|
||||
def module(name):
|
||||
@functools.wraps
|
||||
def wrapper(f):
|
||||
class WrapperModule(Module):
|
||||
def _call(self, *args):
|
||||
return f(*args)
|
||||
return WrapperModule(name)
|
||||
return wrapper
|
||||
|
||||
# ================================================================
|
||||
# Graph traversal
|
||||
# ================================================================
|
||||
|
||||
|
||||
VARIABLES = {}
|
||||
|
||||
|
||||
def get_parents(node):
|
||||
return node.op.inputs
|
||||
|
||||
|
||||
def topsorted(outputs):
|
||||
"""
|
||||
Topological sort via non-recursive depth-first search
|
||||
"""
|
||||
assert isinstance(outputs, (list, tuple))
|
||||
marks = {}
|
||||
out = []
|
||||
stack = [] # pylint: disable=W0621
|
||||
# i: node
|
||||
# jidx = number of children visited so far from that node
|
||||
# marks: state of each node, which is one of
|
||||
# 0: haven't visited
|
||||
# 1: have visited, but not done visiting children
|
||||
# 2: done visiting children
|
||||
for x in outputs:
|
||||
stack.append((x, 0))
|
||||
while stack:
|
||||
(i, jidx) = stack.pop()
|
||||
if jidx == 0:
|
||||
m = marks.get(i, 0)
|
||||
if m == 0:
|
||||
marks[i] = 1
|
||||
elif m == 1:
|
||||
raise ValueError("not a dag")
|
||||
else:
|
||||
continue
|
||||
ps = get_parents(i)
|
||||
if jidx == len(ps):
|
||||
marks[i] = 2
|
||||
out.append(i)
|
||||
else:
|
||||
stack.append((i, jidx + 1))
|
||||
j = ps[jidx]
|
||||
stack.append((j, 0))
|
||||
return out
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Flat vectors
|
||||
# ================================================================
|
||||
|
||||
def var_shape(x):
|
||||
out = x.get_shape().as_list()
|
||||
assert all(isinstance(a, int) for a in out), \
|
||||
"shape function assumes that shape is fully known"
|
||||
return out
|
||||
|
||||
|
||||
def numel(x):
|
||||
return intprod(var_shape(x))
|
||||
|
||||
|
||||
def intprod(x):
|
||||
return int(np.prod(x))
|
||||
|
||||
|
||||
def flatgrad(loss, var_list):
|
||||
grads = tf.gradients(loss, var_list)
|
||||
return tf.concat(axis=0, values=[
|
||||
tf.reshape(grad if grad is not None else tf.zeros_like(v), [numel(v)])
|
||||
for (v, grad) in zip(var_list, grads)
|
||||
])
|
||||
|
||||
|
||||
class SetFromFlat(object):
|
||||
def __init__(self, var_list, dtype=tf.float32):
|
||||
assigns = []
|
||||
shapes = list(map(var_shape, var_list))
|
||||
total_size = np.sum([intprod(shape) for shape in shapes])
|
||||
|
||||
self.theta = theta = tf.placeholder(dtype, [total_size])
|
||||
start = 0
|
||||
assigns = []
|
||||
for (shape, v) in zip(shapes, var_list):
|
||||
size = intprod(shape)
|
||||
assigns.append(
|
||||
tf.assign(v, tf.reshape(theta[start:start + size], shape)))
|
||||
start += size
|
||||
self.op = tf.group(*assigns)
|
||||
|
||||
def __call__(self, theta):
|
||||
get_session().run(self.op, feed_dict={self.theta: theta})
|
||||
|
||||
|
||||
class GetFlat(object):
|
||||
def __init__(self, var_list):
|
||||
self.op = tf.concat(
|
||||
axis=0, values=[tf.reshape(v, [numel(v)]) for v in var_list])
|
||||
|
||||
def __call__(self):
|
||||
return get_session().run(self.op)
|
||||
|
||||
# ================================================================
|
||||
# Misc
|
||||
# ================================================================
|
||||
|
||||
|
||||
def fancy_slice_2d(X, inds0, inds1):
|
||||
"""
|
||||
like numpy X[inds0, inds1]
|
||||
XXX this implementation is bad
|
||||
"""
|
||||
inds0 = tf.cast(inds0, tf.int64)
|
||||
inds1 = tf.cast(inds1, tf.int64)
|
||||
shape = tf.cast(tf.shape(X), tf.int64)
|
||||
ncols = shape[1]
|
||||
Xflat = tf.reshape(X, [-1])
|
||||
return tf.gather(Xflat, inds0 * ncols + inds1)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Scopes
|
||||
# ================================================================
|
||||
|
||||
|
||||
def scope_vars(scope, trainable_only=False):
|
||||
"""
|
||||
Get variables inside a scope
|
||||
The scope can be specified as a string
|
||||
|
||||
Parameters
|
||||
----------
|
||||
scope: str or VariableScope
|
||||
scope in which the variables reside.
|
||||
trainable_only: bool
|
||||
whether or not to return only the variables that were marked as trainable.
|
||||
|
||||
Returns
|
||||
-------
|
||||
vars: [tf.Variable]
|
||||
list of variables in `scope`.
|
||||
"""
|
||||
return tf.get_collection(
|
||||
tf.GraphKeys.TRAINABLE_VARIABLES
|
||||
if trainable_only else tf.GraphKeys.VARIABLES,
|
||||
scope=scope if isinstance(scope, str) else scope.name)
|
||||
|
||||
|
||||
def scope_name():
|
||||
"""Returns the name of current scope as a string, e.g. deepq/q_func"""
|
||||
return tf.get_variable_scope().name
|
||||
|
||||
|
||||
def absolute_scope_name(relative_scope_name):
|
||||
"""Appends parent scope name to `relative_scope_name`"""
|
||||
return scope_name() + "/" + relative_scope_name
|
||||
|
||||
|
||||
def lengths_to_mask(lengths_b, max_length):
|
||||
"""
|
||||
Turns a vector of lengths into a boolean mask
|
||||
|
||||
Args:
|
||||
lengths_b: an integer vector of lengths
|
||||
max_length: maximum length to fill the mask
|
||||
|
||||
Returns:
|
||||
a boolean array of shape (batch_size, max_length)
|
||||
row[i] consists of True repeated lengths_b[i] times, followed by False
|
||||
"""
|
||||
lengths_b = tf.convert_to_tensor(lengths_b)
|
||||
assert lengths_b.get_shape().ndims == 1
|
||||
mask_bt = tf.expand_dims(
|
||||
tf.range(max_length), 0) < tf.expand_dims(lengths_b, 1)
|
||||
return mask_bt
|
||||
|
||||
|
||||
def in_session(f):
|
||||
@functools.wraps(f)
|
||||
def newfunc(*args, **kwargs):
|
||||
with tf.Session():
|
||||
f(*args, **kwargs)
|
||||
return newfunc
|
||||
|
||||
|
||||
_PLACEHOLDER_CACHE = {} # name -> (placeholder, dtype, shape)
|
||||
|
||||
|
||||
def get_placeholder(name, dtype, shape):
|
||||
if name in _PLACEHOLDER_CACHE:
|
||||
out, dtype1, shape1 = _PLACEHOLDER_CACHE[name]
|
||||
assert dtype1 == dtype and shape1 == shape
|
||||
return out
|
||||
else:
|
||||
out = tf.placeholder(dtype=dtype, shape=shape, name=name)
|
||||
_PLACEHOLDER_CACHE[name] = (out, dtype, shape)
|
||||
return out
|
||||
|
||||
|
||||
def get_placeholder_cached(name):
|
||||
return _PLACEHOLDER_CACHE[name][0]
|
||||
|
||||
|
||||
def flattenallbut0(x):
|
||||
return tf.reshape(x, [-1, intprod(x.get_shape().as_list()[1:])])
|
||||
|
||||
|
||||
def reset():
|
||||
global _PLACEHOLDER_CACHE
|
||||
global VARIABLES
|
||||
_PLACEHOLDER_CACHE = {}
|
||||
VARIABLES = {}
|
||||
tf.reset_default_graph()
|
||||
@@ -0,0 +1,208 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from ray.rllib.common import Algorithm, TrainingResult
|
||||
from ray.rllib.dqn.build_graph import build_train
|
||||
from ray.rllib.dqn import logger, models
|
||||
from ray.rllib.dqn.common.atari_wrappers_deprecated \
|
||||
import wrap_dqn, ScaledFloatFrame
|
||||
from ray.rllib.dqn.common import tf_util as U
|
||||
from ray.rllib.dqn.common.schedules import LinearSchedule
|
||||
from ray.rllib.dqn.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
|
||||
|
||||
|
||||
"""The default configuration dict for the DQN algorithm.
|
||||
|
||||
lr: float
|
||||
learning rate for adam optimizer
|
||||
schedule_max_timesteps: int
|
||||
max num timesteps for annealing schedules
|
||||
timesteps_per_iteration: int
|
||||
number of env steps to optimize for before returning
|
||||
buffer_size: int
|
||||
size of the replay buffer
|
||||
exploration_fraction: float
|
||||
fraction of entire training period over which the exploration rate is
|
||||
annealed
|
||||
exploration_final_eps: float
|
||||
final value of random action probability
|
||||
train_freq: int
|
||||
update the model every `train_freq` steps.
|
||||
batch_size: int
|
||||
size of a batched sampled from replay buffer for training
|
||||
print_freq: int
|
||||
how often to print out training progress
|
||||
set to None to disable printing
|
||||
checkpoint_freq: int
|
||||
how often to save the model. This is so that the best version is restored
|
||||
at the end of the training. If you do not wish to restore the best version
|
||||
at the end of the training set this variable to None.
|
||||
learning_starts: int
|
||||
how many steps of the model to collect transitions for before learning
|
||||
starts
|
||||
gamma: float
|
||||
discount factor
|
||||
target_network_update_freq: int
|
||||
update the target network every `target_network_update_freq` steps.
|
||||
prioritized_replay: True
|
||||
if True prioritized replay buffer will be used.
|
||||
prioritized_replay_alpha: float
|
||||
alpha parameter for prioritized replay buffer
|
||||
prioritized_replay_beta0: float
|
||||
initial value of beta for prioritized replay buffer
|
||||
prioritized_replay_beta_iters: int
|
||||
number of iterations over which beta will be annealed from initial value
|
||||
to 1.0. If set to None equals to schedule_max_timesteps
|
||||
prioritized_replay_eps: float
|
||||
epsilon to add to the TD errors when updating priorities.
|
||||
num_cpu: int
|
||||
number of cpus to use for training
|
||||
"""
|
||||
DEFAULT_CONFIG = dict(
|
||||
lr=5e-4,
|
||||
schedule_max_timesteps=100000,
|
||||
timesteps_per_iteration=1000,
|
||||
buffer_size=50000,
|
||||
exploration_fraction=0.1,
|
||||
exploration_final_eps=0.02,
|
||||
train_freq=1,
|
||||
batch_size=32,
|
||||
print_freq=1,
|
||||
checkpoint_freq=10000,
|
||||
learning_starts=1000,
|
||||
gamma=1.0,
|
||||
target_network_update_freq=500,
|
||||
prioritized_replay=False,
|
||||
prioritized_replay_alpha=0.6,
|
||||
prioritized_replay_beta0=0.4,
|
||||
prioritized_replay_beta_iters=None,
|
||||
prioritized_replay_eps=1e-6,
|
||||
num_cpu=16)
|
||||
|
||||
|
||||
class DQN(Algorithm):
|
||||
def __init__(self, env_name, config):
|
||||
Algorithm.__init__(self, env_name, config)
|
||||
env = gym.make(env_name)
|
||||
env = ScaledFloatFrame(wrap_dqn(env))
|
||||
self.env = env
|
||||
model = models.cnn_to_mlp(
|
||||
convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
|
||||
hiddens=[256], dueling=True)
|
||||
sess = U.make_session(num_cpu=config["num_cpu"])
|
||||
sess.__enter__()
|
||||
|
||||
def make_obs_ph(name):
|
||||
return U.BatchInput(env.observation_space.shape, name=name)
|
||||
|
||||
self.act, self.optimize, self.update_target, self.debug = build_train(
|
||||
make_obs_ph=make_obs_ph,
|
||||
q_func=model,
|
||||
num_actions=env.action_space.n,
|
||||
optimizer=tf.train.AdamOptimizer(learning_rate=config["lr"]),
|
||||
gamma=config["gamma"],
|
||||
grad_norm_clipping=10)
|
||||
# Create the replay buffer
|
||||
if config["prioritized_replay"]:
|
||||
self.replay_buffer = PrioritizedReplayBuffer(
|
||||
config["buffer_size"], alpha=config["prioritized_replay_alpha"])
|
||||
prioritized_replay_beta_iters = config["prioritized_replay_beta_iters"]
|
||||
if prioritized_replay_beta_iters is None:
|
||||
prioritized_replay_beta_iters = config["schedule_max_timesteps"]
|
||||
self.beta_schedule = LinearSchedule(
|
||||
prioritized_replay_beta_iters,
|
||||
initial_p=config["prioritized_replay_beta0"],
|
||||
final_p=1.0)
|
||||
else:
|
||||
self.replay_buffer = ReplayBuffer(config["buffer_size"])
|
||||
self.beta_schedule = None
|
||||
# Create the schedule for exploration starting from 1.
|
||||
self.exploration = LinearSchedule(
|
||||
schedule_timesteps=int(
|
||||
config["exploration_fraction"] * config["schedule_max_timesteps"]),
|
||||
initial_p=1.0,
|
||||
final_p=config["exploration_final_eps"])
|
||||
|
||||
# Initialize the parameters and copy them to the target network.
|
||||
U.initialize()
|
||||
self.update_target()
|
||||
|
||||
self.episode_rewards = [0.0]
|
||||
self.episode_lengths = [0.0]
|
||||
self.saved_mean_reward = None
|
||||
self.obs = self.env.reset()
|
||||
self.num_timesteps = 0
|
||||
self.num_iterations = 0
|
||||
|
||||
def train(self):
|
||||
config = self.config
|
||||
sample_time, learn_time = 0, 0
|
||||
|
||||
for t in range(config["timesteps_per_iteration"]):
|
||||
self.num_timesteps += 1
|
||||
dt = time.time()
|
||||
# Take action and update exploration to the newest value
|
||||
action = self.act(
|
||||
np.array(self.obs)[None], update_eps=self.exploration.value(t))[0]
|
||||
new_obs, rew, done, _ = self.env.step(action)
|
||||
# Store transition in the replay buffer.
|
||||
self.replay_buffer.add(self.obs, action, rew, new_obs, float(done))
|
||||
self.obs = new_obs
|
||||
|
||||
self.episode_rewards[-1] += rew
|
||||
self.episode_lengths[-1] += 1
|
||||
if done:
|
||||
self.obs = self.env.reset()
|
||||
self.episode_rewards.append(0.0)
|
||||
self.episode_lengths.append(0.0)
|
||||
sample_time += time.time() - dt
|
||||
|
||||
if self.num_timesteps > config["learning_starts"] and \
|
||||
self.num_timesteps % config["train_freq"] == 0:
|
||||
dt = time.time()
|
||||
# Minimize the error in Bellman's equation on a batch sampled from
|
||||
# replay buffer.
|
||||
if config["prioritized_replay"]:
|
||||
experience = self.replay_buffer.sample(
|
||||
config["batch_size"], beta=self.beta_schedule.value(t))
|
||||
(obses_t, actions, rewards, obses_tp1,
|
||||
dones, _, batch_idxes) = experience
|
||||
else:
|
||||
obses_t, actions, rewards, obses_tp1, dones = \
|
||||
self.replay_buffer.sample(config["batch_size"])
|
||||
batch_idxes = None
|
||||
td_errors = self.optimize(
|
||||
obses_t, actions, rewards, obses_tp1, dones, np.ones_like(rewards))
|
||||
if config["prioritized_replay"]:
|
||||
new_priorities = np.abs(td_errors) + config["prioritized_replay_eps"]
|
||||
self.replay_buffer.update_priorities(batch_idxes, new_priorities)
|
||||
learn_time += (time.time() - dt)
|
||||
|
||||
if self.num_timesteps > config["learning_starts"] and \
|
||||
self.num_timesteps % config["target_network_update_freq"] == 0:
|
||||
# Update target network periodically.
|
||||
self.update_target()
|
||||
|
||||
mean_100ep_reward = round(np.mean(self.episode_rewards[-101:-1]), 1)
|
||||
mean_100ep_length = round(np.mean(self.episode_lengths[-101:-1]), 1)
|
||||
num_episodes = len(self.episode_rewards)
|
||||
logger.record_tabular("sample_time", sample_time)
|
||||
logger.record_tabular("learn_time", learn_time)
|
||||
logger.record_tabular("steps", self.num_timesteps)
|
||||
logger.record_tabular("episodes", num_episodes)
|
||||
logger.record_tabular("mean 100 episode reward", mean_100ep_reward)
|
||||
logger.record_tabular(
|
||||
"% time spent exploring", int(100 * self.exploration.value(t)))
|
||||
logger.dump_tabular()
|
||||
|
||||
res = TrainingResult(
|
||||
self.num_iterations, mean_100ep_reward, mean_100ep_length)
|
||||
self.num_iterations += 1
|
||||
return res
|
||||
Executable
+32
@@ -0,0 +1,32 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.dqn import DQN, DEFAULT_CONFIG
|
||||
|
||||
|
||||
def main():
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config.update(dict(
|
||||
lr=1e-4,
|
||||
schedule_max_timesteps=2000000,
|
||||
buffer_size=10000,
|
||||
exploration_fraction=0.1,
|
||||
exploration_final_eps=0.01,
|
||||
train_freq=4,
|
||||
learning_starts=10000,
|
||||
target_network_update_freq=1000,
|
||||
gamma=0.99,
|
||||
prioritized_replay=True))
|
||||
|
||||
dqn = DQN("PongNoFrameskip-v4", config)
|
||||
|
||||
while True:
|
||||
res = dqn.train()
|
||||
print("current status: {}".format(res))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,305 @@
|
||||
"""
|
||||
|
||||
See README.md for a description of the logging API.
|
||||
|
||||
OFF state corresponds to having Logger.CURRENT == Logger.DEFAULT
|
||||
ON state is otherwise
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import os.path as osp
|
||||
import json
|
||||
|
||||
LOG_OUTPUT_FORMATS = ['stdout', 'log', 'json']
|
||||
|
||||
DEBUG = 10
|
||||
INFO = 20
|
||||
WARN = 30
|
||||
ERROR = 40
|
||||
|
||||
DISABLED = 50
|
||||
|
||||
|
||||
class OutputFormat(object):
|
||||
def writekvs(self, kvs):
|
||||
"""
|
||||
Write key-value pairs
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def writeseq(self, args):
|
||||
"""
|
||||
Write a sequence of other data (e.g. a logging message)
|
||||
"""
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
return
|
||||
|
||||
|
||||
class HumanOutputFormat(OutputFormat):
|
||||
def __init__(self, file):
|
||||
self.file = file
|
||||
|
||||
def writekvs(self, kvs):
|
||||
# Create strings for printing
|
||||
key2str = OrderedDict()
|
||||
for (key, val) in kvs.items():
|
||||
valstr = '%-8.3g' % (val,) if hasattr(val, '__float__') else val
|
||||
key2str[self._truncate(key)] = self._truncate(valstr)
|
||||
|
||||
# Find max widths
|
||||
keywidth = max(map(len, key2str.keys()))
|
||||
valwidth = max(map(len, key2str.values()))
|
||||
|
||||
# Write out the data
|
||||
dashes = '-' * (keywidth + valwidth + 7)
|
||||
lines = [dashes]
|
||||
for (key, val) in key2str.items():
|
||||
lines.append('| %s%s | %s%s |' % (
|
||||
key,
|
||||
' ' * (keywidth - len(key)),
|
||||
val,
|
||||
' ' * (valwidth - len(val)),
|
||||
))
|
||||
lines.append(dashes)
|
||||
self.file.write('\n'.join(lines) + '\n')
|
||||
|
||||
# Flush the output to the file
|
||||
self.file.flush()
|
||||
|
||||
def _truncate(self, s):
|
||||
return s[:20] + '...' if len(s) > 23 else s
|
||||
|
||||
def writeseq(self, args):
|
||||
for arg in args:
|
||||
self.file.write(arg)
|
||||
self.file.write('\n')
|
||||
self.file.flush()
|
||||
|
||||
|
||||
class JSONOutputFormat(OutputFormat):
|
||||
def __init__(self, file):
|
||||
self.file = file
|
||||
|
||||
def writekvs(self, kvs):
|
||||
for k, v in kvs.items():
|
||||
if hasattr(v, 'dtype'):
|
||||
v = v.tolist()
|
||||
kvs[k] = float(v)
|
||||
self.file.write(json.dumps(kvs) + '\n')
|
||||
self.file.flush()
|
||||
|
||||
|
||||
def make_output_format(format, ev_dir):
|
||||
os.makedirs(ev_dir, exist_ok=True)
|
||||
if format == 'stdout':
|
||||
return HumanOutputFormat(sys.stdout)
|
||||
elif format == 'log':
|
||||
log_file = open(osp.join(ev_dir, 'log.txt'), 'wt')
|
||||
return HumanOutputFormat(log_file)
|
||||
elif format == 'json':
|
||||
json_file = open(osp.join(ev_dir, 'progress.json'), 'wt')
|
||||
return JSONOutputFormat(json_file)
|
||||
else:
|
||||
raise ValueError('Unknown format specified: %s' % (format,))
|
||||
|
||||
# ================================================================
|
||||
# API
|
||||
# ================================================================
|
||||
|
||||
|
||||
def logkv(key, val):
|
||||
"""
|
||||
Log a value of some diagnostic
|
||||
Call this once for each diagnostic quantity, each iteration
|
||||
"""
|
||||
Logger.CURRENT.logkv(key, val)
|
||||
|
||||
|
||||
def dumpkvs():
|
||||
"""
|
||||
Write all of the diagnostics from the current iteration
|
||||
|
||||
level: int. (see logger.py docs) If the global logger level is higher than
|
||||
the level argument here, don't print to stdout.
|
||||
"""
|
||||
Logger.CURRENT.dumpkvs()
|
||||
|
||||
|
||||
# for backwards compatibility
|
||||
record_tabular = logkv
|
||||
dump_tabular = dumpkvs
|
||||
|
||||
|
||||
def log(*args, level=INFO):
|
||||
"""
|
||||
Write the sequence of args, with no separators, to the console and output
|
||||
files (if you've configured an output file).
|
||||
"""
|
||||
Logger.CURRENT.log(*args, level=level)
|
||||
|
||||
|
||||
def debug(*args):
|
||||
log(*args, level=DEBUG)
|
||||
|
||||
|
||||
def info(*args):
|
||||
log(*args, level=INFO)
|
||||
|
||||
|
||||
def warn(*args):
|
||||
log(*args, level=WARN)
|
||||
|
||||
|
||||
def error(*args):
|
||||
log(*args, level=ERROR)
|
||||
|
||||
|
||||
def set_level(level):
|
||||
"""
|
||||
Set logging threshold on current logger.
|
||||
"""
|
||||
Logger.CURRENT.set_level(level)
|
||||
|
||||
|
||||
def get_dir():
|
||||
"""
|
||||
Get directory that log files are being written to.
|
||||
will be None if there is no output directory (i.e., if you didn't call start)
|
||||
"""
|
||||
return Logger.CURRENT.get_dir()
|
||||
|
||||
|
||||
def get_expt_dir():
|
||||
sys.stderr.write(
|
||||
"get_expt_dir() is Deprecated. Switch to get_dir() [%s]\n" %
|
||||
(get_dir(),))
|
||||
return get_dir()
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Backend
|
||||
# ================================================================
|
||||
|
||||
|
||||
class Logger(object):
|
||||
# A logger with no output files. (See right below class definition)
|
||||
# So that you can still log to the terminal without setting up any output
|
||||
DEFAULT = None
|
||||
|
||||
# Current logger being used by the free functions above
|
||||
CURRENT = None
|
||||
|
||||
def __init__(self, dir, output_formats):
|
||||
self.name2val = OrderedDict() # values this iteration
|
||||
self.level = INFO
|
||||
self.dir = dir
|
||||
self.output_formats = output_formats
|
||||
|
||||
# Logging API, forwarded
|
||||
# ----------------------------------------
|
||||
def logkv(self, key, val):
|
||||
self.name2val[key] = val
|
||||
|
||||
def dumpkvs(self):
|
||||
for fmt in self.output_formats:
|
||||
fmt.writekvs(self.name2val)
|
||||
self.name2val.clear()
|
||||
|
||||
def log(self, *args, level=INFO):
|
||||
if self.level <= level:
|
||||
self._do_log(args)
|
||||
|
||||
# Configuration
|
||||
# ----------------------------------------
|
||||
def set_level(self, level):
|
||||
self.level = level
|
||||
|
||||
def get_dir(self):
|
||||
return self.dir
|
||||
|
||||
def close(self):
|
||||
for fmt in self.output_formats:
|
||||
fmt.close()
|
||||
|
||||
# Misc
|
||||
# ----------------------------------------
|
||||
def _do_log(self, args):
|
||||
for fmt in self.output_formats:
|
||||
fmt.writeseq(args)
|
||||
|
||||
|
||||
# ================================================================
|
||||
|
||||
Logger.DEFAULT = Logger(
|
||||
output_formats=[HumanOutputFormat(sys.stdout)], dir=None)
|
||||
Logger.CURRENT = Logger.DEFAULT
|
||||
|
||||
|
||||
class session(object):
|
||||
"""
|
||||
Context manager that sets up the loggers for an experiment.
|
||||
"""
|
||||
|
||||
CURRENT = None # Set to a LoggerContext object using enter/exit or cm
|
||||
|
||||
def __init__(self, dir, format_strs=None):
|
||||
self.dir = dir
|
||||
if format_strs is None:
|
||||
format_strs = LOG_OUTPUT_FORMATS
|
||||
output_formats = [make_output_format(f, dir) for f in format_strs]
|
||||
Logger.CURRENT = Logger(dir=dir, output_formats=output_formats)
|
||||
|
||||
def __enter__(self):
|
||||
os.makedirs(self.evaluation_dir(), exist_ok=True)
|
||||
output_formats = [
|
||||
make_output_format(
|
||||
f, self.evaluation_dir()) for f in LOG_OUTPUT_FORMATS]
|
||||
Logger.CURRENT = Logger(dir=self.dir, output_formats=output_formats)
|
||||
|
||||
def __exit__(self, *args):
|
||||
Logger.CURRENT.close()
|
||||
Logger.CURRENT = Logger.DEFAULT
|
||||
|
||||
def evaluation_dir(self):
|
||||
return self.dir
|
||||
|
||||
|
||||
# ================================================================
|
||||
|
||||
|
||||
def _demo():
|
||||
info("hi")
|
||||
debug("shouldn't appear")
|
||||
set_level(DEBUG)
|
||||
debug("should appear")
|
||||
dir = "/tmp/testlogging"
|
||||
if os.path.exists(dir):
|
||||
shutil.rmtree(dir)
|
||||
with session(dir=dir):
|
||||
record_tabular("a", 3)
|
||||
record_tabular("b", 2.5)
|
||||
dump_tabular()
|
||||
record_tabular("b", -2.5)
|
||||
record_tabular("a", 5.5)
|
||||
dump_tabular()
|
||||
info("^^^ should see a = 5.5")
|
||||
|
||||
record_tabular("b", -2.5)
|
||||
dump_tabular()
|
||||
|
||||
record_tabular("a", "longasslongasslongasslongasslongasslongassvalue")
|
||||
dump_tabular()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_demo()
|
||||
@@ -0,0 +1,97 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow.contrib.layers as layers
|
||||
|
||||
|
||||
def _mlp(hiddens, inpt, num_actions, scope, reuse=False):
|
||||
with tf.variable_scope(scope, reuse=reuse):
|
||||
out = inpt
|
||||
for hidden in hiddens:
|
||||
out = layers.fully_connected(
|
||||
out, num_outputs=hidden, activation_fn=tf.nn.relu)
|
||||
out = layers.fully_connected(
|
||||
out, num_outputs=num_actions, activation_fn=None)
|
||||
return out
|
||||
|
||||
|
||||
def mlp(hiddens=[]):
|
||||
"""This model takes as input an observation and returns values of all
|
||||
actions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
hiddens: [int]
|
||||
list of sizes of hidden layers
|
||||
|
||||
Returns
|
||||
-------
|
||||
q_func: function
|
||||
q_function for DQN algorithm.
|
||||
"""
|
||||
return lambda *args, **kwargs: _mlp(hiddens, *args, **kwargs)
|
||||
|
||||
|
||||
def _cnn_to_mlp(
|
||||
convs, hiddens, dueling, inpt, num_actions, scope, reuse=False):
|
||||
|
||||
with tf.variable_scope(scope, reuse=reuse):
|
||||
out = inpt
|
||||
with tf.variable_scope("convnet"):
|
||||
for num_outputs, kernel_size, stride in convs:
|
||||
out = layers.convolution2d(
|
||||
out,
|
||||
num_outputs=num_outputs,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
activation_fn=tf.nn.relu)
|
||||
out = layers.flatten(out)
|
||||
with tf.variable_scope("action_value"):
|
||||
action_out = out
|
||||
for hidden in hiddens:
|
||||
action_out = layers.fully_connected(
|
||||
action_out, num_outputs=hidden, activation_fn=tf.nn.relu)
|
||||
action_scores = layers.fully_connected(
|
||||
action_out, num_outputs=num_actions, activation_fn=None)
|
||||
|
||||
if dueling:
|
||||
with tf.variable_scope("state_value"):
|
||||
state_out = out
|
||||
for hidden in hiddens:
|
||||
state_out = layers.fully_connected(
|
||||
state_out, num_outputs=hidden, activation_fn=tf.nn.relu)
|
||||
state_score = layers.fully_connected(
|
||||
state_out, num_outputs=1, activation_fn=None)
|
||||
action_scores_mean = tf.reduce_mean(action_scores, 1)
|
||||
action_scores_centered = action_scores - tf.expand_dims(
|
||||
action_scores_mean, 1)
|
||||
return state_score + action_scores_centered
|
||||
else:
|
||||
return action_scores
|
||||
return out
|
||||
|
||||
|
||||
def cnn_to_mlp(convs, hiddens, dueling=False):
|
||||
"""This model takes as input an observation and returns values of all actions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
convs: [(int, int int)]
|
||||
list of convolutional layers in form of
|
||||
(num_outputs, kernel_size, stride)
|
||||
hiddens: [int]
|
||||
list of sizes of hidden layers
|
||||
dueling: bool
|
||||
if true double the output MLP to compute a baseline
|
||||
for action scores
|
||||
|
||||
Returns
|
||||
-------
|
||||
q_func: function
|
||||
q_function for DQN algorithm.
|
||||
"""
|
||||
|
||||
return lambda *args, **kwargs: _cnn_to_mlp(
|
||||
convs, hiddens, dueling, *args, **kwargs)
|
||||
@@ -0,0 +1,196 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
from ray.rllib.dqn.common.segment_tree import SumSegmentTree, MinSegmentTree
|
||||
|
||||
|
||||
class ReplayBuffer(object):
|
||||
def __init__(self, size):
|
||||
"""Create Prioritized Replay buffer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
size: int
|
||||
Max number of transitions to store in the buffer. When the buffer
|
||||
overflows the old memories are dropped.
|
||||
"""
|
||||
self._storage = []
|
||||
self._maxsize = size
|
||||
self._next_idx = 0
|
||||
|
||||
def __len__(self):
|
||||
return len(self._storage)
|
||||
|
||||
def add(self, obs_t, action, reward, obs_tp1, done):
|
||||
data = (obs_t, action, reward, obs_tp1, done)
|
||||
|
||||
if self._next_idx >= len(self._storage):
|
||||
self._storage.append(data)
|
||||
else:
|
||||
self._storage[self._next_idx] = data
|
||||
self._next_idx = (self._next_idx + 1) % self._maxsize
|
||||
|
||||
def _encode_sample(self, idxes):
|
||||
obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], []
|
||||
for i in idxes:
|
||||
data = self._storage[i]
|
||||
obs_t, action, reward, obs_tp1, done = data
|
||||
obses_t.append(np.array(obs_t, copy=False))
|
||||
actions.append(np.array(action, copy=False))
|
||||
rewards.append(reward)
|
||||
obses_tp1.append(np.array(obs_tp1, copy=False))
|
||||
dones.append(done)
|
||||
return np.array(obses_t), np.array(actions), np.array(rewards), \
|
||||
np.array(obses_tp1), np.array(dones)
|
||||
|
||||
def sample(self, batch_size):
|
||||
"""Sample a batch of experiences.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
batch_size: int
|
||||
How many transitions to sample.
|
||||
|
||||
Returns
|
||||
-------
|
||||
obs_batch: np.array
|
||||
batch of observations
|
||||
act_batch: np.array
|
||||
batch of actions executed given obs_batch
|
||||
rew_batch: np.array
|
||||
rewards received as results of executing act_batch
|
||||
next_obs_batch: np.array
|
||||
next set of observations seen after executing act_batch
|
||||
done_mask: np.array
|
||||
done_mask[i] = 1 if executing act_batch[i] resulted in
|
||||
the end of an episode and 0 otherwise.
|
||||
"""
|
||||
idxes = [random.randint(0, len(self._storage) - 1)
|
||||
for _ in range(batch_size)]
|
||||
return self._encode_sample(idxes)
|
||||
|
||||
|
||||
class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
def __init__(self, size, alpha):
|
||||
"""Create Prioritized Replay buffer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
size: int
|
||||
Max number of transitions to store in the buffer. When the buffer
|
||||
overflows the old memories are dropped.
|
||||
alpha: float
|
||||
how much prioritization is used
|
||||
(0 - no prioritization, 1 - full prioritization)
|
||||
|
||||
See Also
|
||||
--------
|
||||
ReplayBuffer.__init__
|
||||
"""
|
||||
super(PrioritizedReplayBuffer, self).__init__(size)
|
||||
assert alpha > 0
|
||||
self._alpha = alpha
|
||||
|
||||
it_capacity = 1
|
||||
while it_capacity < size:
|
||||
it_capacity *= 2
|
||||
|
||||
self._it_sum = SumSegmentTree(it_capacity)
|
||||
self._it_min = MinSegmentTree(it_capacity)
|
||||
self._max_priority = 1.0
|
||||
|
||||
def add(self, *args, **kwargs):
|
||||
"""See ReplayBuffer.store_effect"""
|
||||
idx = self._next_idx
|
||||
super().add(*args, **kwargs)
|
||||
self._it_sum[idx] = self._max_priority ** self._alpha
|
||||
self._it_min[idx] = self._max_priority ** self._alpha
|
||||
|
||||
def _sample_proportional(self, batch_size):
|
||||
res = []
|
||||
for _ in range(batch_size):
|
||||
# TODO(szymon): should we ensure no repeats?
|
||||
mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1)
|
||||
idx = self._it_sum.find_prefixsum_idx(mass)
|
||||
res.append(idx)
|
||||
return res
|
||||
|
||||
def sample(self, batch_size, beta):
|
||||
"""Sample a batch of experiences.
|
||||
|
||||
compared to ReplayBuffer.sample
|
||||
it also returns importance weights and idxes
|
||||
of sampled experiences.
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
batch_size: int
|
||||
How many transitions to sample.
|
||||
beta: float
|
||||
To what degree to use importance weights
|
||||
(0 - no corrections, 1 - full correction)
|
||||
|
||||
Returns
|
||||
-------
|
||||
obs_batch: np.array
|
||||
batch of observations
|
||||
act_batch: np.array
|
||||
batch of actions executed given obs_batch
|
||||
rew_batch: np.array
|
||||
rewards received as results of executing act_batch
|
||||
next_obs_batch: np.array
|
||||
next set of observations seen after executing act_batch
|
||||
done_mask: np.array
|
||||
done_mask[i] = 1 if executing act_batch[i] resulted in
|
||||
the end of an episode and 0 otherwise.
|
||||
weights: np.array
|
||||
Array of shape (batch_size,) and dtype np.float32
|
||||
denoting importance weight of each sampled transition
|
||||
idxes: np.array
|
||||
Array of shape (batch_size,) and dtype np.int32
|
||||
idexes in buffer of sampled experiences
|
||||
"""
|
||||
assert beta > 0
|
||||
|
||||
idxes = self._sample_proportional(batch_size)
|
||||
|
||||
weights = []
|
||||
p_min = self._it_min.min() / self._it_sum.sum()
|
||||
max_weight = (p_min * len(self._storage)) ** (-beta)
|
||||
|
||||
for idx in idxes:
|
||||
p_sample = self._it_sum[idx] / self._it_sum.sum()
|
||||
weight = (p_sample * len(self._storage)) ** (-beta)
|
||||
weights.append(weight / max_weight)
|
||||
weights = np.array(weights)
|
||||
encoded_sample = self._encode_sample(idxes)
|
||||
return tuple(list(encoded_sample) + [weights, idxes])
|
||||
|
||||
def update_priorities(self, idxes, priorities):
|
||||
"""Update priorities of sampled transitions.
|
||||
|
||||
sets priority of transition at index idxes[i] in buffer
|
||||
to priorities[i].
|
||||
|
||||
Parameters
|
||||
----------
|
||||
idxes: [int]
|
||||
List of idxes of sampled transitions
|
||||
priorities: [float]
|
||||
List of updated priorities corresponding to
|
||||
transitions at the sampled idxes denoted by
|
||||
variable `idxes`.
|
||||
"""
|
||||
assert len(idxes) == len(priorities)
|
||||
for idx, priority in zip(idxes, priorities):
|
||||
assert priority > 0
|
||||
assert 0 <= idx < len(self._storage)
|
||||
self._it_sum[idx] = priority ** self._alpha
|
||||
self._it_min[idx] = priority ** self._alpha
|
||||
|
||||
self._max_priority = max(self._max_priority, priority)
|
||||
Reference in New Issue
Block a user