Files
Open-Assistant/backend/oasst_backend/tree_manager.py
T

799 lines
34 KiB
Python

import random
from enum import Enum
from http import HTTPStatus
from typing import Optional, Tuple
from uuid import UUID
import numpy as np
import pydantic
from loguru import logger
from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversation_message_list
from oasst_backend.config import TreeManagerConfiguration, settings
from oasst_backend.models import Message, MessageReaction, MessageTreeState, TextLabels, message_tree_state
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.utils.hugging_face import HfEmbeddingModel, HfUrl, HuggingFaceAPI
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlalchemy.sql import text
from sqlmodel import Session, func
class TaskType(Enum):
NONE = -1
RANKING = 0
LABEL_REPLY = 1
REPLY = 2
LABEL_PROMPT = 3
PROMPT = 4
class TaskRole(Enum):
ANY = 0
PROMPTER = 1
ASSISTANT = 2
class ActiveTreeSizeRow(pydantic.BaseModel):
message_tree_id: UUID
tree_size: int
goal_tree_size: int
@property
def remaining_messages(self) -> int:
return max(0, self.goal_tree_size - self.tree_size)
class Config:
orm_mode = True
class ExtendibleParentRow(pydantic.BaseModel):
parent_id: UUID
depth: int
message_tree_id: UUID
active_children_count: int
class Config:
orm_mode = True
class IncompleteRankingsRow(pydantic.BaseModel):
parent_id: UUID
children_count: int
child_min_ranking_count: int
class Config:
orm_mode = True
class TreeManager:
_all_text_labels = list(map(lambda x: x.value, protocol_schema.TextLabel))
def __init__(
self, db: Session, prompt_repository: PromptRepository, cfg: Optional[TreeManagerConfiguration] = None
):
self.db = db
self.cfg = cfg or settings.tree_manager
self.pr = prompt_repository
def _task_selection(
self,
desired_task_type: protocol_schema.TaskRequestType,
num_ranking_tasks: int,
num_replies_need_review: int,
num_prompts_need_review: int,
num_missing_prompts: int,
num_missing_replies: int,
) -> Tuple[TaskType, TaskRole]:
"""
Determines which task to hand out to human worker.
The task type is drawn with relative weight (e.g. ranking has highest priority)
depending on what is possible with the current message trees in the database.
"""
logger.debug(
f"TreeManager._task_selection({num_ranking_tasks=}, {num_replies_need_review=}, "
f"{num_prompts_need_review=}, {num_missing_prompts=}, {num_missing_replies=})"
)
task_type = TaskType.NONE
task_role = TaskRole.ANY
if desired_task_type == protocol_schema.TaskRequestType.random:
task_weights = [0] * 5
if num_ranking_tasks > 0:
task_weights[TaskType.RANKING.value] = 10
if num_replies_need_review > 0:
task_weights[TaskType.LABEL_REPLY.value] = 5
if num_prompts_need_review > 0:
task_weights[TaskType.LABEL_PROMPT.value] = 5
if num_missing_replies > 0:
task_weights[TaskType.REPLY.value] = 2
if num_missing_prompts > 0:
task_weights[TaskType.PROMPT.value] = 1
task_weights = np.array(task_weights)
weight_sum = task_weights.sum()
if weight_sum < 1e-8:
task_type = TaskType.NONE
else:
task_weights = task_weights / weight_sum
task_type = TaskType(np.random.choice(a=len(task_weights), p=task_weights))
else:
match desired_task_type:
case protocol_schema.TaskRequestType.initial_prompt:
if num_missing_prompts > 0:
task_type = TaskType.PROMPT
case protocol_schema.TaskRequestType.label_initial_prompt:
if num_prompts_need_review > 0:
task_type = TaskType.LABEL_PROMPT
case protocol_schema.TaskRequestType.assistant_reply | protocol_schema.TaskRequestType.prompter_reply:
if num_missing_replies > 0:
task_role = (
TaskRole.ASSISTANT
if desired_task_type == protocol_schema.TaskRequestType.assistant_reply
else TaskRole.PROMPTER
)
task_type = TaskType.REPLY
case protocol_schema.TaskRequestType.label_assistant_reply | protocol_schema.TaskRequestType.label_prompter_reply:
if num_replies_need_review > 0:
task_role = (
TaskRole.ASSISTANT
if desired_task_type == protocol_schema.TaskRequestType.label_assistant_reply
else TaskRole.PROMPTER
)
task_type = TaskType.LABEL_REPLY
case protocol_schema.TaskRequestType.rank_assistant_replies | protocol_schema.TaskRequestType.rank_prompter_replies:
if num_ranking_tasks > 0:
task_role = (
TaskRole.ASSISTANT
if desired_task_type == protocol_schema.TaskRequestType.rank_assistant_replies
else TaskRole.PROMPTER
)
task_type = TaskType.RANKING
logger.debug(f"Selected {task_type=}, {task_role=}")
return task_type, task_role
def next_task(
self, desired_task_type: protocol_schema.TaskRequestType
) -> Tuple[protocol_schema.Task, Optional[UUID], Optional[UUID]]:
logger.debug("TreeManager.next_task()")
num_active_trees = self.query_num_active_trees()
prompts_need_review = self.query_prompts_need_review()
replies_need_review = self.query_replies_need_review()
incomplete_rankings = self.query_incomplete_rankings()
active_tree_sizes = self.query_extendible_trees()
# determine type of task to generate
num_missing_replies = sum(x.remaining_messages for x in active_tree_sizes)
task_type, task_role = self._task_selection(
desired_task_type,
num_ranking_tasks=len(incomplete_rankings),
num_replies_need_review=len(replies_need_review),
num_prompts_need_review=len(prompts_need_review),
num_missing_prompts=max(0, self.cfg.max_active_trees - num_active_trees),
num_missing_replies=num_missing_replies,
)
if task_type == TaskType.NONE:
raise OasstError(
f"No tasks of type '{desired_task_type.value}' are currently available.",
OasstErrorCode.TASK_REQUESTED_TYPE_NOT_AVAILABLE,
HTTPStatus.SERVICE_UNAVAILABLE,
)
if task_role != TaskRole.ANY:
# Todo: Allow role specific message selection...
raise OasstError(
f"No tasks of type '{desired_task_type.value}' are currently available.",
OasstErrorCode.TASK_REQUESTED_TYPE_NOT_AVAILABLE,
HTTPStatus.SERVICE_UNAVAILABLE,
)
message_tree_id = None
parent_message_id = None
logger.debug(f"selected {task_type=}")
match task_type:
case TaskType.RANKING:
assert len(incomplete_rankings) > 0
ranking_parent_id = random.choice(incomplete_rankings).parent_id
messages = self.pr.fetch_message_conversation(ranking_parent_id)
conversation = prepare_conversation(messages)
replies = self.pr.fetch_message_children(ranking_parent_id, reviewed=True, exclude_deleted=True)
assert len(replies) > 1
random.shuffle(replies) # hand out replies in random order
reply_messages = prepare_conversation_message_list(replies)
replies = [p.text for p in replies]
if messages[-1].role == "assistant":
logger.info("Generating a RankPrompterRepliesTask.")
task = protocol_schema.RankPrompterRepliesTask(
conversation=conversation, replies=replies, reply_messages=reply_messages
)
else:
logger.info("Generating a RankAssistantRepliesTask.")
task = protocol_schema.RankAssistantRepliesTask(
conversation=conversation, replies=replies, reply_messages=reply_messages
)
parent_message_id = ranking_parent_id
message_tree_id = messages[-1].message_tree_id
case TaskType.LABEL_REPLY:
assert len(replies_need_review) > 0
random_reply_message_id = random.choice(replies_need_review)
messages = self.pr.fetch_message_conversation(random_reply_message_id)
conversation = prepare_conversation(messages[:-1])
message = messages[-1]
self.cfg.p_full_labeling_review_reply_prompter: float = 0.1
label_mode = protocol_schema.LabelTaskMode.full
valid_labels = self._all_text_labels
if message.role == "assistant":
if random.random() > self.cfg.p_full_labeling_review_reply_assistant:
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply))
label_mode = protocol_schema.LabelTaskMode.simple
logger.info(f"Generating a LabelAssistantReplyTask. ({label_mode=:s})")
task = protocol_schema.LabelAssistantReplyTask(
message_id=message.id,
conversation=conversation,
reply=message.text,
valid_labels=valid_labels,
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)),
mode=label_mode,
)
else:
if random.random() > self.cfg.p_full_labeling_review_reply_prompter:
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply))
label_mode = protocol_schema.LabelTaskMode.simple
logger.info(f"Generating a LabelPrompterReplyTask. ({label_mode=:s})")
task = protocol_schema.LabelPrompterReplyTask(
message_id=message.id,
conversation=conversation,
reply=message.text,
valid_labels=valid_labels,
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)),
mode=label_mode,
)
parent_message_id = message.id
message_tree_id = message.message_tree_id
case TaskType.REPLY:
# select a tree with missing replies
extensible_parents = self.query_extendible_parents()
assert len(extensible_parents) > 0
# fetch random conversation to extend
random_parent = random.choice(extensible_parents)
logger.debug(f"selected {random_parent=}")
messages = self.pr.fetch_message_conversation(random_parent.parent_id)
assert all(m.review_result for m in messages) # ensure all messages have positive review
conversation = prepare_conversation(messages)
# generate reply task depending on last message
if messages[-1].role == "assistant":
logger.info("Generating a PrompterReplyTask.")
task = protocol_schema.PrompterReplyTask(conversation=conversation)
else:
logger.info("Generating a AssistantReplyTask.")
task = protocol_schema.AssistantReplyTask(conversation=conversation)
parent_message_id = messages[-1].id
message_tree_id = messages[-1].message_tree_id
case TaskType.LABEL_PROMPT:
assert len(prompts_need_review) > 0
message = self.pr.fetch_message(random.choice(prompts_need_review))
label_mode = protocol_schema.LabelTaskMode.full
valid_labels = self._all_text_labels
if random.random() > self.cfg.p_full_labeling_review_prompt:
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt))
label_mode = protocol_schema.LabelTaskMode.simple
logger.info(f"Generating a LabelInitialPromptTask ({label_mode=:s}).")
task = protocol_schema.LabelInitialPromptTask(
message_id=message.id,
prompt=message.text,
valid_labels=valid_labels,
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)),
mode=label_mode,
)
parent_message_id = message.id
message_tree_id = message.message_tree_id
case TaskType.PROMPT:
logger.info("Generating an InitialPromptTask.")
task = protocol_schema.InitialPromptTask(hint=None)
case _:
task = None
logger.info(f"Generated {task=}.")
return task, message_tree_id, parent_message_id
async def handle_interaction(self, interaction: protocol_schema.AnyInteraction) -> protocol_schema.Task:
pr = self.pr
match type(interaction):
case protocol_schema.TextReplyToMessage:
logger.info(
f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
)
# here we store the text reply in the database
message = pr.store_text_reply(
text=interaction.text,
frontend_message_id=interaction.message_id,
user_frontend_message_id=interaction.user_message_id,
)
if not message.parent_id:
logger.info(f"TreeManager: Inserting new tree state for initial prompt {message.id=}")
self._insert_default_state(message.id)
self.db.commit()
if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION:
try:
hugging_face_api = HuggingFaceAPI(
f"{HfUrl.HUGGINGFACE_FEATURE_EXTRACTION.value}/{HfEmbeddingModel.MINILM.value}"
)
embedding = await hugging_face_api.post(interaction.text)
pr.insert_message_embedding(
message_id=message.id, model=HfEmbeddingModel.MINILM.value, embedding=embedding
)
except OasstError:
logger.error(
f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
)
case protocol_schema.MessageRating:
logger.info(
f"Frontend reports rating of {interaction.message_id=} with {interaction.rating=} by {interaction.user=}."
)
pr.store_rating(interaction)
case protocol_schema.MessageRanking:
logger.info(
f"Frontend reports ranking of {interaction.message_id=} with {interaction.ranking=} by {interaction.user=}."
)
_, task = pr.store_ranking(interaction)
self.check_condition_for_scoring_state(task.message_tree_id)
case protocol_schema.TextLabels:
logger.info(
f"Frontend reports labels of {interaction.message_id=} with {interaction.labels=} by {interaction.user=}."
)
_, task, msg = pr.store_text_labels(interaction)
# if it was a respones for a task, check if we have enough reviews to calc review_result
if task and msg:
reviews = self.query_reviews_for_message(msg.id)
acceptance_score = self._calculate_acceptance(reviews)
logger.debug(
f"Message {msg.id=}, {acceptance_score=}, {len(reviews)=}, {msg.review_result=}, {msg.review_count=}"
)
if msg.parent_id is None:
if not msg.review_result and msg.review_count >= self.cfg.num_reviews_initial_prompt:
if acceptance_score > self.cfg.acceptance_threshold_initial_prompt:
msg.review_result = True
self.db.add(msg)
self.db.commit()
logger.info(
f"Initial prompt message was accepted: {msg.id=}, {acceptance_score=}, {len(reviews)=}"
)
else:
self.enter_low_grade_state(msg.message_tree_id)
self.check_condition_for_growing_state(msg.message_tree_id)
elif msg.review_count >= self.cfg.num_reviews_reply:
if not msg.review_result and acceptance_score > self.cfg.acceptance_threshold_reply:
msg.review_result = True
self.db.add(msg)
self.db.commit()
logger.info(
f"Reply message message accepted: {msg.id=}, {acceptance_score=}, {len(reviews)=}"
)
self.check_condition_for_ranking_state(msg.message_tree_id)
case _:
raise OasstError("Invalid response type.", OasstErrorCode.TASK_INVALID_RESPONSE_TYPE)
return protocol_schema.TaskDone()
def _enter_state(self, mts: MessageTreeState, state: message_tree_state.State):
assert mts and mts.active
is_terminal = state in message_tree_state.TERMINAL_STATES
if is_terminal:
mts.active = False
mts.state = state.value
self.db.add(mts)
self.db.commit()
if is_terminal:
logger.info(f"Tree entered terminal '{mts.state}' state ({mts.message_tree_id=})")
else:
logger.info(f"Tree entered '{mts.state}' state ({mts.message_tree_id=})")
def enter_low_grade_state(self, message_tree_id: UUID) -> None:
logger.debug(f"enter_low_grade_state({message_tree_id=})")
mts = self.pr.fetch_tree_state(message_tree_id)
self._enter_state(mts, message_tree_state.State.ABORTED_LOW_GRADE)
def check_condition_for_growing_state(self, message_tree_id: UUID) -> bool:
logger.debug(f"check_condition_for_growing_state({message_tree_id=})")
mts = self.pr.fetch_tree_state(message_tree_id)
if not mts.active or mts.state != message_tree_state.State.INITIAL_PROMPT_REVIEW:
logger.debug(f"False {mts.active=}, {mts.state=}")
return False
# check if initial prompt was accepted
initial_prompt = self.pr.fetch_message(message_tree_id)
if not initial_prompt.review_result:
logger.debug(f"False {initial_prompt.review_result=}")
return False
self._enter_state(mts, message_tree_state.State.GROWING)
return True
def check_condition_for_ranking_state(self, message_tree_id: UUID) -> bool:
logger.debug(f"check_condition_for_ranking_state({message_tree_id=})")
mts = self.pr.fetch_tree_state(message_tree_id)
if not mts.active or mts.state != message_tree_state.State.GROWING:
logger.debug(f"False {mts.active=}, {mts.state=}")
return False
# check if desired tree size has been reached and all nodes have been reviewed
tree_size = self.query_tree_size(message_tree_id)
if tree_size.remaining_messages > 0:
logger.debug(f"False {tree_size.remaining_messages=}")
return False
self._enter_state(mts, message_tree_state.State.RANKING)
return True
def check_condition_for_scoring_state(self, message_tree_id: UUID) -> bool:
logger.debug(f"check_condition_for_scoring_state({message_tree_id=})")
mts: MessageTreeState
mts = self.db.query(MessageTreeState).filter(MessageTreeState.message_tree_id == message_tree_id).one()
if not mts.active or mts.state != message_tree_state.State.RANKING:
logger.debug(f"False {mts.active=}, {mts.state=}")
return False
rankings_by_message = self.query_tree_ranking_results(message_tree_id)
for parent_msg_id, ranking in rankings_by_message.items():
if len(ranking) < self.cfg.num_required_rankings:
logger.debug(f"False {parent_msg_id=} {len(ranking)=}")
return False
self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING)
return True
def _calculate_acceptance(self, labels: list[TextLabels]):
# calculate acceptance based on spam label
return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels])
_sql_find_prompts_need_review = """
-- find initial prompts that need more reviews
SELECT m.id
FROM message_tree_state mts
LEFT JOIN message m ON mts.message_tree_id = m.id
WHERE mts.active
AND mts.state = :state
AND NOT m.review_result
AND NOT m.deleted
AND m.review_count < :num_reviews_initial_prompt
AND m.parent_id is NULL
AND (:excluded_user_id IS NULL OR m.user_id != :excluded_user_id)
"""
def query_prompts_need_review(self) -> list[UUID]:
"""
Select id of initial prompts with less then required rankings in active message tree
(active == True in message_tree_state)
"""
r = self.db.execute(
text(self._sql_find_prompts_need_review),
{
"state": message_tree_state.State.INITIAL_PROMPT_REVIEW,
"num_reviews_initial_prompt": self.cfg.num_reviews_initial_prompt,
"excluded_user_id": None if settings.DEBUG_ALLOW_SELF_LABELING else self.pr.user_id,
},
)
return [x["id"] for x in r.all()]
_sql_find_replies_need_review = """
SELECT m.id
FROM message_tree_state mts
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
WHERE mts.active
AND mts.state = :breeding_state
AND NOT m.review_result
AND NOT m.deleted
AND m.review_count < :num_required_reviews
AND m.parent_id is NOT NULL
AND (:excluded_user_id IS NULL OR m.user_id != :excluded_user_id)
"""
def query_replies_need_review(self) -> list[UUID]:
"""
Select ids of child messages (parent_id IS NOT NULL) with less then required rankings
in active message tree (active == True in message_tree_state)
"""
r = self.db.execute(
text(self._sql_find_replies_need_review),
{
"breeding_state": message_tree_state.State.GROWING,
"num_required_reviews": self.cfg.num_reviews_reply,
"excluded_user_id": None if settings.DEBUG_ALLOW_SELF_LABELING else self.pr.user_id,
},
)
return [x["id"] for x in r.all()]
_sql_find_incomplete_rankings = """
-- find incomplete rankings
SELECT m.parent_id, COUNT(m.id) children_count, MIN(m.ranking_count) child_min_ranking_count,
COUNT(m.id) FILTER (WHERE m.ranking_count >= :num_required_rankings) as completed_rankings
FROM message_tree_state mts
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
WHERE mts.active -- only consider active trees
AND mts.state = :ranking_state -- message tree must be in ranking state
AND m.review_result -- must be reviewed
AND NOT m.deleted -- not deleted
AND m.parent_id IS NOT NULL -- ignore initial prompts
GROUP BY m.parent_id
HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
"""
def query_incomplete_rankings(self) -> list[IncompleteRankingsRow]:
"""Query parents which have childern that need further rankings"""
r = self.db.execute(
text(self._sql_find_incomplete_rankings),
{
"num_required_rankings": self.cfg.num_required_rankings,
"ranking_state": message_tree_state.State.RANKING,
},
)
return [IncompleteRankingsRow.from_orm(x) for x in r.all()]
_sql_find_extendible_parents = """
-- find all extendible parent nodes
SELECT m.id as parent_id, m.depth, m.message_tree_id, COUNT(c.id) active_children_count
FROM message_tree_state mts
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id -- all elements of message tree
LEFT JOIN message c ON m.id = c.Id -- child nodes
WHERE mts.active -- only consider active trees
AND mts.state = :growing_state -- message tree must be growing
AND NOT m.deleted -- ignore deleted messages as parents
AND m.depth < mts.max_depth -- ignore leaf nodes as parents
AND m.review_result -- parent node must have positive review
AND NOT c.deleted -- don't count deleted children
AND (c.review_result OR c.review_count < :num_reviews_reply) -- don't count children with negative review but count elements under review
GROUP BY m.id, m.depth, m.message_tree_id, mts.max_children_count
HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children
"""
def query_extendible_parents(self) -> list[ExtendibleParentRow]:
"""Query parent messages that have not reached the maximum number of replies."""
r = self.db.execute(
text(self._sql_find_extendible_parents),
{
"growing_state": message_tree_state.State.GROWING,
"num_reviews_reply": self.cfg.num_reviews_reply,
},
)
return [ExtendibleParentRow.from_orm(x) for x in r.all()]
_sql_find_extendible_trees = f"""
-- find extendible trees
SELECT m.message_tree_id, mts.goal_tree_size, COUNT(m.id) AS tree_size
FROM (
SELECT DISTINCT message_tree_id FROM ({_sql_find_extendible_parents}) extendible_parents
) trees LEFT JOIN message_tree_state mts ON trees.message_tree_id = mts.message_tree_id
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
WHERE NOT m.deleted
AND (
m.parent_id IS NOT NULL AND (m.review_result OR m.review_count < :num_reviews_reply) -- children
OR m.parent_id IS NULL AND m.review_result -- prompts (root nodes) must have positive review
)
GROUP BY m.message_tree_id, mts.goal_tree_size
HAVING COUNT(m.id) < mts.goal_tree_size
"""
def query_extendible_trees(self) -> list[ActiveTreeSizeRow]:
"""Query size of active message trees in growing state."""
r = self.db.execute(
text(self._sql_find_extendible_trees),
{
"growing_state": message_tree_state.State.GROWING,
"num_reviews_reply": self.cfg.num_reviews_reply,
},
)
return [ActiveTreeSizeRow.from_orm(x) for x in r.all()]
_sql_get_tree_size = """
SELECT mts.message_tree_id, mts.goal_tree_size, COUNT(m.id) AS tree_size
FROM message_tree_state mts
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
WHERE mts.active
AND NOT m.deleted
AND m.review_result
AND mts.message_tree_id = :message_tree_id
GROUP BY mts.message_tree_id, mts.goal_tree_size
"""
def query_tree_size(self, message_tree_id: UUID) -> ActiveTreeSizeRow:
"""Returns the number of reviewed not deleted messages in the message tree."""
r = self.db.execute(text(self._sql_get_tree_size), {"message_tree_id": message_tree_id})
return ActiveTreeSizeRow.from_orm(r.one())
def query_misssing_tree_states(self) -> list[UUID]:
"""Find all initial prompt messages that have no associated message tree state"""
qry_missing_tree_states = (
self.db.query(Message.id)
.join(MessageTreeState, isouter=True)
.filter(
Message.parent_id.is_(None),
Message.message_tree_id == Message.id,
MessageTreeState.message_tree_id.is_(None),
)
)
return [m.id for m in qry_missing_tree_states.all()]
_sql_find_tree_ranking_results = """
-- get all ranking results of completed tasks for all parents with >=2 children
SELECT p.parent_id, mr.* FROM
(
-- find parents with > 1 children
SELECT m.parent_id, m.message_tree_id, COUNT(m.id) children_count
FROM message_tree_state mts
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
WHERE m.review_result -- must be reviewed
AND NOT m.deleted -- not deleted
AND m.parent_id IS NOT NULL -- ignore initial prompts
AND mts.message_tree_id = :message_tree_id
GROUP BY m.parent_id, m.message_tree_id
HAVING COUNT(m.id) > 1
) as p
LEFT JOIN task t ON p.parent_id = t.parent_message_id AND t.done AND (t.payload_type = 'RankPrompterRepliesPayload' OR t.payload_type = 'RankAssistantRepliesPayload')
LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'RankingReactionPayload'
"""
def query_tree_ranking_results(self, message_tree_id: UUID) -> dict[UUID, list[MessageReaction]]:
"""Finds all completed ranking restuls for a message_tree"""
r = self.db.execute(
text(self._sql_find_tree_ranking_results),
{"message_tree_id": message_tree_id},
)
rankings_by_message = {}
for x in r.all():
parent_id = x["parent_id"]
if parent_id not in rankings_by_message:
rankings_by_message[parent_id] = []
if x["task_id"]:
rankings_by_message[parent_id].append(MessageReaction.from_orm(x))
return rankings_by_message
def ensure_tree_states(self):
"""Add message tree state rows for all root nodes (inital prompt messages)."""
missing_tree_ids = self.query_misssing_tree_states()
for id in missing_tree_ids:
tree_size = self.db.query(func.count(Message.id)).filter(Message.message_tree_id == id).scalar()
state = message_tree_state.State.INITIAL_PROMPT_REVIEW
if tree_size > 1:
state = message_tree_state.State.GROWING
logger.info(f"Inserting missing message tree state for message: {id} ({tree_size=}, {state=})")
self._insert_default_state(id, state=state)
self.db.commit()
def query_num_active_trees(self) -> int:
query = self.db.query(func.count(MessageTreeState.message_tree_id)).filter(MessageTreeState.active)
return query.scalar()
def query_reviews_for_message(self, message_id: UUID) -> list[TextLabels]:
sql_qry = """
SELECT tl.*
FROM task t
INNER JOIN text_labels tl ON tl.id = t.id
WHERE t.done = TRUE
AND tl.message_id = :message_id
"""
r = self.db.execute(text(sql_qry), {"message_id": message_id})
return [TextLabels.from_orm(x) for x in r.all()]
def _insert_tree_state(
self,
root_message_id: UUID,
goal_tree_size: int,
max_depth: int,
max_children_count: int,
active: bool,
state: message_tree_state.State = message_tree_state.State.INITIAL_PROMPT_REVIEW,
) -> MessageTreeState:
model = MessageTreeState(
message_tree_id=root_message_id,
goal_tree_size=goal_tree_size,
max_depth=max_depth,
max_children_count=max_children_count,
state=state.value,
active=active,
)
self.db.add(model)
return model
def _insert_default_state(
self,
root_message_id: UUID,
state: message_tree_state.State = message_tree_state.State.INITIAL_PROMPT_REVIEW,
) -> MessageTreeState:
return self._insert_tree_state(
root_message_id=root_message_id,
goal_tree_size=self.cfg.goal_tree_size,
max_depth=self.cfg.max_tree_depth,
max_children_count=self.cfg.max_children_count,
state=state,
active=True,
)
if __name__ == "__main__":
from oasst_backend.api.deps import get_dummy_api_client
from oasst_backend.database import engine
from oasst_backend.prompt_repository import PromptRepository
with Session(engine) as db:
api_client = get_dummy_api_client(db)
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
pr = PromptRepository(db=db, api_client=api_client, client_user=dummy_user)
cfg = TreeManagerConfiguration()
tm = TreeManager(db, pr, cfg)
tm.ensure_tree_states()
print("query_num_active_trees", tm.query_num_active_trees())
print("query_incomplete_rankings", tm.query_incomplete_rankings())
print("query_incomplete_reply_reviews", tm.query_replies_need_review())
print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review())
print("query_extendible_trees", tm.query_extendible_trees())
print("query_extendible_parents", tm.query_extendible_parents())
print("next_task:", tm.next_task())
print(
".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("2ac20d38-6650-43aa-8bb3-f61080c0d921"))
)