[tune] Fix checkpointing for Gym Types

This commit is contained in:
Richard Liaw
2019-04-12 21:03:56 -07:00
committed by Eric Liang
parent 6e7680bf21
commit 0bfb0d2c29
+17 -7
View File
@@ -45,11 +45,18 @@ def _find_newest_ckpt(ckpt_dir):
class _TuneFunctionEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, function):
return {
"_type": "function",
"value": binary_to_hex(cloudpickle.dumps(obj))
}
return super(_TuneFunctionEncoder, self).default(obj)
return self._to_cloudpickle(obj)
try:
return super(_TuneFunctionEncoder, self).default(obj)
except Exception:
logger.debug("Unable to encode. Falling back to cloudpickle.")
return self._to_cloudpickle(obj)
def _to_cloudpickle(self, obj):
return {
"_type": "CLOUDPICKLE_FALLBACK",
"value": binary_to_hex(cloudpickle.dumps(obj))
}
class _TuneFunctionDecoder(json.JSONDecoder):
@@ -58,10 +65,13 @@ class _TuneFunctionDecoder(json.JSONDecoder):
self, object_hook=self.object_hook, *args, **kwargs)
def object_hook(self, obj):
if obj.get("_type") == "function":
return cloudpickle.loads(hex_to_binary(obj["value"]))
if obj.get("_type") == "CLOUDPICKLE_FALLBACK":
return self._from_cloudpickle(obj)
return obj
def _from_cloudpickle(self, obj):
return cloudpickle.loads(hex_to_binary(obj["value"]))
class TrialRunner(object):
"""A TrialRunner implements the event loop for scheduling trials on Ray.