[tune] Add custom field for serializations (#4237)

This commit is contained in:
Richard Liaw
2019-03-08 11:00:25 -08:00
committed by GitHub
parent 7e4b4822cf
commit c3a3360a4a
+21 -15
View File
@@ -299,6 +299,24 @@ class Trial(object):
self.error_file = None
self.num_failures = 0
# AutoML fields
self.results = None
self.best_result = None
self.param_config = None
self.extra_arg = None
self._nonjson_fields = [
"_checkpoint",
"config",
"loggers",
"sync_function",
"last_result",
"results",
"best_result",
"param_config",
"extra_arg",
]
self.trial_name = None
if trial_name_creator:
self.trial_name = trial_name_creator(self)
@@ -521,17 +539,8 @@ class Trial(object):
state = self.__dict__.copy()
state["resources"] = resources_to_json(self.resources)
# These are non-pickleable entries.
pickle_data = {
"_checkpoint": self._checkpoint,
"config": self.config,
"loggers": self.loggers,
"sync_function": self.sync_function,
"last_result": self.last_result
}
for key, value in pickle_data.items():
state[key] = binary_to_hex(cloudpickle.dumps(value))
for key in self._nonjson_fields:
state[key] = binary_to_hex(cloudpickle.dumps(state.get(key)))
state["runner"] = None
state["result_logger"] = None
@@ -547,10 +556,7 @@ class Trial(object):
def __setstate__(self, state):
logger_started = state.pop("__logger_started__")
state["resources"] = json_to_resources(state["resources"])
for key in [
"_checkpoint", "config", "loggers", "sync_function",
"last_result"
]:
for key in self._nonjson_fields:
state[key] = cloudpickle.loads(hex_to_binary(state[key]))
self.__dict__.update(state)