mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
1015 lines
44 KiB
Python
1015 lines
44 KiB
Python
import random
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
from http import HTTPStatus
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
from uuid import UUID
|
|
|
|
import numpy as np
|
|
import pydantic
|
|
from loguru import logger
|
|
from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversation_message_list
|
|
from oasst_backend.config import TreeManagerConfiguration, settings
|
|
from oasst_backend.models import Message, MessageReaction, MessageTreeState, Task, TextLabels, message_tree_state
|
|
from oasst_backend.prompt_repository import PromptRepository
|
|
from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_method, managed_tx_method
|
|
from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI
|
|
from oasst_backend.utils.ranking import ranked_pairs
|
|
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
|
from oasst_shared.schemas import protocol as protocol_schema
|
|
from sqlalchemy.sql import text
|
|
from sqlmodel import Session, func, not_
|
|
|
|
|
|
class TaskType(Enum):
|
|
NONE = -1
|
|
RANKING = 0
|
|
LABEL_REPLY = 1
|
|
REPLY = 2
|
|
LABEL_PROMPT = 3
|
|
PROMPT = 4
|
|
|
|
|
|
class TaskRole(Enum):
|
|
ANY = 0
|
|
PROMPTER = 1
|
|
ASSISTANT = 2
|
|
|
|
|
|
class ActiveTreeSizeRow(pydantic.BaseModel):
|
|
message_tree_id: UUID
|
|
tree_size: int
|
|
goal_tree_size: int
|
|
|
|
@property
|
|
def remaining_messages(self) -> int:
|
|
return max(0, self.goal_tree_size - self.tree_size)
|
|
|
|
class Config:
|
|
orm_mode = True
|
|
|
|
|
|
class ExtendibleParentRow(pydantic.BaseModel):
|
|
parent_id: UUID
|
|
parent_role: str
|
|
depth: int
|
|
message_tree_id: UUID
|
|
active_children_count: int
|
|
|
|
class Config:
|
|
orm_mode = True
|
|
|
|
|
|
class IncompleteRankingsRow(pydantic.BaseModel):
|
|
parent_id: UUID
|
|
role: str
|
|
children_count: int
|
|
child_min_ranking_count: int
|
|
|
|
class Config:
|
|
orm_mode = True
|
|
|
|
|
|
class TreeMessageCountStats(pydantic.BaseModel):
|
|
message_tree_id: UUID
|
|
state: str
|
|
depth: int
|
|
oldest: datetime
|
|
youngest: datetime
|
|
count: int
|
|
goal_tree_size: int
|
|
|
|
@property
|
|
def completed(self) -> int:
|
|
return self.count / self.goal_tree_size
|
|
|
|
|
|
class TreeManagerStats(pydantic.BaseModel):
|
|
state_counts: dict[str, int]
|
|
message_counts: list[TreeMessageCountStats]
|
|
|
|
|
|
class TreeManager:
|
|
_all_text_labels = list(map(lambda x: x.value, protocol_schema.TextLabel))
|
|
|
|
def __init__(
|
|
self,
|
|
db: Session,
|
|
prompt_repository: PromptRepository,
|
|
cfg: Optional[TreeManagerConfiguration] = None,
|
|
):
|
|
self.db = db
|
|
self.cfg = cfg or settings.tree_manager
|
|
self.pr = prompt_repository
|
|
|
|
def _random_task_selection(
|
|
self,
|
|
num_ranking_tasks: int,
|
|
num_replies_need_review: int,
|
|
num_prompts_need_review: int,
|
|
num_missing_prompts: int,
|
|
num_missing_replies: int,
|
|
) -> TaskType:
|
|
"""
|
|
Determines which task to hand out to human worker.
|
|
The task type is drawn with relative weight (e.g. ranking has highest priority)
|
|
depending on what is possible with the current message trees in the database.
|
|
"""
|
|
|
|
logger.debug(
|
|
f"TreeManager._random_task_selection({num_ranking_tasks=}, {num_replies_need_review=}, "
|
|
f"{num_prompts_need_review=}, {num_missing_prompts=}, {num_missing_replies=})"
|
|
)
|
|
|
|
task_type = TaskType.NONE
|
|
task_weights = [0] * 5
|
|
|
|
if num_ranking_tasks > 0:
|
|
task_weights[TaskType.RANKING.value] = 10
|
|
|
|
if num_replies_need_review > 0:
|
|
task_weights[TaskType.LABEL_REPLY.value] = 5
|
|
|
|
if num_prompts_need_review > 0:
|
|
task_weights[TaskType.LABEL_PROMPT.value] = 5
|
|
|
|
if num_missing_replies > 0:
|
|
task_weights[TaskType.REPLY.value] = 2
|
|
|
|
if num_missing_prompts > 0:
|
|
task_weights[TaskType.PROMPT.value] = 1
|
|
|
|
task_weights = np.array(task_weights)
|
|
weight_sum = task_weights.sum()
|
|
if weight_sum > 1e-8:
|
|
task_weights = task_weights / weight_sum
|
|
task_type = TaskType(np.random.choice(a=len(task_weights), p=task_weights))
|
|
|
|
logger.debug(f"Selected {task_type=}")
|
|
return task_type
|
|
|
|
def _determine_task_availability_internal(
|
|
self,
|
|
num_active_trees: int,
|
|
extendible_parents: list[ExtendibleParentRow],
|
|
prompts_need_review: list[Message],
|
|
replies_need_review: list[Message],
|
|
incomplete_rankings: list[IncompleteRankingsRow],
|
|
) -> dict[protocol_schema.TaskRequestType, int]:
|
|
task_count_by_type: dict[protocol_schema.TaskRequestType, int] = {t: 0 for t in protocol_schema.TaskRequestType}
|
|
|
|
num_missing_prompts = max(0, self.cfg.max_active_trees - num_active_trees)
|
|
task_count_by_type[protocol_schema.TaskRequestType.initial_prompt] = num_missing_prompts
|
|
|
|
task_count_by_type[protocol_schema.TaskRequestType.prompter_reply] = len(
|
|
list(filter(lambda x: x.parent_role == "assistant", extendible_parents))
|
|
)
|
|
task_count_by_type[protocol_schema.TaskRequestType.assistant_reply] = len(
|
|
list(filter(lambda x: x.parent_role == "prompter", extendible_parents))
|
|
)
|
|
|
|
task_count_by_type[protocol_schema.TaskRequestType.label_initial_prompt] = len(prompts_need_review)
|
|
task_count_by_type[protocol_schema.TaskRequestType.label_assistant_reply] = len(
|
|
list(filter(lambda m: m.role == "assistant", replies_need_review))
|
|
)
|
|
task_count_by_type[protocol_schema.TaskRequestType.label_prompter_reply] = len(
|
|
list(filter(lambda m: m.role == "prompter", replies_need_review))
|
|
)
|
|
|
|
if self.cfg.rank_prompter_replies:
|
|
task_count_by_type[protocol_schema.TaskRequestType.rank_prompter_replies] = len(
|
|
list(filter(lambda r: r.role == "prompter", incomplete_rankings))
|
|
)
|
|
|
|
task_count_by_type[protocol_schema.TaskRequestType.rank_assistant_replies] = len(
|
|
list(filter(lambda r: r.role == "assistant", incomplete_rankings))
|
|
)
|
|
|
|
task_count_by_type[protocol_schema.TaskRequestType.random] = sum(
|
|
task_count_by_type[t] for t in protocol_schema.TaskRequestType if t in task_count_by_type
|
|
)
|
|
|
|
return task_count_by_type
|
|
|
|
def determine_task_availability(self) -> dict[protocol_schema.TaskRequestType, int]:
|
|
num_active_trees = self.query_num_active_trees()
|
|
extendible_parents = self.query_extendible_parents()
|
|
prompts_need_review = self.query_prompts_need_review()
|
|
replies_need_review = self.query_replies_need_review()
|
|
incomplete_rankings = self.query_incomplete_rankings()
|
|
|
|
return self._determine_task_availability_internal(
|
|
num_active_trees=num_active_trees,
|
|
extendible_parents=extendible_parents,
|
|
prompts_need_review=prompts_need_review,
|
|
replies_need_review=replies_need_review,
|
|
incomplete_rankings=incomplete_rankings,
|
|
)
|
|
|
|
def next_task(
|
|
self, desired_task_type: protocol_schema.TaskRequestType = protocol_schema.TaskRequestType.random
|
|
) -> Tuple[protocol_schema.Task, Optional[UUID], Optional[UUID]]:
|
|
|
|
logger.debug("TreeManager.next_task()")
|
|
|
|
num_active_trees = self.query_num_active_trees()
|
|
prompts_need_review = self.query_prompts_need_review()
|
|
replies_need_review = self.query_replies_need_review()
|
|
extendible_parents = self.query_extendible_parents()
|
|
|
|
incomplete_rankings = self.query_incomplete_rankings()
|
|
if not self.cfg.rank_prompter_replies:
|
|
incomplete_rankings = list(filter(lambda r: r.role == "assistant", incomplete_rankings))
|
|
|
|
active_tree_sizes = self.query_extendible_trees()
|
|
|
|
# determine type of task to generate
|
|
num_missing_replies = sum(x.remaining_messages for x in active_tree_sizes)
|
|
|
|
task_role = TaskRole.ANY
|
|
if desired_task_type == protocol_schema.TaskRequestType.random:
|
|
task_type = self._random_task_selection(
|
|
num_ranking_tasks=len(incomplete_rankings),
|
|
num_replies_need_review=len(replies_need_review),
|
|
num_prompts_need_review=len(prompts_need_review),
|
|
num_missing_prompts=max(0, self.cfg.max_active_trees - num_active_trees),
|
|
num_missing_replies=num_missing_replies,
|
|
)
|
|
|
|
if task_type == TaskType.NONE:
|
|
raise OasstError(
|
|
f"No tasks of type '{protocol_schema.TaskRequestType.random.value}' are currently available.",
|
|
OasstErrorCode.TASK_REQUESTED_TYPE_NOT_AVAILABLE,
|
|
HTTPStatus.SERVICE_UNAVAILABLE,
|
|
)
|
|
else:
|
|
task_count_by_type = self._determine_task_availability_internal(
|
|
num_active_trees=num_active_trees,
|
|
extendible_parents=extendible_parents,
|
|
prompts_need_review=prompts_need_review,
|
|
replies_need_review=replies_need_review,
|
|
incomplete_rankings=incomplete_rankings,
|
|
)
|
|
|
|
available_count = task_count_by_type.get(desired_task_type)
|
|
if not available_count:
|
|
raise OasstError(
|
|
f"No tasks of type '{desired_task_type.value}' are currently available.",
|
|
OasstErrorCode.TASK_REQUESTED_TYPE_NOT_AVAILABLE,
|
|
HTTPStatus.SERVICE_UNAVAILABLE,
|
|
)
|
|
|
|
task_type_role_map = {
|
|
protocol_schema.TaskRequestType.initial_prompt: (TaskType.PROMPT, TaskRole.ANY),
|
|
protocol_schema.TaskRequestType.prompter_reply: (TaskType.REPLY, TaskRole.PROMPTER),
|
|
protocol_schema.TaskRequestType.assistant_reply: (TaskType.REPLY, TaskRole.ASSISTANT),
|
|
protocol_schema.TaskRequestType.rank_prompter_replies: (TaskType.RANKING, TaskRole.PROMPTER),
|
|
protocol_schema.TaskRequestType.rank_assistant_replies: (TaskType.RANKING, TaskRole.ASSISTANT),
|
|
protocol_schema.TaskRequestType.label_initial_prompt: (TaskType.LABEL_PROMPT, TaskRole.ANY),
|
|
protocol_schema.TaskRequestType.label_assistant_reply: (TaskType.LABEL_REPLY, TaskRole.ASSISTANT),
|
|
protocol_schema.TaskRequestType.label_prompter_reply: (TaskType.LABEL_REPLY, TaskRole.PROMPTER),
|
|
}
|
|
|
|
task_type, task_role = task_type_role_map[desired_task_type]
|
|
|
|
message_tree_id = None
|
|
parent_message_id = None
|
|
|
|
logger.debug(f"selected {task_type=}")
|
|
match task_type:
|
|
case TaskType.RANKING:
|
|
if task_role == TaskRole.PROMPTER:
|
|
incomplete_rankings = list(filter(lambda m: m.role == "prompter", incomplete_rankings))
|
|
elif task_role == TaskRole.ASSISTANT:
|
|
incomplete_rankings = list(filter(lambda m: m.role == "assistant", incomplete_rankings))
|
|
|
|
if len(incomplete_rankings) > 0:
|
|
ranking_parent_id = random.choice(incomplete_rankings).parent_id
|
|
|
|
messages = self.pr.fetch_message_conversation(ranking_parent_id)
|
|
assert len(messages) > 0 and messages[-1].id == ranking_parent_id
|
|
ranking_parent = messages[-1]
|
|
assert not ranking_parent.deleted and ranking_parent.review_result
|
|
conversation = prepare_conversation(messages)
|
|
replies = self.pr.fetch_message_children(ranking_parent_id, reviewed=True, exclude_deleted=True)
|
|
|
|
assert len(replies) > 1
|
|
random.shuffle(replies) # hand out replies in random order
|
|
reply_messages = prepare_conversation_message_list(replies)
|
|
replies = [p.text for p in replies]
|
|
|
|
if messages[-1].role == "assistant":
|
|
logger.info("Generating a RankPrompterRepliesTask.")
|
|
task = protocol_schema.RankPrompterRepliesTask(
|
|
conversation=conversation,
|
|
replies=replies,
|
|
reply_messages=reply_messages,
|
|
ranking_parent_id=ranking_parent.id,
|
|
message_tree_id=ranking_parent.message_tree_id,
|
|
)
|
|
else:
|
|
logger.info("Generating a RankAssistantRepliesTask.")
|
|
task = protocol_schema.RankAssistantRepliesTask(
|
|
conversation=conversation,
|
|
replies=replies,
|
|
reply_messages=reply_messages,
|
|
ranking_parent_id=ranking_parent.id,
|
|
message_tree_id=ranking_parent.message_tree_id,
|
|
)
|
|
|
|
parent_message_id = ranking_parent_id
|
|
message_tree_id = messages[-1].message_tree_id
|
|
|
|
case TaskType.LABEL_REPLY:
|
|
if task_role == TaskRole.PROMPTER:
|
|
replies_need_review = list(filter(lambda m: m.role == "prompter", replies_need_review))
|
|
elif task_role == TaskRole.ASSISTANT:
|
|
replies_need_review = list(filter(lambda m: m.role == "assistant", replies_need_review))
|
|
|
|
if len(replies_need_review) > 0:
|
|
random_reply_message = random.choice(replies_need_review)
|
|
messages = self.pr.fetch_message_conversation(random_reply_message)
|
|
|
|
conversation = prepare_conversation(messages[:-1])
|
|
message = messages[-1]
|
|
|
|
self.cfg.p_full_labeling_review_reply_prompter: float = 0.1
|
|
|
|
label_mode = protocol_schema.LabelTaskMode.full
|
|
valid_labels = self._all_text_labels
|
|
|
|
if message.role == "assistant":
|
|
if (
|
|
desired_task_type == protocol_schema.TaskRequestType.random
|
|
and random.random() > self.cfg.p_full_labeling_review_reply_assistant
|
|
):
|
|
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply))
|
|
label_mode = protocol_schema.LabelTaskMode.simple
|
|
logger.info(f"Generating a LabelAssistantReplyTask. ({label_mode=:s})")
|
|
task = protocol_schema.LabelAssistantReplyTask(
|
|
message_id=message.id,
|
|
conversation=conversation,
|
|
reply=message.text,
|
|
valid_labels=valid_labels,
|
|
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)),
|
|
mode=label_mode,
|
|
)
|
|
else:
|
|
if (
|
|
desired_task_type == protocol_schema.TaskRequestType.random
|
|
and random.random() > self.cfg.p_full_labeling_review_reply_prompter
|
|
):
|
|
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply))
|
|
label_mode = protocol_schema.LabelTaskMode.simple
|
|
logger.info(f"Generating a LabelPrompterReplyTask. ({label_mode=:s})")
|
|
task = protocol_schema.LabelPrompterReplyTask(
|
|
message_id=message.id,
|
|
conversation=conversation,
|
|
reply=message.text,
|
|
valid_labels=valid_labels,
|
|
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)),
|
|
mode=label_mode,
|
|
)
|
|
|
|
parent_message_id = message.id
|
|
message_tree_id = message.message_tree_id
|
|
|
|
case TaskType.REPLY:
|
|
# select a tree with missing replies
|
|
if task_role == TaskRole.PROMPTER:
|
|
extendible_parents = list(filter(lambda x: x.parent_role == "assistant", extendible_parents))
|
|
elif task_role == TaskRole.ASSISTANT:
|
|
extendible_parents = list(filter(lambda x: x.parent_role == "prompter", extendible_parents))
|
|
|
|
if len(extendible_parents) > 0:
|
|
random_parent = random.choice(extendible_parents)
|
|
|
|
# fetch random conversation to extend
|
|
logger.debug(f"selected {random_parent=}")
|
|
messages = self.pr.fetch_message_conversation(random_parent.parent_id)
|
|
assert all(m.review_result for m in messages) # ensure all messages have positive review
|
|
conversation = prepare_conversation(messages)
|
|
|
|
# generate reply task depending on last message
|
|
if messages[-1].role == "assistant":
|
|
logger.info("Generating a PrompterReplyTask.")
|
|
task = protocol_schema.PrompterReplyTask(conversation=conversation)
|
|
else:
|
|
logger.info("Generating a AssistantReplyTask.")
|
|
task = protocol_schema.AssistantReplyTask(conversation=conversation)
|
|
|
|
parent_message_id = messages[-1].id
|
|
message_tree_id = messages[-1].message_tree_id
|
|
|
|
case TaskType.LABEL_PROMPT:
|
|
assert len(prompts_need_review) > 0
|
|
message = random.choice(prompts_need_review)
|
|
|
|
label_mode = protocol_schema.LabelTaskMode.full
|
|
valid_labels = self._all_text_labels
|
|
|
|
if random.random() > self.cfg.p_full_labeling_review_prompt:
|
|
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt))
|
|
label_mode = protocol_schema.LabelTaskMode.simple
|
|
|
|
logger.info(f"Generating a LabelInitialPromptTask ({label_mode=:s}).")
|
|
task = protocol_schema.LabelInitialPromptTask(
|
|
message_id=message.id,
|
|
prompt=message.text,
|
|
valid_labels=valid_labels,
|
|
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)),
|
|
mode=label_mode,
|
|
)
|
|
|
|
parent_message_id = message.id
|
|
message_tree_id = message.message_tree_id
|
|
|
|
case TaskType.PROMPT:
|
|
logger.info("Generating an InitialPromptTask.")
|
|
task = protocol_schema.InitialPromptTask(hint=None)
|
|
|
|
case _:
|
|
task = None
|
|
|
|
if task is None:
|
|
raise OasstError(
|
|
f"No task of type '{desired_task_type.value}' is currently available.",
|
|
OasstErrorCode.TASK_REQUESTED_TYPE_NOT_AVAILABLE,
|
|
HTTPStatus.SERVICE_UNAVAILABLE,
|
|
)
|
|
|
|
logger.info(f"Generated {task=}.")
|
|
|
|
return task, message_tree_id, parent_message_id
|
|
|
|
@async_managed_tx_method(CommitMode.COMMIT)
|
|
async def handle_interaction(self, interaction: protocol_schema.AnyInteraction) -> protocol_schema.Task:
|
|
pr = self.pr
|
|
match type(interaction):
|
|
case protocol_schema.TextReplyToMessage:
|
|
logger.info(
|
|
f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
|
|
)
|
|
|
|
# here we store the text reply in the database
|
|
message = pr.store_text_reply(
|
|
text=interaction.text,
|
|
frontend_message_id=interaction.message_id,
|
|
user_frontend_message_id=interaction.user_message_id,
|
|
)
|
|
|
|
if not message.parent_id:
|
|
logger.info(f"TreeManager: Inserting new tree state for initial prompt {message.id=}")
|
|
self._insert_default_state(message.id)
|
|
|
|
if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION:
|
|
try:
|
|
hugging_face_api = HuggingFaceAPI(
|
|
f"{HfUrl.HUGGINGFACE_FEATURE_EXTRACTION.value}/{HfEmbeddingModel.MINILM.value}"
|
|
)
|
|
embedding = await hugging_face_api.post(interaction.text)
|
|
pr.insert_message_embedding(
|
|
message_id=message.id, model=HfEmbeddingModel.MINILM.value, embedding=embedding
|
|
)
|
|
except OasstError:
|
|
logger.error(
|
|
f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
|
|
)
|
|
|
|
if not settings.DEBUG_SKIP_TOXICITY_CALCULATION:
|
|
try:
|
|
model_name: str = HfClassificationModel.TOXIC_ROBERTA.value
|
|
hugging_face_api: HuggingFaceAPI = HuggingFaceAPI(
|
|
f"{HfUrl.HUGGINGFACE_FEATURE_EXTRACTION.value}/{model_name}"
|
|
)
|
|
|
|
toxicity: List[List[Dict[str, Any]]] = await hugging_face_api.post(interaction.text)
|
|
toxicity = toxicity[0][0]
|
|
|
|
pr.insert_toxicity(
|
|
message_id=message.id, model=model_name, score=toxicity["score"], label=toxicity["label"]
|
|
)
|
|
|
|
except OasstError:
|
|
logger.error(
|
|
f"Could not compute toxicity for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
|
|
)
|
|
|
|
case protocol_schema.MessageRating:
|
|
logger.info(
|
|
f"Frontend reports rating of {interaction.message_id=} with {interaction.rating=} by {interaction.user=}."
|
|
)
|
|
|
|
pr.store_rating(interaction)
|
|
|
|
case protocol_schema.MessageRanking:
|
|
logger.info(
|
|
f"Frontend reports ranking of {interaction.message_id=} with {interaction.ranking=} by {interaction.user=}."
|
|
)
|
|
|
|
_, task = pr.store_ranking(interaction)
|
|
|
|
ok, rankings_by_message = self.check_condition_for_scoring_state(task.message_tree_id)
|
|
self.update_message_ranks(task.message_tree_id, rankings_by_message)
|
|
|
|
case protocol_schema.TextLabels:
|
|
logger.info(
|
|
f"Frontend reports labels of {interaction.message_id=} with {interaction.labels=} by {interaction.user=}."
|
|
)
|
|
|
|
_, task, msg = pr.store_text_labels(interaction)
|
|
|
|
# if it was a respones for a task, check if we have enough reviews to calc review_result
|
|
if task and msg:
|
|
reviews = self.query_reviews_for_message(msg.id)
|
|
acceptance_score = self._calculate_acceptance(reviews)
|
|
logger.debug(
|
|
f"Message {msg.id=}, {acceptance_score=}, {len(reviews)=}, {msg.review_result=}, {msg.review_count=}"
|
|
)
|
|
if msg.parent_id is None:
|
|
if not msg.review_result and msg.review_count >= self.cfg.num_reviews_initial_prompt:
|
|
if acceptance_score > self.cfg.acceptance_threshold_initial_prompt:
|
|
msg.review_result = True
|
|
self.db.add(msg)
|
|
logger.info(
|
|
f"Initial prompt message was accepted: {msg.id=}, {acceptance_score=}, {len(reviews)=}"
|
|
)
|
|
else:
|
|
self.enter_low_grade_state(msg.message_tree_id)
|
|
self.check_condition_for_growing_state(msg.message_tree_id)
|
|
elif msg.review_count >= self.cfg.num_reviews_reply:
|
|
if not msg.review_result and acceptance_score > self.cfg.acceptance_threshold_reply:
|
|
msg.review_result = True
|
|
self.db.add(msg)
|
|
logger.info(
|
|
f"Reply message message accepted: {msg.id=}, {acceptance_score=}, {len(reviews)=}"
|
|
)
|
|
|
|
self.check_condition_for_ranking_state(msg.message_tree_id)
|
|
|
|
case _:
|
|
raise OasstError("Invalid response type.", OasstErrorCode.TASK_INVALID_RESPONSE_TYPE)
|
|
|
|
return protocol_schema.TaskDone()
|
|
|
|
@managed_tx_method(CommitMode.FLUSH)
|
|
def _enter_state(self, mts: MessageTreeState, state: message_tree_state.State):
|
|
assert mts and mts.active
|
|
|
|
is_terminal = state in message_tree_state.TERMINAL_STATES
|
|
|
|
if is_terminal:
|
|
mts.active = False
|
|
mts.state = state.value
|
|
self.db.add(mts)
|
|
|
|
if is_terminal:
|
|
logger.info(f"Tree entered terminal '{mts.state}' state ({mts.message_tree_id=})")
|
|
else:
|
|
logger.info(f"Tree entered '{mts.state}' state ({mts.message_tree_id=})")
|
|
|
|
def enter_low_grade_state(self, message_tree_id: UUID) -> None:
|
|
logger.debug(f"enter_low_grade_state({message_tree_id=})")
|
|
mts = self.pr.fetch_tree_state(message_tree_id)
|
|
self._enter_state(mts, message_tree_state.State.ABORTED_LOW_GRADE)
|
|
|
|
@managed_tx_method(CommitMode.COMMIT)
|
|
def check_condition_for_growing_state(self, message_tree_id: UUID) -> bool:
|
|
logger.debug(f"check_condition_for_growing_state({message_tree_id=})")
|
|
|
|
mts = self.pr.fetch_tree_state(message_tree_id)
|
|
if not mts.active or mts.state != message_tree_state.State.INITIAL_PROMPT_REVIEW:
|
|
logger.debug(f"False {mts.active=}, {mts.state=}")
|
|
return False
|
|
|
|
# check if initial prompt was accepted
|
|
initial_prompt = self.pr.fetch_message(message_tree_id)
|
|
if not initial_prompt.review_result:
|
|
logger.debug(f"False {initial_prompt.review_result=}")
|
|
return False
|
|
|
|
self._enter_state(mts, message_tree_state.State.GROWING)
|
|
return True
|
|
|
|
@managed_tx_method(CommitMode.COMMIT)
|
|
def check_condition_for_ranking_state(self, message_tree_id: UUID) -> bool:
|
|
logger.debug(f"check_condition_for_ranking_state({message_tree_id=})")
|
|
|
|
mts = self.pr.fetch_tree_state(message_tree_id)
|
|
if not mts.active or mts.state != message_tree_state.State.GROWING:
|
|
logger.debug(f"False {mts.active=}, {mts.state=}")
|
|
return False
|
|
|
|
# check if desired tree size has been reached and all nodes have been reviewed
|
|
tree_size = self.query_tree_size(message_tree_id)
|
|
if tree_size.remaining_messages > 0:
|
|
logger.debug(f"False {tree_size.remaining_messages=}")
|
|
return False
|
|
|
|
self._enter_state(mts, message_tree_state.State.RANKING)
|
|
return True
|
|
|
|
@managed_tx_method(CommitMode.COMMIT)
|
|
def check_condition_for_scoring_state(
|
|
self, message_tree_id: UUID
|
|
) -> Tuple[bool, dict[UUID, list[MessageReaction]]]:
|
|
logger.debug(f"check_condition_for_scoring_state({message_tree_id=})")
|
|
|
|
mts = self.pr.fetch_tree_state(message_tree_id)
|
|
if not mts.active or mts.state != message_tree_state.State.RANKING:
|
|
logger.debug(f"False {mts.active=}, {mts.state=}")
|
|
return False, None
|
|
|
|
ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant"
|
|
rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter)
|
|
for parent_msg_id, ranking in rankings_by_message.items():
|
|
if len(ranking) < self.cfg.num_required_rankings:
|
|
logger.debug(f"False {parent_msg_id=} {len(ranking)=}")
|
|
return False, None
|
|
|
|
self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING)
|
|
return True, rankings_by_message
|
|
|
|
@managed_tx_method(CommitMode.COMMIT)
|
|
def update_message_ranks(self, message_tree_id: UUID, rankings_by_message: Dict[int, int]) -> bool:
|
|
|
|
mts = self.pr.fetch_tree_state(message_tree_id)
|
|
# check state, allow retry if in SCORING_FAILED state
|
|
if mts.state not in (message_tree_state.State.READY_FOR_SCORING, message_tree_state.State.SCORING_FAILED):
|
|
logger.debug(f"False {mts.active=}, {mts.state=}")
|
|
return False
|
|
|
|
try:
|
|
for rankings in rankings_by_message.values():
|
|
sorted_messages = []
|
|
for msg_reaction in rankings:
|
|
sorted_messages.append(msg_reaction.payload.payload.ranked_message_ids)
|
|
logger.debug(f"SORTED MESSAGE {sorted_messages}")
|
|
consensus = ranked_pairs(sorted_messages)
|
|
logger.debug(f"CONSENSUS: {consensus}\n\n")
|
|
for rank, message_id in enumerate(consensus):
|
|
# set rank for each message_id for Message rows
|
|
msg = self.pr.fetch_message(message_id=message_id, fail_if_missing=True)
|
|
msg.rank = rank
|
|
self.db.add(msg)
|
|
|
|
except Exception:
|
|
logger.exception(f"update_message_ranks({message_tree_id=}) failed")
|
|
self._enter_state(mts, message_tree_state.State.SCORING_FAILED)
|
|
return False
|
|
|
|
self._enter_state(mts, message_tree_state.State.READY_FOR_EXPORT)
|
|
return True
|
|
|
|
def _calculate_acceptance(self, labels: list[TextLabels]):
|
|
# calculate acceptance based on spam label
|
|
return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels])
|
|
|
|
def query_prompts_need_review(self) -> list[Message]:
|
|
"""
|
|
Select initial prompt messages with less then required rankings in active message tree
|
|
(active == True in message_tree_state)
|
|
"""
|
|
|
|
qry = (
|
|
self.db.query(Message)
|
|
.select_from(MessageTreeState)
|
|
.join(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
|
|
.filter(
|
|
MessageTreeState.active,
|
|
MessageTreeState.state == message_tree_state.State.INITIAL_PROMPT_REVIEW,
|
|
not_(Message.review_result),
|
|
not_(Message.deleted),
|
|
Message.review_count < self.cfg.num_reviews_initial_prompt,
|
|
Message.parent_id.is_(None),
|
|
)
|
|
)
|
|
|
|
if not settings.DEBUG_ALLOW_SELF_LABELING:
|
|
qry = qry.filter(Message.user_id != self.pr.user_id)
|
|
|
|
return qry.all()
|
|
|
|
def query_replies_need_review(self) -> list[Message]:
|
|
"""
|
|
Select child messages (parent_id IS NOT NULL) with less then required rankings
|
|
in active message tree (active == True in message_tree_state)
|
|
"""
|
|
|
|
qry = (
|
|
self.db.query(Message)
|
|
.select_from(MessageTreeState)
|
|
.join(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
|
|
.filter(
|
|
MessageTreeState.active,
|
|
MessageTreeState.state == message_tree_state.State.GROWING,
|
|
not_(Message.review_result),
|
|
not_(Message.deleted),
|
|
Message.review_count < self.cfg.num_reviews_reply,
|
|
Message.parent_id.is_not(None),
|
|
)
|
|
)
|
|
|
|
if not settings.DEBUG_ALLOW_SELF_LABELING:
|
|
qry = qry.filter(Message.user_id != self.pr.user_id)
|
|
|
|
return qry.all()
|
|
|
|
_sql_find_incomplete_rankings = """
|
|
-- find incomplete rankings
|
|
SELECT m.parent_id, m.role, COUNT(m.id) children_count, MIN(m.ranking_count) child_min_ranking_count,
|
|
COUNT(m.id) FILTER (WHERE m.ranking_count >= :num_required_rankings) as completed_rankings
|
|
FROM message_tree_state mts
|
|
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id
|
|
WHERE mts.active -- only consider active trees
|
|
AND mts.state = :ranking_state -- message tree must be in ranking state
|
|
AND m.review_result -- must be reviewed
|
|
AND NOT m.deleted -- not deleted
|
|
AND m.parent_id IS NOT NULL -- ignore initial prompts
|
|
GROUP BY m.parent_id, m.role
|
|
HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
|
|
"""
|
|
|
|
def query_incomplete_rankings(self) -> list[IncompleteRankingsRow]:
|
|
"""Query parents which have childern that need further rankings"""
|
|
|
|
r = self.db.execute(
|
|
text(self._sql_find_incomplete_rankings),
|
|
{
|
|
"num_required_rankings": self.cfg.num_required_rankings,
|
|
"ranking_state": message_tree_state.State.RANKING,
|
|
},
|
|
)
|
|
return [IncompleteRankingsRow.from_orm(x) for x in r.all()]
|
|
|
|
_sql_find_extendible_parents = """
|
|
-- find all extendible parent nodes
|
|
SELECT m.id as parent_id, m.role as parent_role, m.depth, m.message_tree_id, COUNT(c.id) active_children_count
|
|
FROM message_tree_state mts
|
|
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id -- all elements of message tree
|
|
LEFT JOIN message c ON m.id = c.parent_id -- child nodes
|
|
WHERE mts.active -- only consider active trees
|
|
AND mts.state = :growing_state -- message tree must be growing
|
|
AND NOT m.deleted -- ignore deleted messages as parents
|
|
AND m.depth < mts.max_depth -- ignore leaf nodes as parents
|
|
AND m.review_result -- parent node must have positive review
|
|
AND NOT coalesce(c.deleted, FALSE) -- don't count deleted children
|
|
AND (c.review_result OR coalesce(c.review_count, 0) < :num_reviews_reply) -- don't count children with negative review but count elements under review
|
|
GROUP BY m.id, m.role, m.depth, m.message_tree_id, mts.max_children_count
|
|
HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children
|
|
"""
|
|
|
|
def query_extendible_parents(self) -> list[ExtendibleParentRow]:
|
|
"""Query parent messages that have not reached the maximum number of replies."""
|
|
|
|
r = self.db.execute(
|
|
text(self._sql_find_extendible_parents),
|
|
{
|
|
"growing_state": message_tree_state.State.GROWING,
|
|
"num_reviews_reply": self.cfg.num_reviews_reply,
|
|
},
|
|
)
|
|
return [ExtendibleParentRow.from_orm(x) for x in r.all()]
|
|
|
|
_sql_find_extendible_trees = f"""
|
|
-- find extendible trees
|
|
SELECT m.message_tree_id, mts.goal_tree_size, COUNT(m.id) AS tree_size
|
|
FROM (
|
|
SELECT DISTINCT message_tree_id FROM ({_sql_find_extendible_parents}) extendible_parents
|
|
) trees INNER JOIN message_tree_state mts ON trees.message_tree_id = mts.message_tree_id
|
|
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id
|
|
WHERE NOT m.deleted
|
|
AND (
|
|
m.parent_id IS NOT NULL AND (m.review_result OR m.review_count < :num_reviews_reply) -- children
|
|
OR m.parent_id IS NULL AND m.review_result -- prompts (root nodes) must have positive review
|
|
)
|
|
GROUP BY m.message_tree_id, mts.goal_tree_size
|
|
HAVING COUNT(m.id) < mts.goal_tree_size
|
|
"""
|
|
|
|
def query_extendible_trees(self) -> list[ActiveTreeSizeRow]:
|
|
"""Query size of active message trees in growing state."""
|
|
|
|
r = self.db.execute(
|
|
text(self._sql_find_extendible_trees),
|
|
{
|
|
"growing_state": message_tree_state.State.GROWING,
|
|
"num_reviews_reply": self.cfg.num_reviews_reply,
|
|
},
|
|
)
|
|
return [ActiveTreeSizeRow.from_orm(x) for x in r.all()]
|
|
|
|
def query_tree_size(self, message_tree_id: UUID) -> ActiveTreeSizeRow:
|
|
"""Returns the number of reviewed not deleted messages in the message tree."""
|
|
|
|
qry = (
|
|
self.db.query(
|
|
MessageTreeState.message_tree_id.label("message_tree_id"),
|
|
MessageTreeState.goal_tree_size.label("goal_tree_size"),
|
|
func.count(Message.id).label("tree_size"),
|
|
)
|
|
.select_from(MessageTreeState)
|
|
.outerjoin(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
|
|
.filter(
|
|
MessageTreeState.active,
|
|
not_(Message.deleted),
|
|
Message.review_result,
|
|
MessageTreeState.message_tree_id == message_tree_id,
|
|
)
|
|
.group_by(MessageTreeState.message_tree_id, MessageTreeState.goal_tree_size)
|
|
)
|
|
|
|
return ActiveTreeSizeRow.from_orm(qry.one())
|
|
|
|
def query_misssing_tree_states(self) -> list[UUID]:
|
|
"""Find all initial prompt messages that have no associated message tree state"""
|
|
qry_missing_tree_states = (
|
|
self.db.query(Message.id)
|
|
.outerjoin(MessageTreeState, Message.message_tree_id == MessageTreeState.message_tree_id)
|
|
.filter(
|
|
Message.parent_id.is_(None),
|
|
Message.message_tree_id == Message.id,
|
|
MessageTreeState.message_tree_id.is_(None),
|
|
)
|
|
)
|
|
|
|
return [m.id for m in qry_missing_tree_states.all()]
|
|
|
|
_sql_find_tree_ranking_results = """
|
|
-- get all ranking results of completed tasks for all parents with >= 2 children
|
|
SELECT p.parent_id, mr.* FROM
|
|
(
|
|
-- find parents with > 1 children
|
|
SELECT m.parent_id, m.message_tree_id, COUNT(m.id) children_count
|
|
FROM message_tree_state mts
|
|
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id
|
|
WHERE m.review_result -- must be reviewed
|
|
AND NOT m.deleted -- not deleted
|
|
AND m.parent_id IS NOT NULL -- ignore initial prompts
|
|
AND (:role IS NULL OR m.role = :role) -- children with matching role
|
|
AND mts.message_tree_id = :message_tree_id
|
|
GROUP BY m.parent_id, m.message_tree_id
|
|
HAVING COUNT(m.id) > 1
|
|
) as p
|
|
INNER JOIN task t ON p.parent_id = t.parent_message_id AND t.done AND (t.payload_type = 'RankPrompterRepliesPayload' OR t.payload_type = 'RankAssistantRepliesPayload')
|
|
INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'RankingReactionPayload'
|
|
"""
|
|
|
|
def query_tree_ranking_results(
|
|
self,
|
|
message_tree_id: UUID,
|
|
role_filter: str = "assistant",
|
|
) -> dict[UUID, list[MessageReaction]]:
|
|
"""Finds all completed ranking restuls for a message_tree"""
|
|
|
|
assert role_filter in (None, "assistant", "prompter")
|
|
|
|
r = self.db.execute(
|
|
text(self._sql_find_tree_ranking_results),
|
|
{
|
|
"message_tree_id": message_tree_id,
|
|
"role": role_filter,
|
|
},
|
|
)
|
|
|
|
rankings_by_message = {}
|
|
for x in r.all():
|
|
parent_id = x["parent_id"]
|
|
if parent_id not in rankings_by_message:
|
|
rankings_by_message[parent_id] = []
|
|
if x["task_id"]:
|
|
rankings_by_message[parent_id].append(MessageReaction.from_orm(x))
|
|
return rankings_by_message
|
|
|
|
@managed_tx_method(CommitMode.COMMIT)
|
|
def ensure_tree_states(self):
|
|
"""Add message tree state rows for all root nodes (inital prompt messages)."""
|
|
|
|
missing_tree_ids = self.query_misssing_tree_states()
|
|
for id in missing_tree_ids:
|
|
tree_size = self.db.query(func.count(Message.id)).filter(Message.message_tree_id == id).scalar()
|
|
state = message_tree_state.State.INITIAL_PROMPT_REVIEW
|
|
if tree_size > 1:
|
|
state = message_tree_state.State.GROWING
|
|
logger.info(f"Inserting missing message tree state for message: {id} ({tree_size=}, {state=:s})")
|
|
self._insert_default_state(id, state=state)
|
|
|
|
def query_num_active_trees(self) -> int:
|
|
query = self.db.query(func.count(MessageTreeState.message_tree_id)).filter(MessageTreeState.active)
|
|
return query.scalar()
|
|
|
|
def query_reviews_for_message(self, message_id: UUID) -> list[TextLabels]:
|
|
qry = (
|
|
self.db.query(TextLabels)
|
|
.select_from(Task)
|
|
.join(TextLabels, Task.id == TextLabels.id)
|
|
.filter(Task.done, TextLabels.message_id == message_id)
|
|
)
|
|
return qry.all()
|
|
|
|
@managed_tx_method(CommitMode.FLUSH)
|
|
def _insert_tree_state(
|
|
self,
|
|
root_message_id: UUID,
|
|
goal_tree_size: int,
|
|
max_depth: int,
|
|
max_children_count: int,
|
|
active: bool,
|
|
state: message_tree_state.State = message_tree_state.State.INITIAL_PROMPT_REVIEW,
|
|
) -> MessageTreeState:
|
|
model = MessageTreeState(
|
|
message_tree_id=root_message_id,
|
|
goal_tree_size=goal_tree_size,
|
|
max_depth=max_depth,
|
|
max_children_count=max_children_count,
|
|
state=state.value,
|
|
active=active,
|
|
)
|
|
|
|
self.db.add(model)
|
|
return model
|
|
|
|
@managed_tx_method(CommitMode.FLUSH)
|
|
def _insert_default_state(
|
|
self,
|
|
root_message_id: UUID,
|
|
state: message_tree_state.State = message_tree_state.State.INITIAL_PROMPT_REVIEW,
|
|
) -> MessageTreeState:
|
|
return self._insert_tree_state(
|
|
root_message_id=root_message_id,
|
|
goal_tree_size=self.cfg.goal_tree_size,
|
|
max_depth=self.cfg.max_tree_depth,
|
|
max_children_count=self.cfg.max_children_count,
|
|
state=state,
|
|
active=True,
|
|
)
|
|
|
|
def tree_counts_by_state(self) -> dict[str, int]:
|
|
qry = self.db.query(
|
|
MessageTreeState.state, func.count(MessageTreeState.message_tree_id).label("count")
|
|
).group_by(MessageTreeState.state)
|
|
return {x["state"]: x["count"] for x in qry}
|
|
|
|
def tree_message_count_stats(self, only_active: bool = True) -> list[TreeMessageCountStats]:
|
|
qry = (
|
|
self.db.query(
|
|
MessageTreeState.message_tree_id,
|
|
func.max(Message.depth).label("depth"),
|
|
func.min(Message.created_date).label("oldest"),
|
|
func.max(Message.created_date).label("youngest"),
|
|
func.count(Message.id).label("count"),
|
|
MessageTreeState.goal_tree_size,
|
|
MessageTreeState.state,
|
|
)
|
|
.select_from(MessageTreeState)
|
|
.join(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
|
|
.filter(not_(Message.deleted))
|
|
.group_by(MessageTreeState.message_tree_id)
|
|
)
|
|
|
|
if only_active:
|
|
qry = qry.filter(MessageTreeState.active)
|
|
|
|
return [TreeMessageCountStats(**x) for x in qry]
|
|
|
|
def stats(self) -> TreeManagerStats:
|
|
return TreeManagerStats(
|
|
state_counts=self.tree_counts_by_state(),
|
|
message_counts=self.tree_message_count_stats(only_active=True),
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from oasst_backend.api.deps import api_auth
|
|
from oasst_backend.database import engine
|
|
from oasst_backend.prompt_repository import PromptRepository
|
|
|
|
with Session(engine) as db:
|
|
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db)
|
|
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
|
|
|
|
pr = PromptRepository(db=db, api_client=api_client, client_user=dummy_user)
|
|
|
|
cfg = TreeManagerConfiguration()
|
|
tm = TreeManager(db, pr, cfg)
|
|
tm.ensure_tree_states()
|
|
|
|
# print("query_num_active_trees", tm.query_num_active_trees())
|
|
# print("query_incomplete_rankings", tm.query_incomplete_rankings())
|
|
# print("query_replies_need_review", tm.query_replies_need_review())
|
|
# print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review())
|
|
# print("query_extendible_trees", tm.query_extendible_trees())
|
|
# print("query_extendible_parents", tm.query_extendible_parents())
|
|
# print("query_tree_size", tm.query_tree_size(message_tree_id=UUID("bdf434cf-4df5-4b74-949c-a5a157bc3292")))
|
|
|
|
# print(
|
|
# "query_reviews_for_message",
|
|
# tm.query_reviews_for_message(message_id=UUID("6a444493-0d48-4316-a9f1-7e263f5a2473")),
|
|
# )
|
|
|
|
# print("next_task:", tm.next_task())
|
|
|
|
# print(
|
|
# "query_tree_ranking_results", tm.query_tree_ranking_results(UUID("6036f58f-41b5-48c4-bdd9-b16f34ab1312"))
|
|
# )
|