diff --git a/backend/main.py b/backend/main.py index edbad943..2a3bb230 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,3 +1,4 @@ +import json from http import HTTPStatus from math import ceil from pathlib import Path @@ -6,7 +7,6 @@ from typing import Optional import alembic.command import alembic.config import fastapi -import pydantic import redis.asyncio as redis from fastapi_limiter import FastAPILimiter from loguru import logger @@ -17,6 +17,7 @@ from oasst_backend.database import engine from oasst_backend.prompt_repository import PromptRepository from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema +from pydantic import BaseModel from sqlmodel import Session from starlette.middleware.cors import CORSMiddleware @@ -97,7 +98,7 @@ if settings.DEBUG_USE_SEED_DATA: @app.on_event("startup") def seed_data(): - class DummyMessage(pydantic.BaseModel): + class DummyMessage(BaseModel): task_message_id: str user_message_id: str parent_message_id: Optional[str] @@ -111,64 +112,10 @@ if settings.DEBUG_USE_SEED_DATA: dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local") pr = PromptRepository(db=db, api_client=api_client, user=dummy_user) - dummy_messages = [ - DummyMessage( - task_message_id="de111fa8", - user_message_id="6f1d0711", - parent_message_id=None, - text="Hi!", - role="prompter", - ), - DummyMessage( - task_message_id="74c381d4", - user_message_id="4a24530b", - parent_message_id="6f1d0711", - text="Hello! How can I help you?", - role="assistant", - ), - DummyMessage( - task_message_id="3d5dc440", - user_message_id="a8c01c04", - parent_message_id="4a24530b", - text="Do you have a recipe for potato soup?", - role="prompter", - ), - DummyMessage( - task_message_id="643716c1", - user_message_id="f43a93b7", - parent_message_id="4a24530b", - text="Who were the 8 presidents before George Washington?", - role="prompter", - ), - DummyMessage( - task_message_id="2e4e1e6", - user_message_id="c886920", - parent_message_id="6f1d0711", - text="Hey buddy! How can I serve you?", - role="assistant", - ), - DummyMessage( - task_message_id="970c437d", - user_message_id="cec432cf", - parent_message_id=None, - text="euirdteunvglfe23908230892309832098 AAAAAAAA", - role="prompter", - ), - DummyMessage( - task_message_id="6066118e", - user_message_id="4f85f637", - parent_message_id="cec432cf", - text="Sorry, I did not understand your request and it is unclear to me what you want me to do. Could you describe it in a different way?", - role="assistant", - ), - DummyMessage( - task_message_id="ba87780d", - user_message_id="0e276b98", - parent_message_id="cec432cf", - text="I'm unsure how to interpret this. Is it a riddle?", - role="assistant", - ), - ] + with open(settings.DEBUG_USE_SEED_DATA_PATH) as f: + dummy_messages_raw = json.load(f) + + dummy_messages = [DummyMessage(**dm) for dm in dummy_messages_raw] for msg in dummy_messages: task = pr.fetch_task_by_frontend_message_id(msg.task_message_id) @@ -219,7 +166,6 @@ if __name__ == "__main__": # Importing here so we don't import packages unnecessarily if we're # importing main as a module. import argparse - import json import uvicorn diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index df37dc9f..1765af7a 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -1,6 +1,7 @@ +from pathlib import Path from typing import Any, Dict, List, Optional, Union -from pydantic import AnyHttpUrl, BaseSettings, PostgresDsn, validator +from pydantic import AnyHttpUrl, BaseSettings, FilePath, PostgresDsn, validator class Settings(BaseSettings): @@ -21,6 +22,9 @@ class Settings(BaseSettings): DEBUG_ALLOW_ANY_API_KEY: bool = False DEBUG_SKIP_API_KEY_CHECK: bool = False DEBUG_USE_SEED_DATA: bool = False + DEBUG_USE_SEED_DATA_PATH: Optional[FilePath] = ( + Path(__file__).parent.parent / "test_data/generic/test_generic_data.json" + ) HUGGING_FACE_API_KEY: str = "" diff --git a/backend/test_data/generic/test_generic_data.json b/backend/test_data/generic/test_generic_data.json new file mode 100644 index 00000000..b634902d --- /dev/null +++ b/backend/test_data/generic/test_generic_data.json @@ -0,0 +1,58 @@ +[ + { + "task_message_id": "de111fa8", + "user_message_id": "6f1d0711", + "parent_message_id": null, + "text": "Hi!", + "role": "prompter" + }, + { + "task_message_id": "74c381d4", + "user_message_id": "4a24530b", + "parent_message_id": "6f1d0711", + "text": "Hello! How can I help you?", + "role": "assistant" + }, + { + "task_message_id": "3d5dc440", + "user_message_id": "a8c01c04", + "parent_message_id": "4a24530b", + "text": "Do you have a recipe for potato soup?", + "role": "prompter" + }, + { + "task_message_id": "643716c1", + "user_message_id": "f43a93b7", + "parent_message_id": "4a24530b", + "text": "Who were the 8 presidents before George Washington?", + "role": "prompter" + }, + { + "task_message_id": "2e4e1e6", + "user_message_id": "c886920", + "parent_message_id": "6f1d0711", + "text": "Hey buddy! How can I serve you?", + "role": "assistant" + }, + { + "task_message_id": "970c437d", + "user_message_id": "cec432cf", + "parent_message_id": null, + "text": "euirdteunvglfe23908230892309832098 AAAAAAAA", + "role": "prompter" + }, + { + "task_message_id": "6066118e", + "user_message_id": "4f85f637", + "parent_message_id": "cec432cf", + "text": "Sorry, I did not understand your request and it is unclear to me what you want me to do. Could you describe it in a different way?", + "role": "assistant" + }, + { + "task_message_id": "ba87780d", + "user_message_id": "0e276b98", + "parent_message_id": "cec432cf", + "text": "I'm unsure how to interpret this. Is it a riddle?", + "role": "assistant" + } +] diff --git a/docker/Dockerfile.backend b/docker/Dockerfile.backend index 1f3bdfcd..c89a0280 100644 --- a/docker/Dockerfile.backend +++ b/docker/Dockerfile.backend @@ -14,3 +14,4 @@ COPY ./backend/alembic /app/alembic COPY ./backend/alembic.ini /app/alembic.ini COPY ./backend/main.py /app/main.py COPY ./backend/oasst_backend /app/oasst_backend +COPY ./backend/test_data /app/test_data