diff --git a/backend/.gitignore b/backend/.gitignore new file mode 100644 index 00000000..aa177967 --- /dev/null +++ b/backend/.gitignore @@ -0,0 +1,4 @@ +__pycache__ +.env +notes.txt +db/ diff --git a/backend/alembic.ini b/backend/alembic.ini new file mode 100644 index 00000000..3c75af00 --- /dev/null +++ b/backend/alembic.ini @@ -0,0 +1,105 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = alemic + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python-dateutil library that can be +# installed by adding `alembic[tz]` to the pip requirements +# string value is passed to dateutil.tz.gettz() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to alemic/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:alemic/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = postgresql://ocgpt_backend_admin:829b5fc358842652fbdd04d0e6f012f2ff4227684f6b24a073b0ec78666e9d5a@localhost/ocgpt_backend + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/backend/alemic/README b/backend/alemic/README new file mode 100644 index 00000000..98e4f9c4 --- /dev/null +++ b/backend/alemic/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/backend/alemic/env.py b/backend/alemic/env.py new file mode 100644 index 00000000..6626bfd0 --- /dev/null +++ b/backend/alemic/env.py @@ -0,0 +1,78 @@ +from logging.config import fileConfig + +from sqlalchemy import engine_from_config +from sqlalchemy import pool + +from alembic import context + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +target_metadata = None + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/backend/alemic/script.py.mako b/backend/alemic/script.py.mako new file mode 100644 index 00000000..55df2863 --- /dev/null +++ b/backend/alemic/script.py.mako @@ -0,0 +1,24 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/backend/alemic/versions/23e5fea252dd_first_revision.py b/backend/alemic/versions/23e5fea252dd_first_revision.py new file mode 100644 index 00000000..7e51e72a --- /dev/null +++ b/backend/alemic/versions/23e5fea252dd_first_revision.py @@ -0,0 +1,68 @@ +"""first revision + +Revision ID: 23e5fea252dd +Revises: +Create Date: 2022-12-12 12:47:28.801354 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '23e5fea252dd' +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "service_client", + sa.Column("id", sa.Integer, sa.Identity()), + sa.Column("name", sa.String(200), nullable=False), + sa.Column("service_admin_email", sa.String(128), nullable=True), + sa.Column("api_key", sa.String(300), nullable=False), + sa.Column("can_append", sa.Boolean, nullable=False, server_default='true'), + sa.Column("can_write", sa.Boolean, nullable=False, server_default='false'), + sa.Column("can_delete", sa.Boolean, nullable=False, server_default='false'), + sa.Column("can_read", sa.Boolean, nullable=False, server_default='true'), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_service_client_api_key"), "service_client", ["api_key"], unique=True) + + + op.create_table( + "labeler", + sa.Column("id", sa.Integer, sa.Identity()), + sa.Column("display_name", sa.String(96), nullable=False), + sa.Column("discord_username", sa.String(96), nullable=True), + sa.Column("created_date", sa.DateTime, nullable=False, server_default=sa.func.current_timestamp()), + sa.Column("is_enabled", sa.Boolean, nullable=False, server_default='true'), + sa.Column("notes", sa.String(10*1024), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("discord_username") + ) + + op.create_table( + "prompt", + sa.Column("id", sa.Integer, sa.Identity()), + sa.Column("labeler_id", sa.Integer, nullable=False), + sa.Column("prompt", sa.Text, nullable=False), + sa.Column("response", sa.Text, nullable=True), + sa.Column("lang", sa.String(32), nullable=True), + sa.Column("created_date", sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()), + sa.ForeignKeyConstraint(["labeler_id"], ["labeler.id"],), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("prompt_labeler_id"), "prompt", ["labeler_id"], unique=False) + + +def downgrade() -> None: + op.drop_index(op.f("prompt_labeler_id"), table_name="prompt") + op.drop_table("prompt") + + op.drop_table("labeler") + + op.drop_index(op.f("ix_service_client_api_key"), table_name="service_client") + op.drop_table("service_client") diff --git a/backend/app/__init__.py b/backend/app/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/app/api/__init__.py b/backend/app/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py new file mode 100644 index 00000000..ba805075 --- /dev/null +++ b/backend/app/api/deps.py @@ -0,0 +1,44 @@ +from typing import Generator +from sqlmodel import Session +from fastapi import Security, HTTPException +from fastapi.security.api_key import APIKeyQuery, APIKeyHeader, APIKey +from app.database import engine +from app.models import ServiceClient + +from starlette.status import HTTP_403_FORBIDDEN + + +def get_db() -> Generator: + with Session(engine) as db: + yield db + + +api_key_query = APIKeyQuery(name="api_key", auto_error=False) +api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) + + +async def get_api_key( + api_key_query: str = Security(api_key_query), + api_key_header: str = Security(api_key_header), +): + if api_key_query: + return api_key_query + else: + return api_key_header + + +def api_auth( + api_key: APIKey, db: Session, create: bool = False, read: bool = True, update: bool = False, delete: bool = False +) -> ServiceClient: + if api_key is not None: + api_client = db.query(ServiceClient).filter(ServiceClient.api_key == api_key).first() + if api_client is not None: + if ( + (create == False or api_client.can_append) + and (read == False or api_client.can_read) + and (update == False or api_client.can_write) + and (delete == False or api_client.can_delete) + ): + return api_client + + raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials") diff --git a/backend/app/api/v1/__init__.py b/backend/app/api/v1/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/app/api/v1/api.py b/backend/app/api/v1/api.py new file mode 100644 index 00000000..59cdd6c7 --- /dev/null +++ b/backend/app/api/v1/api.py @@ -0,0 +1,7 @@ +from fastapi import APIRouter + +from app.api.v1 import labelers, prompts + +api_router = APIRouter() +api_router.include_router(labelers.router, prefix="/labelers", tags=["labelers"]) +api_router.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) diff --git a/backend/app/api/v1/labelers.py b/backend/app/api/v1/labelers.py new file mode 100644 index 00000000..8161c3b4 --- /dev/null +++ b/backend/app/api/v1/labelers.py @@ -0,0 +1,115 @@ +from typing import Any, List + +from fastapi import APIRouter, Depends, HTTPException +from fastapi.security.api_key import APIKey +from sqlmodel import Session +from starlette.status import HTTP_404_NOT_FOUND, HTTP_400_BAD_REQUEST + +from app import crud, schemas +from app.api import deps + + +router = APIRouter() + + +@router.get("/", response_model=List[schemas.Labeler]) +def read_labelers( + db: Session = Depends(deps.get_db), + api_key: APIKey = Depends(deps.get_api_key), + begin_id: int = 0, + limit: int = 100, +) -> Any: + """ + Retrieve labelers. + """ + deps.api_auth(api_key, db, read=True) + if limit > 10000: + raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail="Bad request") + labelers = crud.labeler.get_multi(db, begin_id=begin_id, limit=limit) + return labelers + + +@router.post("/", response_model=schemas.Labeler) +def create_labeler( + *, + db: Session = Depends(deps.get_db), + api_key: APIKey = Depends(deps.get_api_key), + item_in: schemas.LabelerCreate, +) -> Any: + """ + Create new labeler. + """ + deps.api_auth(api_key, db, create=True) + item = crud.labeler.create(db=db, obj_in=item_in) + return item + + +@router.put("/{id}", response_model=schemas.Labeler) +def update_labeler( + *, + db: Session = Depends(deps.get_db), + api_key: APIKey = Depends(deps.get_api_key), + id: int, + item_in: schemas.LabelerUpdate, +) -> Any: + """ + Update a labeler. + """ + deps.api_auth(api_key, db, update=True, read=True) + item = crud.labeler.get(db=db, id=id) + if not item: + raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="Item not found") + item = crud.labeler.update(db=db, db_obj=item, obj_in=item_in) + return item + + +@router.get("/by-username", response_model=schemas.Labeler) +def read_labeler_by_username( + *, + db: Session = Depends(deps.get_db), + api_key: APIKey = Depends(deps.get_api_key), + discord_username: str, +) -> Any: + """ + Get labeler by ID. + """ + deps.api_auth(api_key, db, read=True) + item = crud.labeler.get_by_discord_username(db=db, discord_username=discord_username) + if not item: + raise HTTPException(status_code=404, detail="Item not found") + return item + + +@router.get("/{id}", response_model=schemas.Labeler) +def read_labeler( + *, + db: Session = Depends(deps.get_db), + api_key: APIKey = Depends(deps.get_api_key), + id: int, +) -> Any: + """ + Get labeler by ID. + """ + deps.api_auth(api_key, db, read=True) + item = crud.labeler.get(db=db, id=id) + if not item: + raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="Item not found") + return item + + +@router.delete("/{id}", response_model=schemas.Labeler) +def delete_labeler( + *, + db: Session = Depends(deps.get_db), + api_key: APIKey = Depends(deps.get_api_key), + id: int, +) -> Any: + """ + Delete a labeler. + """ + deps.api_auth(api_key, db, delete=True) + labeler = crud.labeler.get(db=db, id=id) + if not labeler: + raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="Item not found") + labeler = crud.labeler.remove(db=db, id=id) + return labeler diff --git a/backend/app/api/v1/prompts.py b/backend/app/api/v1/prompts.py new file mode 100644 index 00000000..b7848229 --- /dev/null +++ b/backend/app/api/v1/prompts.py @@ -0,0 +1,92 @@ +from typing import Any, List + +from fastapi import APIRouter, Depends, HTTPException +from fastapi.security.api_key import APIKey +from sqlmodel import Session +from starlette.status import HTTP_404_NOT_FOUND, HTTP_400_BAD_REQUEST, HTTP_401_UNAUTHORIZED + +from app import crud, schemas +from app.api import deps + + +router = APIRouter() + + +@router.get("/", response_model=List[schemas.Prompt]) +def read_prompts( + db: Session = Depends(deps.get_db), + api_key: APIKey = Depends(deps.get_api_key), + begin_id: int = 0, + limit: int = 1000, +) -> Any: + """ + Retrieve prompts. + """ + deps.api_auth(api_key, db, read=True) + if limit > 10000: + raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail="Bad request") + return crud.prompt.get_multi(db, begin_id=begin_id, limit=limit) + + +@router.post("/", response_model=schemas.Prompt) +def create_prompt( + *, + db: Session = Depends(deps.get_db), + api_key: APIKey = Depends(deps.get_api_key), + item_in: schemas.PromptCreate, +) -> Any: + """ + Create new prompt. + """ + deps.api_auth(api_key, db, create=True) + if item_in.labeler_id is None: + if item_in.discord_username is None: + raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail="Bad request") + labeler = crud.labeler.get_by_discord_username(db=db, discord_username=item_in.discord_username) + else: + labeler = crud.labeler.get(db=db, id=item_in.labeler_id) + + if labeler is None: + raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="Invalid labeler user name") + if not labeler.is_enabled: + raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Labeler disabled") + + item_in.labeler_id = labeler.id + item_in.discord_username = None + item = crud.prompt.create(db=db, obj_in=item_in) + return item + + +@router.get("/{id}", response_model=schemas.Prompt) +def read_prompt( + *, + db: Session = Depends(deps.get_db), + api_key: APIKey = Depends(deps.get_api_key), + id: int, +) -> Any: + """ + Get prompt by ID. + """ + deps.api_auth(api_key, db, read=True) + item = crud.prompt.get(db=db, id=id) + if not item: + raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="Item not found") + return item + + +@router.delete("/{id}", response_model=schemas.Prompt) +def delete_prompt( + *, + db: Session = Depends(deps.get_db), + api_key: APIKey = Depends(deps.get_api_key), + id: int, +) -> Any: + """ + Delete a prompt. + """ + deps.api_auth(api_key, db, delete=True) + item = crud.prompt.get(db=db, id=id) + if not item: + raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="Item not found") + item = crud.prompt.remove(db=db, id=id) + return item diff --git a/backend/app/config.py b/backend/app/config.py new file mode 100644 index 00000000..895dbb00 --- /dev/null +++ b/backend/app/config.py @@ -0,0 +1,21 @@ +from typing import List, Optional, Union +from pydantic import AnyHttpUrl, BaseSettings, PostgresDsn, validator + + +class Settings(BaseSettings): + PROJECT_NAME: str = "open-chatGPT backend" + API_V1_STR: str = "/api/v1" + DATABASE_URI: Optional[PostgresDsn] = None + + BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = [] + + @validator("BACKEND_CORS_ORIGINS", pre=True) + def assemble_cors_origins(cls, v: Union[str, List[str]]) -> Union[List[str], str]: + if isinstance(v, str) and not v.startswith("["): + return [i.strip() for i in v.split(",")] + elif isinstance(v, (list, str)): + return v + raise ValueError(v) + + +settings = Settings(_env_file=".env") diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py new file mode 100644 index 00000000..00548245 --- /dev/null +++ b/backend/app/crud/__init__.py @@ -0,0 +1,2 @@ +from .crud_labeler import labeler +from .crud_prompt import prompt \ No newline at end of file diff --git a/backend/app/crud/base.py b/backend/app/crud/base.py new file mode 100644 index 00000000..927359d2 --- /dev/null +++ b/backend/app/crud/base.py @@ -0,0 +1,65 @@ +from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union + +from fastapi.encoders import jsonable_encoder +from pydantic import BaseModel +from sqlmodel import Session, SQLModel + + +ModelType = TypeVar("ModelType", bound=SQLModel) +CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) +UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) + + +class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): + def __init__(self, model: Type[ModelType]): + """ + CRUD object with default methods to Create, Read, Update, Delete (CRUD). + + **Parameters** + + * `model`: A SQLModel model class + * `schema`: A Pydantic model (schema) class + """ + self.model = model + + def get(self, db: Session, id: Any) -> Optional[ModelType]: + return db.query(self.model).filter(self.model.id == id).first() + + def get_multi( + self, db: Session, *, begin_id: int = 0, limit: int = 100 + ) -> List[ModelType]: + return db.query(self.model).filter(self.model.id >= begin_id).limit(limit).all() + + def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType: + obj_in_data = jsonable_encoder(obj_in) + db_obj = self.model(**obj_in_data) # type: ignore + db.add(db_obj) + db.commit() + db.refresh(db_obj) + return db_obj + + def update( + self, + db: Session, + *, + db_obj: ModelType, + obj_in: Union[UpdateSchemaType, Dict[str, Any]] + ) -> ModelType: + obj_data = jsonable_encoder(db_obj) + if isinstance(obj_in, dict): + update_data = obj_in + else: + update_data = obj_in.dict(exclude_unset=True) + for field in obj_data: + if field in update_data: + setattr(db_obj, field, update_data[field]) + db.add(db_obj) + db.commit() + db.refresh(db_obj) + return db_obj + + def delete(self, db: Session, *, id: int) -> ModelType: + obj = db.query(self.model).get(id) + db.delete(obj) + db.commit() + return obj diff --git a/backend/app/crud/crud_labeler.py b/backend/app/crud/crud_labeler.py new file mode 100644 index 00000000..a47aee34 --- /dev/null +++ b/backend/app/crud/crud_labeler.py @@ -0,0 +1,14 @@ +from typing import Optional + +from app.crud.base import CRUDBase +from app.models.labeler import Labeler +from app.schemas.labeler import LabelerCreate, LabelerUpdate +from sqlmodel import Session + + +class CRUDLabeler(CRUDBase[Labeler, LabelerCreate, LabelerUpdate]): + def get_by_discord_username(self, db: Session, discord_username: str) -> Optional[Labeler]: + return db.query(Labeler).filter(Labeler.discord_username == discord_username).first() + + +labeler = CRUDLabeler(Labeler) diff --git a/backend/app/crud/crud_prompt.py b/backend/app/crud/crud_prompt.py new file mode 100644 index 00000000..7288b82c --- /dev/null +++ b/backend/app/crud/crud_prompt.py @@ -0,0 +1,10 @@ +from app.crud.base import CRUDBase +from app.models.prompt import Prompt +from app.schemas.prompt import PromptCreate + + +class CRUDPrompt(CRUDBase[Prompt, PromptCreate, None]): + pass + + +prompt = CRUDPrompt(Prompt) diff --git a/backend/app/database.py b/backend/app/database.py new file mode 100644 index 00000000..c9562b13 --- /dev/null +++ b/backend/app/database.py @@ -0,0 +1,4 @@ +from sqlmodel import create_engine +from app.config import settings + +engine = create_engine(settings.DATABASE_URI) diff --git a/backend/app/db_test.py b/backend/app/db_test.py new file mode 100644 index 00000000..f6d5a374 --- /dev/null +++ b/backend/app/db_test.py @@ -0,0 +1,37 @@ +import dataclasses +from datetime import datetime +import json +from typing import Optional +import argparse +from dataclasses import dataclass + + +from sqlmodel import Session, SQLModel, create_engine +from app.config import settings + +import app.api.deps + + + +def main(): + print(settings.dict()) + quit() + + args = parse_args() + cfg = load_configuration_file(args.config) + + engine = create_engine(cfg.database_url) + app.api.deps.engine = engine + + """ + with Session(engine) as session: + # create a test serivice + #sc1 = ServiceClient(name='blub', api_key='1234') + #session.add(sc1) + + session.commit() + """ + + +if __name__ == '__main__': + main() diff --git a/backend/app/main.py b/backend/app/main.py new file mode 100644 index 00000000..343e142a --- /dev/null +++ b/backend/app/main.py @@ -0,0 +1,21 @@ +from fastapi import FastAPI +from starlette.middleware.cors import CORSMiddleware + +from app.api.v1.api import api_router +from app.config import settings + +app = FastAPI( + title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V1_STR}/openapi.json" +) + +# Set all CORS enabled origins +if settings.BACKEND_CORS_ORIGINS: + app.add_middleware( + CORSMiddleware, + allow_origins=[str(origin) for origin in settings.BACKEND_CORS_ORIGINS], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + +app.include_router(api_router, prefix=settings.API_V1_STR) diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py new file mode 100644 index 00000000..e54a9211 --- /dev/null +++ b/backend/app/models/__init__.py @@ -0,0 +1,3 @@ +from .service_client import ServiceClient +from .labeler import Labeler +from .prompt import Prompt diff --git a/backend/app/models/labeler.py b/backend/app/models/labeler.py new file mode 100644 index 00000000..bfef9232 --- /dev/null +++ b/backend/app/models/labeler.py @@ -0,0 +1,14 @@ +from datetime import datetime +import sqlalchemy as sa +from sqlmodel import Field, SQLModel +from typing import Optional + + +class Labeler(SQLModel, table=True): + __tablename__ = "labeler" + id: Optional[int] = Field(default=None, primary_key=True) + display_name: str + discord_username: str + created_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()), nullable=False) + is_enabled: bool + notes: str diff --git a/backend/app/models/prompt.py b/backend/app/models/prompt.py new file mode 100644 index 00000000..fdb3a002 --- /dev/null +++ b/backend/app/models/prompt.py @@ -0,0 +1,15 @@ +from datetime import datetime +import sqlalchemy as sa +from sqlmodel import Field, SQLModel +from typing import Optional + + +class Prompt(SQLModel, table=True): + __tablename__ = "prompt" + id: Optional[int] = Field(default=None, primary_key=True) + labeler_id: Optional[int] = Field(default=None, foreign_key="labeler.id") + prompt: str + response: Optional[str] + lang: Optional[str] + created_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()), nullable=False) + \ No newline at end of file diff --git a/backend/app/models/service_client.py b/backend/app/models/service_client.py new file mode 100644 index 00000000..54960636 --- /dev/null +++ b/backend/app/models/service_client.py @@ -0,0 +1,15 @@ +from sqlmodel import Field, SQLModel +from typing import Optional + + +class ServiceClient(SQLModel, table=True): + __tablename__ = "service_client" + id: Optional[int] = Field(default=None, primary_key=True) + name: str + api_key: str + service_admin_email: Optional[str] = None + api_key: str + can_append: bool = True + can_write: bool = False + can_delete: bool = False + can_read: bool = True diff --git a/backend/app/schemas/__init__.py b/backend/app/schemas/__init__.py new file mode 100644 index 00000000..7fcc61b4 --- /dev/null +++ b/backend/app/schemas/__init__.py @@ -0,0 +1,2 @@ +from .labeler import Labeler, LabelerCreate, LabelerUpdate +from .prompt import Prompt, PromptCreate diff --git a/backend/app/schemas/labeler.py b/backend/app/schemas/labeler.py new file mode 100644 index 00000000..c0774426 --- /dev/null +++ b/backend/app/schemas/labeler.py @@ -0,0 +1,26 @@ +from typing import Optional +from datetime import datetime +from pydantic import BaseModel + + +class Labeler(BaseModel): + id: int + discord_username: str + display_name: str + created_date: datetime + is_enabled: str + notes: Optional[str] + + +class LabelerCreate(BaseModel): + discord_username: str + display_name: Optional[str] + is_enabled: Optional[bool] = True + notes: Optional[str] = None + + +class LabelerUpdate(BaseModel): + discord_username: Optional[str] = None + display_name: Optional[str] = None + enabled: Optional[bool] = None + notes: Optional[str] = None diff --git a/backend/app/schemas/prompt.py b/backend/app/schemas/prompt.py new file mode 100644 index 00000000..928ab74f --- /dev/null +++ b/backend/app/schemas/prompt.py @@ -0,0 +1,20 @@ +from typing import Optional +from datetime import datetime +from pydantic import BaseModel + + +class Prompt(BaseModel): + id: int + labeler_id: int + prompt: str + response: Optional[str] + lang: Optional[str] + created_date: datetime + + +class PromptCreate(BaseModel): + labeler_id: Optional[int] = None + discord_username: Optional[str] = None + prompt: str + response: Optional[str] = None + lang: Optional[str] = None diff --git a/backend/app/tests/__init__.py b/backend/app/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/requirements.txt b/backend/requirements.txt new file mode 100644 index 00000000..e271f16c --- /dev/null +++ b/backend/requirements.txt @@ -0,0 +1,9 @@ +fastapi==0.88.0 +pydantic==1.9.1 +SQLAlchemy==1.4.41 +sqlmodel==0.0.8 +starlette==0.22.0 +uvicorn==0.20.0 +psycopg2-binary==2.9.5 +alembic==1.8.1 +python-dotenv==0.21.0 diff --git a/backend/scripts/run-local.sh b/backend/scripts/run-local.sh new file mode 100755 index 00000000..b6bc997a --- /dev/null +++ b/backend/scripts/run-local.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env bash + +uvicorn app.main:app --reload \ No newline at end of file