renamed to open assistant

This commit is contained in:
Yannic Kilcher
2022-12-17 23:55:55 +01:00
parent 1acdc66973
commit a6c957ccfd
8 changed files with 738 additions and 3 deletions
+1 -1
View File
@@ -3,7 +3,7 @@ from logging.config import fileConfig
import sqlmodel
from alembic import context
from ocgpt import models # noqa: F401
from oasst import models # noqa: F401
from sqlalchemy import engine_from_config, pool
# this is the Alembic Config object, which provides
+2 -2
View File
@@ -5,8 +5,8 @@ import alembic.command
import alembic.config
import fastapi
from loguru import logger
from ocgpt.api.v1.api import api_router
from ocgpt.config import settings
from oasst.api.v1.api import api_router
from oasst.config import settings
from starlette.middleware.cors import CORSMiddleware
app = fastapi.FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V1_STR}/openapi.json")
+57
View File
@@ -0,0 +1,57 @@
# -*- coding: utf-8 -*-
from secrets import token_hex
from typing import Generator
from uuid import UUID
from fastapi import HTTPException, Security
from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery
from loguru import logger
from oasst.config import settings
from oasst.database import engine
from oasst.models import ApiClient
from sqlmodel import Session
from starlette.status import HTTP_403_FORBIDDEN
def get_db() -> Generator:
with Session(engine) as db:
yield db
api_key_query = APIKeyQuery(name="api_key", auto_error=False)
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def get_api_key(
api_key_query: str = Security(api_key_query),
api_key_header: str = Security(api_key_header),
):
if api_key_query:
return api_key_query
else:
return api_key_header
def api_auth(
api_key: APIKey,
db: Session,
) -> ApiClient:
if api_key is not None:
if settings.ALLOW_ANY_API_KEY:
# make sure that a dummy api key exits in db (foreign key references)
ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444")
api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first()
if api_client is None:
token = token_hex(32)
logger.info(f"ANY_API_KEY missing, inserting api_key: {token}")
api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token")
db.add(api_client)
db.commit()
return api_client
api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first()
if api_client is not None and api_client.enabled:
return api_client
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials")
+6
View File
@@ -0,0 +1,6 @@
# -*- coding: utf-8 -*-
from fastapi import APIRouter
from oasst.api.v1 import tasks
api_router = APIRouter()
api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"])
+259
View File
@@ -0,0 +1,259 @@
# -*- coding: utf-8 -*-
import random
from typing import Any
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security.api_key import APIKey
from loguru import logger
from oasst.api import deps
from oasst.models.db_payload import TaskPayload
from oasst.prompt_repository import PromptRepository
from oasst.schemas import protocol as protocol_schema
from sqlmodel import Session
from starlette.status import HTTP_400_BAD_REQUEST
router = APIRouter()
def generate_task(request: protocol_schema.TaskRequest) -> protocol_schema.Task:
match (request.type):
case protocol_schema.TaskRequestType.random:
logger.info("Frontend requested a random task.")
while request.type == protocol_schema.TaskRequestType.random:
request.type = random.choice(list(protocol_schema.TaskRequestType)).value
return generate_task(request)
case protocol_schema.TaskRequestType.summarize_story:
logger.info("Generating a SummarizeStoryTask.")
task = protocol_schema.SummarizeStoryTask(
story="This is a story. A very long story. So long, it needs to be summarized.",
)
case protocol_schema.TaskRequestType.rate_summary:
logger.info("Generating a RateSummaryTask.")
task = protocol_schema.RateSummaryTask(
full_text="This is a story. A very long story. So long, it needs to be summarized.",
summary="This is a summary.",
scale=protocol_schema.RatingScale(min=1, max=5),
)
case protocol_schema.TaskRequestType.initial_prompt:
logger.info("Generating an InitialPromptTask.")
task = protocol_schema.InitialPromptTask(
hint="Ask the assistant about a current event." # this is optional
)
case protocol_schema.TaskRequestType.user_reply:
logger.info("Generating a UserReplyTask.")
task = protocol_schema.UserReplyTask(
conversation=protocol_schema.Conversation(
messages=[
protocol_schema.ConversationMessage(
text="Hey, assistant, what's going on in the world?",
is_assistant=False,
),
protocol_schema.ConversationMessage(
text="I'm not sure I understood correctly, could you rephrase that?",
is_assistant=True,
),
],
)
)
case protocol_schema.TaskRequestType.assistant_reply:
logger.info("Generating a AssistantReplyTask.")
task = protocol_schema.AssistantReplyTask(
conversation=protocol_schema.Conversation(
messages=[
protocol_schema.ConversationMessage(
text="Hey, assistant, write me an English essay about water.",
is_assistant=False,
),
],
)
)
case protocol_schema.TaskRequestType.rank_initial_prompts:
logger.info("Generating a RankInitialPromptsTask.")
task = protocol_schema.RankInitialPromptsTask(
prompts=[
"Please write a story about a time you were happy.",
"Please write a story about a time you were sad.",
]
)
case protocol_schema.TaskRequestType.rank_user_replies:
logger.info("Generating a RankUserRepliesTask.")
task = protocol_schema.RankUserRepliesTask(
conversation=protocol_schema.Conversation(
messages=[
protocol_schema.ConversationMessage(
text="Hey, assistant, what's going on in the world?",
is_assistant=False,
),
protocol_schema.ConversationMessage(
text="I'm not sure I understood correctly, could you rephrase that?",
is_assistant=True,
),
],
),
replies=[
"Oh come oooooon!",
"What are the news?",
],
)
case protocol_schema.TaskRequestType.rank_assistant_replies:
logger.info("Generating a RankAssistantRepliesTask.")
task = protocol_schema.RankAssistantRepliesTask(
conversation=protocol_schema.Conversation(
messages=[
protocol_schema.ConversationMessage(
text="Hey, assistant, what's going on in the world?",
is_assistant=False,
),
],
),
replies=[
"I'm not sure I understood correctly, could you rephrase that?",
"The world is fine. All good.",
"Crap is hitting the fan. Start farming.",
],
)
case _:
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
detail="Invalid request type.",
)
logger.info(f"Generated {task=}.")
return task
@router.post("/", response_model=protocol_schema.AnyTask) # work with Union once more types are added
def request_task(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
request: protocol_schema.TaskRequest,
) -> Any:
"""
Create new task.
"""
api_client = deps.api_auth(api_key, db)
try:
task = generate_task(request)
pr = PromptRepository(db, api_client, request.user)
pr.store_task(task)
except Exception:
logger.exception("Failed to generate task.")
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
)
return task
@router.post("/{task_id}/ack")
def acknowledge_task(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
task_id: UUID,
ack_request: protocol_schema.TaskAck,
) -> Any:
"""
The frontend acknowledges a task.
"""
api_client = deps.api_auth(api_key, db)
try:
pr = PromptRepository(db, api_client, user=None)
# here we store the post id in the database for the task
pr.bind_frontend_post_id(task_id=task_id, post_id=ack_request.post_id)
logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.")
except Exception:
logger.exception("Failed to acknowledge task.")
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
)
return {}
@router.post("/{task_id}/nack")
def acknowledge_task_failure(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
task_id: UUID,
nack_request: protocol_schema.TaskNAck,
) -> Any:
"""
The frontend reports failure to implement a task.
"""
deps.api_auth(api_key, db)
logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.")
# here we would store the post id in the database for the task
return {}
@router.post("/interaction")
def post_interaction(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
interaction: protocol_schema.AnyInteraction,
) -> Any:
"""
The frontend reports an interaction.
"""
api_client = deps.api_auth(api_key, db)
try:
pr = PromptRepository(db, api_client, user=interaction.user)
match type(interaction):
case protocol_schema.TextReplyToPost:
logger.info(
f"Frontend reports text reply to {interaction.post_id=} with {interaction.text=} by {interaction.user=}."
)
work_package = pr.fetch_workpackage_by_postid(interaction.post_id)
work_payload: TaskPayload = work_package.payload.payload
logger.info(f"found task work package in db: {work_payload}")
# here we store the text reply in the database
# ToDo: role user or agent?
pr.store_text_reply(interaction, role="unknown")
return protocol_schema.TaskDone()
case protocol_schema.PostRating:
logger.info(
f"Frontend reports rating of {interaction.post_id=} with {interaction.rating=} by {interaction.user=}."
)
# here we store the rating in the database
pr.store_rating(interaction)
return protocol_schema.TaskDone()
case protocol_schema.PostRanking:
logger.info(
f"Frontend reports ranking of {interaction.post_id=} with {interaction.ranking=} by {interaction.user=}."
)
# TODO: check if the ranking is valid
pr.store_ranking(interaction)
# here we would store the ranking in the database
return protocol_schema.TaskDone()
case _:
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
detail="Invalid response type.",
)
except Exception:
logger.exception("Interaction request failed.")
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
)
+8
View File
@@ -0,0 +1,8 @@
# -*- coding: utf-8 -*-
from oasst.config import settings
from sqlmodel import create_engine
if settings.DATABASE_URI is None:
raise ValueError("DATABASE_URI is not set")
engine = create_engine(settings.DATABASE_URI)
+94
View File
@@ -0,0 +1,94 @@
# -*- coding: utf-8 -*-
from typing import Literal
from oasst.models.payload_column_type import payload_type
from oasst.schemas import protocol as protocol_schema
from pydantic import BaseModel
@payload_type
class TaskPayload(BaseModel):
type: str
@payload_type
class SummarizationStoryPayload(TaskPayload):
type: Literal["summarize_story"] = "summarize_story"
story: str
@payload_type
class RateSummaryPayload(TaskPayload):
type: Literal["rate_summary"] = "rate_summary"
full_text: str
summary: str
scale: protocol_schema.RatingScale
@payload_type
class InitialPromptPayload(TaskPayload):
type: Literal["initial_prompt"] = "initial_prompt"
hint: str
@payload_type
class UserReplyPayload(TaskPayload):
type: Literal["user_reply"] = "user_reply"
conversation: protocol_schema.Conversation
hint: str | None
@payload_type
class AssistantReplyPayload(TaskPayload):
type: Literal["assistant_reply"] = "assistant_reply"
conversation: protocol_schema.Conversation
@payload_type
class PostPayload(BaseModel):
text: str
@payload_type
class ReactionPayload(BaseModel):
type: str
@payload_type
class RatingReactionPayload(ReactionPayload):
type: Literal["post_rating"] = "post_rating"
rating: str
@payload_type
class RankingReactionPayload(ReactionPayload):
type: Literal["post_ranking"] = "post_ranking"
ranking: list[int]
@payload_type
class RankConversationRepliesPayload(TaskPayload):
conversation: protocol_schema.Conversation # the conversation so far
replies: list[str]
@payload_type
class RankInitialPromptsPayload(TaskPayload):
"""A task to rank a set of initial prompts."""
type: Literal["rank_initial_prompts"] = "rank_initial_prompts"
prompts: list[str]
@payload_type
class RankUserRepliesPayload(RankConversationRepliesPayload):
"""A task to rank a set of user replies to a conversation."""
type: Literal["rank_user_replies"] = "rank_user_replies"
@payload_type
class RankAssistantRepliesPayload(RankConversationRepliesPayload):
"""A task to rank a set of assistant replies to a conversation."""
type: Literal["rank_assistant_replies"] = "rank_assistant_replies"
+311
View File
@@ -0,0 +1,311 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional
from uuid import UUID, uuid4
import oasst.models.db_payload as db_payload
from loguru import logger
from oasst.models import ApiClient, Person, Post, PostReaction, WorkPackage
from oasst.models.payload_column_type import PayloadContainer
from oasst.schemas import protocol as protocol_schema
from sqlmodel import Session
class PromptRepository:
def __init__(self, db: Session, api_client: ApiClient, user: Optional[protocol_schema.User]):
self.db = db
self.api_client = api_client
self.person = self.lookup_person(user)
self.person_id = self.person.id if self.person else None
def lookup_person(self, user: protocol_schema.User) -> Person:
if not user:
return None
person: Person = (
self.db.query(Person)
.filter(
Person.api_client_id == self.api_client.id,
Person.username == user.id,
Person.auth_method == user.auth_method,
)
.first()
)
if person is None:
# user is unknown, create new record
person = Person(username=user.id, display_name=user.display_name, api_client_id=self.api_client.id)
self.db.add(person)
self.db.commit()
self.db.refresh(person)
elif user.display_name and user.display_name != person.display_name:
# we found the user but the display name changed
person.display_name = user.display_name
self.db.add(person)
self.db.commit()
return person
def validate_post_id(self, post_id: str) -> None:
if not isinstance(post_id, str):
raise TypeError(f"post_id must be string, not {type(post_id)}")
if not post_id:
raise ValueError("post_id must not be empty")
def bind_frontend_post_id(self, task_id: UUID, post_id: str):
self.validate_post_id(post_id)
# find work package
work_pack: WorkPackage = (
self.db.query(WorkPackage)
.filter(WorkPackage.id == task_id, WorkPackage.api_client_id == self.api_client.id)
.first()
)
if work_pack is None:
raise KeyError(f"WorkPackage for task {task_id} not found")
if work_pack.expiry_date is not None and datetime.utcnow() > work_pack.expiry_date:
raise RuntimeError("WorkPackage already expired.")
# ToDo: check race-condition, transaction
# check if task thread exits
thread_root = (
self.db.query(Post)
.filter(
Post.workpackage_id == work_pack.id,
Post.frontend_post_id == post_id,
Post.parent_id is None,
Post.api_client_id == self.api_client.id,
)
.one_or_none()
)
if thread_root is None:
thread_id = uuid4()
thread_root = self.insert_post(
post_id=thread_id,
thread_id=thread_id,
frontend_post_id=post_id,
parent_id=None,
role="system",
workpackage_id=work_pack.id,
payload=None,
payload_type="bind",
)
return thread_root
def fetch_post_by_frontend_post_id(self, frontend_post_id: str, fail_if_missing: bool = True) -> Post:
self.validate_post_id(frontend_post_id)
post: Post = (
self.db.query(Post)
.filter(Post.api_client_id == self.api_client.id, Post.frontend_post_id == frontend_post_id)
.one_or_none()
)
if fail_if_missing and post is None:
raise KeyError(f"Post with post_id {frontend_post_id} not found.")
return post
def fetch_workpackage_by_postid(self, post_id: str) -> WorkPackage:
self.validate_post_id(post_id)
post = self.fetch_post_by_frontend_post_id(post_id, fail_if_missing=True)
work_pack = self.db.query(WorkPackage).filter(WorkPackage.id == post.workpackage_id).one()
return work_pack
def store_text_reply(self, reply: protocol_schema.TextReplyToPost, role: str) -> Post:
self.validate_post_id(reply.post_id)
self.validate_post_id(reply.user_post_id)
# find post with post-id
parent_post: Post = (
self.db.query(Post)
.filter(
Post.api_client_id == self.api_client.id,
Post.frontend_post_id == reply.post_id,
# Post.person_id == self.person_id
)
.one_or_none()
)
if parent_post is None:
raise KeyError(f"Post for post_id {reply.post_id} not found.")
# create reply post
user_post_id = uuid4()
user_post = self.insert_post(
post_id=user_post_id,
frontend_post_id=reply.user_post_id,
parent_id=parent_post.id,
thread_id=parent_post.thread_id,
workpackage_id=parent_post.workpackage_id,
role=role,
payload=db_payload.PostPayload(text=reply.text),
)
return user_post
def store_rating(self, rating: protocol_schema.PostRating) -> PostReaction:
post = self.fetch_post_by_frontend_post_id(rating.post_id, fail_if_missing=True)
work_package = self.fetch_workpackage_by_postid(rating.post_id)
work_payload: db_payload.RateSummaryPayload = work_package.payload.payload
if type(work_payload) != db_payload.RateSummaryPayload:
raise ValueError(
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RateSummaryPayload}"
)
if rating.rating < work_payload.scale.min or rating.rating > work_payload.scale.max:
raise ValueError(f"Invalid rating value: {rating.rating=} not in {work_payload.scale=}")
# store reaction to post
reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating)
reaction = self.insert_reaction(post.id, reaction_payload)
logger.info(f"Ranking {rating.rating} stored for work_package {work_package.id}.")
return reaction
def store_ranking(self, ranking: protocol_schema.PostRanking) -> PostReaction:
post = self.fetch_post_by_frontend_post_id(ranking.post_id, fail_if_missing=True)
# fetch work_package
work_package = self.fetch_workpackage_by_postid(ranking.post_id)
work_payload: db_payload.RankConversationRepliesPayload | db_payload.RankInitialPromptsPayload = (
work_package.payload.payload
)
match type(work_payload):
case db_payload.RankUserRepliesPayload | db_payload.RankAssistantRepliesPayload:
# validate ranking
num_replies = len(work_payload.replies)
if sorted(ranking.ranking) != list(range(num_replies)):
raise ValueError(
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_replies=})."
)
# store reaction to post
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
reaction = self.insert_reaction(post.id, reaction_payload)
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
return reaction
case db_payload.RankInitialPromptsPayload:
# validate ranking
if sorted(ranking.ranking) != list(range(num_prompts := len(work_payload.prompts))):
raise ValueError(
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_prompts=})."
)
# store reaction to post
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
reaction = self.insert_reaction(post.id, reaction_payload)
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
return reaction
case _:
raise ValueError(
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RankConversationRepliesPayload}"
)
def store_task(self, task: protocol_schema.Task) -> WorkPackage:
payload: db_payload.TaskPayload
match type(task):
case protocol_schema.SummarizeStoryTask:
payload = db_payload.SummarizationStoryPayload(story=task.story)
case protocol_schema.RateSummaryTask:
payload = db_payload.RateSummaryPayload(
full_text=task.full_text, summary=task.summary, scale=task.scale
)
case protocol_schema.InitialPromptTask:
payload = db_payload.InitialPromptPayload(hint=task.hint)
case protocol_schema.UserReplyTask:
payload = db_payload.UserReplyPayload(conversation=task.conversation, hint=task.hint)
case protocol_schema.AssistantReplyTask:
payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
case protocol_schema.RankInitialPromptsTask:
payload = db_payload.RankInitialPromptsPayload(tpye=task.type, prompts=task.prompts)
case protocol_schema.RankUserRepliesTask:
payload = db_payload.RankUserRepliesPayload(
tpye=task.type, conversation=task.conversation, replies=task.replies
)
case protocol_schema.RankAssistantRepliesTask:
payload = db_payload.RankAssistantRepliesPayload(
tpye=task.type, conversation=task.conversation, replies=task.replies
)
case _:
raise ValueError(f"Invalid task type: {type(task)=}")
wp = self.insert_work_package(payload=payload, id=task.id)
assert wp.id == task.id
return wp
def insert_work_package(self, payload: db_payload.TaskPayload, id: UUID = None) -> WorkPackage:
c = PayloadContainer(payload=payload)
wp = WorkPackage(
id=id,
person_id=self.person_id,
payload_type=type(payload).__name__,
payload=c,
api_client_id=self.api_client.id,
)
self.db.add(wp)
self.db.commit()
self.db.refresh(wp)
return wp
def insert_post(
self,
*,
post_id: UUID,
frontend_post_id: str,
parent_id: UUID,
thread_id: UUID,
workpackage_id: UUID,
role: str,
payload: db_payload.PostPayload,
payload_type: str = None,
) -> Post:
if payload_type is None:
if payload is None:
payload_type = "null"
else:
payload_type = type(payload).__name__
post = Post(
id=post_id,
parent_id=parent_id,
thread_id=thread_id,
workpackage_id=workpackage_id,
person_id=self.person_id,
role=role,
frontend_post_id=frontend_post_id,
api_client_id=self.api_client.id,
payload_type=payload_type,
payload=PayloadContainer(payload=payload),
)
self.db.add(post)
self.db.commit()
self.db.refresh(post)
return post
def insert_reaction(self, post_id: UUID, payload: db_payload.ReactionPayload) -> PostReaction:
if self.person_id is None:
raise ValueError("User required")
container = PayloadContainer(payload=payload)
reaction = PostReaction(
post_id=post_id,
person_id=self.person_id,
payload=container,
api_client_id=self.api_client.id,
payload_type=type(payload).__name__,
)
self.db.add(reaction)
self.db.commit()
self.db.refresh(reaction)
return reaction