From 5bb9a397b494fc0cc7232e0dc852d9b674291e87 Mon Sep 17 00:00:00 2001 From: Jack Michaud Date: Mon, 2 Jan 2023 19:01:29 -0500 Subject: [PATCH 1/9] feat: add OasstErrorResponse to protocols Using a shared protocol to serialize the error in the backend allows clients to use that same protocol to deserialize it. Changes to this protocol will be caught in tests. --- backend/main.py | 7 ++++++- oasst-shared/oasst_shared/schemas/protocol.py | 8 ++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/backend/main.py b/backend/main.py index e6b3bdce..cb682a9f 100644 --- a/backend/main.py +++ b/backend/main.py @@ -26,8 +26,13 @@ app = fastapi.FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V @app.exception_handler(OasstError) async def oasst_exception_handler(request: fastapi.Request, ex: OasstError): logger.error(f"{request.method} {request.url} failed: {repr(ex)}") + return fastapi.responses.JSONResponse( - status_code=int(ex.http_status_code), content={"message": ex.message, "error_code": ex.error_code} + status_code=int(ex.http_status_code), + content=protocol_schema.OasstErrorResponse( + message=ex.message, + error_code=OasstErrorCode(ex.error_code), + ).dict(), ) diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index 652a2c78..83375d8f 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -4,6 +4,7 @@ from typing import List, Literal, Optional, Union from uuid import UUID, uuid4 import pydantic +from oasst_shared.exceptions import OasstErrorCode from pydantic import BaseModel, Field @@ -293,3 +294,10 @@ class UserScore(BaseModel): class LeaderboardStats(BaseModel): leaderboard: List[UserScore] + + +class OasstErrorResponse(BaseModel): + """The format of an error response from the OASST API.""" + + error_code: OasstErrorCode + message: str From fe21732f8d511bd356f56077831148d23607908a Mon Sep 17 00:00:00 2001 From: Jack Michaud Date: Mon, 2 Jan 2023 19:06:30 -0500 Subject: [PATCH 2/9] feat: handle OasstError in OasstApiClient --- oasst-shared/oasst_shared/api_client.py | 37 ++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/oasst-shared/oasst_shared/api_client.py b/oasst-shared/oasst_shared/api_client.py index 3a575e47..09e3739a 100644 --- a/oasst-shared/oasst_shared/api_client.py +++ b/oasst-shared/oasst_shared/api_client.py @@ -1,12 +1,15 @@ """API Client for interacting with the OASST backend.""" import enum import typing as t +from http import HTTPStatus from typing import Optional, Type from uuid import UUID import aiohttp from loguru import logger +from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema +from pydantic import ValidationError # TODO: Move to `protocol`? @@ -27,7 +30,7 @@ class TaskType(str, enum.Enum): class OasstApiClient: """API Client for interacting with the OASST backend.""" - def __init__(self, backend_url: str, api_key: str): + def __init__(self, backend_url: str, api_key: str, session: Optional[aiohttp.ClientSession] = None): """Create a new OasstApiClient. Args: @@ -35,8 +38,12 @@ class OasstApiClient: backend_url (str): The base backend URL. api_key (str): The API key to use for authentication. """ - logger.debug("Opening OasstApiClient session") - self.session = aiohttp.ClientSession() + + if session is None: + logger.debug("Opening OasstApiClient session") + session = aiohttp.ClientSession() + + self.session = session self.backend_url = backend_url self.api_key = api_key @@ -56,7 +63,29 @@ class OasstApiClient: """Make a POST request to the backend.""" logger.debug(f"POST {self.backend_url}{path} DATA: {data}") response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"X-API-Key": self.api_key}) - response.raise_for_status() + + # If the response is not a 2XX, check to see + # if the json has the fields to create an + # OasstError. + if response.status >= 300: + data = await response.json() + try: + oasst_error = protocol_schema.OasstErrorResponse(**data) + raise OasstError( + error_code=oasst_error.error_code, + message=oasst_error.message, + ) + except ValidationError as e: + logger.debug(f"Got error from API but could not parse: {e}") + + raw_response = await response.text() + logger.debug(f"Raw response: {raw_response}") + + raise OasstError( + raw_response, + OasstErrorCode.GENERIC_ERROR, + HTTPStatus(response.status), + ) return await response.json() def _parse_task(self, data: Optional[dict[str, t.Any]]) -> protocol_schema.Task: From c2a5e08a32599540bedcf8ea1ebf9f891ef45cf2 Mon Sep 17 00:00:00 2001 From: Jack Michaud Date: Mon, 2 Jan 2023 19:07:01 -0500 Subject: [PATCH 3/9] (WIP) test: add test for handling OasstError --- oasst-shared/tests/test_oasst_api_client.py | 41 +++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/oasst-shared/tests/test_oasst_api_client.py b/oasst-shared/tests/test_oasst_api_client.py index be757e8f..7092a97e 100644 --- a/oasst-shared/tests/test_oasst_api_client.py +++ b/oasst-shared/tests/test_oasst_api_client.py @@ -1,12 +1,20 @@ +from unittest import mock from uuid import uuid4 +import aiohttp import pytest from oasst_shared.api_client import OasstApiClient +from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema @pytest.fixture def oasst_api_client_mocked(): + """ + A an oasst_api_client pointed at the mocked backend. + Relies on ./scripts/backend-development/start-mock-server.sh + being run. + """ client = OasstApiClient(backend_url="http://localhost:8080", api_key="123") yield client # TODO The fixture should close this connection, but there seems to be a bug @@ -15,6 +23,20 @@ def oasst_api_client_mocked(): # await client.close() +@pytest.fixture +def mock_http_session(): + yield mock.AsyncMock(spec=aiohttp.ClientSession) + + +@pytest.fixture +def oasst_api_client_fake_http(mock_http_session): + """ + An oasst_api_client that uses a mocked http session. No real requests are made. + """ + client = OasstApiClient(backend_url="http://localhost:8080", api_key="123", session=mock_http_session) + yield client + + @pytest.mark.asyncio @pytest.mark.parametrize("task_type", protocol_schema.TaskRequestType) async def test_can_fetch_task(task_type: protocol_schema.TaskRequestType, oasst_api_client_mocked: OasstApiClient): @@ -49,3 +71,22 @@ async def test_can_post_interaction(oasst_api_client_mocked: OasstApiClient): ) is not None ) + + +@pytest.mark.asyncio +async def test_can_handle_oasst_error_from_api( + oasst_api_client_fake_http: OasstApiClient, + mock_http_session: mock.AsyncMock, +): + # Return a 400 response with an OasstErrorResponse body + response_body = protocol_schema.OasstErrorResponse( + error_code=OasstErrorCode.GENERIC_ERROR, + message="Some error", + ).json() + status_code = 400 + + mock_http_session.post.return_value.__aenter__.return_value.json.return_value = response_body + mock_http_session.post.return_value.__aenter__.return_value.status = status_code + + with pytest.raises(OasstError): + await oasst_api_client_fake_http.post("/some-path", data={}) From 789593bab74d5ae1ab49c412d055bde02171ff0a Mon Sep 17 00:00:00 2001 From: Jack Michaud Date: Mon, 2 Jan 2023 20:37:34 -0500 Subject: [PATCH 4/9] fix: contract test script should fail when pytest fails --- scripts/oasst-shared-development/test.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scripts/oasst-shared-development/test.sh b/scripts/oasst-shared-development/test.sh index e9324196..fcf94beb 100755 --- a/scripts/oasst-shared-development/test.sh +++ b/scripts/oasst-shared-development/test.sh @@ -4,6 +4,8 @@ parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P ) # switch to backend directory pushd "$parent_path/../../oasst-shared" +set -xe + pytest . popd From fbcb0a09e665a30f300312a486b5f5e37b3102de Mon Sep 17 00:00:00 2001 From: Jack Michaud Date: Mon, 2 Jan 2023 20:38:10 -0500 Subject: [PATCH 5/9] chore: move aiohttp client into oasst-shared --- discord-bot/requirements.txt | 2 -- oasst-shared/setup.py | 2 ++ 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/discord-bot/requirements.txt b/discord-bot/requirements.txt index b65b5e15..f6943cb0 100644 --- a/discord-bot/requirements.txt +++ b/discord-bot/requirements.txt @@ -1,5 +1,3 @@ -aiohttp # http client -aiohttp[speedups] # speedups for aiohttp aiosqlite # database hikari # discord framework hikari-lightbulb # command handler diff --git a/oasst-shared/setup.py b/oasst-shared/setup.py index 22fbcc60..502d3733 100644 --- a/oasst-shared/setup.py +++ b/oasst-shared/setup.py @@ -11,5 +11,7 @@ setup( author="OASST Team", install_requires=[ "pydantic==1.9.1", + "aiohttp", + "aiohttp[speedups]", ], ) From bc796b70ba0efcbe8c8ebe7a2e269fea6e555272 Mon Sep 17 00:00:00 2001 From: Jack Michaud Date: Mon, 2 Jan 2023 20:59:32 -0500 Subject: [PATCH 6/9] test: finish test for handling oasst error --- oasst-shared/tests/test_oasst_api_client.py | 26 +++++++++++++++++---- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/oasst-shared/tests/test_oasst_api_client.py b/oasst-shared/tests/test_oasst_api_client.py index 7092a97e..8649d891 100644 --- a/oasst-shared/tests/test_oasst_api_client.py +++ b/oasst-shared/tests/test_oasst_api_client.py @@ -1,3 +1,4 @@ +from typing import Any from unittest import mock from uuid import uuid4 @@ -23,9 +24,19 @@ def oasst_api_client_mocked(): # await client.close() +class MockClientSession(aiohttp.ClientSession): + response: Any + + def set_response(self, response: Any): + self.response = response + + async def post(self, *args, **kwargs): + return self.response + + @pytest.fixture def mock_http_session(): - yield mock.AsyncMock(spec=aiohttp.ClientSession) + yield MockClientSession() @pytest.fixture @@ -76,17 +87,22 @@ async def test_can_post_interaction(oasst_api_client_mocked: OasstApiClient): @pytest.mark.asyncio async def test_can_handle_oasst_error_from_api( oasst_api_client_fake_http: OasstApiClient, - mock_http_session: mock.AsyncMock, + mock_http_session: MockClientSession, ): # Return a 400 response with an OasstErrorResponse body response_body = protocol_schema.OasstErrorResponse( error_code=OasstErrorCode.GENERIC_ERROR, message="Some error", - ).json() + ) status_code = 400 - mock_http_session.post.return_value.__aenter__.return_value.json.return_value = response_body - mock_http_session.post.return_value.__aenter__.return_value.status = status_code + mock_http_session.set_response( + mock.AsyncMock( + status=status_code, + text=mock.AsyncMock(return_value=response_body.json()), + json=mock.AsyncMock(return_value=response_body.dict()), + ) + ) with pytest.raises(OasstError): await oasst_api_client_fake_http.post("/some-path", data={}) From f66ad30c53981a2c0666d1a6409e5a81f83650bf Mon Sep 17 00:00:00 2001 From: Jack Michaud Date: Mon, 2 Jan 2023 21:00:23 -0500 Subject: [PATCH 7/9] fix: handle empty 204 response --- oasst-shared/oasst_shared/api_client.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/oasst-shared/oasst_shared/api_client.py b/oasst-shared/oasst_shared/api_client.py index 09e3739a..e8439517 100644 --- a/oasst-shared/oasst_shared/api_client.py +++ b/oasst-shared/oasst_shared/api_client.py @@ -86,6 +86,10 @@ class OasstApiClient: OasstErrorCode.GENERIC_ERROR, HTTPStatus(response.status), ) + + if response.status == 204: + # No content + return None return await response.json() def _parse_task(self, data: Optional[dict[str, t.Any]]) -> protocol_schema.Task: From 3dcfe7014e13fa327631a60d8b891fb11103a04e Mon Sep 17 00:00:00 2001 From: Jack Michaud Date: Mon, 2 Jan 2023 21:06:00 -0500 Subject: [PATCH 8/9] test: add test for unhandled api error --- oasst-shared/oasst_shared/api_client.py | 2 +- oasst-shared/tests/test_oasst_api_client.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/oasst-shared/oasst_shared/api_client.py b/oasst-shared/oasst_shared/api_client.py index e8439517..404521db 100644 --- a/oasst-shared/oasst_shared/api_client.py +++ b/oasst-shared/oasst_shared/api_client.py @@ -70,7 +70,7 @@ class OasstApiClient: if response.status >= 300: data = await response.json() try: - oasst_error = protocol_schema.OasstErrorResponse(**data) + oasst_error = protocol_schema.OasstErrorResponse(**(data or {})) raise OasstError( error_code=oasst_error.error_code, message=oasst_error.message, diff --git a/oasst-shared/tests/test_oasst_api_client.py b/oasst-shared/tests/test_oasst_api_client.py index 8649d891..fdb743ce 100644 --- a/oasst-shared/tests/test_oasst_api_client.py +++ b/oasst-shared/tests/test_oasst_api_client.py @@ -106,3 +106,23 @@ async def test_can_handle_oasst_error_from_api( with pytest.raises(OasstError): await oasst_api_client_fake_http.post("/some-path", data={}) + + +@pytest.mark.asyncio +async def test_can_handle_unknown_error_from_api( + oasst_api_client_fake_http: OasstApiClient, + mock_http_session: MockClientSession, +): + response_body = "Internal Server Error" + status_code = 500 + + mock_http_session.set_response( + mock.AsyncMock( + status=status_code, + text=mock.AsyncMock(return_value=response_body), + json=mock.AsyncMock(return_value=None), + ) + ) + + with pytest.raises(OasstError): + await oasst_api_client_fake_http.post("/some-path", data={}) From c8aba77a48bd4762a81bf5dacaa16433f0960f0f Mon Sep 17 00:00:00 2001 From: Jack Michaud Date: Mon, 2 Jan 2023 21:11:14 -0500 Subject: [PATCH 9/9] chore: pin aiohttp version --- oasst-shared/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oasst-shared/setup.py b/oasst-shared/setup.py index 502d3733..a04b34e8 100644 --- a/oasst-shared/setup.py +++ b/oasst-shared/setup.py @@ -11,7 +11,7 @@ setup( author="OASST Team", install_requires=[ "pydantic==1.9.1", - "aiohttp", + "aiohttp==3.8.3", "aiohttp[speedups]", ], )