From d4688835d54a0da51c986cf2b43cfbe14fdfdb01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Thu, 26 Jan 2023 16:33:03 +0100 Subject: [PATCH] check condition for scoring on startup --- backend/oasst_backend/tree_manager.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 929a9297..a4cad0c5 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -542,9 +542,7 @@ class TreeManager: ) _, task = pr.store_ranking(interaction) - - ok, rankings_by_message = self.check_condition_for_scoring_state(task.message_tree_id) - self.update_message_ranks(task.message_tree_id, rankings_by_message) + self.check_condition_for_scoring_state(task.message_tree_id) case protocol_schema.TextLabels: logger.info( @@ -659,7 +657,8 @@ class TreeManager: return False, None self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING) - return True, rankings_by_message + self.update_message_ranks(message_tree_id, rankings_by_message) + return True def update_message_ranks( self, message_tree_id: UUID, rankings_by_message: dict[UUID, list[MessageReaction]] @@ -976,7 +975,7 @@ INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Ranki return rankings_by_message @managed_tx_method(CommitMode.COMMIT) - def ensure_tree_states(self): + def ensure_tree_states(self) -> None: """Add message tree state rows for all root nodes (inital prompt messages).""" missing_tree_ids = self.query_misssing_tree_states() @@ -988,6 +987,14 @@ INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Ranki logger.info(f"Inserting missing message tree state for message: {id} ({tree_size=}, {state=:s})") self._insert_default_state(id, state=state) + rankings = ( + self.db.query(MessageTreeState).filter(MessageTreeState.state == message_tree_state.State.RANKING).all() + ) + if len(rankings) > 0: + logger.info(f"Checking state of {len(rankings)} message trees in ranking state.") + for r in rankings: + self.check_condition_for_scoring_state(r.message_tree_id) + def query_num_active_trees(self, lang: str) -> int: query = ( self.db.query(func.count(MessageTreeState.message_tree_id))