[RLlib] Deflake 2x remote & local inference tests (external env). (#13459)

This commit is contained in:
Sven Mika
2021-01-14 20:44:26 +01:00
committed by GitHub
parent c89ebdd94a
commit d98235cc84
3 changed files with 12 additions and 5 deletions
+2 -1
View File
@@ -10,7 +10,8 @@ else
basedir="rllib/examples/serving" # In bazel.
fi
(python $basedir/cartpole_server.py --run=PPO 2>&1 | grep -v 200) &
# Do not attempt to restore from checkpoint; leads to errors on travis.
(python $basedir/cartpole_server.py --run=PPO --no-restore 2>&1 | grep -v 200) &
pid=$!
echo "Waiting for server to start"
+2 -1
View File
@@ -10,7 +10,8 @@ else
basedir="rllib/examples/serving" # In bazel.
fi
(python $basedir/cartpole_server.py --run=DQN 2>&1 | grep -v 200) &
# Do not attempt to restore from checkpoint; leads to errors on travis.
(python $basedir/cartpole_server.py --run=DQN --no-restore 2>&1 | grep -v 200) &
pid=$!
echo "Waiting for server to start"
+8 -3
View File
@@ -23,6 +23,11 @@ parser = argparse.ArgumentParser()
parser.add_argument("--run", type=str, default="DQN")
parser.add_argument(
"--framework", type=str, choices=["tf", "torch"], default="tf")
parser.add_argument(
"--no-restore",
action="store_true",
help="Do not restore from a previously saved checkpoint (location of "
"which is saved in `last_checkpoint_[algo-name].out`).")
if __name__ == "__main__":
args = parser.parse_args()
@@ -65,13 +70,13 @@ if __name__ == "__main__":
checkpoint_path = CHECKPOINT_FILE.format(args.run)
# Attempt to restore from checkpoint if possible.
if os.path.exists(checkpoint_path):
# Attempt to restore from checkpoint, if possible.
if not args.no_restore and os.path.exists(checkpoint_path):
checkpoint_path = open(checkpoint_path).read()
print("Restoring from checkpoint path", checkpoint_path)
trainer.restore(checkpoint_path)
# Serving and training loop
# Serving and training loop.
while True:
print(pretty_print(trainer.train()))
checkpoint = trainer.save()