mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Message tree state machine (#555)
* add query_incomplete_rankings() * Add SQL queries for TreeManager task selection * first working version of TreeManager.next_task() * remove old generate_task(), add mandatory_labels to text_labels task * Add ConversationMessage list to Ranking tasks * add more sophisticated sql queries to find extendible trees * add TreeManager.query_extendible_parents() * fix task validation, seed data insertion (reviewed) * provide user for task selection in text-frontend * enter 'growing' state * enter 'aborted_low_grade' state * enter 'ranking' state * check tree 'growing' state upon relpy insertion * exclude user from labeling their own messages (added DEBUG_ALLOW_SELF_LABELING setting) * add DEBUG_ALLOW_SELF_LABELING to docker-compose.yaml * fix ranking submission * add query_tree_ranking_results() * add ranked_message_ids to RankingReactionPayload * fix reply_messages instead of prompt_messages * incorment 'ranking_count' of ranked replies * added logic to check_condition_for_scoring_state * changes to msg_tree_state_machine * pre-commit changes * enter 'ready_for_scoring' state * re-add HF embedding call (lost during merge) * use prepare_conversation() helper for seed-data creation * Partially add user specified task selection Co-authored-by: Daniel Hug <danielpatrickhug@gmail.com>
This commit is contained in:
@@ -79,6 +79,7 @@
|
||||
REDIS_HOST: oasst-redis
|
||||
DEBUG_ALLOW_ANY_API_KEY: "true"
|
||||
DEBUG_USE_SEED_DATA: "true"
|
||||
DEBUG_ALLOW_SELF_LABELING: "true"
|
||||
MAX_WORKERS: "1"
|
||||
RATE_LIMIT: "false"
|
||||
DEBUG_SKIP_EMBEDDING_COMPUTATION: "true"
|
||||
|
||||
+63
@@ -0,0 +1,63 @@
|
||||
"""restructure message_tree_state table
|
||||
|
||||
Revision ID: 92a367bb9f40
|
||||
Revises: ba61fe17fb6e
|
||||
Create Date: 2023-01-08 22:08:46.458195
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "92a367bb9f40"
|
||||
down_revision = "aac6b2f66006"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("message_tree_state")
|
||||
op.create_table(
|
||||
"message_tree_state",
|
||||
sa.Column("message_tree_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("goal_tree_size", sa.Integer(), nullable=False),
|
||||
sa.Column("max_depth", sa.Integer(), nullable=False),
|
||||
sa.Column("max_children_count", sa.Integer(), nullable=False),
|
||||
sa.Column("state", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
|
||||
sa.Column("active", sa.Boolean(), nullable=False),
|
||||
sa.Column("accepted_messages", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["message_tree_id"],
|
||||
["message.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("message_tree_id"),
|
||||
)
|
||||
op.create_index(op.f("ix_message_tree_state_active"), "message_tree_state", ["active"], unique=False)
|
||||
op.create_index(op.f("ix_message_tree_state_state"), "message_tree_state", ["state"], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f("ix_message_tree_state_state"), table_name="message_tree_state")
|
||||
op.drop_index(op.f("ix_message_tree_state_active"), table_name="message_tree_state")
|
||||
op.drop_table("message_tree_state")
|
||||
op.create_table(
|
||||
"message_tree_state",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
|
||||
sa.Column("message_tree_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.Column("state", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
|
||||
sa.Column("goal_tree_size", sa.Integer(), nullable=False),
|
||||
sa.Column("current_num_non_filtered_messages", sa.Integer(), nullable=False),
|
||||
sa.Column("max_depth", sa.Integer(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_message_tree_state_message_tree_id"), "message_tree_state", ["message_tree_id"], unique=False
|
||||
)
|
||||
op.create_index("ix_message_tree_state_tree_id", "message_tree_state", ["message_tree_id"], unique=True)
|
||||
# ### end Alembic commands ###
|
||||
+31
@@ -0,0 +1,31 @@
|
||||
"""add review_count & ranking_count to message
|
||||
|
||||
Revision ID: 05975b274a81
|
||||
Revises: 92a367bb9f40
|
||||
Create Date: 2023-01-09 00:47:25.496036
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "05975b274a81"
|
||||
down_revision = "92a367bb9f40"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("message", sa.Column("review_count", sa.Integer(), server_default=sa.text("0"), nullable=False))
|
||||
op.add_column("message", sa.Column("review_result", sa.Boolean(), server_default=sa.text("false"), nullable=False))
|
||||
op.add_column("message", sa.Column("ranking_count", sa.Integer(), server_default=sa.text("0"), nullable=False))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("message", "ranking_count")
|
||||
op.drop_column("message", "review_result")
|
||||
op.drop_column("message", "review_count")
|
||||
# ### end Alembic commands ###
|
||||
+29
-13
@@ -12,9 +12,12 @@ from fastapi_limiter import FastAPILimiter
|
||||
from loguru import logger
|
||||
from oasst_backend.api.deps import get_dummy_api_client
|
||||
from oasst_backend.api.v1.api import api_router
|
||||
from oasst_backend.api.v1.utils import prepare_conversation
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.database import engine
|
||||
from oasst_backend.models import message_tree_state
|
||||
from oasst_backend.prompt_repository import PromptRepository, TaskRepository, UserRepository
|
||||
from oasst_backend.tree_manager import TreeManager, TreeManagerConfiguration
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from pydantic import BaseModel
|
||||
@@ -116,6 +119,7 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
pr = PromptRepository(
|
||||
db=db, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr
|
||||
)
|
||||
tm = TreeManager(db, pr, TreeManagerConfiguration())
|
||||
|
||||
with open(settings.DEBUG_USE_SEED_DATA_PATH) as f:
|
||||
dummy_messages_raw = json.load(f)
|
||||
@@ -138,24 +142,19 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
msg.parent_message_id, fail_if_missing=True
|
||||
)
|
||||
conversation_messages = pr.fetch_message_conversation(parent_message)
|
||||
conversation = protocol_schema.Conversation(
|
||||
messages=[
|
||||
protocol_schema.ConversationMessage(
|
||||
text=cmsg.text,
|
||||
is_assistant=cmsg.role == "assistant",
|
||||
message_id=cmsg.id,
|
||||
fronend_message_id=cmsg.frontend_message_id,
|
||||
)
|
||||
for cmsg in conversation_messages
|
||||
]
|
||||
)
|
||||
conversation = prepare_conversation(conversation_messages)
|
||||
task = tr.store_task(
|
||||
protocol_schema.AssistantReplyTask(conversation=conversation),
|
||||
message_tree_id=parent_message.message_tree_id,
|
||||
parent_message_id=parent_message.id,
|
||||
)
|
||||
tr.bind_frontend_message_id(task.id, msg.task_message_id)
|
||||
message = pr.store_text_reply(msg.text, msg.task_message_id, msg.user_message_id)
|
||||
message = pr.store_text_reply(
|
||||
msg.text, msg.task_message_id, msg.user_message_id, review_count=5, review_result=True
|
||||
)
|
||||
if message.parent_id is None:
|
||||
tm._insert_default_state(root_message_id=message.id, state=message_tree_state.State.GROWING)
|
||||
db.commit()
|
||||
|
||||
logger.info(
|
||||
f"Inserted: message_id: {message.id}, payload: {message.payload.payload}, parent_message_id: {message.parent_id}"
|
||||
@@ -168,6 +167,19 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
logger.exception("Seed data insertion failed")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def ensure_tree_states():
|
||||
try:
|
||||
logger.info("Startup: TreeManager.ensure_tree_states()")
|
||||
cfg = TreeManagerConfiguration() # TODO: decide where config is stored, e.g. load form json/yaml file
|
||||
with Session(engine) as db:
|
||||
tm = TreeManager(db, None, configuration=cfg)
|
||||
tm.ensure_tree_states()
|
||||
|
||||
except Exception:
|
||||
logger.exception("TreeManager.ensure_tree_states() failed.")
|
||||
|
||||
|
||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||
|
||||
|
||||
@@ -175,7 +187,7 @@ def get_openapi_schema():
|
||||
return json.dumps(app.openapi())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
# Importing here so we don't import packages unnecessarily if we're
|
||||
# importing main as a module.
|
||||
import argparse
|
||||
@@ -198,3 +210,7 @@ if __name__ == "__main__":
|
||||
print(get_openapi_schema())
|
||||
else:
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
import random
|
||||
from typing import Any, Optional, Tuple
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.security.api_key import APIKey
|
||||
from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1.utils import prepare_conversation
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.prompt_repository import PromptRepository, TaskRepository
|
||||
from oasst_backend.utils.hugging_face import HfEmbeddingModel, HfUrl, HuggingFaceAPI
|
||||
from oasst_backend.tree_manager import TreeManager, TreeManagerConfiguration
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
@@ -18,160 +15,6 @@ from starlette.status import HTTP_204_NO_CONTENT
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def generate_task(
|
||||
request: protocol_schema.TaskRequest, pr: PromptRepository
|
||||
) -> Tuple[protocol_schema.Task, Optional[UUID], Optional[UUID]]:
|
||||
message_tree_id = None
|
||||
parent_message_id = None
|
||||
|
||||
match request.type:
|
||||
case protocol_schema.TaskRequestType.random:
|
||||
logger.info("Frontend requested a random task.")
|
||||
disabled_tasks = (
|
||||
protocol_schema.TaskRequestType.random,
|
||||
protocol_schema.TaskRequestType.summarize_story,
|
||||
protocol_schema.TaskRequestType.rate_summary,
|
||||
)
|
||||
candidate_tasks = set(protocol_schema.TaskRequestType).difference(disabled_tasks)
|
||||
request.type = random.choice(tuple(candidate_tasks)).value
|
||||
return generate_task(request, pr)
|
||||
|
||||
# AKo: Summary tasks are currently disabled/supported, we focus on the conversation tasks.
|
||||
|
||||
# case protocol_schema.TaskRequestType.summarize_story:
|
||||
# logger.info("Generating a SummarizeStoryTask.")
|
||||
# task = protocol_schema.SummarizeStoryTask(
|
||||
# story="This is a story. A very long story. So long, it needs to be summarized.",
|
||||
# )
|
||||
# case protocol_schema.TaskRequestType.rate_summary:
|
||||
# logger.info("Generating a RateSummaryTask.")
|
||||
# task = protocol_schema.RateSummaryTask(
|
||||
# full_text="This is a story. A very long story. So long, it needs to be summarized.",
|
||||
# summary="This is a summary.",
|
||||
# scale=protocol_schema.RatingScale(min=1, max=5),
|
||||
# )
|
||||
|
||||
case protocol_schema.TaskRequestType.initial_prompt:
|
||||
logger.info("Generating an InitialPromptTask.")
|
||||
task = protocol_schema.InitialPromptTask(
|
||||
hint="Ask the assistant about a current event." # this is optional
|
||||
)
|
||||
case protocol_schema.TaskRequestType.prompter_reply:
|
||||
logger.info("Generating a PrompterReplyTask.")
|
||||
messages = pr.fetch_random_conversation("assistant")
|
||||
task_messages = [
|
||||
protocol_schema.ConversationMessage(
|
||||
text=msg.text,
|
||||
is_assistant=(msg.role == "assistant"),
|
||||
message_id=msg.id,
|
||||
front_end_id=msg.frontend_message_id,
|
||||
)
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
task = protocol_schema.PrompterReplyTask(conversation=protocol_schema.Conversation(messages=task_messages))
|
||||
message_tree_id = messages[-1].message_tree_id
|
||||
parent_message_id = messages[-1].id
|
||||
case protocol_schema.TaskRequestType.assistant_reply:
|
||||
logger.info("Generating a AssistantReplyTask.")
|
||||
messages = pr.fetch_random_conversation("prompter")
|
||||
task_messages = [
|
||||
protocol_schema.ConversationMessage(
|
||||
text=msg.text,
|
||||
is_assistant=(msg.role == "assistant"),
|
||||
message_id=msg.id,
|
||||
front_end_id=msg.frontend_message_id,
|
||||
)
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
task = protocol_schema.AssistantReplyTask(conversation=protocol_schema.Conversation(messages=task_messages))
|
||||
message_tree_id = messages[-1].message_tree_id
|
||||
parent_message_id = messages[-1].id
|
||||
case protocol_schema.TaskRequestType.rank_initial_prompts:
|
||||
logger.info("Generating a RankInitialPromptsTask.")
|
||||
|
||||
messages = pr.fetch_random_initial_prompts()
|
||||
task = protocol_schema.RankInitialPromptsTask(prompts=[msg.text for msg in messages])
|
||||
case protocol_schema.TaskRequestType.rank_prompter_replies:
|
||||
logger.info("Generating a RankPrompterRepliesTask.")
|
||||
conversation, replies = pr.fetch_multiple_random_replies(message_role="assistant")
|
||||
|
||||
task_messages = [
|
||||
protocol_schema.ConversationMessage(
|
||||
text=p.text,
|
||||
is_assistant=(p.role == "assistant"),
|
||||
message_id=p.id,
|
||||
front_end_id=p.frontend_message_id,
|
||||
)
|
||||
for p in conversation
|
||||
]
|
||||
replies = [p.text for p in replies]
|
||||
task = protocol_schema.RankPrompterRepliesTask(
|
||||
conversation=protocol_schema.Conversation(
|
||||
messages=task_messages,
|
||||
),
|
||||
replies=replies,
|
||||
)
|
||||
|
||||
case protocol_schema.TaskRequestType.rank_assistant_replies:
|
||||
logger.info("Generating a RankAssistantRepliesTask.")
|
||||
conversation, replies = pr.fetch_multiple_random_replies(message_role="prompter")
|
||||
|
||||
task_messages = [
|
||||
protocol_schema.ConversationMessage(
|
||||
text=p.text,
|
||||
is_assistant=(p.role == "assistant"),
|
||||
message_id=p.id,
|
||||
front_end_id=p.frontend_message_id,
|
||||
)
|
||||
for p in conversation
|
||||
]
|
||||
replies = [p.text for p in replies]
|
||||
task = protocol_schema.RankAssistantRepliesTask(
|
||||
conversation=prepare_conversation(conversation),
|
||||
replies=replies,
|
||||
)
|
||||
|
||||
case protocol_schema.TaskRequestType.label_initial_prompt:
|
||||
logger.info("Generating a LabelInitialPromptTask.")
|
||||
message = pr.fetch_random_initial_prompts(1)[0]
|
||||
task = protocol_schema.LabelInitialPromptTask(
|
||||
message_id=message.id,
|
||||
prompt=message.text,
|
||||
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
|
||||
)
|
||||
|
||||
case protocol_schema.TaskRequestType.label_prompter_reply:
|
||||
logger.info("Generating a LabelPrompterReplyTask.")
|
||||
conversation, messages = pr.fetch_multiple_random_replies(max_size=1, message_role="assistant")
|
||||
message = messages[0]
|
||||
task = protocol_schema.LabelPrompterReplyTask(
|
||||
message_id=message.id,
|
||||
conversation=prepare_conversation(conversation),
|
||||
reply=message.text,
|
||||
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
|
||||
)
|
||||
|
||||
case protocol_schema.TaskRequestType.label_assistant_reply:
|
||||
logger.info("Generating a LabelAssistantReplyTask.")
|
||||
conversation, messages = pr.fetch_multiple_random_replies(max_size=1, message_role="prompter")
|
||||
message = messages[0]
|
||||
task = protocol_schema.LabelAssistantReplyTask(
|
||||
message_id=message.id,
|
||||
conversation=prepare_conversation(conversation),
|
||||
reply=message.text,
|
||||
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
|
||||
)
|
||||
|
||||
case _:
|
||||
raise OasstError("Invalid request type", OasstErrorCode.TASK_INVALID_REQUEST_TYPE)
|
||||
|
||||
logger.info(f"Generated {task=}.")
|
||||
|
||||
return task, message_tree_id, parent_message_id
|
||||
|
||||
|
||||
@router.post(
|
||||
"/",
|
||||
response_model=protocol_schema.AnyTask,
|
||||
@@ -193,7 +36,9 @@ def request_task(
|
||||
|
||||
try:
|
||||
pr = PromptRepository(db, api_client, client_user=request.user)
|
||||
task, message_tree_id, parent_message_id = generate_task(request, pr)
|
||||
tree_manager_config = TreeManagerConfiguration()
|
||||
tm = TreeManager(db, pr, tree_manager_config)
|
||||
task, message_tree_id, parent_message_id = tm.next_task(request.type)
|
||||
pr.task_repository.store_task(task, message_tree_id, parent_message_id, request.collective)
|
||||
|
||||
except OasstError:
|
||||
@@ -268,63 +113,10 @@ async def tasks_interaction(
|
||||
|
||||
try:
|
||||
pr = PromptRepository(db, api_client, client_user=interaction.user)
|
||||
tree_manager_config = TreeManagerConfiguration()
|
||||
tm = TreeManager(db, pr, tree_manager_config)
|
||||
return await tm.handle_interaction(interaction)
|
||||
|
||||
match type(interaction):
|
||||
case protocol_schema.TextReplyToMessage:
|
||||
logger.info(
|
||||
f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
# here we store the text reply in the database
|
||||
newMessage = pr.store_text_reply(
|
||||
text=interaction.text,
|
||||
frontend_message_id=interaction.message_id,
|
||||
user_frontend_message_id=interaction.user_message_id,
|
||||
)
|
||||
|
||||
if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION:
|
||||
try:
|
||||
hugging_face_api = HuggingFaceAPI(
|
||||
f"{HfUrl.HUGGINGFACE_FEATURE_EXTRACTION.value}/{HfEmbeddingModel.MINILM.value}"
|
||||
)
|
||||
embedding = await hugging_face_api.post(interaction.text)
|
||||
pr.insert_message_embedding(
|
||||
message_id=newMessage.id, model=HfEmbeddingModel.MINILM.value, embedding=embedding
|
||||
)
|
||||
except OasstError:
|
||||
logger.error(
|
||||
f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
return protocol_schema.TaskDone()
|
||||
case protocol_schema.MessageRating:
|
||||
logger.info(
|
||||
f"Frontend reports rating of {interaction.message_id=} with {interaction.rating=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
# here we store the rating in the database
|
||||
pr.store_rating(interaction)
|
||||
|
||||
return protocol_schema.TaskDone()
|
||||
case protocol_schema.MessageRanking:
|
||||
logger.info(
|
||||
f"Frontend reports ranking of {interaction.message_id=} with {interaction.ranking=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
# TODO: check if the ranking is valid
|
||||
pr.store_ranking(interaction)
|
||||
# here we would store the ranking in the database
|
||||
return protocol_schema.TaskDone()
|
||||
case protocol_schema.TextLabels:
|
||||
logger.info(
|
||||
f"Frontend reports labels of {interaction.message_id=} with {interaction.labels=} by {interaction.user=}."
|
||||
)
|
||||
# Labels are implicitly validated when converting str -> TextLabel
|
||||
# So no need for explicit validation here
|
||||
pr.store_text_labels(interaction)
|
||||
return protocol_schema.TaskDone()
|
||||
case _:
|
||||
raise OasstError("Invalid response type.", OasstErrorCode.TASK_INVALID_RESPONSE_TYPE)
|
||||
except OasstError:
|
||||
raise
|
||||
except Exception:
|
||||
|
||||
@@ -18,19 +18,20 @@ def prepare_message_list(messages: list[Message]) -> list[protocol.Message]:
|
||||
return [prepare_message(m) for m in messages]
|
||||
|
||||
|
||||
def prepare_conversation(messages: list[Message]) -> protocol.Conversation:
|
||||
conv_messages = []
|
||||
for message in messages:
|
||||
conv_messages.append(
|
||||
protocol.ConversationMessage(
|
||||
text=message.text,
|
||||
is_assistant=(message.role == "assistant"),
|
||||
message_id=message.id,
|
||||
frontend_message_id=message.frontend_message_id,
|
||||
)
|
||||
def prepare_conversation_message_list(messages: list[Message]) -> list[protocol.ConversationMessage]:
|
||||
return [
|
||||
protocol.ConversationMessage(
|
||||
text=message.text,
|
||||
is_assistant=(message.role == "assistant"),
|
||||
message_id=message.id,
|
||||
frontend_message_id=message.frontend_message_id,
|
||||
)
|
||||
for message in messages
|
||||
]
|
||||
|
||||
return protocol.Conversation(messages=conv_messages)
|
||||
|
||||
def prepare_conversation(messages: list[Message]) -> protocol.Conversation:
|
||||
return protocol.Conversation(messages=prepare_conversation_message_list(messages))
|
||||
|
||||
|
||||
def prepare_tree(tree: list[Message], tree_id: UUID) -> protocol.MessageTree:
|
||||
|
||||
@@ -25,6 +25,7 @@ class Settings(BaseSettings):
|
||||
DEBUG_USE_SEED_DATA_PATH: Optional[FilePath] = (
|
||||
Path(__file__).parent.parent / "test_data/generic/test_generic_data.json"
|
||||
)
|
||||
DEBUG_ALLOW_SELF_LABELING: bool = False # allow users to label their own messages
|
||||
DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False
|
||||
|
||||
HUGGING_FACE_API_KEY: str = ""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Literal
|
||||
from typing import Literal, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from oasst_backend.models.payload_column_type import payload_type
|
||||
@@ -28,7 +28,7 @@ class RateSummaryPayload(TaskPayload):
|
||||
@payload_type
|
||||
class InitialPromptPayload(TaskPayload):
|
||||
type: Literal["initial_prompt"] = "initial_prompt"
|
||||
hint: str
|
||||
hint: str | None
|
||||
|
||||
|
||||
@payload_type
|
||||
@@ -64,12 +64,13 @@ class RatingReactionPayload(ReactionPayload):
|
||||
class RankingReactionPayload(ReactionPayload):
|
||||
type: Literal["message_ranking"] = "message_ranking"
|
||||
ranking: list[int]
|
||||
ranked_message_ids: list[UUID]
|
||||
|
||||
|
||||
@payload_type
|
||||
class RankConversationRepliesPayload(TaskPayload):
|
||||
conversation: protocol_schema.Conversation # the conversation so far
|
||||
replies: list[str]
|
||||
reply_messages: list[protocol_schema.ConversationMessage]
|
||||
|
||||
|
||||
@payload_type
|
||||
@@ -77,7 +78,7 @@ class RankInitialPromptsPayload(TaskPayload):
|
||||
"""A task to rank a set of initial prompts."""
|
||||
|
||||
type: Literal["rank_initial_prompts"] = "rank_initial_prompts"
|
||||
prompts: list[str]
|
||||
prompt_messages: list[protocol_schema.ConversationMessage]
|
||||
|
||||
|
||||
@payload_type
|
||||
@@ -102,6 +103,7 @@ class LabelInitialPromptPayload(TaskPayload):
|
||||
message_id: UUID
|
||||
prompt: str
|
||||
valid_labels: list[str]
|
||||
mandatory_labels: Optional[list[str]]
|
||||
|
||||
|
||||
@payload_type
|
||||
@@ -112,6 +114,7 @@ class LabelConversationReplyPayload(TaskPayload):
|
||||
conversation: protocol_schema.Conversation
|
||||
reply: str
|
||||
valid_labels: list[str]
|
||||
mandatory_labels: Optional[list[str]]
|
||||
|
||||
|
||||
@payload_type
|
||||
|
||||
@@ -41,6 +41,10 @@ class Message(SQLModel, table=True):
|
||||
children_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
|
||||
deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
|
||||
|
||||
review_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
|
||||
review_result: bool = Field(sa_column=sa.Column(sa.Boolean, default=False, server_default=false(), nullable=False))
|
||||
ranking_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
|
||||
|
||||
def ensure_is_message(self) -> None:
|
||||
if not self.payload or not isinstance(self.payload.payload, MessagePayload):
|
||||
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
||||
@@ -1,29 +1,28 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, Index, SQLModel
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class States(str, Enum):
|
||||
class State(str, Enum):
|
||||
"""States of the Open-Assistant message tree state machine."""
|
||||
|
||||
INITIAL_PROMPT_REVIEW = "initial_prompt_review"
|
||||
"""In this state the message tree consists only of a single inital prompt root node.
|
||||
Initial prompt labeling tasks will determine if the tree goes into `breeding_phase` or
|
||||
`aborted_low_grade`."""
|
||||
Initial prompt labeling tasks will determine if the tree goes into `growing` or
|
||||
`aborted_low_grade` state."""
|
||||
|
||||
BREEDING_PHASE = "breeding_phase"
|
||||
GROWING = "growing"
|
||||
"""Assistant & prompter human demonstrations are collected. Concurrently labeling tasks
|
||||
are handed out to check if the quality of the replies surpasses the minimum acceptable
|
||||
quality.
|
||||
When the required number of messages passing the initial labelling-quality check has been
|
||||
collected the tree will enter `ranking_phase`. If too many poor-quality labelling responses
|
||||
collected the tree will enter `ranking`. If too many poor-quality labelling responses
|
||||
are received the tree can also enter the `aborted_low_grade` state."""
|
||||
|
||||
RANKING_PHASE = "ranking_phase"
|
||||
RANKING = "ranking"
|
||||
"""The tree has been successfully populated with the desired number of messages. Ranking
|
||||
tasks are now handed out for all nodes with more than one child."""
|
||||
|
||||
@@ -46,28 +45,26 @@ class States(str, Enum):
|
||||
|
||||
|
||||
VALID_STATES = (
|
||||
States.INITIAL_PROMPT_REVIEW,
|
||||
States.BREEDING_PHASE,
|
||||
States.RANKING_PHASE,
|
||||
States.READY_FOR_SCORING,
|
||||
States.READY_FOR_EXPORT,
|
||||
States.ABORTED_LOW_GRADE,
|
||||
State.INITIAL_PROMPT_REVIEW,
|
||||
State.GROWING,
|
||||
State.RANKING,
|
||||
State.READY_FOR_SCORING,
|
||||
State.READY_FOR_EXPORT,
|
||||
State.ABORTED_LOW_GRADE,
|
||||
)
|
||||
|
||||
TERMINAL_STATES = (States.READY_FOR_EXPORT, States.ABORTED_LOW_GRADE, States.SCORING_FAILED, States.HALTED_BY_MODERATOR)
|
||||
TERMINAL_STATES = (State.READY_FOR_EXPORT, State.ABORTED_LOW_GRADE, State.SCORING_FAILED, State.HALTED_BY_MODERATOR)
|
||||
|
||||
|
||||
class MessageTreeState(SQLModel, table=True):
|
||||
__tablename__ = "message_tree_state"
|
||||
__table_args__ = (Index("ix_message_tree_state_tree_id", "message_tree_id", unique=True),)
|
||||
|
||||
id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
|
||||
),
|
||||
message_tree_id: UUID = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("message.id"), primary_key=True)
|
||||
)
|
||||
message_tree_id: UUID = Field(nullable=False, index=True)
|
||||
state: str = Field(nullable=False, max_length=128)
|
||||
goal_tree_size: int = Field(nullable=False)
|
||||
current_num_non_filtered_messages: int = Field(nullable=False)
|
||||
max_depth: int = Field(nullable=False)
|
||||
max_children_count: int = Field(nullable=False)
|
||||
state: str = Field(nullable=False, max_length=128, index=True)
|
||||
active: bool = Field(nullable=False, index=True)
|
||||
accepted_messages: int = Field(nullable=False, default=0)
|
||||
|
||||
@@ -2,13 +2,24 @@ import datetime
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from http import HTTPStatus
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import oasst_backend.models.db_payload as db_payload
|
||||
import sqlalchemy as sa
|
||||
from loguru import logger
|
||||
from oasst_backend.journal_writer import JournalWriter
|
||||
from oasst_backend.models import ApiClient, Message, MessageEmbedding, MessageReaction, TextLabels, User
|
||||
from oasst_backend.models import (
|
||||
ApiClient,
|
||||
Message,
|
||||
MessageEmbedding,
|
||||
MessageReaction,
|
||||
MessageTreeState,
|
||||
Task,
|
||||
TextLabels,
|
||||
User,
|
||||
message_tree_state,
|
||||
)
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_backend.task_repository import TaskRepository, validate_frontend_message_id
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
@@ -34,6 +45,7 @@ class PromptRepository:
|
||||
self.user_repository = user_repository or UserRepository(db, api_client)
|
||||
self.user = self.user_repository.lookup_client_user(client_user, create_missing=True)
|
||||
self.user_id = self.user.id if self.user else None
|
||||
logger.debug(f"PromptRepository(api_client_id={self.api_client.id}, {self.user_id=})")
|
||||
self.task_repository = task_repository or TaskRepository(
|
||||
db, api_client, client_user, user_repository=self.user_repository
|
||||
)
|
||||
@@ -66,6 +78,8 @@ class PromptRepository:
|
||||
payload: db_payload.MessagePayload,
|
||||
payload_type: str = None,
|
||||
depth: int = 0,
|
||||
review_count: int = 0,
|
||||
review_result: bool = False,
|
||||
) -> Message:
|
||||
if payload_type is None:
|
||||
if payload is None:
|
||||
@@ -85,25 +99,24 @@ class PromptRepository:
|
||||
payload_type=payload_type,
|
||||
payload=PayloadContainer(payload=payload),
|
||||
depth=depth,
|
||||
review_count=review_count,
|
||||
review_result=review_result,
|
||||
)
|
||||
self.db.add(message)
|
||||
self.db.commit()
|
||||
self.db.refresh(message)
|
||||
return message
|
||||
|
||||
def store_text_reply(
|
||||
self,
|
||||
text: str,
|
||||
frontend_message_id: str,
|
||||
user_frontend_message_id: str,
|
||||
) -> Message:
|
||||
validate_frontend_message_id(frontend_message_id)
|
||||
validate_frontend_message_id(user_frontend_message_id)
|
||||
|
||||
task = self.task_repository.fetch_task_by_frontend_message_id(frontend_message_id)
|
||||
|
||||
def _validate_task(
|
||||
self, task: Task, *, task_id: Optional[UUID] = None, frontend_message_id: Optional[str] = None
|
||||
) -> Task:
|
||||
if task is None:
|
||||
raise OasstError(f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
|
||||
if task_id:
|
||||
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
|
||||
if frontend_message_id:
|
||||
raise OasstError(f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
|
||||
raise OasstError("Task not found", OasstErrorCode.TASK_NOT_FOUND)
|
||||
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
|
||||
if not task.ack:
|
||||
@@ -111,12 +124,45 @@ class PromptRepository:
|
||||
if task.done:
|
||||
raise OasstError("Task already done.", OasstErrorCode.TASK_ALREADY_DONE)
|
||||
|
||||
if (not task.collective or task.user_id is None) and task.user_id != self.user_id:
|
||||
logger.warning(f"Task was assigned to a different user (expected: {task.user_id}; actual: {self.user_id}).")
|
||||
raise OasstError("Task was assigned to a different user.", OasstErrorCode.TASK_NOT_ASSIGNED_TO_USER)
|
||||
|
||||
return task
|
||||
|
||||
def fetch_tree_state(self, message_tree_id: UUID) -> MessageTreeState:
|
||||
return self.db.query(MessageTreeState).filter(MessageTreeState.message_tree_id == message_tree_id).one()
|
||||
|
||||
def store_text_reply(
|
||||
self,
|
||||
text: str,
|
||||
frontend_message_id: str,
|
||||
user_frontend_message_id: str,
|
||||
review_count: int = 0,
|
||||
review_result: bool = False,
|
||||
) -> Message:
|
||||
validate_frontend_message_id(frontend_message_id)
|
||||
validate_frontend_message_id(user_frontend_message_id)
|
||||
|
||||
task = self.task_repository.fetch_task_by_frontend_message_id(frontend_message_id)
|
||||
self._validate_task(task)
|
||||
|
||||
# If there's no parent message assume user started new conversation
|
||||
role = "prompter"
|
||||
depth = 0
|
||||
|
||||
if task.parent_message_id:
|
||||
parent_message = self.fetch_message(task.parent_message_id)
|
||||
|
||||
# check tree state
|
||||
ts = self.fetch_tree_state(parent_message.message_tree_id)
|
||||
if not ts.active or ts.state != message_tree_state.State.GROWING:
|
||||
raise OasstError(
|
||||
"Message insertion failed. Message tree is no longer in 'growing' state.",
|
||||
OasstErrorCode.TREE_NOT_IN_GROWING_STATE,
|
||||
)
|
||||
|
||||
parent_message.message_tree_id
|
||||
parent_message.children_count += 1
|
||||
self.db.add(parent_message)
|
||||
|
||||
@@ -137,6 +183,8 @@ class PromptRepository:
|
||||
role=role,
|
||||
payload=db_payload.MessagePayload(text=text),
|
||||
depth=depth,
|
||||
review_count=review_count,
|
||||
review_result=review_result,
|
||||
)
|
||||
if not task.collective:
|
||||
task.done = True
|
||||
@@ -149,6 +197,7 @@ class PromptRepository:
|
||||
message = self.fetch_message_by_frontend_message_id(rating.message_id, fail_if_missing=True)
|
||||
|
||||
task = self.task_repository.fetch_task_by_frontend_message_id(rating.message_id)
|
||||
self._validate_task(task)
|
||||
task_payload: db_payload.RateSummaryPayload = task.payload.payload
|
||||
if type(task_payload) != db_payload.RateSummaryPayload:
|
||||
raise OasstError(
|
||||
@@ -173,9 +222,10 @@ class PromptRepository:
|
||||
logger.info(f"Ranking {rating.rating} stored for task {task.id}.")
|
||||
return reaction
|
||||
|
||||
def store_ranking(self, ranking: protocol_schema.MessageRanking) -> MessageReaction:
|
||||
def store_ranking(self, ranking: protocol_schema.MessageRanking) -> Tuple[MessageReaction, Task]:
|
||||
# fetch task
|
||||
task = self.task_repository.fetch_task_by_frontend_message_id(ranking.message_id)
|
||||
self._validate_task(task, frontend_message_id=ranking.message_id)
|
||||
if not task.collective:
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
@@ -188,47 +238,59 @@ class PromptRepository:
|
||||
|
||||
case db_payload.RankPrompterRepliesPayload | db_payload.RankAssistantRepliesPayload:
|
||||
# validate ranking
|
||||
num_replies = len(task_payload.replies)
|
||||
if sorted(ranking.ranking) != list(range(num_replies)):
|
||||
if sorted(ranking.ranking) != list(range(num_replies := len(task_payload.reply_messages))):
|
||||
raise OasstError(
|
||||
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_replies=}).",
|
||||
OasstErrorCode.INVALID_RANKING_VALUE,
|
||||
)
|
||||
|
||||
last_conv_message = task_payload.conversation.messages[-1]
|
||||
parent_msg = self.fetch_message(last_conv_message.message_id)
|
||||
|
||||
# store reaction to message
|
||||
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
|
||||
ranked_message_ids = [task_payload.reply_messages[i].message_id for i in ranking.ranking]
|
||||
for mid in ranked_message_ids:
|
||||
message = self.fetch_message(mid)
|
||||
if message.parent_id != parent_msg.id:
|
||||
raise OasstError("Corrupt reply ranking result", OasstErrorCode.CORRUPT_RANKING_RESULT)
|
||||
message.ranking_count += 1
|
||||
self.db.add(message)
|
||||
|
||||
reaction_payload = db_payload.RankingReactionPayload(
|
||||
ranking=ranking.ranking, ranked_message_ids=ranked_message_ids
|
||||
)
|
||||
reaction = self.insert_reaction(task.id, reaction_payload)
|
||||
# TODO: resolve message_id
|
||||
self.journal.log_ranking(task, message_id=None, ranking=ranking.ranking)
|
||||
self.journal.log_ranking(task, message_id=parent_msg.id, ranking=ranking.ranking)
|
||||
|
||||
logger.info(f"Ranking {ranking.ranking} stored for task {task.id}.")
|
||||
|
||||
return reaction
|
||||
|
||||
case db_payload.RankInitialPromptsPayload:
|
||||
# validate ranking
|
||||
if sorted(ranking.ranking) != list(range(num_prompts := len(task_payload.prompts))):
|
||||
if sorted(ranking.ranking) != list(range(num_prompts := len(task_payload.prompt_messages))):
|
||||
raise OasstError(
|
||||
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_prompts=}).",
|
||||
OasstErrorCode.INVALID_RANKING_VALUE,
|
||||
)
|
||||
|
||||
# store reaction to message
|
||||
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
|
||||
ranked_message_ids = [task_payload.prompt_messages[i].message_id for i in ranking.ranking]
|
||||
reaction_payload = db_payload.RankingReactionPayload(
|
||||
ranking=ranking.ranking, ranked_message_ids=ranked_message_ids
|
||||
)
|
||||
reaction = self.insert_reaction(task.id, reaction_payload)
|
||||
# TODO: resolve message_id
|
||||
self.journal.log_ranking(task, message_id=None, ranking=ranking.ranking)
|
||||
|
||||
logger.info(f"Ranking {ranking.ranking} stored for task {task.id}.")
|
||||
|
||||
return reaction
|
||||
|
||||
case _:
|
||||
raise OasstError(
|
||||
f"task payload type mismatch: {type(task_payload)=} != {db_payload.RankConversationRepliesPayload}",
|
||||
OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH,
|
||||
)
|
||||
|
||||
return reaction, task
|
||||
|
||||
def insert_message_embedding(self, message_id: UUID, model: str, embedding: List[float]) -> MessageEmbedding:
|
||||
"""Insert the embedding of a new message in the database.
|
||||
|
||||
@@ -270,21 +332,80 @@ class PromptRepository:
|
||||
self.db.refresh(reaction)
|
||||
return reaction
|
||||
|
||||
def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> TextLabels:
|
||||
def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> Tuple[TextLabels, Task, Message]:
|
||||
|
||||
valid_labels: Optional[list[str]] = None
|
||||
mandatory_labels: Optional[list[str]] = None
|
||||
text_labels_id: Optional[UUID] = None
|
||||
message_id: Optional[UUID] = text_labels.message_id
|
||||
|
||||
task: Task = None
|
||||
if text_labels.task_id:
|
||||
logger.debug(f"text_labels reply has task_id {text_labels.task_id}")
|
||||
task = self.task_repository.fetch_task_by_id(text_labels.task_id)
|
||||
self._validate_task(task, task_id=text_labels.task_id)
|
||||
|
||||
task_payload: db_payload.TaskPayload = task.payload.payload
|
||||
if isinstance(task_payload, db_payload.LabelInitialPromptPayload):
|
||||
if message_id and task_payload.message_id != message_id:
|
||||
raise OasstError("Task message id mismatch", OasstErrorCode.TEXT_LABELS_WRONG_MESSAGE_ID)
|
||||
message_id = task_payload.message_id
|
||||
valid_labels = task_payload.valid_labels
|
||||
mandatory_labels = task_payload.mandatory_labels
|
||||
elif isinstance(task_payload, db_payload.LabelConversationReplyPayload):
|
||||
if message_id and message_id != message_id:
|
||||
raise OasstError("Task message id mismatch", OasstErrorCode.TEXT_LABELS_WRONG_MESSAGE_ID)
|
||||
message_id = task_payload.message_id
|
||||
valid_labels = task_payload.valid_labels
|
||||
mandatory_labels = task_payload.mandatory_labels
|
||||
else:
|
||||
raise OasstError(
|
||||
"Unexpected text_labels task payload",
|
||||
OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH,
|
||||
)
|
||||
|
||||
logger.debug(f"text_labels relpy: {valid_labels=}, {mandatory_labels=}")
|
||||
|
||||
if valid_labels:
|
||||
if not all([label in valid_labels for label in text_labels.labels.keys()]):
|
||||
raise OasstError("Invalid text label specified", OasstErrorCode.TEXT_LABELS_INVALID_LABEL)
|
||||
|
||||
if isinstance(mandatory_labels, list):
|
||||
mandatory_set = set(mandatory_labels)
|
||||
if not mandatory_set.issubset(text_labels.labels.keys()):
|
||||
missing = ", ".join(mandatory_set - text_labels.labels.keys())
|
||||
raise OasstError(
|
||||
f"Mandatory text labels missing: {missing}", OasstErrorCode.TEXT_LABELS_MANDATORY_LABEL_MISSING
|
||||
)
|
||||
|
||||
text_labels_id = task.id # associate with task by sharing the id
|
||||
|
||||
if not task.collective:
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
|
||||
logger.debug(f"inserting TextLabels for {message_id=}, {text_labels_id=}")
|
||||
model = TextLabels(
|
||||
id=text_labels_id,
|
||||
api_client_id=self.api_client.id,
|
||||
message_id=text_labels.message_id,
|
||||
message_id=message_id,
|
||||
user_id=self.user_id,
|
||||
text=text_labels.text,
|
||||
labels=text_labels.labels,
|
||||
)
|
||||
|
||||
if message_id:
|
||||
message = self.fetch_message(message_id)
|
||||
if task:
|
||||
message.review_count += 1
|
||||
self.db.add(message)
|
||||
|
||||
self.db.add(model)
|
||||
self.db.commit()
|
||||
self.db.refresh(model)
|
||||
return model
|
||||
return model, task, message
|
||||
|
||||
def fetch_random_message_tree(self, require_role: str = None) -> list[Message]:
|
||||
def fetch_random_message_tree(self, require_role: str = None, reviewed: bool = True) -> list[Message]:
|
||||
"""
|
||||
Loads all messages of a random message_tree.
|
||||
|
||||
@@ -294,13 +415,18 @@ class PromptRepository:
|
||||
distinct_message_trees = self.db.query(Message.message_tree_id).distinct(Message.message_tree_id)
|
||||
if require_role:
|
||||
distinct_message_trees = distinct_message_trees.filter(Message.role == require_role)
|
||||
if reviewed:
|
||||
distinct_message_trees = distinct_message_trees.filter(Message.review_result)
|
||||
distinct_message_trees = distinct_message_trees.subquery()
|
||||
|
||||
random_message_tree = self.db.query(distinct_message_trees).order_by(func.random()).limit(1)
|
||||
message_tree_messages = self.db.query(Message).filter(Message.message_tree_id.in_(random_message_tree)).all()
|
||||
return message_tree_messages
|
||||
random_message_tree_id = self.db.query(distinct_message_trees).order_by(func.random()).limit(1).scalar()
|
||||
if random_message_tree_id:
|
||||
return self.fetch_message_tree(random_message_tree_id, reviewed)
|
||||
return None
|
||||
|
||||
def fetch_random_conversation(self, last_message_role: str = None) -> list[Message]:
|
||||
def fetch_random_conversation(
|
||||
self, last_message_role: str = None, message_tree_id: Optional[UUID] = None, reviewed: bool = True
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Picks a random linear conversation starting from any root message
|
||||
and ending somewhere in the message_tree, possibly at the root itself.
|
||||
@@ -310,9 +436,13 @@ class PromptRepository:
|
||||
the user should reply as a human and hence the last message of the conversation
|
||||
needs to have "assistant" role.
|
||||
"""
|
||||
messages_tree = self.fetch_random_message_tree(last_message_role)
|
||||
if message_tree_id:
|
||||
messages_tree = self.fetch_message_tree(message_tree_id, reviewed)
|
||||
else:
|
||||
messages_tree = self.fetch_random_message_tree(last_message_role)
|
||||
if not messages_tree:
|
||||
raise OasstError("No message tree found", OasstErrorCode.NO_MESSAGE_TREE_FOUND)
|
||||
|
||||
if last_message_role:
|
||||
conv_messages = [m for m in messages_tree if m.role == last_message_role]
|
||||
conv_messages = [random.choice(conv_messages)]
|
||||
@@ -334,8 +464,11 @@ class PromptRepository:
|
||||
messages = self.db.query(Message).filter(Message.parent_id.is_(None)).order_by(func.random()).limit(size).all()
|
||||
return messages
|
||||
|
||||
def fetch_message_tree(self, message_tree_id: UUID):
|
||||
return self.db.query(Message).filter(Message.message_tree_id == message_tree_id).all()
|
||||
def fetch_message_tree(self, message_tree_id: UUID, reviewed: bool = True):
|
||||
qry = self.db.query(Message).filter(Message.message_tree_id == message_tree_id)
|
||||
if reviewed:
|
||||
qry = qry.filter(Message.review_result)
|
||||
return qry.all()
|
||||
|
||||
def fetch_multiple_random_replies(self, max_size: int = 5, message_role: str = None):
|
||||
"""
|
||||
@@ -388,13 +521,15 @@ class PromptRepository:
|
||||
messages = {m.id: m for m in messages}
|
||||
if not isinstance(messages, dict):
|
||||
# This should not normally happen
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR0, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
||||
conv = [last_message]
|
||||
while conv[-1].parent_id:
|
||||
if conv[-1].parent_id not in messages:
|
||||
# Can't form a continuous conversation
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
raise OasstError(
|
||||
"Broken conversation", OasstErrorCode.BROKEN_CONVERSATION, HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
parent_message = messages[conv[-1].parent_id]
|
||||
conv.append(parent_message)
|
||||
@@ -417,16 +552,24 @@ class PromptRepository:
|
||||
"""
|
||||
if isinstance(message, UUID):
|
||||
message = self.fetch_message(message)
|
||||
logger.debug(f"fetch_message_tree({message.message_tree_id=})")
|
||||
return self.fetch_message_tree(message.message_tree_id)
|
||||
|
||||
def fetch_message_children(self, message: Message | UUID) -> list[Message]:
|
||||
def fetch_message_children(
|
||||
self, message: Message | UUID, reviewed: bool = True, exclude_deleted: bool = True
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Get all direct children of this message
|
||||
"""
|
||||
if isinstance(message, Message):
|
||||
message = message.id
|
||||
|
||||
children = self.db.query(Message).filter(Message.parent_id == message).all()
|
||||
qry = self.db.query(Message).filter(Message.parent_id == message)
|
||||
if reviewed:
|
||||
qry = qry.filter(Message.review_result)
|
||||
if exclude_deleted:
|
||||
qry = qry.filter(Message.deleted == sa.false())
|
||||
children = qry.all()
|
||||
return children
|
||||
|
||||
@staticmethod
|
||||
@@ -536,7 +679,7 @@ class PromptRepository:
|
||||
elif isinstance(message, Message):
|
||||
ids.append(message.id)
|
||||
else:
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR1, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
||||
query = update(Message).where(Message.id.in_(ids)).values(deleted=True)
|
||||
self.db.execute(query)
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import oasst_backend.models.db_payload as db_payload
|
||||
from loguru import logger
|
||||
from oasst_backend.models import ApiClient, Task
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
@@ -62,16 +63,16 @@ class TaskRepository:
|
||||
payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
|
||||
|
||||
case protocol_schema.RankInitialPromptsTask:
|
||||
payload = db_payload.RankInitialPromptsPayload(type=task.type, prompts=task.prompts)
|
||||
payload = db_payload.RankInitialPromptsPayload(type=task.type, prompt_messages=task.prompt_messages)
|
||||
|
||||
case protocol_schema.RankPrompterRepliesTask:
|
||||
payload = db_payload.RankPrompterRepliesPayload(
|
||||
type=task.type, conversation=task.conversation, replies=task.replies
|
||||
type=task.type, conversation=task.conversation, reply_messages=task.reply_messages
|
||||
)
|
||||
|
||||
case protocol_schema.RankAssistantRepliesTask:
|
||||
payload = db_payload.RankAssistantRepliesPayload(
|
||||
type=task.type, conversation=task.conversation, replies=task.replies
|
||||
type=task.type, conversation=task.conversation, reply_messages=task.reply_messages
|
||||
)
|
||||
|
||||
case protocol_schema.LabelInitialPromptTask:
|
||||
@@ -86,6 +87,7 @@ class TaskRepository:
|
||||
conversation=task.conversation,
|
||||
reply=task.reply,
|
||||
valid_labels=task.valid_labels,
|
||||
mandatory_labels=task.mandatory_labels,
|
||||
)
|
||||
|
||||
case protocol_schema.LabelAssistantReplyTask:
|
||||
@@ -95,20 +97,21 @@ class TaskRepository:
|
||||
conversation=task.conversation,
|
||||
reply=task.reply,
|
||||
valid_labels=task.valid_labels,
|
||||
mandatory_labels=task.mandatory_labels,
|
||||
)
|
||||
|
||||
case _:
|
||||
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
|
||||
|
||||
task = self.insert_task(
|
||||
task_model = self.insert_task(
|
||||
payload=payload,
|
||||
id=task.id,
|
||||
message_tree_id=message_tree_id,
|
||||
parent_message_id=parent_message_id,
|
||||
collective=collective,
|
||||
)
|
||||
assert task.id == task.id
|
||||
return task
|
||||
assert task_model.id == task.id
|
||||
return task_model
|
||||
|
||||
def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str):
|
||||
validate_frontend_message_id(frontend_message_id)
|
||||
@@ -184,6 +187,7 @@ class TaskRepository:
|
||||
parent_message_id=parent_message_id,
|
||||
collective=collective,
|
||||
)
|
||||
logger.debug(f"inserting {task=}")
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
self.db.refresh(task)
|
||||
@@ -197,3 +201,7 @@ class TaskRepository:
|
||||
.one_or_none()
|
||||
)
|
||||
return task
|
||||
|
||||
def fetch_task_by_id(self, task_id: UUID) -> Task:
|
||||
task = self.db.query(Task).filter(Task.api_client_id == self.api_client.id, Task.id == task_id).one_or_none()
|
||||
return task
|
||||
|
||||
@@ -0,0 +1,804 @@
|
||||
import random
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
import numpy as np
|
||||
import pydantic
|
||||
from loguru import logger
|
||||
from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversation_message_list
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.models import Message, MessageReaction, MessageTreeState, TextLabels, message_tree_state
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.utils.hugging_face import HfEmbeddingModel, HfUrl, HuggingFaceAPI
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlalchemy.sql import text
|
||||
from sqlmodel import Session, func
|
||||
|
||||
|
||||
class TreeManagerConfiguration(pydantic.BaseModel):
|
||||
"""Configuration class for the TreeManager"""
|
||||
|
||||
max_active_trees: int = 10
|
||||
"""Maximum number of concurrently active trees in the database.
|
||||
No new initial prompt tasks will be handed out to users if this
|
||||
number is reached."""
|
||||
|
||||
max_tree_depth: int = 6
|
||||
"""Maximum depth of message tree."""
|
||||
|
||||
max_children_count: int = 5
|
||||
"""Maximum number of reply messages per tree node."""
|
||||
|
||||
goal_tree_size: int = 15
|
||||
"""Total number of messages to gather per tree"""
|
||||
|
||||
num_reviews_initial_prompt: int = 3
|
||||
"""Number of peer review checks to collect in INITIAL_PROMPT_REVIEW state."""
|
||||
|
||||
num_reviews_reply: int = 3
|
||||
"""Number of peer review checks to collect per reply (other than initial_prompt)"""
|
||||
|
||||
acceptance_threshold_initial_prompt: float = 0.6
|
||||
"""Threshold for accepting an initial prompt"""
|
||||
|
||||
acceptance_threshold_reply: float = 0.6
|
||||
"""Threshold for accepting a reply"""
|
||||
|
||||
num_required_rankings: int = 3
|
||||
"""Number of rankings in which the message participated."""
|
||||
|
||||
mandatory_labels_initial_prompt: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam]
|
||||
mandatory_labels_assistant_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam]
|
||||
mandatory_labels_prompter_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam]
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
NONE = -1
|
||||
RANKING = 0
|
||||
LABEL_REPLY = 1
|
||||
REPLY = 2
|
||||
LABEL_PROMPT = 3
|
||||
PROMPT = 4
|
||||
|
||||
|
||||
class TaskRole(Enum):
|
||||
ANY = 0
|
||||
PROMPTER = 1
|
||||
ASSISTANT = 2
|
||||
|
||||
|
||||
class ActiveTreeSizeRow(pydantic.BaseModel):
|
||||
message_tree_id: UUID
|
||||
tree_size: int
|
||||
goal_tree_size: int
|
||||
|
||||
@property
|
||||
def remaining_messages(self) -> int:
|
||||
return max(0, self.goal_tree_size - self.tree_size)
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
|
||||
class ExtendibleParentRow(pydantic.BaseModel):
|
||||
parent_id: UUID
|
||||
depth: int
|
||||
message_tree_id: UUID
|
||||
active_children_count: int
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
|
||||
class IncompleteRankingsRow(pydantic.BaseModel):
|
||||
parent_id: UUID
|
||||
children_count: int
|
||||
child_min_ranking_count: int
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
|
||||
class TreeManager:
|
||||
def __init__(self, db: Session, prompt_repository: PromptRepository, configuration: TreeManagerConfiguration):
|
||||
self.db = db
|
||||
self.cfg = configuration
|
||||
self.pr = prompt_repository
|
||||
|
||||
def _task_selection(
|
||||
self,
|
||||
desired_task_type: protocol_schema.TaskRequestType,
|
||||
num_ranking_tasks: int,
|
||||
num_replies_need_review: int,
|
||||
num_prompts_need_review: int,
|
||||
num_missing_prompts: int,
|
||||
num_missing_replies: int,
|
||||
) -> Tuple[TaskType, TaskRole]:
|
||||
"""
|
||||
Determines which task to hand out to human worker.
|
||||
The task type is drawn with relative weight (e.g. ranking has highest priority)
|
||||
depending on what is possible with the current message trees in the database.
|
||||
"""
|
||||
|
||||
logger.debug(
|
||||
f"TreeManager._task_selection({num_ranking_tasks=}, {num_replies_need_review=}, "
|
||||
f"{num_prompts_need_review=}, {num_missing_prompts=}, {num_missing_replies=})"
|
||||
)
|
||||
|
||||
task_type = TaskType.NONE
|
||||
task_role = TaskRole.ANY
|
||||
if desired_task_type == protocol_schema.TaskRequestType.random:
|
||||
task_weights = [0] * 5
|
||||
|
||||
if num_ranking_tasks > 0:
|
||||
task_weights[TaskType.RANKING.value] = 10
|
||||
|
||||
if num_replies_need_review > 0:
|
||||
task_weights[TaskType.LABEL_REPLY.value] = 5
|
||||
|
||||
if num_prompts_need_review > 0:
|
||||
task_weights[TaskType.LABEL_PROMPT.value] = 5
|
||||
|
||||
if num_missing_replies > 0:
|
||||
task_weights[TaskType.REPLY.value] = 2
|
||||
|
||||
if num_missing_prompts > 0:
|
||||
task_weights[TaskType.PROMPT.value] = 1
|
||||
|
||||
task_weights = np.array(task_weights)
|
||||
weight_sum = task_weights.sum()
|
||||
if weight_sum < 1e-8:
|
||||
task_type = TaskType.NONE
|
||||
else:
|
||||
task_weights = task_weights / weight_sum
|
||||
task_type = TaskType(np.random.choice(a=len(task_weights), p=task_weights))
|
||||
else:
|
||||
match desired_task_type:
|
||||
case protocol_schema.TaskRequestType.initial_prompt:
|
||||
if num_missing_prompts > 0:
|
||||
task_type = TaskType.PROMPT
|
||||
case protocol_schema.TaskRequestType.label_initial_prompt:
|
||||
if num_prompts_need_review > 0:
|
||||
task_type = TaskType.LABEL_PROMPT
|
||||
case protocol_schema.TaskRequestType.assistant_reply | protocol_schema.TaskRequestType.prompter_reply:
|
||||
if num_missing_replies > 0:
|
||||
task_role = (
|
||||
TaskRole.ASSISTANT
|
||||
if desired_task_type == protocol_schema.TaskRequestType.assistant_reply
|
||||
else TaskRole.PROMPTER
|
||||
)
|
||||
task_type = TaskType.REPLY
|
||||
case protocol_schema.TaskRequestType.label_assistant_reply | protocol_schema.TaskRequestType.label_prompter_reply:
|
||||
if num_replies_need_review > 0:
|
||||
task_role = (
|
||||
TaskRole.ASSISTANT
|
||||
if desired_task_type == protocol_schema.TaskRequestType.label_assistant_reply
|
||||
else TaskRole.PROMPTER
|
||||
)
|
||||
task_type = TaskType.LABEL_REPLY
|
||||
case protocol_schema.TaskRequestType.rank_assistant_replies | protocol_schema.TaskRequestType.rank_prompter_replies:
|
||||
if num_ranking_tasks > 0:
|
||||
task_role = (
|
||||
TaskRole.ASSISTANT
|
||||
if desired_task_type == protocol_schema.TaskRequestType.rank_assistant_replies
|
||||
else TaskRole.PROMPTER
|
||||
)
|
||||
task_type = TaskType.RANKING
|
||||
|
||||
logger.debug(f"Selected {task_type=}, {task_role=}")
|
||||
return task_type, task_role
|
||||
|
||||
def next_task(
|
||||
self, desired_task_type: protocol_schema.TaskRequestType
|
||||
) -> Tuple[protocol_schema.Task, Optional[UUID], Optional[UUID]]:
|
||||
|
||||
logger.debug("TreeManager.next_task()")
|
||||
|
||||
num_active_trees = self.query_num_active_trees()
|
||||
prompts_need_review = self.query_prompts_need_review()
|
||||
replies_need_review = self.query_replies_need_review()
|
||||
incomplete_rankings = self.query_incomplete_rankings()
|
||||
active_tree_sizes = self.query_extendible_trees()
|
||||
|
||||
# determine type of task to generate
|
||||
num_missing_replies = sum(x.remaining_messages for x in active_tree_sizes)
|
||||
|
||||
task_type, task_role = self._task_selection(
|
||||
desired_task_type,
|
||||
num_ranking_tasks=len(incomplete_rankings),
|
||||
num_replies_need_review=len(replies_need_review),
|
||||
num_prompts_need_review=len(prompts_need_review),
|
||||
num_missing_prompts=max(0, self.cfg.max_active_trees - num_active_trees),
|
||||
num_missing_replies=num_missing_replies,
|
||||
)
|
||||
|
||||
if task_type == TaskType.NONE:
|
||||
raise OasstError(
|
||||
f"No tasks of type '{desired_task_type.value}' are currently available.",
|
||||
OasstErrorCode.TASK_REQUESTED_TYPE_UNAVAILABLE,
|
||||
)
|
||||
|
||||
if task_role != TaskRole.ANY:
|
||||
# Todo: Allow role specific message selection...
|
||||
raise OasstError(
|
||||
f"No tasks of type '{desired_task_type.value}' are currently available.",
|
||||
OasstErrorCode.TASK_REQUESTED_TYPE_UNAVAILABLE,
|
||||
)
|
||||
|
||||
message_tree_id = None
|
||||
parent_message_id = None
|
||||
|
||||
logger.debug(f"selected {task_type=}")
|
||||
match task_type:
|
||||
case TaskType.RANKING:
|
||||
assert len(incomplete_rankings) > 0
|
||||
ranking_parent_id = random.choice(incomplete_rankings).parent_id
|
||||
|
||||
messages = self.pr.fetch_message_conversation(ranking_parent_id)
|
||||
conversation = prepare_conversation(messages)
|
||||
replies = self.pr.fetch_message_children(ranking_parent_id, reviewed=True, exclude_deleted=True)
|
||||
reply_messages = prepare_conversation_message_list(replies)
|
||||
replies = [p.text for p in replies]
|
||||
|
||||
if messages[-1].role == "assistant":
|
||||
logger.info("Generating a RankPrompterRepliesTask.")
|
||||
task = protocol_schema.RankPrompterRepliesTask(
|
||||
conversation=conversation, replies=replies, reply_messages=reply_messages
|
||||
)
|
||||
else:
|
||||
logger.info("Generating a RankAssistantRepliesTask.")
|
||||
task = protocol_schema.RankAssistantRepliesTask(
|
||||
conversation=conversation, replies=replies, reply_messages=reply_messages
|
||||
)
|
||||
|
||||
parent_message_id = ranking_parent_id
|
||||
message_tree_id = messages[-1].message_tree_id
|
||||
|
||||
case TaskType.LABEL_REPLY:
|
||||
assert len(replies_need_review) > 0
|
||||
random_reply_message_id = random.choice(replies_need_review)
|
||||
messages = self.pr.fetch_message_conversation(random_reply_message_id)
|
||||
conversation = prepare_conversation(messages[:-1])
|
||||
message = messages[-1]
|
||||
|
||||
if message.role == "assistant":
|
||||
logger.info("Generating a LabelAssistantReplyTask.")
|
||||
task = protocol_schema.LabelAssistantReplyTask(
|
||||
message_id=message.id,
|
||||
conversation=conversation,
|
||||
reply=message.text,
|
||||
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)),
|
||||
)
|
||||
else:
|
||||
logger.info("Generating a LabelPrompterReplyTask.")
|
||||
task = protocol_schema.LabelPrompterReplyTask(
|
||||
message_id=message.id,
|
||||
conversation=conversation,
|
||||
reply=message.text,
|
||||
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)),
|
||||
)
|
||||
|
||||
parent_message_id = message.id
|
||||
message_tree_id = message.message_tree_id
|
||||
|
||||
case TaskType.REPLY:
|
||||
# select a tree with missing replies
|
||||
extensible_parents = self.query_extendible_parents()
|
||||
assert len(extensible_parents) > 0
|
||||
|
||||
# fetch random conversation to extend
|
||||
random_parent = random.choice(extensible_parents)
|
||||
logger.debug(f"selected {random_parent=}")
|
||||
messages = self.pr.fetch_message_conversation(random_parent.parent_id)
|
||||
assert all(m.review_result for m in messages) # ensure all messages have positive review
|
||||
conversation = prepare_conversation(messages)
|
||||
|
||||
# generate reply task depending on last message
|
||||
if messages[-1].role == "assistant":
|
||||
logger.info("Generating a PrompterReplyTask.")
|
||||
task = protocol_schema.PrompterReplyTask(conversation=conversation)
|
||||
else:
|
||||
logger.info("Generating a AssistantReplyTask.")
|
||||
task = protocol_schema.AssistantReplyTask(conversation=conversation)
|
||||
|
||||
parent_message_id = messages[-1].id
|
||||
message_tree_id = messages[-1].message_tree_id
|
||||
|
||||
case TaskType.LABEL_PROMPT:
|
||||
assert len(prompts_need_review) > 0
|
||||
message = self.pr.fetch_message(random.choice(prompts_need_review))
|
||||
logger.info("Generating a LabelInitialPromptTask.")
|
||||
|
||||
task = protocol_schema.LabelInitialPromptTask(
|
||||
message_id=message.id,
|
||||
prompt=message.text,
|
||||
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)),
|
||||
)
|
||||
|
||||
parent_message_id = message.id
|
||||
message_tree_id = message.message_tree_id
|
||||
|
||||
case TaskType.PROMPT:
|
||||
logger.info("Generating an InitialPromptTask.")
|
||||
task = protocol_schema.InitialPromptTask(hint=None)
|
||||
|
||||
case _:
|
||||
task = None
|
||||
|
||||
logger.info(f"Generated {task=}.")
|
||||
|
||||
return task, message_tree_id, parent_message_id
|
||||
|
||||
async def handle_interaction(self, interaction: protocol_schema.AnyInteraction) -> protocol_schema.Task:
|
||||
pr = self.pr
|
||||
match type(interaction):
|
||||
case protocol_schema.TextReplyToMessage:
|
||||
logger.info(
|
||||
f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
# here we store the text reply in the database
|
||||
message = pr.store_text_reply(
|
||||
text=interaction.text,
|
||||
frontend_message_id=interaction.message_id,
|
||||
user_frontend_message_id=interaction.user_message_id,
|
||||
)
|
||||
|
||||
if not message.parent_id:
|
||||
logger.info(f"TreeManager: Inserting new tree state for initial prompt {message.id=}")
|
||||
self._insert_default_state(message.id)
|
||||
self.db.commit()
|
||||
|
||||
if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION:
|
||||
try:
|
||||
hugging_face_api = HuggingFaceAPI(
|
||||
f"{HfUrl.HUGGINGFACE_FEATURE_EXTRACTION.value}/{HfEmbeddingModel.MINILM.value}"
|
||||
)
|
||||
embedding = await hugging_face_api.post(interaction.text)
|
||||
pr.insert_message_embedding(
|
||||
message_id=message.id, model=HfEmbeddingModel.MINILM.value, embedding=embedding
|
||||
)
|
||||
except OasstError:
|
||||
logger.error(
|
||||
f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
case protocol_schema.MessageRating:
|
||||
logger.info(
|
||||
f"Frontend reports rating of {interaction.message_id=} with {interaction.rating=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
pr.store_rating(interaction)
|
||||
|
||||
case protocol_schema.MessageRanking:
|
||||
logger.info(
|
||||
f"Frontend reports ranking of {interaction.message_id=} with {interaction.ranking=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
_, task = pr.store_ranking(interaction)
|
||||
|
||||
self.check_condition_for_scoring_state(task.message_tree_id)
|
||||
|
||||
case protocol_schema.TextLabels:
|
||||
logger.info(
|
||||
f"Frontend reports labels of {interaction.message_id=} with {interaction.labels=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
_, task, msg = pr.store_text_labels(interaction)
|
||||
|
||||
# if it was a respones for a task, check if we have enough reviews to calc review_result
|
||||
if task and msg:
|
||||
reviews = self.query_reviews_for_message(msg.id)
|
||||
acceptance_score = self._calculate_acceptance(reviews)
|
||||
logger.debug(
|
||||
f"Message {msg.id=}, {acceptance_score=}, {len(reviews)=}, {msg.review_result=}, {msg.review_count=}"
|
||||
)
|
||||
if msg.parent_id is None:
|
||||
if not msg.review_result and msg.review_count >= self.cfg.num_reviews_initial_prompt:
|
||||
if acceptance_score > self.cfg.acceptance_threshold_initial_prompt:
|
||||
msg.review_result = True
|
||||
self.db.add(msg)
|
||||
self.db.commit()
|
||||
logger.info(
|
||||
f"Initial prompt message was accepted: {msg.id=}, {acceptance_score=}, {len(reviews)=}"
|
||||
)
|
||||
else:
|
||||
self.enter_low_grade_state(msg.message_tree_id)
|
||||
self.check_condition_for_growing_state(msg.message_tree_id)
|
||||
elif msg.review_count >= self.cfg.num_reviews_reply:
|
||||
if not msg.review_result and acceptance_score > self.cfg.acceptance_threshold_reply:
|
||||
msg.review_result = True
|
||||
self.db.add(msg)
|
||||
self.db.commit()
|
||||
logger.info(
|
||||
f"Reply message message accepted: {msg.id=}, {acceptance_score=}, {len(reviews)=}"
|
||||
)
|
||||
|
||||
self.check_condition_for_ranking_state(msg.message_tree_id)
|
||||
|
||||
case _:
|
||||
raise OasstError("Invalid response type.", OasstErrorCode.TASK_INVALID_RESPONSE_TYPE)
|
||||
|
||||
return protocol_schema.TaskDone()
|
||||
|
||||
def _enter_state(self, mts: MessageTreeState, state: message_tree_state.State):
|
||||
assert mts and mts.active
|
||||
|
||||
is_terminal = state in message_tree_state.TERMINAL_STATES
|
||||
|
||||
if is_terminal:
|
||||
mts.active = False
|
||||
mts.state = state.value
|
||||
self.db.add(mts)
|
||||
self.db.commit()
|
||||
|
||||
if is_terminal:
|
||||
logger.info(f"Tree entered terminal '{mts.state}' state ({mts.message_tree_id=})")
|
||||
else:
|
||||
logger.info(f"Tree entered '{mts.state}' state ({mts.message_tree_id=})")
|
||||
|
||||
def enter_low_grade_state(self, message_tree_id: UUID) -> None:
|
||||
logger.debug(f"enter_low_grade_state({message_tree_id=})")
|
||||
mts = self.pr.fetch_tree_state(message_tree_id)
|
||||
self._enter_state(mts, message_tree_state.State.ABORTED_LOW_GRADE)
|
||||
|
||||
def check_condition_for_growing_state(self, message_tree_id: UUID) -> bool:
|
||||
logger.debug(f"check_condition_for_growing_state({message_tree_id=})")
|
||||
|
||||
mts = self.pr.fetch_tree_state(message_tree_id)
|
||||
if not mts.active or mts.state != message_tree_state.State.INITIAL_PROMPT_REVIEW:
|
||||
logger.debug(f"False {mts.active=}, {mts.state=}")
|
||||
return False
|
||||
|
||||
# check if initial prompt was accepted
|
||||
initial_prompt = self.pr.fetch_message(message_tree_id)
|
||||
if not initial_prompt.review_result:
|
||||
logger.debug(f"False {initial_prompt.review_result=}")
|
||||
return False
|
||||
|
||||
self._enter_state(mts, message_tree_state.State.GROWING)
|
||||
return True
|
||||
|
||||
def check_condition_for_ranking_state(self, message_tree_id: UUID) -> bool:
|
||||
logger.debug(f"check_condition_for_ranking_state({message_tree_id=})")
|
||||
|
||||
mts = self.pr.fetch_tree_state(message_tree_id)
|
||||
if not mts.active or mts.state != message_tree_state.State.GROWING:
|
||||
logger.debug(f"False {mts.active=}, {mts.state=}")
|
||||
return False
|
||||
|
||||
# check if desired tree size has been reached and all nodes have been reviewed
|
||||
tree_size = self.query_tree_size(message_tree_id)
|
||||
if tree_size.remaining_messages > 0:
|
||||
logger.debug(f"False {tree_size.remaining_messages=}")
|
||||
return False
|
||||
|
||||
self._enter_state(mts, message_tree_state.State.RANKING)
|
||||
return True
|
||||
|
||||
def check_condition_for_scoring_state(self, message_tree_id: UUID) -> bool:
|
||||
logger.debug(f"check_condition_for_scoring_state({message_tree_id=})")
|
||||
mts: MessageTreeState
|
||||
mts = self.db.query(MessageTreeState).filter(MessageTreeState.message_tree_id == message_tree_id).one()
|
||||
if not mts.active or mts.state != message_tree_state.State.RANKING:
|
||||
logger.debug(f"False {mts.active=}, {mts.state=}")
|
||||
return False
|
||||
|
||||
rankings_by_message = self.query_tree_ranking_results(message_tree_id)
|
||||
for parent_msg_id, ranking in rankings_by_message.items():
|
||||
if len(ranking) < self.cfg.num_required_rankings:
|
||||
logger.debug(f"False {parent_msg_id=} {len(ranking)=}")
|
||||
return False
|
||||
|
||||
self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING)
|
||||
return True
|
||||
|
||||
def _calculate_acceptance(self, labels: list[TextLabels]):
|
||||
# calculate acceptance based on spam label
|
||||
return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels])
|
||||
|
||||
_sql_find_prompts_need_review = """
|
||||
-- find initial prompts that need more reviews
|
||||
SELECT m.id
|
||||
FROM message_tree_state mts
|
||||
LEFT JOIN message m ON mts.message_tree_id = m.id
|
||||
WHERE mts.active
|
||||
AND mts.state = :state
|
||||
AND NOT m.review_result
|
||||
AND NOT m.deleted
|
||||
AND m.review_count < :num_reviews_initial_prompt
|
||||
AND m.parent_id is NULL
|
||||
AND (:excluded_user_id IS NULL OR m.user_id != :excluded_user_id)
|
||||
"""
|
||||
|
||||
def query_prompts_need_review(self) -> list[UUID]:
|
||||
"""
|
||||
Select id of initial prompts with less then required rankings in active message tree
|
||||
(active == True in message_tree_state)
|
||||
"""
|
||||
|
||||
r = self.db.execute(
|
||||
text(self._sql_find_prompts_need_review),
|
||||
{
|
||||
"state": message_tree_state.State.INITIAL_PROMPT_REVIEW,
|
||||
"num_reviews_initial_prompt": self.cfg.num_reviews_initial_prompt,
|
||||
"excluded_user_id": None if settings.DEBUG_ALLOW_SELF_LABELING else self.pr.user_id,
|
||||
},
|
||||
)
|
||||
return [x["id"] for x in r.all()]
|
||||
|
||||
_sql_find_replies_need_review = """
|
||||
SELECT m.id
|
||||
FROM message_tree_state mts
|
||||
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
|
||||
WHERE mts.active
|
||||
AND mts.state = :breeding_state
|
||||
AND NOT m.review_result
|
||||
AND NOT m.deleted
|
||||
AND m.review_count < :num_required_reviews
|
||||
AND m.parent_id is NOT NULL
|
||||
AND (:excluded_user_id IS NULL OR m.user_id != :excluded_user_id)
|
||||
"""
|
||||
|
||||
def query_replies_need_review(self) -> list[UUID]:
|
||||
"""
|
||||
Select ids of child messages (parent_id IS NOT NULL) with less then required rankings
|
||||
in active message tree (active == True in message_tree_state)
|
||||
"""
|
||||
|
||||
r = self.db.execute(
|
||||
text(self._sql_find_replies_need_review),
|
||||
{
|
||||
"breeding_state": message_tree_state.State.GROWING,
|
||||
"num_required_reviews": self.cfg.num_reviews_reply,
|
||||
"excluded_user_id": None if settings.DEBUG_ALLOW_SELF_LABELING else self.pr.user_id,
|
||||
},
|
||||
)
|
||||
return [x["id"] for x in r.all()]
|
||||
|
||||
_sql_find_incomplete_rankings = """
|
||||
-- find incomplete rankings
|
||||
SELECT m.parent_id, COUNT(m.id) children_count, MIN(m.ranking_count) child_min_ranking_count,
|
||||
COUNT(m.id) FILTER (WHERE m.ranking_count >= :num_required_rankings) as completed_rankings
|
||||
FROM message_tree_state mts
|
||||
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
|
||||
WHERE mts.active -- only consider active trees
|
||||
AND mts.state = :ranking_state -- message tree must be in ranking state
|
||||
AND m.review_result -- must be reviewed
|
||||
AND NOT m.deleted -- not deleted
|
||||
AND m.parent_id IS NOT NULL -- ignore initial prompts
|
||||
GROUP BY m.parent_id
|
||||
HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
|
||||
"""
|
||||
|
||||
def query_incomplete_rankings(self) -> list[IncompleteRankingsRow]:
|
||||
"""Query parents which have childern that need further rankings"""
|
||||
|
||||
r = self.db.execute(
|
||||
text(self._sql_find_incomplete_rankings),
|
||||
{
|
||||
"num_required_rankings": self.cfg.num_required_rankings,
|
||||
"ranking_state": message_tree_state.State.RANKING,
|
||||
},
|
||||
)
|
||||
return [IncompleteRankingsRow.from_orm(x) for x in r.all()]
|
||||
|
||||
_sql_find_extendible_parents = """
|
||||
-- find all extendible parent nodes
|
||||
SELECT m.id as parent_id, m.depth, m.message_tree_id, COUNT(c.id) active_children_count
|
||||
FROM message_tree_state mts
|
||||
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id -- all elements of message tree
|
||||
LEFT JOIN message c ON m.id = c.Id -- child nodes
|
||||
WHERE mts.active -- only consider active trees
|
||||
AND mts.state = :growing_state -- message tree must be growing
|
||||
AND NOT m.deleted -- ignore deleted messages as parents
|
||||
AND m.depth < mts.max_depth -- ignore leaf nodes as parents
|
||||
AND m.review_result -- parent node must have positive review
|
||||
AND NOT c.deleted -- don't count deleted children
|
||||
AND (c.review_result OR c.review_count < :num_reviews_reply) -- don't count children with negative review but count elements under review
|
||||
GROUP BY m.id, m.depth, m.message_tree_id, mts.max_children_count
|
||||
HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children
|
||||
"""
|
||||
|
||||
def query_extendible_parents(self) -> list[ExtendibleParentRow]:
|
||||
"""Query parent messages that have not reached the maximum number of replies."""
|
||||
|
||||
r = self.db.execute(
|
||||
text(self._sql_find_extendible_parents),
|
||||
{
|
||||
"growing_state": message_tree_state.State.GROWING,
|
||||
"num_reviews_reply": self.cfg.num_reviews_reply,
|
||||
},
|
||||
)
|
||||
return [ExtendibleParentRow.from_orm(x) for x in r.all()]
|
||||
|
||||
_sql_find_extendible_trees = f"""
|
||||
-- find extendible trees
|
||||
SELECT m.message_tree_id, mts.goal_tree_size, COUNT(m.id) AS tree_size
|
||||
FROM (
|
||||
SELECT DISTINCT message_tree_id FROM ({_sql_find_extendible_parents}) extendible_parents
|
||||
) trees LEFT JOIN message_tree_state mts ON trees.message_tree_id = mts.message_tree_id
|
||||
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
|
||||
WHERE NOT m.deleted
|
||||
AND (
|
||||
m.parent_id IS NOT NULL AND (m.review_result OR m.review_count < :num_reviews_reply) -- children
|
||||
OR m.parent_id IS NULL AND m.review_result -- prompts (root nodes) must have positive review
|
||||
)
|
||||
GROUP BY m.message_tree_id, mts.goal_tree_size
|
||||
HAVING COUNT(m.id) < mts.goal_tree_size
|
||||
"""
|
||||
|
||||
def query_extendible_trees(self) -> list[ActiveTreeSizeRow]:
|
||||
"""Query size of active message trees in growing state."""
|
||||
|
||||
r = self.db.execute(
|
||||
text(self._sql_find_extendible_trees),
|
||||
{
|
||||
"growing_state": message_tree_state.State.GROWING,
|
||||
"num_reviews_reply": self.cfg.num_reviews_reply,
|
||||
},
|
||||
)
|
||||
return [ActiveTreeSizeRow.from_orm(x) for x in r.all()]
|
||||
|
||||
_sql_get_tree_size = """
|
||||
SELECT mts.message_tree_id, mts.goal_tree_size, COUNT(m.id) AS tree_size
|
||||
FROM message_tree_state mts
|
||||
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
|
||||
WHERE mts.active
|
||||
AND NOT m.deleted
|
||||
AND m.review_result
|
||||
AND mts.message_tree_id = :message_tree_id
|
||||
GROUP BY mts.message_tree_id, mts.goal_tree_size
|
||||
"""
|
||||
|
||||
def query_tree_size(self, message_tree_id: UUID) -> ActiveTreeSizeRow:
|
||||
"""Returns the number of reviewed not deleted messages in the message tree."""
|
||||
r = self.db.execute(text(self._sql_get_tree_size), {"message_tree_id": message_tree_id})
|
||||
return ActiveTreeSizeRow.from_orm(r.one())
|
||||
|
||||
def query_misssing_tree_states(self) -> list[UUID]:
|
||||
"""Find all initial prompt messages that have no associated message tree state"""
|
||||
qry_missing_tree_states = (
|
||||
self.db.query(Message.id)
|
||||
.join(MessageTreeState, isouter=True)
|
||||
.filter(
|
||||
Message.parent_id.is_(None),
|
||||
Message.message_tree_id == Message.id,
|
||||
MessageTreeState.message_tree_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
return [m.id for m in qry_missing_tree_states.all()]
|
||||
|
||||
_sql_find_tree_ranking_results = """
|
||||
-- get all ranking results of completed tasks for all parents with >=2 children
|
||||
SELECT p.parent_id, mr.* FROM
|
||||
(
|
||||
-- find parents with > 1 children
|
||||
SELECT m.parent_id, m.message_tree_id, COUNT(m.id) children_count
|
||||
FROM message_tree_state mts
|
||||
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
|
||||
WHERE m.review_result -- must be reviewed
|
||||
AND NOT m.deleted -- not deleted
|
||||
AND m.parent_id IS NOT NULL -- ignore initial prompts
|
||||
AND mts.message_tree_id = :message_tree_id
|
||||
GROUP BY m.parent_id, m.message_tree_id
|
||||
HAVING COUNT(m.id) > 1
|
||||
) as p
|
||||
LEFT JOIN task t ON p.parent_id = t.parent_message_id AND t.done AND (t.payload_type = 'RankPrompterRepliesPayload' OR t.payload_type = 'RankAssistantRepliesPayload')
|
||||
LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'RankingReactionPayload'
|
||||
"""
|
||||
|
||||
def query_tree_ranking_results(self, message_tree_id: UUID) -> dict[UUID, list[MessageReaction]]:
|
||||
"""Finds all completed ranking restuls for a message_tree"""
|
||||
r = self.db.execute(
|
||||
text(self._sql_find_tree_ranking_results),
|
||||
{"message_tree_id": message_tree_id},
|
||||
)
|
||||
|
||||
rankings_by_message = {}
|
||||
for x in r.all():
|
||||
parent_id = x["parent_id"]
|
||||
if parent_id not in rankings_by_message:
|
||||
rankings_by_message[parent_id] = []
|
||||
if x["task_id"]:
|
||||
rankings_by_message[parent_id].append(MessageReaction.from_orm(x))
|
||||
return rankings_by_message
|
||||
|
||||
def ensure_tree_states(self):
|
||||
"""Add message tree state rows for all root nodes (inital prompt messages)."""
|
||||
|
||||
missing_tree_ids = self.query_misssing_tree_states()
|
||||
for id in missing_tree_ids:
|
||||
tree_size = self.db.query(func.count(Message.id)).filter(Message.message_tree_id == id).scalar()
|
||||
state = message_tree_state.State.INITIAL_PROMPT_REVIEW
|
||||
if tree_size > 1:
|
||||
state = message_tree_state.State.GROWING
|
||||
logger.info(f"Inserting missing message tree state for message: {id} ({tree_size=}, {state=})")
|
||||
self._insert_default_state(id, state=state)
|
||||
self.db.commit()
|
||||
|
||||
def query_num_active_trees(self) -> int:
|
||||
query = self.db.query(func.count(MessageTreeState.message_tree_id)).filter(MessageTreeState.active)
|
||||
return query.scalar()
|
||||
|
||||
def query_reviews_for_message(self, message_id: UUID) -> list[TextLabels]:
|
||||
sql_qry = """
|
||||
SELECT tl.*
|
||||
FROM task t
|
||||
INNER JOIN text_labels tl ON tl.id = t.id
|
||||
WHERE t.done = TRUE
|
||||
AND tl.message_id = :message_id
|
||||
"""
|
||||
r = self.db.execute(text(sql_qry), {"message_id": message_id})
|
||||
return [TextLabels.from_orm(x) for x in r.all()]
|
||||
|
||||
def _insert_tree_state(
|
||||
self,
|
||||
root_message_id: UUID,
|
||||
goal_tree_size: int,
|
||||
max_depth: int,
|
||||
max_children_count: int,
|
||||
active: bool,
|
||||
state: message_tree_state.State = message_tree_state.State.INITIAL_PROMPT_REVIEW,
|
||||
) -> MessageTreeState:
|
||||
model = MessageTreeState(
|
||||
message_tree_id=root_message_id,
|
||||
goal_tree_size=goal_tree_size,
|
||||
max_depth=max_depth,
|
||||
max_children_count=max_children_count,
|
||||
state=state.value,
|
||||
active=active,
|
||||
accepted_messages=0,
|
||||
)
|
||||
|
||||
self.db.add(model)
|
||||
return model
|
||||
|
||||
def _insert_default_state(
|
||||
self,
|
||||
root_message_id: UUID,
|
||||
state: message_tree_state.State = message_tree_state.State.INITIAL_PROMPT_REVIEW,
|
||||
) -> MessageTreeState:
|
||||
return self._insert_tree_state(
|
||||
root_message_id=root_message_id,
|
||||
goal_tree_size=self.cfg.goal_tree_size,
|
||||
max_depth=self.cfg.max_tree_depth,
|
||||
max_children_count=self.cfg.max_children_count,
|
||||
state=state,
|
||||
active=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from oasst_backend.api.deps import get_dummy_api_client
|
||||
from oasst_backend.database import engine
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
|
||||
with Session(engine) as db:
|
||||
api_client = get_dummy_api_client(db)
|
||||
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
|
||||
|
||||
pr = PromptRepository(db=db, api_client=api_client, client_user=dummy_user)
|
||||
|
||||
cfg = TreeManagerConfiguration()
|
||||
tm = TreeManager(db, pr, cfg)
|
||||
tm.ensure_tree_states()
|
||||
|
||||
print("query_num_active_trees", tm.query_num_active_trees())
|
||||
print("query_incomplete_rankings", tm.query_incomplete_rankings())
|
||||
print("query_incomplete_reply_reviews", tm.query_replies_need_review())
|
||||
print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review())
|
||||
print("query_extendible_trees", tm.query_extendible_trees())
|
||||
print("query_extendible_parents", tm.query_extendible_parents())
|
||||
|
||||
print("next_task:", tm.next_task())
|
||||
|
||||
print(
|
||||
".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("2ac20d38-6650-43aa-8bb3-f61080c0d921"))
|
||||
)
|
||||
@@ -51,12 +51,12 @@ class HuggingFaceAPI:
|
||||
|
||||
async with session.post(self.api_url, headers=self.headers, json=payload) as response:
|
||||
# If we get a bad response
|
||||
if response.status != 200:
|
||||
|
||||
if not response.ok:
|
||||
logger.error(response)
|
||||
logger.info(self.headers)
|
||||
raise OasstError(
|
||||
"Response Error Detoxify HuggingFace", error_code=OasstErrorCode.HUGGINGFACE_API_ERROR
|
||||
f"Response Error HuggingFace API (Status: {response.status})",
|
||||
error_code=OasstErrorCode.HUGGINGFACE_API_ERROR,
|
||||
)
|
||||
|
||||
# Get the response from the API call
|
||||
|
||||
@@ -99,6 +99,7 @@ services:
|
||||
- REDIS_HOST=redis
|
||||
- DEBUG_SKIP_API_KEY_CHECK=True
|
||||
- DEBUG_USE_SEED_DATA=True
|
||||
- DEBUG_ALLOW_SELF_LABELING=True
|
||||
- MAX_WORKERS=1
|
||||
- DEBUG_SKIP_EMBEDDING_COMPUTATION=True
|
||||
depends_on:
|
||||
|
||||
@@ -17,9 +17,11 @@ class OasstErrorCode(IntEnum):
|
||||
GENERIC_ERROR = 0
|
||||
DATABASE_URI_NOT_SET = 1
|
||||
API_CLIENT_NOT_AUTHORIZED = 2
|
||||
SERVER_ERROR = 3
|
||||
TOO_MANY_REQUESTS = 429
|
||||
|
||||
SERVER_ERROR0 = 500
|
||||
SERVER_ERROR1 = 501
|
||||
|
||||
# 1000-2000: tasks endpoint
|
||||
TASK_INVALID_REQUEST_TYPE = 1000
|
||||
TASK_ACK_FAILED = 1001
|
||||
@@ -27,6 +29,7 @@ class OasstErrorCode(IntEnum):
|
||||
TASK_INVALID_RESPONSE_TYPE = 1003
|
||||
TASK_INTERACTION_REQUEST_FAILED = 1004
|
||||
TASK_GENERATION_FAILED = 1005
|
||||
TASK_REQUESTED_TYPE_NOT_AVAILABLE = 1006
|
||||
|
||||
# 2000-3000: prompt_repository
|
||||
INVALID_FRONTEND_MESSAGE_ID = 2000
|
||||
@@ -38,6 +41,14 @@ class OasstErrorCode(IntEnum):
|
||||
NO_MESSAGE_TREE_FOUND = 2006
|
||||
NO_REPLIES_FOUND = 2007
|
||||
INVALID_MESSAGE = 2008
|
||||
BROKEN_CONVERSATION = 2009
|
||||
TREE_NOT_IN_GROWING_STATE = 2010
|
||||
CORRUPT_RANKING_RESULT = 2011
|
||||
|
||||
TEXT_LABELS_WRONG_MESSAGE_ID = 2050
|
||||
TEXT_LABELS_INVALID_LABEL = 2051
|
||||
TEXT_LABELS_MANDATORY_LABEL_MISSING = 2052
|
||||
|
||||
TASK_NOT_FOUND = 2100
|
||||
TASK_EXPIRED = 2101
|
||||
TASK_PAYLOAD_TYPE_MISMATCH = 2102
|
||||
@@ -45,6 +56,7 @@ class OasstErrorCode(IntEnum):
|
||||
TASK_NOT_ACK = 2104
|
||||
TASK_ALREADY_DONE = 2105
|
||||
TASK_NOT_COLLECTIVE = 2106
|
||||
TASK_NOT_ASSIGNED_TO_USER = 2106
|
||||
USER_NOT_FOUND = 2200
|
||||
|
||||
# 3000-4000: external resources
|
||||
|
||||
@@ -151,7 +151,8 @@ class RankInitialPromptsTask(Task):
|
||||
"""A task to rank a set of initial prompts."""
|
||||
|
||||
type: Literal["rank_initial_prompts"] = "rank_initial_prompts"
|
||||
prompts: list[str]
|
||||
prompts: list[str] # deprecated, use prompt_messages
|
||||
prompt_messages: list[ConversationMessage]
|
||||
|
||||
|
||||
class RankConversationRepliesTask(Task):
|
||||
@@ -159,7 +160,8 @@ class RankConversationRepliesTask(Task):
|
||||
|
||||
type: Literal["rank_conversation_replies"] = "rank_conversation_replies"
|
||||
conversation: Conversation # the conversation so far
|
||||
replies: list[str]
|
||||
replies: list[str] # deprecated, use reply_messages
|
||||
reply_messages: list[ConversationMessage]
|
||||
|
||||
|
||||
class RankPrompterRepliesTask(RankConversationRepliesTask):
|
||||
@@ -181,6 +183,7 @@ class LabelInitialPromptTask(Task):
|
||||
message_id: UUID
|
||||
prompt: str
|
||||
valid_labels: list[str]
|
||||
mandatory_labels: Optional[list[str]]
|
||||
|
||||
|
||||
class LabelConversationReplyTask(Task):
|
||||
@@ -191,6 +194,7 @@ class LabelConversationReplyTask(Task):
|
||||
message_id: UUID
|
||||
reply: str
|
||||
valid_labels: list[str]
|
||||
mandatory_labels: Optional[list[str]]
|
||||
|
||||
|
||||
class LabelPrompterReplyTask(LabelConversationReplyTask):
|
||||
@@ -304,6 +308,7 @@ class TextLabels(Interaction):
|
||||
text: str
|
||||
labels: dict[TextLabel, float]
|
||||
message_id: UUID
|
||||
task_id: Optional[UUID]
|
||||
|
||||
@property
|
||||
def has_message_id(self) -> bool:
|
||||
|
||||
@@ -6,6 +6,7 @@ pushd "$parent_path/../../backend"
|
||||
|
||||
export DEBUG_SKIP_API_KEY_CHECK=True
|
||||
export DEBUG_USE_SEED_DATA=True
|
||||
export DEBUG_ALLOW_SELF_LABELING=True
|
||||
export DEBUG_SKIP_EMBEDDING_COMPUTATION=True
|
||||
|
||||
uvicorn main:app --reload --port 8080 --host 0.0.0.0
|
||||
|
||||
@@ -36,7 +36,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
|
||||
return response.json()
|
||||
|
||||
typer.echo("Requesting work...")
|
||||
tasks = [_post("/api/v1/tasks/", {"type": "random"})]
|
||||
tasks = [_post("/api/v1/tasks/", {"type": "random", "user": USER})]
|
||||
while tasks:
|
||||
task = tasks.pop(0)
|
||||
match (task["type"]):
|
||||
@@ -58,6 +58,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
|
||||
{
|
||||
"type": "text_reply_to_message",
|
||||
"message_id": message_id,
|
||||
"task_id": task["id"],
|
||||
"user_message_id": user_message_id,
|
||||
"text": summary,
|
||||
"user": USER,
|
||||
@@ -102,6 +103,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
|
||||
{
|
||||
"type": "text_reply_to_message",
|
||||
"message_id": message_id,
|
||||
"task_id": task["id"],
|
||||
"user_message_id": user_message_id,
|
||||
"text": prompt,
|
||||
"user": USER,
|
||||
@@ -150,6 +152,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
|
||||
{
|
||||
"type": "text_reply_to_message",
|
||||
"message_id": message_id,
|
||||
"task_id": task["id"],
|
||||
"user_message_id": user_message_id,
|
||||
"text": reply,
|
||||
"user": USER,
|
||||
@@ -200,6 +203,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
|
||||
{
|
||||
"type": "message_ranking",
|
||||
"message_id": message_id,
|
||||
"task_id": task["id"],
|
||||
"ranking": ranking,
|
||||
"user": USER,
|
||||
},
|
||||
@@ -232,6 +236,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
|
||||
{
|
||||
"type": "text_labels",
|
||||
"message_id": task["message_id"],
|
||||
"task_id": task["id"],
|
||||
"text": task["prompt"],
|
||||
"labels": labels_dict,
|
||||
"user": USER,
|
||||
@@ -269,6 +274,7 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
|
||||
{
|
||||
"type": "text_labels",
|
||||
"message_id": task["message_id"],
|
||||
"task_id": task["id"],
|
||||
"text": task["reply"],
|
||||
"labels": labels_dict,
|
||||
"user": USER,
|
||||
|
||||
Reference in New Issue
Block a user