[tune/rllib] made wandb compatible with rllib trainables (#10252)

This commit is contained in:
krfricke
2020-08-22 02:25:52 +02:00
committed by GitHub
parent f87669372d
commit c31876002d
2 changed files with 35 additions and 5 deletions
+33 -4
View File
@@ -163,6 +163,10 @@ class WandbLogger(Logger):
Wandb configuration is done by passing a ``wandb`` key to
the ``config`` parameter of ``tune.run()`` (see example below).
The ``wandb`` config key can be optionally included in the
``logger_config`` subkey of ``config`` to be compatible with RLLib
trainables (see second example below).
The content of the ``wandb`` config entry is passed to ``wandb.init()``
as keyword arguments. The exception are the following settings, which
are used to configure the WandbLogger itself:
@@ -205,6 +209,27 @@ class WandbLogger(Logger):
},
loggers=DEFAULT_LOGGERS + (WandbLogger, ))
Example for RLLib:
.. code-block :: python
from ray import tune
from ray.tune.integration.wandb import WandbLogger
tune.run(
"PPO",
config={
"env": "CartPole-v0",
"logger_config": {
"wandb": {
"project": "PPO",
"api_key_file": "~/.wandb_api_key"
}
}
},
loggers=[WandbLogger])
"""
# Do not log these result keys
@@ -222,7 +247,11 @@ class WandbLogger(Logger):
config = self.config.copy()
try:
wandb_config = config.pop("wandb").copy()
if config.get("logger_config", {}).get("wandb"):
logger_config = config.pop("logger_config")
wandb_config = logger_config.get("wandb").copy()
else:
wandb_config = config.pop("wandb").copy()
except KeyError:
raise ValueError(
"Wandb logger specified but no configuration has been passed. "
@@ -296,10 +325,10 @@ class WandbTrainableMixin:
super().__init__(config, *args, **kwargs)
config = config.copy()
_config = config.copy()
try:
wandb_config = config.pop("wandb").copy()
wandb_config = _config.pop("wandb").copy()
except KeyError:
raise ValueError(
"Wandb mixin specified but no configuration has been passed. "
@@ -334,7 +363,7 @@ class WandbTrainableMixin:
allow_val_change=True,
group=wandb_group,
project=wandb_project,
config=config)
config=_config)
wandb_init_kwargs.update(wandb_config)
self.wandb = self._wandb.init(**wandb_init_kwargs)