mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 18:45:03 +08:00
[tune] Properly handle closing files in Trainable (#4232)
Fixes #3965. Using the with keyword/block will close to file immediately after the block ends
This commit is contained in:
committed by
Richard Liaw
parent
3483282254
commit
9551f2a92e
@@ -245,14 +245,15 @@ class Trainable(object):
|
||||
raise ValueError(
|
||||
"`_save` must return a dict or string type: {}".format(
|
||||
str(type(checkpoint))))
|
||||
pickle.dump({
|
||||
"experiment_id": self._experiment_id,
|
||||
"iteration": self._iteration,
|
||||
"timesteps_total": self._timesteps_total,
|
||||
"time_total": self._time_total,
|
||||
"episodes_total": self._episodes_total,
|
||||
"saved_as_dict": saved_as_dict
|
||||
}, open(checkpoint_path + ".tune_metadata", "wb"))
|
||||
with open(checkpoint_path + ".tune_metadata", "wb") as f:
|
||||
pickle.dump({
|
||||
"experiment_id": self._experiment_id,
|
||||
"iteration": self._iteration,
|
||||
"timesteps_total": self._timesteps_total,
|
||||
"time_total": self._time_total,
|
||||
"episodes_total": self._episodes_total,
|
||||
"saved_as_dict": saved_as_dict
|
||||
}, f)
|
||||
return checkpoint_path
|
||||
|
||||
def save_to_object(self):
|
||||
@@ -271,7 +272,8 @@ class Trainable(object):
|
||||
for path in os.listdir(base_dir):
|
||||
path = os.path.join(base_dir, path)
|
||||
if path.startswith(checkpoint_prefix):
|
||||
data[os.path.basename(path)] = open(path, "rb").read()
|
||||
with open(path, "rb") as f:
|
||||
data[os.path.basename(path)] = f.read()
|
||||
|
||||
out = io.BytesIO()
|
||||
data_dict = pickle.dumps({
|
||||
@@ -294,7 +296,8 @@ class Trainable(object):
|
||||
This method restores additional metadata saved with the checkpoint.
|
||||
"""
|
||||
|
||||
metadata = pickle.load(open(checkpoint_path + ".tune_metadata", "rb"))
|
||||
with open(checkpoint_path + ".tune_metadata", "rb") as f:
|
||||
metadata = pickle.load(f)
|
||||
self._experiment_id = metadata["experiment_id"]
|
||||
self._iteration = metadata["iteration"]
|
||||
self._timesteps_total = metadata["timesteps_total"]
|
||||
|
||||
Reference in New Issue
Block a user