diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index a2c85940..03aac56e 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -617,9 +617,9 @@ class TreeManager: logger.debug(f"SORTED MESSAGE {sorted_messages}") consensus = ranked_pairs(sorted_messages) logger.debug(f"CONSENSUS: {consensus}\n\n") - for rank, uuid in enumerate(consensus): + for rank, message_id in enumerate(consensus): # set rank for each message_id for Message rows - msg = self.db.query(Message).filter(Message.id == uuid).one() + msg = self.pr.fetch_message(message_id=message_id, fail_if_missing=True) msg.rank = rank self.db.add(msg) @@ -639,7 +639,7 @@ class TreeManager: .outerjoin(Message, MessageTreeState.message_tree_id == Message.message_tree_id) .filter( MessageTreeState.active, - MessageTreeState.state == message_tree_state.State.INITIAL_PROMPT_REVIEW, + MessageTreeState.state == message_tree_state.State.INITIAL_PROMPT_REVIEW.value, not_(Message.review_result), not_(Message.deleted), Message.review_count < self.cfg.num_reviews_initial_prompt, @@ -664,7 +664,7 @@ class TreeManager: .outerjoin(Message, MessageTreeState.message_tree_id == Message.message_tree_id) .filter( MessageTreeState.active, - MessageTreeState.state == message_tree_state.State.GROWING, + MessageTreeState.state == message_tree_state.State.GROWING.value, not_(Message.review_result), not_(Message.deleted), Message.review_count < self.cfg.num_reviews_reply, @@ -699,7 +699,7 @@ HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings text(self._sql_find_incomplete_rankings), { "num_required_rankings": self.cfg.num_required_rankings, - "ranking_state": message_tree_state.State.RANKING, + "ranking_state": message_tree_state.State.RANKING.value, }, ) return [IncompleteRankingsRow.from_orm(x) for x in r.all()] @@ -726,7 +726,10 @@ HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children r = self.db.execute( text(self._sql_find_extendible_parents), - {"growing_state": message_tree_state.State.GROWING, "num_reviews_reply": self.cfg.num_reviews_reply}, + { + "growing_state": message_tree_state.State.GROWING.value, + "num_reviews_reply": self.cfg.num_reviews_reply, + }, ) return [ExtendibleParentRow.from_orm(x) for x in r.all()] @@ -752,7 +755,7 @@ HAVING COUNT(m.id) < mts.goal_tree_size r = self.db.execute( text(self._sql_find_extendible_trees), { - "growing_state": message_tree_state.State.GROWING, + "growing_state": message_tree_state.State.GROWING.value, "num_reviews_reply": self.cfg.num_reviews_reply, }, ) @@ -850,7 +853,7 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin state = message_tree_state.State.INITIAL_PROMPT_REVIEW if tree_size > 1: state = message_tree_state.State.GROWING - logger.info(f"Inserting missing message tree state for message: {id} ({tree_size=}, {state=})") + logger.info(f"Inserting missing message tree state for message: {id} ({tree_size=}, {state=:s})") self._insert_default_state(id, state=state) def query_num_active_trees(self) -> int: