mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-01 16:50:12 +08:00
renamed to open assistant
This commit is contained in:
@@ -3,7 +3,7 @@ from logging.config import fileConfig
|
||||
|
||||
import sqlmodel
|
||||
from alembic import context
|
||||
from ocgpt import models # noqa: F401
|
||||
from oasst import models # noqa: F401
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
|
||||
+2
-2
@@ -5,8 +5,8 @@ import alembic.command
|
||||
import alembic.config
|
||||
import fastapi
|
||||
from loguru import logger
|
||||
from ocgpt.api.v1.api import api_router
|
||||
from ocgpt.config import settings
|
||||
from oasst.api.v1.api import api_router
|
||||
from oasst.config import settings
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
app = fastapi.FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V1_STR}/openapi.json")
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from secrets import token_hex
|
||||
from typing import Generator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException, Security
|
||||
from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery
|
||||
from loguru import logger
|
||||
from oasst.config import settings
|
||||
from oasst.database import engine
|
||||
from oasst.models import ApiClient
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
def get_db() -> Generator:
|
||||
with Session(engine) as db:
|
||||
yield db
|
||||
|
||||
|
||||
api_key_query = APIKeyQuery(name="api_key", auto_error=False)
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
|
||||
|
||||
async def get_api_key(
|
||||
api_key_query: str = Security(api_key_query),
|
||||
api_key_header: str = Security(api_key_header),
|
||||
):
|
||||
if api_key_query:
|
||||
return api_key_query
|
||||
else:
|
||||
return api_key_header
|
||||
|
||||
|
||||
def api_auth(
|
||||
api_key: APIKey,
|
||||
db: Session,
|
||||
) -> ApiClient:
|
||||
|
||||
if api_key is not None:
|
||||
if settings.ALLOW_ANY_API_KEY:
|
||||
# make sure that a dummy api key exits in db (foreign key references)
|
||||
ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444")
|
||||
api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first()
|
||||
if api_client is None:
|
||||
token = token_hex(32)
|
||||
logger.info(f"ANY_API_KEY missing, inserting api_key: {token}")
|
||||
api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token")
|
||||
db.add(api_client)
|
||||
db.commit()
|
||||
return api_client
|
||||
|
||||
api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first()
|
||||
if api_client is not None and api_client.enabled:
|
||||
return api_client
|
||||
|
||||
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials")
|
||||
@@ -0,0 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from fastapi import APIRouter
|
||||
from oasst.api.v1 import tasks
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"])
|
||||
@@ -0,0 +1,259 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import random
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security.api_key import APIKey
|
||||
from loguru import logger
|
||||
from oasst.api import deps
|
||||
from oasst.models.db_payload import TaskPayload
|
||||
from oasst.prompt_repository import PromptRepository
|
||||
from oasst.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_400_BAD_REQUEST
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def generate_task(request: protocol_schema.TaskRequest) -> protocol_schema.Task:
|
||||
match (request.type):
|
||||
case protocol_schema.TaskRequestType.random:
|
||||
logger.info("Frontend requested a random task.")
|
||||
while request.type == protocol_schema.TaskRequestType.random:
|
||||
request.type = random.choice(list(protocol_schema.TaskRequestType)).value
|
||||
return generate_task(request)
|
||||
case protocol_schema.TaskRequestType.summarize_story:
|
||||
logger.info("Generating a SummarizeStoryTask.")
|
||||
task = protocol_schema.SummarizeStoryTask(
|
||||
story="This is a story. A very long story. So long, it needs to be summarized.",
|
||||
)
|
||||
case protocol_schema.TaskRequestType.rate_summary:
|
||||
logger.info("Generating a RateSummaryTask.")
|
||||
task = protocol_schema.RateSummaryTask(
|
||||
full_text="This is a story. A very long story. So long, it needs to be summarized.",
|
||||
summary="This is a summary.",
|
||||
scale=protocol_schema.RatingScale(min=1, max=5),
|
||||
)
|
||||
case protocol_schema.TaskRequestType.initial_prompt:
|
||||
logger.info("Generating an InitialPromptTask.")
|
||||
task = protocol_schema.InitialPromptTask(
|
||||
hint="Ask the assistant about a current event." # this is optional
|
||||
)
|
||||
case protocol_schema.TaskRequestType.user_reply:
|
||||
logger.info("Generating a UserReplyTask.")
|
||||
task = protocol_schema.UserReplyTask(
|
||||
conversation=protocol_schema.Conversation(
|
||||
messages=[
|
||||
protocol_schema.ConversationMessage(
|
||||
text="Hey, assistant, what's going on in the world?",
|
||||
is_assistant=False,
|
||||
),
|
||||
protocol_schema.ConversationMessage(
|
||||
text="I'm not sure I understood correctly, could you rephrase that?",
|
||||
is_assistant=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
case protocol_schema.TaskRequestType.assistant_reply:
|
||||
logger.info("Generating a AssistantReplyTask.")
|
||||
task = protocol_schema.AssistantReplyTask(
|
||||
conversation=protocol_schema.Conversation(
|
||||
messages=[
|
||||
protocol_schema.ConversationMessage(
|
||||
text="Hey, assistant, write me an English essay about water.",
|
||||
is_assistant=False,
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
case protocol_schema.TaskRequestType.rank_initial_prompts:
|
||||
logger.info("Generating a RankInitialPromptsTask.")
|
||||
task = protocol_schema.RankInitialPromptsTask(
|
||||
prompts=[
|
||||
"Please write a story about a time you were happy.",
|
||||
"Please write a story about a time you were sad.",
|
||||
]
|
||||
)
|
||||
case protocol_schema.TaskRequestType.rank_user_replies:
|
||||
logger.info("Generating a RankUserRepliesTask.")
|
||||
task = protocol_schema.RankUserRepliesTask(
|
||||
conversation=protocol_schema.Conversation(
|
||||
messages=[
|
||||
protocol_schema.ConversationMessage(
|
||||
text="Hey, assistant, what's going on in the world?",
|
||||
is_assistant=False,
|
||||
),
|
||||
protocol_schema.ConversationMessage(
|
||||
text="I'm not sure I understood correctly, could you rephrase that?",
|
||||
is_assistant=True,
|
||||
),
|
||||
],
|
||||
),
|
||||
replies=[
|
||||
"Oh come oooooon!",
|
||||
"What are the news?",
|
||||
],
|
||||
)
|
||||
|
||||
case protocol_schema.TaskRequestType.rank_assistant_replies:
|
||||
logger.info("Generating a RankAssistantRepliesTask.")
|
||||
task = protocol_schema.RankAssistantRepliesTask(
|
||||
conversation=protocol_schema.Conversation(
|
||||
messages=[
|
||||
protocol_schema.ConversationMessage(
|
||||
text="Hey, assistant, what's going on in the world?",
|
||||
is_assistant=False,
|
||||
),
|
||||
],
|
||||
),
|
||||
replies=[
|
||||
"I'm not sure I understood correctly, could you rephrase that?",
|
||||
"The world is fine. All good.",
|
||||
"Crap is hitting the fan. Start farming.",
|
||||
],
|
||||
)
|
||||
case _:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid request type.",
|
||||
)
|
||||
|
||||
logger.info(f"Generated {task=}.")
|
||||
|
||||
return task
|
||||
|
||||
|
||||
@router.post("/", response_model=protocol_schema.AnyTask) # work with Union once more types are added
|
||||
def request_task(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
request: protocol_schema.TaskRequest,
|
||||
) -> Any:
|
||||
"""
|
||||
Create new task.
|
||||
"""
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
|
||||
try:
|
||||
task = generate_task(request)
|
||||
|
||||
pr = PromptRepository(db, api_client, request.user)
|
||||
pr.store_task(task)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to generate task.")
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
return task
|
||||
|
||||
|
||||
@router.post("/{task_id}/ack")
|
||||
def acknowledge_task(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
task_id: UUID,
|
||||
ack_request: protocol_schema.TaskAck,
|
||||
) -> Any:
|
||||
"""
|
||||
The frontend acknowledges a task.
|
||||
"""
|
||||
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
|
||||
try:
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
|
||||
# here we store the post id in the database for the task
|
||||
pr.bind_frontend_post_id(task_id=task_id, post_id=ack_request.post_id)
|
||||
logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.")
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to acknowledge task.")
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
return {}
|
||||
|
||||
|
||||
@router.post("/{task_id}/nack")
|
||||
def acknowledge_task_failure(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
task_id: UUID,
|
||||
nack_request: protocol_schema.TaskNAck,
|
||||
) -> Any:
|
||||
"""
|
||||
The frontend reports failure to implement a task.
|
||||
"""
|
||||
deps.api_auth(api_key, db)
|
||||
|
||||
logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.")
|
||||
# here we would store the post id in the database for the task
|
||||
return {}
|
||||
|
||||
|
||||
@router.post("/interaction")
|
||||
def post_interaction(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
interaction: protocol_schema.AnyInteraction,
|
||||
) -> Any:
|
||||
"""
|
||||
The frontend reports an interaction.
|
||||
"""
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
|
||||
try:
|
||||
pr = PromptRepository(db, api_client, user=interaction.user)
|
||||
|
||||
match type(interaction):
|
||||
case protocol_schema.TextReplyToPost:
|
||||
logger.info(
|
||||
f"Frontend reports text reply to {interaction.post_id=} with {interaction.text=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
work_package = pr.fetch_workpackage_by_postid(interaction.post_id)
|
||||
work_payload: TaskPayload = work_package.payload.payload
|
||||
logger.info(f"found task work package in db: {work_payload}")
|
||||
|
||||
# here we store the text reply in the database
|
||||
# ToDo: role user or agent?
|
||||
pr.store_text_reply(interaction, role="unknown")
|
||||
|
||||
return protocol_schema.TaskDone()
|
||||
case protocol_schema.PostRating:
|
||||
logger.info(
|
||||
f"Frontend reports rating of {interaction.post_id=} with {interaction.rating=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
# here we store the rating in the database
|
||||
pr.store_rating(interaction)
|
||||
|
||||
return protocol_schema.TaskDone()
|
||||
case protocol_schema.PostRanking:
|
||||
logger.info(
|
||||
f"Frontend reports ranking of {interaction.post_id=} with {interaction.ranking=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
# TODO: check if the ranking is valid
|
||||
pr.store_ranking(interaction)
|
||||
# here we would store the ranking in the database
|
||||
return protocol_schema.TaskDone()
|
||||
case _:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid response type.",
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Interaction request failed.")
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
@@ -0,0 +1,8 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from oasst.config import settings
|
||||
from sqlmodel import create_engine
|
||||
|
||||
if settings.DATABASE_URI is None:
|
||||
raise ValueError("DATABASE_URI is not set")
|
||||
|
||||
engine = create_engine(settings.DATABASE_URI)
|
||||
@@ -0,0 +1,94 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Literal
|
||||
|
||||
from oasst.models.payload_column_type import payload_type
|
||||
from oasst.schemas import protocol as protocol_schema
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@payload_type
|
||||
class TaskPayload(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
@payload_type
|
||||
class SummarizationStoryPayload(TaskPayload):
|
||||
type: Literal["summarize_story"] = "summarize_story"
|
||||
story: str
|
||||
|
||||
|
||||
@payload_type
|
||||
class RateSummaryPayload(TaskPayload):
|
||||
type: Literal["rate_summary"] = "rate_summary"
|
||||
full_text: str
|
||||
summary: str
|
||||
scale: protocol_schema.RatingScale
|
||||
|
||||
|
||||
@payload_type
|
||||
class InitialPromptPayload(TaskPayload):
|
||||
type: Literal["initial_prompt"] = "initial_prompt"
|
||||
hint: str
|
||||
|
||||
|
||||
@payload_type
|
||||
class UserReplyPayload(TaskPayload):
|
||||
type: Literal["user_reply"] = "user_reply"
|
||||
conversation: protocol_schema.Conversation
|
||||
hint: str | None
|
||||
|
||||
|
||||
@payload_type
|
||||
class AssistantReplyPayload(TaskPayload):
|
||||
type: Literal["assistant_reply"] = "assistant_reply"
|
||||
conversation: protocol_schema.Conversation
|
||||
|
||||
|
||||
@payload_type
|
||||
class PostPayload(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
@payload_type
|
||||
class ReactionPayload(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
@payload_type
|
||||
class RatingReactionPayload(ReactionPayload):
|
||||
type: Literal["post_rating"] = "post_rating"
|
||||
rating: str
|
||||
|
||||
|
||||
@payload_type
|
||||
class RankingReactionPayload(ReactionPayload):
|
||||
type: Literal["post_ranking"] = "post_ranking"
|
||||
ranking: list[int]
|
||||
|
||||
|
||||
@payload_type
|
||||
class RankConversationRepliesPayload(TaskPayload):
|
||||
conversation: protocol_schema.Conversation # the conversation so far
|
||||
replies: list[str]
|
||||
|
||||
|
||||
@payload_type
|
||||
class RankInitialPromptsPayload(TaskPayload):
|
||||
"""A task to rank a set of initial prompts."""
|
||||
|
||||
type: Literal["rank_initial_prompts"] = "rank_initial_prompts"
|
||||
prompts: list[str]
|
||||
|
||||
|
||||
@payload_type
|
||||
class RankUserRepliesPayload(RankConversationRepliesPayload):
|
||||
"""A task to rank a set of user replies to a conversation."""
|
||||
|
||||
type: Literal["rank_user_replies"] = "rank_user_replies"
|
||||
|
||||
|
||||
@payload_type
|
||||
class RankAssistantRepliesPayload(RankConversationRepliesPayload):
|
||||
"""A task to rank a set of assistant replies to a conversation."""
|
||||
|
||||
type: Literal["rank_assistant_replies"] = "rank_assistant_replies"
|
||||
@@ -0,0 +1,311 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import oasst.models.db_payload as db_payload
|
||||
from loguru import logger
|
||||
from oasst.models import ApiClient, Person, Post, PostReaction, WorkPackage
|
||||
from oasst.models.payload_column_type import PayloadContainer
|
||||
from oasst.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
|
||||
|
||||
class PromptRepository:
|
||||
def __init__(self, db: Session, api_client: ApiClient, user: Optional[protocol_schema.User]):
|
||||
self.db = db
|
||||
self.api_client = api_client
|
||||
self.person = self.lookup_person(user)
|
||||
self.person_id = self.person.id if self.person else None
|
||||
|
||||
def lookup_person(self, user: protocol_schema.User) -> Person:
|
||||
if not user:
|
||||
return None
|
||||
person: Person = (
|
||||
self.db.query(Person)
|
||||
.filter(
|
||||
Person.api_client_id == self.api_client.id,
|
||||
Person.username == user.id,
|
||||
Person.auth_method == user.auth_method,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if person is None:
|
||||
# user is unknown, create new record
|
||||
person = Person(username=user.id, display_name=user.display_name, api_client_id=self.api_client.id)
|
||||
self.db.add(person)
|
||||
self.db.commit()
|
||||
self.db.refresh(person)
|
||||
elif user.display_name and user.display_name != person.display_name:
|
||||
# we found the user but the display name changed
|
||||
person.display_name = user.display_name
|
||||
self.db.add(person)
|
||||
self.db.commit()
|
||||
return person
|
||||
|
||||
def validate_post_id(self, post_id: str) -> None:
|
||||
if not isinstance(post_id, str):
|
||||
raise TypeError(f"post_id must be string, not {type(post_id)}")
|
||||
if not post_id:
|
||||
raise ValueError("post_id must not be empty")
|
||||
|
||||
def bind_frontend_post_id(self, task_id: UUID, post_id: str):
|
||||
self.validate_post_id(post_id)
|
||||
|
||||
# find work package
|
||||
work_pack: WorkPackage = (
|
||||
self.db.query(WorkPackage)
|
||||
.filter(WorkPackage.id == task_id, WorkPackage.api_client_id == self.api_client.id)
|
||||
.first()
|
||||
)
|
||||
if work_pack is None:
|
||||
raise KeyError(f"WorkPackage for task {task_id} not found")
|
||||
if work_pack.expiry_date is not None and datetime.utcnow() > work_pack.expiry_date:
|
||||
raise RuntimeError("WorkPackage already expired.")
|
||||
|
||||
# ToDo: check race-condition, transaction
|
||||
|
||||
# check if task thread exits
|
||||
thread_root = (
|
||||
self.db.query(Post)
|
||||
.filter(
|
||||
Post.workpackage_id == work_pack.id,
|
||||
Post.frontend_post_id == post_id,
|
||||
Post.parent_id is None,
|
||||
Post.api_client_id == self.api_client.id,
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
if thread_root is None:
|
||||
thread_id = uuid4()
|
||||
thread_root = self.insert_post(
|
||||
post_id=thread_id,
|
||||
thread_id=thread_id,
|
||||
frontend_post_id=post_id,
|
||||
parent_id=None,
|
||||
role="system",
|
||||
workpackage_id=work_pack.id,
|
||||
payload=None,
|
||||
payload_type="bind",
|
||||
)
|
||||
return thread_root
|
||||
|
||||
def fetch_post_by_frontend_post_id(self, frontend_post_id: str, fail_if_missing: bool = True) -> Post:
|
||||
self.validate_post_id(frontend_post_id)
|
||||
post: Post = (
|
||||
self.db.query(Post)
|
||||
.filter(Post.api_client_id == self.api_client.id, Post.frontend_post_id == frontend_post_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if fail_if_missing and post is None:
|
||||
raise KeyError(f"Post with post_id {frontend_post_id} not found.")
|
||||
return post
|
||||
|
||||
def fetch_workpackage_by_postid(self, post_id: str) -> WorkPackage:
|
||||
self.validate_post_id(post_id)
|
||||
post = self.fetch_post_by_frontend_post_id(post_id, fail_if_missing=True)
|
||||
work_pack = self.db.query(WorkPackage).filter(WorkPackage.id == post.workpackage_id).one()
|
||||
return work_pack
|
||||
|
||||
def store_text_reply(self, reply: protocol_schema.TextReplyToPost, role: str) -> Post:
|
||||
self.validate_post_id(reply.post_id)
|
||||
self.validate_post_id(reply.user_post_id)
|
||||
|
||||
# find post with post-id
|
||||
parent_post: Post = (
|
||||
self.db.query(Post)
|
||||
.filter(
|
||||
Post.api_client_id == self.api_client.id,
|
||||
Post.frontend_post_id == reply.post_id,
|
||||
# Post.person_id == self.person_id
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
if parent_post is None:
|
||||
raise KeyError(f"Post for post_id {reply.post_id} not found.")
|
||||
|
||||
# create reply post
|
||||
user_post_id = uuid4()
|
||||
user_post = self.insert_post(
|
||||
post_id=user_post_id,
|
||||
frontend_post_id=reply.user_post_id,
|
||||
parent_id=parent_post.id,
|
||||
thread_id=parent_post.thread_id,
|
||||
workpackage_id=parent_post.workpackage_id,
|
||||
role=role,
|
||||
payload=db_payload.PostPayload(text=reply.text),
|
||||
)
|
||||
return user_post
|
||||
|
||||
def store_rating(self, rating: protocol_schema.PostRating) -> PostReaction:
|
||||
post = self.fetch_post_by_frontend_post_id(rating.post_id, fail_if_missing=True)
|
||||
|
||||
work_package = self.fetch_workpackage_by_postid(rating.post_id)
|
||||
work_payload: db_payload.RateSummaryPayload = work_package.payload.payload
|
||||
if type(work_payload) != db_payload.RateSummaryPayload:
|
||||
raise ValueError(
|
||||
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RateSummaryPayload}"
|
||||
)
|
||||
|
||||
if rating.rating < work_payload.scale.min or rating.rating > work_payload.scale.max:
|
||||
raise ValueError(f"Invalid rating value: {rating.rating=} not in {work_payload.scale=}")
|
||||
|
||||
# store reaction to post
|
||||
reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating)
|
||||
reaction = self.insert_reaction(post.id, reaction_payload)
|
||||
logger.info(f"Ranking {rating.rating} stored for work_package {work_package.id}.")
|
||||
return reaction
|
||||
|
||||
def store_ranking(self, ranking: protocol_schema.PostRanking) -> PostReaction:
|
||||
post = self.fetch_post_by_frontend_post_id(ranking.post_id, fail_if_missing=True)
|
||||
|
||||
# fetch work_package
|
||||
work_package = self.fetch_workpackage_by_postid(ranking.post_id)
|
||||
work_payload: db_payload.RankConversationRepliesPayload | db_payload.RankInitialPromptsPayload = (
|
||||
work_package.payload.payload
|
||||
)
|
||||
|
||||
match type(work_payload):
|
||||
|
||||
case db_payload.RankUserRepliesPayload | db_payload.RankAssistantRepliesPayload:
|
||||
# validate ranking
|
||||
num_replies = len(work_payload.replies)
|
||||
if sorted(ranking.ranking) != list(range(num_replies)):
|
||||
raise ValueError(
|
||||
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_replies=})."
|
||||
)
|
||||
|
||||
# store reaction to post
|
||||
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
|
||||
reaction = self.insert_reaction(post.id, reaction_payload)
|
||||
|
||||
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
|
||||
|
||||
return reaction
|
||||
|
||||
case db_payload.RankInitialPromptsPayload:
|
||||
# validate ranking
|
||||
if sorted(ranking.ranking) != list(range(num_prompts := len(work_payload.prompts))):
|
||||
raise ValueError(
|
||||
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_prompts=})."
|
||||
)
|
||||
|
||||
# store reaction to post
|
||||
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
|
||||
reaction = self.insert_reaction(post.id, reaction_payload)
|
||||
|
||||
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
|
||||
|
||||
return reaction
|
||||
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RankConversationRepliesPayload}"
|
||||
)
|
||||
|
||||
def store_task(self, task: protocol_schema.Task) -> WorkPackage:
|
||||
payload: db_payload.TaskPayload
|
||||
match type(task):
|
||||
case protocol_schema.SummarizeStoryTask:
|
||||
payload = db_payload.SummarizationStoryPayload(story=task.story)
|
||||
|
||||
case protocol_schema.RateSummaryTask:
|
||||
payload = db_payload.RateSummaryPayload(
|
||||
full_text=task.full_text, summary=task.summary, scale=task.scale
|
||||
)
|
||||
|
||||
case protocol_schema.InitialPromptTask:
|
||||
payload = db_payload.InitialPromptPayload(hint=task.hint)
|
||||
|
||||
case protocol_schema.UserReplyTask:
|
||||
payload = db_payload.UserReplyPayload(conversation=task.conversation, hint=task.hint)
|
||||
|
||||
case protocol_schema.AssistantReplyTask:
|
||||
payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
|
||||
|
||||
case protocol_schema.RankInitialPromptsTask:
|
||||
payload = db_payload.RankInitialPromptsPayload(tpye=task.type, prompts=task.prompts)
|
||||
|
||||
case protocol_schema.RankUserRepliesTask:
|
||||
payload = db_payload.RankUserRepliesPayload(
|
||||
tpye=task.type, conversation=task.conversation, replies=task.replies
|
||||
)
|
||||
|
||||
case protocol_schema.RankAssistantRepliesTask:
|
||||
payload = db_payload.RankAssistantRepliesPayload(
|
||||
tpye=task.type, conversation=task.conversation, replies=task.replies
|
||||
)
|
||||
|
||||
case _:
|
||||
raise ValueError(f"Invalid task type: {type(task)=}")
|
||||
|
||||
wp = self.insert_work_package(payload=payload, id=task.id)
|
||||
assert wp.id == task.id
|
||||
return wp
|
||||
|
||||
def insert_work_package(self, payload: db_payload.TaskPayload, id: UUID = None) -> WorkPackage:
|
||||
c = PayloadContainer(payload=payload)
|
||||
wp = WorkPackage(
|
||||
id=id,
|
||||
person_id=self.person_id,
|
||||
payload_type=type(payload).__name__,
|
||||
payload=c,
|
||||
api_client_id=self.api_client.id,
|
||||
)
|
||||
self.db.add(wp)
|
||||
self.db.commit()
|
||||
self.db.refresh(wp)
|
||||
return wp
|
||||
|
||||
def insert_post(
|
||||
self,
|
||||
*,
|
||||
post_id: UUID,
|
||||
frontend_post_id: str,
|
||||
parent_id: UUID,
|
||||
thread_id: UUID,
|
||||
workpackage_id: UUID,
|
||||
role: str,
|
||||
payload: db_payload.PostPayload,
|
||||
payload_type: str = None,
|
||||
) -> Post:
|
||||
if payload_type is None:
|
||||
if payload is None:
|
||||
payload_type = "null"
|
||||
else:
|
||||
payload_type = type(payload).__name__
|
||||
|
||||
post = Post(
|
||||
id=post_id,
|
||||
parent_id=parent_id,
|
||||
thread_id=thread_id,
|
||||
workpackage_id=workpackage_id,
|
||||
person_id=self.person_id,
|
||||
role=role,
|
||||
frontend_post_id=frontend_post_id,
|
||||
api_client_id=self.api_client.id,
|
||||
payload_type=payload_type,
|
||||
payload=PayloadContainer(payload=payload),
|
||||
)
|
||||
self.db.add(post)
|
||||
self.db.commit()
|
||||
self.db.refresh(post)
|
||||
return post
|
||||
|
||||
def insert_reaction(self, post_id: UUID, payload: db_payload.ReactionPayload) -> PostReaction:
|
||||
if self.person_id is None:
|
||||
raise ValueError("User required")
|
||||
|
||||
container = PayloadContainer(payload=payload)
|
||||
reaction = PostReaction(
|
||||
post_id=post_id,
|
||||
person_id=self.person_id,
|
||||
payload=container,
|
||||
api_client_id=self.api_client.id,
|
||||
payload_type=type(payload).__name__,
|
||||
)
|
||||
self.db.add(reaction)
|
||||
self.db.commit()
|
||||
self.db.refresh(reaction)
|
||||
return reaction
|
||||
Reference in New Issue
Block a user