From 4be324efc36f5f0cd66ba2f6d81bb39fef258e3e Mon Sep 17 00:00:00 2001 From: old-bear Date: Thu, 23 Aug 2018 04:09:14 +0800 Subject: [PATCH] [tune] Support infinity value in report result (#2693) * + Compatibility fix under py2 on ray.tune * + Revert changes on master branch * + Use default JsonEncoder in ray.tune.logger * + Add UT for infinity support --- python/ray/tune/logger.py | 15 --------------- python/ray/tune/test/trial_runner_test.py | 17 +++++++++++++++++ 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index 2ce7ca2cf..c1ffbcafe 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -170,21 +170,6 @@ class _SafeFallbackEncoder(json.JSONEncoder): super(_SafeFallbackEncoder, self).__init__(**kwargs) self.nan_str = nan_str - def iterencode(self, o, _one_shot=False): - if self.ensure_ascii: - _encoder = json.encoder.encode_basestring_ascii - else: - _encoder = json.encoder.encode_basestring - - def floatstr(o, allow_nan=self.allow_nan, nan_str=self.nan_str): - return repr(o) if not np.isnan(o) else nan_str - - _iterencode = json.encoder._make_iterencode( - None, self.default, _encoder, self.indent, floatstr, - self.key_separator, self.item_separator, self.sort_keys, - self.skipkeys, _one_shot) - return _iterencode(o, 0) - def default(self, value): try: if np.isnan(value): diff --git a/python/ray/tune/test/trial_runner_test.py b/python/ray/tune/test/trial_runner_test.py index abfe9e97e..e7814a20c 100644 --- a/python/ray/tune/test/trial_runner_test.py +++ b/python/ray/tune/test/trial_runner_test.py @@ -366,6 +366,23 @@ class TrainableFunctionApiTest(unittest.TestCase): self.assertEqual(trial.status, Trial.TERMINATED) self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99) + def testReportInfinity(self): + def train(config, reporter): + for i in range(100): + reporter(mean_accuracy=float('inf')) + + register_trainable("f1", train) + [trial] = run_experiments({ + "foo": { + "run": "f1", + "config": { + "script_min_iter_time_s": 0, + }, + } + }) + self.assertEqual(trial.status, Trial.TERMINATED) + self.assertEqual(trial.last_result['mean_accuracy'], float('inf')) + class RunExperimentTest(unittest.TestCase): def setUp(self):