From 3f83b2daa95baf0de88b161399cf086d9b6e4a51 Mon Sep 17 00:00:00 2001 From: David Bignell Date: Wed, 6 Nov 2019 04:34:18 +0000 Subject: [PATCH] [rllib] Rollout extensions (#6065) * Rollout improvements * Make info-saving optional, to avoid breaking change. * Store generating ray version in checkpoint metadata * Keep the linter happy * Add small rollout test * Terse. * Update test_io.py --- doc/source/rllib-training.rst | 2 + python/ray/tune/trainable.py | 3 +- rllib/rollout.py | 219 +++++++++++++++++++++++++++++++--- rllib/tests/test_rollout.sh | 7 +- 4 files changed, 211 insertions(+), 20 deletions(-) diff --git a/doc/source/rllib-training.rst b/doc/source/rllib-training.rst index 973aaa889..d9f3c407e 100644 --- a/doc/source/rllib-training.rst +++ b/doc/source/rllib-training.rst @@ -59,6 +59,8 @@ The ``rollout.py`` helper script reconstructs a DQN policy from the checkpoint located at ``~/ray_results/default/DQN_CartPole-v0_0upjmdgr0/checkpoint_1/checkpoint-1`` and renders its behavior in the environment specified by ``--env``. +(Type ``rllib rollout --help`` to see the available evaluation options.) + Configuration ------------- diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 2cbb8be69..8bd24767c 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -292,7 +292,8 @@ class Trainable(object): "timesteps_total": self._timesteps_total, "time_total": self._time_total, "episodes_total": self._episodes_total, - "saved_as_dict": saved_as_dict + "saved_as_dict": saved_as_dict, + "ray_version": ray.__version__ }, f) return checkpoint_path diff --git a/rllib/rollout.py b/rllib/rollout.py index f7ea31e96..7dff98c5f 100755 --- a/rllib/rollout.py +++ b/rllib/rollout.py @@ -9,6 +9,8 @@ import collections import json import os import pickle +import shelve +from pathlib import Path import gym import ray @@ -35,6 +37,129 @@ Example Usage via executable: # register_env("pa_cartpole", lambda _: ParametricActionCartpole(10)) +class RolloutSaver: + """Utility class for storing rollouts. + + Currently supports two behaviours: the original, which + simply dumps everything to a pickle file once complete, + and a mode which stores each rollout as an entry in a Python + shelf db file. The latter mode is more robust to memory problems + or crashes part-way through the rollout generation. Each rollout + is stored with a key based on the episode number (0-indexed), + and the number of episodes is stored with the key "num_episodes", + so to load the shelf file, use something like: + + with shelve.open('rollouts.pkl') as rollouts: + for episode_index in range(rollouts["num_episodes"]): + rollout = rollouts[str(episode_index)] + + If outfile is None, this class does nothing. + """ + + def __init__(self, + outfile=None, + use_shelve=False, + write_update_file=False, + target_steps=None, + target_episodes=None, + save_info=False): + self._outfile = outfile + self._update_file = None + self._use_shelve = use_shelve + self._write_update_file = write_update_file + self._shelf = None + self._num_episodes = 0 + self._rollouts = [] + self._current_rollout = [] + self._total_steps = 0 + self._target_episodes = target_episodes + self._target_steps = target_steps + self._save_info = save_info + + def _get_tmp_progress_filename(self): + outpath = Path(self._outfile) + return outpath.parent / ("__progress_" + outpath.name) + + @property + def outfile(self): + return self._outfile + + def __enter__(self): + if self._outfile: + if self._use_shelve: + # Open a shelf file to store each rollout as they come in + self._shelf = shelve.open(self._outfile) + else: + # Original behaviour - keep all rollouts in memory and save + # them all at the end. + # But check we can actually write to the outfile before going + # through the effort of generating the rollouts: + try: + with open(self._outfile, "wb") as _: + pass + except IOError as x: + print("Can not open {} for writing - cancelling rollouts.". + format(self._outfile)) + raise x + if self._write_update_file: + # Open a file to track rollout progress: + self._update_file = self._get_tmp_progress_filename().open( + mode="w") + return self + + def __exit__(self, type, value, traceback): + if self._shelf: + # Close the shelf file, and store the number of episodes for ease + self._shelf["num_episodes"] = self._num_episodes + self._shelf.close() + elif self._outfile and not self._use_shelve: + # Dump everything as one big pickle: + pickle.dump(self._rollouts, open(self._outfile, "wb")) + if self._update_file: + # Remove the temp progress file: + self._get_tmp_progress_filename().unlink() + self._update_file = None + + def _get_progress(self): + if self._target_episodes: + return "{} / {} episodes completed".format(self._num_episodes, + self._target_episodes) + elif self._target_steps: + return "{} / {} steps completed".format(self._total_steps, + self._target_steps) + else: + return "{} episodes completed".format(self._num_episodes) + + def begin_rollout(self): + self._current_rollout = [] + + def end_rollout(self): + if self._outfile: + if self._use_shelve: + # Save this episode as a new entry in the shelf database, + # using the episode number as the key. + self._shelf[str(self._num_episodes)] = self._current_rollout + else: + # Append this rollout to our list, to save laer. + self._rollouts.append(self._current_rollout) + self._num_episodes += 1 + if self._update_file: + self._update_file.seek(0) + self._update_file.write(self._get_progress() + "\n") + self._update_file.flush() + + def append_step(self, obs, action, next_obs, reward, done, info): + """Add a step to the current rollout, if we are saving them""" + if self._outfile: + if self._save_info: + self._current_rollout.append( + [obs, action, next_obs, reward, done, info]) + else: + self._current_rollout.append( + [obs, action, next_obs, reward, done]) + self._total_steps += 1 + + def create_parser(parser_creator=None): parser_creator = parser_creator or argparse.ArgumentParser parser = parser_creator( @@ -62,6 +187,12 @@ def create_parser(parser_creator=None): action="store_const", const=True, help="Surpress rendering of the environment.") + parser.add_argument( + "--monitor", + default=False, + action="store_const", + const=True, + help="Wrap environment in gym Monitor to record video.") parser.add_argument( "--steps", default=10000, help="Number of steps to roll out.") parser.add_argument("--out", default=None, help="Output filename.") @@ -71,6 +202,29 @@ def create_parser(parser_creator=None): type=json.loads, help="Algorithm-specific configuration (e.g. env, hyperparams). " "Surpresses loading of configuration from checkpoint.") + parser.add_argument( + "--episodes", + default=0, + help="Number of complete episodes to roll out. (Overrides --steps)") + parser.add_argument( + "--save-info", + default=False, + action="store_true", + help="Save the info field generated by the step() method, " + "as well as the action, observations, rewards and done fields.") + parser.add_argument( + "--use-shelve", + default=False, + action="store_true", + help="Save rollouts into a python shelf file (will save each episode " + "as it is generated). An output filename must be set using --out.") + parser.add_argument( + "--track-progress", + default=False, + action="store_true", + help="Write progress to a temporary file (updated " + "after each episode). An output filename must be set using --out; " + "the progress file will live in the same folder.") return parser @@ -103,7 +257,16 @@ def run(args, parser): agent = cls(env=args.env, config=config) agent.restore(args.checkpoint) num_steps = int(args.steps) - rollout(agent, args.env, num_steps, args.out, args.no_render) + num_episodes = int(args.episodes) + with RolloutSaver( + args.out, + args.use_shelve, + write_update_file=args.track_progress, + target_steps=num_steps, + target_episodes=num_episodes, + save_info=args.save_info) as saver: + rollout(agent, args.env, num_steps, num_episodes, saver, + args.no_render, args.monitor) class DefaultMapping(collections.defaultdict): @@ -118,7 +281,25 @@ def default_policy_agent_mapping(unused_agent_id): return DEFAULT_POLICY_ID -def rollout(agent, env_name, num_steps, out=None, no_render=True): +def keep_going(steps, num_steps, episodes, num_episodes): + """Determine whether we've collected enough data""" + # if num_episodes is set, this overrides num_steps + if num_episodes: + return episodes < num_episodes + # if num_steps is set, continue until we reach the limit + if num_steps: + return steps < num_steps + # otherwise keep going forever + return True + + +def rollout(agent, + env_name, + num_steps, + num_episodes=0, + saver=RolloutSaver(), + no_render=True, + monitor=False): policy_agent_mapping = default_policy_agent_mapping if hasattr(agent, "workers"): @@ -140,13 +321,19 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True): multiagent = False use_lstm = {DEFAULT_POLICY_ID: False} - if out is not None: - rollouts = [] + if monitor and not no_render and saver and saver.outfile is not None: + # If monitoring has been requested, + # manually wrap our environment with a gym monitor + # which is set to record every episode. + env = gym.wrappers.Monitor( + env, os.path.join(os.path.dirname(saver.outfile), "monitor"), + lambda x: True) + steps = 0 - while steps < (num_steps or steps + 1): + episodes = 0 + while keep_going(steps, num_steps, episodes, num_episodes): mapping_cache = {} # in case policy_agent_mapping is stochastic - if out is not None: - rollout = [] + saver.begin_rollout() obs = env.reset() agent_states = DefaultMapping( lambda agent_id: state_init[mapping_cache[agent_id]]) @@ -155,7 +342,8 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True): prev_rewards = collections.defaultdict(lambda: 0.) done = False reward_total = 0.0 - while not done and steps < (num_steps or steps + 1): + while not done and keep_going(steps, num_steps, episodes, + num_episodes): multi_obs = obs if multiagent else {_DUMMY_AGENT_ID: obs} action_dict = {} for agent_id, a_obs in multi_obs.items(): @@ -183,7 +371,7 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True): action = action_dict action = action if multiagent else action[_DUMMY_AGENT_ID] - next_obs, reward, done, _ = env.step(action) + next_obs, reward, done, info = env.step(action) if multiagent: for agent_id, r in reward.items(): prev_rewards[agent_id] = r @@ -197,16 +385,13 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True): reward_total += reward if not no_render: env.render() - if out is not None: - rollout.append([obs, action, next_obs, reward, done]) + saver.append_step(obs, action, next_obs, reward, done, info) steps += 1 obs = next_obs - if out is not None: - rollouts.append(rollout) - print("Episode reward", reward_total) - - if out is not None: - pickle.dump(rollouts, open(out, "wb")) + saver.end_rollout() + print("Episode #{}: reward: {}".format(episodes, reward_total)) + if done: + episodes += 1 if __name__ == "__main__": diff --git a/rllib/tests/test_rollout.sh b/rllib/tests/test_rollout.sh index 02a65a7c5..eaafe3de7 100755 --- a/rllib/tests/test_rollout.sh +++ b/rllib/tests/test_rollout.sh @@ -22,7 +22,10 @@ echo "Checkpoint path $CHECKPOINT_PATH" test -e "$CHECKPOINT_PATH" $ROLLOUT --run=IMPALA "$CHECKPOINT_PATH" --steps=100 \ - --out="$TMP/rollouts.pkl" --no-render -test -e "$TMP/rollouts.pkl" + --out="$TMP/rollouts_100steps.pkl" --no-render +test -e "$TMP/rollouts_100steps.pkl" +$ROLLOUT --run=IMPALA "$CHECKPOINT_PATH" --episodes=1 \ + --out="$TMP/rollouts_1episode.pkl" --no-render +test -e "$TMP/rollouts_1episode.pkl" rm -rf "$TMP" echo "OK"