From b4702de1c2539403deb08403fb296483b117f425 Mon Sep 17 00:00:00 2001 From: Maltimore Date: Mon, 25 Jan 2021 12:56:00 +0100 Subject: [PATCH] [RLlib] move evaluation to trainer.step() such that the result is properly logged (#12708) --- rllib/agents/trainer.py | 8 -------- rllib/agents/trainer_template.py | 12 ++++++++++++ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 9055fe378..47e637f6d 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -535,14 +535,6 @@ class Trainer(Trainable): if hasattr(self, "workers") and isinstance(self.workers, WorkerSet): self._sync_filters_if_needed(self.workers) - if self.config["evaluation_interval"] == 1 or ( - self._iteration > 0 and self.config["evaluation_interval"] - and self._iteration % self.config["evaluation_interval"] == 0): - evaluation_metrics = self._evaluate() - assert isinstance(evaluation_metrics, dict), \ - "_evaluate() needs to return a dict." - result.update(evaluation_metrics) - return result def _sync_filters_if_needed(self, workers: WorkerSet): diff --git a/rllib/agents/trainer_template.py b/rllib/agents/trainer_template.py index b896958b6..600cbef12 100644 --- a/rllib/agents/trainer_template.py +++ b/rllib/agents/trainer_template.py @@ -146,6 +146,18 @@ def build_trainer( @override(Trainer) def step(self): res = next(self.train_exec_impl) + + # self._iteration gets incremented after this function returns, + # meaning that e. g. the first time this function is called, + # self._iteration will be 0. We check `self._iteration+1` in the + # if-statement below to reflect that the first training iteration + # is already over. + if (self.config["evaluation_interval"] and (self._iteration + 1) % + self.config["evaluation_interval"] == 0): + evaluation_metrics = self._evaluate() + assert isinstance(evaluation_metrics, dict), \ + "_evaluate() needs to return a dict." + res.update(evaluation_metrics) return res @override(Trainer)