From 813f51769f1b2e572ee50672c5be78bb02535725 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 5 Nov 2018 00:33:25 -0800 Subject: [PATCH] [rllib] Fix rllib rollouts script and add test (#3211) ## What do these changes do? Clean up the checkpointing to handle the new checkpoint dirs. Add a test for rollout.py ## Related issue number https://github.com/ray-project/ray/issues/3206 https://github.com/ray-project/ray/issues/3204 --- python/ray/rllib/agents/agent.py | 5 ++- python/ray/rllib/rollout.py | 8 ++++- python/ray/rllib/test/test_rollout.sh | 28 ++++++++++++++++ .../ray/rllib/tuned_examples/atari-apex.yaml | 2 +- .../tune/examples/mnist_pytorch_trainable.py | 11 ++++--- python/ray/tune/test/trial_runner_test.py | 5 ++- python/ray/tune/trainable.py | 33 +++++++++++++------ test/jenkins_tests/run_multi_node_tests.sh | 3 ++ 8 files changed, 74 insertions(+), 21 deletions(-) create mode 100755 python/ray/rllib/test/test_rollout.sh diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index 6cf18c236..3ceb358c5 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -380,12 +380,11 @@ class Agent(Trainable): def _save(self, checkpoint_dir): checkpoint_path = os.path.join(checkpoint_dir, "checkpoint-{}".format(self.iteration)) - pickle.dump(self.__getstate__(), - open(checkpoint_path + ".agent_state", "wb")) + pickle.dump(self.__getstate__(), open(checkpoint_path, "wb")) return checkpoint_path def _restore(self, checkpoint_path): - extra_data = pickle.load(open(checkpoint_path + ".agent_state", "rb")) + extra_data = pickle.load(open(checkpoint_path, "rb")) self.__setstate__(extra_data) diff --git a/python/ray/rllib/rollout.py b/python/ray/rllib/rollout.py index 0e33e3d6c..29526e173 100755 --- a/python/ray/rllib/rollout.py +++ b/python/ray/rllib/rollout.py @@ -54,7 +54,7 @@ def create_parser(parser_creator=None): const=True, help="Surpress rendering of the environment.") parser.add_argument( - "--steps", default=None, help="Number of steps to roll out.") + "--steps", default=10000, help="Number of steps to roll out.") parser.add_argument("--out", default=None, help="Output filename.") parser.add_argument( "--config", @@ -70,6 +70,12 @@ def run(args, parser): # Load configuration from file config_dir = os.path.dirname(args.checkpoint) config_path = os.path.join(config_dir, "params.json") + if not os.path.exists(config_path): + config_path = os.path.join(config_dir, "../params.json") + if not os.path.exists(config_path): + raise ValueError( + "Could not find params.json in either the checkpoint dir or " + "its parent directory.") with open(config_path) as f: args.config = json.load(f) diff --git a/python/ray/rllib/test/test_rollout.sh b/python/ray/rllib/test/test_rollout.sh new file mode 100755 index 000000000..04685b2be --- /dev/null +++ b/python/ray/rllib/test/test_rollout.sh @@ -0,0 +1,28 @@ +#!/bin/bash -e + +TRAIN=/ray/python/ray/rllib/train.py +if [ ! -e "$TRAIN" ]; then + TRAIN=../train.py +fi +ROLLOUT=/ray/python/ray/rllib/rollout.py +if [ ! -e "$ROLLOUT" ]; then + ROLLOUT=../rollout.py +fi + +TMP=`mktemp -d` +echo "Saving results to $TMP" + +$TRAIN --local-dir=$TMP --run=IMPALA --checkpoint-freq=1 \ + --config='{"num_workers": 1, "num_gpus": 0}' --env=Pong-ram-v4 \ + --stop='{"training_iteration": 1}' +find $TMP + +CHECKPOINT_PATH=`ls $TMP/default/*/checkpoint_1/checkpoint-1` +echo "Checkpoint path $CHECKPOINT_PATH" +test -e "$CHECKPOINT_PATH" + +$ROLLOUT --run=IMPALA "$CHECKPOINT_PATH" --steps=100 \ + --out="$TMP/rollouts.pkl" --no-render +test -e "$TMP/rollouts.pkl" +rm -rf "$TMP" +echo "OK" diff --git a/python/ray/rllib/tuned_examples/atari-apex.yaml b/python/ray/rllib/tuned_examples/atari-apex.yaml index 19036a32b..23b1f19c1 100644 --- a/python/ray/rllib/tuned_examples/atari-apex.yaml +++ b/python/ray/rllib/tuned_examples/atari-apex.yaml @@ -23,7 +23,7 @@ apex: prioritized_replay_alpha: 0.5 beta_annealing_fraction: 1.0 final_prioritized_replay_beta: 1.0 - gpu: false + gpu: true # APEX num_workers: 8 diff --git a/python/ray/tune/examples/mnist_pytorch_trainable.py b/python/ray/tune/examples/mnist_pytorch_trainable.py index 2c0c68bce..b5ab0f2ab 100644 --- a/python/ray/tune/examples/mnist_pytorch_trainable.py +++ b/python/ray/tune/examples/mnist_pytorch_trainable.py @@ -159,12 +159,13 @@ class TrainMNIST(Trainable): self._train_iteration() return self._test() - def _save(self, path): - torch.save(self.model.state_dict(), os.path.join(path, "model.pth")) - return path + def _save(self, checkpoint_dir): + checkpoint_path = os.path.join(checkpoint_dir, "model.pth") + torch.save(self.model.state_dict(), checkpoint_path) + return checkpoint_path - def _restore(self, path): - self.model.load_state_dict(os.path.join(path, "model.pth")) + def _restore(self, checkpoint_path): + self.model.load_state_dict(checkpoint_path) if __name__ == '__main__': diff --git a/python/ray/tune/test/trial_runner_test.py b/python/ray/tune/test/trial_runner_test.py index 3c9ae43e6..f846467f0 100644 --- a/python/ray/tune/test/trial_runner_test.py +++ b/python/ray/tune/test/trial_runner_test.py @@ -627,7 +627,10 @@ class RunExperimentTest(unittest.TestCase): return {"timesteps_this_iter": 1, "done": True} def _save(self, path): - return path + checkpoint = path + "/checkpoint" + with open(checkpoint, "w") as f: + f.write("OK") + return checkpoint trials = run_experiments({ "foo": { diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 6c8b02cf0..5d5e682e7 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -214,17 +214,28 @@ class Trainable(object): Checkpoint path that may be passed to restore(). """ - checkpoint_path = tempfile.mkdtemp( - prefix="checkpoint_{}".format(self._iteration), - dir=checkpoint_dir or self.logdir) - checkpoint = self._save(checkpoint_path) + checkpoint_dir = os.path.join(checkpoint_dir or self.logdir, + "checkpoint_{}".format(self._iteration)) + os.makedirs(checkpoint_dir) + checkpoint = self._save(checkpoint_dir) saved_as_dict = False if isinstance(checkpoint, str): + if (not checkpoint.startswith(checkpoint_dir) + or checkpoint == checkpoint_dir): + raise ValueError( + "The returned checkpoint path must be within the " + "given checkpoint dir {}: {}".format( + checkpoint_dir, checkpoint)) + if not os.path.exists(checkpoint): + raise ValueError( + "The returned checkpoint path does not exist: {}".format( + checkpoint)) checkpoint_path = checkpoint elif isinstance(checkpoint, dict): saved_as_dict = True - pickle.dump(checkpoint, open(checkpoint_path + ".tune_state", - "wb")) + checkpoint_path = os.path.join(checkpoint_dir, "checkpoint") + with open(checkpoint_path, "wb") as f: + pickle.dump(checkpoint, f) else: raise ValueError("Return value from `_save` must be dict or str.") pickle.dump({ @@ -286,7 +297,7 @@ class Trainable(object): self._episodes_total = metadata["episodes_total"] saved_as_dict = metadata["saved_as_dict"] if saved_as_dict: - with open(checkpoint_path + ".tune_state", "rb") as loaded_state: + with open(checkpoint_path, "rb") as loaded_state: checkpoint_dict = pickle.load(loaded_state) self._restore(checkpoint_dict) else: @@ -343,7 +354,7 @@ class Trainable(object): Args: checkpoint_dir (str): The directory where the checkpoint - can be stored. + file must be stored. Returns: checkpoint (str | dict): If string, the return value is @@ -352,8 +363,10 @@ class Trainable(object): serialized by Tune and passed to `_restore()`. Examples: - >>> checkpoint_data = trainable._save(checkpoint_dir) - >>> trainable2._restore(checkpoint_data) + >>> print(trainable1._save("/tmp/checkpoint_1")) + "/tmp/checkpoint_1/my_checkpoint_file" + >>> print(trainable2._save("/tmp/checkpoint_2")) + {"some": "data"} """ raise NotImplementedError diff --git a/test/jenkins_tests/run_multi_node_tests.sh b/test/jenkins_tests/run_multi_node_tests.sh index 380730fcb..05fae9346 100755 --- a/test/jenkins_tests/run_multi_node_tests.sh +++ b/test/jenkins_tests/run_multi_node_tests.sh @@ -261,6 +261,9 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/test/test_supported_spaces.py +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ + /ray/python/ray/rllib/test/test_rollout.sh + docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/tune/examples/tune_mnist_ray.py \ --smoke-test