mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Merge pull request #13 from LAION-AI/simple-text-frontend
implemented a simple text-based frontend
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Generator
|
||||
|
||||
from app.config import settings
|
||||
from app.database import engine
|
||||
from app.models import ServiceClient
|
||||
from fastapi import HTTPException, Security
|
||||
@@ -37,6 +38,10 @@ def api_auth(
|
||||
delete: bool = False,
|
||||
) -> ServiceClient:
|
||||
if api_key is not None:
|
||||
if settings.ALLOW_ANY_API_KEY:
|
||||
return ServiceClient(
|
||||
api_key=api_key, name=api_key, can_append=True, can_read=True, can_write=True, can_delete=True
|
||||
)
|
||||
api_client = db.query(ServiceClient).filter(ServiceClient.api_key == api_key).first()
|
||||
if api_client is not None:
|
||||
if (
|
||||
|
||||
+66
-37
@@ -1,5 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Any, List
|
||||
import random
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from app.api import deps
|
||||
@@ -13,37 +14,58 @@ from starlette.status import HTTP_400_BAD_REQUEST
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/", response_model=List[protocol_schema.SummarizeStoryTask]) # work with Union once more types are added
|
||||
def generate_task(request: protocol_schema.TaskRequest) -> protocol_schema.Task:
|
||||
match (request.type):
|
||||
case protocol_schema.TaskRequestType.generic:
|
||||
logger.info("Frontend requested a generic task.")
|
||||
while request.type == protocol_schema.TaskRequestType.generic:
|
||||
request.type = random.choice(list(protocol_schema.TaskRequestType)).value
|
||||
return generate_task(request)
|
||||
case protocol_schema.TaskRequestType.summarize_story:
|
||||
logger.info("Generating a SummarizeStoryTask.")
|
||||
task = protocol_schema.SummarizeStoryTask(
|
||||
story="This is a story. A very long story. So long, it needs to be summarized.",
|
||||
)
|
||||
case protocol_schema.TaskRequestType.rate_summary:
|
||||
logger.info("Generating a RateSummaryTask.")
|
||||
task = protocol_schema.RateSummaryTask(
|
||||
full_text="This is a story. A very long story. So long, it needs to be summarized.",
|
||||
summary="This is a summary.",
|
||||
scale=protocol_schema.RatingScale(min=1, max=5),
|
||||
)
|
||||
case _:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid request type.",
|
||||
)
|
||||
logger.info(f"Generated {task=}.")
|
||||
if request.user is not None:
|
||||
task.addressed_users = [request.user]
|
||||
|
||||
return task
|
||||
|
||||
|
||||
@router.post("/", response_model=protocol_schema.AnyTask) # work with Union once more types are added
|
||||
def request_task(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
request: protocol_schema.GenericTaskRequest, # work with Union once more types are added
|
||||
request: protocol_schema.TaskRequest,
|
||||
) -> Any:
|
||||
"""
|
||||
Create new task.
|
||||
"""
|
||||
deps.api_auth(api_key, db, create=True)
|
||||
|
||||
# TODO: Create a task and store it in the database.
|
||||
|
||||
match (request.type):
|
||||
case "generic":
|
||||
# here we create a task at random (and store it in the database)
|
||||
logger.info("Frontend requested a generic task.")
|
||||
task = protocol_schema.SummarizeStoryTask(
|
||||
story="This is a story. A very long story. So long, it needs to be summarized.",
|
||||
)
|
||||
|
||||
case _:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid request type.",
|
||||
)
|
||||
if request.user_id is not None:
|
||||
task.addressed_users = [request.user_id]
|
||||
|
||||
return [task]
|
||||
try:
|
||||
task = generate_task(request)
|
||||
# TODO: store task in database
|
||||
except Exception:
|
||||
logger.exception("Failed to generate task.")
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
return task
|
||||
|
||||
|
||||
@router.post("/{task_id}/ack")
|
||||
@@ -52,17 +74,20 @@ def acknowledge_task(
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
task_id: UUID,
|
||||
response: protocol_schema.PostCreatedTaskResponse,
|
||||
response: protocol_schema.AnyTaskResponse,
|
||||
) -> Any:
|
||||
"""
|
||||
The frontend acknowledges a task.
|
||||
"""
|
||||
deps.api_auth(api_key, db, create=True)
|
||||
|
||||
match (response.type):
|
||||
case "post_created":
|
||||
match (type(response)):
|
||||
case protocol_schema.PostCreatedTaskResponse:
|
||||
logger.info(f"Frontend acknowledged {task_id=} and created {response.post_id=}.")
|
||||
# here we would store the post id in the database for the task
|
||||
case protocol_schema.RatingCreatedTaskResponse:
|
||||
logger.info(f"Frontend acknowledged {task_id=} for {response.post_id=}.")
|
||||
# here we would store the rating id in the database for the task
|
||||
case _:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
@@ -77,30 +102,34 @@ def post_interaction(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
interaction: protocol_schema.TextReplyToPost,
|
||||
interaction: protocol_schema.AnyInteraction,
|
||||
) -> Any:
|
||||
"""
|
||||
The frontend reports an interaction.
|
||||
"""
|
||||
deps.api_auth(api_key, db, create=True)
|
||||
|
||||
response = []
|
||||
match (interaction.type):
|
||||
case "text_reply_to_post":
|
||||
match (type(interaction)):
|
||||
case protocol_schema.TextReplyToPost:
|
||||
logger.info(
|
||||
f"Frontend reports text reply to {interaction.post_id=} with {interaction.text=} by {interaction.user_id=}."
|
||||
f"Frontend reports text reply to {interaction.post_id=} with {interaction.text=} by {interaction.user=}."
|
||||
)
|
||||
# here we would store the text reply in the database
|
||||
response.append(
|
||||
protocol_schema.TaskDone(
|
||||
reply_to_post_id=interaction.user_post_id,
|
||||
addressed_users=[interaction.user_id],
|
||||
)
|
||||
return protocol_schema.TaskDone(
|
||||
reply_to_post_id=interaction.user_post_id,
|
||||
addressed_users=[interaction.user],
|
||||
)
|
||||
case protocol_schema.PostRating:
|
||||
logger.info(
|
||||
f"Frontend reports rating of {interaction.post_id=} with {interaction.rating=} by {interaction.user=}."
|
||||
)
|
||||
# here we would store the rating in the database
|
||||
return protocol_schema.TaskDone(
|
||||
reply_to_post_id=interaction.post_id,
|
||||
addressed_users=[interaction.user],
|
||||
)
|
||||
case _:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid response type.",
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@@ -14,6 +14,8 @@ class Settings(BaseSettings):
|
||||
POSTGRES_DB: str = "postgres"
|
||||
DATABASE_URI: Optional[PostgresDsn] = None
|
||||
|
||||
ALLOW_ANY_API_KEY: bool = False
|
||||
|
||||
@validator("DATABASE_URI", pre=True)
|
||||
def assemble_db_connection(cls, v: Optional[str], values: Dict[str, Any]) -> Any:
|
||||
if isinstance(v, str):
|
||||
|
||||
@@ -1,20 +1,28 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Literal, Optional
|
||||
import enum
|
||||
from typing import Literal, Optional, Union
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pydantic
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TaskRequestType(str, enum.Enum):
|
||||
generic = "generic"
|
||||
summarize_story = "summarize_story"
|
||||
rate_summary = "rate_summary"
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class TaskRequest(BaseModel):
|
||||
"""The frontend asks the backend for a task."""
|
||||
|
||||
type: str
|
||||
user_id: Optional[str] = None
|
||||
|
||||
|
||||
class GenericTaskRequest(TaskRequest):
|
||||
type: Literal["generic"] = "generic"
|
||||
type: TaskRequestType = TaskRequestType.generic
|
||||
user: Optional[User] = None
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
@@ -22,14 +30,14 @@ class Task(BaseModel):
|
||||
|
||||
id: UUID = pydantic.Field(default_factory=uuid4)
|
||||
type: str
|
||||
addressed_users: Optional[list[str]] = None
|
||||
addressed_users: Optional[list[User]] = None
|
||||
|
||||
|
||||
class TaskResponse(BaseModel):
|
||||
"""A task response is a message from the frontend to acknowledge the given task."""
|
||||
|
||||
type: str
|
||||
status: Literal["success", "failure"]
|
||||
status: Literal["success", "failure"] = "success"
|
||||
|
||||
|
||||
class PostCreatedTaskResponse(TaskResponse):
|
||||
@@ -37,21 +45,51 @@ class PostCreatedTaskResponse(TaskResponse):
|
||||
post_id: str
|
||||
|
||||
|
||||
class RatingCreatedTaskResponse(TaskResponse):
|
||||
type: Literal["rating_created"] = "rating_created"
|
||||
post_id: str
|
||||
|
||||
|
||||
AnyTaskResponse = Union[
|
||||
PostCreatedTaskResponse,
|
||||
RatingCreatedTaskResponse,
|
||||
]
|
||||
|
||||
|
||||
class SummarizeStoryTask(Task):
|
||||
type: Literal["summarize_story"] = "summarize_story"
|
||||
story: str
|
||||
|
||||
|
||||
class RatingScale(BaseModel):
|
||||
min: int
|
||||
max: int
|
||||
|
||||
|
||||
class RateSummaryTask(Task):
|
||||
type: Literal["rate_summary"] = "rate_summary"
|
||||
full_text: str
|
||||
summary: str
|
||||
scale: RatingScale = RatingScale(min=1, max=5)
|
||||
|
||||
|
||||
class TaskDone(Task):
|
||||
type: Literal["task_done"] = "task_done"
|
||||
reply_to_post_id: str
|
||||
|
||||
|
||||
AnyTask = Union[
|
||||
SummarizeStoryTask,
|
||||
RateSummaryTask,
|
||||
TaskDone,
|
||||
]
|
||||
|
||||
|
||||
class Interaction(BaseModel):
|
||||
"""An interaction is a message from the frontend to the backend."""
|
||||
"""An interaction is a user-generated action in the frontend."""
|
||||
|
||||
type: str
|
||||
user_id: str
|
||||
user: User
|
||||
|
||||
|
||||
class TextReplyToPost(Interaction):
|
||||
@@ -61,3 +99,17 @@ class TextReplyToPost(Interaction):
|
||||
post_id: str
|
||||
user_post_id: str
|
||||
text: str
|
||||
|
||||
|
||||
class PostRating(Interaction):
|
||||
"""A user has replied to a post with text."""
|
||||
|
||||
type: Literal["post_rating"] = "post_rating"
|
||||
post_id: str
|
||||
rating: int
|
||||
|
||||
|
||||
AnyInteraction = Union[
|
||||
TextReplyToPost,
|
||||
PostRating,
|
||||
]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
export ALLOW_ANY_API_KEY=True
|
||||
|
||||
uvicorn app.main:app --reload
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Simple REPL frontend."""
|
||||
|
||||
import requests
|
||||
import typer
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(backend_url: str, api_key: str):
|
||||
"""Simple REPL frontend."""
|
||||
|
||||
def _post(path: str, json: dict) -> dict:
|
||||
response = requests.post(f"{backend_url}{path}", json=json, headers={"X-API-Key": api_key})
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
typer.echo("Requesting work...")
|
||||
tasks = [_post("/api/v1/tasks/", {"type": "generic"})]
|
||||
while tasks:
|
||||
task = tasks.pop(0)
|
||||
match (task["type"]):
|
||||
case "summarize_story":
|
||||
typer.echo("Summarize the following story:")
|
||||
typer.echo(task["story"])
|
||||
|
||||
# acknowledge task
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"type": "post_created", "post_id": "1234"})
|
||||
|
||||
summary = typer.prompt("Enter your summary")
|
||||
|
||||
# send interaction
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_reply_to_post",
|
||||
"post_id": "1234",
|
||||
"user_post_id": "5678",
|
||||
"text": summary,
|
||||
"user": {"id": "1234", "name": "John Doe"},
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
case "rate_summary":
|
||||
typer.echo("Rate the following summary:")
|
||||
typer.echo(task["summary"])
|
||||
typer.echo("Full text:")
|
||||
typer.echo(task["full_text"])
|
||||
typer.echo(f"Rating scale: {task['scale']['min']} - {task['scale']['max']}")
|
||||
|
||||
# acknowledge task
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"type": "rating_created", "post_id": "1234"})
|
||||
|
||||
rating = typer.prompt("Enter your rating", type=int)
|
||||
# send interaction
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "post_rating",
|
||||
"post_id": "1234",
|
||||
"rating": rating,
|
||||
"user": {"id": "1234", "name": "John Doe"},
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
case "task_done":
|
||||
typer.echo("Task done!")
|
||||
case _:
|
||||
typer.echo(f"Unknown task type {task['type']}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
@@ -0,0 +1,2 @@
|
||||
requests==2.18.1
|
||||
typer==0.7.0
|
||||
Reference in New Issue
Block a user