diff --git a/python/ray/rllib/a3c/a3c.py b/python/ray/rllib/a3c/a3c.py index e19b75120..28acc320a 100644 --- a/python/ray/rllib/a3c/a3c.py +++ b/python/ray/rllib/a3c/a3c.py @@ -117,11 +117,11 @@ class A3C(Algorithm): gradient_list.extend( [self.agents[info["id"]].compute_gradient.remote( self.parameters)]) - res = self.fetch_metrics_from_workers() + res = self._fetch_metrics_from_workers() self.iteration += 1 return res - def fetch_metrics_from_workers(self): + def _fetch_metrics_from_workers(self): episode_rewards = [] episode_lengths = [] metric_lists = [ @@ -134,5 +134,11 @@ class A3C(Algorithm): avg_length = np.mean(episode_lengths) if episode_lengths else None res = TrainingResult( self.experiment_id.hex, self.iteration, - avg_reward, avg_length, dict()) + avg_reward, avg_length, None, dict()) return res + + def restore(self, checkpoint_path): + raise NotImplementedError # TODO(ekl) + + def compute_action(self, observation): + raise NotImplementedError # TODO(ekl) diff --git a/python/ray/rllib/common.py b/python/ray/rllib/common.py index c30a8267b..5410edcbf 100644 --- a/python/ray/rllib/common.py +++ b/python/ray/rllib/common.py @@ -50,6 +50,7 @@ TrainingResult = namedtuple("TrainingResult", [ "training_iteration", "episode_reward_mean", "episode_len_mean", + "latest_checkpoint", "info" ]) @@ -108,3 +109,17 @@ class Algorithm(object): """ raise NotImplementedError + + def restore(self, checkpoint_path): + """Restores training state from a given checkpoint. + + These checkpoints are returned from calls to train() in the + checkpoint_path field of TrainingResult. + """ + + raise NotImplementedError + + def compute_action(self, observation): + """Computes an action using the current trained policy.""" + + raise NotImplementedError diff --git a/python/ray/rllib/dqn/dqn.py b/python/ray/rllib/dqn/dqn.py index 0e2202226..b0dd7f641 100644 --- a/python/ray/rllib/dqn/dqn.py +++ b/python/ray/rllib/dqn/dqn.py @@ -235,6 +235,12 @@ class DQN(Algorithm): res = TrainingResult( self.experiment_id.hex, self.num_iterations, mean_100ep_reward, - mean_100ep_length, info) + mean_100ep_length, None, info) self.num_iterations += 1 return res + + def restore(self, checkpoint_path): + raise NotImplementedError # TODO(ekl) + + def compute_action(self, observation): + raise NotImplementedError # TODO(ekl) diff --git a/python/ray/rllib/evolution_strategies/evolution_strategies.py b/python/ray/rllib/evolution_strategies/evolution_strategies.py index 138eb00c3..0667df8d6 100644 --- a/python/ray/rllib/evolution_strategies/evolution_strategies.py +++ b/python/ray/rllib/evolution_strategies/evolution_strategies.py @@ -327,8 +327,14 @@ class EvolutionStrategies(Algorithm): "time_elapsed": step_tend - self.tstart } res = TrainingResult(self.experiment_id.hex, self.iteration, - returns_n2.mean(), lengths_n2.mean(), info) + returns_n2.mean(), lengths_n2.mean(), None, info) self.iteration += 1 return res + + def restore(self, checkpoint_path): + raise NotImplementedError # TODO(ekl) + + def compute_action(self, observation): + raise NotImplementedError # TODO(ekl) diff --git a/python/ray/rllib/policy_gradient/policy_gradient.py b/python/ray/rllib/policy_gradient/policy_gradient.py index 8361185c7..dbee42167 100644 --- a/python/ray/rllib/policy_gradient/policy_gradient.py +++ b/python/ray/rllib/policy_gradient/policy_gradient.py @@ -72,7 +72,7 @@ DEFAULT_CONFIG = { # If True, we write checkpoints and tensorflow logging "write_logs": True, # Name of the model checkpoint file - "model_checkpoint_file": "iteration-%s.ckpt"} + "model_checkpoint_file": "checkpoint"} class PolicyGradient(Algorithm): @@ -90,6 +90,14 @@ class PolicyGradient(Algorithm): self.env_name, 1, self.config, self.logdir, True) for _ in range(config["num_agents"])] self.start_time = time.time() + # TF does not support to write logs to S3 at the moment + write_tf_logs = config["write_logs"] and self.logdir.startswith("file") + if write_tf_logs: + self.file_writer = tf.summary.FileWriter( + self.logdir, self.model.sess.graph) + else: + self.file_writer = None + self.saver = tf.train.Saver(max_to_keep=None) def train(self): agents = self.agents @@ -100,22 +108,7 @@ class PolicyGradient(Algorithm): print("===> iteration", self.j) - saver = tf.train.Saver(max_to_keep=None) - if "load_checkpoint" in config: - saver.restore(model.sess, config["load_checkpoint"]) - - # TF does not support to write logs to S3 at the moment - write_tf_logs = config["write_logs"] and self.logdir.startswith("file") iter_start = time.time() - if write_tf_logs: - file_writer = tf.summary.FileWriter(self.logdir, model.sess.graph) - if config["model_checkpoint_file"]: - checkpoint_path = saver.save( - model.sess, - os.path.join( - self.logdir, config["model_checkpoint_file"] % j)) - print("Checkpoint saved in file: %s" % checkpoint_path) - checkpointing_end = time.time() weights = ray.put(model.get_weights()) [a.load_weights.remote(weights) for a in agents] trajectory, total_reward, traj_len_mean = collect_samples( @@ -123,7 +116,7 @@ class PolicyGradient(Algorithm): print("total reward is ", total_reward) print("trajectory length mean is ", traj_len_mean) print("timesteps:", trajectory["dones"].shape[0]) - if write_tf_logs: + if self.file_writer: traj_stats = tf.Summary(value=[ tf.Summary.Value( tag="policy_gradient/rollouts/mean_reward", @@ -131,7 +124,7 @@ class PolicyGradient(Algorithm): tf.Summary.Value( tag="policy_gradient/rollouts/traj_len_mean", simple_value=traj_len_mean)]) - file_writer.add_summary(traj_stats, self.global_step) + self.file_writer.add_summary(traj_stats, self.global_step) self.global_step += 1 def standardized(value): @@ -155,8 +148,7 @@ class PolicyGradient(Algorithm): tuples_per_device = model.load_data( trajectory, j == 0 and config["full_trace_data_load"]) load_end = time.time() - checkpointing_time = checkpointing_end - iter_start - rollouts_time = rollouts_end - checkpointing_end + rollouts_time = rollouts_end - iter_start shuffle_time = shuffle_end - rollouts_end load_time = load_end - shuffle_end sgd_time = 0 @@ -178,7 +170,7 @@ class PolicyGradient(Algorithm): batch_entropy = model.run_sgd_minibatch( permutation[batch_index] * model.per_device_batch_size, self.kl_coeff, full_trace, - file_writer if write_tf_logs else None) + self.file_writer) loss.append(batch_loss) policy_loss.append(batch_policy_loss) vf_loss.append(batch_vf_loss) @@ -201,21 +193,19 @@ class PolicyGradient(Algorithm): values.append(tf.Summary.Value( tag=metric_prefix + "kl_coeff", simple_value=self.kl_coeff)) - else: - metric_prefix = "policy_gradient/sgd/intermediate_iters/" - values.extend([ - tf.Summary.Value( - tag=metric_prefix + "mean_entropy", - simple_value=entropy), - tf.Summary.Value( - tag=metric_prefix + "mean_loss", - simple_value=loss), - tf.Summary.Value( - tag=metric_prefix + "mean_kl", - simple_value=kl)]) - if write_tf_logs: - sgd_stats = tf.Summary(value=values) - file_writer.add_summary(sgd_stats, self.global_step) + values.extend([ + tf.Summary.Value( + tag=metric_prefix + "mean_entropy", + simple_value=entropy), + tf.Summary.Value( + tag=metric_prefix + "mean_loss", + simple_value=loss), + tf.Summary.Value( + tag=metric_prefix + "mean_kl", + simple_value=kl)]) + if self.file_writer: + sgd_stats = tf.Summary(value=values) + self.file_writer.add_summary(sgd_stats, self.global_step) self.global_step += 1 sgd_time += sgd_end - sgd_start if kl > 2.0 * config["kl_target"]: @@ -223,6 +213,17 @@ class PolicyGradient(Algorithm): elif kl < 0.5 * config["kl_target"]: self.kl_coeff *= 0.5 + checkpointing_start = time.time() + if config["model_checkpoint_file"]: + checkpoint_path = self.saver.save( + model.sess, + os.path.join( + self.logdir, config["model_checkpoint_file"]), + global_step=j) + else: + checkpoint_path = None + checkpointing_time = time.time() - checkpointing_start + info = { "kl_divergence": kl, "kl_coefficient": self.kl_coeff, @@ -245,6 +246,14 @@ class PolicyGradient(Algorithm): print("total time so far:", time.time() - self.start_time) result = TrainingResult( - self.experiment_id.hex, j, total_reward, traj_len_mean, info) + self.experiment_id.hex, j, total_reward, traj_len_mean, + checkpoint_path, info) return result + + def restore(self, checkpoint_path): + self.saver.restore(self.model.sess, checkpoint_path) + + def compute_action(self, observation): + return self.model.common_policy.compute_actions( + observation[np.newaxis, :])[0][0]