mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 08:29:54 +08:00
[rllib] Fix rllib rollouts script and add test (#3211)
## What do these changes do? Clean up the checkpointing to handle the new checkpoint dirs. Add a test for rollout.py ## Related issue number https://github.com/ray-project/ray/issues/3206 https://github.com/ray-project/ray/issues/3204
This commit is contained in:
@@ -380,12 +380,11 @@ class Agent(Trainable):
|
||||
def _save(self, checkpoint_dir):
|
||||
checkpoint_path = os.path.join(checkpoint_dir,
|
||||
"checkpoint-{}".format(self.iteration))
|
||||
pickle.dump(self.__getstate__(),
|
||||
open(checkpoint_path + ".agent_state", "wb"))
|
||||
pickle.dump(self.__getstate__(), open(checkpoint_path, "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
extra_data = pickle.load(open(checkpoint_path + ".agent_state", "rb"))
|
||||
extra_data = pickle.load(open(checkpoint_path, "rb"))
|
||||
self.__setstate__(extra_data)
|
||||
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ def create_parser(parser_creator=None):
|
||||
const=True,
|
||||
help="Surpress rendering of the environment.")
|
||||
parser.add_argument(
|
||||
"--steps", default=None, help="Number of steps to roll out.")
|
||||
"--steps", default=10000, help="Number of steps to roll out.")
|
||||
parser.add_argument("--out", default=None, help="Output filename.")
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
@@ -70,6 +70,12 @@ def run(args, parser):
|
||||
# Load configuration from file
|
||||
config_dir = os.path.dirname(args.checkpoint)
|
||||
config_path = os.path.join(config_dir, "params.json")
|
||||
if not os.path.exists(config_path):
|
||||
config_path = os.path.join(config_dir, "../params.json")
|
||||
if not os.path.exists(config_path):
|
||||
raise ValueError(
|
||||
"Could not find params.json in either the checkpoint dir or "
|
||||
"its parent directory.")
|
||||
with open(config_path) as f:
|
||||
args.config = json.load(f)
|
||||
|
||||
|
||||
Executable
+28
@@ -0,0 +1,28 @@
|
||||
#!/bin/bash -e
|
||||
|
||||
TRAIN=/ray/python/ray/rllib/train.py
|
||||
if [ ! -e "$TRAIN" ]; then
|
||||
TRAIN=../train.py
|
||||
fi
|
||||
ROLLOUT=/ray/python/ray/rllib/rollout.py
|
||||
if [ ! -e "$ROLLOUT" ]; then
|
||||
ROLLOUT=../rollout.py
|
||||
fi
|
||||
|
||||
TMP=`mktemp -d`
|
||||
echo "Saving results to $TMP"
|
||||
|
||||
$TRAIN --local-dir=$TMP --run=IMPALA --checkpoint-freq=1 \
|
||||
--config='{"num_workers": 1, "num_gpus": 0}' --env=Pong-ram-v4 \
|
||||
--stop='{"training_iteration": 1}'
|
||||
find $TMP
|
||||
|
||||
CHECKPOINT_PATH=`ls $TMP/default/*/checkpoint_1/checkpoint-1`
|
||||
echo "Checkpoint path $CHECKPOINT_PATH"
|
||||
test -e "$CHECKPOINT_PATH"
|
||||
|
||||
$ROLLOUT --run=IMPALA "$CHECKPOINT_PATH" --steps=100 \
|
||||
--out="$TMP/rollouts.pkl" --no-render
|
||||
test -e "$TMP/rollouts.pkl"
|
||||
rm -rf "$TMP"
|
||||
echo "OK"
|
||||
@@ -23,7 +23,7 @@ apex:
|
||||
prioritized_replay_alpha: 0.5
|
||||
beta_annealing_fraction: 1.0
|
||||
final_prioritized_replay_beta: 1.0
|
||||
gpu: false
|
||||
gpu: true
|
||||
|
||||
# APEX
|
||||
num_workers: 8
|
||||
|
||||
@@ -159,12 +159,13 @@ class TrainMNIST(Trainable):
|
||||
self._train_iteration()
|
||||
return self._test()
|
||||
|
||||
def _save(self, path):
|
||||
torch.save(self.model.state_dict(), os.path.join(path, "model.pth"))
|
||||
return path
|
||||
def _save(self, checkpoint_dir):
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
|
||||
torch.save(self.model.state_dict(), checkpoint_path)
|
||||
return checkpoint_path
|
||||
|
||||
def _restore(self, path):
|
||||
self.model.load_state_dict(os.path.join(path, "model.pth"))
|
||||
def _restore(self, checkpoint_path):
|
||||
self.model.load_state_dict(checkpoint_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -627,7 +627,10 @@ class RunExperimentTest(unittest.TestCase):
|
||||
return {"timesteps_this_iter": 1, "done": True}
|
||||
|
||||
def _save(self, path):
|
||||
return path
|
||||
checkpoint = path + "/checkpoint"
|
||||
with open(checkpoint, "w") as f:
|
||||
f.write("OK")
|
||||
return checkpoint
|
||||
|
||||
trials = run_experiments({
|
||||
"foo": {
|
||||
|
||||
@@ -214,17 +214,28 @@ class Trainable(object):
|
||||
Checkpoint path that may be passed to restore().
|
||||
"""
|
||||
|
||||
checkpoint_path = tempfile.mkdtemp(
|
||||
prefix="checkpoint_{}".format(self._iteration),
|
||||
dir=checkpoint_dir or self.logdir)
|
||||
checkpoint = self._save(checkpoint_path)
|
||||
checkpoint_dir = os.path.join(checkpoint_dir or self.logdir,
|
||||
"checkpoint_{}".format(self._iteration))
|
||||
os.makedirs(checkpoint_dir)
|
||||
checkpoint = self._save(checkpoint_dir)
|
||||
saved_as_dict = False
|
||||
if isinstance(checkpoint, str):
|
||||
if (not checkpoint.startswith(checkpoint_dir)
|
||||
or checkpoint == checkpoint_dir):
|
||||
raise ValueError(
|
||||
"The returned checkpoint path must be within the "
|
||||
"given checkpoint dir {}: {}".format(
|
||||
checkpoint_dir, checkpoint))
|
||||
if not os.path.exists(checkpoint):
|
||||
raise ValueError(
|
||||
"The returned checkpoint path does not exist: {}".format(
|
||||
checkpoint))
|
||||
checkpoint_path = checkpoint
|
||||
elif isinstance(checkpoint, dict):
|
||||
saved_as_dict = True
|
||||
pickle.dump(checkpoint, open(checkpoint_path + ".tune_state",
|
||||
"wb"))
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(checkpoint_path, "wb") as f:
|
||||
pickle.dump(checkpoint, f)
|
||||
else:
|
||||
raise ValueError("Return value from `_save` must be dict or str.")
|
||||
pickle.dump({
|
||||
@@ -286,7 +297,7 @@ class Trainable(object):
|
||||
self._episodes_total = metadata["episodes_total"]
|
||||
saved_as_dict = metadata["saved_as_dict"]
|
||||
if saved_as_dict:
|
||||
with open(checkpoint_path + ".tune_state", "rb") as loaded_state:
|
||||
with open(checkpoint_path, "rb") as loaded_state:
|
||||
checkpoint_dict = pickle.load(loaded_state)
|
||||
self._restore(checkpoint_dict)
|
||||
else:
|
||||
@@ -343,7 +354,7 @@ class Trainable(object):
|
||||
|
||||
Args:
|
||||
checkpoint_dir (str): The directory where the checkpoint
|
||||
can be stored.
|
||||
file must be stored.
|
||||
|
||||
Returns:
|
||||
checkpoint (str | dict): If string, the return value is
|
||||
@@ -352,8 +363,10 @@ class Trainable(object):
|
||||
serialized by Tune and passed to `_restore()`.
|
||||
|
||||
Examples:
|
||||
>>> checkpoint_data = trainable._save(checkpoint_dir)
|
||||
>>> trainable2._restore(checkpoint_data)
|
||||
>>> print(trainable1._save("/tmp/checkpoint_1"))
|
||||
"/tmp/checkpoint_1/my_checkpoint_file"
|
||||
>>> print(trainable2._save("/tmp/checkpoint_2"))
|
||||
{"some": "data"}
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
Reference in New Issue
Block a user