Merge branch 'db_task_protocol' into main

This commit is contained in:
Andreas Köpf
2022-12-16 14:56:44 +01:00
8 changed files with 521 additions and 6 deletions
+1 -1
View File
@@ -56,7 +56,7 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne
# output_encoding = utf-8
# sqlalchemy.url = postgresql://<username>:<password>@<host>/<database_name>
sqlalchemy.url = postgresql://postgres:postgres@localhost:5432/postgres
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
@@ -94,7 +94,7 @@ def upgrade() -> None:
sa.Column("frontend_post_id", sa.String(200), nullable=False), # unique together with api_client_id
sa.Column("created_date", sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
sa.Column("payload_type", sa.String(200), nullable=False), # deserialization hint & dbg aid
sa.Column("payload", JSONB(astext_type=sa.Text()), nullable=False),
sa.Column("payload", JSONB(astext_type=sa.Text()), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["person_id"], ["person.id"]),
sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"]),
+13 -1
View File
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
from secrets import token_hex
from typing import Generator
from uuid import UUID
@@ -7,6 +8,7 @@ from app.database import engine
from app.models import ApiClient
from fastapi import HTTPException, Security
from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery
from loguru import logger
from sqlmodel import Session
from starlette.status import HTTP_403_FORBIDDEN
@@ -37,7 +39,17 @@ def api_auth(
if api_key is not None:
if settings.ALLOW_ANY_API_KEY:
return ApiClient(id=UUID("00000000-1111-2222-3333-444444444444"), api_key=api_key, name=api_key)
# make sure that a dummy api key exits in db (foreign key references)
ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444")
api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first()
if api_client is None:
token = token_hex(32)
logger.info(f"ANY_API_KEY missing, inserting api_key: {token}")
api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token")
db.add(api_client)
db.commit()
return api_client
api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first()
if api_client is not None and api_client.enabled:
return api_client
+3 -1
View File
@@ -1,8 +1,10 @@
# -*- coding: utf-8 -*-
from app.api.v1 import labelers, prompts, tasks
from app.api.v1 import labelers, prompts, tasks, tasks2
from fastapi import APIRouter
api_router = APIRouter()
api_router.include_router(labelers.router, prefix="/labelers", tags=["labelers"])
api_router.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"])
api_router.include_router(tasks2.router, prefix="/task2", tags=["task2"]) # temporary
+206
View File
@@ -0,0 +1,206 @@
# -*- coding: utf-8 -*-
import random
from typing import Any
from uuid import UUID
from app.api import deps
from app.prompt_repository import PromptRepository, RateSummaryPayload, TaskPayload
from app.schemas import protocol as protocol_schema
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security.api_key import APIKey
from loguru import logger
from sqlmodel import Session
from starlette.status import HTTP_400_BAD_REQUEST
router = APIRouter()
def generate_task(request: protocol_schema.TaskRequest) -> protocol_schema.Task:
match (request.type):
case protocol_schema.TaskRequestType.random:
logger.info("Frontend requested a random task.")
while request.type == protocol_schema.TaskRequestType.random:
request.type = random.choice(list(protocol_schema.TaskRequestType)).value
return generate_task(request)
case protocol_schema.TaskRequestType.summarize_story:
logger.info("Generating a SummarizeStoryTask.")
task = protocol_schema.SummarizeStoryTask(
story="This is a story. A very long story. So long, it needs to be summarized.",
)
case protocol_schema.TaskRequestType.rate_summary:
logger.info("Generating a RateSummaryTask.")
task = protocol_schema.RateSummaryTask(
full_text="This is a story. A very long story. So long, it needs to be summarized.",
summary="This is a summary.",
scale=protocol_schema.RatingScale(min=1, max=5),
)
case protocol_schema.TaskRequestType.initial_prompt:
logger.info("Generating an InitialPromptTask.")
task = protocol_schema.InitialPromptTask(
hint="Ask the assistant about a current event." # this is optional
)
case protocol_schema.TaskRequestType.user_reply:
logger.info("Generating a UserReplyTask.")
task = protocol_schema.UserReplyTask(
conversation=protocol_schema.Conversation(
messages=[
protocol_schema.ConversationMessage(
text="Hey, assistant, what's going on in the world?",
is_assistant=False,
),
protocol_schema.ConversationMessage(
text="I'm not sure I understood correctly, could you rephrase that?",
is_assistant=True,
),
],
)
)
case protocol_schema.TaskRequestType.assistant_reply:
logger.info("Generating a AssistantReplyTask.")
task = protocol_schema.AssistantReplyTask(
conversation=protocol_schema.Conversation(
messages=[
protocol_schema.ConversationMessage(
text="Hey, assistant, write me an English essay about water.",
is_assistant=False,
),
],
)
)
case _:
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
detail="Invalid request type.",
)
logger.info(f"Generated {task=}.")
if request.user is not None:
task.addressed_user = request.user
return task
@router.post("/", response_model=protocol_schema.AnyTask) # work with Union once more types are added
def request_task(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
request: protocol_schema.TaskRequest,
) -> Any:
"""
Create new task.
"""
api_client = deps.api_auth(api_key, db)
try:
task = generate_task(request)
pr = PromptRepository(db, api_client, request.user)
pr.store_task(task)
except Exception:
logger.exception("Failed to generate task.")
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
)
return task
@router.post("/{task_id}/ack")
def acknowledge_task(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
task_id: UUID,
response: protocol_schema.AnyTaskResponse,
) -> Any:
"""
The frontend acknowledges a task.
"""
api_client = deps.api_auth(api_key, db)
pr = PromptRepository(db, api_client, user=None)
match (type(response)):
case protocol_schema.PostCreatedTaskResponse:
logger.info(f"Frontend acknowledged {task_id=} and created {response.post_id=}.")
if response.status == "success":
# here we store the post id in the database for the task
pr.bind_frontend_post_id(task_id=task_id, post_id=response.post_id)
case protocol_schema.RatingCreatedTaskResponse:
logger.info(f"Frontend acknowledged {task_id=} for {response.post_id=}.")
if response.status == "success":
# here we would store the rating id in the database for the task
pr.bind_frontend_post_id(task_id=task_id, post_id=response.post_id)
case _:
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
detail="Invalid response type.",
)
return {}
@router.post("/interaction")
def post_interaction(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
interaction: protocol_schema.AnyInteraction,
) -> Any:
"""
The frontend reports an interaction.
"""
api_client = deps.api_auth(api_key, db)
pr = PromptRepository(db, api_client, user=interaction.user)
match (type(interaction)):
case protocol_schema.TextReplyToPost:
logger.info(
f"Frontend reports text reply to {interaction.post_id=} with {interaction.text=} by {interaction.user=}."
)
work_package = pr.fetch_workpackage_by_postid(interaction.post_id)
work_payload: TaskPayload = work_package.payload.payload
logger.info(f"found task work package in db: {work_payload}")
# here we store the text reply in the database
# ToDo: role user or agent?
pr.store_text_reply(interaction, role="unknown")
return protocol_schema.TaskDone(
reply_to_post_id=interaction.user_post_id,
addressed_user=interaction.user,
)
case protocol_schema.PostRating:
logger.info(
f"Frontend reports rating of {interaction.post_id=} with {interaction.rating=} by {interaction.user=}."
)
# check if rating in range
work_package = pr.fetch_workpackage_by_postid(interaction.post_id)
work_payload: RateSummaryPayload = work_package.payload.payload
if (
type(work_payload) != RateSummaryPayload
or interaction.rating < work_payload.scale.min
or interaction.rating > work_payload.scale.max
):
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
detail="Invalid response type.",
)
pr.store_rating(interaction)
# here we would store the rating in the database
return protocol_schema.TaskDone(
reply_to_post_id=interaction.post_id,
addressed_user=interaction.user,
)
case _:
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
detail="Invalid response type.",
)
+1 -1
View File
@@ -14,7 +14,7 @@ payload_type_registry = {}
P = TypeVar("P", bound=BaseModel)
def payload_tpye(cls: Type[P]) -> Type[P]:
def payload_type(cls: Type[P]) -> Type[P]:
payload_type_registry[cls.__name__] = cls
return cls
+1 -1
View File
@@ -30,4 +30,4 @@ class Post(SQLModel, table=True):
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
)
payload_type: str = Field(nullable=False, max_length=200)
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=True))
+295
View File
@@ -0,0 +1,295 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Literal, Optional
from uuid import UUID, uuid4
from app.models import ApiClient, Person, Post, PostReaction, WorkPackage
from app.models.payload_column_type import PayloadContainer, payload_type
from app.schemas import protocol as protocol_schema
from pydantic import BaseModel
from sqlmodel import Session
@payload_type
class TaskPayload(BaseModel):
type: str
@payload_type
class SummarizationStoryPayload(TaskPayload):
type: Literal["summarize_story"] = "summarize_story"
story: str
@payload_type
class RateSummaryPayload(TaskPayload):
type: Literal["rate_summary"] = "rate_summary"
full_text: str
summary: str
scale: protocol_schema.RatingScale
@payload_type
class InitialPromptPayload(TaskPayload):
type: Literal["initial_prompt"] = "initial_prompt"
hint: str
@payload_type
class UserReplyPayload(TaskPayload):
type: Literal["user_reply"] = "user_reply"
conversation: protocol_schema.Conversation
hint: str | None
@payload_type
class AssistantReplyPayload(TaskPayload):
type: Literal["assistant_reply"] = "assistant_reply"
conversation: protocol_schema.Conversation
@payload_type
class PostPayload(BaseModel):
text: str
@payload_type
class ReactionPayload(BaseModel):
type: str
@payload_type
class RatingReactionPayload(ReactionPayload):
type: Literal["post_rating"] = "post_rating"
rating: str
class PromptRepository:
def __init__(self, db: Session, api_client: ApiClient, user: Optional[protocol_schema.User]):
self.db = db
self.api_client = api_client
self.person = self.lookup_person(user)
self.person_id = self.person.id if self.person else None
def lookup_person(self, user: protocol_schema.User) -> Person:
if not user:
return None
person: Person = (
self.db.query(Person).filter(Person.api_client_id == self.api_client.id, Person.username == user.id).first()
)
if person is None:
# user is unknown, create new record
person = Person(username=user.id, display_name=user.display_name, api_client_id=self.api_client.id)
self.db.add(person)
self.db.commit()
self.db.refresh(person)
elif user.display_name and user.display_name != person.display_name:
# we found the user but the display name changed
person.display_name = user.display_name
self.db.add(person)
self.db.commit()
return person
def validate_post_id(self, post_id: str) -> None:
if not isinstance(post_id, str):
raise TypeError("post_id must be string")
if not post_id:
raise ValueError("post_id must not be empty")
def bind_frontend_post_id(self, task_id: UUID, post_id: str):
self.validate_post_id(post_id)
# find work package
work_pack: WorkPackage = (
self.db.query(WorkPackage)
.filter(WorkPackage.id == task_id, WorkPackage.api_client_id == self.api_client.id)
.first()
)
if work_pack is None:
raise RuntimeError(f"WorkPackage for task {task_id} not found")
if work_pack.expiry_date is not None and datetime.utcnow() > work_pack.expiry_date:
raise RuntimeError("WorkPackage already expired.")
# ToDo: check race-condition, transaction
# check if task thread exits
thread_root = (
self.db.query(Post)
.filter(
Post.workpackage_id == work_pack.id,
Post.frontend_post_id == post_id,
Post.parent_id is None,
self.api_client == self.api_client,
)
.one_or_none()
)
if thread_root is None:
thread_id = uuid4()
thread_root = self.insert_post(
post_id=thread_id,
thread_id=thread_id,
frontend_post_id=post_id,
parent_id=None,
role="system",
workpackage_id=work_pack.id,
payload=None,
payload_type="bind",
)
return thread_root
def fetch_post_by_frontend_post_id(self, frontend_post_id: str, fail_if_missing: bool = True) -> Post:
self.validate_post_id(frontend_post_id)
post: Post = (
self.db.query(Post)
.filter(Post.api_client_id == self.api_client.id, Post.frontend_post_id == frontend_post_id)
.one_or_none()
)
if fail_if_missing and post is None:
raise RuntimeError(f"Post with post_id {frontend_post_id} not found.")
return post
def fetch_workpackage_by_postid(self, post_id: str) -> WorkPackage:
self.validate_post_id(post_id)
post = self.fetch_post_by_frontend_post_id(post_id, fail_if_missing=True)
work_pack = self.db.query(WorkPackage).filter(WorkPackage.id == post.workpackage_id).one()
return work_pack
def store_text_reply(self, reply: protocol_schema.TextReplyToPost, role: str) -> Post:
self.validate_post_id(reply.post_id)
self.validate_post_id(reply.user_post_id)
# find post with post-id
parent_post: Post = (
self.db.query(Post)
.filter(
Post.api_client_id == self.api_client.id,
Post.frontend_post_id == reply.post_id,
# Post.person_id == self.person_id
)
.one_or_none()
)
if parent_post is None:
raise RuntimeError(f"Post for post_id {reply.post_id} not found.")
# create reply post
user_post_id = uuid4()
user_post = self.insert_post(
post_id=user_post_id,
frontend_post_id=reply.user_post_id,
parent_id=parent_post.id,
thread_id=parent_post.thread_id,
workpackage_id=parent_post.workpackage_id,
role=role,
payload=PostPayload(text=reply.text),
)
return user_post
def store_rating(self, rating: protocol_schema.PostRating) -> Post:
post = self.fetch_post_by_frontend_post_id(rating.post_id, fail_if_missing=True)
work_package = self.fetch_workpackage_by_postid(rating.post_id)
work_payload: RateSummaryPayload = work_package.payload.payload
if type(work_payload) != RateSummaryPayload:
raise RuntimeError("work_package payload type missmatch")
if rating.rating < work_payload.scale.min or rating.rating > work_payload.scale.max:
raise ValueError("Invalid rating value")
# store reaction to post
reaction_payload = RatingReactionPayload(rating=rating.rating)
reaction = self.insert_reaction(post.id, reaction_payload)
return reaction
def store_task(self, task: protocol_schema.Task) -> WorkPackage:
payload: TaskPayload = None
match type(task):
case protocol_schema.SummarizeStoryTask:
payload = SummarizationStoryPayload(story=task.story)
case protocol_schema.RateSummaryTask:
payload = RateSummaryPayload(full_text=task.full_text, summary=task.summary, scale=task.scale)
case protocol_schema.InitialPromptTask:
payload = InitialPromptPayload(hint=task.hint)
case protocol_schema.UserReplyTask:
payload = UserReplyPayload(conversation=task.conversation, hint=task.hint)
case protocol_schema.AssistantReplyTask:
payload = AssistantReplyPayload(type=task.type, conversation=task.conversation)
case _:
raise RuntimeError(
detail="Invalid task type.",
)
wp = self.insert_work_package(payload=payload, id=task.id)
assert wp.id == task.id
return wp
def insert_work_package(self, payload: TaskPayload, id: UUID = None) -> WorkPackage:
c = PayloadContainer(payload=payload)
wp = WorkPackage(
id=id,
person_id=self.person_id,
payload_type=type(payload).__name__,
payload=c,
api_client_id=self.api_client.id,
)
self.db.add(wp)
self.db.commit()
self.db.refresh(wp)
return wp
def insert_post(
self,
*,
post_id: UUID,
frontend_post_id: str,
parent_id: UUID,
thread_id: UUID,
workpackage_id: UUID,
role: str,
payload: PostPayload,
payload_type: str = None,
) -> Post:
if payload_type is None:
if payload is None:
payload_type = "null"
else:
payload_type = type(payload).__name__
post = Post(
id=post_id,
parent_id=parent_id,
thread_id=thread_id,
workpackage_id=workpackage_id,
person_id=self.person_id,
role=role,
frontend_post_id=frontend_post_id,
api_client_id=self.api_client.id,
payload_type=payload_type,
payload=PayloadContainer(payload=payload),
)
self.db.add(post)
self.db.commit()
self.db.refresh(post)
return post
def insert_reaction(self, post_id: UUID, payload: ReactionPayload) -> PostReaction:
if self.person_id is None:
raise RuntimeError("User required")
container = PayloadContainer(payload=payload)
reaction = PostReaction(
post_id=post_id,
person_id=self.person_id,
payload=container,
api_client_id=self.api_client.id,
payload_type=type(payload).__name__,
)
self.db.add(reaction)
self.db.commit()
self.db.refresh(reaction)
return reaction