mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
770 tree manager allow to specify desired task_type no prompter ranking (#775)
* only ranking assistant replies by default * add tasks/availability endpoint allow to specify desired task * move rank_prompter_replies option to TreeManagerConfiguration * fix type annotation * remove desired_task_type from _random_task_selection() * fix typo * Convert query_tree_size to sqlachemy, return 'full' text-labeling tasks if they were explicitly requested
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
@@ -48,6 +48,27 @@ def request_task(
|
||||
return task
|
||||
|
||||
|
||||
@router.post("/availability", response_model=dict[protocol_schema.TaskRequestType, int])
|
||||
def tasks_availability(
|
||||
*,
|
||||
user: Optional[protocol_schema.User] = None,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
):
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
|
||||
try:
|
||||
pr = PromptRepository(db, api_client, client_user=user)
|
||||
tm = TreeManager(db, pr)
|
||||
return tm.determine_task_availability()
|
||||
|
||||
except OasstError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Task availability query failed.")
|
||||
raise OasstError("Task availability query failed.", OasstErrorCode.TASK_AVAILABILITY_QUERY_FAILED)
|
||||
|
||||
|
||||
@router.post("/{task_id}/ack", response_model=None, status_code=HTTP_204_NO_CONTENT)
|
||||
def tasks_acknowledge(
|
||||
*,
|
||||
|
||||
@@ -55,6 +55,8 @@ class TreeManagerConfiguration(BaseModel):
|
||||
mandatory_labels_prompter_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam]
|
||||
"""Mandatory labels in text-labeling tasks for prompter replies."""
|
||||
|
||||
rank_prompter_replies: bool = False
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME: str = "open-assistant backend"
|
||||
|
||||
@@ -16,7 +16,7 @@ from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingM
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlalchemy.sql import text
|
||||
from sqlmodel import Session, func
|
||||
from sqlmodel import Session, func, not_
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
@@ -49,6 +49,7 @@ class ActiveTreeSizeRow(pydantic.BaseModel):
|
||||
|
||||
class ExtendibleParentRow(pydantic.BaseModel):
|
||||
parent_id: UUID
|
||||
parent_role: str
|
||||
depth: int
|
||||
message_tree_id: UUID
|
||||
active_children_count: int
|
||||
@@ -59,6 +60,7 @@ class ExtendibleParentRow(pydantic.BaseModel):
|
||||
|
||||
class IncompleteRankingsRow(pydantic.BaseModel):
|
||||
parent_id: UUID
|
||||
role: str
|
||||
children_count: int
|
||||
child_min_ranking_count: int
|
||||
|
||||
@@ -70,21 +72,23 @@ class TreeManager:
|
||||
_all_text_labels = list(map(lambda x: x.value, protocol_schema.TextLabel))
|
||||
|
||||
def __init__(
|
||||
self, db: Session, prompt_repository: PromptRepository, cfg: Optional[TreeManagerConfiguration] = None
|
||||
self,
|
||||
db: Session,
|
||||
prompt_repository: PromptRepository,
|
||||
cfg: Optional[TreeManagerConfiguration] = None,
|
||||
):
|
||||
self.db = db
|
||||
self.cfg = cfg or settings.tree_manager
|
||||
self.pr = prompt_repository
|
||||
|
||||
def _task_selection(
|
||||
def _random_task_selection(
|
||||
self,
|
||||
desired_task_type: protocol_schema.TaskRequestType,
|
||||
num_ranking_tasks: int,
|
||||
num_replies_need_review: int,
|
||||
num_prompts_need_review: int,
|
||||
num_missing_prompts: int,
|
||||
num_missing_replies: int,
|
||||
) -> Tuple[TaskType, TaskRole]:
|
||||
) -> TaskType:
|
||||
"""
|
||||
Determines which task to hand out to human worker.
|
||||
The task type is drawn with relative weight (e.g. ranking has highest priority)
|
||||
@@ -92,75 +96,97 @@ class TreeManager:
|
||||
"""
|
||||
|
||||
logger.debug(
|
||||
f"TreeManager._task_selection({num_ranking_tasks=}, {num_replies_need_review=}, "
|
||||
f"TreeManager._random_task_selection({num_ranking_tasks=}, {num_replies_need_review=}, "
|
||||
f"{num_prompts_need_review=}, {num_missing_prompts=}, {num_missing_replies=})"
|
||||
)
|
||||
|
||||
task_type = TaskType.NONE
|
||||
task_role = TaskRole.ANY
|
||||
if desired_task_type == protocol_schema.TaskRequestType.random:
|
||||
task_weights = [0] * 5
|
||||
task_weights = [0] * 5
|
||||
|
||||
if num_ranking_tasks > 0:
|
||||
task_weights[TaskType.RANKING.value] = 10
|
||||
if num_ranking_tasks > 0:
|
||||
task_weights[TaskType.RANKING.value] = 10
|
||||
|
||||
if num_replies_need_review > 0:
|
||||
task_weights[TaskType.LABEL_REPLY.value] = 5
|
||||
if num_replies_need_review > 0:
|
||||
task_weights[TaskType.LABEL_REPLY.value] = 5
|
||||
|
||||
if num_prompts_need_review > 0:
|
||||
task_weights[TaskType.LABEL_PROMPT.value] = 5
|
||||
if num_prompts_need_review > 0:
|
||||
task_weights[TaskType.LABEL_PROMPT.value] = 5
|
||||
|
||||
if num_missing_replies > 0:
|
||||
task_weights[TaskType.REPLY.value] = 2
|
||||
if num_missing_replies > 0:
|
||||
task_weights[TaskType.REPLY.value] = 2
|
||||
|
||||
if num_missing_prompts > 0:
|
||||
task_weights[TaskType.PROMPT.value] = 1
|
||||
if num_missing_prompts > 0:
|
||||
task_weights[TaskType.PROMPT.value] = 1
|
||||
|
||||
task_weights = np.array(task_weights)
|
||||
weight_sum = task_weights.sum()
|
||||
if weight_sum < 1e-8:
|
||||
task_type = TaskType.NONE
|
||||
else:
|
||||
task_weights = task_weights / weight_sum
|
||||
task_type = TaskType(np.random.choice(a=len(task_weights), p=task_weights))
|
||||
else:
|
||||
match desired_task_type:
|
||||
case protocol_schema.TaskRequestType.initial_prompt:
|
||||
if num_missing_prompts > 0:
|
||||
task_type = TaskType.PROMPT
|
||||
case protocol_schema.TaskRequestType.label_initial_prompt:
|
||||
if num_prompts_need_review > 0:
|
||||
task_type = TaskType.LABEL_PROMPT
|
||||
case protocol_schema.TaskRequestType.assistant_reply | protocol_schema.TaskRequestType.prompter_reply:
|
||||
if num_missing_replies > 0:
|
||||
task_role = (
|
||||
TaskRole.ASSISTANT
|
||||
if desired_task_type == protocol_schema.TaskRequestType.assistant_reply
|
||||
else TaskRole.PROMPTER
|
||||
)
|
||||
task_type = TaskType.REPLY
|
||||
case protocol_schema.TaskRequestType.label_assistant_reply | protocol_schema.TaskRequestType.label_prompter_reply:
|
||||
if num_replies_need_review > 0:
|
||||
task_role = (
|
||||
TaskRole.ASSISTANT
|
||||
if desired_task_type == protocol_schema.TaskRequestType.label_assistant_reply
|
||||
else TaskRole.PROMPTER
|
||||
)
|
||||
task_type = TaskType.LABEL_REPLY
|
||||
case protocol_schema.TaskRequestType.rank_assistant_replies | protocol_schema.TaskRequestType.rank_prompter_replies:
|
||||
if num_ranking_tasks > 0:
|
||||
task_role = (
|
||||
TaskRole.ASSISTANT
|
||||
if desired_task_type == protocol_schema.TaskRequestType.rank_assistant_replies
|
||||
else TaskRole.PROMPTER
|
||||
)
|
||||
task_type = TaskType.RANKING
|
||||
task_weights = np.array(task_weights)
|
||||
weight_sum = task_weights.sum()
|
||||
if weight_sum > 1e-8:
|
||||
task_weights = task_weights / weight_sum
|
||||
task_type = TaskType(np.random.choice(a=len(task_weights), p=task_weights))
|
||||
|
||||
logger.debug(f"Selected {task_type=}, {task_role=}")
|
||||
return task_type, task_role
|
||||
logger.debug(f"Selected {task_type=}")
|
||||
return task_type
|
||||
|
||||
def _determine_task_availability_internal(
|
||||
self,
|
||||
num_active_trees: int,
|
||||
extensible_parents: list[ExtendibleParentRow],
|
||||
prompts_need_review: list[Message],
|
||||
replies_need_review: list[Message],
|
||||
incomplete_rankings: list[IncompleteRankingsRow],
|
||||
) -> dict[protocol_schema.TaskRequestType, int]:
|
||||
task_count_by_type: dict[protocol_schema.TaskRequestType, int] = {t: 0 for t in protocol_schema.TaskRequestType}
|
||||
|
||||
num_missing_prompts = max(0, self.cfg.max_active_trees - num_active_trees)
|
||||
task_count_by_type[protocol_schema.TaskRequestType.initial_prompt] = num_missing_prompts
|
||||
|
||||
task_count_by_type[protocol_schema.TaskRequestType.prompter_reply] = len(
|
||||
list(filter(lambda x: x.parent_role == "assistant", extensible_parents))
|
||||
)
|
||||
task_count_by_type[protocol_schema.TaskRequestType.assistant_reply] = len(
|
||||
list(filter(lambda x: x.parent_role == "prompter", extensible_parents))
|
||||
)
|
||||
|
||||
task_count_by_type[protocol_schema.TaskRequestType.label_initial_prompt] = len(prompts_need_review)
|
||||
task_count_by_type[protocol_schema.TaskRequestType.label_assistant_reply] = len(
|
||||
list(filter(lambda m: m.role == "assistant", replies_need_review))
|
||||
)
|
||||
task_count_by_type[protocol_schema.TaskRequestType.prompter_reply] = len(
|
||||
list(filter(lambda m: m.role == "prompter", replies_need_review))
|
||||
)
|
||||
|
||||
if self.cfg.rank_prompter_replies:
|
||||
task_count_by_type[protocol_schema.TaskRequestType.rank_prompter_replies] = len(
|
||||
list(filter(lambda r: r.role == "prompter", incomplete_rankings))
|
||||
)
|
||||
|
||||
task_count_by_type[protocol_schema.TaskRequestType.rank_assistant_replies] = len(
|
||||
list(filter(lambda r: r.role == "assistant", incomplete_rankings))
|
||||
)
|
||||
|
||||
task_count_by_type[protocol_schema.TaskRequestType.random] = sum(
|
||||
task_count_by_type[t] for t in protocol_schema.TaskRequestType if t in task_count_by_type
|
||||
)
|
||||
|
||||
return task_count_by_type
|
||||
|
||||
def determine_task_availability(self) -> dict[protocol_schema.TaskRequestType, int]:
|
||||
num_active_trees = self.query_num_active_trees()
|
||||
extensible_parents = self.query_extendible_parents()
|
||||
prompts_need_review = self.query_prompts_need_review()
|
||||
replies_need_review = self.query_replies_need_review()
|
||||
incomplete_rankings = self.query_incomplete_rankings()
|
||||
|
||||
return self._determine_task_availability_internal(
|
||||
num_active_trees=num_active_trees,
|
||||
extensible_parents=extensible_parents,
|
||||
prompts_need_review=prompts_need_review,
|
||||
replies_need_review=replies_need_review,
|
||||
incomplete_rankings=incomplete_rankings,
|
||||
)
|
||||
|
||||
def next_task(
|
||||
self, desired_task_type: protocol_schema.TaskRequestType
|
||||
self, desired_task_type: protocol_schema.TaskRequestType = protocol_schema.TaskRequestType.random
|
||||
) -> Tuple[protocol_schema.Task, Optional[UUID], Optional[UUID]]:
|
||||
|
||||
logger.debug("TreeManager.next_task()")
|
||||
@@ -168,148 +194,195 @@ class TreeManager:
|
||||
num_active_trees = self.query_num_active_trees()
|
||||
prompts_need_review = self.query_prompts_need_review()
|
||||
replies_need_review = self.query_replies_need_review()
|
||||
extensible_parents = self.query_extendible_parents()
|
||||
|
||||
incomplete_rankings = self.query_incomplete_rankings()
|
||||
if not self.cfg.rank_prompter_replies:
|
||||
incomplete_rankings = list(filter(lambda r: r.role == "assistant", incomplete_rankings))
|
||||
|
||||
active_tree_sizes = self.query_extendible_trees()
|
||||
|
||||
# determine type of task to generate
|
||||
num_missing_replies = sum(x.remaining_messages for x in active_tree_sizes)
|
||||
|
||||
task_type, task_role = self._task_selection(
|
||||
desired_task_type,
|
||||
num_ranking_tasks=len(incomplete_rankings),
|
||||
num_replies_need_review=len(replies_need_review),
|
||||
num_prompts_need_review=len(prompts_need_review),
|
||||
num_missing_prompts=max(0, self.cfg.max_active_trees - num_active_trees),
|
||||
num_missing_replies=num_missing_replies,
|
||||
)
|
||||
|
||||
if task_type == TaskType.NONE:
|
||||
raise OasstError(
|
||||
f"No tasks of type '{desired_task_type.value}' are currently available.",
|
||||
OasstErrorCode.TASK_REQUESTED_TYPE_NOT_AVAILABLE,
|
||||
HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
task_role = TaskRole.ANY
|
||||
if desired_task_type == protocol_schema.TaskRequestType.random:
|
||||
task_type = self._random_task_selection(
|
||||
num_ranking_tasks=len(incomplete_rankings),
|
||||
num_replies_need_review=len(replies_need_review),
|
||||
num_prompts_need_review=len(prompts_need_review),
|
||||
num_missing_prompts=max(0, self.cfg.max_active_trees - num_active_trees),
|
||||
num_missing_replies=num_missing_replies,
|
||||
)
|
||||
|
||||
if task_role != TaskRole.ANY:
|
||||
# Todo: Allow role specific message selection...
|
||||
raise OasstError(
|
||||
f"No tasks of type '{desired_task_type.value}' are currently available.",
|
||||
OasstErrorCode.TASK_REQUESTED_TYPE_NOT_AVAILABLE,
|
||||
HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
if task_type == TaskType.NONE:
|
||||
raise OasstError(
|
||||
f"No tasks of type '{protocol_schema.TaskRequestType.random.value}' are currently available.",
|
||||
OasstErrorCode.TASK_REQUESTED_TYPE_NOT_AVAILABLE,
|
||||
HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
)
|
||||
else:
|
||||
task_count_by_type = self._determine_task_availability_internal(
|
||||
num_active_trees=num_active_trees,
|
||||
extensible_parents=extensible_parents,
|
||||
prompts_need_review=prompts_need_review,
|
||||
replies_need_review=replies_need_review,
|
||||
incomplete_rankings=incomplete_rankings,
|
||||
)
|
||||
|
||||
available_count = task_count_by_type.get(desired_task_type)
|
||||
if not available_count:
|
||||
raise OasstError(
|
||||
f"No tasks of type '{desired_task_type.value}' are currently available.",
|
||||
OasstErrorCode.TASK_REQUESTED_TYPE_NOT_AVAILABLE,
|
||||
HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
)
|
||||
|
||||
task_type_role_map = {
|
||||
protocol_schema.TaskRequestType.initial_prompt: (TaskType.PROMPT, TaskRole.ANY),
|
||||
protocol_schema.TaskRequestType.prompter_reply: (TaskType.REPLY, TaskRole.PROMPTER),
|
||||
protocol_schema.TaskRequestType.assistant_reply: (TaskType.REPLY, TaskRole.ASSISTANT),
|
||||
protocol_schema.TaskRequestType.rank_prompter_replies: (TaskType.RANKING, TaskRole.PROMPTER),
|
||||
protocol_schema.TaskRequestType.rank_assistant_replies: (TaskType.RANKING, TaskRole.ASSISTANT),
|
||||
protocol_schema.TaskRequestType.label_initial_prompt: (TaskType.LABEL_PROMPT, TaskRole.ANY),
|
||||
protocol_schema.TaskRequestType.label_assistant_reply: (TaskType.LABEL_REPLY, TaskRole.ASSISTANT),
|
||||
protocol_schema.TaskRequestType.label_prompter_reply: (TaskType.LABEL_REPLY, TaskRole.PROMPTER),
|
||||
}
|
||||
|
||||
task_type, task_role = task_type_role_map[desired_task_type]
|
||||
|
||||
message_tree_id = None
|
||||
parent_message_id = None
|
||||
|
||||
logger.debug(f"selected {task_type=}")
|
||||
match task_type:
|
||||
case TaskType.RANKING:
|
||||
assert len(incomplete_rankings) > 0
|
||||
ranking_parent_id = random.choice(incomplete_rankings).parent_id
|
||||
if task_role == TaskRole.PROMPTER:
|
||||
incomplete_rankings = list(filter(lambda m: m.role == "prompter", incomplete_rankings))
|
||||
elif task_role == TaskRole.ASSISTANT:
|
||||
incomplete_rankings = list(filter(lambda m: m.role == "assistant", incomplete_rankings))
|
||||
|
||||
messages = self.pr.fetch_message_conversation(ranking_parent_id)
|
||||
assert len(messages) > 1 and messages[-1].id == ranking_parent_id
|
||||
ranking_parent = messages[-1]
|
||||
assert not ranking_parent.deleted and ranking_parent.review_result
|
||||
conversation = prepare_conversation(messages)
|
||||
replies = self.pr.fetch_message_children(ranking_parent_id, reviewed=True, exclude_deleted=True)
|
||||
if len(incomplete_rankings) > 0:
|
||||
ranking_parent_id = random.choice(incomplete_rankings).parent_id
|
||||
|
||||
assert len(replies) > 1
|
||||
random.shuffle(replies) # hand out replies in random order
|
||||
reply_messages = prepare_conversation_message_list(replies)
|
||||
replies = [p.text for p in replies]
|
||||
messages = self.pr.fetch_message_conversation(ranking_parent_id)
|
||||
assert len(messages) > 1 and messages[-1].id == ranking_parent_id
|
||||
ranking_parent = messages[-1]
|
||||
assert not ranking_parent.deleted and ranking_parent.review_result
|
||||
conversation = prepare_conversation(messages)
|
||||
replies = self.pr.fetch_message_children(ranking_parent_id, reviewed=True, exclude_deleted=True)
|
||||
|
||||
if messages[-1].role == "assistant":
|
||||
logger.info("Generating a RankPrompterRepliesTask.")
|
||||
task = protocol_schema.RankPrompterRepliesTask(
|
||||
conversation=conversation,
|
||||
replies=replies,
|
||||
reply_messages=reply_messages,
|
||||
ranking_parent_id=ranking_parent.id,
|
||||
message_tree_id=ranking_parent.message_tree_id,
|
||||
)
|
||||
else:
|
||||
logger.info("Generating a RankAssistantRepliesTask.")
|
||||
task = protocol_schema.RankAssistantRepliesTask(
|
||||
conversation=conversation,
|
||||
replies=replies,
|
||||
reply_messages=reply_messages,
|
||||
ranking_parent_id=ranking_parent.id,
|
||||
message_tree_id=ranking_parent.message_tree_id,
|
||||
)
|
||||
assert len(replies) > 1
|
||||
random.shuffle(replies) # hand out replies in random order
|
||||
reply_messages = prepare_conversation_message_list(replies)
|
||||
replies = [p.text for p in replies]
|
||||
|
||||
parent_message_id = ranking_parent_id
|
||||
message_tree_id = messages[-1].message_tree_id
|
||||
if messages[-1].role == "assistant":
|
||||
logger.info("Generating a RankPrompterRepliesTask.")
|
||||
task = protocol_schema.RankPrompterRepliesTask(
|
||||
conversation=conversation,
|
||||
replies=replies,
|
||||
reply_messages=reply_messages,
|
||||
ranking_parent_id=ranking_parent.id,
|
||||
message_tree_id=ranking_parent.message_tree_id,
|
||||
)
|
||||
else:
|
||||
logger.info("Generating a RankAssistantRepliesTask.")
|
||||
task = protocol_schema.RankAssistantRepliesTask(
|
||||
conversation=conversation,
|
||||
replies=replies,
|
||||
reply_messages=reply_messages,
|
||||
ranking_parent_id=ranking_parent.id,
|
||||
message_tree_id=ranking_parent.message_tree_id,
|
||||
)
|
||||
|
||||
parent_message_id = ranking_parent_id
|
||||
message_tree_id = messages[-1].message_tree_id
|
||||
|
||||
case TaskType.LABEL_REPLY:
|
||||
assert len(replies_need_review) > 0
|
||||
random_reply_message_id = random.choice(replies_need_review)
|
||||
messages = self.pr.fetch_message_conversation(random_reply_message_id)
|
||||
if task_role == TaskRole.PROMPTER:
|
||||
replies_need_review = list(filter(lambda m: m.role == "prompter", replies_need_review))
|
||||
elif task_role == TaskRole.ASSISTANT:
|
||||
replies_need_review = list(filter(lambda m: m.role == "assistant", replies_need_review))
|
||||
|
||||
conversation = prepare_conversation(messages[:-1])
|
||||
message = messages[-1]
|
||||
if len(replies_need_review) > 0:
|
||||
random_reply_message = random.choice(replies_need_review)
|
||||
messages = self.pr.fetch_message_conversation(random_reply_message)
|
||||
|
||||
self.cfg.p_full_labeling_review_reply_prompter: float = 0.1
|
||||
conversation = prepare_conversation(messages[:-1])
|
||||
message = messages[-1]
|
||||
|
||||
label_mode = protocol_schema.LabelTaskMode.full
|
||||
valid_labels = self._all_text_labels
|
||||
self.cfg.p_full_labeling_review_reply_prompter: float = 0.1
|
||||
|
||||
if message.role == "assistant":
|
||||
if random.random() > self.cfg.p_full_labeling_review_reply_assistant:
|
||||
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply))
|
||||
label_mode = protocol_schema.LabelTaskMode.simple
|
||||
logger.info(f"Generating a LabelAssistantReplyTask. ({label_mode=:s})")
|
||||
task = protocol_schema.LabelAssistantReplyTask(
|
||||
message_id=message.id,
|
||||
conversation=conversation,
|
||||
reply=message.text,
|
||||
valid_labels=valid_labels,
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)),
|
||||
mode=label_mode,
|
||||
)
|
||||
else:
|
||||
if random.random() > self.cfg.p_full_labeling_review_reply_prompter:
|
||||
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply))
|
||||
label_mode = protocol_schema.LabelTaskMode.simple
|
||||
logger.info(f"Generating a LabelPrompterReplyTask. ({label_mode=:s})")
|
||||
task = protocol_schema.LabelPrompterReplyTask(
|
||||
message_id=message.id,
|
||||
conversation=conversation,
|
||||
reply=message.text,
|
||||
valid_labels=valid_labels,
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)),
|
||||
mode=label_mode,
|
||||
)
|
||||
label_mode = protocol_schema.LabelTaskMode.full
|
||||
valid_labels = self._all_text_labels
|
||||
|
||||
parent_message_id = message.id
|
||||
message_tree_id = message.message_tree_id
|
||||
if message.role == "assistant":
|
||||
if (
|
||||
desired_task_type == protocol_schema.TaskRequestType.random
|
||||
and random.random() > self.cfg.p_full_labeling_review_reply_assistant
|
||||
):
|
||||
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply))
|
||||
label_mode = protocol_schema.LabelTaskMode.simple
|
||||
logger.info(f"Generating a LabelAssistantReplyTask. ({label_mode=:s})")
|
||||
task = protocol_schema.LabelAssistantReplyTask(
|
||||
message_id=message.id,
|
||||
conversation=conversation,
|
||||
reply=message.text,
|
||||
valid_labels=valid_labels,
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)),
|
||||
mode=label_mode,
|
||||
)
|
||||
else:
|
||||
if (
|
||||
desired_task_type == protocol_schema.TaskRequestType.random
|
||||
and random.random() > self.cfg.p_full_labeling_review_reply_prompter
|
||||
):
|
||||
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply))
|
||||
label_mode = protocol_schema.LabelTaskMode.simple
|
||||
logger.info(f"Generating a LabelPrompterReplyTask. ({label_mode=:s})")
|
||||
task = protocol_schema.LabelPrompterReplyTask(
|
||||
message_id=message.id,
|
||||
conversation=conversation,
|
||||
reply=message.text,
|
||||
valid_labels=valid_labels,
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)),
|
||||
mode=label_mode,
|
||||
)
|
||||
|
||||
parent_message_id = message.id
|
||||
message_tree_id = message.message_tree_id
|
||||
|
||||
case TaskType.REPLY:
|
||||
# select a tree with missing replies
|
||||
extensible_parents = self.query_extendible_parents()
|
||||
assert len(extensible_parents) > 0
|
||||
if task_role == TaskRole.PROMPTER:
|
||||
extensible_parents = list(filter(lambda x: x.parent_role == "assistant", extensible_parents))
|
||||
elif task_role == TaskRole.ASSISTANT:
|
||||
extensible_parents = list(filter(lambda x: x.parent_role == "prompter", extensible_parents))
|
||||
|
||||
# fetch random conversation to extend
|
||||
random_parent = random.choice(extensible_parents)
|
||||
logger.debug(f"selected {random_parent=}")
|
||||
messages = self.pr.fetch_message_conversation(random_parent.parent_id)
|
||||
assert all(m.review_result for m in messages) # ensure all messages have positive review
|
||||
conversation = prepare_conversation(messages)
|
||||
if len(extensible_parents) > 0:
|
||||
random_parent = random.choice(extensible_parents)
|
||||
|
||||
# generate reply task depending on last message
|
||||
if messages[-1].role == "assistant":
|
||||
logger.info("Generating a PrompterReplyTask.")
|
||||
task = protocol_schema.PrompterReplyTask(conversation=conversation)
|
||||
else:
|
||||
logger.info("Generating a AssistantReplyTask.")
|
||||
task = protocol_schema.AssistantReplyTask(conversation=conversation)
|
||||
# fetch random conversation to extend
|
||||
logger.debug(f"selected {random_parent=}")
|
||||
messages = self.pr.fetch_message_conversation(random_parent.parent_id)
|
||||
assert all(m.review_result for m in messages) # ensure all messages have positive review
|
||||
conversation = prepare_conversation(messages)
|
||||
|
||||
parent_message_id = messages[-1].id
|
||||
message_tree_id = messages[-1].message_tree_id
|
||||
# generate reply task depending on last message
|
||||
if messages[-1].role == "assistant":
|
||||
logger.info("Generating a PrompterReplyTask.")
|
||||
task = protocol_schema.PrompterReplyTask(conversation=conversation)
|
||||
else:
|
||||
logger.info("Generating a AssistantReplyTask.")
|
||||
task = protocol_schema.AssistantReplyTask(conversation=conversation)
|
||||
|
||||
parent_message_id = messages[-1].id
|
||||
message_tree_id = messages[-1].message_tree_id
|
||||
|
||||
case TaskType.LABEL_PROMPT:
|
||||
assert len(prompts_need_review) > 0
|
||||
message = self.pr.fetch_message(random.choice(prompts_need_review))
|
||||
message = random.choice(prompts_need_review)
|
||||
|
||||
label_mode = protocol_schema.LabelTaskMode.full
|
||||
valid_labels = self._all_text_labels
|
||||
@@ -337,6 +410,13 @@ class TreeManager:
|
||||
case _:
|
||||
task = None
|
||||
|
||||
if task is None:
|
||||
raise OasstError(
|
||||
f"No task of type '{desired_task_type.value}' is currently available.",
|
||||
OasstErrorCode.TASK_REQUESTED_TYPE_NOT_AVAILABLE,
|
||||
HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
)
|
||||
|
||||
logger.info(f"Generated {task=}.")
|
||||
|
||||
return task, message_tree_id, parent_message_id
|
||||
@@ -515,7 +595,8 @@ class TreeManager:
|
||||
logger.debug(f"False {mts.active=}, {mts.state=}")
|
||||
return False
|
||||
|
||||
rankings_by_message = self.query_tree_ranking_results(message_tree_id)
|
||||
ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant"
|
||||
rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter)
|
||||
for parent_msg_id, ranking in rankings_by_message.items():
|
||||
if len(ranking) < self.cfg.num_required_rankings:
|
||||
logger.debug(f"False {parent_msg_id=} {len(ranking)=}")
|
||||
@@ -528,68 +609,59 @@ class TreeManager:
|
||||
# calculate acceptance based on spam label
|
||||
return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels])
|
||||
|
||||
_sql_find_prompts_need_review = """
|
||||
-- find initial prompts that need more reviews
|
||||
SELECT m.id
|
||||
FROM message_tree_state mts
|
||||
LEFT JOIN message m ON mts.message_tree_id = m.id
|
||||
WHERE mts.active
|
||||
AND mts.state = :state
|
||||
AND NOT m.review_result
|
||||
AND NOT m.deleted
|
||||
AND m.review_count < :num_reviews_initial_prompt
|
||||
AND m.parent_id is NULL
|
||||
AND (:excluded_user_id IS NULL OR m.user_id != :excluded_user_id)
|
||||
"""
|
||||
|
||||
def query_prompts_need_review(self) -> list[UUID]:
|
||||
def query_prompts_need_review(self) -> list[Message]:
|
||||
"""
|
||||
Select id of initial prompts with less then required rankings in active message tree
|
||||
Select initial prompt messages with less then required rankings in active message tree
|
||||
(active == True in message_tree_state)
|
||||
"""
|
||||
|
||||
r = self.db.execute(
|
||||
text(self._sql_find_prompts_need_review),
|
||||
{
|
||||
"state": message_tree_state.State.INITIAL_PROMPT_REVIEW,
|
||||
"num_reviews_initial_prompt": self.cfg.num_reviews_initial_prompt,
|
||||
"excluded_user_id": None if settings.DEBUG_ALLOW_SELF_LABELING else self.pr.user_id,
|
||||
},
|
||||
qry = (
|
||||
self.db.query(Message)
|
||||
.select_from(MessageTreeState)
|
||||
.outerjoin(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
|
||||
.filter(
|
||||
MessageTreeState.active,
|
||||
MessageTreeState.state == message_tree_state.State.INITIAL_PROMPT_REVIEW,
|
||||
not_(Message.review_result),
|
||||
not_(Message.deleted),
|
||||
Message.review_count < self.cfg.num_reviews_initial_prompt,
|
||||
Message.parent_id.is_(None),
|
||||
)
|
||||
)
|
||||
return [x["id"] for x in r.all()]
|
||||
|
||||
_sql_find_replies_need_review = """
|
||||
SELECT m.id
|
||||
FROM message_tree_state mts
|
||||
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
|
||||
WHERE mts.active
|
||||
AND mts.state = :breeding_state
|
||||
AND NOT m.review_result
|
||||
AND NOT m.deleted
|
||||
AND m.review_count < :num_required_reviews
|
||||
AND m.parent_id is NOT NULL
|
||||
AND (:excluded_user_id IS NULL OR m.user_id != :excluded_user_id)
|
||||
"""
|
||||
if not settings.DEBUG_ALLOW_SELF_LABELING:
|
||||
qry = qry.filter(Message.user_id != self.pr.user_id)
|
||||
|
||||
def query_replies_need_review(self) -> list[UUID]:
|
||||
return qry.all()
|
||||
|
||||
def query_replies_need_review(self) -> list[Message]:
|
||||
"""
|
||||
Select ids of child messages (parent_id IS NOT NULL) with less then required rankings
|
||||
Select child messages (parent_id IS NOT NULL) with less then required rankings
|
||||
in active message tree (active == True in message_tree_state)
|
||||
"""
|
||||
|
||||
r = self.db.execute(
|
||||
text(self._sql_find_replies_need_review),
|
||||
{
|
||||
"breeding_state": message_tree_state.State.GROWING,
|
||||
"num_required_reviews": self.cfg.num_reviews_reply,
|
||||
"excluded_user_id": None if settings.DEBUG_ALLOW_SELF_LABELING else self.pr.user_id,
|
||||
},
|
||||
qry = (
|
||||
self.db.query(Message)
|
||||
.select_from(MessageTreeState)
|
||||
.outerjoin(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
|
||||
.filter(
|
||||
MessageTreeState.active,
|
||||
MessageTreeState.state == message_tree_state.State.GROWING,
|
||||
not_(Message.review_result),
|
||||
not_(Message.deleted),
|
||||
Message.review_count < self.cfg.num_reviews_reply,
|
||||
Message.parent_id.is_not(None),
|
||||
)
|
||||
)
|
||||
return [x["id"] for x in r.all()]
|
||||
|
||||
if not settings.DEBUG_ALLOW_SELF_LABELING:
|
||||
qry = qry.filter(Message.user_id != self.pr.user_id)
|
||||
|
||||
return qry.all()
|
||||
|
||||
_sql_find_incomplete_rankings = """
|
||||
-- find incomplete rankings
|
||||
SELECT m.parent_id, COUNT(m.id) children_count, MIN(m.ranking_count) child_min_ranking_count,
|
||||
SELECT m.parent_id, m.role, COUNT(m.id) children_count, MIN(m.ranking_count) child_min_ranking_count,
|
||||
COUNT(m.id) FILTER (WHERE m.ranking_count >= :num_required_rankings) as completed_rankings
|
||||
FROM message_tree_state mts
|
||||
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
|
||||
@@ -598,7 +670,7 @@ WHERE mts.active -- only consider active trees
|
||||
AND m.review_result -- must be reviewed
|
||||
AND NOT m.deleted -- not deleted
|
||||
AND m.parent_id IS NOT NULL -- ignore initial prompts
|
||||
GROUP BY m.parent_id
|
||||
GROUP BY m.parent_id, m.role
|
||||
HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
|
||||
"""
|
||||
|
||||
@@ -616,10 +688,10 @@ HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
|
||||
|
||||
_sql_find_extendible_parents = """
|
||||
-- find all extendible parent nodes
|
||||
SELECT m.id as parent_id, m.depth, m.message_tree_id, COUNT(c.id) active_children_count
|
||||
SELECT m.id as parent_id, m.role as parent_role, m.depth, m.message_tree_id, COUNT(c.id) active_children_count
|
||||
FROM message_tree_state mts
|
||||
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id -- all elements of message tree
|
||||
LEFT JOIN message c ON m.id = c.Id -- child nodes
|
||||
LEFT JOIN message c ON m.id = c.parent_id -- child nodes
|
||||
WHERE mts.active -- only consider active trees
|
||||
AND mts.state = :growing_state -- message tree must be growing
|
||||
AND NOT m.deleted -- ignore deleted messages as parents
|
||||
@@ -627,7 +699,7 @@ WHERE mts.active -- only consider active trees
|
||||
AND m.review_result -- parent node must have positive review
|
||||
AND NOT c.deleted -- don't count deleted children
|
||||
AND (c.review_result OR c.review_count < :num_reviews_reply) -- don't count children with negative review but count elements under review
|
||||
GROUP BY m.id, m.depth, m.message_tree_id, mts.max_children_count
|
||||
GROUP BY m.id, m.role, m.depth, m.message_tree_id, mts.max_children_count
|
||||
HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children
|
||||
"""
|
||||
|
||||
@@ -636,10 +708,7 @@ HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children
|
||||
|
||||
r = self.db.execute(
|
||||
text(self._sql_find_extendible_parents),
|
||||
{
|
||||
"growing_state": message_tree_state.State.GROWING,
|
||||
"num_reviews_reply": self.cfg.num_reviews_reply,
|
||||
},
|
||||
{"growing_state": message_tree_state.State.GROWING, "num_reviews_reply": self.cfg.num_reviews_reply},
|
||||
)
|
||||
return [ExtendibleParentRow.from_orm(x) for x in r.all()]
|
||||
|
||||
@@ -671,21 +740,27 @@ HAVING COUNT(m.id) < mts.goal_tree_size
|
||||
)
|
||||
return [ActiveTreeSizeRow.from_orm(x) for x in r.all()]
|
||||
|
||||
_sql_get_tree_size = """
|
||||
SELECT mts.message_tree_id, mts.goal_tree_size, COUNT(m.id) AS tree_size
|
||||
FROM message_tree_state mts
|
||||
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
|
||||
WHERE mts.active
|
||||
AND NOT m.deleted
|
||||
AND m.review_result
|
||||
AND mts.message_tree_id = :message_tree_id
|
||||
GROUP BY mts.message_tree_id, mts.goal_tree_size
|
||||
"""
|
||||
|
||||
def query_tree_size(self, message_tree_id: UUID) -> ActiveTreeSizeRow:
|
||||
"""Returns the number of reviewed not deleted messages in the message tree."""
|
||||
r = self.db.execute(text(self._sql_get_tree_size), {"message_tree_id": message_tree_id})
|
||||
return ActiveTreeSizeRow.from_orm(r.one())
|
||||
|
||||
qry = (
|
||||
self.db.query(
|
||||
MessageTreeState.message_tree_id.label("message_tree_id"),
|
||||
MessageTreeState.goal_tree_size.label("goal_tree_size"),
|
||||
func.count(Message.id).label("tree_size"),
|
||||
)
|
||||
.select_from(MessageTreeState)
|
||||
.outerjoin(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
|
||||
.filter(
|
||||
MessageTreeState.active,
|
||||
not_(Message.deleted),
|
||||
Message.review_result,
|
||||
MessageTreeState.message_tree_id == message_tree_id,
|
||||
)
|
||||
.group_by(MessageTreeState.message_tree_id, MessageTreeState.goal_tree_size)
|
||||
)
|
||||
|
||||
return ActiveTreeSizeRow.from_orm(qry.one())
|
||||
|
||||
def query_misssing_tree_states(self) -> list[UUID]:
|
||||
"""Find all initial prompt messages that have no associated message tree state"""
|
||||
@@ -702,7 +777,7 @@ GROUP BY mts.message_tree_id, mts.goal_tree_size
|
||||
return [m.id for m in qry_missing_tree_states.all()]
|
||||
|
||||
_sql_find_tree_ranking_results = """
|
||||
-- get all ranking results of completed tasks for all parents with >=2 children
|
||||
-- get all ranking results of completed tasks for all parents with >= 2 children
|
||||
SELECT p.parent_id, mr.* FROM
|
||||
(
|
||||
-- find parents with > 1 children
|
||||
@@ -712,7 +787,8 @@ SELECT p.parent_id, mr.* FROM
|
||||
WHERE m.review_result -- must be reviewed
|
||||
AND NOT m.deleted -- not deleted
|
||||
AND m.parent_id IS NOT NULL -- ignore initial prompts
|
||||
AND mts.message_tree_id = :message_tree_id
|
||||
AND (:role IS NULL OR m.role = :role) -- children with matching role
|
||||
AND mts.message_tree_id = :message_tree_id
|
||||
GROUP BY m.parent_id, m.message_tree_id
|
||||
HAVING COUNT(m.id) > 1
|
||||
) as p
|
||||
@@ -720,11 +796,21 @@ LEFT JOIN task t ON p.parent_id = t.parent_message_id AND t.done AND (t.payload_
|
||||
LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'RankingReactionPayload'
|
||||
"""
|
||||
|
||||
def query_tree_ranking_results(self, message_tree_id: UUID) -> dict[UUID, list[MessageReaction]]:
|
||||
def query_tree_ranking_results(
|
||||
self,
|
||||
message_tree_id: UUID,
|
||||
role_filter: str = "assistant",
|
||||
) -> dict[UUID, list[MessageReaction]]:
|
||||
"""Finds all completed ranking restuls for a message_tree"""
|
||||
|
||||
assert role_filter in (None, "assistant", "prompter")
|
||||
|
||||
r = self.db.execute(
|
||||
text(self._sql_find_tree_ranking_results),
|
||||
{"message_tree_id": message_tree_id},
|
||||
{
|
||||
"message_tree_id": message_tree_id,
|
||||
"role": role_filter,
|
||||
},
|
||||
)
|
||||
|
||||
rankings_by_message = {}
|
||||
@@ -803,12 +889,12 @@ WHERE t.done = TRUE
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from oasst_backend.api.deps import get_dummy_api_client
|
||||
from oasst_backend.api.deps import api_auth
|
||||
from oasst_backend.database import engine
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
|
||||
with Session(engine) as db:
|
||||
api_client = get_dummy_api_client(db)
|
||||
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db)
|
||||
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
|
||||
|
||||
pr = PromptRepository(db=db, api_client=api_client, client_user=dummy_user)
|
||||
@@ -817,15 +903,16 @@ if __name__ == "__main__":
|
||||
tm = TreeManager(db, pr, cfg)
|
||||
tm.ensure_tree_states()
|
||||
|
||||
print("query_num_active_trees", tm.query_num_active_trees())
|
||||
print("query_incomplete_rankings", tm.query_incomplete_rankings())
|
||||
print("query_incomplete_reply_reviews", tm.query_replies_need_review())
|
||||
print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review())
|
||||
print("query_extendible_trees", tm.query_extendible_trees())
|
||||
print("query_extendible_parents", tm.query_extendible_parents())
|
||||
# print("query_num_active_trees", tm.query_num_active_trees())
|
||||
# print("query_incomplete_rankings", tm.query_incomplete_rankings())
|
||||
# print("query_replies_need_review", tm.query_replies_need_review())
|
||||
# print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review())
|
||||
# print("query_extendible_trees", tm.query_extendible_trees())
|
||||
# print("query_extendible_parents", tm.query_extendible_parents())
|
||||
# print("query_tree_size", tm.query_tree_size(message_tree_id=UUID("bdf434cf-4df5-4b74-949c-a5a157bc3292")))
|
||||
|
||||
print("next_task:", tm.next_task())
|
||||
|
||||
print(
|
||||
".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("2ac20d38-6650-43aa-8bb3-f61080c0d921"))
|
||||
)
|
||||
# print(
|
||||
# "query_tree_ranking_results", tm.query_tree_ranking_results(UUID("6036f58f-41b5-48c4-bdd9-b16f34ab1312"))
|
||||
# )
|
||||
|
||||
@@ -32,6 +32,7 @@ class OasstErrorCode(IntEnum):
|
||||
TASK_INTERACTION_REQUEST_FAILED = 1004
|
||||
TASK_GENERATION_FAILED = 1005
|
||||
TASK_REQUESTED_TYPE_NOT_AVAILABLE = 1006
|
||||
TASK_AVAILABILITY_QUERY_FAILED = 1007
|
||||
|
||||
# 2000-3000: prompt_repository
|
||||
INVALID_FRONTEND_MESSAGE_ID = 2000
|
||||
|
||||
Reference in New Issue
Block a user