mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:38:18 +08:00
[rllib] Make sure to always record stats like time elapsed, timesteps (#965)
* always record training stats * fix * comments * revert assert * nan * fix
This commit is contained in:
committed by
Philipp Moritz
parent
74ac80631b
commit
9f42ef6a4f
+17
-15
@@ -100,9 +100,8 @@ class A3CAgent(Agent):
|
||||
config["batch_size"], self.logdir)
|
||||
for i in range(config["num_workers"])]
|
||||
self.parameters = self.policy.get_weights()
|
||||
self.iteration = 0
|
||||
|
||||
def train(self):
|
||||
def _train(self):
|
||||
gradient_list = [
|
||||
agent.compute_gradient.remote(self.parameters)
|
||||
for agent in self.agents]
|
||||
@@ -119,7 +118,6 @@ class A3CAgent(Agent):
|
||||
[self.agents[info["id"]].compute_gradient.remote(
|
||||
self.parameters)])
|
||||
res = self._fetch_metrics_from_workers()
|
||||
self.iteration += 1
|
||||
return res
|
||||
|
||||
def _fetch_metrics_from_workers(self):
|
||||
@@ -131,27 +129,31 @@ class A3CAgent(Agent):
|
||||
for episode in ray.get(metrics):
|
||||
episode_lengths.append(episode.episode_length)
|
||||
episode_rewards.append(episode.episode_reward)
|
||||
avg_reward = np.mean(episode_rewards) if episode_rewards else None
|
||||
avg_length = np.mean(episode_lengths) if episode_lengths else None
|
||||
res = TrainingResult(
|
||||
self.experiment_id.hex, self.iteration,
|
||||
avg_reward, avg_length, dict())
|
||||
return res
|
||||
avg_reward = (
|
||||
np.mean(episode_rewards) if episode_rewards else float('nan'))
|
||||
avg_length = (
|
||||
np.mean(episode_lengths) if episode_lengths else float('nan'))
|
||||
timesteps = np.sum(episode_lengths) if episode_lengths else 0
|
||||
|
||||
def save(self):
|
||||
result = TrainingResult(
|
||||
episode_reward_mean=avg_reward,
|
||||
episode_len_mean=avg_length,
|
||||
timesteps_this_iter=timesteps,
|
||||
info={})
|
||||
|
||||
return result
|
||||
|
||||
def _save(self):
|
||||
checkpoint_path = os.path.join(
|
||||
self.logdir, "checkpoint-{}".format(self.iteration))
|
||||
objects = [
|
||||
self.parameters,
|
||||
self.iteration]
|
||||
objects = [self.parameters]
|
||||
pickle.dump(objects, open(checkpoint_path, "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
def _restore(self, checkpoint_path):
|
||||
objects = pickle.load(open(checkpoint_path, "rb"))
|
||||
self.parameters = objects[0]
|
||||
self.policy.set_weights(self.parameters)
|
||||
self.iteration = objects[1]
|
||||
|
||||
def compute_action(self, observation):
|
||||
actions = self.policy.compute_actions(observation)[0]
|
||||
|
||||
@@ -4,8 +4,10 @@ import json
|
||||
import logging
|
||||
import numpy as np
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
import smart_open
|
||||
|
||||
@@ -46,13 +48,37 @@ class RLLibLogger(object):
|
||||
|
||||
|
||||
TrainingResult = namedtuple("TrainingResult", [
|
||||
# Unique string identifier for this experiment. This id is preserved
|
||||
# across checkpoint / restore calls.
|
||||
"experiment_id",
|
||||
|
||||
# The index of this training iteration, e.g. call to train().
|
||||
"training_iteration",
|
||||
|
||||
# The mean episode reward reported during this iteration.
|
||||
"episode_reward_mean",
|
||||
|
||||
# The mean episode length reported during this iteration.
|
||||
"episode_len_mean",
|
||||
"info"
|
||||
|
||||
# Agent-specific metadata to report for this iteration.
|
||||
"info",
|
||||
|
||||
# Number of timesteps in the simulator in this iteration.
|
||||
"timesteps_this_iter",
|
||||
|
||||
# Accumulated timesteps for this entire experiment.
|
||||
"timesteps_total",
|
||||
|
||||
# Time in seconds this iteration took to run.
|
||||
"time_this_iter_s",
|
||||
|
||||
# Accumulated time in seconds for this entire experiment.
|
||||
"time_total_s",
|
||||
])
|
||||
|
||||
TrainingResult.__new__.__defaults__ = (None,) * len(TrainingResult._fields)
|
||||
|
||||
|
||||
class Agent(object):
|
||||
"""All RLlib agents extend this base class.
|
||||
@@ -64,8 +90,6 @@ class Agent(object):
|
||||
env_name (str): Name of the OpenAI gym environment to train against.
|
||||
config (obj): Algorithm-specific configuration data.
|
||||
logdir (str): Directory in which training outputs should be placed.
|
||||
|
||||
TODO(ekl): support checkpoint / restore of training state.
|
||||
"""
|
||||
|
||||
def __init__(self, env_name, config, upload_dir=None):
|
||||
@@ -79,10 +103,11 @@ class Agent(object):
|
||||
like s3://bucketname/.
|
||||
"""
|
||||
upload_dir = "file:///tmp/ray" if upload_dir is None else upload_dir
|
||||
self.experiment_id = uuid.uuid4()
|
||||
self.experiment_id = uuid.uuid4().hex
|
||||
self.env_name = env_name
|
||||
|
||||
self.config = config
|
||||
self.config.update({"experiment_id": self.experiment_id.hex})
|
||||
self.config.update({"experiment_id": self.experiment_id})
|
||||
self.config.update({"env_name": env_name})
|
||||
prefix = "{}_{}_{}".format(
|
||||
env_name,
|
||||
@@ -92,6 +117,8 @@ class Agent(object):
|
||||
self.logdir = tempfile.mkdtemp(prefix=prefix, dir="/tmp/ray")
|
||||
else:
|
||||
self.logdir = os.path.join(upload_dir, prefix)
|
||||
|
||||
# TODO(ekl) consider inlining config into the result jsons
|
||||
log_path = os.path.join(self.logdir, "config.json")
|
||||
with smart_open.smart_open(log_path, "w") as f:
|
||||
json.dump(self.config, f, sort_keys=True, cls=RLLibEncoder)
|
||||
@@ -99,6 +126,10 @@ class Agent(object):
|
||||
"%s algorithm created with logdir '%s'",
|
||||
self.__class__.__name__, self.logdir)
|
||||
|
||||
self.iteration = 0
|
||||
self.time_total = 0.0
|
||||
self.timesteps_total = 0
|
||||
|
||||
def train(self):
|
||||
"""Runs one logical iteration of training.
|
||||
|
||||
@@ -106,7 +137,25 @@ class Agent(object):
|
||||
A TrainingResult that describes training progress.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
start = time.time()
|
||||
self.iteration += 1
|
||||
result = self._train()
|
||||
time_this_iter = time.time() - start
|
||||
|
||||
self.time_total += time_this_iter
|
||||
self.timesteps_total += result.timesteps_this_iter
|
||||
|
||||
result = result._replace(
|
||||
experiment_id=self.experiment_id,
|
||||
training_iteration=self.iteration,
|
||||
timesteps_total=self.timesteps_total,
|
||||
time_this_iter_s=time_this_iter,
|
||||
time_total_s=self.time_total)
|
||||
|
||||
for field in result:
|
||||
assert field is not None, result
|
||||
|
||||
return result
|
||||
|
||||
def save(self):
|
||||
"""Saves the current model state to a checkpoint.
|
||||
@@ -115,7 +164,12 @@ class Agent(object):
|
||||
Checkpoint path that may be passed to restore().
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
checkpoint_path = self._save()
|
||||
pickle.dump(
|
||||
[self.experiment_id, self.iteration, self.timesteps_total,
|
||||
self.time_total_s],
|
||||
open(checkpoint_path + ".rllib_metadata", "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
"""Restores training state from a given model checkpoint.
|
||||
@@ -123,9 +177,29 @@ class Agent(object):
|
||||
These checkpoints are returned from calls to save().
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
self._restore(checkpoint_path)
|
||||
metadata = pickle.load(open(checkpoint_path + ".rllib_metadata", "rb"))
|
||||
self.experiment_id = metadata[0]
|
||||
self.iteration = metadata[1]
|
||||
self.timesteps_total = metadata[2]
|
||||
self.time_total_s = metadata[3]
|
||||
|
||||
def compute_action(self, observation):
|
||||
"""Computes an action using the current trained policy."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def _train(self):
|
||||
"""Subclasses should override this to implement train()."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def _save(self):
|
||||
"""Subclasses should override this to implement save()."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def _restore(self):
|
||||
"""Subclasses should override this to implement restore()."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -164,9 +164,10 @@ class DQNAgent(Agent):
|
||||
self.file_writer = tf.summary.FileWriter(self.logdir, self.sess.graph)
|
||||
self.saver = tf.train.Saver(max_to_keep=None)
|
||||
|
||||
def train(self):
|
||||
def _train(self):
|
||||
config = self.config
|
||||
sample_time, learn_time = 0, 0
|
||||
iter_init_timesteps = self.num_timesteps
|
||||
|
||||
for _ in range(config["timesteps_per_iteration"]):
|
||||
self.num_timesteps += 1
|
||||
@@ -243,13 +244,15 @@ class DQNAgent(Agent):
|
||||
int(100 * self.exploration.value(self.num_timesteps)))
|
||||
logger.dump_tabular()
|
||||
|
||||
res = TrainingResult(
|
||||
self.experiment_id.hex, self.num_iterations, mean_100ep_reward,
|
||||
mean_100ep_length, info)
|
||||
self.num_iterations += 1
|
||||
return res
|
||||
result = TrainingResult(
|
||||
episode_reward_mean=mean_100ep_reward,
|
||||
episode_len_mean=mean_100ep_length,
|
||||
timesteps_this_iter=self.num_timesteps - iter_init_timesteps,
|
||||
info=info)
|
||||
|
||||
def save(self):
|
||||
return result
|
||||
|
||||
def _save(self):
|
||||
checkpoint_path = self.saver.save(
|
||||
self.sess,
|
||||
os.path.join(self.logdir, "checkpoint"),
|
||||
@@ -267,7 +270,7 @@ class DQNAgent(Agent):
|
||||
pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
def _restore(self, checkpoint_path):
|
||||
self.saver.restore(self.sess, checkpoint_path)
|
||||
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
|
||||
self.replay_buffer = extra_data[0]
|
||||
|
||||
+10
-19
@@ -203,7 +203,6 @@ class ESAgent(Agent):
|
||||
self.episodes_so_far = 0
|
||||
self.timesteps_so_far = 0
|
||||
self.tstart = time.time()
|
||||
self.iteration = 0
|
||||
|
||||
def _collect_results(self, theta_id, min_eps, min_timesteps):
|
||||
num_eps, num_timesteps = 0, 0
|
||||
@@ -224,7 +223,7 @@ class ESAgent(Agent):
|
||||
num_timesteps += result.lengths_n2.sum()
|
||||
return results
|
||||
|
||||
def train(self):
|
||||
def _train(self):
|
||||
config = self.config
|
||||
|
||||
step_tstart = time.time()
|
||||
@@ -314,14 +313,6 @@ class ESAgent(Agent):
|
||||
tlogger.record_tabular("TimeElapsed", step_tend - self.tstart)
|
||||
tlogger.dump_tabular()
|
||||
|
||||
if (config["snapshot_freq"] != 0 and
|
||||
self.iteration % config["snapshot_freq"] == 0):
|
||||
filename = os.path.join(
|
||||
self.logdir, "snapshot_iter{:05d}.h5".format(self.iteration))
|
||||
assert not os.path.exists(filename)
|
||||
self.policy.save(filename)
|
||||
tlogger.log("Saved snapshot {}".format(filename))
|
||||
|
||||
info = {
|
||||
"weights_norm": np.square(self.policy.get_trainable_flat()).sum(),
|
||||
"grad_norm": np.square(g).sum(),
|
||||
@@ -334,14 +325,16 @@ class ESAgent(Agent):
|
||||
"time_elapsed_this_iter": step_tend - step_tstart,
|
||||
"time_elapsed": step_tend - self.tstart
|
||||
}
|
||||
res = TrainingResult(self.experiment_id.hex, self.iteration,
|
||||
returns_n2.mean(), lengths_n2.mean(), info)
|
||||
|
||||
self.iteration += 1
|
||||
result = TrainingResult(
|
||||
episode_reward_mean=returns_n2.mean(),
|
||||
episode_len_mean=lengths_n2.mean(),
|
||||
timesteps_this_iter=lengths_n2.sum(),
|
||||
info=info)
|
||||
|
||||
return res
|
||||
return result
|
||||
|
||||
def save(self):
|
||||
def _save(self):
|
||||
checkpoint_path = os.path.join(
|
||||
self.logdir, "checkpoint-{}".format(self.iteration))
|
||||
weights = self.policy.get_trainable_flat()
|
||||
@@ -349,18 +342,16 @@ class ESAgent(Agent):
|
||||
weights,
|
||||
self.ob_stat,
|
||||
self.episodes_so_far,
|
||||
self.timesteps_so_far,
|
||||
self.iteration]
|
||||
self.timesteps_so_far]
|
||||
pickle.dump(objects, open(checkpoint_path, "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
def _restore(self, checkpoint_path):
|
||||
objects = pickle.load(open(checkpoint_path, "rb"))
|
||||
self.policy.set_trainable_flat(objects[0])
|
||||
self.ob_stat = objects[1]
|
||||
self.episodes_so_far = objects[2]
|
||||
self.timesteps_so_far = objects[3]
|
||||
self.iteration = objects[4]
|
||||
|
||||
def compute_action(self, observation):
|
||||
return self.policy.act([observation])[0]
|
||||
|
||||
+14
-16
@@ -89,7 +89,6 @@ class PPOAgent(Agent):
|
||||
|
||||
def _init(self):
|
||||
self.global_step = 0
|
||||
self.j = 0
|
||||
self.kl_coeff = self.config["kl_coeff"]
|
||||
self.model = Runner(self.env_name, 1, self.config, self.logdir, False)
|
||||
self.agents = [
|
||||
@@ -104,14 +103,12 @@ class PPOAgent(Agent):
|
||||
self.file_writer = None
|
||||
self.saver = tf.train.Saver(max_to_keep=None)
|
||||
|
||||
def train(self):
|
||||
def _train(self):
|
||||
agents = self.agents
|
||||
config = self.config
|
||||
model = self.model
|
||||
j = self.j
|
||||
self.j += 1
|
||||
|
||||
print("===> iteration", self.j)
|
||||
print("===> iteration", self.iteration)
|
||||
|
||||
iter_start = time.time()
|
||||
weights = ray.put(model.get_weights())
|
||||
@@ -151,7 +148,7 @@ class PPOAgent(Agent):
|
||||
trajectory = shuffle(trajectory)
|
||||
shuffle_end = time.time()
|
||||
tuples_per_device = model.load_data(
|
||||
trajectory, j == 0 and config["full_trace_data_load"])
|
||||
trajectory, self.iteration == 0 and config["full_trace_data_load"])
|
||||
load_end = time.time()
|
||||
rollouts_time = rollouts_end - iter_start
|
||||
shuffle_time = shuffle_end - rollouts_end
|
||||
@@ -165,11 +162,11 @@ class PPOAgent(Agent):
|
||||
loss, policy_loss, vf_loss, kl, entropy = [], [], [], [], []
|
||||
permutation = np.random.permutation(num_batches)
|
||||
# Prepare to drop into the debugger
|
||||
if j == config["tf_debug_iteration"]:
|
||||
if self.iteration == config["tf_debug_iteration"]:
|
||||
model.sess = tf_debug.LocalCLIDebugWrapperSession(model.sess)
|
||||
while batch_index < num_batches:
|
||||
full_trace = (
|
||||
i == 0 and j == 0 and
|
||||
i == 0 and self.iteration == 0 and
|
||||
batch_index == config["full_trace_nth_sgd_batch"])
|
||||
batch_loss, batch_policy_loss, batch_vf_loss, batch_kl, \
|
||||
batch_entropy = model.run_sgd_minibatch(
|
||||
@@ -238,35 +235,36 @@ class PPOAgent(Agent):
|
||||
print("total time so far:", time.time() - self.start_time)
|
||||
|
||||
result = TrainingResult(
|
||||
self.experiment_id.hex, j, total_reward, traj_len_mean, info)
|
||||
episode_reward_mean=total_reward,
|
||||
episode_len_mean=traj_len_mean,
|
||||
timesteps_this_iter=trajectory["dones"].shape[0],
|
||||
info=info)
|
||||
|
||||
return result
|
||||
|
||||
def save(self):
|
||||
def _save(self):
|
||||
checkpoint_path = self.saver.save(
|
||||
self.model.sess,
|
||||
os.path.join(self.logdir, "checkpoint"),
|
||||
global_step=self.j)
|
||||
global_step=self.iteration)
|
||||
agent_state = ray.get([a.save.remote() for a in self.agents])
|
||||
extra_data = [
|
||||
self.model.save(),
|
||||
self.global_step,
|
||||
self.j,
|
||||
self.kl_coeff,
|
||||
agent_state]
|
||||
pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
def _restore(self, checkpoint_path):
|
||||
self.saver.restore(self.model.sess, checkpoint_path)
|
||||
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
|
||||
self.model.restore(extra_data[0])
|
||||
self.global_step = extra_data[1]
|
||||
self.j = extra_data[2]
|
||||
self.kl_coeff = extra_data[3]
|
||||
self.kl_coeff = extra_data[2]
|
||||
ray.get([
|
||||
a.restore.remote(o)
|
||||
for (a, o) in zip(self.agents, extra_data[4])])
|
||||
for (a, o) in zip(self.agents, extra_data[3])])
|
||||
|
||||
def compute_action(self, observation):
|
||||
observation = self.model.observation_filter(observation)
|
||||
|
||||
@@ -7,6 +7,7 @@ from __future__ import print_function
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import pprint
|
||||
import sys
|
||||
|
||||
import ray
|
||||
@@ -89,7 +90,8 @@ if __name__ == "__main__":
|
||||
cls=ray.rllib.common.RLLibEncoder)
|
||||
result_logger.write("\n")
|
||||
|
||||
print("current status: {}".format(result))
|
||||
print("== Iteration {} ==".format(alg.iteration))
|
||||
pprint.pprint(result._asdict())
|
||||
|
||||
if (i + 1) % args.checkpoint_freq == 0:
|
||||
print("checkpoint path: {}".format(alg.save()))
|
||||
|
||||
Reference in New Issue
Block a user