mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 20:56:34 +08:00
[tune/rllib] made wandb compatible with rllib trainables (#10252)
This commit is contained in:
@@ -163,6 +163,10 @@ class WandbLogger(Logger):
|
||||
Wandb configuration is done by passing a ``wandb`` key to
|
||||
the ``config`` parameter of ``tune.run()`` (see example below).
|
||||
|
||||
The ``wandb`` config key can be optionally included in the
|
||||
``logger_config`` subkey of ``config`` to be compatible with RLLib
|
||||
trainables (see second example below).
|
||||
|
||||
The content of the ``wandb`` config entry is passed to ``wandb.init()``
|
||||
as keyword arguments. The exception are the following settings, which
|
||||
are used to configure the WandbLogger itself:
|
||||
@@ -205,6 +209,27 @@ class WandbLogger(Logger):
|
||||
},
|
||||
loggers=DEFAULT_LOGGERS + (WandbLogger, ))
|
||||
|
||||
Example for RLLib:
|
||||
|
||||
.. code-block :: python
|
||||
|
||||
from ray import tune
|
||||
from ray.tune.integration.wandb import WandbLogger
|
||||
|
||||
tune.run(
|
||||
"PPO",
|
||||
config={
|
||||
"env": "CartPole-v0",
|
||||
"logger_config": {
|
||||
"wandb": {
|
||||
"project": "PPO",
|
||||
"api_key_file": "~/.wandb_api_key"
|
||||
}
|
||||
}
|
||||
},
|
||||
loggers=[WandbLogger])
|
||||
|
||||
|
||||
"""
|
||||
|
||||
# Do not log these result keys
|
||||
@@ -222,7 +247,11 @@ class WandbLogger(Logger):
|
||||
config = self.config.copy()
|
||||
|
||||
try:
|
||||
wandb_config = config.pop("wandb").copy()
|
||||
if config.get("logger_config", {}).get("wandb"):
|
||||
logger_config = config.pop("logger_config")
|
||||
wandb_config = logger_config.get("wandb").copy()
|
||||
else:
|
||||
wandb_config = config.pop("wandb").copy()
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
"Wandb logger specified but no configuration has been passed. "
|
||||
@@ -296,10 +325,10 @@ class WandbTrainableMixin:
|
||||
|
||||
super().__init__(config, *args, **kwargs)
|
||||
|
||||
config = config.copy()
|
||||
_config = config.copy()
|
||||
|
||||
try:
|
||||
wandb_config = config.pop("wandb").copy()
|
||||
wandb_config = _config.pop("wandb").copy()
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
"Wandb mixin specified but no configuration has been passed. "
|
||||
@@ -334,7 +363,7 @@ class WandbTrainableMixin:
|
||||
allow_val_change=True,
|
||||
group=wandb_group,
|
||||
project=wandb_project,
|
||||
config=config)
|
||||
config=_config)
|
||||
wandb_init_kwargs.update(wandb_config)
|
||||
|
||||
self.wandb = self._wandb.init(**wandb_init_kwargs)
|
||||
|
||||
Reference in New Issue
Block a user