mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-30 16:40:05 +08:00
add OasstErrorCode IntEnum, use http.HTTPStatus enum
This commit is contained in:
+12
-2
@@ -1,4 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
|
||||
import alembic.command
|
||||
@@ -7,7 +8,7 @@ import fastapi
|
||||
from loguru import logger
|
||||
from oasst_backend.api.v1.api import api_router
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.exceptions import OasstError
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
app = fastapi.FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V1_STR}/openapi.json")
|
||||
@@ -17,7 +18,16 @@ app = fastapi.FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V
|
||||
async def oasst_exception_handler(request: fastapi.Request, ex: OasstError):
|
||||
logger.error(f"{request.method} {request.url} failed: {repr(ex)}")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=ex.http_status_code, content={"message": ex.message, "error_code": ex.error_code}
|
||||
status_code=int(ex.http_status_code), content={"message": ex.message, "error_code": ex.error_code}
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def unhandled_exception_handler(request: fastapi.Request, ex: Exception):
|
||||
logger.exception(f"{request.method} {request.url} failed [UNHANDLED]: {repr(ex)}")
|
||||
status = HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=status.value, content={"message": status.name, "error_code": OasstErrorCode.GENERIC_ERROR}
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from http import HTTPStatus
|
||||
from secrets import token_hex
|
||||
from typing import Generator
|
||||
from uuid import UUID
|
||||
@@ -8,10 +9,9 @@ from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery
|
||||
from loguru import logger
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.database import engine
|
||||
from oasst_backend.exceptions import OasstError, error_codes
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.models import ApiClient
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
def get_db() -> Generator:
|
||||
@@ -37,25 +37,26 @@ def api_auth(
|
||||
api_key: APIKey,
|
||||
db: Session,
|
||||
) -> ApiClient:
|
||||
if api_key is None and not settings.DEBUG_SKIP_API_KEY_CHECK:
|
||||
raise OasstError(
|
||||
"Could not validate credentials",
|
||||
error_code=error_codes.API_CLIENT_NOT_AUTHORIZED,
|
||||
http_status_code=HTTP_403_FORBIDDEN,
|
||||
)
|
||||
if api_key or settings.DEBUG_SKIP_API_KEY_CHECK:
|
||||
|
||||
if settings.DEBUG_SKIP_API_KEY_CHECK or settings.DEBUG_ALLOW_ANY_API_KEY:
|
||||
# make sure that a dummy api key exits in db (foreign key references)
|
||||
ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444")
|
||||
api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first()
|
||||
if api_client is None:
|
||||
token = token_hex(32)
|
||||
logger.info(f"ANY_API_KEY missing, inserting api_key: {token}")
|
||||
api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token")
|
||||
db.add(api_client)
|
||||
db.commit()
|
||||
return api_client
|
||||
if settings.DEBUG_SKIP_API_KEY_CHECK or settings.DEBUG_ALLOW_ANY_API_KEY:
|
||||
# make sure that a dummy api key exits in db (foreign key references)
|
||||
ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444")
|
||||
api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first()
|
||||
if api_client is None:
|
||||
token = token_hex(32)
|
||||
logger.info(f"ANY_API_KEY missing, inserting api_key: {token}")
|
||||
api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token")
|
||||
db.add(api_client)
|
||||
db.commit()
|
||||
return api_client
|
||||
|
||||
api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first()
|
||||
if api_client is not None and api_client.enabled:
|
||||
return api_client
|
||||
api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first()
|
||||
if api_client is not None and api_client.enabled:
|
||||
return api_client
|
||||
|
||||
raise OasstError(
|
||||
"Could not validate credentials",
|
||||
error_code=OasstErrorCode.API_CLIENT_NOT_AUTHORIZED,
|
||||
http_status_code=HTTPStatus.FORBIDDEN,
|
||||
)
|
||||
|
||||
@@ -7,7 +7,7 @@ from fastapi import APIRouter, Depends
|
||||
from fastapi.security.api_key import APIKey
|
||||
from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.exceptions import OasstError, error_codes
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
@@ -114,7 +114,7 @@ def generate_task(request: protocol_schema.TaskRequest) -> protocol_schema.Task:
|
||||
],
|
||||
)
|
||||
case _:
|
||||
raise OasstError("Invalid request type", error_codes.TASK_INVALID_REQUEST_TYPE)
|
||||
raise OasstError("Invalid request type", OasstErrorCode.TASK_INVALID_REQUEST_TYPE)
|
||||
|
||||
logger.info(f"Generated {task=}.")
|
||||
|
||||
@@ -131,7 +131,6 @@ def request_task(
|
||||
"""
|
||||
Create new task.
|
||||
"""
|
||||
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
|
||||
try:
|
||||
@@ -144,7 +143,7 @@ def request_task(
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to generate task..")
|
||||
raise OasstError("Failed to generate task.", error_codes.TASK_GENERATION_FAILED)
|
||||
raise OasstError("Failed to generate task.", OasstErrorCode.TASK_GENERATION_FAILED)
|
||||
return task
|
||||
|
||||
|
||||
@@ -173,7 +172,7 @@ def acknowledge_task(
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to acknowledge task.")
|
||||
raise OasstError("Failed to acknowledge task.", error_codes.TASK_ACK_FAILED)
|
||||
raise OasstError("Failed to acknowledge task.", OasstErrorCode.TASK_ACK_FAILED)
|
||||
return {}
|
||||
|
||||
|
||||
@@ -240,9 +239,9 @@ def post_interaction(
|
||||
# here we would store the ranking in the database
|
||||
return protocol_schema.TaskDone()
|
||||
case _:
|
||||
raise OasstError("Invalid response type.", error_codes.TASK_INVALID_RESPONSE_TYPE)
|
||||
raise OasstError("Invalid response type.", OasstErrorCode.TASK_INVALID_RESPONSE_TYPE)
|
||||
except OasstError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Interaction request failed.")
|
||||
raise OasstError("Interaction request failed.", error_codes.TASK_INTERACTION_REQUEST_FAILED)
|
||||
raise OasstError("Interaction request failed.", OasstErrorCode.TASK_INTERACTION_REQUEST_FAILED)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.exceptions import OasstError, error_codes
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from sqlmodel import create_engine
|
||||
|
||||
if settings.DATABASE_URI is None:
|
||||
raise OasstError("DATABASE_URI is not set", error_code=error_codes.DATABASE_URI_NOT_SET)
|
||||
raise OasstError("DATABASE_URI is not set", error_code=OasstErrorCode.DATABASE_URI_NOT_SET)
|
||||
|
||||
engine = create_engine(settings.DATABASE_URI)
|
||||
|
||||
@@ -1,26 +1,36 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Open-Assistant backend API error codes.
|
||||
"""
|
||||
# 0-1000: general errors
|
||||
GENERIC_ERROR = 0
|
||||
DATABASE_URI_NOT_SET = 1
|
||||
API_CLIENT_NOT_AUTHORIZED = 2
|
||||
from enum import IntEnum
|
||||
|
||||
# 1000-2000: tasks endpoint
|
||||
TASK_INVALID_REQUEST_TYPE = 1000
|
||||
TASK_ACK_FAILED = 1001
|
||||
TASK_INVALID_RESPONSE_TYPE = 1002
|
||||
TASK_INTERACTION_REQUEST_FAILED = 1003
|
||||
TASK_GENERATION_FAILED = 1004
|
||||
|
||||
# 2000-3000: prompt_repository
|
||||
INVALID_POST_ID = 2000
|
||||
POST_NOT_FOUND = 2001
|
||||
RATING_OUT_OF_RANGE = 2002
|
||||
INVALID_RANKING_VALUE = 2003
|
||||
WORK_PACKAGE_NOT_FOUND = 2004
|
||||
WORK_PACKAGE_EXPIRED = 2005
|
||||
WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH = 2006
|
||||
INVALID_TASK_TYPE = 2007
|
||||
USER_NOT_SPECIFIED = 2008
|
||||
class OasstErrorCode(IntEnum):
|
||||
"""
|
||||
Error codes of the Open-Assistant backend API.
|
||||
|
||||
Ranges:
|
||||
0-1000: general errors
|
||||
1000-2000: tasks endpoint
|
||||
2000-3000: prompt_repository
|
||||
"""
|
||||
|
||||
# 0-1000: general errors
|
||||
GENERIC_ERROR = 0
|
||||
DATABASE_URI_NOT_SET = 1
|
||||
API_CLIENT_NOT_AUTHORIZED = 2
|
||||
|
||||
# 1000-2000: tasks endpoint
|
||||
TASK_INVALID_REQUEST_TYPE = 1000
|
||||
TASK_ACK_FAILED = 1001
|
||||
TASK_INVALID_RESPONSE_TYPE = 1002
|
||||
TASK_INTERACTION_REQUEST_FAILED = 1003
|
||||
TASK_GENERATION_FAILED = 1004
|
||||
|
||||
# 2000-3000: prompt_repository
|
||||
INVALID_POST_ID = 2000
|
||||
POST_NOT_FOUND = 2001
|
||||
RATING_OUT_OF_RANGE = 2002
|
||||
INVALID_RANKING_VALUE = 2003
|
||||
WORK_PACKAGE_NOT_FOUND = 2004
|
||||
WORK_PACKAGE_EXPIRED = 2005
|
||||
WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH = 2006
|
||||
INVALID_TASK_TYPE = 2007
|
||||
USER_NOT_SPECIFIED = 2008
|
||||
|
||||
@@ -1,6 +1,40 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import oasst_backend.error_codes as error_codes # noqa: F401
|
||||
from starlette.status import HTTP_400_BAD_REQUEST
|
||||
from enum import IntEnum
|
||||
from http import HTTPStatus
|
||||
|
||||
|
||||
class OasstErrorCode(IntEnum):
|
||||
"""
|
||||
Error codes of the Open-Assistant backend API.
|
||||
|
||||
Ranges:
|
||||
0-1000: general errors
|
||||
1000-2000: tasks endpoint
|
||||
2000-3000: prompt_repository
|
||||
"""
|
||||
|
||||
# 0-1000: general errors
|
||||
GENERIC_ERROR = 0
|
||||
DATABASE_URI_NOT_SET = 1
|
||||
API_CLIENT_NOT_AUTHORIZED = 2
|
||||
|
||||
# 1000-2000: tasks endpoint
|
||||
TASK_INVALID_REQUEST_TYPE = 1000
|
||||
TASK_ACK_FAILED = 1001
|
||||
TASK_INVALID_RESPONSE_TYPE = 1002
|
||||
TASK_INTERACTION_REQUEST_FAILED = 1003
|
||||
TASK_GENERATION_FAILED = 1004
|
||||
|
||||
# 2000-3000: prompt_repository
|
||||
INVALID_POST_ID = 2000
|
||||
POST_NOT_FOUND = 2001
|
||||
RATING_OUT_OF_RANGE = 2002
|
||||
INVALID_RANKING_VALUE = 2003
|
||||
WORK_PACKAGE_NOT_FOUND = 2004
|
||||
WORK_PACKAGE_EXPIRED = 2005
|
||||
WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH = 2006
|
||||
INVALID_TASK_TYPE = 2007
|
||||
USER_NOT_SPECIFIED = 2008
|
||||
|
||||
|
||||
class OasstError(Exception):
|
||||
@@ -8,9 +42,9 @@ class OasstError(Exception):
|
||||
|
||||
message: str
|
||||
error_code: int
|
||||
http_status_code: int
|
||||
http_status_code: HTTPStatus
|
||||
|
||||
def __init__(self, message: str, error_code: int, http_status_code: int = HTTP_400_BAD_REQUEST):
|
||||
def __init__(self, message: str, error_code: OasstErrorCode, http_status_code: HTTPStatus = HTTPStatus.BAD_REQUEST):
|
||||
super().__init__(message, error_code, http_status_code) # make excetpion picklable (fill args member)
|
||||
self.message = message
|
||||
self.error_code = error_code
|
||||
|
||||
@@ -5,7 +5,7 @@ from uuid import UUID, uuid4
|
||||
|
||||
import oasst_backend.models.db_payload as db_payload
|
||||
from loguru import logger
|
||||
from oasst_backend.exceptions import OasstError, error_codes
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.journal_writer import JournalWriter
|
||||
from oasst_backend.models import ApiClient, Person, Post, PostReaction, TextLabels, WorkPackage
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
@@ -53,9 +53,9 @@ class PromptRepository:
|
||||
|
||||
def validate_post_id(self, post_id: str) -> None:
|
||||
if not isinstance(post_id, str):
|
||||
raise OasstError(f"post_id must be string, not {type(post_id)}", error_codes.INVALID_POST_ID)
|
||||
raise OasstError(f"post_id must be string, not {type(post_id)}", OasstErrorCode.INVALID_POST_ID)
|
||||
if not post_id:
|
||||
raise OasstError("post_id must not be empty", error_codes.INVALID_POST_ID)
|
||||
raise OasstError("post_id must not be empty", OasstErrorCode.INVALID_POST_ID)
|
||||
|
||||
def bind_frontend_post_id(self, task_id: UUID, post_id: str):
|
||||
self.validate_post_id(post_id)
|
||||
@@ -67,9 +67,9 @@ class PromptRepository:
|
||||
.first()
|
||||
)
|
||||
if work_pack is None:
|
||||
raise OasstError(f"WorkPackage for task {task_id} not found", error_codes.WORK_PACKAGE_NOT_FOUND)
|
||||
raise OasstError(f"WorkPackage for task {task_id} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
|
||||
if work_pack.expiry_date is not None and datetime.utcnow() > work_pack.expiry_date:
|
||||
raise OasstError("WorkPackage already expired.", error_codes.WORK_PACKAGE_EXPIRED)
|
||||
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
|
||||
|
||||
# ToDo: check race-condition, transaction
|
||||
|
||||
@@ -106,7 +106,7 @@ class PromptRepository:
|
||||
.one_or_none()
|
||||
)
|
||||
if fail_if_missing and post is None:
|
||||
raise OasstError(f"Post with post_id {frontend_post_id} not found.", error_codes.POST_NOT_FOUND)
|
||||
raise OasstError(f"Post with post_id {frontend_post_id} not found.", OasstErrorCode.POST_NOT_FOUND)
|
||||
return post
|
||||
|
||||
def fetch_workpackage_by_postid(self, post_id: str) -> WorkPackage:
|
||||
@@ -135,7 +135,7 @@ class PromptRepository:
|
||||
)
|
||||
|
||||
if parent_post is None:
|
||||
raise OasstError(f"Post for post_id {reply.post_id} not found.", error_codes.POST_NOT_FOUND)
|
||||
raise OasstError(f"Post for post_id {reply.post_id} not found.", OasstErrorCode.POST_NOT_FOUND)
|
||||
|
||||
# create reply post
|
||||
user_post_id = uuid4()
|
||||
@@ -159,12 +159,13 @@ class PromptRepository:
|
||||
if type(work_payload) != db_payload.RateSummaryPayload:
|
||||
raise OasstError(
|
||||
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RateSummaryPayload}",
|
||||
error_codes.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH,
|
||||
OasstErrorCode.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH,
|
||||
)
|
||||
|
||||
if rating.rating < work_payload.scale.min or rating.rating > work_payload.scale.max:
|
||||
raise OasstError(
|
||||
f"Invalid rating value: {rating.rating=} not in {work_payload.scale=}", error_codes.RATING_OUT_OF_RANGE
|
||||
f"Invalid rating value: {rating.rating=} not in {work_payload.scale=}",
|
||||
OasstErrorCode.RATING_OUT_OF_RANGE,
|
||||
)
|
||||
|
||||
# store reaction to post
|
||||
@@ -191,7 +192,7 @@ class PromptRepository:
|
||||
if sorted(ranking.ranking) != list(range(num_replies)):
|
||||
raise OasstError(
|
||||
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_replies=}).",
|
||||
error_codes.INVALID_RANKING_VALUE,
|
||||
OasstErrorCode.INVALID_RANKING_VALUE,
|
||||
)
|
||||
|
||||
# store reaction to post
|
||||
@@ -208,7 +209,7 @@ class PromptRepository:
|
||||
if sorted(ranking.ranking) != list(range(num_prompts := len(work_payload.prompts))):
|
||||
raise OasstError(
|
||||
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_prompts=}).",
|
||||
error_codes.INVALID_RANKING_VALUE,
|
||||
OasstErrorCode.INVALID_RANKING_VALUE,
|
||||
)
|
||||
|
||||
# store reaction to post
|
||||
@@ -223,7 +224,7 @@ class PromptRepository:
|
||||
case _:
|
||||
raise OasstError(
|
||||
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RankConversationRepliesPayload}",
|
||||
error_codes.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH,
|
||||
OasstErrorCode.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH,
|
||||
)
|
||||
|
||||
def store_task(self, task: protocol_schema.Task) -> WorkPackage:
|
||||
@@ -260,7 +261,7 @@ class PromptRepository:
|
||||
)
|
||||
|
||||
case _:
|
||||
raise OasstError(f"Invalid task type: {type(task)=}", error_codes.INVALID_TASK_TYPE)
|
||||
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
|
||||
|
||||
wp = self.insert_work_package(payload=payload, id=task.id)
|
||||
assert wp.id == task.id
|
||||
@@ -317,7 +318,7 @@ class PromptRepository:
|
||||
|
||||
def insert_reaction(self, post_id: UUID, payload: db_payload.ReactionPayload) -> PostReaction:
|
||||
if self.person_id is None:
|
||||
raise OasstError("User required", error_codes.USER_NOT_SPECIFIED)
|
||||
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
|
||||
|
||||
container = PayloadContainer(payload=payload)
|
||||
reaction = PostReaction(
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )
|
||||
|
||||
# switch to bot directory
|
||||
pushd "$parent_path/../../bot"
|
||||
pushd "$parent_path/../../discord-bot"
|
||||
|
||||
python3 __main__.py
|
||||
|
||||
|
||||
Reference in New Issue
Block a user