[tune] convert fallback representation to numbers in wandb integration (#10799)

This commit is contained in:
Kai Fricke
2020-09-15 19:47:11 +01:00
committed by Barak Michener
parent 7bf5f1af8b
commit 2d08b2bb1c
2 changed files with 35 additions and 2 deletions
+29 -2
View File
@@ -2,6 +2,7 @@ import os
import pickle
from multiprocessing import Process, Queue
from numbers import Number
import numpy as np
from ray import logger
from ray.tune import Trainable
@@ -21,12 +22,21 @@ WANDB_ENV_VAR = "WANDB_API_KEY"
_WANDB_QUEUE_END = (None, )
def _is_allowed_type(obj):
"""Return True if type is allowed for logging to wandb"""
if isinstance(obj, np.ndarray) and obj.size == 1:
return isinstance(obj.item(), Number)
return isinstance(obj, Number)
def _clean_log(obj):
# Fixes https://github.com/ray-project/ray/issues/10631
if isinstance(obj, dict):
return {k: _clean_log(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [_clean_log(v) for v in obj]
elif _is_allowed_type(obj):
return obj
# Else
try:
@@ -40,7 +50,24 @@ def _clean_log(obj):
return obj
except Exception:
# give up, similar to _SafeFallBackEncoder
return str(obj)
fallback = str(obj)
# Try to convert to int
try:
fallback = int(fallback)
return fallback
except ValueError:
pass
# Try to convert to float
try:
fallback = float(fallback)
return fallback
except ValueError:
pass
# Else, return string
return fallback
def wandb_mixin(func):
@@ -182,7 +209,7 @@ class _WandbLoggingProcess(Process):
k.startswith(item + "/") or k == item
for item in self._exclude):
continue
elif not isinstance(v, Number):
elif not _is_allowed_type(v):
continue
else:
log[k] = v
@@ -4,6 +4,8 @@ from collections import namedtuple
from multiprocessing import Queue
import unittest
import numpy as np
from ray.tune import Trainable
from ray.tune.function_runner import wrap_function
from ray.tune.integration.wandb import _WandbLoggingProcess, \
@@ -156,6 +158,8 @@ class WandbIntegrationTest(unittest.TestCase):
r1 = {
"metric1": 0.8,
"metric2": 1.4,
"metric3": np.asarray(32.0),
"metric4": np.float32(32.0),
"const": "text",
"config": trial_config
}
@@ -165,6 +169,8 @@ class WandbIntegrationTest(unittest.TestCase):
logged = logger._wandb.logs.get(timeout=10)
self.assertIn("metric1", logged)
self.assertNotIn("metric2", logged)
self.assertIn("metric3", logged)
self.assertIn("metric4", logged)
self.assertNotIn("const", logged)
self.assertNotIn("config", logged)