770 tree manager allow to specify desired task_type no prompter ranking (#775)

* only ranking assistant replies by default

* add tasks/availability endpoint allow to specify desired task

* move rank_prompter_replies option to TreeManagerConfiguration

* fix type annotation

* remove desired_task_type from _random_task_selection()

* fix typo

* Convert query_tree_size to sqlachemy, return 'full' text-labeling tasks if they were explicitly requested
This commit is contained in:
Andreas Köpf
2023-01-16 20:05:40 +01:00
committed by GitHub
parent ad1bd77039
commit ead51ff423
4 changed files with 363 additions and 252 deletions
+22 -1
View File
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Optional
from uuid import UUID
from fastapi import APIRouter, Depends
@@ -48,6 +48,27 @@ def request_task(
return task
@router.post("/availability", response_model=dict[protocol_schema.TaskRequestType, int])
def tasks_availability(
*,
user: Optional[protocol_schema.User] = None,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
):
api_client = deps.api_auth(api_key, db)
try:
pr = PromptRepository(db, api_client, client_user=user)
tm = TreeManager(db, pr)
return tm.determine_task_availability()
except OasstError:
raise
except Exception:
logger.exception("Task availability query failed.")
raise OasstError("Task availability query failed.", OasstErrorCode.TASK_AVAILABILITY_QUERY_FAILED)
@router.post("/{task_id}/ack", response_model=None, status_code=HTTP_204_NO_CONTENT)
def tasks_acknowledge(
*,
+2
View File
@@ -55,6 +55,8 @@ class TreeManagerConfiguration(BaseModel):
mandatory_labels_prompter_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam]
"""Mandatory labels in text-labeling tasks for prompter replies."""
rank_prompter_replies: bool = False
class Settings(BaseSettings):
PROJECT_NAME: str = "open-assistant backend"
+338 -251
View File
@@ -16,7 +16,7 @@ from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingM
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlalchemy.sql import text
from sqlmodel import Session, func
from sqlmodel import Session, func, not_
class TaskType(Enum):
@@ -49,6 +49,7 @@ class ActiveTreeSizeRow(pydantic.BaseModel):
class ExtendibleParentRow(pydantic.BaseModel):
parent_id: UUID
parent_role: str
depth: int
message_tree_id: UUID
active_children_count: int
@@ -59,6 +60,7 @@ class ExtendibleParentRow(pydantic.BaseModel):
class IncompleteRankingsRow(pydantic.BaseModel):
parent_id: UUID
role: str
children_count: int
child_min_ranking_count: int
@@ -70,21 +72,23 @@ class TreeManager:
_all_text_labels = list(map(lambda x: x.value, protocol_schema.TextLabel))
def __init__(
self, db: Session, prompt_repository: PromptRepository, cfg: Optional[TreeManagerConfiguration] = None
self,
db: Session,
prompt_repository: PromptRepository,
cfg: Optional[TreeManagerConfiguration] = None,
):
self.db = db
self.cfg = cfg or settings.tree_manager
self.pr = prompt_repository
def _task_selection(
def _random_task_selection(
self,
desired_task_type: protocol_schema.TaskRequestType,
num_ranking_tasks: int,
num_replies_need_review: int,
num_prompts_need_review: int,
num_missing_prompts: int,
num_missing_replies: int,
) -> Tuple[TaskType, TaskRole]:
) -> TaskType:
"""
Determines which task to hand out to human worker.
The task type is drawn with relative weight (e.g. ranking has highest priority)
@@ -92,75 +96,97 @@ class TreeManager:
"""
logger.debug(
f"TreeManager._task_selection({num_ranking_tasks=}, {num_replies_need_review=}, "
f"TreeManager._random_task_selection({num_ranking_tasks=}, {num_replies_need_review=}, "
f"{num_prompts_need_review=}, {num_missing_prompts=}, {num_missing_replies=})"
)
task_type = TaskType.NONE
task_role = TaskRole.ANY
if desired_task_type == protocol_schema.TaskRequestType.random:
task_weights = [0] * 5
task_weights = [0] * 5
if num_ranking_tasks > 0:
task_weights[TaskType.RANKING.value] = 10
if num_ranking_tasks > 0:
task_weights[TaskType.RANKING.value] = 10
if num_replies_need_review > 0:
task_weights[TaskType.LABEL_REPLY.value] = 5
if num_replies_need_review > 0:
task_weights[TaskType.LABEL_REPLY.value] = 5
if num_prompts_need_review > 0:
task_weights[TaskType.LABEL_PROMPT.value] = 5
if num_prompts_need_review > 0:
task_weights[TaskType.LABEL_PROMPT.value] = 5
if num_missing_replies > 0:
task_weights[TaskType.REPLY.value] = 2
if num_missing_replies > 0:
task_weights[TaskType.REPLY.value] = 2
if num_missing_prompts > 0:
task_weights[TaskType.PROMPT.value] = 1
if num_missing_prompts > 0:
task_weights[TaskType.PROMPT.value] = 1
task_weights = np.array(task_weights)
weight_sum = task_weights.sum()
if weight_sum < 1e-8:
task_type = TaskType.NONE
else:
task_weights = task_weights / weight_sum
task_type = TaskType(np.random.choice(a=len(task_weights), p=task_weights))
else:
match desired_task_type:
case protocol_schema.TaskRequestType.initial_prompt:
if num_missing_prompts > 0:
task_type = TaskType.PROMPT
case protocol_schema.TaskRequestType.label_initial_prompt:
if num_prompts_need_review > 0:
task_type = TaskType.LABEL_PROMPT
case protocol_schema.TaskRequestType.assistant_reply | protocol_schema.TaskRequestType.prompter_reply:
if num_missing_replies > 0:
task_role = (
TaskRole.ASSISTANT
if desired_task_type == protocol_schema.TaskRequestType.assistant_reply
else TaskRole.PROMPTER
)
task_type = TaskType.REPLY
case protocol_schema.TaskRequestType.label_assistant_reply | protocol_schema.TaskRequestType.label_prompter_reply:
if num_replies_need_review > 0:
task_role = (
TaskRole.ASSISTANT
if desired_task_type == protocol_schema.TaskRequestType.label_assistant_reply
else TaskRole.PROMPTER
)
task_type = TaskType.LABEL_REPLY
case protocol_schema.TaskRequestType.rank_assistant_replies | protocol_schema.TaskRequestType.rank_prompter_replies:
if num_ranking_tasks > 0:
task_role = (
TaskRole.ASSISTANT
if desired_task_type == protocol_schema.TaskRequestType.rank_assistant_replies
else TaskRole.PROMPTER
)
task_type = TaskType.RANKING
task_weights = np.array(task_weights)
weight_sum = task_weights.sum()
if weight_sum > 1e-8:
task_weights = task_weights / weight_sum
task_type = TaskType(np.random.choice(a=len(task_weights), p=task_weights))
logger.debug(f"Selected {task_type=}, {task_role=}")
return task_type, task_role
logger.debug(f"Selected {task_type=}")
return task_type
def _determine_task_availability_internal(
self,
num_active_trees: int,
extensible_parents: list[ExtendibleParentRow],
prompts_need_review: list[Message],
replies_need_review: list[Message],
incomplete_rankings: list[IncompleteRankingsRow],
) -> dict[protocol_schema.TaskRequestType, int]:
task_count_by_type: dict[protocol_schema.TaskRequestType, int] = {t: 0 for t in protocol_schema.TaskRequestType}
num_missing_prompts = max(0, self.cfg.max_active_trees - num_active_trees)
task_count_by_type[protocol_schema.TaskRequestType.initial_prompt] = num_missing_prompts
task_count_by_type[protocol_schema.TaskRequestType.prompter_reply] = len(
list(filter(lambda x: x.parent_role == "assistant", extensible_parents))
)
task_count_by_type[protocol_schema.TaskRequestType.assistant_reply] = len(
list(filter(lambda x: x.parent_role == "prompter", extensible_parents))
)
task_count_by_type[protocol_schema.TaskRequestType.label_initial_prompt] = len(prompts_need_review)
task_count_by_type[protocol_schema.TaskRequestType.label_assistant_reply] = len(
list(filter(lambda m: m.role == "assistant", replies_need_review))
)
task_count_by_type[protocol_schema.TaskRequestType.prompter_reply] = len(
list(filter(lambda m: m.role == "prompter", replies_need_review))
)
if self.cfg.rank_prompter_replies:
task_count_by_type[protocol_schema.TaskRequestType.rank_prompter_replies] = len(
list(filter(lambda r: r.role == "prompter", incomplete_rankings))
)
task_count_by_type[protocol_schema.TaskRequestType.rank_assistant_replies] = len(
list(filter(lambda r: r.role == "assistant", incomplete_rankings))
)
task_count_by_type[protocol_schema.TaskRequestType.random] = sum(
task_count_by_type[t] for t in protocol_schema.TaskRequestType if t in task_count_by_type
)
return task_count_by_type
def determine_task_availability(self) -> dict[protocol_schema.TaskRequestType, int]:
num_active_trees = self.query_num_active_trees()
extensible_parents = self.query_extendible_parents()
prompts_need_review = self.query_prompts_need_review()
replies_need_review = self.query_replies_need_review()
incomplete_rankings = self.query_incomplete_rankings()
return self._determine_task_availability_internal(
num_active_trees=num_active_trees,
extensible_parents=extensible_parents,
prompts_need_review=prompts_need_review,
replies_need_review=replies_need_review,
incomplete_rankings=incomplete_rankings,
)
def next_task(
self, desired_task_type: protocol_schema.TaskRequestType
self, desired_task_type: protocol_schema.TaskRequestType = protocol_schema.TaskRequestType.random
) -> Tuple[protocol_schema.Task, Optional[UUID], Optional[UUID]]:
logger.debug("TreeManager.next_task()")
@@ -168,148 +194,195 @@ class TreeManager:
num_active_trees = self.query_num_active_trees()
prompts_need_review = self.query_prompts_need_review()
replies_need_review = self.query_replies_need_review()
extensible_parents = self.query_extendible_parents()
incomplete_rankings = self.query_incomplete_rankings()
if not self.cfg.rank_prompter_replies:
incomplete_rankings = list(filter(lambda r: r.role == "assistant", incomplete_rankings))
active_tree_sizes = self.query_extendible_trees()
# determine type of task to generate
num_missing_replies = sum(x.remaining_messages for x in active_tree_sizes)
task_type, task_role = self._task_selection(
desired_task_type,
num_ranking_tasks=len(incomplete_rankings),
num_replies_need_review=len(replies_need_review),
num_prompts_need_review=len(prompts_need_review),
num_missing_prompts=max(0, self.cfg.max_active_trees - num_active_trees),
num_missing_replies=num_missing_replies,
)
if task_type == TaskType.NONE:
raise OasstError(
f"No tasks of type '{desired_task_type.value}' are currently available.",
OasstErrorCode.TASK_REQUESTED_TYPE_NOT_AVAILABLE,
HTTPStatus.SERVICE_UNAVAILABLE,
task_role = TaskRole.ANY
if desired_task_type == protocol_schema.TaskRequestType.random:
task_type = self._random_task_selection(
num_ranking_tasks=len(incomplete_rankings),
num_replies_need_review=len(replies_need_review),
num_prompts_need_review=len(prompts_need_review),
num_missing_prompts=max(0, self.cfg.max_active_trees - num_active_trees),
num_missing_replies=num_missing_replies,
)
if task_role != TaskRole.ANY:
# Todo: Allow role specific message selection...
raise OasstError(
f"No tasks of type '{desired_task_type.value}' are currently available.",
OasstErrorCode.TASK_REQUESTED_TYPE_NOT_AVAILABLE,
HTTPStatus.SERVICE_UNAVAILABLE,
if task_type == TaskType.NONE:
raise OasstError(
f"No tasks of type '{protocol_schema.TaskRequestType.random.value}' are currently available.",
OasstErrorCode.TASK_REQUESTED_TYPE_NOT_AVAILABLE,
HTTPStatus.SERVICE_UNAVAILABLE,
)
else:
task_count_by_type = self._determine_task_availability_internal(
num_active_trees=num_active_trees,
extensible_parents=extensible_parents,
prompts_need_review=prompts_need_review,
replies_need_review=replies_need_review,
incomplete_rankings=incomplete_rankings,
)
available_count = task_count_by_type.get(desired_task_type)
if not available_count:
raise OasstError(
f"No tasks of type '{desired_task_type.value}' are currently available.",
OasstErrorCode.TASK_REQUESTED_TYPE_NOT_AVAILABLE,
HTTPStatus.SERVICE_UNAVAILABLE,
)
task_type_role_map = {
protocol_schema.TaskRequestType.initial_prompt: (TaskType.PROMPT, TaskRole.ANY),
protocol_schema.TaskRequestType.prompter_reply: (TaskType.REPLY, TaskRole.PROMPTER),
protocol_schema.TaskRequestType.assistant_reply: (TaskType.REPLY, TaskRole.ASSISTANT),
protocol_schema.TaskRequestType.rank_prompter_replies: (TaskType.RANKING, TaskRole.PROMPTER),
protocol_schema.TaskRequestType.rank_assistant_replies: (TaskType.RANKING, TaskRole.ASSISTANT),
protocol_schema.TaskRequestType.label_initial_prompt: (TaskType.LABEL_PROMPT, TaskRole.ANY),
protocol_schema.TaskRequestType.label_assistant_reply: (TaskType.LABEL_REPLY, TaskRole.ASSISTANT),
protocol_schema.TaskRequestType.label_prompter_reply: (TaskType.LABEL_REPLY, TaskRole.PROMPTER),
}
task_type, task_role = task_type_role_map[desired_task_type]
message_tree_id = None
parent_message_id = None
logger.debug(f"selected {task_type=}")
match task_type:
case TaskType.RANKING:
assert len(incomplete_rankings) > 0
ranking_parent_id = random.choice(incomplete_rankings).parent_id
if task_role == TaskRole.PROMPTER:
incomplete_rankings = list(filter(lambda m: m.role == "prompter", incomplete_rankings))
elif task_role == TaskRole.ASSISTANT:
incomplete_rankings = list(filter(lambda m: m.role == "assistant", incomplete_rankings))
messages = self.pr.fetch_message_conversation(ranking_parent_id)
assert len(messages) > 1 and messages[-1].id == ranking_parent_id
ranking_parent = messages[-1]
assert not ranking_parent.deleted and ranking_parent.review_result
conversation = prepare_conversation(messages)
replies = self.pr.fetch_message_children(ranking_parent_id, reviewed=True, exclude_deleted=True)
if len(incomplete_rankings) > 0:
ranking_parent_id = random.choice(incomplete_rankings).parent_id
assert len(replies) > 1
random.shuffle(replies) # hand out replies in random order
reply_messages = prepare_conversation_message_list(replies)
replies = [p.text for p in replies]
messages = self.pr.fetch_message_conversation(ranking_parent_id)
assert len(messages) > 1 and messages[-1].id == ranking_parent_id
ranking_parent = messages[-1]
assert not ranking_parent.deleted and ranking_parent.review_result
conversation = prepare_conversation(messages)
replies = self.pr.fetch_message_children(ranking_parent_id, reviewed=True, exclude_deleted=True)
if messages[-1].role == "assistant":
logger.info("Generating a RankPrompterRepliesTask.")
task = protocol_schema.RankPrompterRepliesTask(
conversation=conversation,
replies=replies,
reply_messages=reply_messages,
ranking_parent_id=ranking_parent.id,
message_tree_id=ranking_parent.message_tree_id,
)
else:
logger.info("Generating a RankAssistantRepliesTask.")
task = protocol_schema.RankAssistantRepliesTask(
conversation=conversation,
replies=replies,
reply_messages=reply_messages,
ranking_parent_id=ranking_parent.id,
message_tree_id=ranking_parent.message_tree_id,
)
assert len(replies) > 1
random.shuffle(replies) # hand out replies in random order
reply_messages = prepare_conversation_message_list(replies)
replies = [p.text for p in replies]
parent_message_id = ranking_parent_id
message_tree_id = messages[-1].message_tree_id
if messages[-1].role == "assistant":
logger.info("Generating a RankPrompterRepliesTask.")
task = protocol_schema.RankPrompterRepliesTask(
conversation=conversation,
replies=replies,
reply_messages=reply_messages,
ranking_parent_id=ranking_parent.id,
message_tree_id=ranking_parent.message_tree_id,
)
else:
logger.info("Generating a RankAssistantRepliesTask.")
task = protocol_schema.RankAssistantRepliesTask(
conversation=conversation,
replies=replies,
reply_messages=reply_messages,
ranking_parent_id=ranking_parent.id,
message_tree_id=ranking_parent.message_tree_id,
)
parent_message_id = ranking_parent_id
message_tree_id = messages[-1].message_tree_id
case TaskType.LABEL_REPLY:
assert len(replies_need_review) > 0
random_reply_message_id = random.choice(replies_need_review)
messages = self.pr.fetch_message_conversation(random_reply_message_id)
if task_role == TaskRole.PROMPTER:
replies_need_review = list(filter(lambda m: m.role == "prompter", replies_need_review))
elif task_role == TaskRole.ASSISTANT:
replies_need_review = list(filter(lambda m: m.role == "assistant", replies_need_review))
conversation = prepare_conversation(messages[:-1])
message = messages[-1]
if len(replies_need_review) > 0:
random_reply_message = random.choice(replies_need_review)
messages = self.pr.fetch_message_conversation(random_reply_message)
self.cfg.p_full_labeling_review_reply_prompter: float = 0.1
conversation = prepare_conversation(messages[:-1])
message = messages[-1]
label_mode = protocol_schema.LabelTaskMode.full
valid_labels = self._all_text_labels
self.cfg.p_full_labeling_review_reply_prompter: float = 0.1
if message.role == "assistant":
if random.random() > self.cfg.p_full_labeling_review_reply_assistant:
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply))
label_mode = protocol_schema.LabelTaskMode.simple
logger.info(f"Generating a LabelAssistantReplyTask. ({label_mode=:s})")
task = protocol_schema.LabelAssistantReplyTask(
message_id=message.id,
conversation=conversation,
reply=message.text,
valid_labels=valid_labels,
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)),
mode=label_mode,
)
else:
if random.random() > self.cfg.p_full_labeling_review_reply_prompter:
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply))
label_mode = protocol_schema.LabelTaskMode.simple
logger.info(f"Generating a LabelPrompterReplyTask. ({label_mode=:s})")
task = protocol_schema.LabelPrompterReplyTask(
message_id=message.id,
conversation=conversation,
reply=message.text,
valid_labels=valid_labels,
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)),
mode=label_mode,
)
label_mode = protocol_schema.LabelTaskMode.full
valid_labels = self._all_text_labels
parent_message_id = message.id
message_tree_id = message.message_tree_id
if message.role == "assistant":
if (
desired_task_type == protocol_schema.TaskRequestType.random
and random.random() > self.cfg.p_full_labeling_review_reply_assistant
):
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply))
label_mode = protocol_schema.LabelTaskMode.simple
logger.info(f"Generating a LabelAssistantReplyTask. ({label_mode=:s})")
task = protocol_schema.LabelAssistantReplyTask(
message_id=message.id,
conversation=conversation,
reply=message.text,
valid_labels=valid_labels,
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)),
mode=label_mode,
)
else:
if (
desired_task_type == protocol_schema.TaskRequestType.random
and random.random() > self.cfg.p_full_labeling_review_reply_prompter
):
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply))
label_mode = protocol_schema.LabelTaskMode.simple
logger.info(f"Generating a LabelPrompterReplyTask. ({label_mode=:s})")
task = protocol_schema.LabelPrompterReplyTask(
message_id=message.id,
conversation=conversation,
reply=message.text,
valid_labels=valid_labels,
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)),
mode=label_mode,
)
parent_message_id = message.id
message_tree_id = message.message_tree_id
case TaskType.REPLY:
# select a tree with missing replies
extensible_parents = self.query_extendible_parents()
assert len(extensible_parents) > 0
if task_role == TaskRole.PROMPTER:
extensible_parents = list(filter(lambda x: x.parent_role == "assistant", extensible_parents))
elif task_role == TaskRole.ASSISTANT:
extensible_parents = list(filter(lambda x: x.parent_role == "prompter", extensible_parents))
# fetch random conversation to extend
random_parent = random.choice(extensible_parents)
logger.debug(f"selected {random_parent=}")
messages = self.pr.fetch_message_conversation(random_parent.parent_id)
assert all(m.review_result for m in messages) # ensure all messages have positive review
conversation = prepare_conversation(messages)
if len(extensible_parents) > 0:
random_parent = random.choice(extensible_parents)
# generate reply task depending on last message
if messages[-1].role == "assistant":
logger.info("Generating a PrompterReplyTask.")
task = protocol_schema.PrompterReplyTask(conversation=conversation)
else:
logger.info("Generating a AssistantReplyTask.")
task = protocol_schema.AssistantReplyTask(conversation=conversation)
# fetch random conversation to extend
logger.debug(f"selected {random_parent=}")
messages = self.pr.fetch_message_conversation(random_parent.parent_id)
assert all(m.review_result for m in messages) # ensure all messages have positive review
conversation = prepare_conversation(messages)
parent_message_id = messages[-1].id
message_tree_id = messages[-1].message_tree_id
# generate reply task depending on last message
if messages[-1].role == "assistant":
logger.info("Generating a PrompterReplyTask.")
task = protocol_schema.PrompterReplyTask(conversation=conversation)
else:
logger.info("Generating a AssistantReplyTask.")
task = protocol_schema.AssistantReplyTask(conversation=conversation)
parent_message_id = messages[-1].id
message_tree_id = messages[-1].message_tree_id
case TaskType.LABEL_PROMPT:
assert len(prompts_need_review) > 0
message = self.pr.fetch_message(random.choice(prompts_need_review))
message = random.choice(prompts_need_review)
label_mode = protocol_schema.LabelTaskMode.full
valid_labels = self._all_text_labels
@@ -337,6 +410,13 @@ class TreeManager:
case _:
task = None
if task is None:
raise OasstError(
f"No task of type '{desired_task_type.value}' is currently available.",
OasstErrorCode.TASK_REQUESTED_TYPE_NOT_AVAILABLE,
HTTPStatus.SERVICE_UNAVAILABLE,
)
logger.info(f"Generated {task=}.")
return task, message_tree_id, parent_message_id
@@ -515,7 +595,8 @@ class TreeManager:
logger.debug(f"False {mts.active=}, {mts.state=}")
return False
rankings_by_message = self.query_tree_ranking_results(message_tree_id)
ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant"
rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter)
for parent_msg_id, ranking in rankings_by_message.items():
if len(ranking) < self.cfg.num_required_rankings:
logger.debug(f"False {parent_msg_id=} {len(ranking)=}")
@@ -528,68 +609,59 @@ class TreeManager:
# calculate acceptance based on spam label
return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels])
_sql_find_prompts_need_review = """
-- find initial prompts that need more reviews
SELECT m.id
FROM message_tree_state mts
LEFT JOIN message m ON mts.message_tree_id = m.id
WHERE mts.active
AND mts.state = :state
AND NOT m.review_result
AND NOT m.deleted
AND m.review_count < :num_reviews_initial_prompt
AND m.parent_id is NULL
AND (:excluded_user_id IS NULL OR m.user_id != :excluded_user_id)
"""
def query_prompts_need_review(self) -> list[UUID]:
def query_prompts_need_review(self) -> list[Message]:
"""
Select id of initial prompts with less then required rankings in active message tree
Select initial prompt messages with less then required rankings in active message tree
(active == True in message_tree_state)
"""
r = self.db.execute(
text(self._sql_find_prompts_need_review),
{
"state": message_tree_state.State.INITIAL_PROMPT_REVIEW,
"num_reviews_initial_prompt": self.cfg.num_reviews_initial_prompt,
"excluded_user_id": None if settings.DEBUG_ALLOW_SELF_LABELING else self.pr.user_id,
},
qry = (
self.db.query(Message)
.select_from(MessageTreeState)
.outerjoin(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
.filter(
MessageTreeState.active,
MessageTreeState.state == message_tree_state.State.INITIAL_PROMPT_REVIEW,
not_(Message.review_result),
not_(Message.deleted),
Message.review_count < self.cfg.num_reviews_initial_prompt,
Message.parent_id.is_(None),
)
)
return [x["id"] for x in r.all()]
_sql_find_replies_need_review = """
SELECT m.id
FROM message_tree_state mts
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
WHERE mts.active
AND mts.state = :breeding_state
AND NOT m.review_result
AND NOT m.deleted
AND m.review_count < :num_required_reviews
AND m.parent_id is NOT NULL
AND (:excluded_user_id IS NULL OR m.user_id != :excluded_user_id)
"""
if not settings.DEBUG_ALLOW_SELF_LABELING:
qry = qry.filter(Message.user_id != self.pr.user_id)
def query_replies_need_review(self) -> list[UUID]:
return qry.all()
def query_replies_need_review(self) -> list[Message]:
"""
Select ids of child messages (parent_id IS NOT NULL) with less then required rankings
Select child messages (parent_id IS NOT NULL) with less then required rankings
in active message tree (active == True in message_tree_state)
"""
r = self.db.execute(
text(self._sql_find_replies_need_review),
{
"breeding_state": message_tree_state.State.GROWING,
"num_required_reviews": self.cfg.num_reviews_reply,
"excluded_user_id": None if settings.DEBUG_ALLOW_SELF_LABELING else self.pr.user_id,
},
qry = (
self.db.query(Message)
.select_from(MessageTreeState)
.outerjoin(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
.filter(
MessageTreeState.active,
MessageTreeState.state == message_tree_state.State.GROWING,
not_(Message.review_result),
not_(Message.deleted),
Message.review_count < self.cfg.num_reviews_reply,
Message.parent_id.is_not(None),
)
)
return [x["id"] for x in r.all()]
if not settings.DEBUG_ALLOW_SELF_LABELING:
qry = qry.filter(Message.user_id != self.pr.user_id)
return qry.all()
_sql_find_incomplete_rankings = """
-- find incomplete rankings
SELECT m.parent_id, COUNT(m.id) children_count, MIN(m.ranking_count) child_min_ranking_count,
SELECT m.parent_id, m.role, COUNT(m.id) children_count, MIN(m.ranking_count) child_min_ranking_count,
COUNT(m.id) FILTER (WHERE m.ranking_count >= :num_required_rankings) as completed_rankings
FROM message_tree_state mts
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
@@ -598,7 +670,7 @@ WHERE mts.active -- only consider active trees
AND m.review_result -- must be reviewed
AND NOT m.deleted -- not deleted
AND m.parent_id IS NOT NULL -- ignore initial prompts
GROUP BY m.parent_id
GROUP BY m.parent_id, m.role
HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
"""
@@ -616,10 +688,10 @@ HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
_sql_find_extendible_parents = """
-- find all extendible parent nodes
SELECT m.id as parent_id, m.depth, m.message_tree_id, COUNT(c.id) active_children_count
SELECT m.id as parent_id, m.role as parent_role, m.depth, m.message_tree_id, COUNT(c.id) active_children_count
FROM message_tree_state mts
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id -- all elements of message tree
LEFT JOIN message c ON m.id = c.Id -- child nodes
LEFT JOIN message c ON m.id = c.parent_id -- child nodes
WHERE mts.active -- only consider active trees
AND mts.state = :growing_state -- message tree must be growing
AND NOT m.deleted -- ignore deleted messages as parents
@@ -627,7 +699,7 @@ WHERE mts.active -- only consider active trees
AND m.review_result -- parent node must have positive review
AND NOT c.deleted -- don't count deleted children
AND (c.review_result OR c.review_count < :num_reviews_reply) -- don't count children with negative review but count elements under review
GROUP BY m.id, m.depth, m.message_tree_id, mts.max_children_count
GROUP BY m.id, m.role, m.depth, m.message_tree_id, mts.max_children_count
HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children
"""
@@ -636,10 +708,7 @@ HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children
r = self.db.execute(
text(self._sql_find_extendible_parents),
{
"growing_state": message_tree_state.State.GROWING,
"num_reviews_reply": self.cfg.num_reviews_reply,
},
{"growing_state": message_tree_state.State.GROWING, "num_reviews_reply": self.cfg.num_reviews_reply},
)
return [ExtendibleParentRow.from_orm(x) for x in r.all()]
@@ -671,21 +740,27 @@ HAVING COUNT(m.id) < mts.goal_tree_size
)
return [ActiveTreeSizeRow.from_orm(x) for x in r.all()]
_sql_get_tree_size = """
SELECT mts.message_tree_id, mts.goal_tree_size, COUNT(m.id) AS tree_size
FROM message_tree_state mts
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
WHERE mts.active
AND NOT m.deleted
AND m.review_result
AND mts.message_tree_id = :message_tree_id
GROUP BY mts.message_tree_id, mts.goal_tree_size
"""
def query_tree_size(self, message_tree_id: UUID) -> ActiveTreeSizeRow:
"""Returns the number of reviewed not deleted messages in the message tree."""
r = self.db.execute(text(self._sql_get_tree_size), {"message_tree_id": message_tree_id})
return ActiveTreeSizeRow.from_orm(r.one())
qry = (
self.db.query(
MessageTreeState.message_tree_id.label("message_tree_id"),
MessageTreeState.goal_tree_size.label("goal_tree_size"),
func.count(Message.id).label("tree_size"),
)
.select_from(MessageTreeState)
.outerjoin(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
.filter(
MessageTreeState.active,
not_(Message.deleted),
Message.review_result,
MessageTreeState.message_tree_id == message_tree_id,
)
.group_by(MessageTreeState.message_tree_id, MessageTreeState.goal_tree_size)
)
return ActiveTreeSizeRow.from_orm(qry.one())
def query_misssing_tree_states(self) -> list[UUID]:
"""Find all initial prompt messages that have no associated message tree state"""
@@ -702,7 +777,7 @@ GROUP BY mts.message_tree_id, mts.goal_tree_size
return [m.id for m in qry_missing_tree_states.all()]
_sql_find_tree_ranking_results = """
-- get all ranking results of completed tasks for all parents with >=2 children
-- get all ranking results of completed tasks for all parents with >= 2 children
SELECT p.parent_id, mr.* FROM
(
-- find parents with > 1 children
@@ -712,7 +787,8 @@ SELECT p.parent_id, mr.* FROM
WHERE m.review_result -- must be reviewed
AND NOT m.deleted -- not deleted
AND m.parent_id IS NOT NULL -- ignore initial prompts
AND mts.message_tree_id = :message_tree_id
AND (:role IS NULL OR m.role = :role) -- children with matching role
AND mts.message_tree_id = :message_tree_id
GROUP BY m.parent_id, m.message_tree_id
HAVING COUNT(m.id) > 1
) as p
@@ -720,11 +796,21 @@ LEFT JOIN task t ON p.parent_id = t.parent_message_id AND t.done AND (t.payload_
LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'RankingReactionPayload'
"""
def query_tree_ranking_results(self, message_tree_id: UUID) -> dict[UUID, list[MessageReaction]]:
def query_tree_ranking_results(
self,
message_tree_id: UUID,
role_filter: str = "assistant",
) -> dict[UUID, list[MessageReaction]]:
"""Finds all completed ranking restuls for a message_tree"""
assert role_filter in (None, "assistant", "prompter")
r = self.db.execute(
text(self._sql_find_tree_ranking_results),
{"message_tree_id": message_tree_id},
{
"message_tree_id": message_tree_id,
"role": role_filter,
},
)
rankings_by_message = {}
@@ -803,12 +889,12 @@ WHERE t.done = TRUE
if __name__ == "__main__":
from oasst_backend.api.deps import get_dummy_api_client
from oasst_backend.api.deps import api_auth
from oasst_backend.database import engine
from oasst_backend.prompt_repository import PromptRepository
with Session(engine) as db:
api_client = get_dummy_api_client(db)
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db)
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
pr = PromptRepository(db=db, api_client=api_client, client_user=dummy_user)
@@ -817,15 +903,16 @@ if __name__ == "__main__":
tm = TreeManager(db, pr, cfg)
tm.ensure_tree_states()
print("query_num_active_trees", tm.query_num_active_trees())
print("query_incomplete_rankings", tm.query_incomplete_rankings())
print("query_incomplete_reply_reviews", tm.query_replies_need_review())
print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review())
print("query_extendible_trees", tm.query_extendible_trees())
print("query_extendible_parents", tm.query_extendible_parents())
# print("query_num_active_trees", tm.query_num_active_trees())
# print("query_incomplete_rankings", tm.query_incomplete_rankings())
# print("query_replies_need_review", tm.query_replies_need_review())
# print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review())
# print("query_extendible_trees", tm.query_extendible_trees())
# print("query_extendible_parents", tm.query_extendible_parents())
# print("query_tree_size", tm.query_tree_size(message_tree_id=UUID("bdf434cf-4df5-4b74-949c-a5a157bc3292")))
print("next_task:", tm.next_task())
print(
".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("2ac20d38-6650-43aa-8bb3-f61080c0d921"))
)
# print(
# "query_tree_ranking_results", tm.query_tree_ranking_results(UUID("6036f58f-41b5-48c4-bdd9-b16f34ab1312"))
# )
@@ -32,6 +32,7 @@ class OasstErrorCode(IntEnum):
TASK_INTERACTION_REQUEST_FAILED = 1004
TASK_GENERATION_FAILED = 1005
TASK_REQUESTED_TYPE_NOT_AVAILABLE = 1006
TASK_AVAILABILITY_QUERY_FAILED = 1007
# 2000-3000: prompt_repository
INVALID_FRONTEND_MESSAGE_ID = 2000