diff --git a/doc/source/tune-config.rst b/doc/source/tune-config.rst index fc024c79d..005eb3f8f 100644 --- a/doc/source/tune-config.rst +++ b/doc/source/tune-config.rst @@ -22,7 +22,7 @@ a single experiment or a list of experiments to `run_experiments`, as follows: An example of this can be found in `hyperband_example.py `__. -Alternatively, you can pass in a JSON object. This uses the same fields as +Alternatively, you can pass in a Python dict. This uses the same fields as the `ray.tune.Experiment`, except the experiment name is the key of the top level dictionary. diff --git a/python/ray/tune/examples/tune_mnist_ray.py b/python/ray/tune/examples/tune_mnist_ray.py index 176bffbc0..e806a1a68 100755 --- a/python/ray/tune/examples/tune_mnist_ray.py +++ b/python/ray/tune/examples/tune_mnist_ray.py @@ -36,6 +36,7 @@ import ray from ray.tune import grid_search, run_experiments, register_trainable from tensorflow.examples.tutorials.mnist import input_data +import numpy as np import tensorflow as tf @@ -227,6 +228,9 @@ if __name__ == '__main__': }, 'config': { 'activation': grid_search(['relu', 'elu', 'tanh']), + # You can pass any serializable object as well + 'foo': grid_search([np.array([1, 2]), + np.array([2, 3])]), }, } diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index bb708f99a..0be6ed4a3 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -89,12 +89,12 @@ class _JsonLogger(Logger): def _init(self): config_out = os.path.join(self.logdir, "params.json") with open(config_out, "w") as f: - json.dump(self.config, f, sort_keys=True, cls=_CustomEncoder) + json.dump(self.config, f, sort_keys=True, cls=_SafeFallbackEncoder) local_file = os.path.join(self.logdir, "result.json") self.local_out = open(local_file, "w") def on_result(self, result): - json.dump(result._asdict(), self, cls=_CustomEncoder) + json.dump(result._asdict(), self, cls=_SafeFallbackEncoder) self.write("\n") def write(self, b): @@ -150,9 +150,9 @@ class _VisKitLogger(Logger): self._file.close() -class _CustomEncoder(json.JSONEncoder): +class _SafeFallbackEncoder(json.JSONEncoder): def __init__(self, nan_str="null", **kwargs): - super(_CustomEncoder, self).__init__(**kwargs) + super(_SafeFallbackEncoder, self).__init__(**kwargs) self.nan_str = nan_str def iterencode(self, o, _one_shot=False): @@ -171,12 +171,15 @@ class _CustomEncoder(json.JSONEncoder): return _iterencode(o, 0) def default(self, value): - if np.isnan(value): - return None - if np.issubdtype(value, float): - return float(value) - if np.issubdtype(value, int): - return int(value) + try: + if np.isnan(value): + return None + if np.issubdtype(value, float): + return float(value) + if np.issubdtype(value, int): + return int(value) + except Exception: + return str(value) # give up, just stringify it (ok for logs) def pretty_print(result): @@ -186,5 +189,5 @@ def pretty_print(result): if v is not None: out[k] = v - cleaned = json.dumps(out, cls=_CustomEncoder) + cleaned = json.dumps(out, cls=_SafeFallbackEncoder) return yaml.safe_dump(json.loads(cleaned), default_flow_style=False) diff --git a/python/ray/tune/variant_generator.py b/python/ray/tune/variant_generator.py index 0d8957c41..7ca8df0ee 100644 --- a/python/ray/tune/variant_generator.py +++ b/python/ray/tune/variant_generator.py @@ -1,3 +1,7 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + import copy import json import numpy @@ -6,6 +10,7 @@ import random import types from ray.tune import TuneError +from ray.tune.logger import _SafeFallbackEncoder from ray.tune.trial import Trial from ray.tune.config_parser import make_parser, json_to_resources @@ -17,7 +22,7 @@ def to_argv(config): if isinstance(v, str): argv.append(v) else: - argv.append(json.dumps(v)) + argv.append(json.dumps(v, cls=_SafeFallbackEncoder)) return argv