mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 04:23:03 +08:00
[tune] Support infinity value in report result (#2693)
* + Compatibility fix under py2 on ray.tune * + Revert changes on master branch * + Use default JsonEncoder in ray.tune.logger * + Add UT for infinity support
This commit is contained in:
@@ -170,21 +170,6 @@ class _SafeFallbackEncoder(json.JSONEncoder):
|
||||
super(_SafeFallbackEncoder, self).__init__(**kwargs)
|
||||
self.nan_str = nan_str
|
||||
|
||||
def iterencode(self, o, _one_shot=False):
|
||||
if self.ensure_ascii:
|
||||
_encoder = json.encoder.encode_basestring_ascii
|
||||
else:
|
||||
_encoder = json.encoder.encode_basestring
|
||||
|
||||
def floatstr(o, allow_nan=self.allow_nan, nan_str=self.nan_str):
|
||||
return repr(o) if not np.isnan(o) else nan_str
|
||||
|
||||
_iterencode = json.encoder._make_iterencode(
|
||||
None, self.default, _encoder, self.indent, floatstr,
|
||||
self.key_separator, self.item_separator, self.sort_keys,
|
||||
self.skipkeys, _one_shot)
|
||||
return _iterencode(o, 0)
|
||||
|
||||
def default(self, value):
|
||||
try:
|
||||
if np.isnan(value):
|
||||
|
||||
@@ -366,6 +366,23 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)
|
||||
|
||||
def testReportInfinity(self):
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(mean_accuracy=float('inf'))
|
||||
|
||||
register_trainable("f1", train)
|
||||
[trial] = run_experiments({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}
|
||||
})
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result['mean_accuracy'], float('inf'))
|
||||
|
||||
|
||||
class RunExperimentTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
||||
Reference in New Issue
Block a user