From ead51ff4232c7e637bc933a06afaceeaecf859c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Mon, 16 Jan 2023 20:05:40 +0100 Subject: [PATCH] 770 tree manager allow to specify desired task_type no prompter ranking (#775) * only ranking assistant replies by default * add tasks/availability endpoint allow to specify desired task * move rank_prompter_replies option to TreeManagerConfiguration * fix type annotation * remove desired_task_type from _random_task_selection() * fix typo * Convert query_tree_size to sqlachemy, return 'full' text-labeling tasks if they were explicitly requested --- backend/oasst_backend/api/v1/tasks.py | 23 +- backend/oasst_backend/config.py | 2 + backend/oasst_backend/tree_manager.py | 589 ++++++++++-------- .../exceptions/oasst_api_error.py | 1 + 4 files changed, 363 insertions(+), 252 deletions(-) diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index 17df814f..c65500fb 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Optional from uuid import UUID from fastapi import APIRouter, Depends @@ -48,6 +48,27 @@ def request_task( return task +@router.post("/availability", response_model=dict[protocol_schema.TaskRequestType, int]) +def tasks_availability( + *, + user: Optional[protocol_schema.User] = None, + db: Session = Depends(deps.get_db), + api_key: APIKey = Depends(deps.get_api_key), +): + api_client = deps.api_auth(api_key, db) + + try: + pr = PromptRepository(db, api_client, client_user=user) + tm = TreeManager(db, pr) + return tm.determine_task_availability() + + except OasstError: + raise + except Exception: + logger.exception("Task availability query failed.") + raise OasstError("Task availability query failed.", OasstErrorCode.TASK_AVAILABILITY_QUERY_FAILED) + + @router.post("/{task_id}/ack", response_model=None, status_code=HTTP_204_NO_CONTENT) def tasks_acknowledge( *, diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index 71a36160..99b10cb4 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -55,6 +55,8 @@ class TreeManagerConfiguration(BaseModel): mandatory_labels_prompter_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam] """Mandatory labels in text-labeling tasks for prompter replies.""" + rank_prompter_replies: bool = False + class Settings(BaseSettings): PROJECT_NAME: str = "open-assistant backend" diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 2d8a7f4d..34672e2e 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -16,7 +16,7 @@ from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingM from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from sqlalchemy.sql import text -from sqlmodel import Session, func +from sqlmodel import Session, func, not_ class TaskType(Enum): @@ -49,6 +49,7 @@ class ActiveTreeSizeRow(pydantic.BaseModel): class ExtendibleParentRow(pydantic.BaseModel): parent_id: UUID + parent_role: str depth: int message_tree_id: UUID active_children_count: int @@ -59,6 +60,7 @@ class ExtendibleParentRow(pydantic.BaseModel): class IncompleteRankingsRow(pydantic.BaseModel): parent_id: UUID + role: str children_count: int child_min_ranking_count: int @@ -70,21 +72,23 @@ class TreeManager: _all_text_labels = list(map(lambda x: x.value, protocol_schema.TextLabel)) def __init__( - self, db: Session, prompt_repository: PromptRepository, cfg: Optional[TreeManagerConfiguration] = None + self, + db: Session, + prompt_repository: PromptRepository, + cfg: Optional[TreeManagerConfiguration] = None, ): self.db = db self.cfg = cfg or settings.tree_manager self.pr = prompt_repository - def _task_selection( + def _random_task_selection( self, - desired_task_type: protocol_schema.TaskRequestType, num_ranking_tasks: int, num_replies_need_review: int, num_prompts_need_review: int, num_missing_prompts: int, num_missing_replies: int, - ) -> Tuple[TaskType, TaskRole]: + ) -> TaskType: """ Determines which task to hand out to human worker. The task type is drawn with relative weight (e.g. ranking has highest priority) @@ -92,75 +96,97 @@ class TreeManager: """ logger.debug( - f"TreeManager._task_selection({num_ranking_tasks=}, {num_replies_need_review=}, " + f"TreeManager._random_task_selection({num_ranking_tasks=}, {num_replies_need_review=}, " f"{num_prompts_need_review=}, {num_missing_prompts=}, {num_missing_replies=})" ) task_type = TaskType.NONE - task_role = TaskRole.ANY - if desired_task_type == protocol_schema.TaskRequestType.random: - task_weights = [0] * 5 + task_weights = [0] * 5 - if num_ranking_tasks > 0: - task_weights[TaskType.RANKING.value] = 10 + if num_ranking_tasks > 0: + task_weights[TaskType.RANKING.value] = 10 - if num_replies_need_review > 0: - task_weights[TaskType.LABEL_REPLY.value] = 5 + if num_replies_need_review > 0: + task_weights[TaskType.LABEL_REPLY.value] = 5 - if num_prompts_need_review > 0: - task_weights[TaskType.LABEL_PROMPT.value] = 5 + if num_prompts_need_review > 0: + task_weights[TaskType.LABEL_PROMPT.value] = 5 - if num_missing_replies > 0: - task_weights[TaskType.REPLY.value] = 2 + if num_missing_replies > 0: + task_weights[TaskType.REPLY.value] = 2 - if num_missing_prompts > 0: - task_weights[TaskType.PROMPT.value] = 1 + if num_missing_prompts > 0: + task_weights[TaskType.PROMPT.value] = 1 - task_weights = np.array(task_weights) - weight_sum = task_weights.sum() - if weight_sum < 1e-8: - task_type = TaskType.NONE - else: - task_weights = task_weights / weight_sum - task_type = TaskType(np.random.choice(a=len(task_weights), p=task_weights)) - else: - match desired_task_type: - case protocol_schema.TaskRequestType.initial_prompt: - if num_missing_prompts > 0: - task_type = TaskType.PROMPT - case protocol_schema.TaskRequestType.label_initial_prompt: - if num_prompts_need_review > 0: - task_type = TaskType.LABEL_PROMPT - case protocol_schema.TaskRequestType.assistant_reply | protocol_schema.TaskRequestType.prompter_reply: - if num_missing_replies > 0: - task_role = ( - TaskRole.ASSISTANT - if desired_task_type == protocol_schema.TaskRequestType.assistant_reply - else TaskRole.PROMPTER - ) - task_type = TaskType.REPLY - case protocol_schema.TaskRequestType.label_assistant_reply | protocol_schema.TaskRequestType.label_prompter_reply: - if num_replies_need_review > 0: - task_role = ( - TaskRole.ASSISTANT - if desired_task_type == protocol_schema.TaskRequestType.label_assistant_reply - else TaskRole.PROMPTER - ) - task_type = TaskType.LABEL_REPLY - case protocol_schema.TaskRequestType.rank_assistant_replies | protocol_schema.TaskRequestType.rank_prompter_replies: - if num_ranking_tasks > 0: - task_role = ( - TaskRole.ASSISTANT - if desired_task_type == protocol_schema.TaskRequestType.rank_assistant_replies - else TaskRole.PROMPTER - ) - task_type = TaskType.RANKING + task_weights = np.array(task_weights) + weight_sum = task_weights.sum() + if weight_sum > 1e-8: + task_weights = task_weights / weight_sum + task_type = TaskType(np.random.choice(a=len(task_weights), p=task_weights)) - logger.debug(f"Selected {task_type=}, {task_role=}") - return task_type, task_role + logger.debug(f"Selected {task_type=}") + return task_type + + def _determine_task_availability_internal( + self, + num_active_trees: int, + extensible_parents: list[ExtendibleParentRow], + prompts_need_review: list[Message], + replies_need_review: list[Message], + incomplete_rankings: list[IncompleteRankingsRow], + ) -> dict[protocol_schema.TaskRequestType, int]: + task_count_by_type: dict[protocol_schema.TaskRequestType, int] = {t: 0 for t in protocol_schema.TaskRequestType} + + num_missing_prompts = max(0, self.cfg.max_active_trees - num_active_trees) + task_count_by_type[protocol_schema.TaskRequestType.initial_prompt] = num_missing_prompts + + task_count_by_type[protocol_schema.TaskRequestType.prompter_reply] = len( + list(filter(lambda x: x.parent_role == "assistant", extensible_parents)) + ) + task_count_by_type[protocol_schema.TaskRequestType.assistant_reply] = len( + list(filter(lambda x: x.parent_role == "prompter", extensible_parents)) + ) + + task_count_by_type[protocol_schema.TaskRequestType.label_initial_prompt] = len(prompts_need_review) + task_count_by_type[protocol_schema.TaskRequestType.label_assistant_reply] = len( + list(filter(lambda m: m.role == "assistant", replies_need_review)) + ) + task_count_by_type[protocol_schema.TaskRequestType.prompter_reply] = len( + list(filter(lambda m: m.role == "prompter", replies_need_review)) + ) + + if self.cfg.rank_prompter_replies: + task_count_by_type[protocol_schema.TaskRequestType.rank_prompter_replies] = len( + list(filter(lambda r: r.role == "prompter", incomplete_rankings)) + ) + + task_count_by_type[protocol_schema.TaskRequestType.rank_assistant_replies] = len( + list(filter(lambda r: r.role == "assistant", incomplete_rankings)) + ) + + task_count_by_type[protocol_schema.TaskRequestType.random] = sum( + task_count_by_type[t] for t in protocol_schema.TaskRequestType if t in task_count_by_type + ) + + return task_count_by_type + + def determine_task_availability(self) -> dict[protocol_schema.TaskRequestType, int]: + num_active_trees = self.query_num_active_trees() + extensible_parents = self.query_extendible_parents() + prompts_need_review = self.query_prompts_need_review() + replies_need_review = self.query_replies_need_review() + incomplete_rankings = self.query_incomplete_rankings() + + return self._determine_task_availability_internal( + num_active_trees=num_active_trees, + extensible_parents=extensible_parents, + prompts_need_review=prompts_need_review, + replies_need_review=replies_need_review, + incomplete_rankings=incomplete_rankings, + ) def next_task( - self, desired_task_type: protocol_schema.TaskRequestType + self, desired_task_type: protocol_schema.TaskRequestType = protocol_schema.TaskRequestType.random ) -> Tuple[protocol_schema.Task, Optional[UUID], Optional[UUID]]: logger.debug("TreeManager.next_task()") @@ -168,148 +194,195 @@ class TreeManager: num_active_trees = self.query_num_active_trees() prompts_need_review = self.query_prompts_need_review() replies_need_review = self.query_replies_need_review() + extensible_parents = self.query_extendible_parents() + incomplete_rankings = self.query_incomplete_rankings() + if not self.cfg.rank_prompter_replies: + incomplete_rankings = list(filter(lambda r: r.role == "assistant", incomplete_rankings)) + active_tree_sizes = self.query_extendible_trees() # determine type of task to generate num_missing_replies = sum(x.remaining_messages for x in active_tree_sizes) - task_type, task_role = self._task_selection( - desired_task_type, - num_ranking_tasks=len(incomplete_rankings), - num_replies_need_review=len(replies_need_review), - num_prompts_need_review=len(prompts_need_review), - num_missing_prompts=max(0, self.cfg.max_active_trees - num_active_trees), - num_missing_replies=num_missing_replies, - ) - - if task_type == TaskType.NONE: - raise OasstError( - f"No tasks of type '{desired_task_type.value}' are currently available.", - OasstErrorCode.TASK_REQUESTED_TYPE_NOT_AVAILABLE, - HTTPStatus.SERVICE_UNAVAILABLE, + task_role = TaskRole.ANY + if desired_task_type == protocol_schema.TaskRequestType.random: + task_type = self._random_task_selection( + num_ranking_tasks=len(incomplete_rankings), + num_replies_need_review=len(replies_need_review), + num_prompts_need_review=len(prompts_need_review), + num_missing_prompts=max(0, self.cfg.max_active_trees - num_active_trees), + num_missing_replies=num_missing_replies, ) - if task_role != TaskRole.ANY: - # Todo: Allow role specific message selection... - raise OasstError( - f"No tasks of type '{desired_task_type.value}' are currently available.", - OasstErrorCode.TASK_REQUESTED_TYPE_NOT_AVAILABLE, - HTTPStatus.SERVICE_UNAVAILABLE, + if task_type == TaskType.NONE: + raise OasstError( + f"No tasks of type '{protocol_schema.TaskRequestType.random.value}' are currently available.", + OasstErrorCode.TASK_REQUESTED_TYPE_NOT_AVAILABLE, + HTTPStatus.SERVICE_UNAVAILABLE, + ) + else: + task_count_by_type = self._determine_task_availability_internal( + num_active_trees=num_active_trees, + extensible_parents=extensible_parents, + prompts_need_review=prompts_need_review, + replies_need_review=replies_need_review, + incomplete_rankings=incomplete_rankings, ) + available_count = task_count_by_type.get(desired_task_type) + if not available_count: + raise OasstError( + f"No tasks of type '{desired_task_type.value}' are currently available.", + OasstErrorCode.TASK_REQUESTED_TYPE_NOT_AVAILABLE, + HTTPStatus.SERVICE_UNAVAILABLE, + ) + + task_type_role_map = { + protocol_schema.TaskRequestType.initial_prompt: (TaskType.PROMPT, TaskRole.ANY), + protocol_schema.TaskRequestType.prompter_reply: (TaskType.REPLY, TaskRole.PROMPTER), + protocol_schema.TaskRequestType.assistant_reply: (TaskType.REPLY, TaskRole.ASSISTANT), + protocol_schema.TaskRequestType.rank_prompter_replies: (TaskType.RANKING, TaskRole.PROMPTER), + protocol_schema.TaskRequestType.rank_assistant_replies: (TaskType.RANKING, TaskRole.ASSISTANT), + protocol_schema.TaskRequestType.label_initial_prompt: (TaskType.LABEL_PROMPT, TaskRole.ANY), + protocol_schema.TaskRequestType.label_assistant_reply: (TaskType.LABEL_REPLY, TaskRole.ASSISTANT), + protocol_schema.TaskRequestType.label_prompter_reply: (TaskType.LABEL_REPLY, TaskRole.PROMPTER), + } + + task_type, task_role = task_type_role_map[desired_task_type] + message_tree_id = None parent_message_id = None logger.debug(f"selected {task_type=}") match task_type: case TaskType.RANKING: - assert len(incomplete_rankings) > 0 - ranking_parent_id = random.choice(incomplete_rankings).parent_id + if task_role == TaskRole.PROMPTER: + incomplete_rankings = list(filter(lambda m: m.role == "prompter", incomplete_rankings)) + elif task_role == TaskRole.ASSISTANT: + incomplete_rankings = list(filter(lambda m: m.role == "assistant", incomplete_rankings)) - messages = self.pr.fetch_message_conversation(ranking_parent_id) - assert len(messages) > 1 and messages[-1].id == ranking_parent_id - ranking_parent = messages[-1] - assert not ranking_parent.deleted and ranking_parent.review_result - conversation = prepare_conversation(messages) - replies = self.pr.fetch_message_children(ranking_parent_id, reviewed=True, exclude_deleted=True) + if len(incomplete_rankings) > 0: + ranking_parent_id = random.choice(incomplete_rankings).parent_id - assert len(replies) > 1 - random.shuffle(replies) # hand out replies in random order - reply_messages = prepare_conversation_message_list(replies) - replies = [p.text for p in replies] + messages = self.pr.fetch_message_conversation(ranking_parent_id) + assert len(messages) > 1 and messages[-1].id == ranking_parent_id + ranking_parent = messages[-1] + assert not ranking_parent.deleted and ranking_parent.review_result + conversation = prepare_conversation(messages) + replies = self.pr.fetch_message_children(ranking_parent_id, reviewed=True, exclude_deleted=True) - if messages[-1].role == "assistant": - logger.info("Generating a RankPrompterRepliesTask.") - task = protocol_schema.RankPrompterRepliesTask( - conversation=conversation, - replies=replies, - reply_messages=reply_messages, - ranking_parent_id=ranking_parent.id, - message_tree_id=ranking_parent.message_tree_id, - ) - else: - logger.info("Generating a RankAssistantRepliesTask.") - task = protocol_schema.RankAssistantRepliesTask( - conversation=conversation, - replies=replies, - reply_messages=reply_messages, - ranking_parent_id=ranking_parent.id, - message_tree_id=ranking_parent.message_tree_id, - ) + assert len(replies) > 1 + random.shuffle(replies) # hand out replies in random order + reply_messages = prepare_conversation_message_list(replies) + replies = [p.text for p in replies] - parent_message_id = ranking_parent_id - message_tree_id = messages[-1].message_tree_id + if messages[-1].role == "assistant": + logger.info("Generating a RankPrompterRepliesTask.") + task = protocol_schema.RankPrompterRepliesTask( + conversation=conversation, + replies=replies, + reply_messages=reply_messages, + ranking_parent_id=ranking_parent.id, + message_tree_id=ranking_parent.message_tree_id, + ) + else: + logger.info("Generating a RankAssistantRepliesTask.") + task = protocol_schema.RankAssistantRepliesTask( + conversation=conversation, + replies=replies, + reply_messages=reply_messages, + ranking_parent_id=ranking_parent.id, + message_tree_id=ranking_parent.message_tree_id, + ) + + parent_message_id = ranking_parent_id + message_tree_id = messages[-1].message_tree_id case TaskType.LABEL_REPLY: - assert len(replies_need_review) > 0 - random_reply_message_id = random.choice(replies_need_review) - messages = self.pr.fetch_message_conversation(random_reply_message_id) + if task_role == TaskRole.PROMPTER: + replies_need_review = list(filter(lambda m: m.role == "prompter", replies_need_review)) + elif task_role == TaskRole.ASSISTANT: + replies_need_review = list(filter(lambda m: m.role == "assistant", replies_need_review)) - conversation = prepare_conversation(messages[:-1]) - message = messages[-1] + if len(replies_need_review) > 0: + random_reply_message = random.choice(replies_need_review) + messages = self.pr.fetch_message_conversation(random_reply_message) - self.cfg.p_full_labeling_review_reply_prompter: float = 0.1 + conversation = prepare_conversation(messages[:-1]) + message = messages[-1] - label_mode = protocol_schema.LabelTaskMode.full - valid_labels = self._all_text_labels + self.cfg.p_full_labeling_review_reply_prompter: float = 0.1 - if message.role == "assistant": - if random.random() > self.cfg.p_full_labeling_review_reply_assistant: - valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)) - label_mode = protocol_schema.LabelTaskMode.simple - logger.info(f"Generating a LabelAssistantReplyTask. ({label_mode=:s})") - task = protocol_schema.LabelAssistantReplyTask( - message_id=message.id, - conversation=conversation, - reply=message.text, - valid_labels=valid_labels, - mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)), - mode=label_mode, - ) - else: - if random.random() > self.cfg.p_full_labeling_review_reply_prompter: - valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)) - label_mode = protocol_schema.LabelTaskMode.simple - logger.info(f"Generating a LabelPrompterReplyTask. ({label_mode=:s})") - task = protocol_schema.LabelPrompterReplyTask( - message_id=message.id, - conversation=conversation, - reply=message.text, - valid_labels=valid_labels, - mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)), - mode=label_mode, - ) + label_mode = protocol_schema.LabelTaskMode.full + valid_labels = self._all_text_labels - parent_message_id = message.id - message_tree_id = message.message_tree_id + if message.role == "assistant": + if ( + desired_task_type == protocol_schema.TaskRequestType.random + and random.random() > self.cfg.p_full_labeling_review_reply_assistant + ): + valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)) + label_mode = protocol_schema.LabelTaskMode.simple + logger.info(f"Generating a LabelAssistantReplyTask. ({label_mode=:s})") + task = protocol_schema.LabelAssistantReplyTask( + message_id=message.id, + conversation=conversation, + reply=message.text, + valid_labels=valid_labels, + mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)), + mode=label_mode, + ) + else: + if ( + desired_task_type == protocol_schema.TaskRequestType.random + and random.random() > self.cfg.p_full_labeling_review_reply_prompter + ): + valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)) + label_mode = protocol_schema.LabelTaskMode.simple + logger.info(f"Generating a LabelPrompterReplyTask. ({label_mode=:s})") + task = protocol_schema.LabelPrompterReplyTask( + message_id=message.id, + conversation=conversation, + reply=message.text, + valid_labels=valid_labels, + mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)), + mode=label_mode, + ) + + parent_message_id = message.id + message_tree_id = message.message_tree_id case TaskType.REPLY: # select a tree with missing replies - extensible_parents = self.query_extendible_parents() - assert len(extensible_parents) > 0 + if task_role == TaskRole.PROMPTER: + extensible_parents = list(filter(lambda x: x.parent_role == "assistant", extensible_parents)) + elif task_role == TaskRole.ASSISTANT: + extensible_parents = list(filter(lambda x: x.parent_role == "prompter", extensible_parents)) - # fetch random conversation to extend - random_parent = random.choice(extensible_parents) - logger.debug(f"selected {random_parent=}") - messages = self.pr.fetch_message_conversation(random_parent.parent_id) - assert all(m.review_result for m in messages) # ensure all messages have positive review - conversation = prepare_conversation(messages) + if len(extensible_parents) > 0: + random_parent = random.choice(extensible_parents) - # generate reply task depending on last message - if messages[-1].role == "assistant": - logger.info("Generating a PrompterReplyTask.") - task = protocol_schema.PrompterReplyTask(conversation=conversation) - else: - logger.info("Generating a AssistantReplyTask.") - task = protocol_schema.AssistantReplyTask(conversation=conversation) + # fetch random conversation to extend + logger.debug(f"selected {random_parent=}") + messages = self.pr.fetch_message_conversation(random_parent.parent_id) + assert all(m.review_result for m in messages) # ensure all messages have positive review + conversation = prepare_conversation(messages) - parent_message_id = messages[-1].id - message_tree_id = messages[-1].message_tree_id + # generate reply task depending on last message + if messages[-1].role == "assistant": + logger.info("Generating a PrompterReplyTask.") + task = protocol_schema.PrompterReplyTask(conversation=conversation) + else: + logger.info("Generating a AssistantReplyTask.") + task = protocol_schema.AssistantReplyTask(conversation=conversation) + + parent_message_id = messages[-1].id + message_tree_id = messages[-1].message_tree_id case TaskType.LABEL_PROMPT: assert len(prompts_need_review) > 0 - message = self.pr.fetch_message(random.choice(prompts_need_review)) + message = random.choice(prompts_need_review) label_mode = protocol_schema.LabelTaskMode.full valid_labels = self._all_text_labels @@ -337,6 +410,13 @@ class TreeManager: case _: task = None + if task is None: + raise OasstError( + f"No task of type '{desired_task_type.value}' is currently available.", + OasstErrorCode.TASK_REQUESTED_TYPE_NOT_AVAILABLE, + HTTPStatus.SERVICE_UNAVAILABLE, + ) + logger.info(f"Generated {task=}.") return task, message_tree_id, parent_message_id @@ -515,7 +595,8 @@ class TreeManager: logger.debug(f"False {mts.active=}, {mts.state=}") return False - rankings_by_message = self.query_tree_ranking_results(message_tree_id) + ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant" + rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter) for parent_msg_id, ranking in rankings_by_message.items(): if len(ranking) < self.cfg.num_required_rankings: logger.debug(f"False {parent_msg_id=} {len(ranking)=}") @@ -528,68 +609,59 @@ class TreeManager: # calculate acceptance based on spam label return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels]) - _sql_find_prompts_need_review = """ --- find initial prompts that need more reviews -SELECT m.id -FROM message_tree_state mts - LEFT JOIN message m ON mts.message_tree_id = m.id -WHERE mts.active - AND mts.state = :state - AND NOT m.review_result - AND NOT m.deleted - AND m.review_count < :num_reviews_initial_prompt - AND m.parent_id is NULL - AND (:excluded_user_id IS NULL OR m.user_id != :excluded_user_id) -""" - - def query_prompts_need_review(self) -> list[UUID]: + def query_prompts_need_review(self) -> list[Message]: """ - Select id of initial prompts with less then required rankings in active message tree + Select initial prompt messages with less then required rankings in active message tree (active == True in message_tree_state) """ - r = self.db.execute( - text(self._sql_find_prompts_need_review), - { - "state": message_tree_state.State.INITIAL_PROMPT_REVIEW, - "num_reviews_initial_prompt": self.cfg.num_reviews_initial_prompt, - "excluded_user_id": None if settings.DEBUG_ALLOW_SELF_LABELING else self.pr.user_id, - }, + qry = ( + self.db.query(Message) + .select_from(MessageTreeState) + .outerjoin(Message, MessageTreeState.message_tree_id == Message.message_tree_id) + .filter( + MessageTreeState.active, + MessageTreeState.state == message_tree_state.State.INITIAL_PROMPT_REVIEW, + not_(Message.review_result), + not_(Message.deleted), + Message.review_count < self.cfg.num_reviews_initial_prompt, + Message.parent_id.is_(None), + ) ) - return [x["id"] for x in r.all()] - _sql_find_replies_need_review = """ -SELECT m.id -FROM message_tree_state mts - LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id -WHERE mts.active - AND mts.state = :breeding_state - AND NOT m.review_result - AND NOT m.deleted - AND m.review_count < :num_required_reviews - AND m.parent_id is NOT NULL - AND (:excluded_user_id IS NULL OR m.user_id != :excluded_user_id) -""" + if not settings.DEBUG_ALLOW_SELF_LABELING: + qry = qry.filter(Message.user_id != self.pr.user_id) - def query_replies_need_review(self) -> list[UUID]: + return qry.all() + + def query_replies_need_review(self) -> list[Message]: """ - Select ids of child messages (parent_id IS NOT NULL) with less then required rankings + Select child messages (parent_id IS NOT NULL) with less then required rankings in active message tree (active == True in message_tree_state) """ - r = self.db.execute( - text(self._sql_find_replies_need_review), - { - "breeding_state": message_tree_state.State.GROWING, - "num_required_reviews": self.cfg.num_reviews_reply, - "excluded_user_id": None if settings.DEBUG_ALLOW_SELF_LABELING else self.pr.user_id, - }, + qry = ( + self.db.query(Message) + .select_from(MessageTreeState) + .outerjoin(Message, MessageTreeState.message_tree_id == Message.message_tree_id) + .filter( + MessageTreeState.active, + MessageTreeState.state == message_tree_state.State.GROWING, + not_(Message.review_result), + not_(Message.deleted), + Message.review_count < self.cfg.num_reviews_reply, + Message.parent_id.is_not(None), + ) ) - return [x["id"] for x in r.all()] + + if not settings.DEBUG_ALLOW_SELF_LABELING: + qry = qry.filter(Message.user_id != self.pr.user_id) + + return qry.all() _sql_find_incomplete_rankings = """ -- find incomplete rankings -SELECT m.parent_id, COUNT(m.id) children_count, MIN(m.ranking_count) child_min_ranking_count, +SELECT m.parent_id, m.role, COUNT(m.id) children_count, MIN(m.ranking_count) child_min_ranking_count, COUNT(m.id) FILTER (WHERE m.ranking_count >= :num_required_rankings) as completed_rankings FROM message_tree_state mts LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id @@ -598,7 +670,7 @@ WHERE mts.active -- only consider active trees AND m.review_result -- must be reviewed AND NOT m.deleted -- not deleted AND m.parent_id IS NOT NULL -- ignore initial prompts -GROUP BY m.parent_id +GROUP BY m.parent_id, m.role HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings """ @@ -616,10 +688,10 @@ HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings _sql_find_extendible_parents = """ -- find all extendible parent nodes -SELECT m.id as parent_id, m.depth, m.message_tree_id, COUNT(c.id) active_children_count +SELECT m.id as parent_id, m.role as parent_role, m.depth, m.message_tree_id, COUNT(c.id) active_children_count FROM message_tree_state mts LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id -- all elements of message tree - LEFT JOIN message c ON m.id = c.Id -- child nodes + LEFT JOIN message c ON m.id = c.parent_id -- child nodes WHERE mts.active -- only consider active trees AND mts.state = :growing_state -- message tree must be growing AND NOT m.deleted -- ignore deleted messages as parents @@ -627,7 +699,7 @@ WHERE mts.active -- only consider active trees AND m.review_result -- parent node must have positive review AND NOT c.deleted -- don't count deleted children AND (c.review_result OR c.review_count < :num_reviews_reply) -- don't count children with negative review but count elements under review -GROUP BY m.id, m.depth, m.message_tree_id, mts.max_children_count +GROUP BY m.id, m.role, m.depth, m.message_tree_id, mts.max_children_count HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children """ @@ -636,10 +708,7 @@ HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children r = self.db.execute( text(self._sql_find_extendible_parents), - { - "growing_state": message_tree_state.State.GROWING, - "num_reviews_reply": self.cfg.num_reviews_reply, - }, + {"growing_state": message_tree_state.State.GROWING, "num_reviews_reply": self.cfg.num_reviews_reply}, ) return [ExtendibleParentRow.from_orm(x) for x in r.all()] @@ -671,21 +740,27 @@ HAVING COUNT(m.id) < mts.goal_tree_size ) return [ActiveTreeSizeRow.from_orm(x) for x in r.all()] - _sql_get_tree_size = """ -SELECT mts.message_tree_id, mts.goal_tree_size, COUNT(m.id) AS tree_size -FROM message_tree_state mts - LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id -WHERE mts.active - AND NOT m.deleted - AND m.review_result - AND mts.message_tree_id = :message_tree_id -GROUP BY mts.message_tree_id, mts.goal_tree_size -""" - def query_tree_size(self, message_tree_id: UUID) -> ActiveTreeSizeRow: """Returns the number of reviewed not deleted messages in the message tree.""" - r = self.db.execute(text(self._sql_get_tree_size), {"message_tree_id": message_tree_id}) - return ActiveTreeSizeRow.from_orm(r.one()) + + qry = ( + self.db.query( + MessageTreeState.message_tree_id.label("message_tree_id"), + MessageTreeState.goal_tree_size.label("goal_tree_size"), + func.count(Message.id).label("tree_size"), + ) + .select_from(MessageTreeState) + .outerjoin(Message, MessageTreeState.message_tree_id == Message.message_tree_id) + .filter( + MessageTreeState.active, + not_(Message.deleted), + Message.review_result, + MessageTreeState.message_tree_id == message_tree_id, + ) + .group_by(MessageTreeState.message_tree_id, MessageTreeState.goal_tree_size) + ) + + return ActiveTreeSizeRow.from_orm(qry.one()) def query_misssing_tree_states(self) -> list[UUID]: """Find all initial prompt messages that have no associated message tree state""" @@ -702,7 +777,7 @@ GROUP BY mts.message_tree_id, mts.goal_tree_size return [m.id for m in qry_missing_tree_states.all()] _sql_find_tree_ranking_results = """ --- get all ranking results of completed tasks for all parents with >=2 children +-- get all ranking results of completed tasks for all parents with >= 2 children SELECT p.parent_id, mr.* FROM ( -- find parents with > 1 children @@ -712,7 +787,8 @@ SELECT p.parent_id, mr.* FROM WHERE m.review_result -- must be reviewed AND NOT m.deleted -- not deleted AND m.parent_id IS NOT NULL -- ignore initial prompts - AND mts.message_tree_id = :message_tree_id + AND (:role IS NULL OR m.role = :role) -- children with matching role + AND mts.message_tree_id = :message_tree_id GROUP BY m.parent_id, m.message_tree_id HAVING COUNT(m.id) > 1 ) as p @@ -720,11 +796,21 @@ LEFT JOIN task t ON p.parent_id = t.parent_message_id AND t.done AND (t.payload_ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'RankingReactionPayload' """ - def query_tree_ranking_results(self, message_tree_id: UUID) -> dict[UUID, list[MessageReaction]]: + def query_tree_ranking_results( + self, + message_tree_id: UUID, + role_filter: str = "assistant", + ) -> dict[UUID, list[MessageReaction]]: """Finds all completed ranking restuls for a message_tree""" + + assert role_filter in (None, "assistant", "prompter") + r = self.db.execute( text(self._sql_find_tree_ranking_results), - {"message_tree_id": message_tree_id}, + { + "message_tree_id": message_tree_id, + "role": role_filter, + }, ) rankings_by_message = {} @@ -803,12 +889,12 @@ WHERE t.done = TRUE if __name__ == "__main__": - from oasst_backend.api.deps import get_dummy_api_client + from oasst_backend.api.deps import api_auth from oasst_backend.database import engine from oasst_backend.prompt_repository import PromptRepository with Session(engine) as db: - api_client = get_dummy_api_client(db) + api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db) dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local") pr = PromptRepository(db=db, api_client=api_client, client_user=dummy_user) @@ -817,15 +903,16 @@ if __name__ == "__main__": tm = TreeManager(db, pr, cfg) tm.ensure_tree_states() - print("query_num_active_trees", tm.query_num_active_trees()) - print("query_incomplete_rankings", tm.query_incomplete_rankings()) - print("query_incomplete_reply_reviews", tm.query_replies_need_review()) - print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review()) - print("query_extendible_trees", tm.query_extendible_trees()) - print("query_extendible_parents", tm.query_extendible_parents()) + # print("query_num_active_trees", tm.query_num_active_trees()) + # print("query_incomplete_rankings", tm.query_incomplete_rankings()) + # print("query_replies_need_review", tm.query_replies_need_review()) + # print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review()) + # print("query_extendible_trees", tm.query_extendible_trees()) + # print("query_extendible_parents", tm.query_extendible_parents()) + # print("query_tree_size", tm.query_tree_size(message_tree_id=UUID("bdf434cf-4df5-4b74-949c-a5a157bc3292"))) print("next_task:", tm.next_task()) - print( - ".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("2ac20d38-6650-43aa-8bb3-f61080c0d921")) - ) + # print( + # "query_tree_ranking_results", tm.query_tree_ranking_results(UUID("6036f58f-41b5-48c4-bdd9-b16f34ab1312")) + # ) diff --git a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py index 7ad0b65e..e60ad746 100644 --- a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py +++ b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py @@ -32,6 +32,7 @@ class OasstErrorCode(IntEnum): TASK_INTERACTION_REQUEST_FAILED = 1004 TASK_GENERATION_FAILED = 1005 TASK_REQUESTED_TYPE_NOT_AVAILABLE = 1006 + TASK_AVAILABILITY_QUERY_FAILED = 1007 # 2000-3000: prompt_repository INVALID_FRONTEND_MESSAGE_ID = 2000