mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-30 16:40:05 +08:00
make sure we enter READY_FOR_EXPORT after ranking
This commit is contained in:
@@ -489,7 +489,8 @@ class TreeManager:
|
||||
|
||||
_, task = pr.store_ranking(interaction)
|
||||
|
||||
self.check_condition_for_scoring_state(task.message_tree_id)
|
||||
ok, rankings_by_message = self.check_condition_for_scoring_state(task.message_tree_id)
|
||||
self.update_message_ranks(task.message_tree_id, rankings_by_message)
|
||||
|
||||
case protocol_schema.TextLabels:
|
||||
logger.info(
|
||||
@@ -589,39 +590,56 @@ class TreeManager:
|
||||
return True
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def check_condition_for_scoring_state(self, message_tree_id: UUID) -> bool:
|
||||
def check_condition_for_scoring_state(
|
||||
self, message_tree_id: UUID
|
||||
) -> Tuple[bool, dict[UUID, list[MessageReaction]]]:
|
||||
logger.debug(f"check_condition_for_scoring_state({message_tree_id=})")
|
||||
mts: MessageTreeState
|
||||
mts = self.db.query(MessageTreeState).filter(MessageTreeState.message_tree_id == message_tree_id).one()
|
||||
|
||||
mts = self.pr.fetch_tree_state(message_tree_id)
|
||||
if not mts.active or mts.state != message_tree_state.State.RANKING:
|
||||
logger.debug(f"False {mts.active=}, {mts.state=}")
|
||||
return False
|
||||
return False, None
|
||||
|
||||
ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant"
|
||||
rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter)
|
||||
for parent_msg_id, ranking in rankings_by_message.items():
|
||||
if len(ranking) < self.cfg.num_required_rankings:
|
||||
logger.debug(f"False {parent_msg_id=} {len(ranking)=}")
|
||||
return False
|
||||
return False, None
|
||||
|
||||
self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING)
|
||||
self.update_message_ranks(rankings_by_message)
|
||||
return True
|
||||
return True, rankings_by_message
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def update_message_ranks(self, rankings_by_message: Dict[int, int]) -> None:
|
||||
for parent_msg_id, ranking in rankings_by_message.items():
|
||||
sorted_messages = []
|
||||
for msg_reaction in ranking:
|
||||
sorted_messages.append(msg_reaction.payload.payload.ranked_message_ids)
|
||||
logger.debug(f"SORTED MESSAGE {sorted_messages}")
|
||||
consensus = ranked_pairs(sorted_messages)
|
||||
logger.debug(f"CONSENSUS: {consensus}\n\n")
|
||||
for rank, message_id in enumerate(consensus):
|
||||
# set rank for each message_id for Message rows
|
||||
msg = self.pr.fetch_message(message_id=message_id, fail_if_missing=True)
|
||||
msg.rank = rank
|
||||
self.db.add(msg)
|
||||
def update_message_ranks(self, message_tree_id: UUID, rankings_by_message: Dict[int, int]) -> bool:
|
||||
|
||||
mts = self.pr.fetch_tree_state(message_tree_id)
|
||||
# check state, allow retry if in SCORING_FAILED state
|
||||
if mts.state not in (message_tree_state.State.READY_FOR_SCORING, message_tree_state.State.SCORING_FAILED):
|
||||
logger.debug(f"False {mts.active=}, {mts.state=}")
|
||||
return False
|
||||
|
||||
try:
|
||||
for rankings in rankings_by_message.values():
|
||||
sorted_messages = []
|
||||
for msg_reaction in rankings:
|
||||
sorted_messages.append(msg_reaction.payload.payload.ranked_message_ids)
|
||||
logger.debug(f"SORTED MESSAGE {sorted_messages}")
|
||||
consensus = ranked_pairs(sorted_messages)
|
||||
logger.debug(f"CONSENSUS: {consensus}\n\n")
|
||||
for rank, message_id in enumerate(consensus):
|
||||
# set rank for each message_id for Message rows
|
||||
msg = self.pr.fetch_message(message_id=message_id, fail_if_missing=True)
|
||||
msg.rank = rank
|
||||
self.db.add(msg)
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"update_message_ranks({message_tree_id=}) failed")
|
||||
self._enter_state(mts, message_tree_state.State.SCORING_FAILED)
|
||||
return False
|
||||
|
||||
self._enter_state(mts, message_tree_state.State.READY_FOR_EXPORT)
|
||||
return True
|
||||
|
||||
def _calculate_acceptance(self, labels: list[TextLabels]):
|
||||
# calculate acceptance based on spam label
|
||||
|
||||
Reference in New Issue
Block a user