From 9551f2a92ecdb85d38ef1f745945c8d848506131 Mon Sep 17 00:00:00 2001 From: Adi Zimmerman Date: Sun, 3 Mar 2019 14:23:05 -0800 Subject: [PATCH] [tune] Properly handle closing files in Trainable (#4232) Fixes #3965. Using the with keyword/block will close to file immediately after the block ends --- python/ray/tune/trainable.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 938aa0b63..32ea413b6 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -245,14 +245,15 @@ class Trainable(object): raise ValueError( "`_save` must return a dict or string type: {}".format( str(type(checkpoint)))) - pickle.dump({ - "experiment_id": self._experiment_id, - "iteration": self._iteration, - "timesteps_total": self._timesteps_total, - "time_total": self._time_total, - "episodes_total": self._episodes_total, - "saved_as_dict": saved_as_dict - }, open(checkpoint_path + ".tune_metadata", "wb")) + with open(checkpoint_path + ".tune_metadata", "wb") as f: + pickle.dump({ + "experiment_id": self._experiment_id, + "iteration": self._iteration, + "timesteps_total": self._timesteps_total, + "time_total": self._time_total, + "episodes_total": self._episodes_total, + "saved_as_dict": saved_as_dict + }, f) return checkpoint_path def save_to_object(self): @@ -271,7 +272,8 @@ class Trainable(object): for path in os.listdir(base_dir): path = os.path.join(base_dir, path) if path.startswith(checkpoint_prefix): - data[os.path.basename(path)] = open(path, "rb").read() + with open(path, "rb") as f: + data[os.path.basename(path)] = f.read() out = io.BytesIO() data_dict = pickle.dumps({ @@ -294,7 +296,8 @@ class Trainable(object): This method restores additional metadata saved with the checkpoint. """ - metadata = pickle.load(open(checkpoint_path + ".tune_metadata", "rb")) + with open(checkpoint_path + ".tune_metadata", "rb") as f: + metadata = pickle.load(f) self._experiment_id = metadata["experiment_id"] self._iteration = metadata["iteration"] self._timesteps_total = metadata["timesteps_total"]