diff --git a/oasst-shared/oasst_shared/api_client.py b/oasst-shared/oasst_shared/api_client.py index 3a575e47..09e3739a 100644 --- a/oasst-shared/oasst_shared/api_client.py +++ b/oasst-shared/oasst_shared/api_client.py @@ -1,12 +1,15 @@ """API Client for interacting with the OASST backend.""" import enum import typing as t +from http import HTTPStatus from typing import Optional, Type from uuid import UUID import aiohttp from loguru import logger +from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema +from pydantic import ValidationError # TODO: Move to `protocol`? @@ -27,7 +30,7 @@ class TaskType(str, enum.Enum): class OasstApiClient: """API Client for interacting with the OASST backend.""" - def __init__(self, backend_url: str, api_key: str): + def __init__(self, backend_url: str, api_key: str, session: Optional[aiohttp.ClientSession] = None): """Create a new OasstApiClient. Args: @@ -35,8 +38,12 @@ class OasstApiClient: backend_url (str): The base backend URL. api_key (str): The API key to use for authentication. """ - logger.debug("Opening OasstApiClient session") - self.session = aiohttp.ClientSession() + + if session is None: + logger.debug("Opening OasstApiClient session") + session = aiohttp.ClientSession() + + self.session = session self.backend_url = backend_url self.api_key = api_key @@ -56,7 +63,29 @@ class OasstApiClient: """Make a POST request to the backend.""" logger.debug(f"POST {self.backend_url}{path} DATA: {data}") response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"X-API-Key": self.api_key}) - response.raise_for_status() + + # If the response is not a 2XX, check to see + # if the json has the fields to create an + # OasstError. + if response.status >= 300: + data = await response.json() + try: + oasst_error = protocol_schema.OasstErrorResponse(**data) + raise OasstError( + error_code=oasst_error.error_code, + message=oasst_error.message, + ) + except ValidationError as e: + logger.debug(f"Got error from API but could not parse: {e}") + + raw_response = await response.text() + logger.debug(f"Raw response: {raw_response}") + + raise OasstError( + raw_response, + OasstErrorCode.GENERIC_ERROR, + HTTPStatus(response.status), + ) return await response.json() def _parse_task(self, data: Optional[dict[str, t.Any]]) -> protocol_schema.Task: