mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 10:33:16 +08:00
[rllib] (take 2) Add top-level checkpoint/restore/compute_action APIs to rllib (#868)
* add top-level checkpoint/restore api to rllib * todos
This commit is contained in:
committed by
Robert Nishihara
parent
e6de744ef4
commit
46641a642f
@@ -117,11 +117,11 @@ class A3C(Algorithm):
|
||||
gradient_list.extend(
|
||||
[self.agents[info["id"]].compute_gradient.remote(
|
||||
self.parameters)])
|
||||
res = self.fetch_metrics_from_workers()
|
||||
res = self._fetch_metrics_from_workers()
|
||||
self.iteration += 1
|
||||
return res
|
||||
|
||||
def fetch_metrics_from_workers(self):
|
||||
def _fetch_metrics_from_workers(self):
|
||||
episode_rewards = []
|
||||
episode_lengths = []
|
||||
metric_lists = [
|
||||
@@ -134,5 +134,11 @@ class A3C(Algorithm):
|
||||
avg_length = np.mean(episode_lengths) if episode_lengths else None
|
||||
res = TrainingResult(
|
||||
self.experiment_id.hex, self.iteration,
|
||||
avg_reward, avg_length, dict())
|
||||
avg_reward, avg_length, None, dict())
|
||||
return res
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
raise NotImplementedError # TODO(ekl)
|
||||
|
||||
def compute_action(self, observation):
|
||||
raise NotImplementedError # TODO(ekl)
|
||||
|
||||
@@ -50,6 +50,7 @@ TrainingResult = namedtuple("TrainingResult", [
|
||||
"training_iteration",
|
||||
"episode_reward_mean",
|
||||
"episode_len_mean",
|
||||
"latest_checkpoint",
|
||||
"info"
|
||||
])
|
||||
|
||||
@@ -108,3 +109,17 @@ class Algorithm(object):
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
"""Restores training state from a given checkpoint.
|
||||
|
||||
These checkpoints are returned from calls to train() in the
|
||||
checkpoint_path field of TrainingResult.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def compute_action(self, observation):
|
||||
"""Computes an action using the current trained policy."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -235,6 +235,12 @@ class DQN(Algorithm):
|
||||
|
||||
res = TrainingResult(
|
||||
self.experiment_id.hex, self.num_iterations, mean_100ep_reward,
|
||||
mean_100ep_length, info)
|
||||
mean_100ep_length, None, info)
|
||||
self.num_iterations += 1
|
||||
return res
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
raise NotImplementedError # TODO(ekl)
|
||||
|
||||
def compute_action(self, observation):
|
||||
raise NotImplementedError # TODO(ekl)
|
||||
|
||||
@@ -327,8 +327,14 @@ class EvolutionStrategies(Algorithm):
|
||||
"time_elapsed": step_tend - self.tstart
|
||||
}
|
||||
res = TrainingResult(self.experiment_id.hex, self.iteration,
|
||||
returns_n2.mean(), lengths_n2.mean(), info)
|
||||
returns_n2.mean(), lengths_n2.mean(), None, info)
|
||||
|
||||
self.iteration += 1
|
||||
|
||||
return res
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
raise NotImplementedError # TODO(ekl)
|
||||
|
||||
def compute_action(self, observation):
|
||||
raise NotImplementedError # TODO(ekl)
|
||||
|
||||
@@ -72,7 +72,7 @@ DEFAULT_CONFIG = {
|
||||
# If True, we write checkpoints and tensorflow logging
|
||||
"write_logs": True,
|
||||
# Name of the model checkpoint file
|
||||
"model_checkpoint_file": "iteration-%s.ckpt"}
|
||||
"model_checkpoint_file": "checkpoint"}
|
||||
|
||||
|
||||
class PolicyGradient(Algorithm):
|
||||
@@ -90,6 +90,14 @@ class PolicyGradient(Algorithm):
|
||||
self.env_name, 1, self.config, self.logdir, True)
|
||||
for _ in range(config["num_agents"])]
|
||||
self.start_time = time.time()
|
||||
# TF does not support to write logs to S3 at the moment
|
||||
write_tf_logs = config["write_logs"] and self.logdir.startswith("file")
|
||||
if write_tf_logs:
|
||||
self.file_writer = tf.summary.FileWriter(
|
||||
self.logdir, self.model.sess.graph)
|
||||
else:
|
||||
self.file_writer = None
|
||||
self.saver = tf.train.Saver(max_to_keep=None)
|
||||
|
||||
def train(self):
|
||||
agents = self.agents
|
||||
@@ -100,22 +108,7 @@ class PolicyGradient(Algorithm):
|
||||
|
||||
print("===> iteration", self.j)
|
||||
|
||||
saver = tf.train.Saver(max_to_keep=None)
|
||||
if "load_checkpoint" in config:
|
||||
saver.restore(model.sess, config["load_checkpoint"])
|
||||
|
||||
# TF does not support to write logs to S3 at the moment
|
||||
write_tf_logs = config["write_logs"] and self.logdir.startswith("file")
|
||||
iter_start = time.time()
|
||||
if write_tf_logs:
|
||||
file_writer = tf.summary.FileWriter(self.logdir, model.sess.graph)
|
||||
if config["model_checkpoint_file"]:
|
||||
checkpoint_path = saver.save(
|
||||
model.sess,
|
||||
os.path.join(
|
||||
self.logdir, config["model_checkpoint_file"] % j))
|
||||
print("Checkpoint saved in file: %s" % checkpoint_path)
|
||||
checkpointing_end = time.time()
|
||||
weights = ray.put(model.get_weights())
|
||||
[a.load_weights.remote(weights) for a in agents]
|
||||
trajectory, total_reward, traj_len_mean = collect_samples(
|
||||
@@ -123,7 +116,7 @@ class PolicyGradient(Algorithm):
|
||||
print("total reward is ", total_reward)
|
||||
print("trajectory length mean is ", traj_len_mean)
|
||||
print("timesteps:", trajectory["dones"].shape[0])
|
||||
if write_tf_logs:
|
||||
if self.file_writer:
|
||||
traj_stats = tf.Summary(value=[
|
||||
tf.Summary.Value(
|
||||
tag="policy_gradient/rollouts/mean_reward",
|
||||
@@ -131,7 +124,7 @@ class PolicyGradient(Algorithm):
|
||||
tf.Summary.Value(
|
||||
tag="policy_gradient/rollouts/traj_len_mean",
|
||||
simple_value=traj_len_mean)])
|
||||
file_writer.add_summary(traj_stats, self.global_step)
|
||||
self.file_writer.add_summary(traj_stats, self.global_step)
|
||||
self.global_step += 1
|
||||
|
||||
def standardized(value):
|
||||
@@ -155,8 +148,7 @@ class PolicyGradient(Algorithm):
|
||||
tuples_per_device = model.load_data(
|
||||
trajectory, j == 0 and config["full_trace_data_load"])
|
||||
load_end = time.time()
|
||||
checkpointing_time = checkpointing_end - iter_start
|
||||
rollouts_time = rollouts_end - checkpointing_end
|
||||
rollouts_time = rollouts_end - iter_start
|
||||
shuffle_time = shuffle_end - rollouts_end
|
||||
load_time = load_end - shuffle_end
|
||||
sgd_time = 0
|
||||
@@ -178,7 +170,7 @@ class PolicyGradient(Algorithm):
|
||||
batch_entropy = model.run_sgd_minibatch(
|
||||
permutation[batch_index] * model.per_device_batch_size,
|
||||
self.kl_coeff, full_trace,
|
||||
file_writer if write_tf_logs else None)
|
||||
self.file_writer)
|
||||
loss.append(batch_loss)
|
||||
policy_loss.append(batch_policy_loss)
|
||||
vf_loss.append(batch_vf_loss)
|
||||
@@ -201,21 +193,19 @@ class PolicyGradient(Algorithm):
|
||||
values.append(tf.Summary.Value(
|
||||
tag=metric_prefix + "kl_coeff",
|
||||
simple_value=self.kl_coeff))
|
||||
else:
|
||||
metric_prefix = "policy_gradient/sgd/intermediate_iters/"
|
||||
values.extend([
|
||||
tf.Summary.Value(
|
||||
tag=metric_prefix + "mean_entropy",
|
||||
simple_value=entropy),
|
||||
tf.Summary.Value(
|
||||
tag=metric_prefix + "mean_loss",
|
||||
simple_value=loss),
|
||||
tf.Summary.Value(
|
||||
tag=metric_prefix + "mean_kl",
|
||||
simple_value=kl)])
|
||||
if write_tf_logs:
|
||||
sgd_stats = tf.Summary(value=values)
|
||||
file_writer.add_summary(sgd_stats, self.global_step)
|
||||
values.extend([
|
||||
tf.Summary.Value(
|
||||
tag=metric_prefix + "mean_entropy",
|
||||
simple_value=entropy),
|
||||
tf.Summary.Value(
|
||||
tag=metric_prefix + "mean_loss",
|
||||
simple_value=loss),
|
||||
tf.Summary.Value(
|
||||
tag=metric_prefix + "mean_kl",
|
||||
simple_value=kl)])
|
||||
if self.file_writer:
|
||||
sgd_stats = tf.Summary(value=values)
|
||||
self.file_writer.add_summary(sgd_stats, self.global_step)
|
||||
self.global_step += 1
|
||||
sgd_time += sgd_end - sgd_start
|
||||
if kl > 2.0 * config["kl_target"]:
|
||||
@@ -223,6 +213,17 @@ class PolicyGradient(Algorithm):
|
||||
elif kl < 0.5 * config["kl_target"]:
|
||||
self.kl_coeff *= 0.5
|
||||
|
||||
checkpointing_start = time.time()
|
||||
if config["model_checkpoint_file"]:
|
||||
checkpoint_path = self.saver.save(
|
||||
model.sess,
|
||||
os.path.join(
|
||||
self.logdir, config["model_checkpoint_file"]),
|
||||
global_step=j)
|
||||
else:
|
||||
checkpoint_path = None
|
||||
checkpointing_time = time.time() - checkpointing_start
|
||||
|
||||
info = {
|
||||
"kl_divergence": kl,
|
||||
"kl_coefficient": self.kl_coeff,
|
||||
@@ -245,6 +246,14 @@ class PolicyGradient(Algorithm):
|
||||
print("total time so far:", time.time() - self.start_time)
|
||||
|
||||
result = TrainingResult(
|
||||
self.experiment_id.hex, j, total_reward, traj_len_mean, info)
|
||||
self.experiment_id.hex, j, total_reward, traj_len_mean,
|
||||
checkpoint_path, info)
|
||||
|
||||
return result
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
self.saver.restore(self.model.sess, checkpoint_path)
|
||||
|
||||
def compute_action(self, observation):
|
||||
return self.model.common_policy.compute_actions(
|
||||
observation[np.newaxis, :])[0][0]
|
||||
|
||||
Reference in New Issue
Block a user