From 39a107e8516c1863e1eb97b80bd0368130f2e880 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Wed, 25 Jan 2023 16:21:58 +0100 Subject: [PATCH] Use intersection of ranking ID sets, add fetch_siblings() --- backend/oasst_backend/prompt_repository.py | 18 ++++++++ backend/oasst_backend/tree_manager.py | 49 +++++++++++++++++----- 2 files changed, 56 insertions(+), 11 deletions(-) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index bbc8abe2..7dddb5cf 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -659,6 +659,24 @@ class PromptRepository: children = qry.all() return children + def fetch_message_siblings( + self, message: Message | UUID, reviewed: Optional[bool] = True, deleted: Optional[bool] = False + ) -> list[Message]: + """ + Get siblings of a message (other messages with the same parent_id) + """ + if isinstance(message, Message): + message = message.id + + parent_qry = self.db.query(Message.parent_id).filter(Message.id == message).subquery() + qry = self.db.query(Message).filter(Message.parent_id == parent_qry.c.parent_id) + if reviewed is not None: + qry = qry.filter(Message.review_result == reviewed) + if deleted is not None: + qry = qry.filter(Message.deleted == deleted) + siblings = qry.all() + return siblings + @staticmethod def trace_descendants(root: Message, messages: list[Message]) -> list[Message]: children = defaultdict(list) diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 1828f8ab..b55c903c 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -662,19 +662,47 @@ class TreeManager: logger.debug(f"False {mts.active=}, {mts.state=}") return False + if mts.state == message_tree_state.State.SCORING_FAILED: + mts.active = True + mts.state = message_tree_state.State.READY_FOR_SCORING + try: for rankings in rankings_by_message.values(): - sorted_messages = [] - for msg_reaction in rankings: - sorted_messages.append(msg_reaction.payload.payload.ranked_message_ids) - logger.debug(f"SORTED MESSAGE {sorted_messages}") - consensus = ranked_pairs(sorted_messages) + ordered_ids_list: list[list[UUID]] = [ + msg_reaction.payload.payload.ranked_message_ids for msg_reaction in rankings + ] + + common_set: set[UUID] = set.intersection(*map(set, ordered_ids_list)) + if len(common_set) < 2: + logger.warning("The intersection of ranking results ID sets has less than two elements. Skipping.") + continue + + # keep only elements in commond set + ordered_ids_list = [list(filter(lambda x: x in common_set, ids)) for ids in ordered_ids_list] + assert all(len(x) == len(common_set) for x in ordered_ids_list) + + logger.debug(f"SORTED MESSAGE IDS {ordered_ids_list}") + consensus = ranked_pairs(ordered_ids_list) + assert len(consensus) == len(common_set) logger.debug(f"CONSENSUS: {consensus}\n\n") + + # fetch all siblings and clear ranks + siblings = self.pr.fetch_message_siblings(consensus[0], reviewed=None, deleted=None) + for m in siblings: + m.rank = None + self.db.add(m) + + # index by id + siblings = {m.id: m for m in siblings} + + # set rank for each message that was part of the common set for rank, message_id in enumerate(consensus): - # set rank for each message_id for Message rows - msg = self.pr.fetch_message(message_id=message_id, fail_if_missing=True) - msg.rank = rank - self.db.add(msg) + msg = siblings.get(message_id) + if msg: + msg.rank = rank + self.db.add(msg) + else: + logger.warning(f"Message {message_id=} not found among siblings.") except Exception: logger.exception(f"update_message_ranks({message_tree_id=}) failed") @@ -1256,7 +1284,6 @@ if __name__ == "__main__": dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local") pr = PromptRepository(db=db, api_client=api_client, client_user=dummy_user) - cfg = TreeManagerConfiguration() tm = TreeManager(db, pr, cfg) tm.ensure_tree_states() @@ -1279,4 +1306,4 @@ if __name__ == "__main__": # ".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("2ac20d38-6650-43aa-8bb3-f61080c0d921")) # ) - print(tm.export_trees_to_file(message_tree_ids=["7e75fb38-e664-4e2b-817c-b9a0b01b0074"], file="lol.jsonl")) + # print(tm.export_trees_to_file(message_tree_ids=["7e75fb38-e664-4e2b-817c-b9a0b01b0074"], file="lol.jsonl"))