mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-28 16:20:34 +08:00
Added initial protocol classes
Provides a rough draft of story summarization interaction
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from app.api.v1 import labelers, prompts
|
||||
from app.api.v1 import labelers, prompts, tasks
|
||||
from fastapi import APIRouter
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(labelers.router, prefix="/labelers", tags=["labelers"])
|
||||
api_router.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
|
||||
api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"])
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Any, List
|
||||
from uuid import UUID
|
||||
|
||||
from app.api import deps
|
||||
from app.schemas import protocol as protocol_schema
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security.api_key import APIKey
|
||||
from loguru import logger
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_400_BAD_REQUEST
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/", response_model=List[protocol_schema.SummarizeStoryTask]) # work with Union once more types are added
|
||||
def request_task(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
request: protocol_schema.GenericTaskRequest, # work with Union once more types are added
|
||||
) -> Any:
|
||||
"""
|
||||
Create new task.
|
||||
"""
|
||||
deps.api_auth(api_key, db, create=True)
|
||||
|
||||
# TODO: Create a task and store it in the database.
|
||||
|
||||
match (request.type):
|
||||
case "generic":
|
||||
# here we create a task at random (and store it in the database)
|
||||
logger.info("Frontend requested a generic task.")
|
||||
task = protocol_schema.SummarizeStoryTask(
|
||||
story="This is a story. A very long story. So long, it needs to be summarized.",
|
||||
)
|
||||
|
||||
case _:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid request type.",
|
||||
)
|
||||
if request.user_id is not None:
|
||||
task.addressed_users = [request.user_id]
|
||||
|
||||
return [task]
|
||||
|
||||
|
||||
@router.post("/{task_id}/ack")
|
||||
def acknowledge_task(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
task_id: UUID,
|
||||
response: protocol_schema.PostCreatedTaskResponse,
|
||||
) -> Any:
|
||||
"""
|
||||
The frontend acknowledges a task.
|
||||
"""
|
||||
deps.api_auth(api_key, db, create=True)
|
||||
|
||||
match (response.type):
|
||||
case "post_created":
|
||||
logger.info(f"Frontend acknowledged {task_id=} and created {response.post_id=}.")
|
||||
# here we would store the post id in the database for the task
|
||||
case _:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid response type.",
|
||||
)
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
@router.post("/interaction")
|
||||
def post_interaction(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
interaction: protocol_schema.TextReplyToPost,
|
||||
) -> Any:
|
||||
"""
|
||||
The frontend reports an interaction.
|
||||
"""
|
||||
deps.api_auth(api_key, db, create=True)
|
||||
|
||||
response = []
|
||||
match (interaction.type):
|
||||
case "text_reply_to_post":
|
||||
logger.info(
|
||||
f"Frontend reports text reply to {interaction.post_id=} with {interaction.text=} by {interaction.user_id=}."
|
||||
)
|
||||
# here we would store the text reply in the database
|
||||
response.append(
|
||||
protocol_schema.TaskDone(
|
||||
reply_to_post_id=interaction.user_post_id,
|
||||
addressed_users=[interaction.user_id],
|
||||
)
|
||||
)
|
||||
case _:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid response type.",
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -0,0 +1,63 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Literal, Optional
|
||||
from uuid import UUID
|
||||
|
||||
import pydantic
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TaskRequest(BaseModel):
|
||||
"""The frontend asks the backend for a task."""
|
||||
|
||||
type: Literal
|
||||
user_id: Optional[str] = None
|
||||
|
||||
|
||||
class GenericTaskRequest(TaskRequest):
|
||||
type: Literal["generic"] = "generic"
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
"""A task is a unit of work that the backend gives to the frontend."""
|
||||
|
||||
id: UUID = pydantic.Field(default_factory=UUID)
|
||||
type: Literal
|
||||
addressed_users: Optional[list[str]] = None
|
||||
|
||||
|
||||
class TaskResponse(BaseModel):
|
||||
"""A task response is a message from the frontend to acknowledge the given task."""
|
||||
|
||||
type: Literal
|
||||
status: Literal["success", "failure"]
|
||||
|
||||
|
||||
class PostCreatedTaskResponse(TaskResponse):
|
||||
type: Literal["post_created"] = "post_created"
|
||||
post_id: UUID
|
||||
|
||||
|
||||
class SummarizeStoryTask(Task):
|
||||
type: Literal["summarize_story"] = "summarize_story"
|
||||
story: str
|
||||
|
||||
|
||||
class TaskDone(Task):
|
||||
type: Literal["task_done"] = "task_done"
|
||||
reply_to_post_id: UUID
|
||||
|
||||
|
||||
class Interaction(BaseModel):
|
||||
"""An interaction is a message from the frontend to the backend."""
|
||||
|
||||
type: Literal
|
||||
user_id: str
|
||||
|
||||
|
||||
class TextReplyToPost(Interaction):
|
||||
"""A user has replied to a post with text."""
|
||||
|
||||
type: Literal["text_reply_to_post"] = "text_reply_to_post"
|
||||
post_id: UUID
|
||||
user_post_id: UUID
|
||||
text: str
|
||||
Reference in New Issue
Block a user