[tune] Support all serializable objects in config (#2287)

* wip

* order

* lint
This commit is contained in:
Eric Liang
2018-06-23 16:13:46 -07:00
committed by GitHub
parent aa42331844
commit 9c3bab5c42
4 changed files with 25 additions and 13 deletions
@@ -36,6 +36,7 @@ import ray
from ray.tune import grid_search, run_experiments, register_trainable
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import tensorflow as tf
@@ -227,6 +228,9 @@ if __name__ == '__main__':
},
'config': {
'activation': grid_search(['relu', 'elu', 'tanh']),
# You can pass any serializable object as well
'foo': grid_search([np.array([1, 2]),
np.array([2, 3])]),
},
}
+14 -11
View File
@@ -89,12 +89,12 @@ class _JsonLogger(Logger):
def _init(self):
config_out = os.path.join(self.logdir, "params.json")
with open(config_out, "w") as f:
json.dump(self.config, f, sort_keys=True, cls=_CustomEncoder)
json.dump(self.config, f, sort_keys=True, cls=_SafeFallbackEncoder)
local_file = os.path.join(self.logdir, "result.json")
self.local_out = open(local_file, "w")
def on_result(self, result):
json.dump(result._asdict(), self, cls=_CustomEncoder)
json.dump(result._asdict(), self, cls=_SafeFallbackEncoder)
self.write("\n")
def write(self, b):
@@ -150,9 +150,9 @@ class _VisKitLogger(Logger):
self._file.close()
class _CustomEncoder(json.JSONEncoder):
class _SafeFallbackEncoder(json.JSONEncoder):
def __init__(self, nan_str="null", **kwargs):
super(_CustomEncoder, self).__init__(**kwargs)
super(_SafeFallbackEncoder, self).__init__(**kwargs)
self.nan_str = nan_str
def iterencode(self, o, _one_shot=False):
@@ -171,12 +171,15 @@ class _CustomEncoder(json.JSONEncoder):
return _iterencode(o, 0)
def default(self, value):
if np.isnan(value):
return None
if np.issubdtype(value, float):
return float(value)
if np.issubdtype(value, int):
return int(value)
try:
if np.isnan(value):
return None
if np.issubdtype(value, float):
return float(value)
if np.issubdtype(value, int):
return int(value)
except Exception:
return str(value) # give up, just stringify it (ok for logs)
def pretty_print(result):
@@ -186,5 +189,5 @@ def pretty_print(result):
if v is not None:
out[k] = v
cleaned = json.dumps(out, cls=_CustomEncoder)
cleaned = json.dumps(out, cls=_SafeFallbackEncoder)
return yaml.safe_dump(json.loads(cleaned), default_flow_style=False)
+6 -1
View File
@@ -1,3 +1,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import json
import numpy
@@ -6,6 +10,7 @@ import random
import types
from ray.tune import TuneError
from ray.tune.logger import _SafeFallbackEncoder
from ray.tune.trial import Trial
from ray.tune.config_parser import make_parser, json_to_resources
@@ -17,7 +22,7 @@ def to_argv(config):
if isinstance(v, str):
argv.append(v)
else:
argv.append(json.dumps(v))
argv.append(json.dumps(v, cls=_SafeFallbackEncoder))
return argv