From 348999a93636bfcadb72256175a0a7115ff0cf63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Thu, 26 Jan 2023 19:06:25 +0100 Subject: [PATCH] exclude trees in ranking state in acitve tree count --- backend/oasst_backend/tree_manager.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index a4cad0c5..992e75dd 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -202,7 +202,7 @@ class TreeManager: lang = "en" logger.warning("Task availability request without lang tag received, assuming lang='en'.") - num_active_trees = self.query_num_active_trees(lang=lang) + num_active_trees = self.query_num_active_trees(lang=lang, exclude_ranking=True) extendible_parents = self.query_extendible_parents(lang=lang) prompts_need_review = self.query_prompts_need_review(lang=lang) replies_need_review = self.query_replies_need_review(lang=lang) @@ -230,7 +230,7 @@ class TreeManager: lang = "en" logger.warning("Task request without lang tag received, assuming 'en'.") - num_active_trees = self.query_num_active_trees(lang=lang) + num_active_trees = self.query_num_active_trees(lang=lang, exclude_ranking=True) prompts_need_review = self.query_prompts_need_review(lang=lang) replies_need_review = self.query_replies_need_review(lang=lang) extendible_parents = self.query_extendible_parents(lang=lang) @@ -995,12 +995,15 @@ INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Ranki for r in rankings: self.check_condition_for_scoring_state(r.message_tree_id) - def query_num_active_trees(self, lang: str) -> int: + def query_num_active_trees(self, lang: str, exclude_ranking: bool = True) -> int: + """Count all active trees (optionally exclude those in ranking state).""" query = ( self.db.query(func.count(MessageTreeState.message_tree_id)) .join(Message, MessageTreeState.message_tree_id == Message.id) .filter(MessageTreeState.active, Message.lang == lang) ) + if exclude_ranking: + query = query.filter(MessageTreeState.state != message_tree_state.State.RANKING) return query.scalar() def query_reviews_for_message(self, message_id: UUID) -> list[TextLabels]: