mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-05 17:30:48 +08:00
feat: handle OasstError in OasstApiClient
This commit is contained in:
@@ -1,12 +1,15 @@
|
||||
"""API Client for interacting with the OASST backend."""
|
||||
import enum
|
||||
import typing as t
|
||||
from http import HTTPStatus
|
||||
from typing import Optional, Type
|
||||
from uuid import UUID
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
# TODO: Move to `protocol`?
|
||||
@@ -27,7 +30,7 @@ class TaskType(str, enum.Enum):
|
||||
class OasstApiClient:
|
||||
"""API Client for interacting with the OASST backend."""
|
||||
|
||||
def __init__(self, backend_url: str, api_key: str):
|
||||
def __init__(self, backend_url: str, api_key: str, session: Optional[aiohttp.ClientSession] = None):
|
||||
"""Create a new OasstApiClient.
|
||||
|
||||
Args:
|
||||
@@ -35,8 +38,12 @@ class OasstApiClient:
|
||||
backend_url (str): The base backend URL.
|
||||
api_key (str): The API key to use for authentication.
|
||||
"""
|
||||
logger.debug("Opening OasstApiClient session")
|
||||
self.session = aiohttp.ClientSession()
|
||||
|
||||
if session is None:
|
||||
logger.debug("Opening OasstApiClient session")
|
||||
session = aiohttp.ClientSession()
|
||||
|
||||
self.session = session
|
||||
self.backend_url = backend_url
|
||||
self.api_key = api_key
|
||||
|
||||
@@ -56,7 +63,29 @@ class OasstApiClient:
|
||||
"""Make a POST request to the backend."""
|
||||
logger.debug(f"POST {self.backend_url}{path} DATA: {data}")
|
||||
response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"X-API-Key": self.api_key})
|
||||
response.raise_for_status()
|
||||
|
||||
# If the response is not a 2XX, check to see
|
||||
# if the json has the fields to create an
|
||||
# OasstError.
|
||||
if response.status >= 300:
|
||||
data = await response.json()
|
||||
try:
|
||||
oasst_error = protocol_schema.OasstErrorResponse(**data)
|
||||
raise OasstError(
|
||||
error_code=oasst_error.error_code,
|
||||
message=oasst_error.message,
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.debug(f"Got error from API but could not parse: {e}")
|
||||
|
||||
raw_response = await response.text()
|
||||
logger.debug(f"Raw response: {raw_response}")
|
||||
|
||||
raise OasstError(
|
||||
raw_response,
|
||||
OasstErrorCode.GENERIC_ERROR,
|
||||
HTTPStatus(response.status),
|
||||
)
|
||||
return await response.json()
|
||||
|
||||
def _parse_task(self, data: Optional[dict[str, t.Any]]) -> protocol_schema.Task:
|
||||
|
||||
Reference in New Issue
Block a user