From 14fa08e2e7ee85d859380fcd6ed2042e1e4d2e3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Wed, 11 Jan 2023 10:54:03 +0100 Subject: [PATCH] Message tree state machine (#555) * add query_incomplete_rankings() * Add SQL queries for TreeManager task selection * first working version of TreeManager.next_task() * remove old generate_task(), add mandatory_labels to text_labels task * Add ConversationMessage list to Ranking tasks * add more sophisticated sql queries to find extendible trees * add TreeManager.query_extendible_parents() * fix task validation, seed data insertion (reviewed) * provide user for task selection in text-frontend * enter 'growing' state * enter 'aborted_low_grade' state * enter 'ranking' state * check tree 'growing' state upon relpy insertion * exclude user from labeling their own messages (added DEBUG_ALLOW_SELF_LABELING setting) * add DEBUG_ALLOW_SELF_LABELING to docker-compose.yaml * fix ranking submission * add query_tree_ranking_results() * add ranked_message_ids to RankingReactionPayload * fix reply_messages instead of prompt_messages * incorment 'ranking_count' of ranked replies * added logic to check_condition_for_scoring_state * changes to msg_tree_state_machine * pre-commit changes * enter 'ready_for_scoring' state * re-add HF embedding call (lost during merge) * use prepare_conversation() helper for seed-data creation * Partially add user specified task selection Co-authored-by: Daniel Hug --- ansible/dev.yaml | 1 + ...40_restructure_message_tree_state_table.py | 63 ++ ...4a81_add_review_count_ranking_count_to_.py | 31 + backend/main.py | 42 +- backend/oasst_backend/api/v1/tasks.py | 224 +---- backend/oasst_backend/api/v1/utils.py | 23 +- backend/oasst_backend/config.py | 1 + backend/oasst_backend/models/db_payload.py | 11 +- backend/oasst_backend/models/message.py | 4 + .../models/message_tree_state.py | 45 +- backend/oasst_backend/prompt_repository.py | 227 ++++- backend/oasst_backend/task_repository.py | 20 +- backend/oasst_backend/tree_manager.py | 804 ++++++++++++++++++ backend/oasst_backend/utils/hugging_face.py | 6 +- docker-compose.yaml | 1 + .../exceptions/oasst_api_error.py | 14 +- oasst-shared/oasst_shared/schemas/protocol.py | 9 +- scripts/backend-development/run-local.sh | 1 + text-frontend/__main__.py | 8 +- 19 files changed, 1212 insertions(+), 323 deletions(-) create mode 100644 backend/alembic/versions/2023_01_08_2208-92a367bb9f40_restructure_message_tree_state_table.py create mode 100644 backend/alembic/versions/2023_01_09_0047-05975b274a81_add_review_count_ranking_count_to_.py create mode 100644 backend/oasst_backend/tree_manager.py diff --git a/ansible/dev.yaml b/ansible/dev.yaml index 6f93aaa6..3cf061a5 100644 --- a/ansible/dev.yaml +++ b/ansible/dev.yaml @@ -79,6 +79,7 @@ REDIS_HOST: oasst-redis DEBUG_ALLOW_ANY_API_KEY: "true" DEBUG_USE_SEED_DATA: "true" + DEBUG_ALLOW_SELF_LABELING: "true" MAX_WORKERS: "1" RATE_LIMIT: "false" DEBUG_SKIP_EMBEDDING_COMPUTATION: "true" diff --git a/backend/alembic/versions/2023_01_08_2208-92a367bb9f40_restructure_message_tree_state_table.py b/backend/alembic/versions/2023_01_08_2208-92a367bb9f40_restructure_message_tree_state_table.py new file mode 100644 index 00000000..db1f7127 --- /dev/null +++ b/backend/alembic/versions/2023_01_08_2208-92a367bb9f40_restructure_message_tree_state_table.py @@ -0,0 +1,63 @@ +"""restructure message_tree_state table + +Revision ID: 92a367bb9f40 +Revises: ba61fe17fb6e +Create Date: 2023-01-08 22:08:46.458195 + +""" +import sqlalchemy as sa +import sqlmodel +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "92a367bb9f40" +down_revision = "aac6b2f66006" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("message_tree_state") + op.create_table( + "message_tree_state", + sa.Column("message_tree_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("goal_tree_size", sa.Integer(), nullable=False), + sa.Column("max_depth", sa.Integer(), nullable=False), + sa.Column("max_children_count", sa.Integer(), nullable=False), + sa.Column("state", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False), + sa.Column("active", sa.Boolean(), nullable=False), + sa.Column("accepted_messages", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["message_tree_id"], + ["message.id"], + ), + sa.PrimaryKeyConstraint("message_tree_id"), + ) + op.create_index(op.f("ix_message_tree_state_active"), "message_tree_state", ["active"], unique=False) + op.create_index(op.f("ix_message_tree_state_state"), "message_tree_state", ["state"], unique=False) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_message_tree_state_state"), table_name="message_tree_state") + op.drop_index(op.f("ix_message_tree_state_active"), table_name="message_tree_state") + op.drop_table("message_tree_state") + op.create_table( + "message_tree_state", + sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False), + sa.Column("message_tree_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("state", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False), + sa.Column("goal_tree_size", sa.Integer(), nullable=False), + sa.Column("current_num_non_filtered_messages", sa.Integer(), nullable=False), + sa.Column("max_depth", sa.Integer(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_message_tree_state_message_tree_id"), "message_tree_state", ["message_tree_id"], unique=False + ) + op.create_index("ix_message_tree_state_tree_id", "message_tree_state", ["message_tree_id"], unique=True) + # ### end Alembic commands ### diff --git a/backend/alembic/versions/2023_01_09_0047-05975b274a81_add_review_count_ranking_count_to_.py b/backend/alembic/versions/2023_01_09_0047-05975b274a81_add_review_count_ranking_count_to_.py new file mode 100644 index 00000000..f39a6e69 --- /dev/null +++ b/backend/alembic/versions/2023_01_09_0047-05975b274a81_add_review_count_ranking_count_to_.py @@ -0,0 +1,31 @@ +"""add review_count & ranking_count to message + +Revision ID: 05975b274a81 +Revises: 92a367bb9f40 +Create Date: 2023-01-09 00:47:25.496036 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "05975b274a81" +down_revision = "92a367bb9f40" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("message", sa.Column("review_count", sa.Integer(), server_default=sa.text("0"), nullable=False)) + op.add_column("message", sa.Column("review_result", sa.Boolean(), server_default=sa.text("false"), nullable=False)) + op.add_column("message", sa.Column("ranking_count", sa.Integer(), server_default=sa.text("0"), nullable=False)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("message", "ranking_count") + op.drop_column("message", "review_result") + op.drop_column("message", "review_count") + # ### end Alembic commands ### diff --git a/backend/main.py b/backend/main.py index b84a2d9e..78c7a27a 100644 --- a/backend/main.py +++ b/backend/main.py @@ -12,9 +12,12 @@ from fastapi_limiter import FastAPILimiter from loguru import logger from oasst_backend.api.deps import get_dummy_api_client from oasst_backend.api.v1.api import api_router +from oasst_backend.api.v1.utils import prepare_conversation from oasst_backend.config import settings from oasst_backend.database import engine +from oasst_backend.models import message_tree_state from oasst_backend.prompt_repository import PromptRepository, TaskRepository, UserRepository +from oasst_backend.tree_manager import TreeManager, TreeManagerConfiguration from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from pydantic import BaseModel @@ -116,6 +119,7 @@ if settings.DEBUG_USE_SEED_DATA: pr = PromptRepository( db=db, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr ) + tm = TreeManager(db, pr, TreeManagerConfiguration()) with open(settings.DEBUG_USE_SEED_DATA_PATH) as f: dummy_messages_raw = json.load(f) @@ -138,24 +142,19 @@ if settings.DEBUG_USE_SEED_DATA: msg.parent_message_id, fail_if_missing=True ) conversation_messages = pr.fetch_message_conversation(parent_message) - conversation = protocol_schema.Conversation( - messages=[ - protocol_schema.ConversationMessage( - text=cmsg.text, - is_assistant=cmsg.role == "assistant", - message_id=cmsg.id, - fronend_message_id=cmsg.frontend_message_id, - ) - for cmsg in conversation_messages - ] - ) + conversation = prepare_conversation(conversation_messages) task = tr.store_task( protocol_schema.AssistantReplyTask(conversation=conversation), message_tree_id=parent_message.message_tree_id, parent_message_id=parent_message.id, ) tr.bind_frontend_message_id(task.id, msg.task_message_id) - message = pr.store_text_reply(msg.text, msg.task_message_id, msg.user_message_id) + message = pr.store_text_reply( + msg.text, msg.task_message_id, msg.user_message_id, review_count=5, review_result=True + ) + if message.parent_id is None: + tm._insert_default_state(root_message_id=message.id, state=message_tree_state.State.GROWING) + db.commit() logger.info( f"Inserted: message_id: {message.id}, payload: {message.payload.payload}, parent_message_id: {message.parent_id}" @@ -168,6 +167,19 @@ if settings.DEBUG_USE_SEED_DATA: logger.exception("Seed data insertion failed") +@app.on_event("startup") +def ensure_tree_states(): + try: + logger.info("Startup: TreeManager.ensure_tree_states()") + cfg = TreeManagerConfiguration() # TODO: decide where config is stored, e.g. load form json/yaml file + with Session(engine) as db: + tm = TreeManager(db, None, configuration=cfg) + tm.ensure_tree_states() + + except Exception: + logger.exception("TreeManager.ensure_tree_states() failed.") + + app.include_router(api_router, prefix=settings.API_V1_STR) @@ -175,7 +187,7 @@ def get_openapi_schema(): return json.dumps(app.openapi()) -if __name__ == "__main__": +def main(): # Importing here so we don't import packages unnecessarily if we're # importing main as a module. import argparse @@ -198,3 +210,7 @@ if __name__ == "__main__": print(get_openapi_schema()) else: uvicorn.run(app, host=args.host, port=args.port) + + +if __name__ == "__main__": + main() diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index 821ba562..4a976f45 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -1,15 +1,12 @@ -import random -from typing import Any, Optional, Tuple +from typing import Any from uuid import UUID from fastapi import APIRouter, Depends from fastapi.security.api_key import APIKey from loguru import logger from oasst_backend.api import deps -from oasst_backend.api.v1.utils import prepare_conversation -from oasst_backend.config import settings from oasst_backend.prompt_repository import PromptRepository, TaskRepository -from oasst_backend.utils.hugging_face import HfEmbeddingModel, HfUrl, HuggingFaceAPI +from oasst_backend.tree_manager import TreeManager, TreeManagerConfiguration from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from sqlmodel import Session @@ -18,160 +15,6 @@ from starlette.status import HTTP_204_NO_CONTENT router = APIRouter() -def generate_task( - request: protocol_schema.TaskRequest, pr: PromptRepository -) -> Tuple[protocol_schema.Task, Optional[UUID], Optional[UUID]]: - message_tree_id = None - parent_message_id = None - - match request.type: - case protocol_schema.TaskRequestType.random: - logger.info("Frontend requested a random task.") - disabled_tasks = ( - protocol_schema.TaskRequestType.random, - protocol_schema.TaskRequestType.summarize_story, - protocol_schema.TaskRequestType.rate_summary, - ) - candidate_tasks = set(protocol_schema.TaskRequestType).difference(disabled_tasks) - request.type = random.choice(tuple(candidate_tasks)).value - return generate_task(request, pr) - - # AKo: Summary tasks are currently disabled/supported, we focus on the conversation tasks. - - # case protocol_schema.TaskRequestType.summarize_story: - # logger.info("Generating a SummarizeStoryTask.") - # task = protocol_schema.SummarizeStoryTask( - # story="This is a story. A very long story. So long, it needs to be summarized.", - # ) - # case protocol_schema.TaskRequestType.rate_summary: - # logger.info("Generating a RateSummaryTask.") - # task = protocol_schema.RateSummaryTask( - # full_text="This is a story. A very long story. So long, it needs to be summarized.", - # summary="This is a summary.", - # scale=protocol_schema.RatingScale(min=1, max=5), - # ) - - case protocol_schema.TaskRequestType.initial_prompt: - logger.info("Generating an InitialPromptTask.") - task = protocol_schema.InitialPromptTask( - hint="Ask the assistant about a current event." # this is optional - ) - case protocol_schema.TaskRequestType.prompter_reply: - logger.info("Generating a PrompterReplyTask.") - messages = pr.fetch_random_conversation("assistant") - task_messages = [ - protocol_schema.ConversationMessage( - text=msg.text, - is_assistant=(msg.role == "assistant"), - message_id=msg.id, - front_end_id=msg.frontend_message_id, - ) - for msg in messages - ] - - task = protocol_schema.PrompterReplyTask(conversation=protocol_schema.Conversation(messages=task_messages)) - message_tree_id = messages[-1].message_tree_id - parent_message_id = messages[-1].id - case protocol_schema.TaskRequestType.assistant_reply: - logger.info("Generating a AssistantReplyTask.") - messages = pr.fetch_random_conversation("prompter") - task_messages = [ - protocol_schema.ConversationMessage( - text=msg.text, - is_assistant=(msg.role == "assistant"), - message_id=msg.id, - front_end_id=msg.frontend_message_id, - ) - for msg in messages - ] - - task = protocol_schema.AssistantReplyTask(conversation=protocol_schema.Conversation(messages=task_messages)) - message_tree_id = messages[-1].message_tree_id - parent_message_id = messages[-1].id - case protocol_schema.TaskRequestType.rank_initial_prompts: - logger.info("Generating a RankInitialPromptsTask.") - - messages = pr.fetch_random_initial_prompts() - task = protocol_schema.RankInitialPromptsTask(prompts=[msg.text for msg in messages]) - case protocol_schema.TaskRequestType.rank_prompter_replies: - logger.info("Generating a RankPrompterRepliesTask.") - conversation, replies = pr.fetch_multiple_random_replies(message_role="assistant") - - task_messages = [ - protocol_schema.ConversationMessage( - text=p.text, - is_assistant=(p.role == "assistant"), - message_id=p.id, - front_end_id=p.frontend_message_id, - ) - for p in conversation - ] - replies = [p.text for p in replies] - task = protocol_schema.RankPrompterRepliesTask( - conversation=protocol_schema.Conversation( - messages=task_messages, - ), - replies=replies, - ) - - case protocol_schema.TaskRequestType.rank_assistant_replies: - logger.info("Generating a RankAssistantRepliesTask.") - conversation, replies = pr.fetch_multiple_random_replies(message_role="prompter") - - task_messages = [ - protocol_schema.ConversationMessage( - text=p.text, - is_assistant=(p.role == "assistant"), - message_id=p.id, - front_end_id=p.frontend_message_id, - ) - for p in conversation - ] - replies = [p.text for p in replies] - task = protocol_schema.RankAssistantRepliesTask( - conversation=prepare_conversation(conversation), - replies=replies, - ) - - case protocol_schema.TaskRequestType.label_initial_prompt: - logger.info("Generating a LabelInitialPromptTask.") - message = pr.fetch_random_initial_prompts(1)[0] - task = protocol_schema.LabelInitialPromptTask( - message_id=message.id, - prompt=message.text, - valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)), - ) - - case protocol_schema.TaskRequestType.label_prompter_reply: - logger.info("Generating a LabelPrompterReplyTask.") - conversation, messages = pr.fetch_multiple_random_replies(max_size=1, message_role="assistant") - message = messages[0] - task = protocol_schema.LabelPrompterReplyTask( - message_id=message.id, - conversation=prepare_conversation(conversation), - reply=message.text, - valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)), - ) - - case protocol_schema.TaskRequestType.label_assistant_reply: - logger.info("Generating a LabelAssistantReplyTask.") - conversation, messages = pr.fetch_multiple_random_replies(max_size=1, message_role="prompter") - message = messages[0] - task = protocol_schema.LabelAssistantReplyTask( - message_id=message.id, - conversation=prepare_conversation(conversation), - reply=message.text, - valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)), - ) - - case _: - raise OasstError("Invalid request type", OasstErrorCode.TASK_INVALID_REQUEST_TYPE) - - logger.info(f"Generated {task=}.") - - return task, message_tree_id, parent_message_id - - @router.post( "/", response_model=protocol_schema.AnyTask, @@ -193,7 +36,9 @@ def request_task( try: pr = PromptRepository(db, api_client, client_user=request.user) - task, message_tree_id, parent_message_id = generate_task(request, pr) + tree_manager_config = TreeManagerConfiguration() + tm = TreeManager(db, pr, tree_manager_config) + task, message_tree_id, parent_message_id = tm.next_task(request.type) pr.task_repository.store_task(task, message_tree_id, parent_message_id, request.collective) except OasstError: @@ -268,63 +113,10 @@ async def tasks_interaction( try: pr = PromptRepository(db, api_client, client_user=interaction.user) + tree_manager_config = TreeManagerConfiguration() + tm = TreeManager(db, pr, tree_manager_config) + return await tm.handle_interaction(interaction) - match type(interaction): - case protocol_schema.TextReplyToMessage: - logger.info( - f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}." - ) - - # here we store the text reply in the database - newMessage = pr.store_text_reply( - text=interaction.text, - frontend_message_id=interaction.message_id, - user_frontend_message_id=interaction.user_message_id, - ) - - if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION: - try: - hugging_face_api = HuggingFaceAPI( - f"{HfUrl.HUGGINGFACE_FEATURE_EXTRACTION.value}/{HfEmbeddingModel.MINILM.value}" - ) - embedding = await hugging_face_api.post(interaction.text) - pr.insert_message_embedding( - message_id=newMessage.id, model=HfEmbeddingModel.MINILM.value, embedding=embedding - ) - except OasstError: - logger.error( - f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}." - ) - - return protocol_schema.TaskDone() - case protocol_schema.MessageRating: - logger.info( - f"Frontend reports rating of {interaction.message_id=} with {interaction.rating=} by {interaction.user=}." - ) - - # here we store the rating in the database - pr.store_rating(interaction) - - return protocol_schema.TaskDone() - case protocol_schema.MessageRanking: - logger.info( - f"Frontend reports ranking of {interaction.message_id=} with {interaction.ranking=} by {interaction.user=}." - ) - - # TODO: check if the ranking is valid - pr.store_ranking(interaction) - # here we would store the ranking in the database - return protocol_schema.TaskDone() - case protocol_schema.TextLabels: - logger.info( - f"Frontend reports labels of {interaction.message_id=} with {interaction.labels=} by {interaction.user=}." - ) - # Labels are implicitly validated when converting str -> TextLabel - # So no need for explicit validation here - pr.store_text_labels(interaction) - return protocol_schema.TaskDone() - case _: - raise OasstError("Invalid response type.", OasstErrorCode.TASK_INVALID_RESPONSE_TYPE) except OasstError: raise except Exception: diff --git a/backend/oasst_backend/api/v1/utils.py b/backend/oasst_backend/api/v1/utils.py index 4e20395f..a8b54483 100644 --- a/backend/oasst_backend/api/v1/utils.py +++ b/backend/oasst_backend/api/v1/utils.py @@ -18,19 +18,20 @@ def prepare_message_list(messages: list[Message]) -> list[protocol.Message]: return [prepare_message(m) for m in messages] -def prepare_conversation(messages: list[Message]) -> protocol.Conversation: - conv_messages = [] - for message in messages: - conv_messages.append( - protocol.ConversationMessage( - text=message.text, - is_assistant=(message.role == "assistant"), - message_id=message.id, - frontend_message_id=message.frontend_message_id, - ) +def prepare_conversation_message_list(messages: list[Message]) -> list[protocol.ConversationMessage]: + return [ + protocol.ConversationMessage( + text=message.text, + is_assistant=(message.role == "assistant"), + message_id=message.id, + frontend_message_id=message.frontend_message_id, ) + for message in messages + ] - return protocol.Conversation(messages=conv_messages) + +def prepare_conversation(messages: list[Message]) -> protocol.Conversation: + return protocol.Conversation(messages=prepare_conversation_message_list(messages)) def prepare_tree(tree: list[Message], tree_id: UUID) -> protocol.MessageTree: diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index ed394412..7b21b3d5 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -25,6 +25,7 @@ class Settings(BaseSettings): DEBUG_USE_SEED_DATA_PATH: Optional[FilePath] = ( Path(__file__).parent.parent / "test_data/generic/test_generic_data.json" ) + DEBUG_ALLOW_SELF_LABELING: bool = False # allow users to label their own messages DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False HUGGING_FACE_API_KEY: str = "" diff --git a/backend/oasst_backend/models/db_payload.py b/backend/oasst_backend/models/db_payload.py index fed60dd8..590e9f5b 100644 --- a/backend/oasst_backend/models/db_payload.py +++ b/backend/oasst_backend/models/db_payload.py @@ -1,4 +1,4 @@ -from typing import Literal +from typing import Literal, Optional from uuid import UUID from oasst_backend.models.payload_column_type import payload_type @@ -28,7 +28,7 @@ class RateSummaryPayload(TaskPayload): @payload_type class InitialPromptPayload(TaskPayload): type: Literal["initial_prompt"] = "initial_prompt" - hint: str + hint: str | None @payload_type @@ -64,12 +64,13 @@ class RatingReactionPayload(ReactionPayload): class RankingReactionPayload(ReactionPayload): type: Literal["message_ranking"] = "message_ranking" ranking: list[int] + ranked_message_ids: list[UUID] @payload_type class RankConversationRepliesPayload(TaskPayload): conversation: protocol_schema.Conversation # the conversation so far - replies: list[str] + reply_messages: list[protocol_schema.ConversationMessage] @payload_type @@ -77,7 +78,7 @@ class RankInitialPromptsPayload(TaskPayload): """A task to rank a set of initial prompts.""" type: Literal["rank_initial_prompts"] = "rank_initial_prompts" - prompts: list[str] + prompt_messages: list[protocol_schema.ConversationMessage] @payload_type @@ -102,6 +103,7 @@ class LabelInitialPromptPayload(TaskPayload): message_id: UUID prompt: str valid_labels: list[str] + mandatory_labels: Optional[list[str]] @payload_type @@ -112,6 +114,7 @@ class LabelConversationReplyPayload(TaskPayload): conversation: protocol_schema.Conversation reply: str valid_labels: list[str] + mandatory_labels: Optional[list[str]] @payload_type diff --git a/backend/oasst_backend/models/message.py b/backend/oasst_backend/models/message.py index 6d24fd13..488656a5 100644 --- a/backend/oasst_backend/models/message.py +++ b/backend/oasst_backend/models/message.py @@ -41,6 +41,10 @@ class Message(SQLModel, table=True): children_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False)) deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false())) + review_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False)) + review_result: bool = Field(sa_column=sa.Column(sa.Boolean, default=False, server_default=false(), nullable=False)) + ranking_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False)) + def ensure_is_message(self) -> None: if not self.payload or not isinstance(self.payload.payload, MessagePayload): raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE, HTTPStatus.INTERNAL_SERVER_ERROR) diff --git a/backend/oasst_backend/models/message_tree_state.py b/backend/oasst_backend/models/message_tree_state.py index 97ad34eb..ba551129 100644 --- a/backend/oasst_backend/models/message_tree_state.py +++ b/backend/oasst_backend/models/message_tree_state.py @@ -1,29 +1,28 @@ from enum import Enum -from typing import Optional -from uuid import UUID, uuid4 +from uuid import UUID import sqlalchemy as sa import sqlalchemy.dialects.postgresql as pg -from sqlmodel import Field, Index, SQLModel +from sqlmodel import Field, SQLModel -class States(str, Enum): +class State(str, Enum): """States of the Open-Assistant message tree state machine.""" INITIAL_PROMPT_REVIEW = "initial_prompt_review" """In this state the message tree consists only of a single inital prompt root node. - Initial prompt labeling tasks will determine if the tree goes into `breeding_phase` or - `aborted_low_grade`.""" + Initial prompt labeling tasks will determine if the tree goes into `growing` or + `aborted_low_grade` state.""" - BREEDING_PHASE = "breeding_phase" + GROWING = "growing" """Assistant & prompter human demonstrations are collected. Concurrently labeling tasks are handed out to check if the quality of the replies surpasses the minimum acceptable quality. When the required number of messages passing the initial labelling-quality check has been - collected the tree will enter `ranking_phase`. If too many poor-quality labelling responses + collected the tree will enter `ranking`. If too many poor-quality labelling responses are received the tree can also enter the `aborted_low_grade` state.""" - RANKING_PHASE = "ranking_phase" + RANKING = "ranking" """The tree has been successfully populated with the desired number of messages. Ranking tasks are now handed out for all nodes with more than one child.""" @@ -46,28 +45,26 @@ class States(str, Enum): VALID_STATES = ( - States.INITIAL_PROMPT_REVIEW, - States.BREEDING_PHASE, - States.RANKING_PHASE, - States.READY_FOR_SCORING, - States.READY_FOR_EXPORT, - States.ABORTED_LOW_GRADE, + State.INITIAL_PROMPT_REVIEW, + State.GROWING, + State.RANKING, + State.READY_FOR_SCORING, + State.READY_FOR_EXPORT, + State.ABORTED_LOW_GRADE, ) -TERMINAL_STATES = (States.READY_FOR_EXPORT, States.ABORTED_LOW_GRADE, States.SCORING_FAILED, States.HALTED_BY_MODERATOR) +TERMINAL_STATES = (State.READY_FOR_EXPORT, State.ABORTED_LOW_GRADE, State.SCORING_FAILED, State.HALTED_BY_MODERATOR) class MessageTreeState(SQLModel, table=True): __tablename__ = "message_tree_state" - __table_args__ = (Index("ix_message_tree_state_tree_id", "message_tree_id", unique=True),) - id: Optional[UUID] = Field( - sa_column=sa.Column( - pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()") - ), + message_tree_id: UUID = Field( + sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("message.id"), primary_key=True) ) - message_tree_id: UUID = Field(nullable=False, index=True) - state: str = Field(nullable=False, max_length=128) goal_tree_size: int = Field(nullable=False) - current_num_non_filtered_messages: int = Field(nullable=False) max_depth: int = Field(nullable=False) + max_children_count: int = Field(nullable=False) + state: str = Field(nullable=False, max_length=128, index=True) + active: bool = Field(nullable=False, index=True) + accepted_messages: int = Field(nullable=False, default=0) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index c31c0061..75b9b6dc 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -2,13 +2,24 @@ import datetime import random from collections import defaultdict from http import HTTPStatus -from typing import List, Optional +from typing import List, Optional, Tuple from uuid import UUID, uuid4 import oasst_backend.models.db_payload as db_payload +import sqlalchemy as sa from loguru import logger from oasst_backend.journal_writer import JournalWriter -from oasst_backend.models import ApiClient, Message, MessageEmbedding, MessageReaction, TextLabels, User +from oasst_backend.models import ( + ApiClient, + Message, + MessageEmbedding, + MessageReaction, + MessageTreeState, + Task, + TextLabels, + User, + message_tree_state, +) from oasst_backend.models.payload_column_type import PayloadContainer from oasst_backend.task_repository import TaskRepository, validate_frontend_message_id from oasst_backend.user_repository import UserRepository @@ -34,6 +45,7 @@ class PromptRepository: self.user_repository = user_repository or UserRepository(db, api_client) self.user = self.user_repository.lookup_client_user(client_user, create_missing=True) self.user_id = self.user.id if self.user else None + logger.debug(f"PromptRepository(api_client_id={self.api_client.id}, {self.user_id=})") self.task_repository = task_repository or TaskRepository( db, api_client, client_user, user_repository=self.user_repository ) @@ -66,6 +78,8 @@ class PromptRepository: payload: db_payload.MessagePayload, payload_type: str = None, depth: int = 0, + review_count: int = 0, + review_result: bool = False, ) -> Message: if payload_type is None: if payload is None: @@ -85,25 +99,24 @@ class PromptRepository: payload_type=payload_type, payload=PayloadContainer(payload=payload), depth=depth, + review_count=review_count, + review_result=review_result, ) self.db.add(message) self.db.commit() self.db.refresh(message) return message - def store_text_reply( - self, - text: str, - frontend_message_id: str, - user_frontend_message_id: str, - ) -> Message: - validate_frontend_message_id(frontend_message_id) - validate_frontend_message_id(user_frontend_message_id) - - task = self.task_repository.fetch_task_by_frontend_message_id(frontend_message_id) - + def _validate_task( + self, task: Task, *, task_id: Optional[UUID] = None, frontend_message_id: Optional[str] = None + ) -> Task: if task is None: - raise OasstError(f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND) + if task_id: + raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND) + if frontend_message_id: + raise OasstError(f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND) + raise OasstError("Task not found", OasstErrorCode.TASK_NOT_FOUND) + if task.expired: raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED) if not task.ack: @@ -111,12 +124,45 @@ class PromptRepository: if task.done: raise OasstError("Task already done.", OasstErrorCode.TASK_ALREADY_DONE) + if (not task.collective or task.user_id is None) and task.user_id != self.user_id: + logger.warning(f"Task was assigned to a different user (expected: {task.user_id}; actual: {self.user_id}).") + raise OasstError("Task was assigned to a different user.", OasstErrorCode.TASK_NOT_ASSIGNED_TO_USER) + + return task + + def fetch_tree_state(self, message_tree_id: UUID) -> MessageTreeState: + return self.db.query(MessageTreeState).filter(MessageTreeState.message_tree_id == message_tree_id).one() + + def store_text_reply( + self, + text: str, + frontend_message_id: str, + user_frontend_message_id: str, + review_count: int = 0, + review_result: bool = False, + ) -> Message: + validate_frontend_message_id(frontend_message_id) + validate_frontend_message_id(user_frontend_message_id) + + task = self.task_repository.fetch_task_by_frontend_message_id(frontend_message_id) + self._validate_task(task) + # If there's no parent message assume user started new conversation role = "prompter" depth = 0 if task.parent_message_id: parent_message = self.fetch_message(task.parent_message_id) + + # check tree state + ts = self.fetch_tree_state(parent_message.message_tree_id) + if not ts.active or ts.state != message_tree_state.State.GROWING: + raise OasstError( + "Message insertion failed. Message tree is no longer in 'growing' state.", + OasstErrorCode.TREE_NOT_IN_GROWING_STATE, + ) + + parent_message.message_tree_id parent_message.children_count += 1 self.db.add(parent_message) @@ -137,6 +183,8 @@ class PromptRepository: role=role, payload=db_payload.MessagePayload(text=text), depth=depth, + review_count=review_count, + review_result=review_result, ) if not task.collective: task.done = True @@ -149,6 +197,7 @@ class PromptRepository: message = self.fetch_message_by_frontend_message_id(rating.message_id, fail_if_missing=True) task = self.task_repository.fetch_task_by_frontend_message_id(rating.message_id) + self._validate_task(task) task_payload: db_payload.RateSummaryPayload = task.payload.payload if type(task_payload) != db_payload.RateSummaryPayload: raise OasstError( @@ -173,9 +222,10 @@ class PromptRepository: logger.info(f"Ranking {rating.rating} stored for task {task.id}.") return reaction - def store_ranking(self, ranking: protocol_schema.MessageRanking) -> MessageReaction: + def store_ranking(self, ranking: protocol_schema.MessageRanking) -> Tuple[MessageReaction, Task]: # fetch task task = self.task_repository.fetch_task_by_frontend_message_id(ranking.message_id) + self._validate_task(task, frontend_message_id=ranking.message_id) if not task.collective: task.done = True self.db.add(task) @@ -188,47 +238,59 @@ class PromptRepository: case db_payload.RankPrompterRepliesPayload | db_payload.RankAssistantRepliesPayload: # validate ranking - num_replies = len(task_payload.replies) - if sorted(ranking.ranking) != list(range(num_replies)): + if sorted(ranking.ranking) != list(range(num_replies := len(task_payload.reply_messages))): raise OasstError( f"Invalid ranking submitted. Each reply index must appear exactly once ({num_replies=}).", OasstErrorCode.INVALID_RANKING_VALUE, ) + last_conv_message = task_payload.conversation.messages[-1] + parent_msg = self.fetch_message(last_conv_message.message_id) + # store reaction to message - reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking) + ranked_message_ids = [task_payload.reply_messages[i].message_id for i in ranking.ranking] + for mid in ranked_message_ids: + message = self.fetch_message(mid) + if message.parent_id != parent_msg.id: + raise OasstError("Corrupt reply ranking result", OasstErrorCode.CORRUPT_RANKING_RESULT) + message.ranking_count += 1 + self.db.add(message) + + reaction_payload = db_payload.RankingReactionPayload( + ranking=ranking.ranking, ranked_message_ids=ranked_message_ids + ) reaction = self.insert_reaction(task.id, reaction_payload) - # TODO: resolve message_id - self.journal.log_ranking(task, message_id=None, ranking=ranking.ranking) + self.journal.log_ranking(task, message_id=parent_msg.id, ranking=ranking.ranking) logger.info(f"Ranking {ranking.ranking} stored for task {task.id}.") - return reaction - case db_payload.RankInitialPromptsPayload: # validate ranking - if sorted(ranking.ranking) != list(range(num_prompts := len(task_payload.prompts))): + if sorted(ranking.ranking) != list(range(num_prompts := len(task_payload.prompt_messages))): raise OasstError( f"Invalid ranking submitted. Each reply index must appear exactly once ({num_prompts=}).", OasstErrorCode.INVALID_RANKING_VALUE, ) # store reaction to message - reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking) + ranked_message_ids = [task_payload.prompt_messages[i].message_id for i in ranking.ranking] + reaction_payload = db_payload.RankingReactionPayload( + ranking=ranking.ranking, ranked_message_ids=ranked_message_ids + ) reaction = self.insert_reaction(task.id, reaction_payload) # TODO: resolve message_id self.journal.log_ranking(task, message_id=None, ranking=ranking.ranking) logger.info(f"Ranking {ranking.ranking} stored for task {task.id}.") - return reaction - case _: raise OasstError( f"task payload type mismatch: {type(task_payload)=} != {db_payload.RankConversationRepliesPayload}", OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH, ) + return reaction, task + def insert_message_embedding(self, message_id: UUID, model: str, embedding: List[float]) -> MessageEmbedding: """Insert the embedding of a new message in the database. @@ -270,21 +332,80 @@ class PromptRepository: self.db.refresh(reaction) return reaction - def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> TextLabels: + def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> Tuple[TextLabels, Task, Message]: + + valid_labels: Optional[list[str]] = None + mandatory_labels: Optional[list[str]] = None + text_labels_id: Optional[UUID] = None + message_id: Optional[UUID] = text_labels.message_id + + task: Task = None + if text_labels.task_id: + logger.debug(f"text_labels reply has task_id {text_labels.task_id}") + task = self.task_repository.fetch_task_by_id(text_labels.task_id) + self._validate_task(task, task_id=text_labels.task_id) + + task_payload: db_payload.TaskPayload = task.payload.payload + if isinstance(task_payload, db_payload.LabelInitialPromptPayload): + if message_id and task_payload.message_id != message_id: + raise OasstError("Task message id mismatch", OasstErrorCode.TEXT_LABELS_WRONG_MESSAGE_ID) + message_id = task_payload.message_id + valid_labels = task_payload.valid_labels + mandatory_labels = task_payload.mandatory_labels + elif isinstance(task_payload, db_payload.LabelConversationReplyPayload): + if message_id and message_id != message_id: + raise OasstError("Task message id mismatch", OasstErrorCode.TEXT_LABELS_WRONG_MESSAGE_ID) + message_id = task_payload.message_id + valid_labels = task_payload.valid_labels + mandatory_labels = task_payload.mandatory_labels + else: + raise OasstError( + "Unexpected text_labels task payload", + OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH, + ) + + logger.debug(f"text_labels relpy: {valid_labels=}, {mandatory_labels=}") + + if valid_labels: + if not all([label in valid_labels for label in text_labels.labels.keys()]): + raise OasstError("Invalid text label specified", OasstErrorCode.TEXT_LABELS_INVALID_LABEL) + + if isinstance(mandatory_labels, list): + mandatory_set = set(mandatory_labels) + if not mandatory_set.issubset(text_labels.labels.keys()): + missing = ", ".join(mandatory_set - text_labels.labels.keys()) + raise OasstError( + f"Mandatory text labels missing: {missing}", OasstErrorCode.TEXT_LABELS_MANDATORY_LABEL_MISSING + ) + + text_labels_id = task.id # associate with task by sharing the id + + if not task.collective: + task.done = True + self.db.add(task) + + logger.debug(f"inserting TextLabels for {message_id=}, {text_labels_id=}") model = TextLabels( + id=text_labels_id, api_client_id=self.api_client.id, - message_id=text_labels.message_id, + message_id=message_id, user_id=self.user_id, text=text_labels.text, labels=text_labels.labels, ) + if message_id: + message = self.fetch_message(message_id) + if task: + message.review_count += 1 + self.db.add(message) + self.db.add(model) self.db.commit() self.db.refresh(model) - return model + return model, task, message - def fetch_random_message_tree(self, require_role: str = None) -> list[Message]: + def fetch_random_message_tree(self, require_role: str = None, reviewed: bool = True) -> list[Message]: """ Loads all messages of a random message_tree. @@ -294,13 +415,18 @@ class PromptRepository: distinct_message_trees = self.db.query(Message.message_tree_id).distinct(Message.message_tree_id) if require_role: distinct_message_trees = distinct_message_trees.filter(Message.role == require_role) + if reviewed: + distinct_message_trees = distinct_message_trees.filter(Message.review_result) distinct_message_trees = distinct_message_trees.subquery() - random_message_tree = self.db.query(distinct_message_trees).order_by(func.random()).limit(1) - message_tree_messages = self.db.query(Message).filter(Message.message_tree_id.in_(random_message_tree)).all() - return message_tree_messages + random_message_tree_id = self.db.query(distinct_message_trees).order_by(func.random()).limit(1).scalar() + if random_message_tree_id: + return self.fetch_message_tree(random_message_tree_id, reviewed) + return None - def fetch_random_conversation(self, last_message_role: str = None) -> list[Message]: + def fetch_random_conversation( + self, last_message_role: str = None, message_tree_id: Optional[UUID] = None, reviewed: bool = True + ) -> list[Message]: """ Picks a random linear conversation starting from any root message and ending somewhere in the message_tree, possibly at the root itself. @@ -310,9 +436,13 @@ class PromptRepository: the user should reply as a human and hence the last message of the conversation needs to have "assistant" role. """ - messages_tree = self.fetch_random_message_tree(last_message_role) + if message_tree_id: + messages_tree = self.fetch_message_tree(message_tree_id, reviewed) + else: + messages_tree = self.fetch_random_message_tree(last_message_role) if not messages_tree: raise OasstError("No message tree found", OasstErrorCode.NO_MESSAGE_TREE_FOUND) + if last_message_role: conv_messages = [m for m in messages_tree if m.role == last_message_role] conv_messages = [random.choice(conv_messages)] @@ -334,8 +464,11 @@ class PromptRepository: messages = self.db.query(Message).filter(Message.parent_id.is_(None)).order_by(func.random()).limit(size).all() return messages - def fetch_message_tree(self, message_tree_id: UUID): - return self.db.query(Message).filter(Message.message_tree_id == message_tree_id).all() + def fetch_message_tree(self, message_tree_id: UUID, reviewed: bool = True): + qry = self.db.query(Message).filter(Message.message_tree_id == message_tree_id) + if reviewed: + qry = qry.filter(Message.review_result) + return qry.all() def fetch_multiple_random_replies(self, max_size: int = 5, message_role: str = None): """ @@ -388,13 +521,15 @@ class PromptRepository: messages = {m.id: m for m in messages} if not isinstance(messages, dict): # This should not normally happen - raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR) + raise OasstError("Server error", OasstErrorCode.SERVER_ERROR0, HTTPStatus.INTERNAL_SERVER_ERROR) conv = [last_message] while conv[-1].parent_id: if conv[-1].parent_id not in messages: # Can't form a continuous conversation - raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR) + raise OasstError( + "Broken conversation", OasstErrorCode.BROKEN_CONVERSATION, HTTPStatus.INTERNAL_SERVER_ERROR + ) parent_message = messages[conv[-1].parent_id] conv.append(parent_message) @@ -417,16 +552,24 @@ class PromptRepository: """ if isinstance(message, UUID): message = self.fetch_message(message) + logger.debug(f"fetch_message_tree({message.message_tree_id=})") return self.fetch_message_tree(message.message_tree_id) - def fetch_message_children(self, message: Message | UUID) -> list[Message]: + def fetch_message_children( + self, message: Message | UUID, reviewed: bool = True, exclude_deleted: bool = True + ) -> list[Message]: """ Get all direct children of this message """ if isinstance(message, Message): message = message.id - children = self.db.query(Message).filter(Message.parent_id == message).all() + qry = self.db.query(Message).filter(Message.parent_id == message) + if reviewed: + qry = qry.filter(Message.review_result) + if exclude_deleted: + qry = qry.filter(Message.deleted == sa.false()) + children = qry.all() return children @staticmethod @@ -536,7 +679,7 @@ class PromptRepository: elif isinstance(message, Message): ids.append(message.id) else: - raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR) + raise OasstError("Server error", OasstErrorCode.SERVER_ERROR1, HTTPStatus.INTERNAL_SERVER_ERROR) query = update(Message).where(Message.id.in_(ids)).values(deleted=True) self.db.execute(query) diff --git a/backend/oasst_backend/task_repository.py b/backend/oasst_backend/task_repository.py index 15484d66..de7eb28a 100644 --- a/backend/oasst_backend/task_repository.py +++ b/backend/oasst_backend/task_repository.py @@ -2,6 +2,7 @@ from typing import Optional from uuid import UUID import oasst_backend.models.db_payload as db_payload +from loguru import logger from oasst_backend.models import ApiClient, Task from oasst_backend.models.payload_column_type import PayloadContainer from oasst_backend.user_repository import UserRepository @@ -62,16 +63,16 @@ class TaskRepository: payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation) case protocol_schema.RankInitialPromptsTask: - payload = db_payload.RankInitialPromptsPayload(type=task.type, prompts=task.prompts) + payload = db_payload.RankInitialPromptsPayload(type=task.type, prompt_messages=task.prompt_messages) case protocol_schema.RankPrompterRepliesTask: payload = db_payload.RankPrompterRepliesPayload( - type=task.type, conversation=task.conversation, replies=task.replies + type=task.type, conversation=task.conversation, reply_messages=task.reply_messages ) case protocol_schema.RankAssistantRepliesTask: payload = db_payload.RankAssistantRepliesPayload( - type=task.type, conversation=task.conversation, replies=task.replies + type=task.type, conversation=task.conversation, reply_messages=task.reply_messages ) case protocol_schema.LabelInitialPromptTask: @@ -86,6 +87,7 @@ class TaskRepository: conversation=task.conversation, reply=task.reply, valid_labels=task.valid_labels, + mandatory_labels=task.mandatory_labels, ) case protocol_schema.LabelAssistantReplyTask: @@ -95,20 +97,21 @@ class TaskRepository: conversation=task.conversation, reply=task.reply, valid_labels=task.valid_labels, + mandatory_labels=task.mandatory_labels, ) case _: raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE) - task = self.insert_task( + task_model = self.insert_task( payload=payload, id=task.id, message_tree_id=message_tree_id, parent_message_id=parent_message_id, collective=collective, ) - assert task.id == task.id - return task + assert task_model.id == task.id + return task_model def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str): validate_frontend_message_id(frontend_message_id) @@ -184,6 +187,7 @@ class TaskRepository: parent_message_id=parent_message_id, collective=collective, ) + logger.debug(f"inserting {task=}") self.db.add(task) self.db.commit() self.db.refresh(task) @@ -197,3 +201,7 @@ class TaskRepository: .one_or_none() ) return task + + def fetch_task_by_id(self, task_id: UUID) -> Task: + task = self.db.query(Task).filter(Task.api_client_id == self.api_client.id, Task.id == task_id).one_or_none() + return task diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py new file mode 100644 index 00000000..16dc4d16 --- /dev/null +++ b/backend/oasst_backend/tree_manager.py @@ -0,0 +1,804 @@ +import random +from enum import Enum +from typing import Optional, Tuple +from uuid import UUID + +import numpy as np +import pydantic +from loguru import logger +from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversation_message_list +from oasst_backend.config import settings +from oasst_backend.models import Message, MessageReaction, MessageTreeState, TextLabels, message_tree_state +from oasst_backend.prompt_repository import PromptRepository +from oasst_backend.utils.hugging_face import HfEmbeddingModel, HfUrl, HuggingFaceAPI +from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode +from oasst_shared.schemas import protocol as protocol_schema +from sqlalchemy.sql import text +from sqlmodel import Session, func + + +class TreeManagerConfiguration(pydantic.BaseModel): + """Configuration class for the TreeManager""" + + max_active_trees: int = 10 + """Maximum number of concurrently active trees in the database. + No new initial prompt tasks will be handed out to users if this + number is reached.""" + + max_tree_depth: int = 6 + """Maximum depth of message tree.""" + + max_children_count: int = 5 + """Maximum number of reply messages per tree node.""" + + goal_tree_size: int = 15 + """Total number of messages to gather per tree""" + + num_reviews_initial_prompt: int = 3 + """Number of peer review checks to collect in INITIAL_PROMPT_REVIEW state.""" + + num_reviews_reply: int = 3 + """Number of peer review checks to collect per reply (other than initial_prompt)""" + + acceptance_threshold_initial_prompt: float = 0.6 + """Threshold for accepting an initial prompt""" + + acceptance_threshold_reply: float = 0.6 + """Threshold for accepting a reply""" + + num_required_rankings: int = 3 + """Number of rankings in which the message participated.""" + + mandatory_labels_initial_prompt: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam] + mandatory_labels_assistant_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam] + mandatory_labels_prompter_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam] + + +class TaskType(Enum): + NONE = -1 + RANKING = 0 + LABEL_REPLY = 1 + REPLY = 2 + LABEL_PROMPT = 3 + PROMPT = 4 + + +class TaskRole(Enum): + ANY = 0 + PROMPTER = 1 + ASSISTANT = 2 + + +class ActiveTreeSizeRow(pydantic.BaseModel): + message_tree_id: UUID + tree_size: int + goal_tree_size: int + + @property + def remaining_messages(self) -> int: + return max(0, self.goal_tree_size - self.tree_size) + + class Config: + orm_mode = True + + +class ExtendibleParentRow(pydantic.BaseModel): + parent_id: UUID + depth: int + message_tree_id: UUID + active_children_count: int + + class Config: + orm_mode = True + + +class IncompleteRankingsRow(pydantic.BaseModel): + parent_id: UUID + children_count: int + child_min_ranking_count: int + + class Config: + orm_mode = True + + +class TreeManager: + def __init__(self, db: Session, prompt_repository: PromptRepository, configuration: TreeManagerConfiguration): + self.db = db + self.cfg = configuration + self.pr = prompt_repository + + def _task_selection( + self, + desired_task_type: protocol_schema.TaskRequestType, + num_ranking_tasks: int, + num_replies_need_review: int, + num_prompts_need_review: int, + num_missing_prompts: int, + num_missing_replies: int, + ) -> Tuple[TaskType, TaskRole]: + """ + Determines which task to hand out to human worker. + The task type is drawn with relative weight (e.g. ranking has highest priority) + depending on what is possible with the current message trees in the database. + """ + + logger.debug( + f"TreeManager._task_selection({num_ranking_tasks=}, {num_replies_need_review=}, " + f"{num_prompts_need_review=}, {num_missing_prompts=}, {num_missing_replies=})" + ) + + task_type = TaskType.NONE + task_role = TaskRole.ANY + if desired_task_type == protocol_schema.TaskRequestType.random: + task_weights = [0] * 5 + + if num_ranking_tasks > 0: + task_weights[TaskType.RANKING.value] = 10 + + if num_replies_need_review > 0: + task_weights[TaskType.LABEL_REPLY.value] = 5 + + if num_prompts_need_review > 0: + task_weights[TaskType.LABEL_PROMPT.value] = 5 + + if num_missing_replies > 0: + task_weights[TaskType.REPLY.value] = 2 + + if num_missing_prompts > 0: + task_weights[TaskType.PROMPT.value] = 1 + + task_weights = np.array(task_weights) + weight_sum = task_weights.sum() + if weight_sum < 1e-8: + task_type = TaskType.NONE + else: + task_weights = task_weights / weight_sum + task_type = TaskType(np.random.choice(a=len(task_weights), p=task_weights)) + else: + match desired_task_type: + case protocol_schema.TaskRequestType.initial_prompt: + if num_missing_prompts > 0: + task_type = TaskType.PROMPT + case protocol_schema.TaskRequestType.label_initial_prompt: + if num_prompts_need_review > 0: + task_type = TaskType.LABEL_PROMPT + case protocol_schema.TaskRequestType.assistant_reply | protocol_schema.TaskRequestType.prompter_reply: + if num_missing_replies > 0: + task_role = ( + TaskRole.ASSISTANT + if desired_task_type == protocol_schema.TaskRequestType.assistant_reply + else TaskRole.PROMPTER + ) + task_type = TaskType.REPLY + case protocol_schema.TaskRequestType.label_assistant_reply | protocol_schema.TaskRequestType.label_prompter_reply: + if num_replies_need_review > 0: + task_role = ( + TaskRole.ASSISTANT + if desired_task_type == protocol_schema.TaskRequestType.label_assistant_reply + else TaskRole.PROMPTER + ) + task_type = TaskType.LABEL_REPLY + case protocol_schema.TaskRequestType.rank_assistant_replies | protocol_schema.TaskRequestType.rank_prompter_replies: + if num_ranking_tasks > 0: + task_role = ( + TaskRole.ASSISTANT + if desired_task_type == protocol_schema.TaskRequestType.rank_assistant_replies + else TaskRole.PROMPTER + ) + task_type = TaskType.RANKING + + logger.debug(f"Selected {task_type=}, {task_role=}") + return task_type, task_role + + def next_task( + self, desired_task_type: protocol_schema.TaskRequestType + ) -> Tuple[protocol_schema.Task, Optional[UUID], Optional[UUID]]: + + logger.debug("TreeManager.next_task()") + + num_active_trees = self.query_num_active_trees() + prompts_need_review = self.query_prompts_need_review() + replies_need_review = self.query_replies_need_review() + incomplete_rankings = self.query_incomplete_rankings() + active_tree_sizes = self.query_extendible_trees() + + # determine type of task to generate + num_missing_replies = sum(x.remaining_messages for x in active_tree_sizes) + + task_type, task_role = self._task_selection( + desired_task_type, + num_ranking_tasks=len(incomplete_rankings), + num_replies_need_review=len(replies_need_review), + num_prompts_need_review=len(prompts_need_review), + num_missing_prompts=max(0, self.cfg.max_active_trees - num_active_trees), + num_missing_replies=num_missing_replies, + ) + + if task_type == TaskType.NONE: + raise OasstError( + f"No tasks of type '{desired_task_type.value}' are currently available.", + OasstErrorCode.TASK_REQUESTED_TYPE_UNAVAILABLE, + ) + + if task_role != TaskRole.ANY: + # Todo: Allow role specific message selection... + raise OasstError( + f"No tasks of type '{desired_task_type.value}' are currently available.", + OasstErrorCode.TASK_REQUESTED_TYPE_UNAVAILABLE, + ) + + message_tree_id = None + parent_message_id = None + + logger.debug(f"selected {task_type=}") + match task_type: + case TaskType.RANKING: + assert len(incomplete_rankings) > 0 + ranking_parent_id = random.choice(incomplete_rankings).parent_id + + messages = self.pr.fetch_message_conversation(ranking_parent_id) + conversation = prepare_conversation(messages) + replies = self.pr.fetch_message_children(ranking_parent_id, reviewed=True, exclude_deleted=True) + reply_messages = prepare_conversation_message_list(replies) + replies = [p.text for p in replies] + + if messages[-1].role == "assistant": + logger.info("Generating a RankPrompterRepliesTask.") + task = protocol_schema.RankPrompterRepliesTask( + conversation=conversation, replies=replies, reply_messages=reply_messages + ) + else: + logger.info("Generating a RankAssistantRepliesTask.") + task = protocol_schema.RankAssistantRepliesTask( + conversation=conversation, replies=replies, reply_messages=reply_messages + ) + + parent_message_id = ranking_parent_id + message_tree_id = messages[-1].message_tree_id + + case TaskType.LABEL_REPLY: + assert len(replies_need_review) > 0 + random_reply_message_id = random.choice(replies_need_review) + messages = self.pr.fetch_message_conversation(random_reply_message_id) + conversation = prepare_conversation(messages[:-1]) + message = messages[-1] + + if message.role == "assistant": + logger.info("Generating a LabelAssistantReplyTask.") + task = protocol_schema.LabelAssistantReplyTask( + message_id=message.id, + conversation=conversation, + reply=message.text, + valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)), + mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)), + ) + else: + logger.info("Generating a LabelPrompterReplyTask.") + task = protocol_schema.LabelPrompterReplyTask( + message_id=message.id, + conversation=conversation, + reply=message.text, + valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)), + mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)), + ) + + parent_message_id = message.id + message_tree_id = message.message_tree_id + + case TaskType.REPLY: + # select a tree with missing replies + extensible_parents = self.query_extendible_parents() + assert len(extensible_parents) > 0 + + # fetch random conversation to extend + random_parent = random.choice(extensible_parents) + logger.debug(f"selected {random_parent=}") + messages = self.pr.fetch_message_conversation(random_parent.parent_id) + assert all(m.review_result for m in messages) # ensure all messages have positive review + conversation = prepare_conversation(messages) + + # generate reply task depending on last message + if messages[-1].role == "assistant": + logger.info("Generating a PrompterReplyTask.") + task = protocol_schema.PrompterReplyTask(conversation=conversation) + else: + logger.info("Generating a AssistantReplyTask.") + task = protocol_schema.AssistantReplyTask(conversation=conversation) + + parent_message_id = messages[-1].id + message_tree_id = messages[-1].message_tree_id + + case TaskType.LABEL_PROMPT: + assert len(prompts_need_review) > 0 + message = self.pr.fetch_message(random.choice(prompts_need_review)) + logger.info("Generating a LabelInitialPromptTask.") + + task = protocol_schema.LabelInitialPromptTask( + message_id=message.id, + prompt=message.text, + valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)), + mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)), + ) + + parent_message_id = message.id + message_tree_id = message.message_tree_id + + case TaskType.PROMPT: + logger.info("Generating an InitialPromptTask.") + task = protocol_schema.InitialPromptTask(hint=None) + + case _: + task = None + + logger.info(f"Generated {task=}.") + + return task, message_tree_id, parent_message_id + + async def handle_interaction(self, interaction: protocol_schema.AnyInteraction) -> protocol_schema.Task: + pr = self.pr + match type(interaction): + case protocol_schema.TextReplyToMessage: + logger.info( + f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}." + ) + + # here we store the text reply in the database + message = pr.store_text_reply( + text=interaction.text, + frontend_message_id=interaction.message_id, + user_frontend_message_id=interaction.user_message_id, + ) + + if not message.parent_id: + logger.info(f"TreeManager: Inserting new tree state for initial prompt {message.id=}") + self._insert_default_state(message.id) + self.db.commit() + + if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION: + try: + hugging_face_api = HuggingFaceAPI( + f"{HfUrl.HUGGINGFACE_FEATURE_EXTRACTION.value}/{HfEmbeddingModel.MINILM.value}" + ) + embedding = await hugging_face_api.post(interaction.text) + pr.insert_message_embedding( + message_id=message.id, model=HfEmbeddingModel.MINILM.value, embedding=embedding + ) + except OasstError: + logger.error( + f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}." + ) + + case protocol_schema.MessageRating: + logger.info( + f"Frontend reports rating of {interaction.message_id=} with {interaction.rating=} by {interaction.user=}." + ) + + pr.store_rating(interaction) + + case protocol_schema.MessageRanking: + logger.info( + f"Frontend reports ranking of {interaction.message_id=} with {interaction.ranking=} by {interaction.user=}." + ) + + _, task = pr.store_ranking(interaction) + + self.check_condition_for_scoring_state(task.message_tree_id) + + case protocol_schema.TextLabels: + logger.info( + f"Frontend reports labels of {interaction.message_id=} with {interaction.labels=} by {interaction.user=}." + ) + + _, task, msg = pr.store_text_labels(interaction) + + # if it was a respones for a task, check if we have enough reviews to calc review_result + if task and msg: + reviews = self.query_reviews_for_message(msg.id) + acceptance_score = self._calculate_acceptance(reviews) + logger.debug( + f"Message {msg.id=}, {acceptance_score=}, {len(reviews)=}, {msg.review_result=}, {msg.review_count=}" + ) + if msg.parent_id is None: + if not msg.review_result and msg.review_count >= self.cfg.num_reviews_initial_prompt: + if acceptance_score > self.cfg.acceptance_threshold_initial_prompt: + msg.review_result = True + self.db.add(msg) + self.db.commit() + logger.info( + f"Initial prompt message was accepted: {msg.id=}, {acceptance_score=}, {len(reviews)=}" + ) + else: + self.enter_low_grade_state(msg.message_tree_id) + self.check_condition_for_growing_state(msg.message_tree_id) + elif msg.review_count >= self.cfg.num_reviews_reply: + if not msg.review_result and acceptance_score > self.cfg.acceptance_threshold_reply: + msg.review_result = True + self.db.add(msg) + self.db.commit() + logger.info( + f"Reply message message accepted: {msg.id=}, {acceptance_score=}, {len(reviews)=}" + ) + + self.check_condition_for_ranking_state(msg.message_tree_id) + + case _: + raise OasstError("Invalid response type.", OasstErrorCode.TASK_INVALID_RESPONSE_TYPE) + + return protocol_schema.TaskDone() + + def _enter_state(self, mts: MessageTreeState, state: message_tree_state.State): + assert mts and mts.active + + is_terminal = state in message_tree_state.TERMINAL_STATES + + if is_terminal: + mts.active = False + mts.state = state.value + self.db.add(mts) + self.db.commit() + + if is_terminal: + logger.info(f"Tree entered terminal '{mts.state}' state ({mts.message_tree_id=})") + else: + logger.info(f"Tree entered '{mts.state}' state ({mts.message_tree_id=})") + + def enter_low_grade_state(self, message_tree_id: UUID) -> None: + logger.debug(f"enter_low_grade_state({message_tree_id=})") + mts = self.pr.fetch_tree_state(message_tree_id) + self._enter_state(mts, message_tree_state.State.ABORTED_LOW_GRADE) + + def check_condition_for_growing_state(self, message_tree_id: UUID) -> bool: + logger.debug(f"check_condition_for_growing_state({message_tree_id=})") + + mts = self.pr.fetch_tree_state(message_tree_id) + if not mts.active or mts.state != message_tree_state.State.INITIAL_PROMPT_REVIEW: + logger.debug(f"False {mts.active=}, {mts.state=}") + return False + + # check if initial prompt was accepted + initial_prompt = self.pr.fetch_message(message_tree_id) + if not initial_prompt.review_result: + logger.debug(f"False {initial_prompt.review_result=}") + return False + + self._enter_state(mts, message_tree_state.State.GROWING) + return True + + def check_condition_for_ranking_state(self, message_tree_id: UUID) -> bool: + logger.debug(f"check_condition_for_ranking_state({message_tree_id=})") + + mts = self.pr.fetch_tree_state(message_tree_id) + if not mts.active or mts.state != message_tree_state.State.GROWING: + logger.debug(f"False {mts.active=}, {mts.state=}") + return False + + # check if desired tree size has been reached and all nodes have been reviewed + tree_size = self.query_tree_size(message_tree_id) + if tree_size.remaining_messages > 0: + logger.debug(f"False {tree_size.remaining_messages=}") + return False + + self._enter_state(mts, message_tree_state.State.RANKING) + return True + + def check_condition_for_scoring_state(self, message_tree_id: UUID) -> bool: + logger.debug(f"check_condition_for_scoring_state({message_tree_id=})") + mts: MessageTreeState + mts = self.db.query(MessageTreeState).filter(MessageTreeState.message_tree_id == message_tree_id).one() + if not mts.active or mts.state != message_tree_state.State.RANKING: + logger.debug(f"False {mts.active=}, {mts.state=}") + return False + + rankings_by_message = self.query_tree_ranking_results(message_tree_id) + for parent_msg_id, ranking in rankings_by_message.items(): + if len(ranking) < self.cfg.num_required_rankings: + logger.debug(f"False {parent_msg_id=} {len(ranking)=}") + return False + + self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING) + return True + + def _calculate_acceptance(self, labels: list[TextLabels]): + # calculate acceptance based on spam label + return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels]) + + _sql_find_prompts_need_review = """ +-- find initial prompts that need more reviews +SELECT m.id +FROM message_tree_state mts + LEFT JOIN message m ON mts.message_tree_id = m.id +WHERE mts.active + AND mts.state = :state + AND NOT m.review_result + AND NOT m.deleted + AND m.review_count < :num_reviews_initial_prompt + AND m.parent_id is NULL + AND (:excluded_user_id IS NULL OR m.user_id != :excluded_user_id) +""" + + def query_prompts_need_review(self) -> list[UUID]: + """ + Select id of initial prompts with less then required rankings in active message tree + (active == True in message_tree_state) + """ + + r = self.db.execute( + text(self._sql_find_prompts_need_review), + { + "state": message_tree_state.State.INITIAL_PROMPT_REVIEW, + "num_reviews_initial_prompt": self.cfg.num_reviews_initial_prompt, + "excluded_user_id": None if settings.DEBUG_ALLOW_SELF_LABELING else self.pr.user_id, + }, + ) + return [x["id"] for x in r.all()] + + _sql_find_replies_need_review = """ +SELECT m.id +FROM message_tree_state mts + LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id +WHERE mts.active + AND mts.state = :breeding_state + AND NOT m.review_result + AND NOT m.deleted + AND m.review_count < :num_required_reviews + AND m.parent_id is NOT NULL + AND (:excluded_user_id IS NULL OR m.user_id != :excluded_user_id) +""" + + def query_replies_need_review(self) -> list[UUID]: + """ + Select ids of child messages (parent_id IS NOT NULL) with less then required rankings + in active message tree (active == True in message_tree_state) + """ + + r = self.db.execute( + text(self._sql_find_replies_need_review), + { + "breeding_state": message_tree_state.State.GROWING, + "num_required_reviews": self.cfg.num_reviews_reply, + "excluded_user_id": None if settings.DEBUG_ALLOW_SELF_LABELING else self.pr.user_id, + }, + ) + return [x["id"] for x in r.all()] + + _sql_find_incomplete_rankings = """ +-- find incomplete rankings +SELECT m.parent_id, COUNT(m.id) children_count, MIN(m.ranking_count) child_min_ranking_count, + COUNT(m.id) FILTER (WHERE m.ranking_count >= :num_required_rankings) as completed_rankings +FROM message_tree_state mts + LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id +WHERE mts.active -- only consider active trees + AND mts.state = :ranking_state -- message tree must be in ranking state + AND m.review_result -- must be reviewed + AND NOT m.deleted -- not deleted + AND m.parent_id IS NOT NULL -- ignore initial prompts +GROUP BY m.parent_id +HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings +""" + + def query_incomplete_rankings(self) -> list[IncompleteRankingsRow]: + """Query parents which have childern that need further rankings""" + + r = self.db.execute( + text(self._sql_find_incomplete_rankings), + { + "num_required_rankings": self.cfg.num_required_rankings, + "ranking_state": message_tree_state.State.RANKING, + }, + ) + return [IncompleteRankingsRow.from_orm(x) for x in r.all()] + + _sql_find_extendible_parents = """ +-- find all extendible parent nodes +SELECT m.id as parent_id, m.depth, m.message_tree_id, COUNT(c.id) active_children_count +FROM message_tree_state mts + LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id -- all elements of message tree + LEFT JOIN message c ON m.id = c.Id -- child nodes +WHERE mts.active -- only consider active trees + AND mts.state = :growing_state -- message tree must be growing + AND NOT m.deleted -- ignore deleted messages as parents + AND m.depth < mts.max_depth -- ignore leaf nodes as parents + AND m.review_result -- parent node must have positive review + AND NOT c.deleted -- don't count deleted children + AND (c.review_result OR c.review_count < :num_reviews_reply) -- don't count children with negative review but count elements under review +GROUP BY m.id, m.depth, m.message_tree_id, mts.max_children_count +HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children +""" + + def query_extendible_parents(self) -> list[ExtendibleParentRow]: + """Query parent messages that have not reached the maximum number of replies.""" + + r = self.db.execute( + text(self._sql_find_extendible_parents), + { + "growing_state": message_tree_state.State.GROWING, + "num_reviews_reply": self.cfg.num_reviews_reply, + }, + ) + return [ExtendibleParentRow.from_orm(x) for x in r.all()] + + _sql_find_extendible_trees = f""" +-- find extendible trees +SELECT m.message_tree_id, mts.goal_tree_size, COUNT(m.id) AS tree_size +FROM ( + SELECT DISTINCT message_tree_id FROM ({_sql_find_extendible_parents}) extendible_parents + ) trees LEFT JOIN message_tree_state mts ON trees.message_tree_id = mts.message_tree_id + LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id +WHERE NOT m.deleted + AND ( + m.parent_id IS NOT NULL AND (m.review_result OR m.review_count < :num_reviews_reply) -- children + OR m.parent_id IS NULL AND m.review_result -- prompts (root nodes) must have positive review + ) +GROUP BY m.message_tree_id, mts.goal_tree_size +HAVING COUNT(m.id) < mts.goal_tree_size +""" + + def query_extendible_trees(self) -> list[ActiveTreeSizeRow]: + """Query size of active message trees in growing state.""" + + r = self.db.execute( + text(self._sql_find_extendible_trees), + { + "growing_state": message_tree_state.State.GROWING, + "num_reviews_reply": self.cfg.num_reviews_reply, + }, + ) + return [ActiveTreeSizeRow.from_orm(x) for x in r.all()] + + _sql_get_tree_size = """ +SELECT mts.message_tree_id, mts.goal_tree_size, COUNT(m.id) AS tree_size +FROM message_tree_state mts + LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id +WHERE mts.active + AND NOT m.deleted + AND m.review_result + AND mts.message_tree_id = :message_tree_id +GROUP BY mts.message_tree_id, mts.goal_tree_size +""" + + def query_tree_size(self, message_tree_id: UUID) -> ActiveTreeSizeRow: + """Returns the number of reviewed not deleted messages in the message tree.""" + r = self.db.execute(text(self._sql_get_tree_size), {"message_tree_id": message_tree_id}) + return ActiveTreeSizeRow.from_orm(r.one()) + + def query_misssing_tree_states(self) -> list[UUID]: + """Find all initial prompt messages that have no associated message tree state""" + qry_missing_tree_states = ( + self.db.query(Message.id) + .join(MessageTreeState, isouter=True) + .filter( + Message.parent_id.is_(None), + Message.message_tree_id == Message.id, + MessageTreeState.message_tree_id.is_(None), + ) + ) + + return [m.id for m in qry_missing_tree_states.all()] + + _sql_find_tree_ranking_results = """ +-- get all ranking results of completed tasks for all parents with >=2 children +SELECT p.parent_id, mr.* FROM +( + -- find parents with > 1 children + SELECT m.parent_id, m.message_tree_id, COUNT(m.id) children_count + FROM message_tree_state mts + LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id + WHERE m.review_result -- must be reviewed + AND NOT m.deleted -- not deleted + AND m.parent_id IS NOT NULL -- ignore initial prompts + AND mts.message_tree_id = :message_tree_id + GROUP BY m.parent_id, m.message_tree_id + HAVING COUNT(m.id) > 1 +) as p +LEFT JOIN task t ON p.parent_id = t.parent_message_id AND t.done AND (t.payload_type = 'RankPrompterRepliesPayload' OR t.payload_type = 'RankAssistantRepliesPayload') +LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'RankingReactionPayload' +""" + + def query_tree_ranking_results(self, message_tree_id: UUID) -> dict[UUID, list[MessageReaction]]: + """Finds all completed ranking restuls for a message_tree""" + r = self.db.execute( + text(self._sql_find_tree_ranking_results), + {"message_tree_id": message_tree_id}, + ) + + rankings_by_message = {} + for x in r.all(): + parent_id = x["parent_id"] + if parent_id not in rankings_by_message: + rankings_by_message[parent_id] = [] + if x["task_id"]: + rankings_by_message[parent_id].append(MessageReaction.from_orm(x)) + return rankings_by_message + + def ensure_tree_states(self): + """Add message tree state rows for all root nodes (inital prompt messages).""" + + missing_tree_ids = self.query_misssing_tree_states() + for id in missing_tree_ids: + tree_size = self.db.query(func.count(Message.id)).filter(Message.message_tree_id == id).scalar() + state = message_tree_state.State.INITIAL_PROMPT_REVIEW + if tree_size > 1: + state = message_tree_state.State.GROWING + logger.info(f"Inserting missing message tree state for message: {id} ({tree_size=}, {state=})") + self._insert_default_state(id, state=state) + self.db.commit() + + def query_num_active_trees(self) -> int: + query = self.db.query(func.count(MessageTreeState.message_tree_id)).filter(MessageTreeState.active) + return query.scalar() + + def query_reviews_for_message(self, message_id: UUID) -> list[TextLabels]: + sql_qry = """ +SELECT tl.* +FROM task t + INNER JOIN text_labels tl ON tl.id = t.id +WHERE t.done = TRUE + AND tl.message_id = :message_id +""" + r = self.db.execute(text(sql_qry), {"message_id": message_id}) + return [TextLabels.from_orm(x) for x in r.all()] + + def _insert_tree_state( + self, + root_message_id: UUID, + goal_tree_size: int, + max_depth: int, + max_children_count: int, + active: bool, + state: message_tree_state.State = message_tree_state.State.INITIAL_PROMPT_REVIEW, + ) -> MessageTreeState: + model = MessageTreeState( + message_tree_id=root_message_id, + goal_tree_size=goal_tree_size, + max_depth=max_depth, + max_children_count=max_children_count, + state=state.value, + active=active, + accepted_messages=0, + ) + + self.db.add(model) + return model + + def _insert_default_state( + self, + root_message_id: UUID, + state: message_tree_state.State = message_tree_state.State.INITIAL_PROMPT_REVIEW, + ) -> MessageTreeState: + return self._insert_tree_state( + root_message_id=root_message_id, + goal_tree_size=self.cfg.goal_tree_size, + max_depth=self.cfg.max_tree_depth, + max_children_count=self.cfg.max_children_count, + state=state, + active=True, + ) + + +if __name__ == "__main__": + from oasst_backend.api.deps import get_dummy_api_client + from oasst_backend.database import engine + from oasst_backend.prompt_repository import PromptRepository + + with Session(engine) as db: + api_client = get_dummy_api_client(db) + dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local") + + pr = PromptRepository(db=db, api_client=api_client, client_user=dummy_user) + + cfg = TreeManagerConfiguration() + tm = TreeManager(db, pr, cfg) + tm.ensure_tree_states() + + print("query_num_active_trees", tm.query_num_active_trees()) + print("query_incomplete_rankings", tm.query_incomplete_rankings()) + print("query_incomplete_reply_reviews", tm.query_replies_need_review()) + print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review()) + print("query_extendible_trees", tm.query_extendible_trees()) + print("query_extendible_parents", tm.query_extendible_parents()) + + print("next_task:", tm.next_task()) + + print( + ".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("2ac20d38-6650-43aa-8bb3-f61080c0d921")) + ) diff --git a/backend/oasst_backend/utils/hugging_face.py b/backend/oasst_backend/utils/hugging_face.py index 87c6288e..099bc51f 100644 --- a/backend/oasst_backend/utils/hugging_face.py +++ b/backend/oasst_backend/utils/hugging_face.py @@ -51,12 +51,12 @@ class HuggingFaceAPI: async with session.post(self.api_url, headers=self.headers, json=payload) as response: # If we get a bad response - if response.status != 200: - + if not response.ok: logger.error(response) logger.info(self.headers) raise OasstError( - "Response Error Detoxify HuggingFace", error_code=OasstErrorCode.HUGGINGFACE_API_ERROR + f"Response Error HuggingFace API (Status: {response.status})", + error_code=OasstErrorCode.HUGGINGFACE_API_ERROR, ) # Get the response from the API call diff --git a/docker-compose.yaml b/docker-compose.yaml index 224c9efd..858acb68 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -99,6 +99,7 @@ services: - REDIS_HOST=redis - DEBUG_SKIP_API_KEY_CHECK=True - DEBUG_USE_SEED_DATA=True + - DEBUG_ALLOW_SELF_LABELING=True - MAX_WORKERS=1 - DEBUG_SKIP_EMBEDDING_COMPUTATION=True depends_on: diff --git a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py index ce08d31d..c365eed2 100644 --- a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py +++ b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py @@ -17,9 +17,11 @@ class OasstErrorCode(IntEnum): GENERIC_ERROR = 0 DATABASE_URI_NOT_SET = 1 API_CLIENT_NOT_AUTHORIZED = 2 - SERVER_ERROR = 3 TOO_MANY_REQUESTS = 429 + SERVER_ERROR0 = 500 + SERVER_ERROR1 = 501 + # 1000-2000: tasks endpoint TASK_INVALID_REQUEST_TYPE = 1000 TASK_ACK_FAILED = 1001 @@ -27,6 +29,7 @@ class OasstErrorCode(IntEnum): TASK_INVALID_RESPONSE_TYPE = 1003 TASK_INTERACTION_REQUEST_FAILED = 1004 TASK_GENERATION_FAILED = 1005 + TASK_REQUESTED_TYPE_NOT_AVAILABLE = 1006 # 2000-3000: prompt_repository INVALID_FRONTEND_MESSAGE_ID = 2000 @@ -38,6 +41,14 @@ class OasstErrorCode(IntEnum): NO_MESSAGE_TREE_FOUND = 2006 NO_REPLIES_FOUND = 2007 INVALID_MESSAGE = 2008 + BROKEN_CONVERSATION = 2009 + TREE_NOT_IN_GROWING_STATE = 2010 + CORRUPT_RANKING_RESULT = 2011 + + TEXT_LABELS_WRONG_MESSAGE_ID = 2050 + TEXT_LABELS_INVALID_LABEL = 2051 + TEXT_LABELS_MANDATORY_LABEL_MISSING = 2052 + TASK_NOT_FOUND = 2100 TASK_EXPIRED = 2101 TASK_PAYLOAD_TYPE_MISMATCH = 2102 @@ -45,6 +56,7 @@ class OasstErrorCode(IntEnum): TASK_NOT_ACK = 2104 TASK_ALREADY_DONE = 2105 TASK_NOT_COLLECTIVE = 2106 + TASK_NOT_ASSIGNED_TO_USER = 2106 USER_NOT_FOUND = 2200 # 3000-4000: external resources diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index 3372cafa..374e8d26 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -151,7 +151,8 @@ class RankInitialPromptsTask(Task): """A task to rank a set of initial prompts.""" type: Literal["rank_initial_prompts"] = "rank_initial_prompts" - prompts: list[str] + prompts: list[str] # deprecated, use prompt_messages + prompt_messages: list[ConversationMessage] class RankConversationRepliesTask(Task): @@ -159,7 +160,8 @@ class RankConversationRepliesTask(Task): type: Literal["rank_conversation_replies"] = "rank_conversation_replies" conversation: Conversation # the conversation so far - replies: list[str] + replies: list[str] # deprecated, use reply_messages + reply_messages: list[ConversationMessage] class RankPrompterRepliesTask(RankConversationRepliesTask): @@ -181,6 +183,7 @@ class LabelInitialPromptTask(Task): message_id: UUID prompt: str valid_labels: list[str] + mandatory_labels: Optional[list[str]] class LabelConversationReplyTask(Task): @@ -191,6 +194,7 @@ class LabelConversationReplyTask(Task): message_id: UUID reply: str valid_labels: list[str] + mandatory_labels: Optional[list[str]] class LabelPrompterReplyTask(LabelConversationReplyTask): @@ -304,6 +308,7 @@ class TextLabels(Interaction): text: str labels: dict[TextLabel, float] message_id: UUID + task_id: Optional[UUID] @property def has_message_id(self) -> bool: diff --git a/scripts/backend-development/run-local.sh b/scripts/backend-development/run-local.sh index f0f6d16c..2433c67e 100755 --- a/scripts/backend-development/run-local.sh +++ b/scripts/backend-development/run-local.sh @@ -6,6 +6,7 @@ pushd "$parent_path/../../backend" export DEBUG_SKIP_API_KEY_CHECK=True export DEBUG_USE_SEED_DATA=True +export DEBUG_ALLOW_SELF_LABELING=True export DEBUG_SKIP_EMBEDDING_COMPUTATION=True uvicorn main:app --reload --port 8080 --host 0.0.0.0 diff --git a/text-frontend/__main__.py b/text-frontend/__main__.py index a6e5f947..b9234d4f 100644 --- a/text-frontend/__main__.py +++ b/text-frontend/__main__.py @@ -36,7 +36,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY") return response.json() typer.echo("Requesting work...") - tasks = [_post("/api/v1/tasks/", {"type": "random"})] + tasks = [_post("/api/v1/tasks/", {"type": "random", "user": USER})] while tasks: task = tasks.pop(0) match (task["type"]): @@ -58,6 +58,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY") { "type": "text_reply_to_message", "message_id": message_id, + "task_id": task["id"], "user_message_id": user_message_id, "text": summary, "user": USER, @@ -102,6 +103,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY") { "type": "text_reply_to_message", "message_id": message_id, + "task_id": task["id"], "user_message_id": user_message_id, "text": prompt, "user": USER, @@ -150,6 +152,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY") { "type": "text_reply_to_message", "message_id": message_id, + "task_id": task["id"], "user_message_id": user_message_id, "text": reply, "user": USER, @@ -200,6 +203,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY") { "type": "message_ranking", "message_id": message_id, + "task_id": task["id"], "ranking": ranking, "user": USER, }, @@ -232,6 +236,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY") { "type": "text_labels", "message_id": task["message_id"], + "task_id": task["id"], "text": task["prompt"], "labels": labels_dict, "user": USER, @@ -269,6 +274,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY") { "type": "text_labels", "message_id": task["message_id"], + "task_id": task["id"], "text": task["reply"], "labels": labels_dict, "user": USER,