mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-28 16:20:34 +08:00
Merge branch 'v1_db_schema' into main
This commit is contained in:
@@ -0,0 +1,192 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""v1 db structure
|
||||
|
||||
Revision ID: cd7de470586e
|
||||
Revises: 23e5fea252dd
|
||||
Create Date: 2022-12-15 11:15:32.830225
|
||||
|
||||
"""
|
||||
import uuid
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "cd7de470586e"
|
||||
down_revision = "23e5fea252dd"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# remove database objects
|
||||
op.drop_index(op.f("prompt_labeler_id"), table_name="prompt")
|
||||
op.drop_table("prompt")
|
||||
op.drop_table("labeler")
|
||||
op.drop_index(op.f("ix_service_client_api_key"), table_name="service_client")
|
||||
op.drop_table("service_client")
|
||||
|
||||
# wreate new database structure
|
||||
op.create_table(
|
||||
"api_client",
|
||||
sa.Column("id", UUID(as_uuid=True), default=uuid.uuid4, server_default=sa.text("gen_random_uuid()")),
|
||||
sa.Column("api_key", sa.String(512), nullable=False),
|
||||
sa.Column("description", sa.String(256), nullable=False),
|
||||
sa.Column("admin_email", sa.String(256), nullable=True),
|
||||
sa.Column("enabled", sa.Boolean, default=True, nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("ix_api_client_api_key"), "api_client", ["api_key"], unique=True)
|
||||
|
||||
op.create_table(
|
||||
"person",
|
||||
sa.Column("id", UUID(as_uuid=True), default=uuid.uuid4, server_default=sa.text("gen_random_uuid()")),
|
||||
sa.Column("username", sa.String(128), nullable=False), # unique in combination with api_client_id
|
||||
sa.Column("display_name", sa.String(256), nullable=False), # cached last seen display_name
|
||||
sa.Column("created_date", sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
sa.Column("api_client_id", UUID(as_uuid=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"]),
|
||||
)
|
||||
op.create_index(op.f("ix_person_username"), "person", ["api_client_id", "username"], unique=True)
|
||||
|
||||
op.create_table(
|
||||
"person_stats",
|
||||
sa.Column("person_id", UUID(as_uuid=True)),
|
||||
sa.Column("leader_score", sa.Integer, default=0, nullable=False), # determines position on leader board
|
||||
sa.Column("modified_date", sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
sa.Column("reactions", sa.Integer, default=0, nullable=False), # reactions sent by user
|
||||
sa.Column("posts", sa.Integer, default=0, nullable=False), # posts sent by user
|
||||
sa.Column("upvotes", sa.Integer, default=0, nullable=False), # received upvotes (form other users)
|
||||
sa.Column("downvotes", sa.Integer, default=0, nullable=False), # received downvotes (from other users)
|
||||
sa.Column("work_reward", sa.Integer, default=0, nullable=False), # reward for workpackage completions
|
||||
sa.Column("compare_wins", sa.Integer, default=0, nullable=False), # num times user's post won compare tasks
|
||||
sa.Column("compare_losses", sa.Integer, default=0, nullable=False), # num times users's post lost compare tasks
|
||||
sa.PrimaryKeyConstraint("person_id"),
|
||||
sa.ForeignKeyConstraint(["person_id"], ["person.id"]),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"work_package",
|
||||
sa.Column("id", UUID(as_uuid=True), default=uuid.uuid4, server_default=sa.text("gen_random_uuid()")),
|
||||
sa.Column("created_date", sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
sa.Column("expiry_date", sa.DateTime(), nullable=True),
|
||||
sa.Column("person_id", UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("payload_type", sa.String(200), nullable=False), # deserialization hint & dbg aid
|
||||
sa.Column("payload", JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.Column("api_client_id", UUID(as_uuid=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["person_id"], ["person.id"]),
|
||||
sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"]),
|
||||
)
|
||||
op.create_index(op.f("ix_work_package_person_id"), "work_package", ["person_id"], unique=False)
|
||||
|
||||
op.create_table(
|
||||
"post",
|
||||
sa.Column("id", UUID(as_uuid=True), default=uuid.uuid4, server_default=sa.text("gen_random_uuid()")),
|
||||
sa.Column("parent_id", UUID(as_uuid=True), nullable=True), # root posts have NULL parent
|
||||
sa.Column("thread_id", UUID(as_uuid=True), nullable=False), # id of thread root
|
||||
sa.Column("workpackage_id", UUID(as_uuid=True), nullable=True), # workpackage id to pass to handler on reply
|
||||
sa.Column("person_id", UUID(as_uuid=True), nullable=True), # sender (recipients are part of payload)
|
||||
sa.Column("api_client_id", UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("role", sa.String(128), nullable=False), # 'assistant', 'user' or something else
|
||||
sa.Column("frontend_post_id", sa.String(200), nullable=False), # unique together with api_client_id
|
||||
sa.Column("created_date", sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
sa.Column("payload_type", sa.String(200), nullable=False), # deserialization hint & dbg aid
|
||||
sa.Column("payload", JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["person_id"], ["person.id"]),
|
||||
sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"]),
|
||||
)
|
||||
op.create_index(op.f("ix_post_frontend_post_id"), "post", ["api_client_id", "frontend_post_id"], unique=True)
|
||||
op.create_index(op.f("ix_post_thread_id"), "post", ["thread_id"], unique=False)
|
||||
op.create_index(op.f("ix_post_workpackage_id"), "post", ["workpackage_id"], unique=False)
|
||||
op.create_index(op.f("ix_post_person_id"), "post", ["person_id"], unique=False)
|
||||
|
||||
op.create_table(
|
||||
"post_reaction",
|
||||
sa.Column("post_id", UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("person_id", UUID(as_uuid=True), nullable=False), # sender (recipients are part of payload)
|
||||
sa.Column("created_date", sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
sa.Column("payload_type", sa.String(200), nullable=False), # deserialization hint & dbg aid
|
||||
sa.Column("payload", JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.Column("api_client_id", UUID(as_uuid=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("post_id", "person_id"),
|
||||
sa.ForeignKeyConstraint(["post_id"], ["post.id"]),
|
||||
sa.ForeignKeyConstraint(["person_id"], ["person.id"]),
|
||||
sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"]),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("post_reaction")
|
||||
|
||||
op.drop_index("ix_post_person_id")
|
||||
op.drop_index("ix_post_workpackage_id")
|
||||
op.drop_index("ix_post_thread_id")
|
||||
op.drop_index("ix_post_frontend_post_id")
|
||||
op.drop_table("post")
|
||||
|
||||
op.drop_index("ix_work_package_person_id")
|
||||
op.drop_table("work_package")
|
||||
|
||||
op.drop_table("person_stats")
|
||||
|
||||
op.drop_index("ix_person_username")
|
||||
op.drop_table("person")
|
||||
|
||||
op.drop_index("ix_api_client_api_key")
|
||||
op.drop_table("api_client")
|
||||
|
||||
op.create_table(
|
||||
"service_client",
|
||||
sa.Column("id", sa.Integer, sa.Identity()),
|
||||
sa.Column("name", sa.String(200), nullable=False),
|
||||
sa.Column("service_admin_email", sa.String(128), nullable=True),
|
||||
sa.Column("api_key", sa.String(300), nullable=False),
|
||||
sa.Column("can_append", sa.Boolean, nullable=False, server_default="true"),
|
||||
sa.Column("can_write", sa.Boolean, nullable=False, server_default="false"),
|
||||
sa.Column("can_delete", sa.Boolean, nullable=False, server_default="false"),
|
||||
sa.Column("can_read", sa.Boolean, nullable=False, server_default="true"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("ix_service_client_api_key"), "service_client", ["api_key"], unique=True)
|
||||
|
||||
op.create_table(
|
||||
"labeler",
|
||||
sa.Column("id", sa.Integer, sa.Identity()),
|
||||
sa.Column("display_name", sa.String(96), nullable=False),
|
||||
sa.Column("discord_username", sa.String(96), nullable=True),
|
||||
sa.Column(
|
||||
"created_date",
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=sa.func.current_timestamp(),
|
||||
),
|
||||
sa.Column("is_enabled", sa.Boolean, nullable=False, server_default="true"),
|
||||
sa.Column("notes", sa.String(10 * 1024), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("discord_username"),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"prompt",
|
||||
sa.Column("id", sa.Integer, sa.Identity()),
|
||||
sa.Column("labeler_id", sa.Integer, nullable=False),
|
||||
sa.Column("prompt", sa.Text, nullable=False),
|
||||
sa.Column("response", sa.Text, nullable=True),
|
||||
sa.Column("lang", sa.String(32), nullable=True),
|
||||
sa.Column(
|
||||
"created_date",
|
||||
sa.DateTime(),
|
||||
nullable=False,
|
||||
server_default=sa.func.current_timestamp(),
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["labeler_id"],
|
||||
["labeler.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("prompt_labeler_id"), "prompt", ["labeler_id"], unique=False)
|
||||
+10
-18
@@ -1,9 +1,9 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Generator
|
||||
|
||||
from uuid import UUID
|
||||
from app.config import settings
|
||||
from app.database import engine
|
||||
from app.models import ServiceClient
|
||||
from app.models import ApiClient
|
||||
from fastapi import HTTPException, Security
|
||||
from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery
|
||||
from sqlmodel import Session
|
||||
@@ -32,24 +32,16 @@ async def get_api_key(
|
||||
def api_auth(
|
||||
api_key: APIKey,
|
||||
db: Session,
|
||||
create: bool = False,
|
||||
read: bool = True,
|
||||
update: bool = False,
|
||||
delete: bool = False,
|
||||
) -> ServiceClient:
|
||||
) -> ApiClient:
|
||||
|
||||
if api_key is not None:
|
||||
if settings.ALLOW_ANY_API_KEY:
|
||||
return ServiceClient(
|
||||
api_key=api_key, name=api_key, can_append=True, can_read=True, can_write=True, can_delete=True
|
||||
return ApiClient(
|
||||
id=UUID('00000000-1111-2222-3333-444444444444'),
|
||||
api_key=api_key, name=api_key
|
||||
)
|
||||
api_client = db.query(ServiceClient).filter(ServiceClient.api_key == api_key).first()
|
||||
if api_client is not None:
|
||||
if (
|
||||
(create is False or api_client.can_append)
|
||||
and (read is False or api_client.can_read)
|
||||
and (update is False or api_client.can_write)
|
||||
and (delete is False or api_client.can_delete)
|
||||
):
|
||||
return api_client
|
||||
api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first()
|
||||
if api_client is not None and api_client.enabled:
|
||||
return api_client
|
||||
|
||||
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials")
|
||||
|
||||
@@ -21,7 +21,7 @@ def read_labelers(
|
||||
"""
|
||||
Retrieve labelers.
|
||||
"""
|
||||
deps.api_auth(api_key, db, read=True)
|
||||
deps.api_auth(api_key, db)
|
||||
if limit > 10000:
|
||||
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail="Bad request")
|
||||
labelers = crud.labeler.get_multi(db, begin_id=begin_id, limit=limit)
|
||||
@@ -38,7 +38,7 @@ def create_labeler(
|
||||
"""
|
||||
Create new labeler.
|
||||
"""
|
||||
deps.api_auth(api_key, db, create=True)
|
||||
deps.api_auth(api_key, db)
|
||||
item = crud.labeler.create(db=db, obj_in=item_in)
|
||||
return item
|
||||
|
||||
@@ -54,7 +54,7 @@ def update_labeler(
|
||||
"""
|
||||
Update a labeler.
|
||||
"""
|
||||
deps.api_auth(api_key, db, update=True, read=True)
|
||||
deps.api_auth(api_key, db)
|
||||
item = crud.labeler.get(db=db, id=id)
|
||||
if not item:
|
||||
raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="Item not found")
|
||||
@@ -72,7 +72,7 @@ def read_labeler_by_username(
|
||||
"""
|
||||
Get labeler by ID.
|
||||
"""
|
||||
deps.api_auth(api_key, db, read=True)
|
||||
deps.api_auth(api_key, db)
|
||||
item = crud.labeler.get_by_discord_username(db=db, discord_username=discord_username)
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Item not found")
|
||||
@@ -89,7 +89,7 @@ def read_labeler(
|
||||
"""
|
||||
Get labeler by ID.
|
||||
"""
|
||||
deps.api_auth(api_key, db, read=True)
|
||||
deps.api_auth(api_key, db)
|
||||
item = crud.labeler.get(db=db, id=id)
|
||||
if not item:
|
||||
raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="Item not found")
|
||||
@@ -106,7 +106,7 @@ def delete_labeler(
|
||||
"""
|
||||
Delete a labeler.
|
||||
"""
|
||||
deps.api_auth(api_key, db, delete=True)
|
||||
deps.api_auth(api_key, db)
|
||||
labeler = crud.labeler.get(db=db, id=id)
|
||||
if not labeler:
|
||||
raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="Item not found")
|
||||
|
||||
@@ -21,7 +21,7 @@ def read_prompts(
|
||||
"""
|
||||
Retrieve prompts.
|
||||
"""
|
||||
deps.api_auth(api_key, db, read=True)
|
||||
deps.api_auth(api_key, db)
|
||||
if limit > 10000:
|
||||
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail="Bad request")
|
||||
return crud.prompt.get_multi(db, begin_id=begin_id, limit=limit)
|
||||
@@ -37,7 +37,7 @@ def create_prompt(
|
||||
"""
|
||||
Create new prompt.
|
||||
"""
|
||||
deps.api_auth(api_key, db, create=True)
|
||||
deps.api_auth(api_key, db)
|
||||
if item_in.labeler_id is None:
|
||||
if item_in.discord_username is None:
|
||||
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail="Bad request")
|
||||
@@ -66,7 +66,7 @@ def read_prompt(
|
||||
"""
|
||||
Get prompt by ID.
|
||||
"""
|
||||
deps.api_auth(api_key, db, read=True)
|
||||
deps.api_auth(api_key, db)
|
||||
item = crud.prompt.get(db=db, id=id)
|
||||
if not item:
|
||||
raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="Item not found")
|
||||
@@ -83,7 +83,7 @@ def delete_prompt(
|
||||
"""
|
||||
Delete a prompt.
|
||||
"""
|
||||
deps.api_auth(api_key, db, delete=True)
|
||||
deps.api_auth(api_key, db)
|
||||
item = crud.prompt.get(db=db, id=id)
|
||||
if not item:
|
||||
raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="Item not found")
|
||||
|
||||
@@ -55,7 +55,7 @@ def request_task(
|
||||
"""
|
||||
Create new task.
|
||||
"""
|
||||
deps.api_auth(api_key, db, create=True)
|
||||
deps.api_auth(api_key, db)
|
||||
|
||||
try:
|
||||
task = generate_task(request)
|
||||
@@ -79,7 +79,7 @@ def acknowledge_task(
|
||||
"""
|
||||
The frontend acknowledges a task.
|
||||
"""
|
||||
deps.api_auth(api_key, db, create=True)
|
||||
deps.api_auth(api_key, db)
|
||||
|
||||
match (type(response)):
|
||||
case protocol_schema.PostCreatedTaskResponse:
|
||||
@@ -107,7 +107,7 @@ def post_interaction(
|
||||
"""
|
||||
The frontend reports an interaction.
|
||||
"""
|
||||
deps.api_auth(api_key, db, create=True)
|
||||
deps.api_auth(api_key, db)
|
||||
|
||||
match (type(interaction)):
|
||||
case protocol_schema.TextReplyToPost:
|
||||
|
||||
@@ -1,6 +1,22 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from .api_client import ApiClient
|
||||
from .labeler import Labeler
|
||||
from .person import Person
|
||||
from .person_stats import PersonStats
|
||||
from .post import Post
|
||||
from .post_reaction import PostReaction
|
||||
from .prompt import Prompt
|
||||
from .service_client import ServiceClient
|
||||
from .work_package import WorkPackage
|
||||
|
||||
__all__ = ["Labeler", "Prompt", "ServiceClient"]
|
||||
__all__ = [
|
||||
"ApiClient",
|
||||
"Person",
|
||||
"PersonStats",
|
||||
"Post",
|
||||
"PostReaction",
|
||||
"WorkPackage",
|
||||
"Labeler",
|
||||
"Prompt",
|
||||
"ServiceClient",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class ApiClient(SQLModel, table=True):
|
||||
__tablename__ = "api_client"
|
||||
|
||||
id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
|
||||
),
|
||||
)
|
||||
api_key: str = Field(max_length=512, index=True, unique=True)
|
||||
description: str = Field(max_length=256)
|
||||
admin_email: Optional[str] = Field(max_length=256, nullable=True)
|
||||
enabled: bool = Field(default=True)
|
||||
@@ -12,8 +12,7 @@ class Labeler(SQLModel, table=True):
|
||||
display_name: str
|
||||
discord_username: str
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
nullable=False,
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
is_enabled: bool
|
||||
notes: str
|
||||
|
||||
@@ -0,0 +1,102 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import json
|
||||
from typing import Any, Generic, Type, TypeVar
|
||||
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel, parse_obj_as, validator
|
||||
from pydantic.main import ModelMetaclass
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
|
||||
payload_type_registry = {}
|
||||
|
||||
|
||||
P = TypeVar("P", bound=BaseModel)
|
||||
|
||||
|
||||
def payload_tpye(cls: Type[P]) -> Type[P]:
|
||||
payload_type_registry[cls.__name__] = cls
|
||||
return cls
|
||||
|
||||
|
||||
class PayloadContainer(BaseModel):
|
||||
payload_type: str = ""
|
||||
payload: BaseModel = None
|
||||
|
||||
def __init__(self, **v):
|
||||
p = v["payload"]
|
||||
if isinstance(p, dict):
|
||||
t = v["payload_type"]
|
||||
if t not in payload_type_registry:
|
||||
raise RuntimeError(f"Payload type '{t}' not registered")
|
||||
cls = payload_type_registry[t]
|
||||
v["payload"] = cls(**p)
|
||||
super().__init__(**v)
|
||||
|
||||
@validator("payload", pre=True)
|
||||
def check_payload(cls, v: BaseModel, values: dict[str, Any]) -> BaseModel:
|
||||
values["payload_type"] = type(v).__name__
|
||||
return v
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def payload_column_type(pydantic_type):
|
||||
class PayloadJSONBType(TypeDecorator, Generic[T]):
|
||||
impl = pg.JSONB()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
json_encoder=json,
|
||||
):
|
||||
self.json_encoder = json_encoder
|
||||
super(PayloadJSONBType, self).__init__()
|
||||
|
||||
# serialize
|
||||
def bind_processor(self, dialect):
|
||||
impl_processor = self.impl.bind_processor(dialect)
|
||||
dumps = self.json_encoder.dumps
|
||||
|
||||
def process(value: T):
|
||||
if value is not None:
|
||||
if isinstance(pydantic_type, ModelMetaclass):
|
||||
# This allows to assign non-InDB models and if they're
|
||||
# compatible, they're directly parsed into the InDB
|
||||
# representation, thus hiding the implementation in the
|
||||
# background. However, the InDB model will still be returned
|
||||
value_to_dump = pydantic_type.from_orm(value)
|
||||
else:
|
||||
value_to_dump = value
|
||||
|
||||
value = jsonable_encoder(value_to_dump)
|
||||
|
||||
if impl_processor:
|
||||
return impl_processor(value)
|
||||
else:
|
||||
return dumps(jsonable_encoder(value_to_dump))
|
||||
|
||||
return process
|
||||
|
||||
# deserialize
|
||||
def result_processor(self, dialect, coltype) -> T:
|
||||
impl_processor = self.impl.result_processor(dialect, coltype)
|
||||
|
||||
def process(value):
|
||||
if impl_processor:
|
||||
value = impl_processor(value)
|
||||
if value is None:
|
||||
return None
|
||||
# Explicitly use the generic directly, not type(T)
|
||||
full_obj = parse_obj_as(pydantic_type, value)
|
||||
return full_obj
|
||||
|
||||
return process
|
||||
|
||||
def compare_values(self, x, y):
|
||||
return x == y
|
||||
|
||||
return PayloadJSONBType
|
||||
@@ -0,0 +1,25 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, Index, SQLModel
|
||||
|
||||
|
||||
class Person(SQLModel, table=True):
|
||||
__tablename__ = "person"
|
||||
__table_args__ = (Index("ix_person_username", "api_client_id", "username", unique=True),)
|
||||
|
||||
id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
|
||||
),
|
||||
)
|
||||
username: str = Field(nullable=False, max_length=128)
|
||||
display_name: str = Field(nullable=False, max_length=256)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
api_client_id: UUID = Field(foreign_key="api_client.id")
|
||||
@@ -0,0 +1,28 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class PersonStats(SQLModel, table=True):
|
||||
__tablename__ = "person_stats"
|
||||
|
||||
person_id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("person.id"), primary_key=True)
|
||||
)
|
||||
leader_score: int = 0
|
||||
modified_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
|
||||
reactions: int = 0 # reactions sent by user
|
||||
posts: int = 0 # posts sent by user
|
||||
upvotes: int = 0 # received upvotes (form other users)
|
||||
downvotes: int = 0 # received downvotes (from other users)
|
||||
work_reward: int = 0 # reward for workpackage completions
|
||||
compare_wins: int = 0 # num times user's post won compare tasks
|
||||
compare_losses: int = 0 # num times users's post lost compare tasks
|
||||
@@ -0,0 +1,33 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, Index, SQLModel
|
||||
|
||||
from .payload_column_type import PayloadContainer, payload_column_type
|
||||
|
||||
|
||||
class Post(SQLModel, table=True):
|
||||
__tablename__ = "post"
|
||||
__table_args__ = (Index("ix_post_frontend_post_id", "api_client_id", "frontend_post_id", unique=True),)
|
||||
|
||||
id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
|
||||
),
|
||||
)
|
||||
parent_id: UUID = Field(nullable=True)
|
||||
thread_id: UUID = Field(nullable=False, index=True)
|
||||
workpackage_id: UUID = Field(nullable=True, index=True)
|
||||
person_id: UUID = Field(nullable=True, foreign_key="person.id", index=True)
|
||||
role: str = Field(nullable=False, max_length=128)
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
frontend_post_id: str = Field(max_length=200, nullable=False)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
payload_type: str = Field(nullable=False, max_length=200)
|
||||
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
|
||||
@@ -0,0 +1,27 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
from .payload_column_type import PayloadContainer, payload_column_type
|
||||
|
||||
|
||||
class PostReaction(SQLModel, table=True):
|
||||
__tablename__ = "post_reaction"
|
||||
|
||||
post_id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("post.id"), nullable=False, primary_key=True)
|
||||
)
|
||||
person_id: UUID = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("person.id"), nullable=False, primary_key=True)
|
||||
)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
payload_type: str = Field(nullable=False, max_length=200)
|
||||
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
@@ -14,6 +14,5 @@ class Prompt(SQLModel, table=True):
|
||||
response: Optional[str]
|
||||
lang: Optional[str]
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
nullable=False,
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
from .payload_column_type import PayloadContainer, payload_column_type
|
||||
|
||||
|
||||
class WorkPackage(SQLModel, table=True):
|
||||
__tablename__ = "work_package"
|
||||
|
||||
id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
|
||||
),
|
||||
)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
)
|
||||
expiry_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
|
||||
person_id: UUID = Field(nullable=True, foreign_key="person.id", index=True)
|
||||
payload_type: str = Field(nullable=False, max_length=200)
|
||||
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
Reference in New Issue
Block a user