feat: handle OasstError in OasstApiClient

This commit is contained in:
Jack Michaud
2023-01-02 19:06:30 -05:00
parent 5bb9a397b4
commit fe21732f8d
+33 -4
View File
@@ -1,12 +1,15 @@
"""API Client for interacting with the OASST backend."""
import enum
import typing as t
from http import HTTPStatus
from typing import Optional, Type
from uuid import UUID
import aiohttp
from loguru import logger
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from pydantic import ValidationError
# TODO: Move to `protocol`?
@@ -27,7 +30,7 @@ class TaskType(str, enum.Enum):
class OasstApiClient:
"""API Client for interacting with the OASST backend."""
def __init__(self, backend_url: str, api_key: str):
def __init__(self, backend_url: str, api_key: str, session: Optional[aiohttp.ClientSession] = None):
"""Create a new OasstApiClient.
Args:
@@ -35,8 +38,12 @@ class OasstApiClient:
backend_url (str): The base backend URL.
api_key (str): The API key to use for authentication.
"""
logger.debug("Opening OasstApiClient session")
self.session = aiohttp.ClientSession()
if session is None:
logger.debug("Opening OasstApiClient session")
session = aiohttp.ClientSession()
self.session = session
self.backend_url = backend_url
self.api_key = api_key
@@ -56,7 +63,29 @@ class OasstApiClient:
"""Make a POST request to the backend."""
logger.debug(f"POST {self.backend_url}{path} DATA: {data}")
response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"X-API-Key": self.api_key})
response.raise_for_status()
# If the response is not a 2XX, check to see
# if the json has the fields to create an
# OasstError.
if response.status >= 300:
data = await response.json()
try:
oasst_error = protocol_schema.OasstErrorResponse(**data)
raise OasstError(
error_code=oasst_error.error_code,
message=oasst_error.message,
)
except ValidationError as e:
logger.debug(f"Got error from API but could not parse: {e}")
raw_response = await response.text()
logger.debug(f"Raw response: {raw_response}")
raise OasstError(
raw_response,
OasstErrorCode.GENERIC_ERROR,
HTTPStatus(response.status),
)
return await response.json()
def _parse_task(self, data: Optional[dict[str, t.Any]]) -> protocol_schema.Task: