mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-30 16:40:05 +08:00
insert tasks into work_package table, add PromptRepository
This commit is contained in:
+1
-1
@@ -56,7 +56,7 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne
|
||||
# output_encoding = utf-8
|
||||
|
||||
# sqlalchemy.url = postgresql://<username>:<password>@<host>/<database_name>
|
||||
|
||||
sqlalchemy.url = postgresql://postgres:postgres@localhost:5432/postgres
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
|
||||
+13
-1
@@ -1,4 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from secrets import token_hex
|
||||
from typing import Generator
|
||||
from uuid import UUID
|
||||
|
||||
@@ -7,6 +8,7 @@ from app.database import engine
|
||||
from app.models import ApiClient
|
||||
from fastapi import HTTPException, Security
|
||||
from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery
|
||||
from loguru import logger
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
|
||||
@@ -37,7 +39,17 @@ def api_auth(
|
||||
|
||||
if api_key is not None:
|
||||
if settings.ALLOW_ANY_API_KEY:
|
||||
return ApiClient(id=UUID("00000000-1111-2222-3333-444444444444"), api_key=api_key, name=api_key)
|
||||
# make sure that a dummy api key exits in db (foreign key references)
|
||||
ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444")
|
||||
api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first()
|
||||
if api_client is None:
|
||||
token = token_hex(32)
|
||||
logger.info(f"ANY_API_KEY missing, inserting api_key: {token}")
|
||||
api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token")
|
||||
db.add(api_client)
|
||||
db.commit()
|
||||
return api_client
|
||||
|
||||
api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first()
|
||||
if api_client is not None and api_client.enabled:
|
||||
return api_client
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from app.api.v1 import labelers, prompts, tasks
|
||||
from app.api.v1 import labelers, prompts, tasks, tasks2
|
||||
from fastapi import APIRouter
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(labelers.router, prefix="/labelers", tags=["labelers"])
|
||||
api_router.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
|
||||
api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"])
|
||||
|
||||
api_router.include_router(tasks2.router, prefix="/task2", tags=["task2"]) # temporary
|
||||
|
||||
@@ -0,0 +1,173 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import random
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from app.api import deps
|
||||
from app.prompt_repository import PromptRepository
|
||||
from app.schemas import protocol as protocol_schema
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security.api_key import APIKey
|
||||
from loguru import logger
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_400_BAD_REQUEST
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def generate_task(request: protocol_schema.TaskRequest) -> protocol_schema.Task:
|
||||
match (request.type):
|
||||
case protocol_schema.TaskRequestType.random:
|
||||
logger.info("Frontend requested a random task.")
|
||||
while request.type == protocol_schema.TaskRequestType.random:
|
||||
request.type = random.choice(list(protocol_schema.TaskRequestType)).value
|
||||
return generate_task(request)
|
||||
case protocol_schema.TaskRequestType.summarize_story:
|
||||
logger.info("Generating a SummarizeStoryTask.")
|
||||
task = protocol_schema.SummarizeStoryTask(
|
||||
story="This is a story. A very long story. So long, it needs to be summarized.",
|
||||
)
|
||||
case protocol_schema.TaskRequestType.rate_summary:
|
||||
logger.info("Generating a RateSummaryTask.")
|
||||
task = protocol_schema.RateSummaryTask(
|
||||
full_text="This is a story. A very long story. So long, it needs to be summarized.",
|
||||
summary="This is a summary.",
|
||||
scale=protocol_schema.RatingScale(min=1, max=5),
|
||||
)
|
||||
case protocol_schema.TaskRequestType.initial_prompt:
|
||||
logger.info("Generating an InitialPromptTask.")
|
||||
task = protocol_schema.InitialPromptTask(
|
||||
hint="Ask the assistant about a current event." # this is optional
|
||||
)
|
||||
case protocol_schema.TaskRequestType.user_reply:
|
||||
logger.info("Generating a UserReplyTask.")
|
||||
task = protocol_schema.UserReplyTask(
|
||||
conversation=protocol_schema.Conversation(
|
||||
messages=[
|
||||
protocol_schema.ConversationMessage(
|
||||
text="Hey, assistant, what's going on in the world?",
|
||||
is_assistant=False,
|
||||
),
|
||||
protocol_schema.ConversationMessage(
|
||||
text="I'm not sure I understood correctly, could you rephrase that?",
|
||||
is_assistant=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
case protocol_schema.TaskRequestType.assistant_reply:
|
||||
logger.info("Generating a AssistantReplyTask.")
|
||||
task = protocol_schema.AssistantReplyTask(
|
||||
conversation=protocol_schema.Conversation(
|
||||
messages=[
|
||||
protocol_schema.ConversationMessage(
|
||||
text="Hey, assistant, write me an English essay about water.",
|
||||
is_assistant=False,
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
case _:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid request type.",
|
||||
)
|
||||
logger.info(f"Generated {task=}.")
|
||||
if request.user is not None:
|
||||
task.addressed_user = request.user
|
||||
|
||||
return task
|
||||
|
||||
|
||||
@router.post("/", response_model=protocol_schema.AnyTask) # work with Union once more types are added
|
||||
def request_task(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
request: protocol_schema.TaskRequest,
|
||||
) -> Any:
|
||||
"""
|
||||
Create new task.
|
||||
"""
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
|
||||
try:
|
||||
task = generate_task(request)
|
||||
|
||||
pr = PromptRepository(db, api_client, request.user)
|
||||
pr.store_task(task)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to generate task.")
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
return task
|
||||
|
||||
|
||||
@router.post("/{task_id}/ack")
|
||||
def acknowledge_task(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
task_id: UUID,
|
||||
response: protocol_schema.AnyTaskResponse,
|
||||
) -> Any:
|
||||
"""
|
||||
The frontend acknowledges a task.
|
||||
"""
|
||||
deps.api_auth(api_key, db)
|
||||
|
||||
match (type(response)):
|
||||
case protocol_schema.PostCreatedTaskResponse:
|
||||
logger.info(f"Frontend acknowledged {task_id=} and created {response.post_id=}.")
|
||||
# here we would store the post id in the database for the task
|
||||
case protocol_schema.RatingCreatedTaskResponse:
|
||||
logger.info(f"Frontend acknowledged {task_id=} for {response.post_id=}.")
|
||||
# here we would store the rating id in the database for the task
|
||||
case _:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid response type.",
|
||||
)
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
@router.post("/interaction")
|
||||
def post_interaction(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
interaction: protocol_schema.AnyInteraction,
|
||||
) -> Any:
|
||||
"""
|
||||
The frontend reports an interaction.
|
||||
"""
|
||||
deps.api_auth(api_key, db)
|
||||
|
||||
match (type(interaction)):
|
||||
case protocol_schema.TextReplyToPost:
|
||||
logger.info(
|
||||
f"Frontend reports text reply to {interaction.post_id=} with {interaction.text=} by {interaction.user=}."
|
||||
)
|
||||
# here we would store the text reply in the database
|
||||
return protocol_schema.TaskDone(
|
||||
reply_to_post_id=interaction.user_post_id,
|
||||
addressed_user=interaction.user,
|
||||
)
|
||||
case protocol_schema.PostRating:
|
||||
logger.info(
|
||||
f"Frontend reports rating of {interaction.post_id=} with {interaction.rating=} by {interaction.user=}."
|
||||
)
|
||||
# check if rating in range
|
||||
# here we would store the rating in the database
|
||||
return protocol_schema.TaskDone(
|
||||
reply_to_post_id=interaction.post_id,
|
||||
addressed_user=interaction.user,
|
||||
)
|
||||
case _:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid response type.",
|
||||
)
|
||||
@@ -0,0 +1,116 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Literal, Optional
|
||||
from uuid import UUID
|
||||
|
||||
# from app.models import ApiClient, Person, PersonStats, Post, PostReaction, WorkPackage
|
||||
from app.models import ApiClient, Person, WorkPackage
|
||||
from app.models.payload_column_type import PayloadContainer, payload_tpye
|
||||
from app.schemas import protocol as protocol_schema
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import Session
|
||||
|
||||
|
||||
@payload_tpye
|
||||
class TaskPayload(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
@payload_tpye
|
||||
class SummarizationStoryPayload(TaskPayload):
|
||||
type: Literal["summarize_story"] = "summarize_story"
|
||||
story: str
|
||||
|
||||
|
||||
@payload_tpye
|
||||
class RateSummaryPayload(TaskPayload):
|
||||
type: Literal["rate_summary"] = "rate_summary"
|
||||
full_text: str
|
||||
summary: str
|
||||
scale: protocol_schema.RatingScale
|
||||
|
||||
|
||||
@payload_tpye
|
||||
class InitialPromptPayload(TaskPayload):
|
||||
type: Literal["initial_prompt"] = "initial_prompt"
|
||||
hint: str
|
||||
|
||||
|
||||
@payload_tpye
|
||||
class UserReplyPayload(TaskPayload):
|
||||
type: Literal["user_reply"] = "user_reply"
|
||||
conversation: protocol_schema.Conversation
|
||||
hint: str | None
|
||||
|
||||
|
||||
@payload_tpye
|
||||
class AssistantReplyPayload(TaskPayload):
|
||||
type: Literal["assistant_reply"] = "assistant_reply"
|
||||
conversation: protocol_schema.Conversation
|
||||
|
||||
|
||||
class PromptRepository:
|
||||
def __init__(self, db: Session, api_client: ApiClient, user: Optional[protocol_schema.User]):
|
||||
self.db = db
|
||||
self.api_client = api_client
|
||||
self.person = self.lookup_person(user)
|
||||
self.person_id = self.person.id if self.person else None
|
||||
|
||||
def lookup_person(self, user: protocol_schema.User) -> Person:
|
||||
person: Person = (
|
||||
self.db.query(Person)
|
||||
.filter(Person.api_client_id == self.api_client.id and Person.username == user.id)
|
||||
.first()
|
||||
)
|
||||
if person is None:
|
||||
# user is unknown, create new record
|
||||
person = Person(username=user.id, display_name=user.display_name, api_client_id=self.api_client.id)
|
||||
self.db.add(person)
|
||||
self.db.commit()
|
||||
self.db.refresh(person)
|
||||
elif user.display_name and user.display_name != person.display_name:
|
||||
# we found the user but the display name changed
|
||||
person.display_name = user.display_name
|
||||
self.db.add(person)
|
||||
self.db.commit()
|
||||
return person
|
||||
|
||||
def store_task(self, task: protocol_schema.Task) -> WorkPackage:
|
||||
payload: TaskPayload = None
|
||||
match type(task):
|
||||
case protocol_schema.SummarizeStoryTask:
|
||||
payload = SummarizationStoryPayload(story=task.story)
|
||||
|
||||
case protocol_schema.RateSummaryTask:
|
||||
payload = RateSummaryPayload(full_text=task.full_text, summary=task.summary, scale=task.scale)
|
||||
|
||||
case protocol_schema.InitialPromptTask:
|
||||
payload = InitialPromptPayload(hint=task.hint)
|
||||
|
||||
case protocol_schema.UserReplyTask:
|
||||
payload = UserReplyPayload(conversation=task.conversation, hint=task.hint)
|
||||
|
||||
case protocol_schema.AssistantReplyTask:
|
||||
payload = AssistantReplyPayload(type=task.type, conversation=task.conversation)
|
||||
|
||||
case _:
|
||||
raise RuntimeError(
|
||||
detail="Invalid task type.",
|
||||
)
|
||||
|
||||
wp = self.insert_work_package(payload=payload, id=task.id)
|
||||
assert wp.id == task.id
|
||||
return wp
|
||||
|
||||
def insert_work_package(self, payload: TaskPayload, id: UUID = None) -> WorkPackage:
|
||||
c = PayloadContainer(payload=payload)
|
||||
wp = WorkPackage(
|
||||
id=id,
|
||||
person_id=self.person_id,
|
||||
payload_type=type(payload).__name__,
|
||||
payload=c,
|
||||
api_client_id=self.api_client.id,
|
||||
)
|
||||
self.db.add(wp)
|
||||
self.db.commit()
|
||||
self.db.refresh(wp)
|
||||
return wp
|
||||
Reference in New Issue
Block a user