Add DEBUG_USD_SEED_DATA_PATH in config to make seed data flexible (#395)

* Add DEBUG_USD_SEED_DATA_PATH in config to make seed data flexible

* reformat

* Copy test_data folder in Dockerfile.backend, correct DEBUG_USE_SEED_DATA_PATH in cofig

* - make DEBUG_USE_SEED_DATA_PATH to absolute path
- correct test_data path in  Dockerfile.backend
This commit is contained in:
Ken Tsui
2023-01-07 21:37:30 +08:00
committed by GitHub
parent 99dcfd06ed
commit 043b5eff5a
4 changed files with 71 additions and 62 deletions
+7 -61
View File
@@ -1,3 +1,4 @@
import json
from http import HTTPStatus
from math import ceil
from pathlib import Path
@@ -6,7 +7,6 @@ from typing import Optional
import alembic.command
import alembic.config
import fastapi
import pydantic
import redis.asyncio as redis
from fastapi_limiter import FastAPILimiter
from loguru import logger
@@ -17,6 +17,7 @@ from oasst_backend.database import engine
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from pydantic import BaseModel
from sqlmodel import Session
from starlette.middleware.cors import CORSMiddleware
@@ -97,7 +98,7 @@ if settings.DEBUG_USE_SEED_DATA:
@app.on_event("startup")
def seed_data():
class DummyMessage(pydantic.BaseModel):
class DummyMessage(BaseModel):
task_message_id: str
user_message_id: str
parent_message_id: Optional[str]
@@ -111,64 +112,10 @@ if settings.DEBUG_USE_SEED_DATA:
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
pr = PromptRepository(db=db, api_client=api_client, user=dummy_user)
dummy_messages = [
DummyMessage(
task_message_id="de111fa8",
user_message_id="6f1d0711",
parent_message_id=None,
text="Hi!",
role="prompter",
),
DummyMessage(
task_message_id="74c381d4",
user_message_id="4a24530b",
parent_message_id="6f1d0711",
text="Hello! How can I help you?",
role="assistant",
),
DummyMessage(
task_message_id="3d5dc440",
user_message_id="a8c01c04",
parent_message_id="4a24530b",
text="Do you have a recipe for potato soup?",
role="prompter",
),
DummyMessage(
task_message_id="643716c1",
user_message_id="f43a93b7",
parent_message_id="4a24530b",
text="Who were the 8 presidents before George Washington?",
role="prompter",
),
DummyMessage(
task_message_id="2e4e1e6",
user_message_id="c886920",
parent_message_id="6f1d0711",
text="Hey buddy! How can I serve you?",
role="assistant",
),
DummyMessage(
task_message_id="970c437d",
user_message_id="cec432cf",
parent_message_id=None,
text="euirdteunvglfe23908230892309832098 AAAAAAAA",
role="prompter",
),
DummyMessage(
task_message_id="6066118e",
user_message_id="4f85f637",
parent_message_id="cec432cf",
text="Sorry, I did not understand your request and it is unclear to me what you want me to do. Could you describe it in a different way?",
role="assistant",
),
DummyMessage(
task_message_id="ba87780d",
user_message_id="0e276b98",
parent_message_id="cec432cf",
text="I'm unsure how to interpret this. Is it a riddle?",
role="assistant",
),
]
with open(settings.DEBUG_USE_SEED_DATA_PATH) as f:
dummy_messages_raw = json.load(f)
dummy_messages = [DummyMessage(**dm) for dm in dummy_messages_raw]
for msg in dummy_messages:
task = pr.fetch_task_by_frontend_message_id(msg.task_message_id)
@@ -219,7 +166,6 @@ if __name__ == "__main__":
# Importing here so we don't import packages unnecessarily if we're
# importing main as a module.
import argparse
import json
import uvicorn
+5 -1
View File
@@ -1,6 +1,7 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from pydantic import AnyHttpUrl, BaseSettings, PostgresDsn, validator
from pydantic import AnyHttpUrl, BaseSettings, FilePath, PostgresDsn, validator
class Settings(BaseSettings):
@@ -21,6 +22,9 @@ class Settings(BaseSettings):
DEBUG_ALLOW_ANY_API_KEY: bool = False
DEBUG_SKIP_API_KEY_CHECK: bool = False
DEBUG_USE_SEED_DATA: bool = False
DEBUG_USE_SEED_DATA_PATH: Optional[FilePath] = (
Path(__file__).parent.parent / "test_data/generic/test_generic_data.json"
)
HUGGING_FACE_API_KEY: str = ""
@@ -0,0 +1,58 @@
[
{
"task_message_id": "de111fa8",
"user_message_id": "6f1d0711",
"parent_message_id": null,
"text": "Hi!",
"role": "prompter"
},
{
"task_message_id": "74c381d4",
"user_message_id": "4a24530b",
"parent_message_id": "6f1d0711",
"text": "Hello! How can I help you?",
"role": "assistant"
},
{
"task_message_id": "3d5dc440",
"user_message_id": "a8c01c04",
"parent_message_id": "4a24530b",
"text": "Do you have a recipe for potato soup?",
"role": "prompter"
},
{
"task_message_id": "643716c1",
"user_message_id": "f43a93b7",
"parent_message_id": "4a24530b",
"text": "Who were the 8 presidents before George Washington?",
"role": "prompter"
},
{
"task_message_id": "2e4e1e6",
"user_message_id": "c886920",
"parent_message_id": "6f1d0711",
"text": "Hey buddy! How can I serve you?",
"role": "assistant"
},
{
"task_message_id": "970c437d",
"user_message_id": "cec432cf",
"parent_message_id": null,
"text": "euirdteunvglfe23908230892309832098 AAAAAAAA",
"role": "prompter"
},
{
"task_message_id": "6066118e",
"user_message_id": "4f85f637",
"parent_message_id": "cec432cf",
"text": "Sorry, I did not understand your request and it is unclear to me what you want me to do. Could you describe it in a different way?",
"role": "assistant"
},
{
"task_message_id": "ba87780d",
"user_message_id": "0e276b98",
"parent_message_id": "cec432cf",
"text": "I'm unsure how to interpret this. Is it a riddle?",
"role": "assistant"
}
]
+1
View File
@@ -14,3 +14,4 @@ COPY ./backend/alembic /app/alembic
COPY ./backend/alembic.ini /app/alembic.ini
COPY ./backend/main.py /app/main.py
COPY ./backend/oasst_backend /app/oasst_backend
COPY ./backend/test_data /app/test_data