diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 0989ba3f4..be85eeb77 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -299,6 +299,24 @@ class Trial(object): self.error_file = None self.num_failures = 0 + # AutoML fields + self.results = None + self.best_result = None + self.param_config = None + self.extra_arg = None + + self._nonjson_fields = [ + "_checkpoint", + "config", + "loggers", + "sync_function", + "last_result", + "results", + "best_result", + "param_config", + "extra_arg", + ] + self.trial_name = None if trial_name_creator: self.trial_name = trial_name_creator(self) @@ -521,17 +539,8 @@ class Trial(object): state = self.__dict__.copy() state["resources"] = resources_to_json(self.resources) - # These are non-pickleable entries. - pickle_data = { - "_checkpoint": self._checkpoint, - "config": self.config, - "loggers": self.loggers, - "sync_function": self.sync_function, - "last_result": self.last_result - } - - for key, value in pickle_data.items(): - state[key] = binary_to_hex(cloudpickle.dumps(value)) + for key in self._nonjson_fields: + state[key] = binary_to_hex(cloudpickle.dumps(state.get(key))) state["runner"] = None state["result_logger"] = None @@ -547,10 +556,7 @@ class Trial(object): def __setstate__(self, state): logger_started = state.pop("__logger_started__") state["resources"] = json_to_resources(state["resources"]) - for key in [ - "_checkpoint", "config", "loggers", "sync_function", - "last_result" - ]: + for key in self._nonjson_fields: state[key] = cloudpickle.loads(hex_to_binary(state[key])) self.__dict__.update(state)