diff --git a/python/ray/rllib/a3c/a3c.py b/python/ray/rllib/a3c/a3c.py index f82dea64a..306e4753d 100644 --- a/python/ray/rllib/a3c/a3c.py +++ b/python/ray/rllib/a3c/a3c.py @@ -100,9 +100,8 @@ class A3CAgent(Agent): config["batch_size"], self.logdir) for i in range(config["num_workers"])] self.parameters = self.policy.get_weights() - self.iteration = 0 - def train(self): + def _train(self): gradient_list = [ agent.compute_gradient.remote(self.parameters) for agent in self.agents] @@ -119,7 +118,6 @@ class A3CAgent(Agent): [self.agents[info["id"]].compute_gradient.remote( self.parameters)]) res = self._fetch_metrics_from_workers() - self.iteration += 1 return res def _fetch_metrics_from_workers(self): @@ -131,27 +129,31 @@ class A3CAgent(Agent): for episode in ray.get(metrics): episode_lengths.append(episode.episode_length) episode_rewards.append(episode.episode_reward) - avg_reward = np.mean(episode_rewards) if episode_rewards else None - avg_length = np.mean(episode_lengths) if episode_lengths else None - res = TrainingResult( - self.experiment_id.hex, self.iteration, - avg_reward, avg_length, dict()) - return res + avg_reward = ( + np.mean(episode_rewards) if episode_rewards else float('nan')) + avg_length = ( + np.mean(episode_lengths) if episode_lengths else float('nan')) + timesteps = np.sum(episode_lengths) if episode_lengths else 0 - def save(self): + result = TrainingResult( + episode_reward_mean=avg_reward, + episode_len_mean=avg_length, + timesteps_this_iter=timesteps, + info={}) + + return result + + def _save(self): checkpoint_path = os.path.join( self.logdir, "checkpoint-{}".format(self.iteration)) - objects = [ - self.parameters, - self.iteration] + objects = [self.parameters] pickle.dump(objects, open(checkpoint_path, "wb")) return checkpoint_path - def restore(self, checkpoint_path): + def _restore(self, checkpoint_path): objects = pickle.load(open(checkpoint_path, "rb")) self.parameters = objects[0] self.policy.set_weights(self.parameters) - self.iteration = objects[1] def compute_action(self, observation): actions = self.policy.compute_actions(observation)[0] diff --git a/python/ray/rllib/common.py b/python/ray/rllib/common.py index 6434d87f0..b46987202 100644 --- a/python/ray/rllib/common.py +++ b/python/ray/rllib/common.py @@ -4,8 +4,10 @@ import json import logging import numpy as np import os +import pickle import sys import tempfile +import time import uuid import smart_open @@ -46,13 +48,37 @@ class RLLibLogger(object): TrainingResult = namedtuple("TrainingResult", [ + # Unique string identifier for this experiment. This id is preserved + # across checkpoint / restore calls. "experiment_id", + + # The index of this training iteration, e.g. call to train(). "training_iteration", + + # The mean episode reward reported during this iteration. "episode_reward_mean", + + # The mean episode length reported during this iteration. "episode_len_mean", - "info" + + # Agent-specific metadata to report for this iteration. + "info", + + # Number of timesteps in the simulator in this iteration. + "timesteps_this_iter", + + # Accumulated timesteps for this entire experiment. + "timesteps_total", + + # Time in seconds this iteration took to run. + "time_this_iter_s", + + # Accumulated time in seconds for this entire experiment. + "time_total_s", ]) +TrainingResult.__new__.__defaults__ = (None,) * len(TrainingResult._fields) + class Agent(object): """All RLlib agents extend this base class. @@ -64,8 +90,6 @@ class Agent(object): env_name (str): Name of the OpenAI gym environment to train against. config (obj): Algorithm-specific configuration data. logdir (str): Directory in which training outputs should be placed. - - TODO(ekl): support checkpoint / restore of training state. """ def __init__(self, env_name, config, upload_dir=None): @@ -79,10 +103,11 @@ class Agent(object): like s3://bucketname/. """ upload_dir = "file:///tmp/ray" if upload_dir is None else upload_dir - self.experiment_id = uuid.uuid4() + self.experiment_id = uuid.uuid4().hex self.env_name = env_name + self.config = config - self.config.update({"experiment_id": self.experiment_id.hex}) + self.config.update({"experiment_id": self.experiment_id}) self.config.update({"env_name": env_name}) prefix = "{}_{}_{}".format( env_name, @@ -92,6 +117,8 @@ class Agent(object): self.logdir = tempfile.mkdtemp(prefix=prefix, dir="/tmp/ray") else: self.logdir = os.path.join(upload_dir, prefix) + + # TODO(ekl) consider inlining config into the result jsons log_path = os.path.join(self.logdir, "config.json") with smart_open.smart_open(log_path, "w") as f: json.dump(self.config, f, sort_keys=True, cls=RLLibEncoder) @@ -99,6 +126,10 @@ class Agent(object): "%s algorithm created with logdir '%s'", self.__class__.__name__, self.logdir) + self.iteration = 0 + self.time_total = 0.0 + self.timesteps_total = 0 + def train(self): """Runs one logical iteration of training. @@ -106,7 +137,25 @@ class Agent(object): A TrainingResult that describes training progress. """ - raise NotImplementedError + start = time.time() + self.iteration += 1 + result = self._train() + time_this_iter = time.time() - start + + self.time_total += time_this_iter + self.timesteps_total += result.timesteps_this_iter + + result = result._replace( + experiment_id=self.experiment_id, + training_iteration=self.iteration, + timesteps_total=self.timesteps_total, + time_this_iter_s=time_this_iter, + time_total_s=self.time_total) + + for field in result: + assert field is not None, result + + return result def save(self): """Saves the current model state to a checkpoint. @@ -115,7 +164,12 @@ class Agent(object): Checkpoint path that may be passed to restore(). """ - raise NotImplementedError + checkpoint_path = self._save() + pickle.dump( + [self.experiment_id, self.iteration, self.timesteps_total, + self.time_total_s], + open(checkpoint_path + ".rllib_metadata", "wb")) + return checkpoint_path def restore(self, checkpoint_path): """Restores training state from a given model checkpoint. @@ -123,9 +177,29 @@ class Agent(object): These checkpoints are returned from calls to save(). """ - raise NotImplementedError + self._restore(checkpoint_path) + metadata = pickle.load(open(checkpoint_path + ".rllib_metadata", "rb")) + self.experiment_id = metadata[0] + self.iteration = metadata[1] + self.timesteps_total = metadata[2] + self.time_total_s = metadata[3] def compute_action(self, observation): """Computes an action using the current trained policy.""" raise NotImplementedError + + def _train(self): + """Subclasses should override this to implement train().""" + + raise NotImplementedError + + def _save(self): + """Subclasses should override this to implement save().""" + + raise NotImplementedError + + def _restore(self): + """Subclasses should override this to implement restore().""" + + raise NotImplementedError diff --git a/python/ray/rllib/dqn/dqn.py b/python/ray/rllib/dqn/dqn.py index afbf978d3..77794ea16 100644 --- a/python/ray/rllib/dqn/dqn.py +++ b/python/ray/rllib/dqn/dqn.py @@ -164,9 +164,10 @@ class DQNAgent(Agent): self.file_writer = tf.summary.FileWriter(self.logdir, self.sess.graph) self.saver = tf.train.Saver(max_to_keep=None) - def train(self): + def _train(self): config = self.config sample_time, learn_time = 0, 0 + iter_init_timesteps = self.num_timesteps for _ in range(config["timesteps_per_iteration"]): self.num_timesteps += 1 @@ -243,13 +244,15 @@ class DQNAgent(Agent): int(100 * self.exploration.value(self.num_timesteps))) logger.dump_tabular() - res = TrainingResult( - self.experiment_id.hex, self.num_iterations, mean_100ep_reward, - mean_100ep_length, info) - self.num_iterations += 1 - return res + result = TrainingResult( + episode_reward_mean=mean_100ep_reward, + episode_len_mean=mean_100ep_length, + timesteps_this_iter=self.num_timesteps - iter_init_timesteps, + info=info) - def save(self): + return result + + def _save(self): checkpoint_path = self.saver.save( self.sess, os.path.join(self.logdir, "checkpoint"), @@ -267,7 +270,7 @@ class DQNAgent(Agent): pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb")) return checkpoint_path - def restore(self, checkpoint_path): + def _restore(self, checkpoint_path): self.saver.restore(self.sess, checkpoint_path) extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb")) self.replay_buffer = extra_data[0] diff --git a/python/ray/rllib/es/es.py b/python/ray/rllib/es/es.py index 4052f6ed0..226e03bf0 100644 --- a/python/ray/rllib/es/es.py +++ b/python/ray/rllib/es/es.py @@ -203,7 +203,6 @@ class ESAgent(Agent): self.episodes_so_far = 0 self.timesteps_so_far = 0 self.tstart = time.time() - self.iteration = 0 def _collect_results(self, theta_id, min_eps, min_timesteps): num_eps, num_timesteps = 0, 0 @@ -224,7 +223,7 @@ class ESAgent(Agent): num_timesteps += result.lengths_n2.sum() return results - def train(self): + def _train(self): config = self.config step_tstart = time.time() @@ -314,14 +313,6 @@ class ESAgent(Agent): tlogger.record_tabular("TimeElapsed", step_tend - self.tstart) tlogger.dump_tabular() - if (config["snapshot_freq"] != 0 and - self.iteration % config["snapshot_freq"] == 0): - filename = os.path.join( - self.logdir, "snapshot_iter{:05d}.h5".format(self.iteration)) - assert not os.path.exists(filename) - self.policy.save(filename) - tlogger.log("Saved snapshot {}".format(filename)) - info = { "weights_norm": np.square(self.policy.get_trainable_flat()).sum(), "grad_norm": np.square(g).sum(), @@ -334,14 +325,16 @@ class ESAgent(Agent): "time_elapsed_this_iter": step_tend - step_tstart, "time_elapsed": step_tend - self.tstart } - res = TrainingResult(self.experiment_id.hex, self.iteration, - returns_n2.mean(), lengths_n2.mean(), info) - self.iteration += 1 + result = TrainingResult( + episode_reward_mean=returns_n2.mean(), + episode_len_mean=lengths_n2.mean(), + timesteps_this_iter=lengths_n2.sum(), + info=info) - return res + return result - def save(self): + def _save(self): checkpoint_path = os.path.join( self.logdir, "checkpoint-{}".format(self.iteration)) weights = self.policy.get_trainable_flat() @@ -349,18 +342,16 @@ class ESAgent(Agent): weights, self.ob_stat, self.episodes_so_far, - self.timesteps_so_far, - self.iteration] + self.timesteps_so_far] pickle.dump(objects, open(checkpoint_path, "wb")) return checkpoint_path - def restore(self, checkpoint_path): + def _restore(self, checkpoint_path): objects = pickle.load(open(checkpoint_path, "rb")) self.policy.set_trainable_flat(objects[0]) self.ob_stat = objects[1] self.episodes_so_far = objects[2] self.timesteps_so_far = objects[3] - self.iteration = objects[4] def compute_action(self, observation): return self.policy.act([observation])[0] diff --git a/python/ray/rllib/ppo/ppo.py b/python/ray/rllib/ppo/ppo.py index f742a9ade..c53c49c31 100644 --- a/python/ray/rllib/ppo/ppo.py +++ b/python/ray/rllib/ppo/ppo.py @@ -89,7 +89,6 @@ class PPOAgent(Agent): def _init(self): self.global_step = 0 - self.j = 0 self.kl_coeff = self.config["kl_coeff"] self.model = Runner(self.env_name, 1, self.config, self.logdir, False) self.agents = [ @@ -104,14 +103,12 @@ class PPOAgent(Agent): self.file_writer = None self.saver = tf.train.Saver(max_to_keep=None) - def train(self): + def _train(self): agents = self.agents config = self.config model = self.model - j = self.j - self.j += 1 - print("===> iteration", self.j) + print("===> iteration", self.iteration) iter_start = time.time() weights = ray.put(model.get_weights()) @@ -151,7 +148,7 @@ class PPOAgent(Agent): trajectory = shuffle(trajectory) shuffle_end = time.time() tuples_per_device = model.load_data( - trajectory, j == 0 and config["full_trace_data_load"]) + trajectory, self.iteration == 0 and config["full_trace_data_load"]) load_end = time.time() rollouts_time = rollouts_end - iter_start shuffle_time = shuffle_end - rollouts_end @@ -165,11 +162,11 @@ class PPOAgent(Agent): loss, policy_loss, vf_loss, kl, entropy = [], [], [], [], [] permutation = np.random.permutation(num_batches) # Prepare to drop into the debugger - if j == config["tf_debug_iteration"]: + if self.iteration == config["tf_debug_iteration"]: model.sess = tf_debug.LocalCLIDebugWrapperSession(model.sess) while batch_index < num_batches: full_trace = ( - i == 0 and j == 0 and + i == 0 and self.iteration == 0 and batch_index == config["full_trace_nth_sgd_batch"]) batch_loss, batch_policy_loss, batch_vf_loss, batch_kl, \ batch_entropy = model.run_sgd_minibatch( @@ -238,35 +235,36 @@ class PPOAgent(Agent): print("total time so far:", time.time() - self.start_time) result = TrainingResult( - self.experiment_id.hex, j, total_reward, traj_len_mean, info) + episode_reward_mean=total_reward, + episode_len_mean=traj_len_mean, + timesteps_this_iter=trajectory["dones"].shape[0], + info=info) return result - def save(self): + def _save(self): checkpoint_path = self.saver.save( self.model.sess, os.path.join(self.logdir, "checkpoint"), - global_step=self.j) + global_step=self.iteration) agent_state = ray.get([a.save.remote() for a in self.agents]) extra_data = [ self.model.save(), self.global_step, - self.j, self.kl_coeff, agent_state] pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb")) return checkpoint_path - def restore(self, checkpoint_path): + def _restore(self, checkpoint_path): self.saver.restore(self.model.sess, checkpoint_path) extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb")) self.model.restore(extra_data[0]) self.global_step = extra_data[1] - self.j = extra_data[2] - self.kl_coeff = extra_data[3] + self.kl_coeff = extra_data[2] ray.get([ a.restore.remote(o) - for (a, o) in zip(self.agents, extra_data[4])]) + for (a, o) in zip(self.agents, extra_data[3])]) def compute_action(self, observation): observation = self.model.observation_filter(observation) diff --git a/python/ray/rllib/train.py b/python/ray/rllib/train.py index 0bfb37efe..a606459b3 100755 --- a/python/ray/rllib/train.py +++ b/python/ray/rllib/train.py @@ -7,6 +7,7 @@ from __future__ import print_function import argparse import json import os +import pprint import sys import ray @@ -89,7 +90,8 @@ if __name__ == "__main__": cls=ray.rllib.common.RLLibEncoder) result_logger.write("\n") - print("current status: {}".format(result)) + print("== Iteration {} ==".format(alg.iteration)) + pprint.pprint(result._asdict()) if (i + 1) % args.checkpoint_freq == 0: print("checkpoint path: {}".format(alg.save()))