mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 17:18:45 +08:00
[tune] Fix checkpointing for Gym Types
This commit is contained in:
@@ -45,11 +45,18 @@ def _find_newest_ckpt(ckpt_dir):
|
||||
class _TuneFunctionEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, function):
|
||||
return {
|
||||
"_type": "function",
|
||||
"value": binary_to_hex(cloudpickle.dumps(obj))
|
||||
}
|
||||
return super(_TuneFunctionEncoder, self).default(obj)
|
||||
return self._to_cloudpickle(obj)
|
||||
try:
|
||||
return super(_TuneFunctionEncoder, self).default(obj)
|
||||
except Exception:
|
||||
logger.debug("Unable to encode. Falling back to cloudpickle.")
|
||||
return self._to_cloudpickle(obj)
|
||||
|
||||
def _to_cloudpickle(self, obj):
|
||||
return {
|
||||
"_type": "CLOUDPICKLE_FALLBACK",
|
||||
"value": binary_to_hex(cloudpickle.dumps(obj))
|
||||
}
|
||||
|
||||
|
||||
class _TuneFunctionDecoder(json.JSONDecoder):
|
||||
@@ -58,10 +65,13 @@ class _TuneFunctionDecoder(json.JSONDecoder):
|
||||
self, object_hook=self.object_hook, *args, **kwargs)
|
||||
|
||||
def object_hook(self, obj):
|
||||
if obj.get("_type") == "function":
|
||||
return cloudpickle.loads(hex_to_binary(obj["value"]))
|
||||
if obj.get("_type") == "CLOUDPICKLE_FALLBACK":
|
||||
return self._from_cloudpickle(obj)
|
||||
return obj
|
||||
|
||||
def _from_cloudpickle(self, obj):
|
||||
return cloudpickle.loads(hex_to_binary(obj["value"]))
|
||||
|
||||
|
||||
class TrialRunner(object):
|
||||
"""A TrialRunner implements the event loop for scheduling trials on Ray.
|
||||
|
||||
Reference in New Issue
Block a user