mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 19:08:52 +08:00
[tune] Support all serializable objects in config (#2287)
* wip * order * lint
This commit is contained in:
@@ -36,6 +36,7 @@ import ray
|
||||
from ray.tune import grid_search, run_experiments, register_trainable
|
||||
|
||||
from tensorflow.examples.tutorials.mnist import input_data
|
||||
import numpy as np
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
@@ -227,6 +228,9 @@ if __name__ == '__main__':
|
||||
},
|
||||
'config': {
|
||||
'activation': grid_search(['relu', 'elu', 'tanh']),
|
||||
# You can pass any serializable object as well
|
||||
'foo': grid_search([np.array([1, 2]),
|
||||
np.array([2, 3])]),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
+14
-11
@@ -89,12 +89,12 @@ class _JsonLogger(Logger):
|
||||
def _init(self):
|
||||
config_out = os.path.join(self.logdir, "params.json")
|
||||
with open(config_out, "w") as f:
|
||||
json.dump(self.config, f, sort_keys=True, cls=_CustomEncoder)
|
||||
json.dump(self.config, f, sort_keys=True, cls=_SafeFallbackEncoder)
|
||||
local_file = os.path.join(self.logdir, "result.json")
|
||||
self.local_out = open(local_file, "w")
|
||||
|
||||
def on_result(self, result):
|
||||
json.dump(result._asdict(), self, cls=_CustomEncoder)
|
||||
json.dump(result._asdict(), self, cls=_SafeFallbackEncoder)
|
||||
self.write("\n")
|
||||
|
||||
def write(self, b):
|
||||
@@ -150,9 +150,9 @@ class _VisKitLogger(Logger):
|
||||
self._file.close()
|
||||
|
||||
|
||||
class _CustomEncoder(json.JSONEncoder):
|
||||
class _SafeFallbackEncoder(json.JSONEncoder):
|
||||
def __init__(self, nan_str="null", **kwargs):
|
||||
super(_CustomEncoder, self).__init__(**kwargs)
|
||||
super(_SafeFallbackEncoder, self).__init__(**kwargs)
|
||||
self.nan_str = nan_str
|
||||
|
||||
def iterencode(self, o, _one_shot=False):
|
||||
@@ -171,12 +171,15 @@ class _CustomEncoder(json.JSONEncoder):
|
||||
return _iterencode(o, 0)
|
||||
|
||||
def default(self, value):
|
||||
if np.isnan(value):
|
||||
return None
|
||||
if np.issubdtype(value, float):
|
||||
return float(value)
|
||||
if np.issubdtype(value, int):
|
||||
return int(value)
|
||||
try:
|
||||
if np.isnan(value):
|
||||
return None
|
||||
if np.issubdtype(value, float):
|
||||
return float(value)
|
||||
if np.issubdtype(value, int):
|
||||
return int(value)
|
||||
except Exception:
|
||||
return str(value) # give up, just stringify it (ok for logs)
|
||||
|
||||
|
||||
def pretty_print(result):
|
||||
@@ -186,5 +189,5 @@ def pretty_print(result):
|
||||
if v is not None:
|
||||
out[k] = v
|
||||
|
||||
cleaned = json.dumps(out, cls=_CustomEncoder)
|
||||
cleaned = json.dumps(out, cls=_SafeFallbackEncoder)
|
||||
return yaml.safe_dump(json.loads(cleaned), default_flow_style=False)
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import json
|
||||
import numpy
|
||||
@@ -6,6 +10,7 @@ import random
|
||||
import types
|
||||
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.logger import _SafeFallbackEncoder
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.config_parser import make_parser, json_to_resources
|
||||
|
||||
@@ -17,7 +22,7 @@ def to_argv(config):
|
||||
if isinstance(v, str):
|
||||
argv.append(v)
|
||||
else:
|
||||
argv.append(json.dumps(v))
|
||||
argv.append(json.dumps(v, cls=_SafeFallbackEncoder))
|
||||
return argv
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user