mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-04 17:20:19 +08:00
Use intersection of ranking ID sets, add fetch_siblings()
This commit is contained in:
@@ -659,6 +659,24 @@ class PromptRepository:
|
||||
children = qry.all()
|
||||
return children
|
||||
|
||||
def fetch_message_siblings(
|
||||
self, message: Message | UUID, reviewed: Optional[bool] = True, deleted: Optional[bool] = False
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Get siblings of a message (other messages with the same parent_id)
|
||||
"""
|
||||
if isinstance(message, Message):
|
||||
message = message.id
|
||||
|
||||
parent_qry = self.db.query(Message.parent_id).filter(Message.id == message).subquery()
|
||||
qry = self.db.query(Message).filter(Message.parent_id == parent_qry.c.parent_id)
|
||||
if reviewed is not None:
|
||||
qry = qry.filter(Message.review_result == reviewed)
|
||||
if deleted is not None:
|
||||
qry = qry.filter(Message.deleted == deleted)
|
||||
siblings = qry.all()
|
||||
return siblings
|
||||
|
||||
@staticmethod
|
||||
def trace_descendants(root: Message, messages: list[Message]) -> list[Message]:
|
||||
children = defaultdict(list)
|
||||
|
||||
@@ -662,19 +662,47 @@ class TreeManager:
|
||||
logger.debug(f"False {mts.active=}, {mts.state=}")
|
||||
return False
|
||||
|
||||
if mts.state == message_tree_state.State.SCORING_FAILED:
|
||||
mts.active = True
|
||||
mts.state = message_tree_state.State.READY_FOR_SCORING
|
||||
|
||||
try:
|
||||
for rankings in rankings_by_message.values():
|
||||
sorted_messages = []
|
||||
for msg_reaction in rankings:
|
||||
sorted_messages.append(msg_reaction.payload.payload.ranked_message_ids)
|
||||
logger.debug(f"SORTED MESSAGE {sorted_messages}")
|
||||
consensus = ranked_pairs(sorted_messages)
|
||||
ordered_ids_list: list[list[UUID]] = [
|
||||
msg_reaction.payload.payload.ranked_message_ids for msg_reaction in rankings
|
||||
]
|
||||
|
||||
common_set: set[UUID] = set.intersection(*map(set, ordered_ids_list))
|
||||
if len(common_set) < 2:
|
||||
logger.warning("The intersection of ranking results ID sets has less than two elements. Skipping.")
|
||||
continue
|
||||
|
||||
# keep only elements in commond set
|
||||
ordered_ids_list = [list(filter(lambda x: x in common_set, ids)) for ids in ordered_ids_list]
|
||||
assert all(len(x) == len(common_set) for x in ordered_ids_list)
|
||||
|
||||
logger.debug(f"SORTED MESSAGE IDS {ordered_ids_list}")
|
||||
consensus = ranked_pairs(ordered_ids_list)
|
||||
assert len(consensus) == len(common_set)
|
||||
logger.debug(f"CONSENSUS: {consensus}\n\n")
|
||||
|
||||
# fetch all siblings and clear ranks
|
||||
siblings = self.pr.fetch_message_siblings(consensus[0], reviewed=None, deleted=None)
|
||||
for m in siblings:
|
||||
m.rank = None
|
||||
self.db.add(m)
|
||||
|
||||
# index by id
|
||||
siblings = {m.id: m for m in siblings}
|
||||
|
||||
# set rank for each message that was part of the common set
|
||||
for rank, message_id in enumerate(consensus):
|
||||
# set rank for each message_id for Message rows
|
||||
msg = self.pr.fetch_message(message_id=message_id, fail_if_missing=True)
|
||||
msg.rank = rank
|
||||
self.db.add(msg)
|
||||
msg = siblings.get(message_id)
|
||||
if msg:
|
||||
msg.rank = rank
|
||||
self.db.add(msg)
|
||||
else:
|
||||
logger.warning(f"Message {message_id=} not found among siblings.")
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"update_message_ranks({message_tree_id=}) failed")
|
||||
@@ -1256,7 +1284,6 @@ if __name__ == "__main__":
|
||||
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
|
||||
|
||||
pr = PromptRepository(db=db, api_client=api_client, client_user=dummy_user)
|
||||
|
||||
cfg = TreeManagerConfiguration()
|
||||
tm = TreeManager(db, pr, cfg)
|
||||
tm.ensure_tree_states()
|
||||
@@ -1279,4 +1306,4 @@ if __name__ == "__main__":
|
||||
# ".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("2ac20d38-6650-43aa-8bb3-f61080c0d921"))
|
||||
# )
|
||||
|
||||
print(tm.export_trees_to_file(message_tree_ids=["7e75fb38-e664-4e2b-817c-b9a0b01b0074"], file="lol.jsonl"))
|
||||
# print(tm.export_trees_to_file(message_tree_ids=["7e75fb38-e664-4e2b-817c-b9a0b01b0074"], file="lol.jsonl"))
|
||||
|
||||
Reference in New Issue
Block a user