diff --git a/oasst-shared/tests/test_oasst_api_client.py b/oasst-shared/tests/test_oasst_api_client.py index be757e8f..7092a97e 100644 --- a/oasst-shared/tests/test_oasst_api_client.py +++ b/oasst-shared/tests/test_oasst_api_client.py @@ -1,12 +1,20 @@ +from unittest import mock from uuid import uuid4 +import aiohttp import pytest from oasst_shared.api_client import OasstApiClient +from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema @pytest.fixture def oasst_api_client_mocked(): + """ + A an oasst_api_client pointed at the mocked backend. + Relies on ./scripts/backend-development/start-mock-server.sh + being run. + """ client = OasstApiClient(backend_url="http://localhost:8080", api_key="123") yield client # TODO The fixture should close this connection, but there seems to be a bug @@ -15,6 +23,20 @@ def oasst_api_client_mocked(): # await client.close() +@pytest.fixture +def mock_http_session(): + yield mock.AsyncMock(spec=aiohttp.ClientSession) + + +@pytest.fixture +def oasst_api_client_fake_http(mock_http_session): + """ + An oasst_api_client that uses a mocked http session. No real requests are made. + """ + client = OasstApiClient(backend_url="http://localhost:8080", api_key="123", session=mock_http_session) + yield client + + @pytest.mark.asyncio @pytest.mark.parametrize("task_type", protocol_schema.TaskRequestType) async def test_can_fetch_task(task_type: protocol_schema.TaskRequestType, oasst_api_client_mocked: OasstApiClient): @@ -49,3 +71,22 @@ async def test_can_post_interaction(oasst_api_client_mocked: OasstApiClient): ) is not None ) + + +@pytest.mark.asyncio +async def test_can_handle_oasst_error_from_api( + oasst_api_client_fake_http: OasstApiClient, + mock_http_session: mock.AsyncMock, +): + # Return a 400 response with an OasstErrorResponse body + response_body = protocol_schema.OasstErrorResponse( + error_code=OasstErrorCode.GENERIC_ERROR, + message="Some error", + ).json() + status_code = 400 + + mock_http_session.post.return_value.__aenter__.return_value.json.return_value = response_body + mock_http_session.post.return_value.__aenter__.return_value.status = status_code + + with pytest.raises(OasstError): + await oasst_api_client_fake_http.post("/some-path", data={})