make sure we enter READY_FOR_EXPORT after ranking

This commit is contained in:
Andreas Koepf
2023-01-17 17:50:17 +00:00
parent 3749791bce
commit 0f896d910e
+39 -21
View File
@@ -489,7 +489,8 @@ class TreeManager:
_, task = pr.store_ranking(interaction)
self.check_condition_for_scoring_state(task.message_tree_id)
ok, rankings_by_message = self.check_condition_for_scoring_state(task.message_tree_id)
self.update_message_ranks(task.message_tree_id, rankings_by_message)
case protocol_schema.TextLabels:
logger.info(
@@ -589,39 +590,56 @@ class TreeManager:
return True
@managed_tx_method(CommitMode.COMMIT)
def check_condition_for_scoring_state(self, message_tree_id: UUID) -> bool:
def check_condition_for_scoring_state(
self, message_tree_id: UUID
) -> Tuple[bool, dict[UUID, list[MessageReaction]]]:
logger.debug(f"check_condition_for_scoring_state({message_tree_id=})")
mts: MessageTreeState
mts = self.db.query(MessageTreeState).filter(MessageTreeState.message_tree_id == message_tree_id).one()
mts = self.pr.fetch_tree_state(message_tree_id)
if not mts.active or mts.state != message_tree_state.State.RANKING:
logger.debug(f"False {mts.active=}, {mts.state=}")
return False
return False, None
ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant"
rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter)
for parent_msg_id, ranking in rankings_by_message.items():
if len(ranking) < self.cfg.num_required_rankings:
logger.debug(f"False {parent_msg_id=} {len(ranking)=}")
return False
return False, None
self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING)
self.update_message_ranks(rankings_by_message)
return True
return True, rankings_by_message
@managed_tx_method(CommitMode.COMMIT)
def update_message_ranks(self, rankings_by_message: Dict[int, int]) -> None:
for parent_msg_id, ranking in rankings_by_message.items():
sorted_messages = []
for msg_reaction in ranking:
sorted_messages.append(msg_reaction.payload.payload.ranked_message_ids)
logger.debug(f"SORTED MESSAGE {sorted_messages}")
consensus = ranked_pairs(sorted_messages)
logger.debug(f"CONSENSUS: {consensus}\n\n")
for rank, message_id in enumerate(consensus):
# set rank for each message_id for Message rows
msg = self.pr.fetch_message(message_id=message_id, fail_if_missing=True)
msg.rank = rank
self.db.add(msg)
def update_message_ranks(self, message_tree_id: UUID, rankings_by_message: Dict[int, int]) -> bool:
mts = self.pr.fetch_tree_state(message_tree_id)
# check state, allow retry if in SCORING_FAILED state
if mts.state not in (message_tree_state.State.READY_FOR_SCORING, message_tree_state.State.SCORING_FAILED):
logger.debug(f"False {mts.active=}, {mts.state=}")
return False
try:
for rankings in rankings_by_message.values():
sorted_messages = []
for msg_reaction in rankings:
sorted_messages.append(msg_reaction.payload.payload.ranked_message_ids)
logger.debug(f"SORTED MESSAGE {sorted_messages}")
consensus = ranked_pairs(sorted_messages)
logger.debug(f"CONSENSUS: {consensus}\n\n")
for rank, message_id in enumerate(consensus):
# set rank for each message_id for Message rows
msg = self.pr.fetch_message(message_id=message_id, fail_if_missing=True)
msg.rank = rank
self.db.add(msg)
except Exception:
logger.exception(f"update_message_ranks({message_tree_id=}) failed")
self._enter_state(mts, message_tree_state.State.SCORING_FAILED)
return False
self._enter_state(mts, message_tree_state.State.READY_FOR_EXPORT)
return True
def _calculate_acceptance(self, labels: list[TextLabels]):
# calculate acceptance based on spam label