[RLlib] Issue 7046 cannot restore keras model from h5 file. (#7482)

This commit is contained in:
Sven Mika
2020-03-23 20:19:30 +01:00
committed by GitHub
parent ee8c9ff732
commit 1138f2ebed
15 changed files with 364 additions and 40 deletions
+4 -1
View File
@@ -473,13 +473,16 @@ class Trainable:
export model to local directory.
Args:
export_formats (list): List of formats that should be exported.
export_formats (Union[list,str]): Format or list of (str) formats
that should be exported.
export_dir (str): Optional dir to place the exported model.
Defaults to self.logdir.
Returns:
A dict that maps ExportFormats to successfully exported models.
"""
if isinstance(export_formats, str):
export_formats = [export_formats]
export_dir = export_dir or self.logdir
return self._export_model(export_formats, export_dir)
+11 -9
View File
@@ -48,28 +48,30 @@ class Location:
class ExportFormat:
"""Describes the format to export the trial Trainable.
"""Describes the format to import/export the trial Trainable.
This may correspond to different file formats based on the
Trainable implementation.
"""
CHECKPOINT = "checkpoint"
MODEL = "model"
H5 = "h5"
@staticmethod
def validate(export_formats):
"""Validates export_formats.
def validate(formats):
"""Validates formats.
Raises:
ValueError if the format is unknown.
"""
for i in range(len(export_formats)):
export_formats[i] = export_formats[i].strip().lower()
if export_formats[i] not in [
ExportFormat.CHECKPOINT, ExportFormat.MODEL
for i in range(len(formats)):
formats[i] = formats[i].strip().lower()
if formats[i] not in [
ExportFormat.CHECKPOINT, ExportFormat.MODEL,
ExportFormat.H5
]:
raise TuneError("Unsupported export format: " +
export_formats[i])
raise TuneError("Unsupported import/export format: " +
formats[i])
def checkpoint_deleter(trial_id, runner):