mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 22:53:20 +08:00
[tune] convert fallback representation to numbers in wandb integration (#10799)
This commit is contained in:
committed by
Barak Michener
parent
7bf5f1af8b
commit
2d08b2bb1c
@@ -2,6 +2,7 @@ import os
|
||||
import pickle
|
||||
from multiprocessing import Process, Queue
|
||||
from numbers import Number
|
||||
import numpy as np
|
||||
|
||||
from ray import logger
|
||||
from ray.tune import Trainable
|
||||
@@ -21,12 +22,21 @@ WANDB_ENV_VAR = "WANDB_API_KEY"
|
||||
_WANDB_QUEUE_END = (None, )
|
||||
|
||||
|
||||
def _is_allowed_type(obj):
|
||||
"""Return True if type is allowed for logging to wandb"""
|
||||
if isinstance(obj, np.ndarray) and obj.size == 1:
|
||||
return isinstance(obj.item(), Number)
|
||||
return isinstance(obj, Number)
|
||||
|
||||
|
||||
def _clean_log(obj):
|
||||
# Fixes https://github.com/ray-project/ray/issues/10631
|
||||
if isinstance(obj, dict):
|
||||
return {k: _clean_log(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [_clean_log(v) for v in obj]
|
||||
elif _is_allowed_type(obj):
|
||||
return obj
|
||||
|
||||
# Else
|
||||
try:
|
||||
@@ -40,7 +50,24 @@ def _clean_log(obj):
|
||||
return obj
|
||||
except Exception:
|
||||
# give up, similar to _SafeFallBackEncoder
|
||||
return str(obj)
|
||||
fallback = str(obj)
|
||||
|
||||
# Try to convert to int
|
||||
try:
|
||||
fallback = int(fallback)
|
||||
return fallback
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Try to convert to float
|
||||
try:
|
||||
fallback = float(fallback)
|
||||
return fallback
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Else, return string
|
||||
return fallback
|
||||
|
||||
|
||||
def wandb_mixin(func):
|
||||
@@ -182,7 +209,7 @@ class _WandbLoggingProcess(Process):
|
||||
k.startswith(item + "/") or k == item
|
||||
for item in self._exclude):
|
||||
continue
|
||||
elif not isinstance(v, Number):
|
||||
elif not _is_allowed_type(v):
|
||||
continue
|
||||
else:
|
||||
log[k] = v
|
||||
|
||||
@@ -4,6 +4,8 @@ from collections import namedtuple
|
||||
from multiprocessing import Queue
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.function_runner import wrap_function
|
||||
from ray.tune.integration.wandb import _WandbLoggingProcess, \
|
||||
@@ -156,6 +158,8 @@ class WandbIntegrationTest(unittest.TestCase):
|
||||
r1 = {
|
||||
"metric1": 0.8,
|
||||
"metric2": 1.4,
|
||||
"metric3": np.asarray(32.0),
|
||||
"metric4": np.float32(32.0),
|
||||
"const": "text",
|
||||
"config": trial_config
|
||||
}
|
||||
@@ -165,6 +169,8 @@ class WandbIntegrationTest(unittest.TestCase):
|
||||
logged = logger._wandb.logs.get(timeout=10)
|
||||
self.assertIn("metric1", logged)
|
||||
self.assertNotIn("metric2", logged)
|
||||
self.assertIn("metric3", logged)
|
||||
self.assertIn("metric4", logged)
|
||||
self.assertNotIn("const", logged)
|
||||
self.assertNotIn("config", logged)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user