diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index fd7c1c67..a8538652 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -19,11 +19,11 @@ from sqlmodel import Session, func class TreeManagerConfiguration(pydantic.BaseModel): - """Configuration class for the TreeManager""" + """TreeManager configuration settings""" max_active_trees: int = 10 - """Maximum number of concurrently active trees in the database. - No new initial prompt tasks will be handed out to users if this + """Maximum number of concurrently active message trees in the database. + No new initial prompt tasks are handed out to users if this number is reached.""" max_tree_depth: int = 6 @@ -41,6 +41,15 @@ class TreeManagerConfiguration(pydantic.BaseModel): num_reviews_reply: int = 3 """Number of peer review checks to collect per reply (other than initial_prompt)""" + p_full_labeling_review_prompt: float = 0.1 + """Probability of full text-labeling (instead of mandatory only) for initial prompts""" + + p_full_labeling_review_reply_assistant: float = 0.1 + """Probability of full text-labeling (instead of mandatory only) for assistant replies""" + + p_full_labeling_review_reply_prompter: float = 0.1 + """Probability of full text-labeling (instead of mandatory only) for prompter replies""" + acceptance_threshold_initial_prompt: float = 0.6 """Threshold for accepting an initial prompt""" @@ -103,6 +112,8 @@ class IncompleteRankingsRow(pydantic.BaseModel): class TreeManager: + _all_text_labels = list(map(lambda x: x.value, protocol_schema.TextLabel)) + def __init__(self, db: Session, prompt_repository: PromptRepository, configuration: TreeManagerConfiguration): self.db = db self.cfg = configuration @@ -242,6 +253,9 @@ class TreeManager: messages = self.pr.fetch_message_conversation(ranking_parent_id) conversation = prepare_conversation(messages) replies = self.pr.fetch_message_children(ranking_parent_id, reviewed=True, exclude_deleted=True) + + assert len(replies) > 1 + random.shuffle(replies) # hand out replies in random order reply_messages = prepare_conversation_message_list(replies) replies = [p.text for p in replies] @@ -263,26 +277,40 @@ class TreeManager: assert len(replies_need_review) > 0 random_reply_message_id = random.choice(replies_need_review) messages = self.pr.fetch_message_conversation(random_reply_message_id) + conversation = prepare_conversation(messages[:-1]) message = messages[-1] + self.cfg.p_full_labeling_review_reply_prompter: float = 0.1 + + label_mode = protocol_schema.LabelTaskMode.full + valid_labels = self._all_text_labels + if message.role == "assistant": - logger.info("Generating a LabelAssistantReplyTask.") + if random.random() > self.cfg.p_full_labeling_review_reply_assistant: + valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)) + label_mode = protocol_schema.LabelTaskMode.simple + logger.info(f"Generating a LabelAssistantReplyTask. ({label_mode=:s})") task = protocol_schema.LabelAssistantReplyTask( message_id=message.id, conversation=conversation, reply=message.text, - valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)), + valid_labels=valid_labels, mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)), + mode=label_mode, ) else: - logger.info("Generating a LabelPrompterReplyTask.") + if random.random() > self.cfg.p_full_labeling_review_reply_prompter: + valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)) + label_mode = protocol_schema.LabelTaskMode.simple + logger.info(f"Generating a LabelPrompterReplyTask. ({label_mode=:s})") task = protocol_schema.LabelPrompterReplyTask( message_id=message.id, conversation=conversation, reply=message.text, - valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)), + valid_labels=valid_labels, mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)), + mode=label_mode, ) parent_message_id = message.id @@ -314,13 +342,21 @@ class TreeManager: case TaskType.LABEL_PROMPT: assert len(prompts_need_review) > 0 message = self.pr.fetch_message(random.choice(prompts_need_review)) - logger.info("Generating a LabelInitialPromptTask.") + label_mode = protocol_schema.LabelTaskMode.full + valid_labels = self._all_text_labels + + if random.random() > self.cfg.p_full_labeling_review_prompt: + valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)) + label_mode = protocol_schema.LabelTaskMode.simple + + logger.info(f"Generating a LabelInitialPromptTask ({label_mode=:s}).") task = protocol_schema.LabelInitialPromptTask( message_id=message.id, prompt=message.text, - valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)), + valid_labels=valid_labels, mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)), + mode=label_mode, ) parent_message_id = message.id diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index 63c9ded0..eec852d7 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -179,6 +179,13 @@ class RankAssistantRepliesTask(RankConversationRepliesTask): type: Literal["rank_assistant_replies"] = "rank_assistant_replies" +class LabelTaskMode(str, enum.Enum): + """Label task mode that allows frontends to select an appropriate UI.""" + + simple = "simple" + full = "full" + + class LabelInitialPromptTask(Task): """A task to label an initial prompt.""" @@ -187,6 +194,7 @@ class LabelInitialPromptTask(Task): prompt: str valid_labels: list[str] mandatory_labels: Optional[list[str]] + mode: Optional[LabelTaskMode] class LabelConversationReplyTask(Task): @@ -198,6 +206,7 @@ class LabelConversationReplyTask(Task): reply: str valid_labels: list[str] mandatory_labels: Optional[list[str]] + mode: Optional[LabelTaskMode] class LabelPrompterReplyTask(LabelConversationReplyTask): diff --git a/text-frontend/__main__.py b/text-frontend/__main__.py index b9234d4f..3e662e5f 100644 --- a/text-frontend/__main__.py +++ b/text-frontend/__main__.py @@ -220,15 +220,19 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY") valid_labels = task["valid_labels"] labels_dict = None - while labels_dict is None: - labels_str: str = typer.prompt("Enter labels, separated by commas") - labels = labels_str.lower().replace(" ", "").split(",") + if task["mode"] == "simple" and len(valid_labels) == 1: + answer: str = typer.confirm(f"{valid_labels[0]}?") + labels_dict = {valid_labels[0]: 1 if answer else 0} + else: + while labels_dict is None: + labels_str: str = typer.prompt("Enter labels, separated by commas") + labels = labels_str.lower().replace(" ", "").split(",") - if all([label in valid_labels for label in labels]): - labels_dict = {label: "1" if label in labels else "0" for label in valid_labels} - else: - invalid_labels = [label for label in labels if label not in valid_labels] - typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}") + if all([label in valid_labels for label in labels]): + labels_dict = {label: "1" if label in labels else "0" for label in valid_labels} + else: + invalid_labels = [label for label in labels if label not in valid_labels] + typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}") # send labels new_task = _post( @@ -258,15 +262,19 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY") valid_labels = task["valid_labels"] labels_dict = None - while labels_dict is None: - labels_str: str = typer.prompt("Enter labels, separated by commas") - labels = labels_str.lower().replace(" ", "").split(",") + if task["mode"] == "simple" and len(valid_labels) == 1: + answer: str = typer.confirm(f"{valid_labels[0]}?") + labels_dict = {valid_labels[0]: 1 if answer else 0} + else: + while labels_dict is None: + labels_str: str = typer.prompt("Enter labels, separated by commas") + labels = labels_str.lower().replace(" ", "").split(",") - if all([label in valid_labels for label in labels]): - labels_dict = {label: "1" if label in labels else "0" for label in valid_labels} - else: - invalid_labels = [label for label in labels if label not in valid_labels] - typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}") + if all([label in valid_labels for label in labels]): + labels_dict = {label: "1" if label in labels else "0" for label in valid_labels} + else: + invalid_labels = [label for label in labels if label not in valid_labels] + typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}") # send labels new_task = _post(