mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 16:54:21 +08:00
[rllib] Refactor DQN to use an Evaluator abstraction (#1276)
This introduces rllib.Evaluator and rllib.Optimizer classes. Optimizers encapsulate a particular distributed optimization strategy for RL. Evaluators encapsulate the model graph, and once implemented, any Optimizer may be "plugged in" to any algorithm that implements the Evaluator interface.
This commit is contained in:
@@ -0,0 +1,131 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.dqn import models
|
||||
from ray.rllib.dqn.common.wrappers import wrap_dqn
|
||||
from ray.rllib.dqn.common.schedules import LinearSchedule
|
||||
from ray.rllib.evaluator import TFMultiGPUSupport
|
||||
|
||||
|
||||
class DQNEvaluator(TFMultiGPUSupport):
|
||||
"""The base DQN Evaluator that does not include the replay buffer."""
|
||||
|
||||
def __init__(self, env_creator, config, logdir):
|
||||
env = env_creator()
|
||||
env = wrap_dqn(env, config["model"])
|
||||
self.env = env
|
||||
self.config = config
|
||||
|
||||
tf_config = tf.ConfigProto(**config["tf_session_args"])
|
||||
self.sess = tf.Session(config=tf_config)
|
||||
self.dqn_graph = models.DQNGraph(env, config, logdir)
|
||||
|
||||
# Create the schedule for exploration starting from 1.
|
||||
self.exploration = LinearSchedule(
|
||||
schedule_timesteps=int(
|
||||
config["exploration_fraction"] *
|
||||
config["schedule_max_timesteps"]),
|
||||
initial_p=1.0,
|
||||
final_p=config["exploration_final_eps"])
|
||||
|
||||
# Initialize the parameters and copy them to the target network.
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
self.dqn_graph.update_target(self.sess)
|
||||
self.global_timestep = 0
|
||||
self.local_timestep = 0
|
||||
|
||||
# Note that this encompasses both the Q and target network
|
||||
self.variables = ray.experimental.TensorFlowVariables(
|
||||
tf.group(self.dqn_graph.q_t, self.dqn_graph.q_tp1), self.sess)
|
||||
|
||||
self.episode_rewards = [0.0]
|
||||
self.episode_lengths = [0.0]
|
||||
self.saved_mean_reward = None
|
||||
self.obs = self.env.reset()
|
||||
|
||||
def set_global_timestep(self, global_timestep):
|
||||
self.global_timestep = global_timestep
|
||||
|
||||
def update_target(self):
|
||||
self.dqn_graph.update_target(self.sess)
|
||||
|
||||
def sample(self):
|
||||
output = []
|
||||
for _ in range(self.config["sample_batch_size"]):
|
||||
result = self._step(self.global_timestep)
|
||||
output.append(result)
|
||||
return output
|
||||
|
||||
def compute_gradients(self, samples):
|
||||
if self.config["prioritized_replay"]:
|
||||
obses_t, actions, rewards, obses_tp1, dones, _ = samples
|
||||
else:
|
||||
obses_t, actions, rewards, obses_tp1, dones = samples
|
||||
_, grad = self.dqn_graph.compute_gradients(
|
||||
self.sess, obses_t, actions, rewards, obses_tp1, dones,
|
||||
np.ones_like(rewards))
|
||||
return grad
|
||||
|
||||
def apply_gradients(self, grads):
|
||||
self.dqn_graph.apply_gradients(self.sess, grads)
|
||||
|
||||
def get_weights(self):
|
||||
return self.variables.get_weights()
|
||||
|
||||
def set_weights(self, weights):
|
||||
self.variables.set_weights(weights)
|
||||
|
||||
def tf_loss_inputs(self):
|
||||
return self.dqn_graph.loss_inputs
|
||||
|
||||
def build_tf_loss(self, input_placeholders):
|
||||
return self.dqn_graph.build_loss(*input_placeholders)
|
||||
|
||||
def _step(self, global_timestep):
|
||||
"""Takes a single step, and returns the result of the step."""
|
||||
action = self.dqn_graph.act(
|
||||
self.sess, np.array(self.obs)[None],
|
||||
self.exploration.value(global_timestep))[0]
|
||||
new_obs, rew, done, _ = self.env.step(action)
|
||||
ret = (self.obs, action, rew, new_obs, float(done))
|
||||
self.obs = new_obs
|
||||
self.episode_rewards[-1] += rew
|
||||
self.episode_lengths[-1] += 1
|
||||
if done:
|
||||
self.obs = self.env.reset()
|
||||
self.episode_rewards.append(0.0)
|
||||
self.episode_lengths.append(0.0)
|
||||
self.local_timestep += 1
|
||||
return ret
|
||||
|
||||
def stats(self):
|
||||
mean_100ep_reward = round(np.mean(self.episode_rewards[-101:-1]), 5)
|
||||
mean_100ep_length = round(np.mean(self.episode_lengths[-101:-1]), 5)
|
||||
exploration = self.exploration.value(self.global_timestep)
|
||||
return {
|
||||
"mean_100ep_reward": mean_100ep_reward,
|
||||
"mean_100ep_length": mean_100ep_length,
|
||||
"num_episodes": len(self.episode_rewards),
|
||||
"exploration": exploration,
|
||||
"local_timestep": self.local_timestep,
|
||||
}
|
||||
|
||||
def save(self):
|
||||
return [
|
||||
self.exploration,
|
||||
self.episode_rewards,
|
||||
self.episode_lengths,
|
||||
self.saved_mean_reward,
|
||||
self.obs]
|
||||
|
||||
def restore(self, data):
|
||||
self.exploration = data[0]
|
||||
self.episode_rewards = data[1]
|
||||
self.episode_lengths = data[2]
|
||||
self.saved_mean_reward = data[3]
|
||||
self.obs = data[4]
|
||||
+99
-487
@@ -2,21 +2,17 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pickle
|
||||
import os
|
||||
import sys
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.dqn.base_evaluator import DQNEvaluator
|
||||
from ray.rllib.dqn.replay_evaluator import DQNReplayEvaluator
|
||||
from ray.rllib.optimizers import AsyncOptimizer, LocalMultiGPUOptimizer, \
|
||||
LocalSyncOptimizer
|
||||
from ray.rllib.agent import Agent
|
||||
from ray.rllib.dqn import logger, models
|
||||
from ray.rllib.dqn.common.wrappers import wrap_dqn
|
||||
from ray.rllib.dqn.common.schedules import LinearSchedule
|
||||
from ray.rllib.dqn.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
|
||||
from ray.rllib.utils.filter import RunningStat
|
||||
from ray.tune.result import TrainingResult
|
||||
|
||||
|
||||
@@ -81,6 +77,8 @@ DEFAULT_CONFIG = dict(
|
||||
sgd_batch_size=32,
|
||||
# If not None, clip gradients during optimization at this value
|
||||
grad_norm_clipping=10,
|
||||
# Arguments to pass to the rllib optimizer
|
||||
optimizer={},
|
||||
|
||||
# === Tensorflow ===
|
||||
# Arguments to pass to tensorflow
|
||||
@@ -96,8 +94,8 @@ DEFAULT_CONFIG = dict(
|
||||
# Number of workers for collecting samples with. Note that the typical
|
||||
# setting is 1 unless your environment is particularly slow to sample.
|
||||
num_workers=1,
|
||||
# Whether to allocate GPUs for workers (if num_workers > 1).
|
||||
use_gpu_for_workers=False,
|
||||
# Whether to allocate GPUs for workers (if > 0).
|
||||
num_gpus_per_worker=0,
|
||||
# (Experimental) Whether to update the model asynchronously from
|
||||
# workers. In this mode, gradients will be computed on workers instead of
|
||||
# on the driver, and workers will each have their own replay buffer.
|
||||
@@ -111,522 +109,136 @@ DEFAULT_CONFIG = dict(
|
||||
devices=["/gpu:0"])
|
||||
|
||||
|
||||
class Actor(object):
|
||||
def __init__(self, env_creator, config, logdir):
|
||||
env = env_creator()
|
||||
env = wrap_dqn(env, config["model"])
|
||||
self.env = env
|
||||
self.config = config
|
||||
|
||||
tf_config = tf.ConfigProto(**config["tf_session_args"])
|
||||
self.sess = tf.Session(config=tf_config)
|
||||
self.dqn_graph = models.DQNGraph(env, config, logdir)
|
||||
|
||||
# Create the replay buffer
|
||||
if config["prioritized_replay"]:
|
||||
self.replay_buffer = PrioritizedReplayBuffer(
|
||||
config["buffer_size"],
|
||||
alpha=config["prioritized_replay_alpha"])
|
||||
prioritized_replay_beta_iters = \
|
||||
config["prioritized_replay_beta_iters"]
|
||||
if prioritized_replay_beta_iters is None:
|
||||
prioritized_replay_beta_iters = \
|
||||
config["schedule_max_timesteps"]
|
||||
self.beta_schedule = LinearSchedule(
|
||||
prioritized_replay_beta_iters,
|
||||
initial_p=config["prioritized_replay_beta0"],
|
||||
final_p=1.0)
|
||||
else:
|
||||
self.replay_buffer = ReplayBuffer(config["buffer_size"])
|
||||
self.beta_schedule = None
|
||||
# Create the schedule for exploration starting from 1.
|
||||
self.exploration = LinearSchedule(
|
||||
schedule_timesteps=int(
|
||||
config["exploration_fraction"] *
|
||||
config["schedule_max_timesteps"]),
|
||||
initial_p=1.0,
|
||||
final_p=config["exploration_final_eps"])
|
||||
|
||||
# Initialize the parameters and copy them to the target network.
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
self.dqn_graph.update_target(self.sess)
|
||||
self.set_weights_time = RunningStat(())
|
||||
self.sample_time = RunningStat(())
|
||||
self.grad_time = RunningStat(())
|
||||
|
||||
# Note that workers don't need target vars to be synced
|
||||
self.variables = ray.experimental.TensorFlowVariables(
|
||||
tf.group(self.dqn_graph.q_t, self.dqn_graph.q_tp1), self.sess)
|
||||
|
||||
self.episode_rewards = [0.0]
|
||||
self.episode_lengths = [0.0]
|
||||
self.saved_mean_reward = None
|
||||
self.obs = self.env.reset()
|
||||
self.file_writer = tf.summary.FileWriter(logdir, self.sess.graph)
|
||||
|
||||
def step(self, cur_timestep):
|
||||
"""Takes a single step, and returns the result of the step."""
|
||||
action = self.dqn_graph.act(
|
||||
self.sess, np.array(self.obs)[None],
|
||||
self.exploration.value(cur_timestep))[0]
|
||||
new_obs, rew, done, _ = self.env.step(action)
|
||||
ret = (self.obs, action, rew, new_obs, float(done))
|
||||
self.obs = new_obs
|
||||
self.episode_rewards[-1] += rew
|
||||
self.episode_lengths[-1] += 1
|
||||
if done:
|
||||
self.obs = self.env.reset()
|
||||
self.episode_rewards.append(0.0)
|
||||
self.episode_lengths.append(0.0)
|
||||
return ret
|
||||
|
||||
def do_steps(self, num_steps, cur_timestep, store):
|
||||
"""Takes N steps.
|
||||
|
||||
If store is True, the steps will be stored in the local replay buffer.
|
||||
Otherwise, the steps will be returned.
|
||||
"""
|
||||
|
||||
output = []
|
||||
for _ in range(num_steps):
|
||||
result = self.step(cur_timestep)
|
||||
if store:
|
||||
obs, action, rew, new_obs, done = result
|
||||
self.replay_buffer.add(obs, action, rew, new_obs, done)
|
||||
else:
|
||||
output.append(result)
|
||||
if not store:
|
||||
return output
|
||||
|
||||
def do_multi_gpu_optimize(self, cur_timestep):
|
||||
"""Performs N iters of multi-gpu SGD over the local replay buffer."""
|
||||
dt = time.time()
|
||||
if self.config["prioritized_replay"]:
|
||||
experience = self.replay_buffer.sample(
|
||||
self.config["train_batch_size"],
|
||||
beta=self.beta_schedule.value(cur_timestep))
|
||||
(obses_t, actions, rewards, obses_tp1,
|
||||
dones, _, batch_idxes) = experience
|
||||
else:
|
||||
obses_t, actions, rewards, obses_tp1, dones = \
|
||||
self.replay_buffer.sample(self.config["train_batch_size"])
|
||||
batch_idxes = None
|
||||
replay_buffer_read_time = (time.time() - dt)
|
||||
dt = time.time()
|
||||
tuples_per_device = self.dqn_graph.multi_gpu_optimizer.load_data(
|
||||
self.sess,
|
||||
[obses_t, actions, rewards, obses_tp1, dones,
|
||||
np.ones_like(rewards)])
|
||||
per_device_batch_size = (
|
||||
self.dqn_graph.multi_gpu_optimizer.per_device_batch_size)
|
||||
num_batches = (int(tuples_per_device) // int(per_device_batch_size))
|
||||
data_load_time = (time.time() - dt)
|
||||
dt = time.time()
|
||||
for _ in range(self.config["num_sgd_iter"]):
|
||||
batches = list(range(num_batches))
|
||||
np.random.shuffle(batches)
|
||||
for i in batches:
|
||||
self.dqn_graph.multi_gpu_optimizer.optimize(
|
||||
self.sess, i * per_device_batch_size)
|
||||
sgd_time = (time.time() - dt)
|
||||
dt = time.time()
|
||||
if self.config["prioritized_replay"]:
|
||||
dt = time.time()
|
||||
td_errors = self.dqn_graph.compute_td_error(
|
||||
self.sess, obses_t, actions, rewards, obses_tp1, dones,
|
||||
np.ones_like(rewards))
|
||||
dt = time.time()
|
||||
new_priorities = (
|
||||
np.abs(td_errors) + self.config["prioritized_replay_eps"])
|
||||
self.replay_buffer.update_priorities(
|
||||
batch_idxes, new_priorities)
|
||||
prioritization_time = (time.time() - dt)
|
||||
return {
|
||||
"replay_buffer_read_time": replay_buffer_read_time,
|
||||
"data_load_time": data_load_time,
|
||||
"sgd_time": sgd_time,
|
||||
"prioritization_time": prioritization_time,
|
||||
}
|
||||
|
||||
def do_async_step(self, worker_id, cur_timestep, params, gradient_id):
|
||||
"""Takes steps and returns grad to apply async in the driver."""
|
||||
dt = time.time()
|
||||
self.set_weights(params)
|
||||
self.set_weights_time.push(time.time() - dt)
|
||||
dt = time.time()
|
||||
self.do_steps(
|
||||
self.config["sample_batch_size"], cur_timestep, store=True)
|
||||
self.sample_time.push(time.time() - dt)
|
||||
if (cur_timestep > self.config["learning_starts"] and
|
||||
len(self.replay_buffer) > self.config["train_batch_size"]):
|
||||
dt = time.time()
|
||||
gradient = self.sample_buffer_gradient(cur_timestep)
|
||||
self.grad_time.push(time.time() - dt)
|
||||
else:
|
||||
gradient = None
|
||||
return gradient, {"id": worker_id, "gradient_id": gradient_id}
|
||||
|
||||
def sample_buffer_gradient(self, cur_timestep):
|
||||
"""Returns grad over a batch sampled from the local replay buffer."""
|
||||
if self.config["prioritized_replay"]:
|
||||
experience = self.replay_buffer.sample(
|
||||
self.config["sgd_batch_size"],
|
||||
beta=self.beta_schedule.value(cur_timestep))
|
||||
(obses_t, actions, rewards, obses_tp1,
|
||||
dones, _, batch_idxes) = experience
|
||||
else:
|
||||
obses_t, actions, rewards, obses_tp1, dones = \
|
||||
self.replay_buffer.sample(self.config["sgd_batch_size"])
|
||||
batch_idxes = None
|
||||
td_errors, grad = self.dqn_graph.compute_gradients(
|
||||
self.sess, obses_t, actions, rewards, obses_tp1, dones,
|
||||
np.ones_like(rewards))
|
||||
if self.config["prioritized_replay"]:
|
||||
new_priorities = (
|
||||
np.abs(td_errors) + self.config["prioritized_replay_eps"])
|
||||
self.replay_buffer.update_priorities(
|
||||
batch_idxes, new_priorities)
|
||||
return grad
|
||||
|
||||
def apply_gradients(self, grad):
|
||||
self.dqn_graph.apply_gradients(self.sess, grad)
|
||||
|
||||
# TODO(ekl) return a dictionary and use that everywhere to clean up the
|
||||
# bookkeeping of stats
|
||||
def stats(self, num_timesteps):
|
||||
mean_100ep_reward = round(np.mean(self.episode_rewards[-101:-1]), 5)
|
||||
mean_100ep_length = round(np.mean(self.episode_lengths[-101:-1]), 5)
|
||||
exploration = self.exploration.value(num_timesteps)
|
||||
return (
|
||||
mean_100ep_reward,
|
||||
mean_100ep_length,
|
||||
len(self.episode_rewards),
|
||||
exploration,
|
||||
len(self.replay_buffer),
|
||||
float(self.set_weights_time.mean),
|
||||
float(self.sample_time.mean),
|
||||
float(self.grad_time.mean))
|
||||
|
||||
def get_weights(self):
|
||||
return self.variables.get_weights()
|
||||
|
||||
def set_weights(self, weights):
|
||||
self.variables.set_weights(weights)
|
||||
|
||||
def save(self):
|
||||
return [
|
||||
self.beta_schedule,
|
||||
self.exploration,
|
||||
self.episode_rewards,
|
||||
self.episode_lengths,
|
||||
self.saved_mean_reward,
|
||||
self.obs,
|
||||
self.replay_buffer]
|
||||
|
||||
def restore(self, data):
|
||||
self.beta_schedule = data[0]
|
||||
self.exploration = data[1]
|
||||
self.episode_rewards = data[2]
|
||||
self.episode_lengths = data[3]
|
||||
self.saved_mean_reward = data[4]
|
||||
self.obs = data[5]
|
||||
self.replay_buffer = data[6]
|
||||
|
||||
|
||||
@ray.remote
|
||||
class RemoteActor(Actor):
|
||||
def __init__(self, env_creator, config, logdir):
|
||||
Actor.__init__(self, env_creator, config, logdir)
|
||||
|
||||
def stop(self):
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class GPURemoteActor(Actor):
|
||||
def __init__(self, env_creator, config, logdir):
|
||||
Actor.__init__(self, env_creator, config, logdir)
|
||||
|
||||
def stop(self):
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
class DQNAgent(Agent):
|
||||
_agent_name = "DQN"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
|
||||
def _stop(self):
|
||||
for w in self.workers:
|
||||
w.stop.remote()
|
||||
|
||||
def _init(self):
|
||||
self.actor = Actor(self.env_creator, self.config, self.logdir)
|
||||
if self.config["use_gpu_for_workers"]:
|
||||
remote_cls = GPURemoteActor
|
||||
if self.config["async_updates"]:
|
||||
self.local_evaluator = DQNEvaluator(
|
||||
self.env_creator, self.config, self.logdir)
|
||||
remote_cls = ray.remote(
|
||||
num_cpus=1, num_gpus=self.config["num_gpus_per_worker"])(
|
||||
DQNReplayEvaluator)
|
||||
remote_config = dict(self.config, num_workers=1)
|
||||
# In async mode, we create N remote evaluators, each with their
|
||||
# own replay buffer (i.e. the replay buffer is sharded).
|
||||
self.remote_evaluators = [
|
||||
remote_cls.remote(
|
||||
self.env_creator, remote_config, self.logdir)
|
||||
for _ in range(self.config["num_workers"])]
|
||||
optimizer_cls = AsyncOptimizer
|
||||
else:
|
||||
remote_cls = RemoteActor
|
||||
# Use remote workers
|
||||
if self.config["num_workers"] > 1 or self.config["async_updates"]:
|
||||
self.workers = [
|
||||
remote_cls.remote(self.env_creator, self.config, self.logdir)
|
||||
for i in range(self.config["num_workers"])]
|
||||
else:
|
||||
# Use a single local worker and avoid object store overheads
|
||||
self.workers = []
|
||||
self.local_evaluator = DQNReplayEvaluator(
|
||||
self.env_creator, self.config, self.logdir)
|
||||
# No remote evaluators. If num_workers > 1, the DQNReplayEvaluator
|
||||
# will internally create more workers for parallelism. This means
|
||||
# there is only one replay buffer regardless of num_workers.
|
||||
self.remote_evaluators = []
|
||||
if self.config["multi_gpu_optimize"]:
|
||||
optimizer_cls = LocalMultiGPUOptimizer
|
||||
else:
|
||||
optimizer_cls = LocalSyncOptimizer
|
||||
|
||||
self.cur_timestep = 0
|
||||
self.num_iterations = 0
|
||||
self.num_target_updates = 0
|
||||
self.steps_since_update = 0
|
||||
self.file_writer = tf.summary.FileWriter(
|
||||
self.logdir, self.actor.sess.graph)
|
||||
self.optimizer = optimizer_cls(
|
||||
self.config["optimizer"], self.local_evaluator,
|
||||
self.remote_evaluators)
|
||||
self.saver = tf.train.Saver(max_to_keep=None)
|
||||
|
||||
def _update_worker_weights(self):
|
||||
if self.workers:
|
||||
w = self.actor.get_weights()
|
||||
weights = ray.put(self.actor.get_weights())
|
||||
for w in self.workers:
|
||||
w.set_weights.remote(weights)
|
||||
self.global_timestep = 0
|
||||
self.last_target_update_ts = 0
|
||||
self.num_target_updates = 0
|
||||
|
||||
def _train(self):
|
||||
if self.config["async_updates"]:
|
||||
return self._train_async()
|
||||
else:
|
||||
return self._train_sync()
|
||||
start_timestep = self.global_timestep
|
||||
|
||||
def _train_async(self):
|
||||
apply_time = RunningStat(())
|
||||
wait_time = RunningStat(())
|
||||
gradient_lag = RunningStat(())
|
||||
iter_init_timesteps = self.cur_timestep
|
||||
num_gradients_applied = 0
|
||||
gradient_list = [
|
||||
worker.do_async_step.remote(
|
||||
i, self.cur_timestep, self.actor.get_weights(),
|
||||
num_gradients_applied)
|
||||
for i, worker in enumerate(self.workers)]
|
||||
steps = self.config["sample_batch_size"] * len(gradient_list)
|
||||
self.cur_timestep += steps
|
||||
self.steps_since_update += steps
|
||||
while (self.global_timestep - start_timestep <
|
||||
self.config["timesteps_per_iteration"]):
|
||||
|
||||
while gradient_list:
|
||||
dt = time.time()
|
||||
gradient, info = ray.get(gradient_list[0])
|
||||
gradient_list = gradient_list[1:]
|
||||
wait_time.push(time.time() - dt)
|
||||
|
||||
if gradient is not None:
|
||||
dt = time.time()
|
||||
self.actor.apply_gradients(gradient)
|
||||
apply_time.push(time.time() - dt)
|
||||
gradient_lag.push(num_gradients_applied - info["gradient_id"])
|
||||
num_gradients_applied += 1
|
||||
|
||||
if (self.cur_timestep - iter_init_timesteps <
|
||||
self.config["timesteps_per_iteration"]):
|
||||
worker_id = info["id"]
|
||||
gradient_list.append(
|
||||
self.workers[info["id"]].do_async_step.remote(
|
||||
worker_id, self.cur_timestep,
|
||||
self.actor.get_weights(), num_gradients_applied))
|
||||
self.cur_timestep += self.config["sample_batch_size"]
|
||||
self.steps_since_update += self.config["sample_batch_size"]
|
||||
|
||||
if (self.cur_timestep > self.config["learning_starts"] and
|
||||
self.steps_since_update >
|
||||
self.config["target_network_update_freq"]):
|
||||
# Update target network periodically.
|
||||
self.actor.dqn_graph.update_target(self.actor.sess)
|
||||
self.steps_since_update -= (
|
||||
self.config["target_network_update_freq"])
|
||||
self.num_target_updates += 1
|
||||
|
||||
mean_100ep_reward = 0.0
|
||||
mean_100ep_length = 0.0
|
||||
num_episodes = 0
|
||||
buffer_size_sum = 0
|
||||
stats = ray.get(
|
||||
[w.stats.remote(self.cur_timestep) for w in self.workers])
|
||||
for stat in stats:
|
||||
mean_100ep_reward += stat[0]
|
||||
mean_100ep_length += stat[1]
|
||||
num_episodes += stat[2]
|
||||
exploration = stat[3]
|
||||
buffer_size_sum += stat[4]
|
||||
set_weights_time = stat[5]
|
||||
sample_time = stat[6]
|
||||
grad_time = stat[7]
|
||||
mean_100ep_reward /= self.config["num_workers"]
|
||||
mean_100ep_length /= self.config["num_workers"]
|
||||
|
||||
info = [
|
||||
("mean_100ep_reward", mean_100ep_reward),
|
||||
("exploration_frac", exploration),
|
||||
("steps", self.cur_timestep),
|
||||
("episodes", num_episodes),
|
||||
("buffer_sizes_sum", buffer_size_sum),
|
||||
("target_updates", self.num_target_updates),
|
||||
("mean_set_weights_time", set_weights_time),
|
||||
("mean_sample_time", sample_time),
|
||||
("mean_grad_time", grad_time),
|
||||
("mean_apply_time", float(apply_time.mean)),
|
||||
("mean_ray_wait_time", float(wait_time.mean)),
|
||||
("gradient_lag_mean", float(gradient_lag.mean)),
|
||||
("gradient_lag_stdev", float(gradient_lag.std)),
|
||||
]
|
||||
|
||||
for k, v in info:
|
||||
logger.record_tabular(k, v)
|
||||
logger.dump_tabular()
|
||||
|
||||
result = TrainingResult(
|
||||
episode_reward_mean=mean_100ep_reward,
|
||||
episode_len_mean=mean_100ep_length,
|
||||
timesteps_this_iter=self.cur_timestep - iter_init_timesteps,
|
||||
info=info)
|
||||
|
||||
return result
|
||||
|
||||
def _train_sync(self):
|
||||
config = self.config
|
||||
sample_time, sync_time, learn_time, apply_time = 0, 0, 0, 0
|
||||
iter_init_timesteps = self.cur_timestep
|
||||
|
||||
num_loop_iters = 0
|
||||
while (self.cur_timestep - iter_init_timesteps <
|
||||
config["timesteps_per_iteration"]):
|
||||
dt = time.time()
|
||||
if self.workers:
|
||||
worker_steps = ray.get([
|
||||
w.do_steps.remote(
|
||||
config["sample_batch_size"] // len(self.workers),
|
||||
self.cur_timestep, store=False)
|
||||
for w in self.workers])
|
||||
for steps in worker_steps:
|
||||
for obs, action, rew, new_obs, done in steps:
|
||||
self.actor.replay_buffer.add(
|
||||
obs, action, rew, new_obs, done)
|
||||
if self.global_timestep < self.config["learning_starts"]:
|
||||
self._populate_replay_buffer()
|
||||
else:
|
||||
self.actor.do_steps(
|
||||
config["sample_batch_size"], self.cur_timestep, store=True)
|
||||
num_loop_iters += 1
|
||||
self.cur_timestep += config["sample_batch_size"]
|
||||
self.steps_since_update += config["sample_batch_size"]
|
||||
sample_time += time.time() - dt
|
||||
self.optimizer.step()
|
||||
|
||||
if self.cur_timestep > config["learning_starts"]:
|
||||
if config["multi_gpu_optimize"]:
|
||||
dt = time.time()
|
||||
times = self.actor.do_multi_gpu_optimize(self.cur_timestep)
|
||||
if num_loop_iters <= 1:
|
||||
print("Multi-GPU times", times)
|
||||
learn_time += (time.time() - dt)
|
||||
else:
|
||||
# Minimize the error in Bellman's equation on a batch
|
||||
# sampled from replay buffer.
|
||||
for _ in range(
|
||||
max(1, config["train_batch_size"] //
|
||||
config["sgd_batch_size"])):
|
||||
dt = time.time()
|
||||
gradients = [
|
||||
self.actor.sample_buffer_gradient(
|
||||
self.cur_timestep)]
|
||||
learn_time += (time.time() - dt)
|
||||
dt = time.time()
|
||||
for grad in gradients:
|
||||
self.actor.apply_gradients(grad)
|
||||
apply_time += (time.time() - dt)
|
||||
dt = time.time()
|
||||
self._update_worker_weights()
|
||||
sync_time += (time.time() - dt)
|
||||
stats = self._update_global_stats()
|
||||
|
||||
if (self.cur_timestep > config["learning_starts"] and
|
||||
self.steps_since_update >
|
||||
config["target_network_update_freq"]):
|
||||
# Update target network periodically.
|
||||
self.actor.dqn_graph.update_target(self.actor.sess)
|
||||
self.steps_since_update -= config["target_network_update_freq"]
|
||||
if self.global_timestep - self.last_target_update_ts > \
|
||||
self.config["target_network_update_freq"]:
|
||||
self.local_evaluator.update_target()
|
||||
self.last_target_update_ts = self.global_timestep
|
||||
self.num_target_updates += 1
|
||||
|
||||
mean_100ep_reward = 0.0
|
||||
mean_100ep_length = 0.0
|
||||
num_episodes = 0
|
||||
buffer_size_sum = 0
|
||||
if not self.workers:
|
||||
stats = self.actor.stats(self.cur_timestep)
|
||||
mean_100ep_reward += stats[0]
|
||||
mean_100ep_length += stats[1]
|
||||
num_episodes += stats[2]
|
||||
exploration = stats[3]
|
||||
buffer_size_sum += stats[4]
|
||||
for mean_rew, mean_len, episodes, exploration, buf_sz in ray.get(
|
||||
[w.stats.remote(self.cur_timestep) for w in self.workers]):
|
||||
mean_100ep_reward += mean_rew
|
||||
mean_100ep_length += mean_len
|
||||
num_episodes += episodes
|
||||
buffer_size_sum += buf_sz
|
||||
mean_100ep_reward /= config["num_workers"]
|
||||
mean_100ep_length /= config["num_workers"]
|
||||
exploration = -1
|
||||
|
||||
info = [
|
||||
("mean_100ep_reward", mean_100ep_reward),
|
||||
("exploration_frac", exploration),
|
||||
("steps", self.cur_timestep),
|
||||
("episodes", num_episodes),
|
||||
("buffer_sizes_sum", buffer_size_sum),
|
||||
("target_updates", self.num_target_updates),
|
||||
("sample_time", sample_time),
|
||||
("weight_sync_time", sync_time),
|
||||
("apply_time", apply_time),
|
||||
("learn_time", learn_time),
|
||||
("samples_per_s",
|
||||
num_loop_iters * np.float64(config["sample_batch_size"]) /
|
||||
sample_time),
|
||||
("learn_samples_per_s",
|
||||
num_loop_iters * np.float64(config["train_batch_size"]) /
|
||||
learn_time),
|
||||
]
|
||||
|
||||
for k, v in info:
|
||||
logger.record_tabular(k, v)
|
||||
logger.dump_tabular()
|
||||
for s in stats:
|
||||
mean_100ep_reward += s["mean_100ep_reward"] / len(stats)
|
||||
mean_100ep_length += s["mean_100ep_length"] / len(stats)
|
||||
num_episodes += s["num_episodes"]
|
||||
exploration = s["exploration"]
|
||||
|
||||
result = TrainingResult(
|
||||
episode_reward_mean=mean_100ep_reward,
|
||||
episode_len_mean=mean_100ep_length,
|
||||
timesteps_this_iter=self.cur_timestep - iter_init_timesteps,
|
||||
info=info)
|
||||
episodes_total=num_episodes,
|
||||
timesteps_this_iter=self.global_timestep - start_timestep,
|
||||
info=dict({
|
||||
"exploration": exploration,
|
||||
"num_target_updates": self.num_target_updates,
|
||||
}, **self.optimizer.stats()))
|
||||
|
||||
return result
|
||||
|
||||
def _update_global_stats(self):
|
||||
if self.remote_evaluators:
|
||||
stats = ray.get([
|
||||
e.stats.remote() for e in self.remote_evaluators])
|
||||
else:
|
||||
stats = self.local_evaluator.stats()
|
||||
if not isinstance(stats, list):
|
||||
stats = [stats]
|
||||
new_timestep = sum(s["local_timestep"] for s in stats)
|
||||
assert new_timestep > self.global_timestep, new_timestep
|
||||
self.global_timestep = new_timestep
|
||||
self.local_evaluator.set_global_timestep(self.global_timestep)
|
||||
for e in self.remote_evaluators:
|
||||
e.set_global_timestep.remote(self.global_timestep)
|
||||
return stats
|
||||
|
||||
def _populate_replay_buffer(self):
|
||||
if self.remote_evaluators:
|
||||
for e in self.remote_evaluators:
|
||||
e.sample.remote(no_replay=True)
|
||||
else:
|
||||
self.local_evaluator.sample(no_replay=True)
|
||||
|
||||
def _save(self):
|
||||
checkpoint_path = self.saver.save(
|
||||
self.actor.sess,
|
||||
self.local_evaluator.sess,
|
||||
os.path.join(self.logdir, "checkpoint"),
|
||||
global_step=self.num_iterations)
|
||||
global_step=self.iteration)
|
||||
extra_data = [
|
||||
self.actor.save(),
|
||||
ray.get([w.save.remote() for w in self.workers]),
|
||||
self.cur_timestep,
|
||||
self.num_iterations,
|
||||
self.local_evaluator.save(),
|
||||
ray.get([e.save.remote() for e in self.remote_evaluators]),
|
||||
self.global_timestep,
|
||||
self.num_target_updates,
|
||||
self.steps_since_update]
|
||||
self.last_target_update_ts]
|
||||
pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
self.saver.restore(self.actor.sess, checkpoint_path)
|
||||
self.saver.restore(self.local_evaluator.sess, checkpoint_path)
|
||||
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
|
||||
self.actor.restore(extra_data[0])
|
||||
self.local_evaluator.restore(extra_data[0])
|
||||
ray.get([
|
||||
w.restore.remote(d) for (d, w)
|
||||
in zip(extra_data[1], self.workers)])
|
||||
self.cur_timestep = extra_data[2]
|
||||
self.num_iterations = extra_data[3]
|
||||
self.num_target_updates = extra_data[4]
|
||||
self.steps_since_update = extra_data[5]
|
||||
e.restore.remote(d) for (d, e)
|
||||
in zip(extra_data[1], self.remote_evaluators)])
|
||||
self.global_timestep = extra_data[2]
|
||||
self.num_target_updates = extra_data[3]
|
||||
self.last_target_update_ts = extra_data[4]
|
||||
|
||||
def compute_action(self, observation):
|
||||
return self.actor.dqn_graph.act(
|
||||
self.actor.sess, np.array(observation)[None], 0.0)[0]
|
||||
return self.local_evaluator.dqn_graph.act(
|
||||
self.local_evaluator.sess, np.array(observation)[None], 0.0)[0]
|
||||
|
||||
@@ -1,314 +0,0 @@
|
||||
"""
|
||||
|
||||
See README.md for a description of the logging API.
|
||||
|
||||
OFF state corresponds to having Logger.CURRENT == Logger.DEFAULT
|
||||
ON state is otherwise
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import os.path as osp
|
||||
import json
|
||||
|
||||
LOG_OUTPUT_FORMATS = ['stdout', 'log', 'json']
|
||||
|
||||
DEBUG = 10
|
||||
INFO = 20
|
||||
WARN = 30
|
||||
ERROR = 40
|
||||
|
||||
DISABLED = 50
|
||||
|
||||
|
||||
class OutputFormat(object):
|
||||
def writekvs(self, kvs):
|
||||
"""
|
||||
Write key-value pairs
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def writeseq(self, args):
|
||||
"""
|
||||
Write a sequence of other data (e.g. a logging message)
|
||||
"""
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
return
|
||||
|
||||
|
||||
class HumanOutputFormat(OutputFormat):
|
||||
def __init__(self, file):
|
||||
self.file = file
|
||||
|
||||
def writekvs(self, kvs):
|
||||
# Create strings for printing
|
||||
key2str = OrderedDict()
|
||||
for (key, val) in kvs.items():
|
||||
valstr = '%-8.3g' % (val,) if hasattr(val, '__float__') else val
|
||||
key2str[self._truncate(key)] = self._truncate(valstr)
|
||||
|
||||
# Find max widths
|
||||
keywidth = max(map(len, key2str.keys()))
|
||||
valwidth = max(map(len, key2str.values()))
|
||||
|
||||
# Write out the data
|
||||
dashes = '-' * (keywidth + valwidth + 7)
|
||||
lines = [dashes]
|
||||
for (key, val) in key2str.items():
|
||||
lines.append('| %s%s | %s%s |' % (
|
||||
key,
|
||||
' ' * (keywidth - len(key)),
|
||||
val,
|
||||
' ' * (valwidth - len(val)),
|
||||
))
|
||||
lines.append(dashes)
|
||||
self.file.write('\n'.join(lines) + '\n')
|
||||
|
||||
# Flush the output to the file
|
||||
self.file.flush()
|
||||
|
||||
def _truncate(self, s):
|
||||
return s[:20] + '...' if len(s) > 23 else s
|
||||
|
||||
def writeseq(self, args):
|
||||
for arg in args:
|
||||
self.file.write(arg)
|
||||
self.file.write('\n')
|
||||
self.file.flush()
|
||||
|
||||
|
||||
class JSONOutputFormat(OutputFormat):
|
||||
def __init__(self, file):
|
||||
self.file = file
|
||||
|
||||
def writekvs(self, kvs):
|
||||
for k, v in kvs.items():
|
||||
if hasattr(v, 'dtype'):
|
||||
v = v.tolist()
|
||||
kvs[k] = float(v)
|
||||
self.file.write(json.dumps(kvs) + '\n')
|
||||
self.file.flush()
|
||||
|
||||
|
||||
def make_output_format(format, ev_dir):
|
||||
os.makedirs(ev_dir, exist_ok=True)
|
||||
if format == 'stdout':
|
||||
return HumanOutputFormat(sys.stdout)
|
||||
elif format == 'log':
|
||||
log_file = open(osp.join(ev_dir, 'log.txt'), 'wt')
|
||||
return HumanOutputFormat(log_file)
|
||||
elif format == 'json':
|
||||
json_file = open(osp.join(ev_dir, 'progress.json'), 'wt')
|
||||
return JSONOutputFormat(json_file)
|
||||
else:
|
||||
raise ValueError('Unknown format specified: %s' % (format,))
|
||||
|
||||
# ================================================================
|
||||
# API
|
||||
# ================================================================
|
||||
|
||||
|
||||
def logkv(key, val):
|
||||
"""
|
||||
Log a value of some diagnostic
|
||||
Call this once for each diagnostic quantity, each iteration
|
||||
"""
|
||||
Logger.CURRENT.logkv(key, val)
|
||||
|
||||
|
||||
def dumpkvs():
|
||||
"""
|
||||
Write all of the diagnostics from the current iteration
|
||||
|
||||
level: int. (see logger.py docs) If the global logger level is higher than
|
||||
the level argument here, don't print to stdout.
|
||||
"""
|
||||
Logger.CURRENT.dumpkvs()
|
||||
|
||||
|
||||
# for backwards compatibility
|
||||
record_tabular = logkv
|
||||
dump_tabular = dumpkvs
|
||||
|
||||
|
||||
def log(*args, **kwargs):
|
||||
"""
|
||||
Write the sequence of args, with no separators, to the console and output
|
||||
files (if you've configured an output file).
|
||||
"""
|
||||
if "level" in kwargs:
|
||||
level = kwargs["level"]
|
||||
else:
|
||||
level = INFO
|
||||
Logger.CURRENT.log(*args, level=level)
|
||||
|
||||
|
||||
def debug(*args):
|
||||
log(*args, level=DEBUG)
|
||||
|
||||
|
||||
def info(*args):
|
||||
log(*args, level=INFO)
|
||||
|
||||
|
||||
def warn(*args):
|
||||
log(*args, level=WARN)
|
||||
|
||||
|
||||
def error(*args):
|
||||
log(*args, level=ERROR)
|
||||
|
||||
|
||||
def set_level(level):
|
||||
"""
|
||||
Set logging threshold on current logger.
|
||||
"""
|
||||
Logger.CURRENT.set_level(level)
|
||||
|
||||
|
||||
def get_dir():
|
||||
"""
|
||||
Get directory that log files are being written to.
|
||||
will be None if there is no output directory (i.e., if you didn't call
|
||||
start)
|
||||
"""
|
||||
return Logger.CURRENT.get_dir()
|
||||
|
||||
|
||||
def get_expt_dir():
|
||||
sys.stderr.write(
|
||||
"get_expt_dir() is Deprecated. Switch to get_dir() [%s]\n" %
|
||||
(get_dir(),))
|
||||
return get_dir()
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Backend
|
||||
# ================================================================
|
||||
|
||||
|
||||
class Logger(object):
|
||||
# A logger with no output files. (See right below class definition)
|
||||
# So that you can still log to the terminal without setting up any output
|
||||
DEFAULT = None
|
||||
|
||||
# Current logger being used by the free functions above
|
||||
CURRENT = None
|
||||
|
||||
def __init__(self, dir, output_formats):
|
||||
self.name2val = OrderedDict() # values this iteration
|
||||
self.level = INFO
|
||||
self.dir = dir
|
||||
self.output_formats = output_formats
|
||||
|
||||
# Logging API, forwarded
|
||||
# ----------------------------------------
|
||||
def logkv(self, key, val):
|
||||
self.name2val[key] = val
|
||||
|
||||
def dumpkvs(self):
|
||||
for fmt in self.output_formats:
|
||||
fmt.writekvs(self.name2val)
|
||||
self.name2val.clear()
|
||||
|
||||
def log(self, *args, **kwargs):
|
||||
if "level" in kwargs:
|
||||
level = kwargs["level"]
|
||||
else:
|
||||
level = INFO
|
||||
if self.level <= level:
|
||||
self._do_log(args)
|
||||
|
||||
# Configuration
|
||||
# ----------------------------------------
|
||||
def set_level(self, level):
|
||||
self.level = level
|
||||
|
||||
def get_dir(self):
|
||||
return self.dir
|
||||
|
||||
def close(self):
|
||||
for fmt in self.output_formats:
|
||||
fmt.close()
|
||||
|
||||
# Misc
|
||||
# ----------------------------------------
|
||||
def _do_log(self, args):
|
||||
for fmt in self.output_formats:
|
||||
fmt.writeseq(args)
|
||||
|
||||
|
||||
# ================================================================
|
||||
|
||||
Logger.DEFAULT = Logger(
|
||||
output_formats=[HumanOutputFormat(sys.stdout)], dir=None)
|
||||
Logger.CURRENT = Logger.DEFAULT
|
||||
|
||||
|
||||
class session(object):
|
||||
"""
|
||||
Context manager that sets up the loggers for an experiment.
|
||||
"""
|
||||
|
||||
CURRENT = None # Set to a LoggerContext object using enter/exit or cm
|
||||
|
||||
def __init__(self, dir, format_strs=None):
|
||||
self.dir = dir
|
||||
if format_strs is None:
|
||||
format_strs = LOG_OUTPUT_FORMATS
|
||||
output_formats = [make_output_format(f, dir) for f in format_strs]
|
||||
Logger.CURRENT = Logger(dir=dir, output_formats=output_formats)
|
||||
|
||||
def __enter__(self):
|
||||
os.makedirs(self.evaluation_dir(), exist_ok=True)
|
||||
output_formats = [
|
||||
make_output_format(
|
||||
f, self.evaluation_dir()) for f in LOG_OUTPUT_FORMATS]
|
||||
Logger.CURRENT = Logger(dir=self.dir, output_formats=output_formats)
|
||||
|
||||
def __exit__(self, *args):
|
||||
Logger.CURRENT.close()
|
||||
Logger.CURRENT = Logger.DEFAULT
|
||||
|
||||
def evaluation_dir(self):
|
||||
return self.dir
|
||||
|
||||
|
||||
# ================================================================
|
||||
|
||||
|
||||
def _demo():
|
||||
info("hi")
|
||||
debug("shouldn't appear")
|
||||
set_level(DEBUG)
|
||||
debug("should appear")
|
||||
dir = "/tmp/testlogging"
|
||||
if os.path.exists(dir):
|
||||
shutil.rmtree(dir)
|
||||
with session(dir=dir):
|
||||
record_tabular("a", 3)
|
||||
record_tabular("b", 2.5)
|
||||
dump_tabular()
|
||||
record_tabular("b", -2.5)
|
||||
record_tabular("a", 5.5)
|
||||
dump_tabular()
|
||||
info("^^^ should see a = 5.5")
|
||||
|
||||
record_tabular("b", -2.5)
|
||||
dump_tabular()
|
||||
|
||||
record_tabular("a", "longasslongasslongasslongasslongasslongassvalue")
|
||||
dump_tabular()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_demo()
|
||||
@@ -193,6 +193,11 @@ class DQNGraph(object):
|
||||
num_actions, config,
|
||||
obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights)
|
||||
|
||||
self.loss_inputs = [
|
||||
self.obs_t, self.act_t, self.rew_t, self.obs_tp1, self.done_mask,
|
||||
self.importance_weights]
|
||||
self.build_loss = build_loss
|
||||
|
||||
if config["multi_gpu_optimize"]:
|
||||
self.multi_gpu_optimizer = LocalSyncParallelOptimizer(
|
||||
optimizer,
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.rllib.dqn.base_evaluator import DQNEvaluator
|
||||
from ray.rllib.dqn.common.schedules import LinearSchedule
|
||||
from ray.rllib.dqn.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
|
||||
|
||||
|
||||
class DQNReplayEvaluator(DQNEvaluator):
|
||||
"""Wraps DQNEvaluators to provide replay buffer functionality.
|
||||
|
||||
This has two modes:
|
||||
If config["num_workers"] == 1:
|
||||
Samples will be collected locally.
|
||||
If config["num_workers"] > 1:
|
||||
Samples will be collected from a number of remote workers.
|
||||
"""
|
||||
|
||||
def __init__(self, env_creator, config, logdir):
|
||||
DQNEvaluator.__init__(self, env_creator, config, logdir)
|
||||
|
||||
# Create extra workers if needed
|
||||
if self.config["num_workers"] > 1:
|
||||
remote_cls = ray.remote(num_cpus=1)(DQNEvaluator)
|
||||
self.workers = [
|
||||
remote_cls.remote(env_creator, config, logdir)
|
||||
for _ in range(self.config["num_workers"])]
|
||||
else:
|
||||
self.workers = []
|
||||
|
||||
# Create the replay buffer
|
||||
if config["prioritized_replay"]:
|
||||
self.replay_buffer = PrioritizedReplayBuffer(
|
||||
config["buffer_size"],
|
||||
alpha=config["prioritized_replay_alpha"])
|
||||
prioritized_replay_beta_iters = \
|
||||
config["prioritized_replay_beta_iters"]
|
||||
if prioritized_replay_beta_iters is None:
|
||||
prioritized_replay_beta_iters = \
|
||||
config["schedule_max_timesteps"]
|
||||
self.beta_schedule = LinearSchedule(
|
||||
prioritized_replay_beta_iters,
|
||||
initial_p=config["prioritized_replay_beta0"],
|
||||
final_p=1.0)
|
||||
else:
|
||||
self.replay_buffer = ReplayBuffer(config["buffer_size"])
|
||||
self.beta_schedule = None
|
||||
|
||||
self.samples_to_prioritize = None
|
||||
|
||||
def sample(self, no_replay=False):
|
||||
# First seed the replay buffer with a few new samples
|
||||
if self.workers:
|
||||
weights = ray.put(self.get_weights())
|
||||
for w in self.workers:
|
||||
w.set_weights.remote(weights)
|
||||
samples = ray.get([w.sample.remote() for w in self.workers])
|
||||
else:
|
||||
samples = [DQNEvaluator.sample(self)]
|
||||
|
||||
for s in samples:
|
||||
for obs, action, rew, new_obs, done in s:
|
||||
self.replay_buffer.add(obs, action, rew, new_obs, done)
|
||||
|
||||
if no_replay:
|
||||
return samples
|
||||
|
||||
# Then return a batch sampled from the buffer
|
||||
if self.config["prioritized_replay"]:
|
||||
experience = self.replay_buffer.sample(
|
||||
self.config["train_batch_size"],
|
||||
beta=self.beta_schedule.value(self.global_timestep))
|
||||
(obses_t, actions, rewards, obses_tp1,
|
||||
dones, _, batch_idxes) = experience
|
||||
self._update_priorities_if_needed()
|
||||
self.samples_to_prioritize = (
|
||||
obses_t, actions, rewards, obses_tp1, dones, batch_idxes)
|
||||
else:
|
||||
obses_t, actions, rewards, obses_tp1, dones = \
|
||||
self.replay_buffer.sample(self.config["train_batch_size"])
|
||||
batch_idxes = None
|
||||
|
||||
return self.samples_to_prioritize
|
||||
|
||||
def compute_gradients(self, samples):
|
||||
obses_t, actions, rewards, obses_tp1, dones, batch_indxes = samples
|
||||
td_errors, grad = self.dqn_graph.compute_gradients(
|
||||
self.sess, obses_t, actions, rewards, obses_tp1, dones,
|
||||
np.ones_like(rewards))
|
||||
if self.config["prioritized_replay"]:
|
||||
new_priorities = (
|
||||
np.abs(td_errors) + self.config["prioritized_replay_eps"])
|
||||
self.replay_buffer.update_priorities(batch_indxes, new_priorities)
|
||||
self.samples_to_prioritize = None
|
||||
return grad
|
||||
|
||||
def _update_priorities_if_needed(self):
|
||||
"""Manually updates replay buffer priorities on the last batch.
|
||||
|
||||
Note that this is only needed when not computing gradients on this
|
||||
Evaluator (e.g. when using local multi-GPU). Otherwise, priorities
|
||||
can be updated more efficiently as part of computing gradients.
|
||||
"""
|
||||
|
||||
if not self.samples_to_prioritize:
|
||||
return
|
||||
|
||||
obses_t, actions, rewards, obses_tp1, dones, batch_idxes = \
|
||||
self.samples_to_prioritize
|
||||
td_errors = self.dqn_graph.compute_td_error(
|
||||
self.sess, obses_t, actions, rewards, obses_tp1, dones,
|
||||
np.ones_like(rewards))
|
||||
new_priorities = (
|
||||
np.abs(td_errors) + self.config["prioritized_replay_eps"])
|
||||
self.replay_buffer.update_priorities(batch_idxes, new_priorities)
|
||||
self.samples_to_prioritize = None
|
||||
|
||||
def stats(self):
|
||||
if self.workers:
|
||||
return ray.get([s.stats.remote() for s in self.workers])
|
||||
else:
|
||||
return DQNEvaluator.stats(self)
|
||||
|
||||
def save(self):
|
||||
return [
|
||||
DQNEvaluator.save(self),
|
||||
ray.get([w.save.remote() for w in self.workers]),
|
||||
self.beta_schedule,
|
||||
self.replay_buffer]
|
||||
|
||||
def restore(self, data):
|
||||
DQNEvaluator.restore(self, data[0])
|
||||
for (w, d) in zip(self.workers, data[1]):
|
||||
w.restore.remote(d)
|
||||
self.beta_schedule = data[2]
|
||||
self.replay_buffer = data[3]
|
||||
@@ -0,0 +1,50 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
class Evaluator(object):
|
||||
"""RLlib optimizers require RL algorithms to implement this interface.
|
||||
|
||||
Any algorithm that implements Evaluator can plug in any RLLib optimizer,
|
||||
e.g. async SGD, local multi-GPU SGD, etc.
|
||||
"""
|
||||
|
||||
def sample(self):
|
||||
"""Returns experience samples from this Evaluator."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def compute_gradients(self, samples):
|
||||
"""Returns a gradient computed w.r.t the specified samples."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def apply_gradients(self, grads):
|
||||
"""Applies the given gradients to this Evaluator's weights."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def get_weights(self):
|
||||
"""Returns the model weights of this Evaluator."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def set_weights(self, weights):
|
||||
"""Sets the model weights of this Evaluator."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TFMultiGPUSupport(Evaluator):
|
||||
"""The multi-GPU TF optimizer requires this additional interface."""
|
||||
|
||||
def tf_loss_inputs(self):
|
||||
"""Returns a list of the input placeholders required for the loss."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def build_tf_loss(self, input_placeholders):
|
||||
"""Returns a new loss tensor graph for the specified inputs."""
|
||||
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,6 @@
|
||||
from ray.rllib.optimizers.async import AsyncOptimizer
|
||||
from ray.rllib.optimizers.local_sync import LocalSyncOptimizer
|
||||
from ray.rllib.optimizers.multi_gpu import LocalMultiGPUOptimizer
|
||||
|
||||
|
||||
__all__ = ["AsyncOptimizer", "LocalSyncOptimizer", "LocalMultiGPUOptimizer"]
|
||||
@@ -0,0 +1,58 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
from ray.rllib.optimizers.optimizer import Optimizer
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
|
||||
|
||||
class AsyncOptimizer(Optimizer):
|
||||
"""An asynchronous RL optimizer, e.g. for implementing A3C.
|
||||
|
||||
This optimizer asynchronously pulls and applies gradients from remote
|
||||
evaluators, sending updated weights back as needed. This pipelines the
|
||||
gradient computations on the remote workers.
|
||||
"""
|
||||
def _init(self):
|
||||
self.apply_timer = TimerStat()
|
||||
self.wait_timer = TimerStat()
|
||||
self.dispatch_timer = TimerStat()
|
||||
self.grads_per_step = self.config.get("grads_per_step", 100)
|
||||
|
||||
def step(self):
|
||||
weights = ray.put(self.local_evaluator.get_weights())
|
||||
gradient_queue = []
|
||||
num_gradients = 0
|
||||
|
||||
# Kick off the first wave of async tasks
|
||||
for e in self.remote_evaluators:
|
||||
e.set_weights.remote(weights)
|
||||
fut = e.compute_gradients.remote(e.sample.remote())
|
||||
gradient_queue.append((fut, e))
|
||||
num_gradients += 1
|
||||
|
||||
# Note: can't use wait: https://github.com/ray-project/ray/issues/1128
|
||||
while gradient_queue:
|
||||
with self.wait_timer:
|
||||
fut, e = gradient_queue[0]
|
||||
gradient_queue = gradient_queue[1:]
|
||||
gradient = ray.get(fut)
|
||||
|
||||
if gradient is not None:
|
||||
with self.apply_timer:
|
||||
self.local_evaluator.apply_gradients(gradient)
|
||||
|
||||
if num_gradients < self.grads_per_step:
|
||||
with self.dispatch_timer:
|
||||
e.set_weights.remote(self.local_evaluator.get_weights())
|
||||
fut = e.compute_gradients.remote(e.sample.remote())
|
||||
gradient_queue.append((fut, e))
|
||||
num_gradients += 1
|
||||
|
||||
def stats(self):
|
||||
return {
|
||||
"wait_time_ms": round(1000 * self.wait_timer.mean, 3),
|
||||
"apply_time_ms": round(1000 * self.apply_timer.mean, 3),
|
||||
"dispatch_time_ms": round(1000 * self.dispatch_timer.mean, 3),
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
from ray.rllib.optimizers.optimizer import Optimizer
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
|
||||
|
||||
class LocalSyncOptimizer(Optimizer):
|
||||
"""A simple synchronous RL optimizer.
|
||||
|
||||
In each step, this optimizer pulls samples from a number of remote
|
||||
evaluators, concatenates them, and then updates a local model. The updated
|
||||
model weights are then broadcast to all remote evaluators.
|
||||
"""
|
||||
|
||||
def _init(self):
|
||||
self.update_weights_timer = TimerStat()
|
||||
self.sample_timer = TimerStat()
|
||||
self.grad_timer = TimerStat()
|
||||
|
||||
def step(self):
|
||||
with self.update_weights_timer:
|
||||
if self.remote_evaluators:
|
||||
weights = ray.put(self.local_evaluator.get_weights())
|
||||
for e in self.remote_evaluators:
|
||||
e.set_weights.remote(weights)
|
||||
|
||||
with self.sample_timer:
|
||||
if self.remote_evaluators:
|
||||
samples = _concat(
|
||||
ray.get(
|
||||
[e.sample.remote() for e in self.remote_evaluators]))
|
||||
else:
|
||||
samples = self.local_evaluator.sample()
|
||||
|
||||
with self.grad_timer:
|
||||
grad = self.local_evaluator.compute_gradients(samples)
|
||||
self.local_evaluator.apply_gradients(grad)
|
||||
|
||||
def stats(self):
|
||||
return {
|
||||
"sample_time_ms": round(1000 * self.sample_timer.mean, 3),
|
||||
"grad_time_ms": round(1000 * self.grad_timer.mean, 3),
|
||||
"update_time_ms": round(1000 * self.update_weights_timer.mean, 3),
|
||||
}
|
||||
|
||||
|
||||
# TODO(ekl) this should be implemented by some sample batch class
|
||||
def _concat(samples):
|
||||
result = []
|
||||
for s in samples:
|
||||
result.extend(s)
|
||||
return result
|
||||
@@ -0,0 +1,9 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.optimizers.optimizer import Optimizer
|
||||
|
||||
|
||||
class LocalMultiGPUOptimizer(Optimizer):
|
||||
pass # TODO(ekl)
|
||||
@@ -0,0 +1,45 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
class Optimizer(object):
|
||||
"""RLlib optimizers encapsulate distributed RL optimization strategies.
|
||||
|
||||
For example, AsyncOptimizer is used for A3C, and LocalMultiGPUOptimizer is
|
||||
used for PPO. These optimizers are all pluggable however, it is possible
|
||||
to mix as match as needed.
|
||||
|
||||
In order for an algorithm to use an RLlib optimizer, it must implement
|
||||
the Evaluator interface and pass a number of Evaluators to its Optimizer
|
||||
of choice. The Optimizer uses these Evaluators to sample from the
|
||||
environment and compute model gradient updates.
|
||||
"""
|
||||
|
||||
def __init__(self, config, local_evaluator, remote_evaluators):
|
||||
"""Create an optimizer instance.
|
||||
|
||||
Args:
|
||||
config (dict): Optimizer-specific configuration data.
|
||||
local_evaluator (Evaluator): Local evaluator instance, required.
|
||||
remote_evaluators (list): A list of handles to remote evaluators.
|
||||
if empty, the optimizer should fall back to to using only the
|
||||
local evaluator.
|
||||
"""
|
||||
self.config = config
|
||||
self.local_evaluator = local_evaluator
|
||||
self.remote_evaluators = remote_evaluators
|
||||
self._init()
|
||||
|
||||
def _init(self):
|
||||
pass
|
||||
|
||||
def step(self):
|
||||
"""Takes a logical optimization step."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def stats(self):
|
||||
"""Returns a dictionary of internal performance statistics."""
|
||||
|
||||
return {}
|
||||
@@ -0,0 +1,33 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
# TODO(ekl) move this to a common location
|
||||
from ray.rllib.ppo.filter import RunningStat
|
||||
|
||||
|
||||
class TimerStat(RunningStat):
|
||||
"""A running stat for conveniently logging the duration of a code block.
|
||||
|
||||
Example:
|
||||
wait_timer = TimeStat()
|
||||
with wait_timer:
|
||||
ray.wait(...)
|
||||
|
||||
Note that this class is *not* thread-safe.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
RunningStat.__init__(self, ())
|
||||
self._start_time = None
|
||||
|
||||
def __enter__(self):
|
||||
assert self._start_time is None, "concurrent updates not supported"
|
||||
self._start_time = time.monotonic()
|
||||
|
||||
def __exit__(self, type, value, tb):
|
||||
assert self._start_time is not None
|
||||
self.push(time.monotonic() - self._start_time)
|
||||
self._start_time = None
|
||||
@@ -0,0 +1,32 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
from ray.rllib.utils.filter import RunningStat
|
||||
|
||||
|
||||
class TimerStat(RunningStat):
|
||||
"""A running stat for conveniently logging the duration of a code block.
|
||||
|
||||
Example:
|
||||
wait_timer = TimeStat()
|
||||
with wait_timer:
|
||||
ray.wait(...)
|
||||
|
||||
Note that this class is *not* thread-safe.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
RunningStat.__init__(self, ())
|
||||
self._start_time = None
|
||||
|
||||
def __enter__(self):
|
||||
assert self._start_time is None, "concurrent updates not supported"
|
||||
self._start_time = time.monotonic()
|
||||
|
||||
def __exit__(self, type, value, tb):
|
||||
assert self._start_time is not None
|
||||
self.push(time.monotonic() - self._start_time)
|
||||
self._start_time = None
|
||||
@@ -3,6 +3,13 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import namedtuple
|
||||
import json
|
||||
|
||||
try:
|
||||
import yaml
|
||||
except ImportError:
|
||||
print("Could not import YAML module, falling back to JSON pretty-printing")
|
||||
yaml = None
|
||||
|
||||
"""
|
||||
When using ray.tune with custom training scripts, you must periodically report
|
||||
@@ -30,6 +37,9 @@ TrainingResult = namedtuple("TrainingResult", [
|
||||
# (Optional) The mean episode length if applicable.
|
||||
"episode_len_mean",
|
||||
|
||||
# (Optional) The number of episodes total.
|
||||
"episodes_total",
|
||||
|
||||
# (Optional) The current training accuracy if applicable>
|
||||
"mean_accuracy",
|
||||
|
||||
@@ -63,4 +73,16 @@ TrainingResult = namedtuple("TrainingResult", [
|
||||
"hostname",
|
||||
])
|
||||
|
||||
|
||||
def pretty_print(result):
|
||||
out = {}
|
||||
for k, v in result._asdict().items():
|
||||
if v is not None:
|
||||
out[k] = v
|
||||
if yaml:
|
||||
return yaml.dump(out, default_flow_style=False)
|
||||
else:
|
||||
return json.dumps(out) + "\n"
|
||||
|
||||
|
||||
TrainingResult.__new__.__defaults__ = (None,) * len(TrainingResult._fields)
|
||||
|
||||
@@ -8,6 +8,7 @@ import time
|
||||
import traceback
|
||||
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.result import pretty_print
|
||||
from ray.tune.trial import Trial, Resources
|
||||
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
|
||||
@@ -163,10 +164,7 @@ class TrialRunner(object):
|
||||
result = ray.get(result_id)
|
||||
trial.result_logger.on_result(result)
|
||||
print("TrainingResult for {}:".format(trial))
|
||||
for k, v in result._asdict().items():
|
||||
if v is not None:
|
||||
print(" {}={}".format(k, v))
|
||||
print()
|
||||
print(" {}".format(pretty_print(result).replace("\n", "\n ")))
|
||||
trial.last_result = result
|
||||
self._total_time += result.time_this_iter_s
|
||||
|
||||
|
||||
Reference in New Issue
Block a user