diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 03aac56e..9bca0201 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -152,7 +152,7 @@ class TreeManager: task_count_by_type[protocol_schema.TaskRequestType.label_assistant_reply] = len( list(filter(lambda m: m.role == "assistant", replies_need_review)) ) - task_count_by_type[protocol_schema.TaskRequestType.prompter_reply] = len( + task_count_by_type[protocol_schema.TaskRequestType.label_prompter_reply] = len( list(filter(lambda m: m.role == "prompter", replies_need_review)) ) @@ -639,7 +639,7 @@ class TreeManager: .outerjoin(Message, MessageTreeState.message_tree_id == Message.message_tree_id) .filter( MessageTreeState.active, - MessageTreeState.state == message_tree_state.State.INITIAL_PROMPT_REVIEW.value, + MessageTreeState.state == message_tree_state.State.INITIAL_PROMPT_REVIEW, not_(Message.review_result), not_(Message.deleted), Message.review_count < self.cfg.num_reviews_initial_prompt, @@ -664,7 +664,7 @@ class TreeManager: .outerjoin(Message, MessageTreeState.message_tree_id == Message.message_tree_id) .filter( MessageTreeState.active, - MessageTreeState.state == message_tree_state.State.GROWING.value, + MessageTreeState.state == message_tree_state.State.GROWING, not_(Message.review_result), not_(Message.deleted), Message.review_count < self.cfg.num_reviews_reply, @@ -699,7 +699,7 @@ HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings text(self._sql_find_incomplete_rankings), { "num_required_rankings": self.cfg.num_required_rankings, - "ranking_state": message_tree_state.State.RANKING.value, + "ranking_state": message_tree_state.State.RANKING, }, ) return [IncompleteRankingsRow.from_orm(x) for x in r.all()] @@ -727,7 +727,7 @@ HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children r = self.db.execute( text(self._sql_find_extendible_parents), { - "growing_state": message_tree_state.State.GROWING.value, + "growing_state": message_tree_state.State.GROWING, "num_reviews_reply": self.cfg.num_reviews_reply, }, ) @@ -755,7 +755,7 @@ HAVING COUNT(m.id) < mts.goal_tree_size r = self.db.execute( text(self._sql_find_extendible_trees), { - "growing_state": message_tree_state.State.GROWING.value, + "growing_state": message_tree_state.State.GROWING, "num_reviews_reply": self.cfg.num_reviews_reply, }, ) @@ -924,16 +924,16 @@ if __name__ == "__main__": # print("query_num_active_trees", tm.query_num_active_trees()) # print("query_incomplete_rankings", tm.query_incomplete_rankings()) - # print("query_replies_need_review", tm.query_replies_need_review()) - # print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review()) + print("query_replies_need_review", tm.query_replies_need_review()) + print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review()) # print("query_extendible_trees", tm.query_extendible_trees()) # print("query_extendible_parents", tm.query_extendible_parents()) # print("query_tree_size", tm.query_tree_size(message_tree_id=UUID("bdf434cf-4df5-4b74-949c-a5a157bc3292"))) - print( - "query_reviews_for_message", - tm.query_reviews_for_message(message_id=UUID("6a444493-0d48-4316-a9f1-7e263f5a2473")), - ) + # print( + # "query_reviews_for_message", + # tm.query_reviews_for_message(message_id=UUID("6a444493-0d48-4316-a9f1-7e263f5a2473")), + # ) # print("next_task:", tm.next_task())