diff --git a/rllib/env/tests/test_local_inference.sh b/rllib/env/tests/test_local_inference.sh index 7544d0cfb..f553ee32a 100755 --- a/rllib/env/tests/test_local_inference.sh +++ b/rllib/env/tests/test_local_inference.sh @@ -10,7 +10,8 @@ else basedir="rllib/examples/serving" # In bazel. fi -(python $basedir/cartpole_server.py --run=PPO 2>&1 | grep -v 200) & +# Do not attempt to restore from checkpoint; leads to errors on travis. +(python $basedir/cartpole_server.py --run=PPO --no-restore 2>&1 | grep -v 200) & pid=$! echo "Waiting for server to start" diff --git a/rllib/env/tests/test_remote_inference.sh b/rllib/env/tests/test_remote_inference.sh index 88dce1a32..bbfa21c70 100755 --- a/rllib/env/tests/test_remote_inference.sh +++ b/rllib/env/tests/test_remote_inference.sh @@ -10,7 +10,8 @@ else basedir="rllib/examples/serving" # In bazel. fi -(python $basedir/cartpole_server.py --run=DQN 2>&1 | grep -v 200) & +# Do not attempt to restore from checkpoint; leads to errors on travis. +(python $basedir/cartpole_server.py --run=DQN --no-restore 2>&1 | grep -v 200) & pid=$! echo "Waiting for server to start" diff --git a/rllib/examples/serving/cartpole_server.py b/rllib/examples/serving/cartpole_server.py index 167d424a8..297320422 100755 --- a/rllib/examples/serving/cartpole_server.py +++ b/rllib/examples/serving/cartpole_server.py @@ -23,6 +23,11 @@ parser = argparse.ArgumentParser() parser.add_argument("--run", type=str, default="DQN") parser.add_argument( "--framework", type=str, choices=["tf", "torch"], default="tf") +parser.add_argument( + "--no-restore", + action="store_true", + help="Do not restore from a previously saved checkpoint (location of " + "which is saved in `last_checkpoint_[algo-name].out`).") if __name__ == "__main__": args = parser.parse_args() @@ -65,13 +70,13 @@ if __name__ == "__main__": checkpoint_path = CHECKPOINT_FILE.format(args.run) - # Attempt to restore from checkpoint if possible. - if os.path.exists(checkpoint_path): + # Attempt to restore from checkpoint, if possible. + if not args.no_restore and os.path.exists(checkpoint_path): checkpoint_path = open(checkpoint_path).read() print("Restoring from checkpoint path", checkpoint_path) trainer.restore(checkpoint_path) - # Serving and training loop + # Serving and training loop. while True: print(pretty_print(trainer.train())) checkpoint = trainer.save()