mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-30 16:40:05 +08:00
check condition for scoring on startup
This commit is contained in:
@@ -542,9 +542,7 @@ class TreeManager:
|
||||
)
|
||||
|
||||
_, task = pr.store_ranking(interaction)
|
||||
|
||||
ok, rankings_by_message = self.check_condition_for_scoring_state(task.message_tree_id)
|
||||
self.update_message_ranks(task.message_tree_id, rankings_by_message)
|
||||
self.check_condition_for_scoring_state(task.message_tree_id)
|
||||
|
||||
case protocol_schema.TextLabels:
|
||||
logger.info(
|
||||
@@ -659,7 +657,8 @@ class TreeManager:
|
||||
return False, None
|
||||
|
||||
self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING)
|
||||
return True, rankings_by_message
|
||||
self.update_message_ranks(message_tree_id, rankings_by_message)
|
||||
return True
|
||||
|
||||
def update_message_ranks(
|
||||
self, message_tree_id: UUID, rankings_by_message: dict[UUID, list[MessageReaction]]
|
||||
@@ -976,7 +975,7 @@ INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Ranki
|
||||
return rankings_by_message
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def ensure_tree_states(self):
|
||||
def ensure_tree_states(self) -> None:
|
||||
"""Add message tree state rows for all root nodes (inital prompt messages)."""
|
||||
|
||||
missing_tree_ids = self.query_misssing_tree_states()
|
||||
@@ -988,6 +987,14 @@ INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Ranki
|
||||
logger.info(f"Inserting missing message tree state for message: {id} ({tree_size=}, {state=:s})")
|
||||
self._insert_default_state(id, state=state)
|
||||
|
||||
rankings = (
|
||||
self.db.query(MessageTreeState).filter(MessageTreeState.state == message_tree_state.State.RANKING).all()
|
||||
)
|
||||
if len(rankings) > 0:
|
||||
logger.info(f"Checking state of {len(rankings)} message trees in ranking state.")
|
||||
for r in rankings:
|
||||
self.check_condition_for_scoring_state(r.message_tree_id)
|
||||
|
||||
def query_num_active_trees(self, lang: str) -> int:
|
||||
query = (
|
||||
self.db.query(func.count(MessageTreeState.message_tree_id))
|
||||
|
||||
Reference in New Issue
Block a user