diff --git a/python/ray/tune/integration/wandb.py b/python/ray/tune/integration/wandb.py index b3248cbcb..7f8b35c14 100644 --- a/python/ray/tune/integration/wandb.py +++ b/python/ray/tune/integration/wandb.py @@ -2,6 +2,7 @@ import os import pickle from multiprocessing import Process, Queue from numbers import Number +import numpy as np from ray import logger from ray.tune import Trainable @@ -21,12 +22,21 @@ WANDB_ENV_VAR = "WANDB_API_KEY" _WANDB_QUEUE_END = (None, ) +def _is_allowed_type(obj): + """Return True if type is allowed for logging to wandb""" + if isinstance(obj, np.ndarray) and obj.size == 1: + return isinstance(obj.item(), Number) + return isinstance(obj, Number) + + def _clean_log(obj): # Fixes https://github.com/ray-project/ray/issues/10631 if isinstance(obj, dict): return {k: _clean_log(v) for k, v in obj.items()} elif isinstance(obj, list): return [_clean_log(v) for v in obj] + elif _is_allowed_type(obj): + return obj # Else try: @@ -40,7 +50,24 @@ def _clean_log(obj): return obj except Exception: # give up, similar to _SafeFallBackEncoder - return str(obj) + fallback = str(obj) + + # Try to convert to int + try: + fallback = int(fallback) + return fallback + except ValueError: + pass + + # Try to convert to float + try: + fallback = float(fallback) + return fallback + except ValueError: + pass + + # Else, return string + return fallback def wandb_mixin(func): @@ -182,7 +209,7 @@ class _WandbLoggingProcess(Process): k.startswith(item + "/") or k == item for item in self._exclude): continue - elif not isinstance(v, Number): + elif not _is_allowed_type(v): continue else: log[k] = v diff --git a/python/ray/tune/tests/test_integration_wandb.py b/python/ray/tune/tests/test_integration_wandb.py index aeffaace6..6baa6b341 100644 --- a/python/ray/tune/tests/test_integration_wandb.py +++ b/python/ray/tune/tests/test_integration_wandb.py @@ -4,6 +4,8 @@ from collections import namedtuple from multiprocessing import Queue import unittest +import numpy as np + from ray.tune import Trainable from ray.tune.function_runner import wrap_function from ray.tune.integration.wandb import _WandbLoggingProcess, \ @@ -156,6 +158,8 @@ class WandbIntegrationTest(unittest.TestCase): r1 = { "metric1": 0.8, "metric2": 1.4, + "metric3": np.asarray(32.0), + "metric4": np.float32(32.0), "const": "text", "config": trial_config } @@ -165,6 +169,8 @@ class WandbIntegrationTest(unittest.TestCase): logged = logger._wandb.logs.get(timeout=10) self.assertIn("metric1", logged) self.assertNotIn("metric2", logged) + self.assertIn("metric3", logged) + self.assertIn("metric4", logged) self.assertNotIn("const", logged) self.assertNotIn("config", logged)