Merge pull request #300 from jack-michaud/jm/oasst-api-client-handle-error

feat: handle OasstError in OasstApiClient
This commit is contained in:
Yannic Kilcher
2023-01-03 08:55:26 +01:00
committed by GitHub
7 changed files with 132 additions and 7 deletions
+6 -1
View File
@@ -26,8 +26,13 @@ app = fastapi.FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V
@app.exception_handler(OasstError)
async def oasst_exception_handler(request: fastapi.Request, ex: OasstError):
logger.error(f"{request.method} {request.url} failed: {repr(ex)}")
return fastapi.responses.JSONResponse(
status_code=int(ex.http_status_code), content={"message": ex.message, "error_code": ex.error_code}
status_code=int(ex.http_status_code),
content=protocol_schema.OasstErrorResponse(
message=ex.message,
error_code=OasstErrorCode(ex.error_code),
).dict(),
)
-2
View File
@@ -1,5 +1,3 @@
aiohttp # http client
aiohttp[speedups] # speedups for aiohttp
aiosqlite # database
hikari # discord framework
hikari-lightbulb # command handler
+37 -4
View File
@@ -1,12 +1,15 @@
"""API Client for interacting with the OASST backend."""
import enum
import typing as t
from http import HTTPStatus
from typing import Optional, Type
from uuid import UUID
import aiohttp
from loguru import logger
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from pydantic import ValidationError
# TODO: Move to `protocol`?
@@ -27,7 +30,7 @@ class TaskType(str, enum.Enum):
class OasstApiClient:
"""API Client for interacting with the OASST backend."""
def __init__(self, backend_url: str, api_key: str):
def __init__(self, backend_url: str, api_key: str, session: Optional[aiohttp.ClientSession] = None):
"""Create a new OasstApiClient.
Args:
@@ -35,8 +38,12 @@ class OasstApiClient:
backend_url (str): The base backend URL.
api_key (str): The API key to use for authentication.
"""
logger.debug("Opening OasstApiClient session")
self.session = aiohttp.ClientSession()
if session is None:
logger.debug("Opening OasstApiClient session")
session = aiohttp.ClientSession()
self.session = session
self.backend_url = backend_url
self.api_key = api_key
@@ -56,7 +63,33 @@ class OasstApiClient:
"""Make a POST request to the backend."""
logger.debug(f"POST {self.backend_url}{path} DATA: {data}")
response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"X-API-Key": self.api_key})
response.raise_for_status()
# If the response is not a 2XX, check to see
# if the json has the fields to create an
# OasstError.
if response.status >= 300:
data = await response.json()
try:
oasst_error = protocol_schema.OasstErrorResponse(**(data or {}))
raise OasstError(
error_code=oasst_error.error_code,
message=oasst_error.message,
)
except ValidationError as e:
logger.debug(f"Got error from API but could not parse: {e}")
raw_response = await response.text()
logger.debug(f"Raw response: {raw_response}")
raise OasstError(
raw_response,
OasstErrorCode.GENERIC_ERROR,
HTTPStatus(response.status),
)
if response.status == 204:
# No content
return None
return await response.json()
def _parse_task(self, data: Optional[dict[str, t.Any]]) -> protocol_schema.Task:
@@ -4,6 +4,7 @@ from typing import List, Literal, Optional, Union
from uuid import UUID, uuid4
import pydantic
from oasst_shared.exceptions import OasstErrorCode
from pydantic import BaseModel, Field
@@ -293,3 +294,10 @@ class UserScore(BaseModel):
class LeaderboardStats(BaseModel):
leaderboard: List[UserScore]
class OasstErrorResponse(BaseModel):
"""The format of an error response from the OASST API."""
error_code: OasstErrorCode
message: str
+2
View File
@@ -11,5 +11,7 @@ setup(
author="OASST Team",
install_requires=[
"pydantic==1.9.1",
"aiohttp==3.8.3",
"aiohttp[speedups]",
],
)
@@ -1,12 +1,21 @@
from typing import Any
from unittest import mock
from uuid import uuid4
import aiohttp
import pytest
from oasst_shared.api_client import OasstApiClient
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
@pytest.fixture
def oasst_api_client_mocked():
"""
A an oasst_api_client pointed at the mocked backend.
Relies on ./scripts/backend-development/start-mock-server.sh
being run.
"""
client = OasstApiClient(backend_url="http://localhost:8080", api_key="123")
yield client
# TODO The fixture should close this connection, but there seems to be a bug
@@ -15,6 +24,30 @@ def oasst_api_client_mocked():
# await client.close()
class MockClientSession(aiohttp.ClientSession):
response: Any
def set_response(self, response: Any):
self.response = response
async def post(self, *args, **kwargs):
return self.response
@pytest.fixture
def mock_http_session():
yield MockClientSession()
@pytest.fixture
def oasst_api_client_fake_http(mock_http_session):
"""
An oasst_api_client that uses a mocked http session. No real requests are made.
"""
client = OasstApiClient(backend_url="http://localhost:8080", api_key="123", session=mock_http_session)
yield client
@pytest.mark.asyncio
@pytest.mark.parametrize("task_type", protocol_schema.TaskRequestType)
async def test_can_fetch_task(task_type: protocol_schema.TaskRequestType, oasst_api_client_mocked: OasstApiClient):
@@ -49,3 +82,47 @@ async def test_can_post_interaction(oasst_api_client_mocked: OasstApiClient):
)
is not None
)
@pytest.mark.asyncio
async def test_can_handle_oasst_error_from_api(
oasst_api_client_fake_http: OasstApiClient,
mock_http_session: MockClientSession,
):
# Return a 400 response with an OasstErrorResponse body
response_body = protocol_schema.OasstErrorResponse(
error_code=OasstErrorCode.GENERIC_ERROR,
message="Some error",
)
status_code = 400
mock_http_session.set_response(
mock.AsyncMock(
status=status_code,
text=mock.AsyncMock(return_value=response_body.json()),
json=mock.AsyncMock(return_value=response_body.dict()),
)
)
with pytest.raises(OasstError):
await oasst_api_client_fake_http.post("/some-path", data={})
@pytest.mark.asyncio
async def test_can_handle_unknown_error_from_api(
oasst_api_client_fake_http: OasstApiClient,
mock_http_session: MockClientSession,
):
response_body = "Internal Server Error"
status_code = 500
mock_http_session.set_response(
mock.AsyncMock(
status=status_code,
text=mock.AsyncMock(return_value=response_body),
json=mock.AsyncMock(return_value=None),
)
)
with pytest.raises(OasstError):
await oasst_api_client_fake_http.post("/some-path", data={})
+2
View File
@@ -4,6 +4,8 @@ parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )
# switch to backend directory
pushd "$parent_path/../../oasst-shared"
set -xe
pytest .
popd