mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 02:00:12 +08:00
[tune] Add custom field for serializations (#4237)
This commit is contained in:
+21
-15
@@ -299,6 +299,24 @@ class Trial(object):
|
||||
self.error_file = None
|
||||
self.num_failures = 0
|
||||
|
||||
# AutoML fields
|
||||
self.results = None
|
||||
self.best_result = None
|
||||
self.param_config = None
|
||||
self.extra_arg = None
|
||||
|
||||
self._nonjson_fields = [
|
||||
"_checkpoint",
|
||||
"config",
|
||||
"loggers",
|
||||
"sync_function",
|
||||
"last_result",
|
||||
"results",
|
||||
"best_result",
|
||||
"param_config",
|
||||
"extra_arg",
|
||||
]
|
||||
|
||||
self.trial_name = None
|
||||
if trial_name_creator:
|
||||
self.trial_name = trial_name_creator(self)
|
||||
@@ -521,17 +539,8 @@ class Trial(object):
|
||||
state = self.__dict__.copy()
|
||||
state["resources"] = resources_to_json(self.resources)
|
||||
|
||||
# These are non-pickleable entries.
|
||||
pickle_data = {
|
||||
"_checkpoint": self._checkpoint,
|
||||
"config": self.config,
|
||||
"loggers": self.loggers,
|
||||
"sync_function": self.sync_function,
|
||||
"last_result": self.last_result
|
||||
}
|
||||
|
||||
for key, value in pickle_data.items():
|
||||
state[key] = binary_to_hex(cloudpickle.dumps(value))
|
||||
for key in self._nonjson_fields:
|
||||
state[key] = binary_to_hex(cloudpickle.dumps(state.get(key)))
|
||||
|
||||
state["runner"] = None
|
||||
state["result_logger"] = None
|
||||
@@ -547,10 +556,7 @@ class Trial(object):
|
||||
def __setstate__(self, state):
|
||||
logger_started = state.pop("__logger_started__")
|
||||
state["resources"] = json_to_resources(state["resources"])
|
||||
for key in [
|
||||
"_checkpoint", "config", "loggers", "sync_function",
|
||||
"last_result"
|
||||
]:
|
||||
for key in self._nonjson_fields:
|
||||
state[key] = cloudpickle.loads(hex_to_binary(state[key]))
|
||||
|
||||
self.__dict__.update(state)
|
||||
|
||||
Reference in New Issue
Block a user