From dda668bcd535f71292e0342eeb2f166e4a8fbec9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Wed, 28 Dec 2022 14:10:15 +0100 Subject: [PATCH 1/4] Add OasstError exception class and exception filter --- backend/main.py | 10 ++++++ backend/oasst_backend/api/deps.py | 10 ++++-- backend/oasst_backend/api/v1/tasks.py | 36 ++++++++----------- backend/oasst_backend/database.py | 3 +- backend/oasst_backend/error_codes.py | 26 ++++++++++++++ backend/oasst_backend/exceptions.py | 21 +++++++++++ backend/oasst_backend/prompt_repository.py | 41 +++++++++++++--------- 7 files changed, 105 insertions(+), 42 deletions(-) create mode 100644 backend/oasst_backend/error_codes.py create mode 100644 backend/oasst_backend/exceptions.py diff --git a/backend/main.py b/backend/main.py index abcd2391..7d7f42c3 100644 --- a/backend/main.py +++ b/backend/main.py @@ -7,10 +7,20 @@ import fastapi from loguru import logger from oasst_backend.api.v1.api import api_router from oasst_backend.config import settings +from oasst_backend.exceptions import OasstError from starlette.middleware.cors import CORSMiddleware app = fastapi.FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V1_STR}/openapi.json") + +@app.exception_handler(OasstError) +async def http_exception_handler(request: fastapi.Request, ex: OasstError): + logger.error(f"{request.method} {request.url} failed: {repr(ex)}") + return fastapi.responses.JSONResponse( + status_code=ex.http_status_code, content={"message": ex.message, "error_code": ex.error_code} + ) + + # Set all CORS enabled origins if settings.BACKEND_CORS_ORIGINS: app.add_middleware( diff --git a/backend/oasst_backend/api/deps.py b/backend/oasst_backend/api/deps.py index 505fa2c6..fe5eb24d 100644 --- a/backend/oasst_backend/api/deps.py +++ b/backend/oasst_backend/api/deps.py @@ -3,11 +3,12 @@ from secrets import token_hex from typing import Generator from uuid import UUID -from fastapi import HTTPException, Security +from fastapi import Security from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery from loguru import logger from oasst_backend.config import settings from oasst_backend.database import engine +from oasst_backend.exceptions import OasstError, error_codes from oasst_backend.models import ApiClient from sqlmodel import Session from starlette.status import HTTP_403_FORBIDDEN @@ -36,9 +37,12 @@ def api_auth( api_key: APIKey, db: Session, ) -> ApiClient: - if api_key is None and not settings.DEBUG_SKIP_API_KEY_CHECK: - raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials") + raise OasstError( + "Could not validate credentials", + error_code=error_codes.API_CLIENT_NOT_AUTHORIZED, + http_status_code=HTTP_403_FORBIDDEN, + ) if settings.DEBUG_SKIP_API_KEY_CHECK or settings.DEBUG_ALLOW_ANY_API_KEY: # make sure that a dummy api key exits in db (foreign key references) diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index 0778fd4c..a4b3ed8b 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -3,14 +3,14 @@ import random from typing import Any from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends from fastapi.security.api_key import APIKey from loguru import logger from oasst_backend.api import deps +from oasst_backend.exceptions import OasstError, error_codes from oasst_backend.prompt_repository import PromptRepository from oasst_shared.schemas import protocol as protocol_schema from sqlmodel import Session -from starlette.status import HTTP_400_BAD_REQUEST router = APIRouter() @@ -114,10 +114,7 @@ def generate_task(request: protocol_schema.TaskRequest) -> protocol_schema.Task: ], ) case _: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail="Invalid request type.", - ) + raise OasstError("Invalid request type", error_codes.TASK_INVALID_REQUEST_TYPE) logger.info(f"Generated {task=}.") @@ -134,6 +131,7 @@ def request_task( """ Create new task. """ + api_client = deps.api_auth(api_key, db) try: @@ -142,11 +140,11 @@ def request_task( pr = PromptRepository(db, api_client, request.user) pr.store_task(task) + except OasstError: + raise except Exception: - logger.exception("Failed to generate task.") - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - ) + logger.exception("Failed to generate task..") + raise OasstError("Failed to generate task.", error_codes.TASK_GENERATION_FAILED) return task @@ -171,11 +169,11 @@ def acknowledge_task( logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.") pr.bind_frontend_post_id(task_id=task_id, post_id=ack_request.post_id) + except OasstError: + raise except Exception: logger.exception("Failed to acknowledge task.") - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - ) + raise OasstError("Failed to acknowledge task.", error_codes.TASK_ACK_FAILED) return {} @@ -242,13 +240,9 @@ def post_interaction( # here we would store the ranking in the database return protocol_schema.TaskDone() case _: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail="Invalid response type.", - ) - + raise OasstError("Invalid response type.", error_codes.TASK_INVALID_RESPONSE_TYPE) + except OasstError: + raise except Exception: logger.exception("Interaction request failed.") - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - ) + raise OasstError("Interaction request failed.", error_codes.TASK_INTERACTION_REQUEST_FAILED) diff --git a/backend/oasst_backend/database.py b/backend/oasst_backend/database.py index 66d7a857..dde5446a 100644 --- a/backend/oasst_backend/database.py +++ b/backend/oasst_backend/database.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- from oasst_backend.config import settings +from oasst_backend.exceptions import OasstError, error_codes from sqlmodel import create_engine if settings.DATABASE_URI is None: - raise ValueError("DATABASE_URI is not set") + raise OasstError("DATABASE_URI is not set", error_code=error_codes.DATABASE_URI_NOT_SET) engine = create_engine(settings.DATABASE_URI) diff --git a/backend/oasst_backend/error_codes.py b/backend/oasst_backend/error_codes.py new file mode 100644 index 00000000..54ed5caf --- /dev/null +++ b/backend/oasst_backend/error_codes.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +""" +Open-Assistant backend API error codes. +""" +# 0-1000: general errors +GENERIC_ERROR = 0 +DATABASE_URI_NOT_SET = 1 +API_CLIENT_NOT_AUTHORIZED = 2 + +# 1000-2000: tasks endpoint +TASK_INVALID_REQUEST_TYPE = 1000 +TASK_ACK_FAILED = 1001 +TASK_INVALID_RESPONSE_TYPE = 1002 +TASK_INTERACTION_REQUEST_FAILED = 1003 +TASK_GENERATION_FAILED = 1004 + +# 2000-3000: prompt_repository +INVALID_POST_ID = 2000 +POST_NOT_FOUND = 2001 +RATING_OUT_OF_RANGE = 2002 +INVALID_RANKING_VALUE = 2003 +WORK_PACKAGE_NOT_FOUND = 2004 +WORK_PACKAGE_EXPIRED = 2005 +WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH = 2006 +INVALID_TASK_TYPE = 2007 +USER_NOT_SPECIFIED = 2008 diff --git a/backend/oasst_backend/exceptions.py b/backend/oasst_backend/exceptions.py new file mode 100644 index 00000000..c614de88 --- /dev/null +++ b/backend/oasst_backend/exceptions.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +import oasst_backend.error_codes as error_codes # noqa: F401 +from starlette.status import HTTP_400_BAD_REQUEST + + +class OasstError(Exception): + """Base class for Open-Assistant exceptions.""" + + message: str + error_code: int + http_status_code: int + + def __init__(self, message: str, error_code: int, http_status_code: int = HTTP_400_BAD_REQUEST): + super().__init__(message, error_code, http_status_code) # make excetpion picklable (fill args member) + self.message = message + self.error_code = error_code + self.http_status_code = http_status_code + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + return f'{class_name}(message="{self.message}", error_code={self.error_code}, http_status_code={self.http_status_code})' diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index f4f83277..eb2ce8b6 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -5,6 +5,7 @@ from uuid import UUID, uuid4 import oasst_backend.models.db_payload as db_payload from loguru import logger +from oasst_backend.exceptions import OasstError, error_codes from oasst_backend.journal_writer import JournalWriter from oasst_backend.models import ApiClient, Person, Post, PostReaction, TextLabels, WorkPackage from oasst_backend.models.payload_column_type import PayloadContainer @@ -52,9 +53,9 @@ class PromptRepository: def validate_post_id(self, post_id: str) -> None: if not isinstance(post_id, str): - raise TypeError(f"post_id must be string, not {type(post_id)}") + raise OasstError(f"post_id must be string, not {type(post_id)}", error_codes.INVALID_POST_ID) if not post_id: - raise ValueError("post_id must not be empty") + raise OasstError("post_id must not be empty", error_codes.INVALID_POST_ID) def bind_frontend_post_id(self, task_id: UUID, post_id: str): self.validate_post_id(post_id) @@ -66,9 +67,9 @@ class PromptRepository: .first() ) if work_pack is None: - raise KeyError(f"WorkPackage for task {task_id} not found") + raise OasstError(f"WorkPackage for task {task_id} not found", error_codes.WORK_PACKAGE_NOT_FOUND) if work_pack.expiry_date is not None and datetime.utcnow() > work_pack.expiry_date: - raise RuntimeError("WorkPackage already expired.") + raise OasstError("WorkPackage already expired.", error_codes.WORK_PACKAGE_EXPIRED) # ToDo: check race-condition, transaction @@ -105,7 +106,7 @@ class PromptRepository: .one_or_none() ) if fail_if_missing and post is None: - raise KeyError(f"Post with post_id {frontend_post_id} not found.") + raise OasstError(f"Post with post_id {frontend_post_id} not found.", error_codes.POST_NOT_FOUND) return post def fetch_workpackage_by_postid(self, post_id: str) -> WorkPackage: @@ -134,7 +135,7 @@ class PromptRepository: ) if parent_post is None: - raise KeyError(f"Post for post_id {reply.post_id} not found.") + raise OasstError(f"Post for post_id {reply.post_id} not found.", error_codes.POST_NOT_FOUND) # create reply post user_post_id = uuid4() @@ -156,12 +157,15 @@ class PromptRepository: work_package = self.fetch_workpackage_by_postid(rating.post_id) work_payload: db_payload.RateSummaryPayload = work_package.payload.payload if type(work_payload) != db_payload.RateSummaryPayload: - raise ValueError( - f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RateSummaryPayload}" + raise OasstError( + f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RateSummaryPayload}", + error_codes.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH, ) if rating.rating < work_payload.scale.min or rating.rating > work_payload.scale.max: - raise ValueError(f"Invalid rating value: {rating.rating=} not in {work_payload.scale=}") + raise OasstError( + f"Invalid rating value: {rating.rating=} not in {work_payload.scale=}", error_codes.RATING_OUT_OF_RANGE + ) # store reaction to post reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating) @@ -185,8 +189,9 @@ class PromptRepository: # validate ranking num_replies = len(work_payload.replies) if sorted(ranking.ranking) != list(range(num_replies)): - raise ValueError( - f"Invalid ranking submitted. Each reply index must appear exactly once ({num_replies=})." + raise OasstError( + f"Invalid ranking submitted. Each reply index must appear exactly once ({num_replies=}).", + error_codes.INVALID_RANKING_VALUE, ) # store reaction to post @@ -201,8 +206,9 @@ class PromptRepository: case db_payload.RankInitialPromptsPayload: # validate ranking if sorted(ranking.ranking) != list(range(num_prompts := len(work_payload.prompts))): - raise ValueError( - f"Invalid ranking submitted. Each reply index must appear exactly once ({num_prompts=})." + raise OasstError( + f"Invalid ranking submitted. Each reply index must appear exactly once ({num_prompts=}).", + error_codes.INVALID_RANKING_VALUE, ) # store reaction to post @@ -215,8 +221,9 @@ class PromptRepository: return reaction case _: - raise ValueError( - f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RankConversationRepliesPayload}" + raise OasstError( + f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RankConversationRepliesPayload}", + error_codes.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH, ) def store_task(self, task: protocol_schema.Task) -> WorkPackage: @@ -253,7 +260,7 @@ class PromptRepository: ) case _: - raise ValueError(f"Invalid task type: {type(task)=}") + raise OasstError(f"Invalid task type: {type(task)=}", error_codes.INVALID_TASK_TYPE) wp = self.insert_work_package(payload=payload, id=task.id) assert wp.id == task.id @@ -310,7 +317,7 @@ class PromptRepository: def insert_reaction(self, post_id: UUID, payload: db_payload.ReactionPayload) -> PostReaction: if self.person_id is None: - raise ValueError("User required") + raise OasstError("User required", error_codes.USER_NOT_SPECIFIED) container = PayloadContainer(payload=payload) reaction = PostReaction( From 36c74e238cfbd1e95819d204f6ece3c2f7200a44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Wed, 28 Dec 2022 14:22:53 +0100 Subject: [PATCH 2/4] rename handler function to --- backend/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/main.py b/backend/main.py index 7d7f42c3..a0745f01 100644 --- a/backend/main.py +++ b/backend/main.py @@ -14,7 +14,7 @@ app = fastapi.FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V @app.exception_handler(OasstError) -async def http_exception_handler(request: fastapi.Request, ex: OasstError): +async def oasst_exception_handler(request: fastapi.Request, ex: OasstError): logger.error(f"{request.method} {request.url} failed: {repr(ex)}") return fastapi.responses.JSONResponse( status_code=ex.http_status_code, content={"message": ex.message, "error_code": ex.error_code} From 0c3103838bb22e65a68278d554dae5a2286dec19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Wed, 28 Dec 2022 19:25:39 +0100 Subject: [PATCH 3/4] add OasstErrorCode IntEnum, use http.HTTPStatus enum --- backend/main.py | 14 ++++- backend/oasst_backend/api/deps.py | 45 +++++++-------- backend/oasst_backend/api/v1/tasks.py | 13 ++--- backend/oasst_backend/database.py | 4 +- backend/oasst_backend/error_codes.py | 56 +++++++++++-------- backend/oasst_backend/exceptions.py | 42 ++++++++++++-- backend/oasst_backend/prompt_repository.py | 29 +++++----- scripts/frontend-development/run-bot-local.sh | 2 +- 8 files changed, 130 insertions(+), 75 deletions(-) diff --git a/backend/main.py b/backend/main.py index a0745f01..386a495b 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +from http import HTTPStatus from pathlib import Path import alembic.command @@ -7,7 +8,7 @@ import fastapi from loguru import logger from oasst_backend.api.v1.api import api_router from oasst_backend.config import settings -from oasst_backend.exceptions import OasstError +from oasst_backend.exceptions import OasstError, OasstErrorCode from starlette.middleware.cors import CORSMiddleware app = fastapi.FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V1_STR}/openapi.json") @@ -17,7 +18,16 @@ app = fastapi.FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V async def oasst_exception_handler(request: fastapi.Request, ex: OasstError): logger.error(f"{request.method} {request.url} failed: {repr(ex)}") return fastapi.responses.JSONResponse( - status_code=ex.http_status_code, content={"message": ex.message, "error_code": ex.error_code} + status_code=int(ex.http_status_code), content={"message": ex.message, "error_code": ex.error_code} + ) + + +@app.exception_handler(Exception) +async def unhandled_exception_handler(request: fastapi.Request, ex: Exception): + logger.exception(f"{request.method} {request.url} failed [UNHANDLED]: {repr(ex)}") + status = HTTPStatus.INTERNAL_SERVER_ERROR + return fastapi.responses.JSONResponse( + status_code=status.value, content={"message": status.name, "error_code": OasstErrorCode.GENERIC_ERROR} ) diff --git a/backend/oasst_backend/api/deps.py b/backend/oasst_backend/api/deps.py index fe5eb24d..bdbd83eb 100644 --- a/backend/oasst_backend/api/deps.py +++ b/backend/oasst_backend/api/deps.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +from http import HTTPStatus from secrets import token_hex from typing import Generator from uuid import UUID @@ -8,10 +9,9 @@ from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery from loguru import logger from oasst_backend.config import settings from oasst_backend.database import engine -from oasst_backend.exceptions import OasstError, error_codes +from oasst_backend.exceptions import OasstError, OasstErrorCode from oasst_backend.models import ApiClient from sqlmodel import Session -from starlette.status import HTTP_403_FORBIDDEN def get_db() -> Generator: @@ -37,25 +37,26 @@ def api_auth( api_key: APIKey, db: Session, ) -> ApiClient: - if api_key is None and not settings.DEBUG_SKIP_API_KEY_CHECK: - raise OasstError( - "Could not validate credentials", - error_code=error_codes.API_CLIENT_NOT_AUTHORIZED, - http_status_code=HTTP_403_FORBIDDEN, - ) + if api_key or settings.DEBUG_SKIP_API_KEY_CHECK: - if settings.DEBUG_SKIP_API_KEY_CHECK or settings.DEBUG_ALLOW_ANY_API_KEY: - # make sure that a dummy api key exits in db (foreign key references) - ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444") - api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first() - if api_client is None: - token = token_hex(32) - logger.info(f"ANY_API_KEY missing, inserting api_key: {token}") - api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token") - db.add(api_client) - db.commit() - return api_client + if settings.DEBUG_SKIP_API_KEY_CHECK or settings.DEBUG_ALLOW_ANY_API_KEY: + # make sure that a dummy api key exits in db (foreign key references) + ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444") + api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first() + if api_client is None: + token = token_hex(32) + logger.info(f"ANY_API_KEY missing, inserting api_key: {token}") + api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token") + db.add(api_client) + db.commit() + return api_client - api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first() - if api_client is not None and api_client.enabled: - return api_client + api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first() + if api_client is not None and api_client.enabled: + return api_client + + raise OasstError( + "Could not validate credentials", + error_code=OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, + http_status_code=HTTPStatus.FORBIDDEN, + ) diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index a4b3ed8b..17f9ffb1 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -7,7 +7,7 @@ from fastapi import APIRouter, Depends from fastapi.security.api_key import APIKey from loguru import logger from oasst_backend.api import deps -from oasst_backend.exceptions import OasstError, error_codes +from oasst_backend.exceptions import OasstError, OasstErrorCode from oasst_backend.prompt_repository import PromptRepository from oasst_shared.schemas import protocol as protocol_schema from sqlmodel import Session @@ -114,7 +114,7 @@ def generate_task(request: protocol_schema.TaskRequest) -> protocol_schema.Task: ], ) case _: - raise OasstError("Invalid request type", error_codes.TASK_INVALID_REQUEST_TYPE) + raise OasstError("Invalid request type", OasstErrorCode.TASK_INVALID_REQUEST_TYPE) logger.info(f"Generated {task=}.") @@ -131,7 +131,6 @@ def request_task( """ Create new task. """ - api_client = deps.api_auth(api_key, db) try: @@ -144,7 +143,7 @@ def request_task( raise except Exception: logger.exception("Failed to generate task..") - raise OasstError("Failed to generate task.", error_codes.TASK_GENERATION_FAILED) + raise OasstError("Failed to generate task.", OasstErrorCode.TASK_GENERATION_FAILED) return task @@ -173,7 +172,7 @@ def acknowledge_task( raise except Exception: logger.exception("Failed to acknowledge task.") - raise OasstError("Failed to acknowledge task.", error_codes.TASK_ACK_FAILED) + raise OasstError("Failed to acknowledge task.", OasstErrorCode.TASK_ACK_FAILED) return {} @@ -240,9 +239,9 @@ def post_interaction( # here we would store the ranking in the database return protocol_schema.TaskDone() case _: - raise OasstError("Invalid response type.", error_codes.TASK_INVALID_RESPONSE_TYPE) + raise OasstError("Invalid response type.", OasstErrorCode.TASK_INVALID_RESPONSE_TYPE) except OasstError: raise except Exception: logger.exception("Interaction request failed.") - raise OasstError("Interaction request failed.", error_codes.TASK_INTERACTION_REQUEST_FAILED) + raise OasstError("Interaction request failed.", OasstErrorCode.TASK_INTERACTION_REQUEST_FAILED) diff --git a/backend/oasst_backend/database.py b/backend/oasst_backend/database.py index dde5446a..38e5105c 100644 --- a/backend/oasst_backend/database.py +++ b/backend/oasst_backend/database.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- from oasst_backend.config import settings -from oasst_backend.exceptions import OasstError, error_codes +from oasst_backend.exceptions import OasstError, OasstErrorCode from sqlmodel import create_engine if settings.DATABASE_URI is None: - raise OasstError("DATABASE_URI is not set", error_code=error_codes.DATABASE_URI_NOT_SET) + raise OasstError("DATABASE_URI is not set", error_code=OasstErrorCode.DATABASE_URI_NOT_SET) engine = create_engine(settings.DATABASE_URI) diff --git a/backend/oasst_backend/error_codes.py b/backend/oasst_backend/error_codes.py index 54ed5caf..7f0bb541 100644 --- a/backend/oasst_backend/error_codes.py +++ b/backend/oasst_backend/error_codes.py @@ -1,26 +1,36 @@ # -*- coding: utf-8 -*- -""" -Open-Assistant backend API error codes. -""" -# 0-1000: general errors -GENERIC_ERROR = 0 -DATABASE_URI_NOT_SET = 1 -API_CLIENT_NOT_AUTHORIZED = 2 +from enum import IntEnum -# 1000-2000: tasks endpoint -TASK_INVALID_REQUEST_TYPE = 1000 -TASK_ACK_FAILED = 1001 -TASK_INVALID_RESPONSE_TYPE = 1002 -TASK_INTERACTION_REQUEST_FAILED = 1003 -TASK_GENERATION_FAILED = 1004 -# 2000-3000: prompt_repository -INVALID_POST_ID = 2000 -POST_NOT_FOUND = 2001 -RATING_OUT_OF_RANGE = 2002 -INVALID_RANKING_VALUE = 2003 -WORK_PACKAGE_NOT_FOUND = 2004 -WORK_PACKAGE_EXPIRED = 2005 -WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH = 2006 -INVALID_TASK_TYPE = 2007 -USER_NOT_SPECIFIED = 2008 +class OasstErrorCode(IntEnum): + """ + Error codes of the Open-Assistant backend API. + + Ranges: + 0-1000: general errors + 1000-2000: tasks endpoint + 2000-3000: prompt_repository + """ + + # 0-1000: general errors + GENERIC_ERROR = 0 + DATABASE_URI_NOT_SET = 1 + API_CLIENT_NOT_AUTHORIZED = 2 + + # 1000-2000: tasks endpoint + TASK_INVALID_REQUEST_TYPE = 1000 + TASK_ACK_FAILED = 1001 + TASK_INVALID_RESPONSE_TYPE = 1002 + TASK_INTERACTION_REQUEST_FAILED = 1003 + TASK_GENERATION_FAILED = 1004 + + # 2000-3000: prompt_repository + INVALID_POST_ID = 2000 + POST_NOT_FOUND = 2001 + RATING_OUT_OF_RANGE = 2002 + INVALID_RANKING_VALUE = 2003 + WORK_PACKAGE_NOT_FOUND = 2004 + WORK_PACKAGE_EXPIRED = 2005 + WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH = 2006 + INVALID_TASK_TYPE = 2007 + USER_NOT_SPECIFIED = 2008 diff --git a/backend/oasst_backend/exceptions.py b/backend/oasst_backend/exceptions.py index c614de88..eba6c543 100644 --- a/backend/oasst_backend/exceptions.py +++ b/backend/oasst_backend/exceptions.py @@ -1,6 +1,40 @@ # -*- coding: utf-8 -*- -import oasst_backend.error_codes as error_codes # noqa: F401 -from starlette.status import HTTP_400_BAD_REQUEST +from enum import IntEnum +from http import HTTPStatus + + +class OasstErrorCode(IntEnum): + """ + Error codes of the Open-Assistant backend API. + + Ranges: + 0-1000: general errors + 1000-2000: tasks endpoint + 2000-3000: prompt_repository + """ + + # 0-1000: general errors + GENERIC_ERROR = 0 + DATABASE_URI_NOT_SET = 1 + API_CLIENT_NOT_AUTHORIZED = 2 + + # 1000-2000: tasks endpoint + TASK_INVALID_REQUEST_TYPE = 1000 + TASK_ACK_FAILED = 1001 + TASK_INVALID_RESPONSE_TYPE = 1002 + TASK_INTERACTION_REQUEST_FAILED = 1003 + TASK_GENERATION_FAILED = 1004 + + # 2000-3000: prompt_repository + INVALID_POST_ID = 2000 + POST_NOT_FOUND = 2001 + RATING_OUT_OF_RANGE = 2002 + INVALID_RANKING_VALUE = 2003 + WORK_PACKAGE_NOT_FOUND = 2004 + WORK_PACKAGE_EXPIRED = 2005 + WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH = 2006 + INVALID_TASK_TYPE = 2007 + USER_NOT_SPECIFIED = 2008 class OasstError(Exception): @@ -8,9 +42,9 @@ class OasstError(Exception): message: str error_code: int - http_status_code: int + http_status_code: HTTPStatus - def __init__(self, message: str, error_code: int, http_status_code: int = HTTP_400_BAD_REQUEST): + def __init__(self, message: str, error_code: OasstErrorCode, http_status_code: HTTPStatus = HTTPStatus.BAD_REQUEST): super().__init__(message, error_code, http_status_code) # make excetpion picklable (fill args member) self.message = message self.error_code = error_code diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index eb2ce8b6..bce69420 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -5,7 +5,7 @@ from uuid import UUID, uuid4 import oasst_backend.models.db_payload as db_payload from loguru import logger -from oasst_backend.exceptions import OasstError, error_codes +from oasst_backend.exceptions import OasstError, OasstErrorCode from oasst_backend.journal_writer import JournalWriter from oasst_backend.models import ApiClient, Person, Post, PostReaction, TextLabels, WorkPackage from oasst_backend.models.payload_column_type import PayloadContainer @@ -53,9 +53,9 @@ class PromptRepository: def validate_post_id(self, post_id: str) -> None: if not isinstance(post_id, str): - raise OasstError(f"post_id must be string, not {type(post_id)}", error_codes.INVALID_POST_ID) + raise OasstError(f"post_id must be string, not {type(post_id)}", OasstErrorCode.INVALID_POST_ID) if not post_id: - raise OasstError("post_id must not be empty", error_codes.INVALID_POST_ID) + raise OasstError("post_id must not be empty", OasstErrorCode.INVALID_POST_ID) def bind_frontend_post_id(self, task_id: UUID, post_id: str): self.validate_post_id(post_id) @@ -67,9 +67,9 @@ class PromptRepository: .first() ) if work_pack is None: - raise OasstError(f"WorkPackage for task {task_id} not found", error_codes.WORK_PACKAGE_NOT_FOUND) + raise OasstError(f"WorkPackage for task {task_id} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND) if work_pack.expiry_date is not None and datetime.utcnow() > work_pack.expiry_date: - raise OasstError("WorkPackage already expired.", error_codes.WORK_PACKAGE_EXPIRED) + raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED) # ToDo: check race-condition, transaction @@ -106,7 +106,7 @@ class PromptRepository: .one_or_none() ) if fail_if_missing and post is None: - raise OasstError(f"Post with post_id {frontend_post_id} not found.", error_codes.POST_NOT_FOUND) + raise OasstError(f"Post with post_id {frontend_post_id} not found.", OasstErrorCode.POST_NOT_FOUND) return post def fetch_workpackage_by_postid(self, post_id: str) -> WorkPackage: @@ -135,7 +135,7 @@ class PromptRepository: ) if parent_post is None: - raise OasstError(f"Post for post_id {reply.post_id} not found.", error_codes.POST_NOT_FOUND) + raise OasstError(f"Post for post_id {reply.post_id} not found.", OasstErrorCode.POST_NOT_FOUND) # create reply post user_post_id = uuid4() @@ -159,12 +159,13 @@ class PromptRepository: if type(work_payload) != db_payload.RateSummaryPayload: raise OasstError( f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RateSummaryPayload}", - error_codes.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH, + OasstErrorCode.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH, ) if rating.rating < work_payload.scale.min or rating.rating > work_payload.scale.max: raise OasstError( - f"Invalid rating value: {rating.rating=} not in {work_payload.scale=}", error_codes.RATING_OUT_OF_RANGE + f"Invalid rating value: {rating.rating=} not in {work_payload.scale=}", + OasstErrorCode.RATING_OUT_OF_RANGE, ) # store reaction to post @@ -191,7 +192,7 @@ class PromptRepository: if sorted(ranking.ranking) != list(range(num_replies)): raise OasstError( f"Invalid ranking submitted. Each reply index must appear exactly once ({num_replies=}).", - error_codes.INVALID_RANKING_VALUE, + OasstErrorCode.INVALID_RANKING_VALUE, ) # store reaction to post @@ -208,7 +209,7 @@ class PromptRepository: if sorted(ranking.ranking) != list(range(num_prompts := len(work_payload.prompts))): raise OasstError( f"Invalid ranking submitted. Each reply index must appear exactly once ({num_prompts=}).", - error_codes.INVALID_RANKING_VALUE, + OasstErrorCode.INVALID_RANKING_VALUE, ) # store reaction to post @@ -223,7 +224,7 @@ class PromptRepository: case _: raise OasstError( f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RankConversationRepliesPayload}", - error_codes.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH, + OasstErrorCode.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH, ) def store_task(self, task: protocol_schema.Task) -> WorkPackage: @@ -260,7 +261,7 @@ class PromptRepository: ) case _: - raise OasstError(f"Invalid task type: {type(task)=}", error_codes.INVALID_TASK_TYPE) + raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE) wp = self.insert_work_package(payload=payload, id=task.id) assert wp.id == task.id @@ -317,7 +318,7 @@ class PromptRepository: def insert_reaction(self, post_id: UUID, payload: db_payload.ReactionPayload) -> PostReaction: if self.person_id is None: - raise OasstError("User required", error_codes.USER_NOT_SPECIFIED) + raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED) container = PayloadContainer(payload=payload) reaction = PostReaction( diff --git a/scripts/frontend-development/run-bot-local.sh b/scripts/frontend-development/run-bot-local.sh index 7308c541..56833b0a 100755 --- a/scripts/frontend-development/run-bot-local.sh +++ b/scripts/frontend-development/run-bot-local.sh @@ -2,7 +2,7 @@ parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P ) # switch to bot directory -pushd "$parent_path/../../bot" +pushd "$parent_path/../../discord-bot" python3 __main__.py From 30e65b250fbbbd278ab5af256e1135b9a89239ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Wed, 28 Dec 2022 22:41:39 +0100 Subject: [PATCH 4/4] remove old error_codes file --- backend/oasst_backend/error_codes.py | 36 ---------------------------- 1 file changed, 36 deletions(-) delete mode 100644 backend/oasst_backend/error_codes.py diff --git a/backend/oasst_backend/error_codes.py b/backend/oasst_backend/error_codes.py deleted file mode 100644 index 7f0bb541..00000000 --- a/backend/oasst_backend/error_codes.py +++ /dev/null @@ -1,36 +0,0 @@ -# -*- coding: utf-8 -*- -from enum import IntEnum - - -class OasstErrorCode(IntEnum): - """ - Error codes of the Open-Assistant backend API. - - Ranges: - 0-1000: general errors - 1000-2000: tasks endpoint - 2000-3000: prompt_repository - """ - - # 0-1000: general errors - GENERIC_ERROR = 0 - DATABASE_URI_NOT_SET = 1 - API_CLIENT_NOT_AUTHORIZED = 2 - - # 1000-2000: tasks endpoint - TASK_INVALID_REQUEST_TYPE = 1000 - TASK_ACK_FAILED = 1001 - TASK_INVALID_RESPONSE_TYPE = 1002 - TASK_INTERACTION_REQUEST_FAILED = 1003 - TASK_GENERATION_FAILED = 1004 - - # 2000-3000: prompt_repository - INVALID_POST_ID = 2000 - POST_NOT_FOUND = 2001 - RATING_OUT_OF_RANGE = 2002 - INVALID_RANKING_VALUE = 2003 - WORK_PACKAGE_NOT_FOUND = 2004 - WORK_PACKAGE_EXPIRED = 2005 - WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH = 2006 - INVALID_TASK_TYPE = 2007 - USER_NOT_SPECIFIED = 2008