diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 7382065ab..9ecfcb540 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -45,11 +45,18 @@ def _find_newest_ckpt(ckpt_dir): class _TuneFunctionEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, function): - return { - "_type": "function", - "value": binary_to_hex(cloudpickle.dumps(obj)) - } - return super(_TuneFunctionEncoder, self).default(obj) + return self._to_cloudpickle(obj) + try: + return super(_TuneFunctionEncoder, self).default(obj) + except Exception: + logger.debug("Unable to encode. Falling back to cloudpickle.") + return self._to_cloudpickle(obj) + + def _to_cloudpickle(self, obj): + return { + "_type": "CLOUDPICKLE_FALLBACK", + "value": binary_to_hex(cloudpickle.dumps(obj)) + } class _TuneFunctionDecoder(json.JSONDecoder): @@ -58,10 +65,13 @@ class _TuneFunctionDecoder(json.JSONDecoder): self, object_hook=self.object_hook, *args, **kwargs) def object_hook(self, obj): - if obj.get("_type") == "function": - return cloudpickle.loads(hex_to_binary(obj["value"])) + if obj.get("_type") == "CLOUDPICKLE_FALLBACK": + return self._from_cloudpickle(obj) return obj + def _from_cloudpickle(self, obj): + return cloudpickle.loads(hex_to_binary(obj["value"])) + class TrialRunner(object): """A TrialRunner implements the event loop for scheduling trials on Ray.