From 0bfb0d2c2998cfd23cdea6a63adafb29548ed0b1 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Fri, 12 Apr 2019 21:03:56 -0700 Subject: [PATCH] [tune] Fix checkpointing for Gym Types --- python/ray/tune/trial_runner.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 7382065ab..9ecfcb540 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -45,11 +45,18 @@ def _find_newest_ckpt(ckpt_dir): class _TuneFunctionEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, function): - return { - "_type": "function", - "value": binary_to_hex(cloudpickle.dumps(obj)) - } - return super(_TuneFunctionEncoder, self).default(obj) + return self._to_cloudpickle(obj) + try: + return super(_TuneFunctionEncoder, self).default(obj) + except Exception: + logger.debug("Unable to encode. Falling back to cloudpickle.") + return self._to_cloudpickle(obj) + + def _to_cloudpickle(self, obj): + return { + "_type": "CLOUDPICKLE_FALLBACK", + "value": binary_to_hex(cloudpickle.dumps(obj)) + } class _TuneFunctionDecoder(json.JSONDecoder): @@ -58,10 +65,13 @@ class _TuneFunctionDecoder(json.JSONDecoder): self, object_hook=self.object_hook, *args, **kwargs) def object_hook(self, obj): - if obj.get("_type") == "function": - return cloudpickle.loads(hex_to_binary(obj["value"])) + if obj.get("_type") == "CLOUDPICKLE_FALLBACK": + return self._from_cloudpickle(obj) return obj + def _from_cloudpickle(self, obj): + return cloudpickle.loads(hex_to_binary(obj["value"])) + class TrialRunner(object): """A TrialRunner implements the event loop for scheduling trials on Ray.