From c31876002d11022646cb878c93eac55f75111f80 Mon Sep 17 00:00:00 2001 From: krfricke Date: Sat, 22 Aug 2020 02:25:52 +0200 Subject: [PATCH] [tune/rllib] made wandb compatible with rllib trainables (#10252) --- python/ray/tune/integration/wandb.py | 37 +++++++++++++++++++++++++--- rllib/agents/trainer.py | 3 ++- 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/python/ray/tune/integration/wandb.py b/python/ray/tune/integration/wandb.py index bfe3786e3..34d559634 100644 --- a/python/ray/tune/integration/wandb.py +++ b/python/ray/tune/integration/wandb.py @@ -163,6 +163,10 @@ class WandbLogger(Logger): Wandb configuration is done by passing a ``wandb`` key to the ``config`` parameter of ``tune.run()`` (see example below). + The ``wandb`` config key can be optionally included in the + ``logger_config`` subkey of ``config`` to be compatible with RLLib + trainables (see second example below). + The content of the ``wandb`` config entry is passed to ``wandb.init()`` as keyword arguments. The exception are the following settings, which are used to configure the WandbLogger itself: @@ -205,6 +209,27 @@ class WandbLogger(Logger): }, loggers=DEFAULT_LOGGERS + (WandbLogger, )) + Example for RLLib: + + .. code-block :: python + + from ray import tune + from ray.tune.integration.wandb import WandbLogger + + tune.run( + "PPO", + config={ + "env": "CartPole-v0", + "logger_config": { + "wandb": { + "project": "PPO", + "api_key_file": "~/.wandb_api_key" + } + } + }, + loggers=[WandbLogger]) + + """ # Do not log these result keys @@ -222,7 +247,11 @@ class WandbLogger(Logger): config = self.config.copy() try: - wandb_config = config.pop("wandb").copy() + if config.get("logger_config", {}).get("wandb"): + logger_config = config.pop("logger_config") + wandb_config = logger_config.get("wandb").copy() + else: + wandb_config = config.pop("wandb").copy() except KeyError: raise ValueError( "Wandb logger specified but no configuration has been passed. " @@ -296,10 +325,10 @@ class WandbTrainableMixin: super().__init__(config, *args, **kwargs) - config = config.copy() + _config = config.copy() try: - wandb_config = config.pop("wandb").copy() + wandb_config = _config.pop("wandb").copy() except KeyError: raise ValueError( "Wandb mixin specified but no configuration has been passed. " @@ -334,7 +363,7 @@ class WandbTrainableMixin: allow_val_change=True, group=wandb_group, project=wandb_project, - config=config) + config=_config) wandb_init_kwargs.update(wandb_config) self.wandb = self._wandb.init(**wandb_init_kwargs) diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 7904245c1..2f11e203b 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -372,7 +372,8 @@ COMMON_CONFIG: TrainerConfigDict = { # === Logger === # Define logger-specific configuration to be used inside Logger - "logger_config": {}, + # Default value None allows overwriting with nested dicts + "logger_config": None, # === Replay Settings === # The number of contiguous environment steps to replay at once. This may