mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 10:33:16 +08:00
[RLlib] Issue 7046 cannot restore keras model from h5 file. (#7482)
This commit is contained in:
@@ -473,13 +473,16 @@ class Trainable:
|
||||
export model to local directory.
|
||||
|
||||
Args:
|
||||
export_formats (list): List of formats that should be exported.
|
||||
export_formats (Union[list,str]): Format or list of (str) formats
|
||||
that should be exported.
|
||||
export_dir (str): Optional dir to place the exported model.
|
||||
Defaults to self.logdir.
|
||||
|
||||
Returns:
|
||||
A dict that maps ExportFormats to successfully exported models.
|
||||
"""
|
||||
if isinstance(export_formats, str):
|
||||
export_formats = [export_formats]
|
||||
export_dir = export_dir or self.logdir
|
||||
return self._export_model(export_formats, export_dir)
|
||||
|
||||
|
||||
@@ -48,28 +48,30 @@ class Location:
|
||||
|
||||
|
||||
class ExportFormat:
|
||||
"""Describes the format to export the trial Trainable.
|
||||
"""Describes the format to import/export the trial Trainable.
|
||||
|
||||
This may correspond to different file formats based on the
|
||||
Trainable implementation.
|
||||
"""
|
||||
CHECKPOINT = "checkpoint"
|
||||
MODEL = "model"
|
||||
H5 = "h5"
|
||||
|
||||
@staticmethod
|
||||
def validate(export_formats):
|
||||
"""Validates export_formats.
|
||||
def validate(formats):
|
||||
"""Validates formats.
|
||||
|
||||
Raises:
|
||||
ValueError if the format is unknown.
|
||||
"""
|
||||
for i in range(len(export_formats)):
|
||||
export_formats[i] = export_formats[i].strip().lower()
|
||||
if export_formats[i] not in [
|
||||
ExportFormat.CHECKPOINT, ExportFormat.MODEL
|
||||
for i in range(len(formats)):
|
||||
formats[i] = formats[i].strip().lower()
|
||||
if formats[i] not in [
|
||||
ExportFormat.CHECKPOINT, ExportFormat.MODEL,
|
||||
ExportFormat.H5
|
||||
]:
|
||||
raise TuneError("Unsupported export format: " +
|
||||
export_formats[i])
|
||||
raise TuneError("Unsupported import/export format: " +
|
||||
formats[i])
|
||||
|
||||
|
||||
def checkpoint_deleter(trial_id, runner):
|
||||
|
||||
Reference in New Issue
Block a user