Message tree state machine (#555)

* add query_incomplete_rankings()

* Add SQL queries for TreeManager task selection

* first working version of TreeManager.next_task()

* remove old generate_task(), add mandatory_labels to text_labels task

* Add ConversationMessage list to Ranking tasks

* add more sophisticated sql queries to find extendible trees

* add TreeManager.query_extendible_parents()

* fix task validation, seed data insertion (reviewed)

* provide user for task selection in text-frontend

* enter 'growing' state

* enter 'aborted_low_grade' state

* enter 'ranking' state

* check tree 'growing' state upon relpy insertion

* exclude user from labeling their own messages (added DEBUG_ALLOW_SELF_LABELING setting)

* add DEBUG_ALLOW_SELF_LABELING to docker-compose.yaml

* fix ranking submission

* add query_tree_ranking_results()

* add ranked_message_ids to RankingReactionPayload

* fix reply_messages instead of prompt_messages

* incorment 'ranking_count' of ranked replies

* added logic to check_condition_for_scoring_state

* changes to msg_tree_state_machine

* pre-commit changes

* enter 'ready_for_scoring' state

* re-add HF embedding call (lost during merge)

* use prepare_conversation() helper for seed-data creation

* Partially add user specified task selection

Co-authored-by: Daniel Hug <danielpatrickhug@gmail.com>
This commit is contained in:
Andreas Köpf
2023-01-11 10:54:03 +01:00
committed by GitHub
parent 23ff01c603
commit 14fa08e2e7
19 changed files with 1212 additions and 323 deletions
+1
View File
@@ -79,6 +79,7 @@
REDIS_HOST: oasst-redis
DEBUG_ALLOW_ANY_API_KEY: "true"
DEBUG_USE_SEED_DATA: "true"
DEBUG_ALLOW_SELF_LABELING: "true"
MAX_WORKERS: "1"
RATE_LIMIT: "false"
DEBUG_SKIP_EMBEDDING_COMPUTATION: "true"
@@ -0,0 +1,63 @@
"""restructure message_tree_state table
Revision ID: 92a367bb9f40
Revises: ba61fe17fb6e
Create Date: 2023-01-08 22:08:46.458195
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "92a367bb9f40"
down_revision = "aac6b2f66006"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("message_tree_state")
op.create_table(
"message_tree_state",
sa.Column("message_tree_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("goal_tree_size", sa.Integer(), nullable=False),
sa.Column("max_depth", sa.Integer(), nullable=False),
sa.Column("max_children_count", sa.Integer(), nullable=False),
sa.Column("state", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
sa.Column("active", sa.Boolean(), nullable=False),
sa.Column("accepted_messages", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["message_tree_id"],
["message.id"],
),
sa.PrimaryKeyConstraint("message_tree_id"),
)
op.create_index(op.f("ix_message_tree_state_active"), "message_tree_state", ["active"], unique=False)
op.create_index(op.f("ix_message_tree_state_state"), "message_tree_state", ["state"], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_message_tree_state_state"), table_name="message_tree_state")
op.drop_index(op.f("ix_message_tree_state_active"), table_name="message_tree_state")
op.drop_table("message_tree_state")
op.create_table(
"message_tree_state",
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
sa.Column("message_tree_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("state", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
sa.Column("goal_tree_size", sa.Integer(), nullable=False),
sa.Column("current_num_non_filtered_messages", sa.Integer(), nullable=False),
sa.Column("max_depth", sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_message_tree_state_message_tree_id"), "message_tree_state", ["message_tree_id"], unique=False
)
op.create_index("ix_message_tree_state_tree_id", "message_tree_state", ["message_tree_id"], unique=True)
# ### end Alembic commands ###
@@ -0,0 +1,31 @@
"""add review_count & ranking_count to message
Revision ID: 05975b274a81
Revises: 92a367bb9f40
Create Date: 2023-01-09 00:47:25.496036
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "05975b274a81"
down_revision = "92a367bb9f40"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("message", sa.Column("review_count", sa.Integer(), server_default=sa.text("0"), nullable=False))
op.add_column("message", sa.Column("review_result", sa.Boolean(), server_default=sa.text("false"), nullable=False))
op.add_column("message", sa.Column("ranking_count", sa.Integer(), server_default=sa.text("0"), nullable=False))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("message", "ranking_count")
op.drop_column("message", "review_result")
op.drop_column("message", "review_count")
# ### end Alembic commands ###
+29 -13
View File
@@ -12,9 +12,12 @@ from fastapi_limiter import FastAPILimiter
from loguru import logger
from oasst_backend.api.deps import get_dummy_api_client
from oasst_backend.api.v1.api import api_router
from oasst_backend.api.v1.utils import prepare_conversation
from oasst_backend.config import settings
from oasst_backend.database import engine
from oasst_backend.models import message_tree_state
from oasst_backend.prompt_repository import PromptRepository, TaskRepository, UserRepository
from oasst_backend.tree_manager import TreeManager, TreeManagerConfiguration
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from pydantic import BaseModel
@@ -116,6 +119,7 @@ if settings.DEBUG_USE_SEED_DATA:
pr = PromptRepository(
db=db, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr
)
tm = TreeManager(db, pr, TreeManagerConfiguration())
with open(settings.DEBUG_USE_SEED_DATA_PATH) as f:
dummy_messages_raw = json.load(f)
@@ -138,24 +142,19 @@ if settings.DEBUG_USE_SEED_DATA:
msg.parent_message_id, fail_if_missing=True
)
conversation_messages = pr.fetch_message_conversation(parent_message)
conversation = protocol_schema.Conversation(
messages=[
protocol_schema.ConversationMessage(
text=cmsg.text,
is_assistant=cmsg.role == "assistant",
message_id=cmsg.id,
fronend_message_id=cmsg.frontend_message_id,
)
for cmsg in conversation_messages
]
)
conversation = prepare_conversation(conversation_messages)
task = tr.store_task(
protocol_schema.AssistantReplyTask(conversation=conversation),
message_tree_id=parent_message.message_tree_id,
parent_message_id=parent_message.id,
)
tr.bind_frontend_message_id(task.id, msg.task_message_id)
message = pr.store_text_reply(msg.text, msg.task_message_id, msg.user_message_id)
message = pr.store_text_reply(
msg.text, msg.task_message_id, msg.user_message_id, review_count=5, review_result=True
)
if message.parent_id is None:
tm._insert_default_state(root_message_id=message.id, state=message_tree_state.State.GROWING)
db.commit()
logger.info(
f"Inserted: message_id: {message.id}, payload: {message.payload.payload}, parent_message_id: {message.parent_id}"
@@ -168,6 +167,19 @@ if settings.DEBUG_USE_SEED_DATA:
logger.exception("Seed data insertion failed")
@app.on_event("startup")
def ensure_tree_states():
try:
logger.info("Startup: TreeManager.ensure_tree_states()")
cfg = TreeManagerConfiguration() # TODO: decide where config is stored, e.g. load form json/yaml file
with Session(engine) as db:
tm = TreeManager(db, None, configuration=cfg)
tm.ensure_tree_states()
except Exception:
logger.exception("TreeManager.ensure_tree_states() failed.")
app.include_router(api_router, prefix=settings.API_V1_STR)
@@ -175,7 +187,7 @@ def get_openapi_schema():
return json.dumps(app.openapi())
if __name__ == "__main__":
def main():
# Importing here so we don't import packages unnecessarily if we're
# importing main as a module.
import argparse
@@ -198,3 +210,7 @@ if __name__ == "__main__":
print(get_openapi_schema())
else:
uvicorn.run(app, host=args.host, port=args.port)
if __name__ == "__main__":
main()
+8 -216
View File
@@ -1,15 +1,12 @@
import random
from typing import Any, Optional, Tuple
from typing import Any
from uuid import UUID
from fastapi import APIRouter, Depends
from fastapi.security.api_key import APIKey
from loguru import logger
from oasst_backend.api import deps
from oasst_backend.api.v1.utils import prepare_conversation
from oasst_backend.config import settings
from oasst_backend.prompt_repository import PromptRepository, TaskRepository
from oasst_backend.utils.hugging_face import HfEmbeddingModel, HfUrl, HuggingFaceAPI
from oasst_backend.tree_manager import TreeManager, TreeManagerConfiguration
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
@@ -18,160 +15,6 @@ from starlette.status import HTTP_204_NO_CONTENT
router = APIRouter()
def generate_task(
request: protocol_schema.TaskRequest, pr: PromptRepository
) -> Tuple[protocol_schema.Task, Optional[UUID], Optional[UUID]]:
message_tree_id = None
parent_message_id = None
match request.type:
case protocol_schema.TaskRequestType.random:
logger.info("Frontend requested a random task.")
disabled_tasks = (
protocol_schema.TaskRequestType.random,
protocol_schema.TaskRequestType.summarize_story,
protocol_schema.TaskRequestType.rate_summary,
)
candidate_tasks = set(protocol_schema.TaskRequestType).difference(disabled_tasks)
request.type = random.choice(tuple(candidate_tasks)).value
return generate_task(request, pr)
# AKo: Summary tasks are currently disabled/supported, we focus on the conversation tasks.
# case protocol_schema.TaskRequestType.summarize_story:
# logger.info("Generating a SummarizeStoryTask.")
# task = protocol_schema.SummarizeStoryTask(
# story="This is a story. A very long story. So long, it needs to be summarized.",
# )
# case protocol_schema.TaskRequestType.rate_summary:
# logger.info("Generating a RateSummaryTask.")
# task = protocol_schema.RateSummaryTask(
# full_text="This is a story. A very long story. So long, it needs to be summarized.",
# summary="This is a summary.",
# scale=protocol_schema.RatingScale(min=1, max=5),
# )
case protocol_schema.TaskRequestType.initial_prompt:
logger.info("Generating an InitialPromptTask.")
task = protocol_schema.InitialPromptTask(
hint="Ask the assistant about a current event." # this is optional
)
case protocol_schema.TaskRequestType.prompter_reply:
logger.info("Generating a PrompterReplyTask.")
messages = pr.fetch_random_conversation("assistant")
task_messages = [
protocol_schema.ConversationMessage(
text=msg.text,
is_assistant=(msg.role == "assistant"),
message_id=msg.id,
front_end_id=msg.frontend_message_id,
)
for msg in messages
]
task = protocol_schema.PrompterReplyTask(conversation=protocol_schema.Conversation(messages=task_messages))
message_tree_id = messages[-1].message_tree_id
parent_message_id = messages[-1].id
case protocol_schema.TaskRequestType.assistant_reply:
logger.info("Generating a AssistantReplyTask.")
messages = pr.fetch_random_conversation("prompter")
task_messages = [
protocol_schema.ConversationMessage(
text=msg.text,
is_assistant=(msg.role == "assistant"),
message_id=msg.id,
front_end_id=msg.frontend_message_id,
)
for msg in messages
]
task = protocol_schema.AssistantReplyTask(conversation=protocol_schema.Conversation(messages=task_messages))
message_tree_id = messages[-1].message_tree_id
parent_message_id = messages[-1].id
case protocol_schema.TaskRequestType.rank_initial_prompts:
logger.info("Generating a RankInitialPromptsTask.")
messages = pr.fetch_random_initial_prompts()
task = protocol_schema.RankInitialPromptsTask(prompts=[msg.text for msg in messages])
case protocol_schema.TaskRequestType.rank_prompter_replies:
logger.info("Generating a RankPrompterRepliesTask.")
conversation, replies = pr.fetch_multiple_random_replies(message_role="assistant")
task_messages = [
protocol_schema.ConversationMessage(
text=p.text,
is_assistant=(p.role == "assistant"),
message_id=p.id,
front_end_id=p.frontend_message_id,
)
for p in conversation
]
replies = [p.text for p in replies]
task = protocol_schema.RankPrompterRepliesTask(
conversation=protocol_schema.Conversation(
messages=task_messages,
),
replies=replies,
)
case protocol_schema.TaskRequestType.rank_assistant_replies:
logger.info("Generating a RankAssistantRepliesTask.")
conversation, replies = pr.fetch_multiple_random_replies(message_role="prompter")
task_messages = [
protocol_schema.ConversationMessage(
text=p.text,
is_assistant=(p.role == "assistant"),
message_id=p.id,
front_end_id=p.frontend_message_id,
)
for p in conversation
]
replies = [p.text for p in replies]
task = protocol_schema.RankAssistantRepliesTask(
conversation=prepare_conversation(conversation),
replies=replies,
)
case protocol_schema.TaskRequestType.label_initial_prompt:
logger.info("Generating a LabelInitialPromptTask.")
message = pr.fetch_random_initial_prompts(1)[0]
task = protocol_schema.LabelInitialPromptTask(
message_id=message.id,
prompt=message.text,
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
)
case protocol_schema.TaskRequestType.label_prompter_reply:
logger.info("Generating a LabelPrompterReplyTask.")
conversation, messages = pr.fetch_multiple_random_replies(max_size=1, message_role="assistant")
message = messages[0]
task = protocol_schema.LabelPrompterReplyTask(
message_id=message.id,
conversation=prepare_conversation(conversation),
reply=message.text,
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
)
case protocol_schema.TaskRequestType.label_assistant_reply:
logger.info("Generating a LabelAssistantReplyTask.")
conversation, messages = pr.fetch_multiple_random_replies(max_size=1, message_role="prompter")
message = messages[0]
task = protocol_schema.LabelAssistantReplyTask(
message_id=message.id,
conversation=prepare_conversation(conversation),
reply=message.text,
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
)
case _:
raise OasstError("Invalid request type", OasstErrorCode.TASK_INVALID_REQUEST_TYPE)
logger.info(f"Generated {task=}.")
return task, message_tree_id, parent_message_id
@router.post(
"/",
response_model=protocol_schema.AnyTask,
@@ -193,7 +36,9 @@ def request_task(
try:
pr = PromptRepository(db, api_client, client_user=request.user)
task, message_tree_id, parent_message_id = generate_task(request, pr)
tree_manager_config = TreeManagerConfiguration()
tm = TreeManager(db, pr, tree_manager_config)
task, message_tree_id, parent_message_id = tm.next_task(request.type)
pr.task_repository.store_task(task, message_tree_id, parent_message_id, request.collective)
except OasstError:
@@ -268,63 +113,10 @@ async def tasks_interaction(
try:
pr = PromptRepository(db, api_client, client_user=interaction.user)
tree_manager_config = TreeManagerConfiguration()
tm = TreeManager(db, pr, tree_manager_config)
return await tm.handle_interaction(interaction)
match type(interaction):
case protocol_schema.TextReplyToMessage:
logger.info(
f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
)
# here we store the text reply in the database
newMessage = pr.store_text_reply(
text=interaction.text,
frontend_message_id=interaction.message_id,
user_frontend_message_id=interaction.user_message_id,
)
if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION:
try:
hugging_face_api = HuggingFaceAPI(
f"{HfUrl.HUGGINGFACE_FEATURE_EXTRACTION.value}/{HfEmbeddingModel.MINILM.value}"
)
embedding = await hugging_face_api.post(interaction.text)
pr.insert_message_embedding(
message_id=newMessage.id, model=HfEmbeddingModel.MINILM.value, embedding=embedding
)
except OasstError:
logger.error(
f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
)
return protocol_schema.TaskDone()
case protocol_schema.MessageRating:
logger.info(
f"Frontend reports rating of {interaction.message_id=} with {interaction.rating=} by {interaction.user=}."
)
# here we store the rating in the database
pr.store_rating(interaction)
return protocol_schema.TaskDone()
case protocol_schema.MessageRanking:
logger.info(
f"Frontend reports ranking of {interaction.message_id=} with {interaction.ranking=} by {interaction.user=}."
)
# TODO: check if the ranking is valid
pr.store_ranking(interaction)
# here we would store the ranking in the database
return protocol_schema.TaskDone()
case protocol_schema.TextLabels:
logger.info(
f"Frontend reports labels of {interaction.message_id=} with {interaction.labels=} by {interaction.user=}."
)
# Labels are implicitly validated when converting str -> TextLabel
# So no need for explicit validation here
pr.store_text_labels(interaction)
return protocol_schema.TaskDone()
case _:
raise OasstError("Invalid response type.", OasstErrorCode.TASK_INVALID_RESPONSE_TYPE)
except OasstError:
raise
except Exception:
+12 -11
View File
@@ -18,19 +18,20 @@ def prepare_message_list(messages: list[Message]) -> list[protocol.Message]:
return [prepare_message(m) for m in messages]
def prepare_conversation(messages: list[Message]) -> protocol.Conversation:
conv_messages = []
for message in messages:
conv_messages.append(
protocol.ConversationMessage(
text=message.text,
is_assistant=(message.role == "assistant"),
message_id=message.id,
frontend_message_id=message.frontend_message_id,
)
def prepare_conversation_message_list(messages: list[Message]) -> list[protocol.ConversationMessage]:
return [
protocol.ConversationMessage(
text=message.text,
is_assistant=(message.role == "assistant"),
message_id=message.id,
frontend_message_id=message.frontend_message_id,
)
for message in messages
]
return protocol.Conversation(messages=conv_messages)
def prepare_conversation(messages: list[Message]) -> protocol.Conversation:
return protocol.Conversation(messages=prepare_conversation_message_list(messages))
def prepare_tree(tree: list[Message], tree_id: UUID) -> protocol.MessageTree:
+1
View File
@@ -25,6 +25,7 @@ class Settings(BaseSettings):
DEBUG_USE_SEED_DATA_PATH: Optional[FilePath] = (
Path(__file__).parent.parent / "test_data/generic/test_generic_data.json"
)
DEBUG_ALLOW_SELF_LABELING: bool = False # allow users to label their own messages
DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False
HUGGING_FACE_API_KEY: str = ""
+7 -4
View File
@@ -1,4 +1,4 @@
from typing import Literal
from typing import Literal, Optional
from uuid import UUID
from oasst_backend.models.payload_column_type import payload_type
@@ -28,7 +28,7 @@ class RateSummaryPayload(TaskPayload):
@payload_type
class InitialPromptPayload(TaskPayload):
type: Literal["initial_prompt"] = "initial_prompt"
hint: str
hint: str | None
@payload_type
@@ -64,12 +64,13 @@ class RatingReactionPayload(ReactionPayload):
class RankingReactionPayload(ReactionPayload):
type: Literal["message_ranking"] = "message_ranking"
ranking: list[int]
ranked_message_ids: list[UUID]
@payload_type
class RankConversationRepliesPayload(TaskPayload):
conversation: protocol_schema.Conversation # the conversation so far
replies: list[str]
reply_messages: list[protocol_schema.ConversationMessage]
@payload_type
@@ -77,7 +78,7 @@ class RankInitialPromptsPayload(TaskPayload):
"""A task to rank a set of initial prompts."""
type: Literal["rank_initial_prompts"] = "rank_initial_prompts"
prompts: list[str]
prompt_messages: list[protocol_schema.ConversationMessage]
@payload_type
@@ -102,6 +103,7 @@ class LabelInitialPromptPayload(TaskPayload):
message_id: UUID
prompt: str
valid_labels: list[str]
mandatory_labels: Optional[list[str]]
@payload_type
@@ -112,6 +114,7 @@ class LabelConversationReplyPayload(TaskPayload):
conversation: protocol_schema.Conversation
reply: str
valid_labels: list[str]
mandatory_labels: Optional[list[str]]
@payload_type
+4
View File
@@ -41,6 +41,10 @@ class Message(SQLModel, table=True):
children_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
review_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
review_result: bool = Field(sa_column=sa.Column(sa.Boolean, default=False, server_default=false(), nullable=False))
ranking_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
def ensure_is_message(self) -> None:
if not self.payload or not isinstance(self.payload.payload, MessagePayload):
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE, HTTPStatus.INTERNAL_SERVER_ERROR)
@@ -1,29 +1,28 @@
from enum import Enum
from typing import Optional
from uuid import UUID, uuid4
from uuid import UUID
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, Index, SQLModel
from sqlmodel import Field, SQLModel
class States(str, Enum):
class State(str, Enum):
"""States of the Open-Assistant message tree state machine."""
INITIAL_PROMPT_REVIEW = "initial_prompt_review"
"""In this state the message tree consists only of a single inital prompt root node.
Initial prompt labeling tasks will determine if the tree goes into `breeding_phase` or
`aborted_low_grade`."""
Initial prompt labeling tasks will determine if the tree goes into `growing` or
`aborted_low_grade` state."""
BREEDING_PHASE = "breeding_phase"
GROWING = "growing"
"""Assistant & prompter human demonstrations are collected. Concurrently labeling tasks
are handed out to check if the quality of the replies surpasses the minimum acceptable
quality.
When the required number of messages passing the initial labelling-quality check has been
collected the tree will enter `ranking_phase`. If too many poor-quality labelling responses
collected the tree will enter `ranking`. If too many poor-quality labelling responses
are received the tree can also enter the `aborted_low_grade` state."""
RANKING_PHASE = "ranking_phase"
RANKING = "ranking"
"""The tree has been successfully populated with the desired number of messages. Ranking
tasks are now handed out for all nodes with more than one child."""
@@ -46,28 +45,26 @@ class States(str, Enum):
VALID_STATES = (
States.INITIAL_PROMPT_REVIEW,
States.BREEDING_PHASE,
States.RANKING_PHASE,
States.READY_FOR_SCORING,
States.READY_FOR_EXPORT,
States.ABORTED_LOW_GRADE,
State.INITIAL_PROMPT_REVIEW,
State.GROWING,
State.RANKING,
State.READY_FOR_SCORING,
State.READY_FOR_EXPORT,
State.ABORTED_LOW_GRADE,
)
TERMINAL_STATES = (States.READY_FOR_EXPORT, States.ABORTED_LOW_GRADE, States.SCORING_FAILED, States.HALTED_BY_MODERATOR)
TERMINAL_STATES = (State.READY_FOR_EXPORT, State.ABORTED_LOW_GRADE, State.SCORING_FAILED, State.HALTED_BY_MODERATOR)
class MessageTreeState(SQLModel, table=True):
__tablename__ = "message_tree_state"
__table_args__ = (Index("ix_message_tree_state_tree_id", "message_tree_id", unique=True),)
id: Optional[UUID] = Field(
sa_column=sa.Column(
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
),
message_tree_id: UUID = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("message.id"), primary_key=True)
)
message_tree_id: UUID = Field(nullable=False, index=True)
state: str = Field(nullable=False, max_length=128)
goal_tree_size: int = Field(nullable=False)
current_num_non_filtered_messages: int = Field(nullable=False)
max_depth: int = Field(nullable=False)
max_children_count: int = Field(nullable=False)
state: str = Field(nullable=False, max_length=128, index=True)
active: bool = Field(nullable=False, index=True)
accepted_messages: int = Field(nullable=False, default=0)
+185 -42
View File
@@ -2,13 +2,24 @@ import datetime
import random
from collections import defaultdict
from http import HTTPStatus
from typing import List, Optional
from typing import List, Optional, Tuple
from uuid import UUID, uuid4
import oasst_backend.models.db_payload as db_payload
import sqlalchemy as sa
from loguru import logger
from oasst_backend.journal_writer import JournalWriter
from oasst_backend.models import ApiClient, Message, MessageEmbedding, MessageReaction, TextLabels, User
from oasst_backend.models import (
ApiClient,
Message,
MessageEmbedding,
MessageReaction,
MessageTreeState,
Task,
TextLabels,
User,
message_tree_state,
)
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_backend.task_repository import TaskRepository, validate_frontend_message_id
from oasst_backend.user_repository import UserRepository
@@ -34,6 +45,7 @@ class PromptRepository:
self.user_repository = user_repository or UserRepository(db, api_client)
self.user = self.user_repository.lookup_client_user(client_user, create_missing=True)
self.user_id = self.user.id if self.user else None
logger.debug(f"PromptRepository(api_client_id={self.api_client.id}, {self.user_id=})")
self.task_repository = task_repository or TaskRepository(
db, api_client, client_user, user_repository=self.user_repository
)
@@ -66,6 +78,8 @@ class PromptRepository:
payload: db_payload.MessagePayload,
payload_type: str = None,
depth: int = 0,
review_count: int = 0,
review_result: bool = False,
) -> Message:
if payload_type is None:
if payload is None:
@@ -85,25 +99,24 @@ class PromptRepository:
payload_type=payload_type,
payload=PayloadContainer(payload=payload),
depth=depth,
review_count=review_count,
review_result=review_result,
)
self.db.add(message)
self.db.commit()
self.db.refresh(message)
return message
def store_text_reply(
self,
text: str,
frontend_message_id: str,
user_frontend_message_id: str,
) -> Message:
validate_frontend_message_id(frontend_message_id)
validate_frontend_message_id(user_frontend_message_id)
task = self.task_repository.fetch_task_by_frontend_message_id(frontend_message_id)
def _validate_task(
self, task: Task, *, task_id: Optional[UUID] = None, frontend_message_id: Optional[str] = None
) -> Task:
if task is None:
raise OasstError(f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
if task_id:
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
if frontend_message_id:
raise OasstError(f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
raise OasstError("Task not found", OasstErrorCode.TASK_NOT_FOUND)
if task.expired:
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
if not task.ack:
@@ -111,12 +124,45 @@ class PromptRepository:
if task.done:
raise OasstError("Task already done.", OasstErrorCode.TASK_ALREADY_DONE)
if (not task.collective or task.user_id is None) and task.user_id != self.user_id:
logger.warning(f"Task was assigned to a different user (expected: {task.user_id}; actual: {self.user_id}).")
raise OasstError("Task was assigned to a different user.", OasstErrorCode.TASK_NOT_ASSIGNED_TO_USER)
return task
def fetch_tree_state(self, message_tree_id: UUID) -> MessageTreeState:
return self.db.query(MessageTreeState).filter(MessageTreeState.message_tree_id == message_tree_id).one()
def store_text_reply(
self,
text: str,
frontend_message_id: str,
user_frontend_message_id: str,
review_count: int = 0,
review_result: bool = False,
) -> Message:
validate_frontend_message_id(frontend_message_id)
validate_frontend_message_id(user_frontend_message_id)
task = self.task_repository.fetch_task_by_frontend_message_id(frontend_message_id)
self._validate_task(task)
# If there's no parent message assume user started new conversation
role = "prompter"
depth = 0
if task.parent_message_id:
parent_message = self.fetch_message(task.parent_message_id)
# check tree state
ts = self.fetch_tree_state(parent_message.message_tree_id)
if not ts.active or ts.state != message_tree_state.State.GROWING:
raise OasstError(
"Message insertion failed. Message tree is no longer in 'growing' state.",
OasstErrorCode.TREE_NOT_IN_GROWING_STATE,
)
parent_message.message_tree_id
parent_message.children_count += 1
self.db.add(parent_message)
@@ -137,6 +183,8 @@ class PromptRepository:
role=role,
payload=db_payload.MessagePayload(text=text),
depth=depth,
review_count=review_count,
review_result=review_result,
)
if not task.collective:
task.done = True
@@ -149,6 +197,7 @@ class PromptRepository:
message = self.fetch_message_by_frontend_message_id(rating.message_id, fail_if_missing=True)
task = self.task_repository.fetch_task_by_frontend_message_id(rating.message_id)
self._validate_task(task)
task_payload: db_payload.RateSummaryPayload = task.payload.payload
if type(task_payload) != db_payload.RateSummaryPayload:
raise OasstError(
@@ -173,9 +222,10 @@ class PromptRepository:
logger.info(f"Ranking {rating.rating} stored for task {task.id}.")
return reaction
def store_ranking(self, ranking: protocol_schema.MessageRanking) -> MessageReaction:
def store_ranking(self, ranking: protocol_schema.MessageRanking) -> Tuple[MessageReaction, Task]:
# fetch task
task = self.task_repository.fetch_task_by_frontend_message_id(ranking.message_id)
self._validate_task(task, frontend_message_id=ranking.message_id)
if not task.collective:
task.done = True
self.db.add(task)
@@ -188,47 +238,59 @@ class PromptRepository:
case db_payload.RankPrompterRepliesPayload | db_payload.RankAssistantRepliesPayload:
# validate ranking
num_replies = len(task_payload.replies)
if sorted(ranking.ranking) != list(range(num_replies)):
if sorted(ranking.ranking) != list(range(num_replies := len(task_payload.reply_messages))):
raise OasstError(
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_replies=}).",
OasstErrorCode.INVALID_RANKING_VALUE,
)
last_conv_message = task_payload.conversation.messages[-1]
parent_msg = self.fetch_message(last_conv_message.message_id)
# store reaction to message
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
ranked_message_ids = [task_payload.reply_messages[i].message_id for i in ranking.ranking]
for mid in ranked_message_ids:
message = self.fetch_message(mid)
if message.parent_id != parent_msg.id:
raise OasstError("Corrupt reply ranking result", OasstErrorCode.CORRUPT_RANKING_RESULT)
message.ranking_count += 1
self.db.add(message)
reaction_payload = db_payload.RankingReactionPayload(
ranking=ranking.ranking, ranked_message_ids=ranked_message_ids
)
reaction = self.insert_reaction(task.id, reaction_payload)
# TODO: resolve message_id
self.journal.log_ranking(task, message_id=None, ranking=ranking.ranking)
self.journal.log_ranking(task, message_id=parent_msg.id, ranking=ranking.ranking)
logger.info(f"Ranking {ranking.ranking} stored for task {task.id}.")
return reaction
case db_payload.RankInitialPromptsPayload:
# validate ranking
if sorted(ranking.ranking) != list(range(num_prompts := len(task_payload.prompts))):
if sorted(ranking.ranking) != list(range(num_prompts := len(task_payload.prompt_messages))):
raise OasstError(
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_prompts=}).",
OasstErrorCode.INVALID_RANKING_VALUE,
)
# store reaction to message
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
ranked_message_ids = [task_payload.prompt_messages[i].message_id for i in ranking.ranking]
reaction_payload = db_payload.RankingReactionPayload(
ranking=ranking.ranking, ranked_message_ids=ranked_message_ids
)
reaction = self.insert_reaction(task.id, reaction_payload)
# TODO: resolve message_id
self.journal.log_ranking(task, message_id=None, ranking=ranking.ranking)
logger.info(f"Ranking {ranking.ranking} stored for task {task.id}.")
return reaction
case _:
raise OasstError(
f"task payload type mismatch: {type(task_payload)=} != {db_payload.RankConversationRepliesPayload}",
OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH,
)
return reaction, task
def insert_message_embedding(self, message_id: UUID, model: str, embedding: List[float]) -> MessageEmbedding:
"""Insert the embedding of a new message in the database.
@@ -270,21 +332,80 @@ class PromptRepository:
self.db.refresh(reaction)
return reaction
def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> TextLabels:
def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> Tuple[TextLabels, Task, Message]:
valid_labels: Optional[list[str]] = None
mandatory_labels: Optional[list[str]] = None
text_labels_id: Optional[UUID] = None
message_id: Optional[UUID] = text_labels.message_id
task: Task = None
if text_labels.task_id:
logger.debug(f"text_labels reply has task_id {text_labels.task_id}")
task = self.task_repository.fetch_task_by_id(text_labels.task_id)
self._validate_task(task, task_id=text_labels.task_id)
task_payload: db_payload.TaskPayload = task.payload.payload
if isinstance(task_payload, db_payload.LabelInitialPromptPayload):
if message_id and task_payload.message_id != message_id:
raise OasstError("Task message id mismatch", OasstErrorCode.TEXT_LABELS_WRONG_MESSAGE_ID)
message_id = task_payload.message_id
valid_labels = task_payload.valid_labels
mandatory_labels = task_payload.mandatory_labels
elif isinstance(task_payload, db_payload.LabelConversationReplyPayload):
if message_id and message_id != message_id:
raise OasstError("Task message id mismatch", OasstErrorCode.TEXT_LABELS_WRONG_MESSAGE_ID)
message_id = task_payload.message_id
valid_labels = task_payload.valid_labels
mandatory_labels = task_payload.mandatory_labels
else:
raise OasstError(
"Unexpected text_labels task payload",
OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH,
)
logger.debug(f"text_labels relpy: {valid_labels=}, {mandatory_labels=}")
if valid_labels:
if not all([label in valid_labels for label in text_labels.labels.keys()]):
raise OasstError("Invalid text label specified", OasstErrorCode.TEXT_LABELS_INVALID_LABEL)
if isinstance(mandatory_labels, list):
mandatory_set = set(mandatory_labels)
if not mandatory_set.issubset(text_labels.labels.keys()):
missing = ", ".join(mandatory_set - text_labels.labels.keys())
raise OasstError(
f"Mandatory text labels missing: {missing}", OasstErrorCode.TEXT_LABELS_MANDATORY_LABEL_MISSING
)
text_labels_id = task.id # associate with task by sharing the id
if not task.collective:
task.done = True
self.db.add(task)
logger.debug(f"inserting TextLabels for {message_id=}, {text_labels_id=}")
model = TextLabels(
id=text_labels_id,
api_client_id=self.api_client.id,
message_id=text_labels.message_id,
message_id=message_id,
user_id=self.user_id,
text=text_labels.text,
labels=text_labels.labels,
)
if message_id:
message = self.fetch_message(message_id)
if task:
message.review_count += 1
self.db.add(message)
self.db.add(model)
self.db.commit()
self.db.refresh(model)
return model
return model, task, message
def fetch_random_message_tree(self, require_role: str = None) -> list[Message]:
def fetch_random_message_tree(self, require_role: str = None, reviewed: bool = True) -> list[Message]:
"""
Loads all messages of a random message_tree.
@@ -294,13 +415,18 @@ class PromptRepository:
distinct_message_trees = self.db.query(Message.message_tree_id).distinct(Message.message_tree_id)
if require_role:
distinct_message_trees = distinct_message_trees.filter(Message.role == require_role)
if reviewed:
distinct_message_trees = distinct_message_trees.filter(Message.review_result)
distinct_message_trees = distinct_message_trees.subquery()
random_message_tree = self.db.query(distinct_message_trees).order_by(func.random()).limit(1)
message_tree_messages = self.db.query(Message).filter(Message.message_tree_id.in_(random_message_tree)).all()
return message_tree_messages
random_message_tree_id = self.db.query(distinct_message_trees).order_by(func.random()).limit(1).scalar()
if random_message_tree_id:
return self.fetch_message_tree(random_message_tree_id, reviewed)
return None
def fetch_random_conversation(self, last_message_role: str = None) -> list[Message]:
def fetch_random_conversation(
self, last_message_role: str = None, message_tree_id: Optional[UUID] = None, reviewed: bool = True
) -> list[Message]:
"""
Picks a random linear conversation starting from any root message
and ending somewhere in the message_tree, possibly at the root itself.
@@ -310,9 +436,13 @@ class PromptRepository:
the user should reply as a human and hence the last message of the conversation
needs to have "assistant" role.
"""
messages_tree = self.fetch_random_message_tree(last_message_role)
if message_tree_id:
messages_tree = self.fetch_message_tree(message_tree_id, reviewed)
else:
messages_tree = self.fetch_random_message_tree(last_message_role)
if not messages_tree:
raise OasstError("No message tree found", OasstErrorCode.NO_MESSAGE_TREE_FOUND)
if last_message_role:
conv_messages = [m for m in messages_tree if m.role == last_message_role]
conv_messages = [random.choice(conv_messages)]
@@ -334,8 +464,11 @@ class PromptRepository:
messages = self.db.query(Message).filter(Message.parent_id.is_(None)).order_by(func.random()).limit(size).all()
return messages
def fetch_message_tree(self, message_tree_id: UUID):
return self.db.query(Message).filter(Message.message_tree_id == message_tree_id).all()
def fetch_message_tree(self, message_tree_id: UUID, reviewed: bool = True):
qry = self.db.query(Message).filter(Message.message_tree_id == message_tree_id)
if reviewed:
qry = qry.filter(Message.review_result)
return qry.all()
def fetch_multiple_random_replies(self, max_size: int = 5, message_role: str = None):
"""
@@ -388,13 +521,15 @@ class PromptRepository:
messages = {m.id: m for m in messages}
if not isinstance(messages, dict):
# This should not normally happen
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR0, HTTPStatus.INTERNAL_SERVER_ERROR)
conv = [last_message]
while conv[-1].parent_id:
if conv[-1].parent_id not in messages:
# Can't form a continuous conversation
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
raise OasstError(
"Broken conversation", OasstErrorCode.BROKEN_CONVERSATION, HTTPStatus.INTERNAL_SERVER_ERROR
)
parent_message = messages[conv[-1].parent_id]
conv.append(parent_message)
@@ -417,16 +552,24 @@ class PromptRepository:
"""
if isinstance(message, UUID):
message = self.fetch_message(message)
logger.debug(f"fetch_message_tree({message.message_tree_id=})")
return self.fetch_message_tree(message.message_tree_id)
def fetch_message_children(self, message: Message | UUID) -> list[Message]:
def fetch_message_children(
self, message: Message | UUID, reviewed: bool = True, exclude_deleted: bool = True
) -> list[Message]:
"""
Get all direct children of this message
"""
if isinstance(message, Message):
message = message.id
children = self.db.query(Message).filter(Message.parent_id == message).all()
qry = self.db.query(Message).filter(Message.parent_id == message)
if reviewed:
qry = qry.filter(Message.review_result)
if exclude_deleted:
qry = qry.filter(Message.deleted == sa.false())
children = qry.all()
return children
@staticmethod
@@ -536,7 +679,7 @@ class PromptRepository:
elif isinstance(message, Message):
ids.append(message.id)
else:
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR1, HTTPStatus.INTERNAL_SERVER_ERROR)
query = update(Message).where(Message.id.in_(ids)).values(deleted=True)
self.db.execute(query)
+14 -6
View File
@@ -2,6 +2,7 @@ from typing import Optional
from uuid import UUID
import oasst_backend.models.db_payload as db_payload
from loguru import logger
from oasst_backend.models import ApiClient, Task
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_backend.user_repository import UserRepository
@@ -62,16 +63,16 @@ class TaskRepository:
payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
case protocol_schema.RankInitialPromptsTask:
payload = db_payload.RankInitialPromptsPayload(type=task.type, prompts=task.prompts)
payload = db_payload.RankInitialPromptsPayload(type=task.type, prompt_messages=task.prompt_messages)
case protocol_schema.RankPrompterRepliesTask:
payload = db_payload.RankPrompterRepliesPayload(
type=task.type, conversation=task.conversation, replies=task.replies
type=task.type, conversation=task.conversation, reply_messages=task.reply_messages
)
case protocol_schema.RankAssistantRepliesTask:
payload = db_payload.RankAssistantRepliesPayload(
type=task.type, conversation=task.conversation, replies=task.replies
type=task.type, conversation=task.conversation, reply_messages=task.reply_messages
)
case protocol_schema.LabelInitialPromptTask:
@@ -86,6 +87,7 @@ class TaskRepository:
conversation=task.conversation,
reply=task.reply,
valid_labels=task.valid_labels,
mandatory_labels=task.mandatory_labels,
)
case protocol_schema.LabelAssistantReplyTask:
@@ -95,20 +97,21 @@ class TaskRepository:
conversation=task.conversation,
reply=task.reply,
valid_labels=task.valid_labels,
mandatory_labels=task.mandatory_labels,
)
case _:
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
task = self.insert_task(
task_model = self.insert_task(
payload=payload,
id=task.id,
message_tree_id=message_tree_id,
parent_message_id=parent_message_id,
collective=collective,
)
assert task.id == task.id
return task
assert task_model.id == task.id
return task_model
def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str):
validate_frontend_message_id(frontend_message_id)
@@ -184,6 +187,7 @@ class TaskRepository:
parent_message_id=parent_message_id,
collective=collective,
)
logger.debug(f"inserting {task=}")
self.db.add(task)
self.db.commit()
self.db.refresh(task)
@@ -197,3 +201,7 @@ class TaskRepository:
.one_or_none()
)
return task
def fetch_task_by_id(self, task_id: UUID) -> Task:
task = self.db.query(Task).filter(Task.api_client_id == self.api_client.id, Task.id == task_id).one_or_none()
return task
+804
View File
@@ -0,0 +1,804 @@
import random
from enum import Enum
from typing import Optional, Tuple
from uuid import UUID
import numpy as np
import pydantic
from loguru import logger
from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversation_message_list
from oasst_backend.config import settings
from oasst_backend.models import Message, MessageReaction, MessageTreeState, TextLabels, message_tree_state
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.utils.hugging_face import HfEmbeddingModel, HfUrl, HuggingFaceAPI
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlalchemy.sql import text
from sqlmodel import Session, func
class TreeManagerConfiguration(pydantic.BaseModel):
"""Configuration class for the TreeManager"""
max_active_trees: int = 10
"""Maximum number of concurrently active trees in the database.
No new initial prompt tasks will be handed out to users if this
number is reached."""
max_tree_depth: int = 6
"""Maximum depth of message tree."""
max_children_count: int = 5
"""Maximum number of reply messages per tree node."""
goal_tree_size: int = 15
"""Total number of messages to gather per tree"""
num_reviews_initial_prompt: int = 3
"""Number of peer review checks to collect in INITIAL_PROMPT_REVIEW state."""
num_reviews_reply: int = 3
"""Number of peer review checks to collect per reply (other than initial_prompt)"""
acceptance_threshold_initial_prompt: float = 0.6
"""Threshold for accepting an initial prompt"""
acceptance_threshold_reply: float = 0.6
"""Threshold for accepting a reply"""
num_required_rankings: int = 3
"""Number of rankings in which the message participated."""
mandatory_labels_initial_prompt: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam]
mandatory_labels_assistant_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam]
mandatory_labels_prompter_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam]
class TaskType(Enum):
NONE = -1
RANKING = 0
LABEL_REPLY = 1
REPLY = 2
LABEL_PROMPT = 3
PROMPT = 4
class TaskRole(Enum):
ANY = 0
PROMPTER = 1
ASSISTANT = 2
class ActiveTreeSizeRow(pydantic.BaseModel):
message_tree_id: UUID
tree_size: int
goal_tree_size: int
@property
def remaining_messages(self) -> int:
return max(0, self.goal_tree_size - self.tree_size)
class Config:
orm_mode = True
class ExtendibleParentRow(pydantic.BaseModel):
parent_id: UUID
depth: int
message_tree_id: UUID
active_children_count: int
class Config:
orm_mode = True
class IncompleteRankingsRow(pydantic.BaseModel):
parent_id: UUID
children_count: int
child_min_ranking_count: int
class Config:
orm_mode = True
class TreeManager:
def __init__(self, db: Session, prompt_repository: PromptRepository, configuration: TreeManagerConfiguration):
self.db = db
self.cfg = configuration
self.pr = prompt_repository
def _task_selection(
self,
desired_task_type: protocol_schema.TaskRequestType,
num_ranking_tasks: int,
num_replies_need_review: int,
num_prompts_need_review: int,
num_missing_prompts: int,
num_missing_replies: int,
) -> Tuple[TaskType, TaskRole]:
"""
Determines which task to hand out to human worker.
The task type is drawn with relative weight (e.g. ranking has highest priority)
depending on what is possible with the current message trees in the database.
"""
logger.debug(
f"TreeManager._task_selection({num_ranking_tasks=}, {num_replies_need_review=}, "
f"{num_prompts_need_review=}, {num_missing_prompts=}, {num_missing_replies=})"
)
task_type = TaskType.NONE
task_role = TaskRole.ANY
if desired_task_type == protocol_schema.TaskRequestType.random:
task_weights = [0] * 5
if num_ranking_tasks > 0:
task_weights[TaskType.RANKING.value] = 10
if num_replies_need_review > 0:
task_weights[TaskType.LABEL_REPLY.value] = 5
if num_prompts_need_review > 0:
task_weights[TaskType.LABEL_PROMPT.value] = 5
if num_missing_replies > 0:
task_weights[TaskType.REPLY.value] = 2
if num_missing_prompts > 0:
task_weights[TaskType.PROMPT.value] = 1
task_weights = np.array(task_weights)
weight_sum = task_weights.sum()
if weight_sum < 1e-8:
task_type = TaskType.NONE
else:
task_weights = task_weights / weight_sum
task_type = TaskType(np.random.choice(a=len(task_weights), p=task_weights))
else:
match desired_task_type:
case protocol_schema.TaskRequestType.initial_prompt:
if num_missing_prompts > 0:
task_type = TaskType.PROMPT
case protocol_schema.TaskRequestType.label_initial_prompt:
if num_prompts_need_review > 0:
task_type = TaskType.LABEL_PROMPT
case protocol_schema.TaskRequestType.assistant_reply | protocol_schema.TaskRequestType.prompter_reply:
if num_missing_replies > 0:
task_role = (
TaskRole.ASSISTANT
if desired_task_type == protocol_schema.TaskRequestType.assistant_reply
else TaskRole.PROMPTER
)
task_type = TaskType.REPLY
case protocol_schema.TaskRequestType.label_assistant_reply | protocol_schema.TaskRequestType.label_prompter_reply:
if num_replies_need_review > 0:
task_role = (
TaskRole.ASSISTANT
if desired_task_type == protocol_schema.TaskRequestType.label_assistant_reply
else TaskRole.PROMPTER
)
task_type = TaskType.LABEL_REPLY
case protocol_schema.TaskRequestType.rank_assistant_replies | protocol_schema.TaskRequestType.rank_prompter_replies:
if num_ranking_tasks > 0:
task_role = (
TaskRole.ASSISTANT
if desired_task_type == protocol_schema.TaskRequestType.rank_assistant_replies
else TaskRole.PROMPTER
)
task_type = TaskType.RANKING
logger.debug(f"Selected {task_type=}, {task_role=}")
return task_type, task_role
def next_task(
self, desired_task_type: protocol_schema.TaskRequestType
) -> Tuple[protocol_schema.Task, Optional[UUID], Optional[UUID]]:
logger.debug("TreeManager.next_task()")
num_active_trees = self.query_num_active_trees()
prompts_need_review = self.query_prompts_need_review()
replies_need_review = self.query_replies_need_review()
incomplete_rankings = self.query_incomplete_rankings()
active_tree_sizes = self.query_extendible_trees()
# determine type of task to generate
num_missing_replies = sum(x.remaining_messages for x in active_tree_sizes)
task_type, task_role = self._task_selection(
desired_task_type,
num_ranking_tasks=len(incomplete_rankings),
num_replies_need_review=len(replies_need_review),
num_prompts_need_review=len(prompts_need_review),
num_missing_prompts=max(0, self.cfg.max_active_trees - num_active_trees),
num_missing_replies=num_missing_replies,
)
if task_type == TaskType.NONE:
raise OasstError(
f"No tasks of type '{desired_task_type.value}' are currently available.",
OasstErrorCode.TASK_REQUESTED_TYPE_UNAVAILABLE,
)
if task_role != TaskRole.ANY:
# Todo: Allow role specific message selection...
raise OasstError(
f"No tasks of type '{desired_task_type.value}' are currently available.",
OasstErrorCode.TASK_REQUESTED_TYPE_UNAVAILABLE,
)
message_tree_id = None
parent_message_id = None
logger.debug(f"selected {task_type=}")
match task_type:
case TaskType.RANKING:
assert len(incomplete_rankings) > 0
ranking_parent_id = random.choice(incomplete_rankings).parent_id
messages = self.pr.fetch_message_conversation(ranking_parent_id)
conversation = prepare_conversation(messages)
replies = self.pr.fetch_message_children(ranking_parent_id, reviewed=True, exclude_deleted=True)
reply_messages = prepare_conversation_message_list(replies)
replies = [p.text for p in replies]
if messages[-1].role == "assistant":
logger.info("Generating a RankPrompterRepliesTask.")
task = protocol_schema.RankPrompterRepliesTask(
conversation=conversation, replies=replies, reply_messages=reply_messages
)
else:
logger.info("Generating a RankAssistantRepliesTask.")
task = protocol_schema.RankAssistantRepliesTask(
conversation=conversation, replies=replies, reply_messages=reply_messages
)
parent_message_id = ranking_parent_id
message_tree_id = messages[-1].message_tree_id
case TaskType.LABEL_REPLY:
assert len(replies_need_review) > 0
random_reply_message_id = random.choice(replies_need_review)
messages = self.pr.fetch_message_conversation(random_reply_message_id)
conversation = prepare_conversation(messages[:-1])
message = messages[-1]
if message.role == "assistant":
logger.info("Generating a LabelAssistantReplyTask.")
task = protocol_schema.LabelAssistantReplyTask(
message_id=message.id,
conversation=conversation,
reply=message.text,
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)),
)
else:
logger.info("Generating a LabelPrompterReplyTask.")
task = protocol_schema.LabelPrompterReplyTask(
message_id=message.id,
conversation=conversation,
reply=message.text,
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)),
)
parent_message_id = message.id
message_tree_id = message.message_tree_id
case TaskType.REPLY:
# select a tree with missing replies
extensible_parents = self.query_extendible_parents()
assert len(extensible_parents) > 0
# fetch random conversation to extend
random_parent = random.choice(extensible_parents)
logger.debug(f"selected {random_parent=}")
messages = self.pr.fetch_message_conversation(random_parent.parent_id)
assert all(m.review_result for m in messages) # ensure all messages have positive review
conversation = prepare_conversation(messages)
# generate reply task depending on last message
if messages[-1].role == "assistant":
logger.info("Generating a PrompterReplyTask.")
task = protocol_schema.PrompterReplyTask(conversation=conversation)
else:
logger.info("Generating a AssistantReplyTask.")
task = protocol_schema.AssistantReplyTask(conversation=conversation)
parent_message_id = messages[-1].id
message_tree_id = messages[-1].message_tree_id
case TaskType.LABEL_PROMPT:
assert len(prompts_need_review) > 0
message = self.pr.fetch_message(random.choice(prompts_need_review))
logger.info("Generating a LabelInitialPromptTask.")
task = protocol_schema.LabelInitialPromptTask(
message_id=message.id,
prompt=message.text,
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)),
)
parent_message_id = message.id
message_tree_id = message.message_tree_id
case TaskType.PROMPT:
logger.info("Generating an InitialPromptTask.")
task = protocol_schema.InitialPromptTask(hint=None)
case _:
task = None
logger.info(f"Generated {task=}.")
return task, message_tree_id, parent_message_id
async def handle_interaction(self, interaction: protocol_schema.AnyInteraction) -> protocol_schema.Task:
pr = self.pr
match type(interaction):
case protocol_schema.TextReplyToMessage:
logger.info(
f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
)
# here we store the text reply in the database
message = pr.store_text_reply(
text=interaction.text,
frontend_message_id=interaction.message_id,
user_frontend_message_id=interaction.user_message_id,
)
if not message.parent_id:
logger.info(f"TreeManager: Inserting new tree state for initial prompt {message.id=}")
self._insert_default_state(message.id)
self.db.commit()
if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION:
try:
hugging_face_api = HuggingFaceAPI(
f"{HfUrl.HUGGINGFACE_FEATURE_EXTRACTION.value}/{HfEmbeddingModel.MINILM.value}"
)
embedding = await hugging_face_api.post(interaction.text)
pr.insert_message_embedding(
message_id=message.id, model=HfEmbeddingModel.MINILM.value, embedding=embedding
)
except OasstError:
logger.error(
f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
)
case protocol_schema.MessageRating:
logger.info(
f"Frontend reports rating of {interaction.message_id=} with {interaction.rating=} by {interaction.user=}."
)
pr.store_rating(interaction)
case protocol_schema.MessageRanking:
logger.info(
f"Frontend reports ranking of {interaction.message_id=} with {interaction.ranking=} by {interaction.user=}."
)
_, task = pr.store_ranking(interaction)
self.check_condition_for_scoring_state(task.message_tree_id)
case protocol_schema.TextLabels:
logger.info(
f"Frontend reports labels of {interaction.message_id=} with {interaction.labels=} by {interaction.user=}."
)
_, task, msg = pr.store_text_labels(interaction)
# if it was a respones for a task, check if we have enough reviews to calc review_result
if task and msg:
reviews = self.query_reviews_for_message(msg.id)
acceptance_score = self._calculate_acceptance(reviews)
logger.debug(
f"Message {msg.id=}, {acceptance_score=}, {len(reviews)=}, {msg.review_result=}, {msg.review_count=}"
)
if msg.parent_id is None:
if not msg.review_result and msg.review_count >= self.cfg.num_reviews_initial_prompt:
if acceptance_score > self.cfg.acceptance_threshold_initial_prompt:
msg.review_result = True
self.db.add(msg)
self.db.commit()
logger.info(
f"Initial prompt message was accepted: {msg.id=}, {acceptance_score=}, {len(reviews)=}"
)
else:
self.enter_low_grade_state(msg.message_tree_id)
self.check_condition_for_growing_state(msg.message_tree_id)
elif msg.review_count >= self.cfg.num_reviews_reply:
if not msg.review_result and acceptance_score > self.cfg.acceptance_threshold_reply:
msg.review_result = True
self.db.add(msg)
self.db.commit()
logger.info(
f"Reply message message accepted: {msg.id=}, {acceptance_score=}, {len(reviews)=}"
)
self.check_condition_for_ranking_state(msg.message_tree_id)
case _:
raise OasstError("Invalid response type.", OasstErrorCode.TASK_INVALID_RESPONSE_TYPE)
return protocol_schema.TaskDone()
def _enter_state(self, mts: MessageTreeState, state: message_tree_state.State):
assert mts and mts.active
is_terminal = state in message_tree_state.TERMINAL_STATES
if is_terminal:
mts.active = False
mts.state = state.value
self.db.add(mts)
self.db.commit()
if is_terminal:
logger.info(f"Tree entered terminal '{mts.state}' state ({mts.message_tree_id=})")
else:
logger.info(f"Tree entered '{mts.state}' state ({mts.message_tree_id=})")
def enter_low_grade_state(self, message_tree_id: UUID) -> None:
logger.debug(f"enter_low_grade_state({message_tree_id=})")
mts = self.pr.fetch_tree_state(message_tree_id)
self._enter_state(mts, message_tree_state.State.ABORTED_LOW_GRADE)
def check_condition_for_growing_state(self, message_tree_id: UUID) -> bool:
logger.debug(f"check_condition_for_growing_state({message_tree_id=})")
mts = self.pr.fetch_tree_state(message_tree_id)
if not mts.active or mts.state != message_tree_state.State.INITIAL_PROMPT_REVIEW:
logger.debug(f"False {mts.active=}, {mts.state=}")
return False
# check if initial prompt was accepted
initial_prompt = self.pr.fetch_message(message_tree_id)
if not initial_prompt.review_result:
logger.debug(f"False {initial_prompt.review_result=}")
return False
self._enter_state(mts, message_tree_state.State.GROWING)
return True
def check_condition_for_ranking_state(self, message_tree_id: UUID) -> bool:
logger.debug(f"check_condition_for_ranking_state({message_tree_id=})")
mts = self.pr.fetch_tree_state(message_tree_id)
if not mts.active or mts.state != message_tree_state.State.GROWING:
logger.debug(f"False {mts.active=}, {mts.state=}")
return False
# check if desired tree size has been reached and all nodes have been reviewed
tree_size = self.query_tree_size(message_tree_id)
if tree_size.remaining_messages > 0:
logger.debug(f"False {tree_size.remaining_messages=}")
return False
self._enter_state(mts, message_tree_state.State.RANKING)
return True
def check_condition_for_scoring_state(self, message_tree_id: UUID) -> bool:
logger.debug(f"check_condition_for_scoring_state({message_tree_id=})")
mts: MessageTreeState
mts = self.db.query(MessageTreeState).filter(MessageTreeState.message_tree_id == message_tree_id).one()
if not mts.active or mts.state != message_tree_state.State.RANKING:
logger.debug(f"False {mts.active=}, {mts.state=}")
return False
rankings_by_message = self.query_tree_ranking_results(message_tree_id)
for parent_msg_id, ranking in rankings_by_message.items():
if len(ranking) < self.cfg.num_required_rankings:
logger.debug(f"False {parent_msg_id=} {len(ranking)=}")
return False
self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING)
return True
def _calculate_acceptance(self, labels: list[TextLabels]):
# calculate acceptance based on spam label
return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels])
_sql_find_prompts_need_review = """
-- find initial prompts that need more reviews
SELECT m.id
FROM message_tree_state mts
LEFT JOIN message m ON mts.message_tree_id = m.id
WHERE mts.active
AND mts.state = :state
AND NOT m.review_result
AND NOT m.deleted
AND m.review_count < :num_reviews_initial_prompt
AND m.parent_id is NULL
AND (:excluded_user_id IS NULL OR m.user_id != :excluded_user_id)
"""
def query_prompts_need_review(self) -> list[UUID]:
"""
Select id of initial prompts with less then required rankings in active message tree
(active == True in message_tree_state)
"""
r = self.db.execute(
text(self._sql_find_prompts_need_review),
{
"state": message_tree_state.State.INITIAL_PROMPT_REVIEW,
"num_reviews_initial_prompt": self.cfg.num_reviews_initial_prompt,
"excluded_user_id": None if settings.DEBUG_ALLOW_SELF_LABELING else self.pr.user_id,
},
)
return [x["id"] for x in r.all()]
_sql_find_replies_need_review = """
SELECT m.id
FROM message_tree_state mts
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
WHERE mts.active
AND mts.state = :breeding_state
AND NOT m.review_result
AND NOT m.deleted
AND m.review_count < :num_required_reviews
AND m.parent_id is NOT NULL
AND (:excluded_user_id IS NULL OR m.user_id != :excluded_user_id)
"""
def query_replies_need_review(self) -> list[UUID]:
"""
Select ids of child messages (parent_id IS NOT NULL) with less then required rankings
in active message tree (active == True in message_tree_state)
"""
r = self.db.execute(
text(self._sql_find_replies_need_review),
{
"breeding_state": message_tree_state.State.GROWING,
"num_required_reviews": self.cfg.num_reviews_reply,
"excluded_user_id": None if settings.DEBUG_ALLOW_SELF_LABELING else self.pr.user_id,
},
)
return [x["id"] for x in r.all()]
_sql_find_incomplete_rankings = """
-- find incomplete rankings
SELECT m.parent_id, COUNT(m.id) children_count, MIN(m.ranking_count) child_min_ranking_count,
COUNT(m.id) FILTER (WHERE m.ranking_count >= :num_required_rankings) as completed_rankings
FROM message_tree_state mts
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
WHERE mts.active -- only consider active trees
AND mts.state = :ranking_state -- message tree must be in ranking state
AND m.review_result -- must be reviewed
AND NOT m.deleted -- not deleted
AND m.parent_id IS NOT NULL -- ignore initial prompts
GROUP BY m.parent_id
HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
"""
def query_incomplete_rankings(self) -> list[IncompleteRankingsRow]:
"""Query parents which have childern that need further rankings"""
r = self.db.execute(
text(self._sql_find_incomplete_rankings),
{
"num_required_rankings": self.cfg.num_required_rankings,
"ranking_state": message_tree_state.State.RANKING,
},
)
return [IncompleteRankingsRow.from_orm(x) for x in r.all()]
_sql_find_extendible_parents = """
-- find all extendible parent nodes
SELECT m.id as parent_id, m.depth, m.message_tree_id, COUNT(c.id) active_children_count
FROM message_tree_state mts
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id -- all elements of message tree
LEFT JOIN message c ON m.id = c.Id -- child nodes
WHERE mts.active -- only consider active trees
AND mts.state = :growing_state -- message tree must be growing
AND NOT m.deleted -- ignore deleted messages as parents
AND m.depth < mts.max_depth -- ignore leaf nodes as parents
AND m.review_result -- parent node must have positive review
AND NOT c.deleted -- don't count deleted children
AND (c.review_result OR c.review_count < :num_reviews_reply) -- don't count children with negative review but count elements under review
GROUP BY m.id, m.depth, m.message_tree_id, mts.max_children_count
HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children
"""
def query_extendible_parents(self) -> list[ExtendibleParentRow]:
"""Query parent messages that have not reached the maximum number of replies."""
r = self.db.execute(
text(self._sql_find_extendible_parents),
{
"growing_state": message_tree_state.State.GROWING,
"num_reviews_reply": self.cfg.num_reviews_reply,
},
)
return [ExtendibleParentRow.from_orm(x) for x in r.all()]
_sql_find_extendible_trees = f"""
-- find extendible trees
SELECT m.message_tree_id, mts.goal_tree_size, COUNT(m.id) AS tree_size
FROM (
SELECT DISTINCT message_tree_id FROM ({_sql_find_extendible_parents}) extendible_parents
) trees LEFT JOIN message_tree_state mts ON trees.message_tree_id = mts.message_tree_id
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
WHERE NOT m.deleted
AND (
m.parent_id IS NOT NULL AND (m.review_result OR m.review_count < :num_reviews_reply) -- children
OR m.parent_id IS NULL AND m.review_result -- prompts (root nodes) must have positive review
)
GROUP BY m.message_tree_id, mts.goal_tree_size
HAVING COUNT(m.id) < mts.goal_tree_size
"""
def query_extendible_trees(self) -> list[ActiveTreeSizeRow]:
"""Query size of active message trees in growing state."""
r = self.db.execute(
text(self._sql_find_extendible_trees),
{
"growing_state": message_tree_state.State.GROWING,
"num_reviews_reply": self.cfg.num_reviews_reply,
},
)
return [ActiveTreeSizeRow.from_orm(x) for x in r.all()]
_sql_get_tree_size = """
SELECT mts.message_tree_id, mts.goal_tree_size, COUNT(m.id) AS tree_size
FROM message_tree_state mts
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
WHERE mts.active
AND NOT m.deleted
AND m.review_result
AND mts.message_tree_id = :message_tree_id
GROUP BY mts.message_tree_id, mts.goal_tree_size
"""
def query_tree_size(self, message_tree_id: UUID) -> ActiveTreeSizeRow:
"""Returns the number of reviewed not deleted messages in the message tree."""
r = self.db.execute(text(self._sql_get_tree_size), {"message_tree_id": message_tree_id})
return ActiveTreeSizeRow.from_orm(r.one())
def query_misssing_tree_states(self) -> list[UUID]:
"""Find all initial prompt messages that have no associated message tree state"""
qry_missing_tree_states = (
self.db.query(Message.id)
.join(MessageTreeState, isouter=True)
.filter(
Message.parent_id.is_(None),
Message.message_tree_id == Message.id,
MessageTreeState.message_tree_id.is_(None),
)
)
return [m.id for m in qry_missing_tree_states.all()]
_sql_find_tree_ranking_results = """
-- get all ranking results of completed tasks for all parents with >=2 children
SELECT p.parent_id, mr.* FROM
(
-- find parents with > 1 children
SELECT m.parent_id, m.message_tree_id, COUNT(m.id) children_count
FROM message_tree_state mts
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
WHERE m.review_result -- must be reviewed
AND NOT m.deleted -- not deleted
AND m.parent_id IS NOT NULL -- ignore initial prompts
AND mts.message_tree_id = :message_tree_id
GROUP BY m.parent_id, m.message_tree_id
HAVING COUNT(m.id) > 1
) as p
LEFT JOIN task t ON p.parent_id = t.parent_message_id AND t.done AND (t.payload_type = 'RankPrompterRepliesPayload' OR t.payload_type = 'RankAssistantRepliesPayload')
LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'RankingReactionPayload'
"""
def query_tree_ranking_results(self, message_tree_id: UUID) -> dict[UUID, list[MessageReaction]]:
"""Finds all completed ranking restuls for a message_tree"""
r = self.db.execute(
text(self._sql_find_tree_ranking_results),
{"message_tree_id": message_tree_id},
)
rankings_by_message = {}
for x in r.all():
parent_id = x["parent_id"]
if parent_id not in rankings_by_message:
rankings_by_message[parent_id] = []
if x["task_id"]:
rankings_by_message[parent_id].append(MessageReaction.from_orm(x))
return rankings_by_message
def ensure_tree_states(self):
"""Add message tree state rows for all root nodes (inital prompt messages)."""
missing_tree_ids = self.query_misssing_tree_states()
for id in missing_tree_ids:
tree_size = self.db.query(func.count(Message.id)).filter(Message.message_tree_id == id).scalar()
state = message_tree_state.State.INITIAL_PROMPT_REVIEW
if tree_size > 1:
state = message_tree_state.State.GROWING
logger.info(f"Inserting missing message tree state for message: {id} ({tree_size=}, {state=})")
self._insert_default_state(id, state=state)
self.db.commit()
def query_num_active_trees(self) -> int:
query = self.db.query(func.count(MessageTreeState.message_tree_id)).filter(MessageTreeState.active)
return query.scalar()
def query_reviews_for_message(self, message_id: UUID) -> list[TextLabels]:
sql_qry = """
SELECT tl.*
FROM task t
INNER JOIN text_labels tl ON tl.id = t.id
WHERE t.done = TRUE
AND tl.message_id = :message_id
"""
r = self.db.execute(text(sql_qry), {"message_id": message_id})
return [TextLabels.from_orm(x) for x in r.all()]
def _insert_tree_state(
self,
root_message_id: UUID,
goal_tree_size: int,
max_depth: int,
max_children_count: int,
active: bool,
state: message_tree_state.State = message_tree_state.State.INITIAL_PROMPT_REVIEW,
) -> MessageTreeState:
model = MessageTreeState(
message_tree_id=root_message_id,
goal_tree_size=goal_tree_size,
max_depth=max_depth,
max_children_count=max_children_count,
state=state.value,
active=active,
accepted_messages=0,
)
self.db.add(model)
return model
def _insert_default_state(
self,
root_message_id: UUID,
state: message_tree_state.State = message_tree_state.State.INITIAL_PROMPT_REVIEW,
) -> MessageTreeState:
return self._insert_tree_state(
root_message_id=root_message_id,
goal_tree_size=self.cfg.goal_tree_size,
max_depth=self.cfg.max_tree_depth,
max_children_count=self.cfg.max_children_count,
state=state,
active=True,
)
if __name__ == "__main__":
from oasst_backend.api.deps import get_dummy_api_client
from oasst_backend.database import engine
from oasst_backend.prompt_repository import PromptRepository
with Session(engine) as db:
api_client = get_dummy_api_client(db)
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
pr = PromptRepository(db=db, api_client=api_client, client_user=dummy_user)
cfg = TreeManagerConfiguration()
tm = TreeManager(db, pr, cfg)
tm.ensure_tree_states()
print("query_num_active_trees", tm.query_num_active_trees())
print("query_incomplete_rankings", tm.query_incomplete_rankings())
print("query_incomplete_reply_reviews", tm.query_replies_need_review())
print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review())
print("query_extendible_trees", tm.query_extendible_trees())
print("query_extendible_parents", tm.query_extendible_parents())
print("next_task:", tm.next_task())
print(
".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("2ac20d38-6650-43aa-8bb3-f61080c0d921"))
)
+3 -3
View File
@@ -51,12 +51,12 @@ class HuggingFaceAPI:
async with session.post(self.api_url, headers=self.headers, json=payload) as response:
# If we get a bad response
if response.status != 200:
if not response.ok:
logger.error(response)
logger.info(self.headers)
raise OasstError(
"Response Error Detoxify HuggingFace", error_code=OasstErrorCode.HUGGINGFACE_API_ERROR
f"Response Error HuggingFace API (Status: {response.status})",
error_code=OasstErrorCode.HUGGINGFACE_API_ERROR,
)
# Get the response from the API call
+1
View File
@@ -99,6 +99,7 @@ services:
- REDIS_HOST=redis
- DEBUG_SKIP_API_KEY_CHECK=True
- DEBUG_USE_SEED_DATA=True
- DEBUG_ALLOW_SELF_LABELING=True
- MAX_WORKERS=1
- DEBUG_SKIP_EMBEDDING_COMPUTATION=True
depends_on:
@@ -17,9 +17,11 @@ class OasstErrorCode(IntEnum):
GENERIC_ERROR = 0
DATABASE_URI_NOT_SET = 1
API_CLIENT_NOT_AUTHORIZED = 2
SERVER_ERROR = 3
TOO_MANY_REQUESTS = 429
SERVER_ERROR0 = 500
SERVER_ERROR1 = 501
# 1000-2000: tasks endpoint
TASK_INVALID_REQUEST_TYPE = 1000
TASK_ACK_FAILED = 1001
@@ -27,6 +29,7 @@ class OasstErrorCode(IntEnum):
TASK_INVALID_RESPONSE_TYPE = 1003
TASK_INTERACTION_REQUEST_FAILED = 1004
TASK_GENERATION_FAILED = 1005
TASK_REQUESTED_TYPE_NOT_AVAILABLE = 1006
# 2000-3000: prompt_repository
INVALID_FRONTEND_MESSAGE_ID = 2000
@@ -38,6 +41,14 @@ class OasstErrorCode(IntEnum):
NO_MESSAGE_TREE_FOUND = 2006
NO_REPLIES_FOUND = 2007
INVALID_MESSAGE = 2008
BROKEN_CONVERSATION = 2009
TREE_NOT_IN_GROWING_STATE = 2010
CORRUPT_RANKING_RESULT = 2011
TEXT_LABELS_WRONG_MESSAGE_ID = 2050
TEXT_LABELS_INVALID_LABEL = 2051
TEXT_LABELS_MANDATORY_LABEL_MISSING = 2052
TASK_NOT_FOUND = 2100
TASK_EXPIRED = 2101
TASK_PAYLOAD_TYPE_MISMATCH = 2102
@@ -45,6 +56,7 @@ class OasstErrorCode(IntEnum):
TASK_NOT_ACK = 2104
TASK_ALREADY_DONE = 2105
TASK_NOT_COLLECTIVE = 2106
TASK_NOT_ASSIGNED_TO_USER = 2106
USER_NOT_FOUND = 2200
# 3000-4000: external resources
@@ -151,7 +151,8 @@ class RankInitialPromptsTask(Task):
"""A task to rank a set of initial prompts."""
type: Literal["rank_initial_prompts"] = "rank_initial_prompts"
prompts: list[str]
prompts: list[str] # deprecated, use prompt_messages
prompt_messages: list[ConversationMessage]
class RankConversationRepliesTask(Task):
@@ -159,7 +160,8 @@ class RankConversationRepliesTask(Task):
type: Literal["rank_conversation_replies"] = "rank_conversation_replies"
conversation: Conversation # the conversation so far
replies: list[str]
replies: list[str] # deprecated, use reply_messages
reply_messages: list[ConversationMessage]
class RankPrompterRepliesTask(RankConversationRepliesTask):
@@ -181,6 +183,7 @@ class LabelInitialPromptTask(Task):
message_id: UUID
prompt: str
valid_labels: list[str]
mandatory_labels: Optional[list[str]]
class LabelConversationReplyTask(Task):
@@ -191,6 +194,7 @@ class LabelConversationReplyTask(Task):
message_id: UUID
reply: str
valid_labels: list[str]
mandatory_labels: Optional[list[str]]
class LabelPrompterReplyTask(LabelConversationReplyTask):
@@ -304,6 +308,7 @@ class TextLabels(Interaction):
text: str
labels: dict[TextLabel, float]
message_id: UUID
task_id: Optional[UUID]
@property
def has_message_id(self) -> bool:
+1
View File
@@ -6,6 +6,7 @@ pushd "$parent_path/../../backend"
export DEBUG_SKIP_API_KEY_CHECK=True
export DEBUG_USE_SEED_DATA=True
export DEBUG_ALLOW_SELF_LABELING=True
export DEBUG_SKIP_EMBEDDING_COMPUTATION=True
uvicorn main:app --reload --port 8080 --host 0.0.0.0
+7 -1
View File
@@ -36,7 +36,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
return response.json()
typer.echo("Requesting work...")
tasks = [_post("/api/v1/tasks/", {"type": "random"})]
tasks = [_post("/api/v1/tasks/", {"type": "random", "user": USER})]
while tasks:
task = tasks.pop(0)
match (task["type"]):
@@ -58,6 +58,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
{
"type": "text_reply_to_message",
"message_id": message_id,
"task_id": task["id"],
"user_message_id": user_message_id,
"text": summary,
"user": USER,
@@ -102,6 +103,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
{
"type": "text_reply_to_message",
"message_id": message_id,
"task_id": task["id"],
"user_message_id": user_message_id,
"text": prompt,
"user": USER,
@@ -150,6 +152,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
{
"type": "text_reply_to_message",
"message_id": message_id,
"task_id": task["id"],
"user_message_id": user_message_id,
"text": reply,
"user": USER,
@@ -200,6 +203,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
{
"type": "message_ranking",
"message_id": message_id,
"task_id": task["id"],
"ranking": ranking,
"user": USER,
},
@@ -232,6 +236,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
{
"type": "text_labels",
"message_id": task["message_id"],
"task_id": task["id"],
"text": task["prompt"],
"labels": labels_dict,
"user": USER,
@@ -269,6 +274,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
{
"type": "text_labels",
"message_id": task["message_id"],
"task_id": task["id"],
"text": task["reply"],
"labels": labels_dict,
"user": USER,