From 0f896d910e0074d5095ef953ba26bb058ebb613a Mon Sep 17 00:00:00 2001 From: Andreas Koepf Date: Tue, 17 Jan 2023 17:50:17 +0000 Subject: [PATCH] make sure we enter READY_FOR_EXPORT after ranking --- backend/oasst_backend/tree_manager.py | 60 +++++++++++++++++---------- 1 file changed, 39 insertions(+), 21 deletions(-) diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 265591c1..2f48bca0 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -489,7 +489,8 @@ class TreeManager: _, task = pr.store_ranking(interaction) - self.check_condition_for_scoring_state(task.message_tree_id) + ok, rankings_by_message = self.check_condition_for_scoring_state(task.message_tree_id) + self.update_message_ranks(task.message_tree_id, rankings_by_message) case protocol_schema.TextLabels: logger.info( @@ -589,39 +590,56 @@ class TreeManager: return True @managed_tx_method(CommitMode.COMMIT) - def check_condition_for_scoring_state(self, message_tree_id: UUID) -> bool: + def check_condition_for_scoring_state( + self, message_tree_id: UUID + ) -> Tuple[bool, dict[UUID, list[MessageReaction]]]: logger.debug(f"check_condition_for_scoring_state({message_tree_id=})") - mts: MessageTreeState - mts = self.db.query(MessageTreeState).filter(MessageTreeState.message_tree_id == message_tree_id).one() + + mts = self.pr.fetch_tree_state(message_tree_id) if not mts.active or mts.state != message_tree_state.State.RANKING: logger.debug(f"False {mts.active=}, {mts.state=}") - return False + return False, None ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant" rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter) for parent_msg_id, ranking in rankings_by_message.items(): if len(ranking) < self.cfg.num_required_rankings: logger.debug(f"False {parent_msg_id=} {len(ranking)=}") - return False + return False, None self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING) - self.update_message_ranks(rankings_by_message) - return True + return True, rankings_by_message @managed_tx_method(CommitMode.COMMIT) - def update_message_ranks(self, rankings_by_message: Dict[int, int]) -> None: - for parent_msg_id, ranking in rankings_by_message.items(): - sorted_messages = [] - for msg_reaction in ranking: - sorted_messages.append(msg_reaction.payload.payload.ranked_message_ids) - logger.debug(f"SORTED MESSAGE {sorted_messages}") - consensus = ranked_pairs(sorted_messages) - logger.debug(f"CONSENSUS: {consensus}\n\n") - for rank, message_id in enumerate(consensus): - # set rank for each message_id for Message rows - msg = self.pr.fetch_message(message_id=message_id, fail_if_missing=True) - msg.rank = rank - self.db.add(msg) + def update_message_ranks(self, message_tree_id: UUID, rankings_by_message: Dict[int, int]) -> bool: + + mts = self.pr.fetch_tree_state(message_tree_id) + # check state, allow retry if in SCORING_FAILED state + if mts.state not in (message_tree_state.State.READY_FOR_SCORING, message_tree_state.State.SCORING_FAILED): + logger.debug(f"False {mts.active=}, {mts.state=}") + return False + + try: + for rankings in rankings_by_message.values(): + sorted_messages = [] + for msg_reaction in rankings: + sorted_messages.append(msg_reaction.payload.payload.ranked_message_ids) + logger.debug(f"SORTED MESSAGE {sorted_messages}") + consensus = ranked_pairs(sorted_messages) + logger.debug(f"CONSENSUS: {consensus}\n\n") + for rank, message_id in enumerate(consensus): + # set rank for each message_id for Message rows + msg = self.pr.fetch_message(message_id=message_id, fail_if_missing=True) + msg.rank = rank + self.db.add(msg) + + except Exception: + logger.exception(f"update_message_ranks({message_tree_id=}) failed") + self._enter_state(mts, message_tree_state.State.SCORING_FAILED) + return False + + self._enter_state(mts, message_tree_state.State.READY_FOR_EXPORT) + return True def _calculate_acceptance(self, labels: list[TextLabels]): # calculate acceptance based on spam label