[rllib] (take 2) Add top-level checkpoint/restore/compute_action APIs to rllib (#868)

* add top-level checkpoint/restore api to rllib

* todos
This commit is contained in:
Eric Liang
2017-08-24 00:09:33 -07:00
committed by Robert Nishihara
parent e6de744ef4
commit 46641a642f
5 changed files with 84 additions and 42 deletions
+9 -3
View File
@@ -117,11 +117,11 @@ class A3C(Algorithm):
gradient_list.extend(
[self.agents[info["id"]].compute_gradient.remote(
self.parameters)])
res = self.fetch_metrics_from_workers()
res = self._fetch_metrics_from_workers()
self.iteration += 1
return res
def fetch_metrics_from_workers(self):
def _fetch_metrics_from_workers(self):
episode_rewards = []
episode_lengths = []
metric_lists = [
@@ -134,5 +134,11 @@ class A3C(Algorithm):
avg_length = np.mean(episode_lengths) if episode_lengths else None
res = TrainingResult(
self.experiment_id.hex, self.iteration,
avg_reward, avg_length, dict())
avg_reward, avg_length, None, dict())
return res
def restore(self, checkpoint_path):
raise NotImplementedError # TODO(ekl)
def compute_action(self, observation):
raise NotImplementedError # TODO(ekl)
+15
View File
@@ -50,6 +50,7 @@ TrainingResult = namedtuple("TrainingResult", [
"training_iteration",
"episode_reward_mean",
"episode_len_mean",
"latest_checkpoint",
"info"
])
@@ -108,3 +109,17 @@ class Algorithm(object):
"""
raise NotImplementedError
def restore(self, checkpoint_path):
"""Restores training state from a given checkpoint.
These checkpoints are returned from calls to train() in the
checkpoint_path field of TrainingResult.
"""
raise NotImplementedError
def compute_action(self, observation):
"""Computes an action using the current trained policy."""
raise NotImplementedError
+7 -1
View File
@@ -235,6 +235,12 @@ class DQN(Algorithm):
res = TrainingResult(
self.experiment_id.hex, self.num_iterations, mean_100ep_reward,
mean_100ep_length, info)
mean_100ep_length, None, info)
self.num_iterations += 1
return res
def restore(self, checkpoint_path):
raise NotImplementedError # TODO(ekl)
def compute_action(self, observation):
raise NotImplementedError # TODO(ekl)
@@ -327,8 +327,14 @@ class EvolutionStrategies(Algorithm):
"time_elapsed": step_tend - self.tstart
}
res = TrainingResult(self.experiment_id.hex, self.iteration,
returns_n2.mean(), lengths_n2.mean(), info)
returns_n2.mean(), lengths_n2.mean(), None, info)
self.iteration += 1
return res
def restore(self, checkpoint_path):
raise NotImplementedError # TODO(ekl)
def compute_action(self, observation):
raise NotImplementedError # TODO(ekl)
@@ -72,7 +72,7 @@ DEFAULT_CONFIG = {
# If True, we write checkpoints and tensorflow logging
"write_logs": True,
# Name of the model checkpoint file
"model_checkpoint_file": "iteration-%s.ckpt"}
"model_checkpoint_file": "checkpoint"}
class PolicyGradient(Algorithm):
@@ -90,6 +90,14 @@ class PolicyGradient(Algorithm):
self.env_name, 1, self.config, self.logdir, True)
for _ in range(config["num_agents"])]
self.start_time = time.time()
# TF does not support to write logs to S3 at the moment
write_tf_logs = config["write_logs"] and self.logdir.startswith("file")
if write_tf_logs:
self.file_writer = tf.summary.FileWriter(
self.logdir, self.model.sess.graph)
else:
self.file_writer = None
self.saver = tf.train.Saver(max_to_keep=None)
def train(self):
agents = self.agents
@@ -100,22 +108,7 @@ class PolicyGradient(Algorithm):
print("===> iteration", self.j)
saver = tf.train.Saver(max_to_keep=None)
if "load_checkpoint" in config:
saver.restore(model.sess, config["load_checkpoint"])
# TF does not support to write logs to S3 at the moment
write_tf_logs = config["write_logs"] and self.logdir.startswith("file")
iter_start = time.time()
if write_tf_logs:
file_writer = tf.summary.FileWriter(self.logdir, model.sess.graph)
if config["model_checkpoint_file"]:
checkpoint_path = saver.save(
model.sess,
os.path.join(
self.logdir, config["model_checkpoint_file"] % j))
print("Checkpoint saved in file: %s" % checkpoint_path)
checkpointing_end = time.time()
weights = ray.put(model.get_weights())
[a.load_weights.remote(weights) for a in agents]
trajectory, total_reward, traj_len_mean = collect_samples(
@@ -123,7 +116,7 @@ class PolicyGradient(Algorithm):
print("total reward is ", total_reward)
print("trajectory length mean is ", traj_len_mean)
print("timesteps:", trajectory["dones"].shape[0])
if write_tf_logs:
if self.file_writer:
traj_stats = tf.Summary(value=[
tf.Summary.Value(
tag="policy_gradient/rollouts/mean_reward",
@@ -131,7 +124,7 @@ class PolicyGradient(Algorithm):
tf.Summary.Value(
tag="policy_gradient/rollouts/traj_len_mean",
simple_value=traj_len_mean)])
file_writer.add_summary(traj_stats, self.global_step)
self.file_writer.add_summary(traj_stats, self.global_step)
self.global_step += 1
def standardized(value):
@@ -155,8 +148,7 @@ class PolicyGradient(Algorithm):
tuples_per_device = model.load_data(
trajectory, j == 0 and config["full_trace_data_load"])
load_end = time.time()
checkpointing_time = checkpointing_end - iter_start
rollouts_time = rollouts_end - checkpointing_end
rollouts_time = rollouts_end - iter_start
shuffle_time = shuffle_end - rollouts_end
load_time = load_end - shuffle_end
sgd_time = 0
@@ -178,7 +170,7 @@ class PolicyGradient(Algorithm):
batch_entropy = model.run_sgd_minibatch(
permutation[batch_index] * model.per_device_batch_size,
self.kl_coeff, full_trace,
file_writer if write_tf_logs else None)
self.file_writer)
loss.append(batch_loss)
policy_loss.append(batch_policy_loss)
vf_loss.append(batch_vf_loss)
@@ -201,21 +193,19 @@ class PolicyGradient(Algorithm):
values.append(tf.Summary.Value(
tag=metric_prefix + "kl_coeff",
simple_value=self.kl_coeff))
else:
metric_prefix = "policy_gradient/sgd/intermediate_iters/"
values.extend([
tf.Summary.Value(
tag=metric_prefix + "mean_entropy",
simple_value=entropy),
tf.Summary.Value(
tag=metric_prefix + "mean_loss",
simple_value=loss),
tf.Summary.Value(
tag=metric_prefix + "mean_kl",
simple_value=kl)])
if write_tf_logs:
sgd_stats = tf.Summary(value=values)
file_writer.add_summary(sgd_stats, self.global_step)
values.extend([
tf.Summary.Value(
tag=metric_prefix + "mean_entropy",
simple_value=entropy),
tf.Summary.Value(
tag=metric_prefix + "mean_loss",
simple_value=loss),
tf.Summary.Value(
tag=metric_prefix + "mean_kl",
simple_value=kl)])
if self.file_writer:
sgd_stats = tf.Summary(value=values)
self.file_writer.add_summary(sgd_stats, self.global_step)
self.global_step += 1
sgd_time += sgd_end - sgd_start
if kl > 2.0 * config["kl_target"]:
@@ -223,6 +213,17 @@ class PolicyGradient(Algorithm):
elif kl < 0.5 * config["kl_target"]:
self.kl_coeff *= 0.5
checkpointing_start = time.time()
if config["model_checkpoint_file"]:
checkpoint_path = self.saver.save(
model.sess,
os.path.join(
self.logdir, config["model_checkpoint_file"]),
global_step=j)
else:
checkpoint_path = None
checkpointing_time = time.time() - checkpointing_start
info = {
"kl_divergence": kl,
"kl_coefficient": self.kl_coeff,
@@ -245,6 +246,14 @@ class PolicyGradient(Algorithm):
print("total time so far:", time.time() - self.start_time)
result = TrainingResult(
self.experiment_id.hex, j, total_reward, traj_len_mean, info)
self.experiment_id.hex, j, total_reward, traj_len_mean,
checkpoint_path, info)
return result
def restore(self, checkpoint_path):
self.saver.restore(self.model.sess, checkpoint_path)
def compute_action(self, observation):
return self.model.common_policy.compute_actions(
observation[np.newaxis, :])[0][0]