diff --git a/backend/main.py b/backend/main.py index e6b3bdce..cb682a9f 100644 --- a/backend/main.py +++ b/backend/main.py @@ -26,8 +26,13 @@ app = fastapi.FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V @app.exception_handler(OasstError) async def oasst_exception_handler(request: fastapi.Request, ex: OasstError): logger.error(f"{request.method} {request.url} failed: {repr(ex)}") + return fastapi.responses.JSONResponse( - status_code=int(ex.http_status_code), content={"message": ex.message, "error_code": ex.error_code} + status_code=int(ex.http_status_code), + content=protocol_schema.OasstErrorResponse( + message=ex.message, + error_code=OasstErrorCode(ex.error_code), + ).dict(), ) diff --git a/discord-bot/requirements.txt b/discord-bot/requirements.txt index b65b5e15..f6943cb0 100644 --- a/discord-bot/requirements.txt +++ b/discord-bot/requirements.txt @@ -1,5 +1,3 @@ -aiohttp # http client -aiohttp[speedups] # speedups for aiohttp aiosqlite # database hikari # discord framework hikari-lightbulb # command handler diff --git a/oasst-shared/oasst_shared/api_client.py b/oasst-shared/oasst_shared/api_client.py index 3a575e47..404521db 100644 --- a/oasst-shared/oasst_shared/api_client.py +++ b/oasst-shared/oasst_shared/api_client.py @@ -1,12 +1,15 @@ """API Client for interacting with the OASST backend.""" import enum import typing as t +from http import HTTPStatus from typing import Optional, Type from uuid import UUID import aiohttp from loguru import logger +from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema +from pydantic import ValidationError # TODO: Move to `protocol`? @@ -27,7 +30,7 @@ class TaskType(str, enum.Enum): class OasstApiClient: """API Client for interacting with the OASST backend.""" - def __init__(self, backend_url: str, api_key: str): + def __init__(self, backend_url: str, api_key: str, session: Optional[aiohttp.ClientSession] = None): """Create a new OasstApiClient. Args: @@ -35,8 +38,12 @@ class OasstApiClient: backend_url (str): The base backend URL. api_key (str): The API key to use for authentication. """ - logger.debug("Opening OasstApiClient session") - self.session = aiohttp.ClientSession() + + if session is None: + logger.debug("Opening OasstApiClient session") + session = aiohttp.ClientSession() + + self.session = session self.backend_url = backend_url self.api_key = api_key @@ -56,7 +63,33 @@ class OasstApiClient: """Make a POST request to the backend.""" logger.debug(f"POST {self.backend_url}{path} DATA: {data}") response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"X-API-Key": self.api_key}) - response.raise_for_status() + + # If the response is not a 2XX, check to see + # if the json has the fields to create an + # OasstError. + if response.status >= 300: + data = await response.json() + try: + oasst_error = protocol_schema.OasstErrorResponse(**(data or {})) + raise OasstError( + error_code=oasst_error.error_code, + message=oasst_error.message, + ) + except ValidationError as e: + logger.debug(f"Got error from API but could not parse: {e}") + + raw_response = await response.text() + logger.debug(f"Raw response: {raw_response}") + + raise OasstError( + raw_response, + OasstErrorCode.GENERIC_ERROR, + HTTPStatus(response.status), + ) + + if response.status == 204: + # No content + return None return await response.json() def _parse_task(self, data: Optional[dict[str, t.Any]]) -> protocol_schema.Task: diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index 652a2c78..83375d8f 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -4,6 +4,7 @@ from typing import List, Literal, Optional, Union from uuid import UUID, uuid4 import pydantic +from oasst_shared.exceptions import OasstErrorCode from pydantic import BaseModel, Field @@ -293,3 +294,10 @@ class UserScore(BaseModel): class LeaderboardStats(BaseModel): leaderboard: List[UserScore] + + +class OasstErrorResponse(BaseModel): + """The format of an error response from the OASST API.""" + + error_code: OasstErrorCode + message: str diff --git a/oasst-shared/setup.py b/oasst-shared/setup.py index 22fbcc60..a04b34e8 100644 --- a/oasst-shared/setup.py +++ b/oasst-shared/setup.py @@ -11,5 +11,7 @@ setup( author="OASST Team", install_requires=[ "pydantic==1.9.1", + "aiohttp==3.8.3", + "aiohttp[speedups]", ], ) diff --git a/oasst-shared/tests/test_oasst_api_client.py b/oasst-shared/tests/test_oasst_api_client.py index be757e8f..fdb743ce 100644 --- a/oasst-shared/tests/test_oasst_api_client.py +++ b/oasst-shared/tests/test_oasst_api_client.py @@ -1,12 +1,21 @@ +from typing import Any +from unittest import mock from uuid import uuid4 +import aiohttp import pytest from oasst_shared.api_client import OasstApiClient +from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema @pytest.fixture def oasst_api_client_mocked(): + """ + A an oasst_api_client pointed at the mocked backend. + Relies on ./scripts/backend-development/start-mock-server.sh + being run. + """ client = OasstApiClient(backend_url="http://localhost:8080", api_key="123") yield client # TODO The fixture should close this connection, but there seems to be a bug @@ -15,6 +24,30 @@ def oasst_api_client_mocked(): # await client.close() +class MockClientSession(aiohttp.ClientSession): + response: Any + + def set_response(self, response: Any): + self.response = response + + async def post(self, *args, **kwargs): + return self.response + + +@pytest.fixture +def mock_http_session(): + yield MockClientSession() + + +@pytest.fixture +def oasst_api_client_fake_http(mock_http_session): + """ + An oasst_api_client that uses a mocked http session. No real requests are made. + """ + client = OasstApiClient(backend_url="http://localhost:8080", api_key="123", session=mock_http_session) + yield client + + @pytest.mark.asyncio @pytest.mark.parametrize("task_type", protocol_schema.TaskRequestType) async def test_can_fetch_task(task_type: protocol_schema.TaskRequestType, oasst_api_client_mocked: OasstApiClient): @@ -49,3 +82,47 @@ async def test_can_post_interaction(oasst_api_client_mocked: OasstApiClient): ) is not None ) + + +@pytest.mark.asyncio +async def test_can_handle_oasst_error_from_api( + oasst_api_client_fake_http: OasstApiClient, + mock_http_session: MockClientSession, +): + # Return a 400 response with an OasstErrorResponse body + response_body = protocol_schema.OasstErrorResponse( + error_code=OasstErrorCode.GENERIC_ERROR, + message="Some error", + ) + status_code = 400 + + mock_http_session.set_response( + mock.AsyncMock( + status=status_code, + text=mock.AsyncMock(return_value=response_body.json()), + json=mock.AsyncMock(return_value=response_body.dict()), + ) + ) + + with pytest.raises(OasstError): + await oasst_api_client_fake_http.post("/some-path", data={}) + + +@pytest.mark.asyncio +async def test_can_handle_unknown_error_from_api( + oasst_api_client_fake_http: OasstApiClient, + mock_http_session: MockClientSession, +): + response_body = "Internal Server Error" + status_code = 500 + + mock_http_session.set_response( + mock.AsyncMock( + status=status_code, + text=mock.AsyncMock(return_value=response_body), + json=mock.AsyncMock(return_value=None), + ) + ) + + with pytest.raises(OasstError): + await oasst_api_client_fake_http.post("/some-path", data={}) diff --git a/scripts/oasst-shared-development/test.sh b/scripts/oasst-shared-development/test.sh index e9324196..fcf94beb 100755 --- a/scripts/oasst-shared-development/test.sh +++ b/scripts/oasst-shared-development/test.sh @@ -4,6 +4,8 @@ parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P ) # switch to backend directory pushd "$parent_path/../../oasst-shared" +set -xe + pytest . popd