mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Merge branch 'LAION-AI:main' into main
This commit is contained in:
+6
-1
@@ -26,8 +26,13 @@ app = fastapi.FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V
|
||||
@app.exception_handler(OasstError)
|
||||
async def oasst_exception_handler(request: fastapi.Request, ex: OasstError):
|
||||
logger.error(f"{request.method} {request.url} failed: {repr(ex)}")
|
||||
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=int(ex.http_status_code), content={"message": ex.message, "error_code": ex.error_code}
|
||||
status_code=int(ex.http_status_code),
|
||||
content=protocol_schema.OasstErrorResponse(
|
||||
message=ex.message,
|
||||
error_code=OasstErrorCode(ex.error_code),
|
||||
).dict(),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
aiohttp # http client
|
||||
aiohttp[speedups] # speedups for aiohttp
|
||||
aiosqlite # database
|
||||
hikari # discord framework
|
||||
hikari-lightbulb # command handler
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
"""API Client for interacting with the OASST backend."""
|
||||
import enum
|
||||
import typing as t
|
||||
from http import HTTPStatus
|
||||
from typing import Optional, Type
|
||||
from uuid import UUID
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
# TODO: Move to `protocol`?
|
||||
@@ -27,7 +30,7 @@ class TaskType(str, enum.Enum):
|
||||
class OasstApiClient:
|
||||
"""API Client for interacting with the OASST backend."""
|
||||
|
||||
def __init__(self, backend_url: str, api_key: str):
|
||||
def __init__(self, backend_url: str, api_key: str, session: Optional[aiohttp.ClientSession] = None):
|
||||
"""Create a new OasstApiClient.
|
||||
|
||||
Args:
|
||||
@@ -35,8 +38,12 @@ class OasstApiClient:
|
||||
backend_url (str): The base backend URL.
|
||||
api_key (str): The API key to use for authentication.
|
||||
"""
|
||||
logger.debug("Opening OasstApiClient session")
|
||||
self.session = aiohttp.ClientSession()
|
||||
|
||||
if session is None:
|
||||
logger.debug("Opening OasstApiClient session")
|
||||
session = aiohttp.ClientSession()
|
||||
|
||||
self.session = session
|
||||
self.backend_url = backend_url
|
||||
self.api_key = api_key
|
||||
|
||||
@@ -56,7 +63,33 @@ class OasstApiClient:
|
||||
"""Make a POST request to the backend."""
|
||||
logger.debug(f"POST {self.backend_url}{path} DATA: {data}")
|
||||
response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"X-API-Key": self.api_key})
|
||||
response.raise_for_status()
|
||||
|
||||
# If the response is not a 2XX, check to see
|
||||
# if the json has the fields to create an
|
||||
# OasstError.
|
||||
if response.status >= 300:
|
||||
data = await response.json()
|
||||
try:
|
||||
oasst_error = protocol_schema.OasstErrorResponse(**(data or {}))
|
||||
raise OasstError(
|
||||
error_code=oasst_error.error_code,
|
||||
message=oasst_error.message,
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.debug(f"Got error from API but could not parse: {e}")
|
||||
|
||||
raw_response = await response.text()
|
||||
logger.debug(f"Raw response: {raw_response}")
|
||||
|
||||
raise OasstError(
|
||||
raw_response,
|
||||
OasstErrorCode.GENERIC_ERROR,
|
||||
HTTPStatus(response.status),
|
||||
)
|
||||
|
||||
if response.status == 204:
|
||||
# No content
|
||||
return None
|
||||
return await response.json()
|
||||
|
||||
def _parse_task(self, data: Optional[dict[str, t.Any]]) -> protocol_schema.Task:
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import List, Literal, Optional, Union
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pydantic
|
||||
from oasst_shared.exceptions import OasstErrorCode
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -293,3 +294,10 @@ class UserScore(BaseModel):
|
||||
|
||||
class LeaderboardStats(BaseModel):
|
||||
leaderboard: List[UserScore]
|
||||
|
||||
|
||||
class OasstErrorResponse(BaseModel):
|
||||
"""The format of an error response from the OASST API."""
|
||||
|
||||
error_code: OasstErrorCode
|
||||
message: str
|
||||
|
||||
@@ -11,5 +11,7 @@ setup(
|
||||
author="OASST Team",
|
||||
install_requires=[
|
||||
"pydantic==1.9.1",
|
||||
"aiohttp==3.8.3",
|
||||
"aiohttp[speedups]",
|
||||
],
|
||||
)
|
||||
|
||||
@@ -1,12 +1,21 @@
|
||||
from typing import Any
|
||||
from unittest import mock
|
||||
from uuid import uuid4
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
from oasst_shared.api_client import OasstApiClient
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oasst_api_client_mocked():
|
||||
"""
|
||||
A an oasst_api_client pointed at the mocked backend.
|
||||
Relies on ./scripts/backend-development/start-mock-server.sh
|
||||
being run.
|
||||
"""
|
||||
client = OasstApiClient(backend_url="http://localhost:8080", api_key="123")
|
||||
yield client
|
||||
# TODO The fixture should close this connection, but there seems to be a bug
|
||||
@@ -15,6 +24,30 @@ def oasst_api_client_mocked():
|
||||
# await client.close()
|
||||
|
||||
|
||||
class MockClientSession(aiohttp.ClientSession):
|
||||
response: Any
|
||||
|
||||
def set_response(self, response: Any):
|
||||
self.response = response
|
||||
|
||||
async def post(self, *args, **kwargs):
|
||||
return self.response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_http_session():
|
||||
yield MockClientSession()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oasst_api_client_fake_http(mock_http_session):
|
||||
"""
|
||||
An oasst_api_client that uses a mocked http session. No real requests are made.
|
||||
"""
|
||||
client = OasstApiClient(backend_url="http://localhost:8080", api_key="123", session=mock_http_session)
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("task_type", protocol_schema.TaskRequestType)
|
||||
async def test_can_fetch_task(task_type: protocol_schema.TaskRequestType, oasst_api_client_mocked: OasstApiClient):
|
||||
@@ -49,3 +82,47 @@ async def test_can_post_interaction(oasst_api_client_mocked: OasstApiClient):
|
||||
)
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_handle_oasst_error_from_api(
|
||||
oasst_api_client_fake_http: OasstApiClient,
|
||||
mock_http_session: MockClientSession,
|
||||
):
|
||||
# Return a 400 response with an OasstErrorResponse body
|
||||
response_body = protocol_schema.OasstErrorResponse(
|
||||
error_code=OasstErrorCode.GENERIC_ERROR,
|
||||
message="Some error",
|
||||
)
|
||||
status_code = 400
|
||||
|
||||
mock_http_session.set_response(
|
||||
mock.AsyncMock(
|
||||
status=status_code,
|
||||
text=mock.AsyncMock(return_value=response_body.json()),
|
||||
json=mock.AsyncMock(return_value=response_body.dict()),
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(OasstError):
|
||||
await oasst_api_client_fake_http.post("/some-path", data={})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_handle_unknown_error_from_api(
|
||||
oasst_api_client_fake_http: OasstApiClient,
|
||||
mock_http_session: MockClientSession,
|
||||
):
|
||||
response_body = "Internal Server Error"
|
||||
status_code = 500
|
||||
|
||||
mock_http_session.set_response(
|
||||
mock.AsyncMock(
|
||||
status=status_code,
|
||||
text=mock.AsyncMock(return_value=response_body),
|
||||
json=mock.AsyncMock(return_value=None),
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(OasstError):
|
||||
await oasst_api_client_fake_http.post("/some-path", data={})
|
||||
|
||||
@@ -4,6 +4,8 @@ parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )
|
||||
# switch to backend directory
|
||||
pushd "$parent_path/../../oasst-shared"
|
||||
|
||||
set -xe
|
||||
|
||||
pytest .
|
||||
|
||||
popd
|
||||
|
||||
Reference in New Issue
Block a user