Use intersection of ranking ID sets, add fetch_siblings()

This commit is contained in:
Andreas Köpf
2023-01-25 16:21:58 +01:00
parent 1020dcb024
commit 39a107e851
2 changed files with 56 additions and 11 deletions
@@ -659,6 +659,24 @@ class PromptRepository:
children = qry.all()
return children
def fetch_message_siblings(
self, message: Message | UUID, reviewed: Optional[bool] = True, deleted: Optional[bool] = False
) -> list[Message]:
"""
Get siblings of a message (other messages with the same parent_id)
"""
if isinstance(message, Message):
message = message.id
parent_qry = self.db.query(Message.parent_id).filter(Message.id == message).subquery()
qry = self.db.query(Message).filter(Message.parent_id == parent_qry.c.parent_id)
if reviewed is not None:
qry = qry.filter(Message.review_result == reviewed)
if deleted is not None:
qry = qry.filter(Message.deleted == deleted)
siblings = qry.all()
return siblings
@staticmethod
def trace_descendants(root: Message, messages: list[Message]) -> list[Message]:
children = defaultdict(list)
+38 -11
View File
@@ -662,19 +662,47 @@ class TreeManager:
logger.debug(f"False {mts.active=}, {mts.state=}")
return False
if mts.state == message_tree_state.State.SCORING_FAILED:
mts.active = True
mts.state = message_tree_state.State.READY_FOR_SCORING
try:
for rankings in rankings_by_message.values():
sorted_messages = []
for msg_reaction in rankings:
sorted_messages.append(msg_reaction.payload.payload.ranked_message_ids)
logger.debug(f"SORTED MESSAGE {sorted_messages}")
consensus = ranked_pairs(sorted_messages)
ordered_ids_list: list[list[UUID]] = [
msg_reaction.payload.payload.ranked_message_ids for msg_reaction in rankings
]
common_set: set[UUID] = set.intersection(*map(set, ordered_ids_list))
if len(common_set) < 2:
logger.warning("The intersection of ranking results ID sets has less than two elements. Skipping.")
continue
# keep only elements in commond set
ordered_ids_list = [list(filter(lambda x: x in common_set, ids)) for ids in ordered_ids_list]
assert all(len(x) == len(common_set) for x in ordered_ids_list)
logger.debug(f"SORTED MESSAGE IDS {ordered_ids_list}")
consensus = ranked_pairs(ordered_ids_list)
assert len(consensus) == len(common_set)
logger.debug(f"CONSENSUS: {consensus}\n\n")
# fetch all siblings and clear ranks
siblings = self.pr.fetch_message_siblings(consensus[0], reviewed=None, deleted=None)
for m in siblings:
m.rank = None
self.db.add(m)
# index by id
siblings = {m.id: m for m in siblings}
# set rank for each message that was part of the common set
for rank, message_id in enumerate(consensus):
# set rank for each message_id for Message rows
msg = self.pr.fetch_message(message_id=message_id, fail_if_missing=True)
msg.rank = rank
self.db.add(msg)
msg = siblings.get(message_id)
if msg:
msg.rank = rank
self.db.add(msg)
else:
logger.warning(f"Message {message_id=} not found among siblings.")
except Exception:
logger.exception(f"update_message_ranks({message_tree_id=}) failed")
@@ -1256,7 +1284,6 @@ if __name__ == "__main__":
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
pr = PromptRepository(db=db, api_client=api_client, client_user=dummy_user)
cfg = TreeManagerConfiguration()
tm = TreeManager(db, pr, cfg)
tm.ensure_tree_states()
@@ -1279,4 +1306,4 @@ if __name__ == "__main__":
# ".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("2ac20d38-6650-43aa-8bb3-f61080c0d921"))
# )
print(tm.export_trees_to_file(message_tree_ids=["7e75fb38-e664-4e2b-817c-b9a0b01b0074"], file="lol.jsonl"))
# print(tm.export_trees_to_file(message_tree_ids=["7e75fb38-e664-4e2b-817c-b9a0b01b0074"], file="lol.jsonl"))