[rllib] Make sure to always record stats like time elapsed, timesteps (#965)

* always record training stats

* fix

* comments

* revert assert

* nan

* fix
This commit is contained in:
Eric Liang
2017-09-12 14:28:16 -07:00
committed by Philipp Moritz
parent 74ac80631b
commit 9f42ef6a4f
6 changed files with 137 additions and 67 deletions
+17 -15
View File
@@ -100,9 +100,8 @@ class A3CAgent(Agent):
config["batch_size"], self.logdir)
for i in range(config["num_workers"])]
self.parameters = self.policy.get_weights()
self.iteration = 0
def train(self):
def _train(self):
gradient_list = [
agent.compute_gradient.remote(self.parameters)
for agent in self.agents]
@@ -119,7 +118,6 @@ class A3CAgent(Agent):
[self.agents[info["id"]].compute_gradient.remote(
self.parameters)])
res = self._fetch_metrics_from_workers()
self.iteration += 1
return res
def _fetch_metrics_from_workers(self):
@@ -131,27 +129,31 @@ class A3CAgent(Agent):
for episode in ray.get(metrics):
episode_lengths.append(episode.episode_length)
episode_rewards.append(episode.episode_reward)
avg_reward = np.mean(episode_rewards) if episode_rewards else None
avg_length = np.mean(episode_lengths) if episode_lengths else None
res = TrainingResult(
self.experiment_id.hex, self.iteration,
avg_reward, avg_length, dict())
return res
avg_reward = (
np.mean(episode_rewards) if episode_rewards else float('nan'))
avg_length = (
np.mean(episode_lengths) if episode_lengths else float('nan'))
timesteps = np.sum(episode_lengths) if episode_lengths else 0
def save(self):
result = TrainingResult(
episode_reward_mean=avg_reward,
episode_len_mean=avg_length,
timesteps_this_iter=timesteps,
info={})
return result
def _save(self):
checkpoint_path = os.path.join(
self.logdir, "checkpoint-{}".format(self.iteration))
objects = [
self.parameters,
self.iteration]
objects = [self.parameters]
pickle.dump(objects, open(checkpoint_path, "wb"))
return checkpoint_path
def restore(self, checkpoint_path):
def _restore(self, checkpoint_path):
objects = pickle.load(open(checkpoint_path, "rb"))
self.parameters = objects[0]
self.policy.set_weights(self.parameters)
self.iteration = objects[1]
def compute_action(self, observation):
actions = self.policy.compute_actions(observation)[0]
+82 -8
View File
@@ -4,8 +4,10 @@ import json
import logging
import numpy as np
import os
import pickle
import sys
import tempfile
import time
import uuid
import smart_open
@@ -46,13 +48,37 @@ class RLLibLogger(object):
TrainingResult = namedtuple("TrainingResult", [
# Unique string identifier for this experiment. This id is preserved
# across checkpoint / restore calls.
"experiment_id",
# The index of this training iteration, e.g. call to train().
"training_iteration",
# The mean episode reward reported during this iteration.
"episode_reward_mean",
# The mean episode length reported during this iteration.
"episode_len_mean",
"info"
# Agent-specific metadata to report for this iteration.
"info",
# Number of timesteps in the simulator in this iteration.
"timesteps_this_iter",
# Accumulated timesteps for this entire experiment.
"timesteps_total",
# Time in seconds this iteration took to run.
"time_this_iter_s",
# Accumulated time in seconds for this entire experiment.
"time_total_s",
])
TrainingResult.__new__.__defaults__ = (None,) * len(TrainingResult._fields)
class Agent(object):
"""All RLlib agents extend this base class.
@@ -64,8 +90,6 @@ class Agent(object):
env_name (str): Name of the OpenAI gym environment to train against.
config (obj): Algorithm-specific configuration data.
logdir (str): Directory in which training outputs should be placed.
TODO(ekl): support checkpoint / restore of training state.
"""
def __init__(self, env_name, config, upload_dir=None):
@@ -79,10 +103,11 @@ class Agent(object):
like s3://bucketname/.
"""
upload_dir = "file:///tmp/ray" if upload_dir is None else upload_dir
self.experiment_id = uuid.uuid4()
self.experiment_id = uuid.uuid4().hex
self.env_name = env_name
self.config = config
self.config.update({"experiment_id": self.experiment_id.hex})
self.config.update({"experiment_id": self.experiment_id})
self.config.update({"env_name": env_name})
prefix = "{}_{}_{}".format(
env_name,
@@ -92,6 +117,8 @@ class Agent(object):
self.logdir = tempfile.mkdtemp(prefix=prefix, dir="/tmp/ray")
else:
self.logdir = os.path.join(upload_dir, prefix)
# TODO(ekl) consider inlining config into the result jsons
log_path = os.path.join(self.logdir, "config.json")
with smart_open.smart_open(log_path, "w") as f:
json.dump(self.config, f, sort_keys=True, cls=RLLibEncoder)
@@ -99,6 +126,10 @@ class Agent(object):
"%s algorithm created with logdir '%s'",
self.__class__.__name__, self.logdir)
self.iteration = 0
self.time_total = 0.0
self.timesteps_total = 0
def train(self):
"""Runs one logical iteration of training.
@@ -106,7 +137,25 @@ class Agent(object):
A TrainingResult that describes training progress.
"""
raise NotImplementedError
start = time.time()
self.iteration += 1
result = self._train()
time_this_iter = time.time() - start
self.time_total += time_this_iter
self.timesteps_total += result.timesteps_this_iter
result = result._replace(
experiment_id=self.experiment_id,
training_iteration=self.iteration,
timesteps_total=self.timesteps_total,
time_this_iter_s=time_this_iter,
time_total_s=self.time_total)
for field in result:
assert field is not None, result
return result
def save(self):
"""Saves the current model state to a checkpoint.
@@ -115,7 +164,12 @@ class Agent(object):
Checkpoint path that may be passed to restore().
"""
raise NotImplementedError
checkpoint_path = self._save()
pickle.dump(
[self.experiment_id, self.iteration, self.timesteps_total,
self.time_total_s],
open(checkpoint_path + ".rllib_metadata", "wb"))
return checkpoint_path
def restore(self, checkpoint_path):
"""Restores training state from a given model checkpoint.
@@ -123,9 +177,29 @@ class Agent(object):
These checkpoints are returned from calls to save().
"""
raise NotImplementedError
self._restore(checkpoint_path)
metadata = pickle.load(open(checkpoint_path + ".rllib_metadata", "rb"))
self.experiment_id = metadata[0]
self.iteration = metadata[1]
self.timesteps_total = metadata[2]
self.time_total_s = metadata[3]
def compute_action(self, observation):
"""Computes an action using the current trained policy."""
raise NotImplementedError
def _train(self):
"""Subclasses should override this to implement train()."""
raise NotImplementedError
def _save(self):
"""Subclasses should override this to implement save()."""
raise NotImplementedError
def _restore(self):
"""Subclasses should override this to implement restore()."""
raise NotImplementedError
+11 -8
View File
@@ -164,9 +164,10 @@ class DQNAgent(Agent):
self.file_writer = tf.summary.FileWriter(self.logdir, self.sess.graph)
self.saver = tf.train.Saver(max_to_keep=None)
def train(self):
def _train(self):
config = self.config
sample_time, learn_time = 0, 0
iter_init_timesteps = self.num_timesteps
for _ in range(config["timesteps_per_iteration"]):
self.num_timesteps += 1
@@ -243,13 +244,15 @@ class DQNAgent(Agent):
int(100 * self.exploration.value(self.num_timesteps)))
logger.dump_tabular()
res = TrainingResult(
self.experiment_id.hex, self.num_iterations, mean_100ep_reward,
mean_100ep_length, info)
self.num_iterations += 1
return res
result = TrainingResult(
episode_reward_mean=mean_100ep_reward,
episode_len_mean=mean_100ep_length,
timesteps_this_iter=self.num_timesteps - iter_init_timesteps,
info=info)
def save(self):
return result
def _save(self):
checkpoint_path = self.saver.save(
self.sess,
os.path.join(self.logdir, "checkpoint"),
@@ -267,7 +270,7 @@ class DQNAgent(Agent):
pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
return checkpoint_path
def restore(self, checkpoint_path):
def _restore(self, checkpoint_path):
self.saver.restore(self.sess, checkpoint_path)
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
self.replay_buffer = extra_data[0]
+10 -19
View File
@@ -203,7 +203,6 @@ class ESAgent(Agent):
self.episodes_so_far = 0
self.timesteps_so_far = 0
self.tstart = time.time()
self.iteration = 0
def _collect_results(self, theta_id, min_eps, min_timesteps):
num_eps, num_timesteps = 0, 0
@@ -224,7 +223,7 @@ class ESAgent(Agent):
num_timesteps += result.lengths_n2.sum()
return results
def train(self):
def _train(self):
config = self.config
step_tstart = time.time()
@@ -314,14 +313,6 @@ class ESAgent(Agent):
tlogger.record_tabular("TimeElapsed", step_tend - self.tstart)
tlogger.dump_tabular()
if (config["snapshot_freq"] != 0 and
self.iteration % config["snapshot_freq"] == 0):
filename = os.path.join(
self.logdir, "snapshot_iter{:05d}.h5".format(self.iteration))
assert not os.path.exists(filename)
self.policy.save(filename)
tlogger.log("Saved snapshot {}".format(filename))
info = {
"weights_norm": np.square(self.policy.get_trainable_flat()).sum(),
"grad_norm": np.square(g).sum(),
@@ -334,14 +325,16 @@ class ESAgent(Agent):
"time_elapsed_this_iter": step_tend - step_tstart,
"time_elapsed": step_tend - self.tstart
}
res = TrainingResult(self.experiment_id.hex, self.iteration,
returns_n2.mean(), lengths_n2.mean(), info)
self.iteration += 1
result = TrainingResult(
episode_reward_mean=returns_n2.mean(),
episode_len_mean=lengths_n2.mean(),
timesteps_this_iter=lengths_n2.sum(),
info=info)
return res
return result
def save(self):
def _save(self):
checkpoint_path = os.path.join(
self.logdir, "checkpoint-{}".format(self.iteration))
weights = self.policy.get_trainable_flat()
@@ -349,18 +342,16 @@ class ESAgent(Agent):
weights,
self.ob_stat,
self.episodes_so_far,
self.timesteps_so_far,
self.iteration]
self.timesteps_so_far]
pickle.dump(objects, open(checkpoint_path, "wb"))
return checkpoint_path
def restore(self, checkpoint_path):
def _restore(self, checkpoint_path):
objects = pickle.load(open(checkpoint_path, "rb"))
self.policy.set_trainable_flat(objects[0])
self.ob_stat = objects[1]
self.episodes_so_far = objects[2]
self.timesteps_so_far = objects[3]
self.iteration = objects[4]
def compute_action(self, observation):
return self.policy.act([observation])[0]
+14 -16
View File
@@ -89,7 +89,6 @@ class PPOAgent(Agent):
def _init(self):
self.global_step = 0
self.j = 0
self.kl_coeff = self.config["kl_coeff"]
self.model = Runner(self.env_name, 1, self.config, self.logdir, False)
self.agents = [
@@ -104,14 +103,12 @@ class PPOAgent(Agent):
self.file_writer = None
self.saver = tf.train.Saver(max_to_keep=None)
def train(self):
def _train(self):
agents = self.agents
config = self.config
model = self.model
j = self.j
self.j += 1
print("===> iteration", self.j)
print("===> iteration", self.iteration)
iter_start = time.time()
weights = ray.put(model.get_weights())
@@ -151,7 +148,7 @@ class PPOAgent(Agent):
trajectory = shuffle(trajectory)
shuffle_end = time.time()
tuples_per_device = model.load_data(
trajectory, j == 0 and config["full_trace_data_load"])
trajectory, self.iteration == 0 and config["full_trace_data_load"])
load_end = time.time()
rollouts_time = rollouts_end - iter_start
shuffle_time = shuffle_end - rollouts_end
@@ -165,11 +162,11 @@ class PPOAgent(Agent):
loss, policy_loss, vf_loss, kl, entropy = [], [], [], [], []
permutation = np.random.permutation(num_batches)
# Prepare to drop into the debugger
if j == config["tf_debug_iteration"]:
if self.iteration == config["tf_debug_iteration"]:
model.sess = tf_debug.LocalCLIDebugWrapperSession(model.sess)
while batch_index < num_batches:
full_trace = (
i == 0 and j == 0 and
i == 0 and self.iteration == 0 and
batch_index == config["full_trace_nth_sgd_batch"])
batch_loss, batch_policy_loss, batch_vf_loss, batch_kl, \
batch_entropy = model.run_sgd_minibatch(
@@ -238,35 +235,36 @@ class PPOAgent(Agent):
print("total time so far:", time.time() - self.start_time)
result = TrainingResult(
self.experiment_id.hex, j, total_reward, traj_len_mean, info)
episode_reward_mean=total_reward,
episode_len_mean=traj_len_mean,
timesteps_this_iter=trajectory["dones"].shape[0],
info=info)
return result
def save(self):
def _save(self):
checkpoint_path = self.saver.save(
self.model.sess,
os.path.join(self.logdir, "checkpoint"),
global_step=self.j)
global_step=self.iteration)
agent_state = ray.get([a.save.remote() for a in self.agents])
extra_data = [
self.model.save(),
self.global_step,
self.j,
self.kl_coeff,
agent_state]
pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
return checkpoint_path
def restore(self, checkpoint_path):
def _restore(self, checkpoint_path):
self.saver.restore(self.model.sess, checkpoint_path)
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
self.model.restore(extra_data[0])
self.global_step = extra_data[1]
self.j = extra_data[2]
self.kl_coeff = extra_data[3]
self.kl_coeff = extra_data[2]
ray.get([
a.restore.remote(o)
for (a, o) in zip(self.agents, extra_data[4])])
for (a, o) in zip(self.agents, extra_data[3])])
def compute_action(self, observation):
observation = self.model.observation_filter(observation)
+3 -1
View File
@@ -7,6 +7,7 @@ from __future__ import print_function
import argparse
import json
import os
import pprint
import sys
import ray
@@ -89,7 +90,8 @@ if __name__ == "__main__":
cls=ray.rllib.common.RLLibEncoder)
result_logger.write("\n")
print("current status: {}".format(result))
print("== Iteration {} ==".format(alg.iteration))
pprint.pprint(result._asdict())
if (i + 1) % args.checkpoint_freq == 0:
print("checkpoint path: {}".format(alg.save()))