Merge branch '55_backend_HTTP_error_feedback' into main

This commit is contained in:
Andreas Köpf
2022-12-29 00:06:22 +01:00
6 changed files with 154 additions and 69 deletions
+20
View File
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
from http import HTTPStatus
from pathlib import Path
import alembic.command
@@ -7,10 +8,29 @@ import fastapi
from loguru import logger
from oasst_backend.api.v1.api import api_router
from oasst_backend.config import settings
from oasst_backend.exceptions import OasstError, OasstErrorCode
from starlette.middleware.cors import CORSMiddleware
app = fastapi.FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V1_STR}/openapi.json")
@app.exception_handler(OasstError)
async def oasst_exception_handler(request: fastapi.Request, ex: OasstError):
logger.error(f"{request.method} {request.url} failed: {repr(ex)}")
return fastapi.responses.JSONResponse(
status_code=int(ex.http_status_code), content={"message": ex.message, "error_code": ex.error_code}
)
@app.exception_handler(Exception)
async def unhandled_exception_handler(request: fastapi.Request, ex: Exception):
logger.exception(f"{request.method} {request.url} failed [UNHANDLED]: {repr(ex)}")
status = HTTPStatus.INTERNAL_SERVER_ERROR
return fastapi.responses.JSONResponse(
status_code=status.value, content={"message": status.name, "error_code": OasstErrorCode.GENERIC_ERROR}
)
# Set all CORS enabled origins
if settings.BACKEND_CORS_ORIGINS:
app.add_middleware(
+23 -18
View File
@@ -1,16 +1,17 @@
# -*- coding: utf-8 -*-
from http import HTTPStatus
from secrets import token_hex
from typing import Generator
from uuid import UUID
from fastapi import HTTPException, Security
from fastapi import Security
from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery
from loguru import logger
from oasst_backend.config import settings
from oasst_backend.database import engine
from oasst_backend.exceptions import OasstError, OasstErrorCode
from oasst_backend.models import ApiClient
from sqlmodel import Session
from starlette.status import HTTP_403_FORBIDDEN
def get_db() -> Generator:
@@ -36,22 +37,26 @@ def api_auth(
api_key: APIKey,
db: Session,
) -> ApiClient:
if api_key or settings.DEBUG_SKIP_API_KEY_CHECK:
if api_key is None and not settings.DEBUG_SKIP_API_KEY_CHECK:
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials")
if settings.DEBUG_SKIP_API_KEY_CHECK or settings.DEBUG_ALLOW_ANY_API_KEY:
# make sure that a dummy api key exits in db (foreign key references)
ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444")
api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first()
if api_client is None:
token = token_hex(32)
logger.info(f"ANY_API_KEY missing, inserting api_key: {token}")
api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token")
db.add(api_client)
db.commit()
return api_client
if settings.DEBUG_SKIP_API_KEY_CHECK or settings.DEBUG_ALLOW_ANY_API_KEY:
# make sure that a dummy api key exits in db (foreign key references)
ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444")
api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first()
if api_client is None:
token = token_hex(32)
logger.info(f"ANY_API_KEY missing, inserting api_key: {token}")
api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token")
db.add(api_client)
db.commit()
return api_client
api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first()
if api_client is not None and api_client.enabled:
return api_client
api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first()
if api_client is not None and api_client.enabled:
return api_client
raise OasstError(
"Could not validate credentials",
error_code=OasstErrorCode.API_CLIENT_NOT_AUTHORIZED,
http_status_code=HTTPStatus.FORBIDDEN,
)
+16 -25
View File
@@ -3,14 +3,14 @@ import random
from typing import Any, Optional, Tuple
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends
from fastapi.security.api_key import APIKey
from loguru import logger
from oasst_backend.api import deps
from oasst_backend.exceptions import OasstError, OasstErrorCode
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
from starlette.status import HTTP_400_BAD_REQUEST
router = APIRouter()
@@ -107,10 +107,7 @@ def generate_task(
replies=replies,
)
case _:
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
detail="Invalid request type.",
)
raise OasstError("Invalid request type", OasstErrorCode.TASK_INVALID_REQUEST_TYPE)
logger.info(f"Generated {task=}.")
@@ -134,11 +131,11 @@ def request_task(
task, thread_id, parent_post_id = generate_task(request, pr)
pr.store_task(task, thread_id, parent_post_id)
except OasstError:
raise
except Exception:
logger.exception("Failed to generate task.")
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
)
logger.exception("Failed to generate task..")
raise OasstError("Failed to generate task.", OasstErrorCode.TASK_GENERATION_FAILED)
return task
@@ -163,11 +160,11 @@ def acknowledge_task(
logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.")
pr.bind_frontend_post_id(task_id=task_id, post_id=ack_request.post_id)
except OasstError:
raise
except Exception:
logger.exception("Failed to acknowledge task.")
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
)
raise OasstError("Failed to acknowledge task.", OasstErrorCode.TASK_ACK_FAILED)
return {}
@@ -189,10 +186,8 @@ def acknowledge_task_failure(
pr = PromptRepository(db, api_client, user=None)
pr.acknowledge_task_failure(task_id)
except (KeyError, RuntimeError):
logger.exception("Failed to acknowledge task.")
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
)
logger.exception("Failed to not acknowledge task.")
raise OasstError("Failed to not acknowledge task.", OasstErrorCode.TASK_NACK_FAILED)
@router.post("/interaction")
@@ -241,13 +236,9 @@ def post_interaction(
# here we would store the ranking in the database
return protocol_schema.TaskDone()
case _:
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
detail="Invalid response type.",
)
raise OasstError("Invalid response type.", OasstErrorCode.TASK_INVALID_RESPONSE_TYPE)
except OasstError:
raise
except Exception:
logger.exception("Interaction request failed.")
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
)
raise OasstError("Interaction request failed.", OasstErrorCode.TASK_INTERACTION_REQUEST_FAILED)
+2 -1
View File
@@ -1,8 +1,9 @@
# -*- coding: utf-8 -*-
from oasst_backend.config import settings
from oasst_backend.exceptions import OasstError, OasstErrorCode
from sqlmodel import create_engine
if settings.DATABASE_URI is None:
raise ValueError("DATABASE_URI is not set")
raise OasstError("DATABASE_URI is not set", error_code=OasstErrorCode.DATABASE_URI_NOT_SET)
engine = create_engine(settings.DATABASE_URI)
+60
View File
@@ -0,0 +1,60 @@
# -*- coding: utf-8 -*-
from enum import IntEnum
from http import HTTPStatus
class OasstErrorCode(IntEnum):
"""
Error codes of the Open-Assistant backend API.
Ranges:
0-1000: general errors
1000-2000: tasks endpoint
2000-3000: prompt_repository
"""
# 0-1000: general errors
GENERIC_ERROR = 0
DATABASE_URI_NOT_SET = 1
API_CLIENT_NOT_AUTHORIZED = 2
# 1000-2000: tasks endpoint
TASK_INVALID_REQUEST_TYPE = 1000
TASK_ACK_FAILED = 1001
TASK_NACK_FAILED = 1002
TASK_INVALID_RESPONSE_TYPE = 1003
TASK_INTERACTION_REQUEST_FAILED = 1004
TASK_GENERATION_FAILED = 1005
# 2000-3000: prompt_repository
INVALID_POST_ID = 2000
POST_NOT_FOUND = 2001
RATING_OUT_OF_RANGE = 2002
INVALID_RANKING_VALUE = 2003
INVALID_TASK_TYPE = 2004
USER_NOT_SPECIFIED = 2005
NO_THREADS_FOUND = 2006
WORK_PACKAGE_NOT_FOUND = 2100
WORK_PACKAGE_EXPIRED = 2101
WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH = 2102
WORK_PACKAGE_ALREADY_UPDATED = 2103
WORK_PACKAGE_NOT_ACK = 2104
WORK_PACKAGE_ALREADY_DONE = 2105
class OasstError(Exception):
"""Base class for Open-Assistant exceptions."""
message: str
error_code: int
http_status_code: HTTPStatus
def __init__(self, message: str, error_code: OasstErrorCode, http_status_code: HTTPStatus = HTTPStatus.BAD_REQUEST):
super().__init__(message, error_code, http_status_code) # make excetpion picklable (fill args member)
self.message = message
self.error_code = error_code
self.http_status_code = http_status_code
def __repr__(self) -> str:
class_name = self.__class__.__name__
return f'{class_name}(message="{self.message}", error_code={self.error_code}, http_status_code={self.http_status_code})'
+33 -25
View File
@@ -5,6 +5,7 @@ from uuid import UUID, uuid4
import oasst_backend.models.db_payload as db_payload
from loguru import logger
from oasst_backend.exceptions import OasstError, OasstErrorCode
from oasst_backend.journal_writer import JournalWriter
from oasst_backend.models import ApiClient, Person, Post, PostReaction, TextLabels, WorkPackage
from oasst_backend.models.payload_column_type import PayloadContainer
@@ -52,9 +53,9 @@ class PromptRepository:
def validate_post_id(self, post_id: str) -> None:
if not isinstance(post_id, str):
raise TypeError(f"post_id must be string, not {type(post_id)}")
raise OasstError(f"post_id must be string, not {type(post_id)}", OasstErrorCode.INVALID_POST_ID)
if not post_id:
raise ValueError("post_id must not be empty")
raise OasstError("post_id must not be empty", OasstErrorCode.INVALID_POST_ID)
def bind_frontend_post_id(self, task_id: UUID, post_id: str):
self.validate_post_id(post_id)
@@ -66,11 +67,11 @@ class PromptRepository:
.first()
)
if work_pack is None:
raise KeyError(f"WorkPackage for task {task_id} not found")
raise OasstError(f"WorkPackage for task {task_id} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
if work_pack.expired:
raise RuntimeError("WorkPackage already expired.")
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
if work_pack.done or work_pack.ack is not None:
raise RuntimeError("WorkPackage already updated.")
raise OasstError("WorkPackage already updated.", OasstErrorCode.WORK_PACKAGE_ALREADY_UPDATED)
work_pack.frontend_ref_post_id = post_id
work_pack.ack = True
@@ -86,11 +87,11 @@ class PromptRepository:
.first()
)
if work_pack is None:
raise KeyError(f"WorkPackage for task {task_id} not found")
raise OasstError(f"WorkPackage for task {task_id} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
if work_pack.expired:
raise RuntimeError("WorkPackage already expired.")
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
if work_pack.done or work_pack.ack is not None:
raise RuntimeError("WorkPackage already updated.")
raise OasstError("WorkPackage already updated.", OasstErrorCode.WORK_PACKAGE_ALREADY_UPDATED)
work_pack.ack = False
# ToDo: check race-condition, transaction
@@ -105,7 +106,7 @@ class PromptRepository:
.one_or_none()
)
if fail_if_missing and post is None:
raise KeyError(f"Post with post_id {frontend_post_id} not found.")
raise OasstError(f"Post with post_id {frontend_post_id} not found.", OasstErrorCode.POST_NOT_FOUND)
return post
def fetch_workpackage_by_postid(self, post_id: str) -> WorkPackage:
@@ -124,13 +125,13 @@ class PromptRepository:
wp = self.fetch_workpackage_by_postid(post_id)
if wp is None:
raise KeyError(f"WorkPackage for {post_id=} not found")
raise OasstError(f"WorkPackage for {post_id=} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
if wp.expired:
raise RuntimeError("WorkPackage already expired.")
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
if not wp.ack:
raise RuntimeError("WorkPackage is not acknowledged.")
raise OasstError("WorkPackage is not acknowledged.", OasstErrorCode.WORK_PACKAGE_NOT_ACK)
if wp.done:
raise RuntimeError("WorkPackage already done.")
raise OasstError("WorkPackage already done.", OasstErrorCode.WORK_PACKAGE_ALREADY_DONE)
# If there's no parent post assume user started new conversation
role = "user"
@@ -171,12 +172,16 @@ class PromptRepository:
work_package = self.fetch_workpackage_by_postid(rating.post_id)
work_payload: db_payload.RateSummaryPayload = work_package.payload.payload
if type(work_payload) != db_payload.RateSummaryPayload:
raise ValueError(
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RateSummaryPayload}"
raise OasstError(
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RateSummaryPayload}",
OasstErrorCode.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH,
)
if rating.rating < work_payload.scale.min or rating.rating > work_payload.scale.max:
raise ValueError(f"Invalid rating value: {rating.rating=} not in {work_payload.scale=}")
raise OasstError(
f"Invalid rating value: {rating.rating=} not in {work_payload.scale=}",
OasstErrorCode.RATING_OUT_OF_RANGE,
)
# store reaction to post
reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating)
@@ -201,8 +206,9 @@ class PromptRepository:
# validate ranking
num_replies = len(work_payload.replies)
if sorted(ranking.ranking) != list(range(num_replies)):
raise ValueError(
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_replies=})."
raise OasstError(
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_replies=}).",
OasstErrorCode.INVALID_RANKING_VALUE,
)
# store reaction to post
@@ -218,8 +224,9 @@ class PromptRepository:
case db_payload.RankInitialPromptsPayload:
# validate ranking
if sorted(ranking.ranking) != list(range(num_prompts := len(work_payload.prompts))):
raise ValueError(
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_prompts=})."
raise OasstError(
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_prompts=}).",
OasstErrorCode.INVALID_RANKING_VALUE,
)
# store reaction to post
@@ -233,8 +240,9 @@ class PromptRepository:
return reaction
case _:
raise ValueError(
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RankConversationRepliesPayload}"
raise OasstError(
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RankConversationRepliesPayload}",
OasstErrorCode.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH,
)
def store_task(
@@ -276,7 +284,7 @@ class PromptRepository:
)
case _:
raise ValueError(f"Invalid task type: {type(task)=}")
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
wp = self.insert_work_package(
payload=payload,
@@ -348,7 +356,7 @@ class PromptRepository:
def insert_reaction(self, work_package_id: UUID, payload: db_payload.ReactionPayload) -> PostReaction:
if self.person_id is None:
raise ValueError("User required")
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
container = PayloadContainer(payload=payload)
reaction = PostReaction(
@@ -405,7 +413,7 @@ class PromptRepository:
"""
thread_posts = self.fetch_random_thread(last_post_role)
if not thread_posts:
raise RuntimeError("No threads found")
raise OasstError("No threads found", OasstErrorCode.NO_THREADS_FOUND)
if last_post_role:
conv_posts = [p for p in thread_posts if p.role == last_post_role]
conv_posts = [random.choice(conv_posts)]