[rllib] Fix rllib rollouts script and add test (#3211)

## What do these changes do?

Clean up the checkpointing to handle the new checkpoint dirs. Add a test for rollout.py

## Related issue number

https://github.com/ray-project/ray/issues/3206
https://github.com/ray-project/ray/issues/3204
This commit is contained in:
Eric Liang
2018-11-05 00:33:25 -08:00
committed by Richard Liaw
parent 99bac44375
commit 813f51769f
8 changed files with 74 additions and 21 deletions
+2 -3
View File
@@ -380,12 +380,11 @@ class Agent(Trainable):
def _save(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir,
"checkpoint-{}".format(self.iteration))
pickle.dump(self.__getstate__(),
open(checkpoint_path + ".agent_state", "wb"))
pickle.dump(self.__getstate__(), open(checkpoint_path, "wb"))
return checkpoint_path
def _restore(self, checkpoint_path):
extra_data = pickle.load(open(checkpoint_path + ".agent_state", "rb"))
extra_data = pickle.load(open(checkpoint_path, "rb"))
self.__setstate__(extra_data)
+7 -1
View File
@@ -54,7 +54,7 @@ def create_parser(parser_creator=None):
const=True,
help="Surpress rendering of the environment.")
parser.add_argument(
"--steps", default=None, help="Number of steps to roll out.")
"--steps", default=10000, help="Number of steps to roll out.")
parser.add_argument("--out", default=None, help="Output filename.")
parser.add_argument(
"--config",
@@ -70,6 +70,12 @@ def run(args, parser):
# Load configuration from file
config_dir = os.path.dirname(args.checkpoint)
config_path = os.path.join(config_dir, "params.json")
if not os.path.exists(config_path):
config_path = os.path.join(config_dir, "../params.json")
if not os.path.exists(config_path):
raise ValueError(
"Could not find params.json in either the checkpoint dir or "
"its parent directory.")
with open(config_path) as f:
args.config = json.load(f)
+28
View File
@@ -0,0 +1,28 @@
#!/bin/bash -e
TRAIN=/ray/python/ray/rllib/train.py
if [ ! -e "$TRAIN" ]; then
TRAIN=../train.py
fi
ROLLOUT=/ray/python/ray/rllib/rollout.py
if [ ! -e "$ROLLOUT" ]; then
ROLLOUT=../rollout.py
fi
TMP=`mktemp -d`
echo "Saving results to $TMP"
$TRAIN --local-dir=$TMP --run=IMPALA --checkpoint-freq=1 \
--config='{"num_workers": 1, "num_gpus": 0}' --env=Pong-ram-v4 \
--stop='{"training_iteration": 1}'
find $TMP
CHECKPOINT_PATH=`ls $TMP/default/*/checkpoint_1/checkpoint-1`
echo "Checkpoint path $CHECKPOINT_PATH"
test -e "$CHECKPOINT_PATH"
$ROLLOUT --run=IMPALA "$CHECKPOINT_PATH" --steps=100 \
--out="$TMP/rollouts.pkl" --no-render
test -e "$TMP/rollouts.pkl"
rm -rf "$TMP"
echo "OK"
@@ -23,7 +23,7 @@ apex:
prioritized_replay_alpha: 0.5
beta_annealing_fraction: 1.0
final_prioritized_replay_beta: 1.0
gpu: false
gpu: true
# APEX
num_workers: 8
@@ -159,12 +159,13 @@ class TrainMNIST(Trainable):
self._train_iteration()
return self._test()
def _save(self, path):
torch.save(self.model.state_dict(), os.path.join(path, "model.pth"))
return path
def _save(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
torch.save(self.model.state_dict(), checkpoint_path)
return checkpoint_path
def _restore(self, path):
self.model.load_state_dict(os.path.join(path, "model.pth"))
def _restore(self, checkpoint_path):
self.model.load_state_dict(checkpoint_path)
if __name__ == '__main__':
+4 -1
View File
@@ -627,7 +627,10 @@ class RunExperimentTest(unittest.TestCase):
return {"timesteps_this_iter": 1, "done": True}
def _save(self, path):
return path
checkpoint = path + "/checkpoint"
with open(checkpoint, "w") as f:
f.write("OK")
return checkpoint
trials = run_experiments({
"foo": {
+23 -10
View File
@@ -214,17 +214,28 @@ class Trainable(object):
Checkpoint path that may be passed to restore().
"""
checkpoint_path = tempfile.mkdtemp(
prefix="checkpoint_{}".format(self._iteration),
dir=checkpoint_dir or self.logdir)
checkpoint = self._save(checkpoint_path)
checkpoint_dir = os.path.join(checkpoint_dir or self.logdir,
"checkpoint_{}".format(self._iteration))
os.makedirs(checkpoint_dir)
checkpoint = self._save(checkpoint_dir)
saved_as_dict = False
if isinstance(checkpoint, str):
if (not checkpoint.startswith(checkpoint_dir)
or checkpoint == checkpoint_dir):
raise ValueError(
"The returned checkpoint path must be within the "
"given checkpoint dir {}: {}".format(
checkpoint_dir, checkpoint))
if not os.path.exists(checkpoint):
raise ValueError(
"The returned checkpoint path does not exist: {}".format(
checkpoint))
checkpoint_path = checkpoint
elif isinstance(checkpoint, dict):
saved_as_dict = True
pickle.dump(checkpoint, open(checkpoint_path + ".tune_state",
"wb"))
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
with open(checkpoint_path, "wb") as f:
pickle.dump(checkpoint, f)
else:
raise ValueError("Return value from `_save` must be dict or str.")
pickle.dump({
@@ -286,7 +297,7 @@ class Trainable(object):
self._episodes_total = metadata["episodes_total"]
saved_as_dict = metadata["saved_as_dict"]
if saved_as_dict:
with open(checkpoint_path + ".tune_state", "rb") as loaded_state:
with open(checkpoint_path, "rb") as loaded_state:
checkpoint_dict = pickle.load(loaded_state)
self._restore(checkpoint_dict)
else:
@@ -343,7 +354,7 @@ class Trainable(object):
Args:
checkpoint_dir (str): The directory where the checkpoint
can be stored.
file must be stored.
Returns:
checkpoint (str | dict): If string, the return value is
@@ -352,8 +363,10 @@ class Trainable(object):
serialized by Tune and passed to `_restore()`.
Examples:
>>> checkpoint_data = trainable._save(checkpoint_dir)
>>> trainable2._restore(checkpoint_data)
>>> print(trainable1._save("/tmp/checkpoint_1"))
"/tmp/checkpoint_1/my_checkpoint_file"
>>> print(trainable2._save("/tmp/checkpoint_2"))
{"some": "data"}
"""
raise NotImplementedError