mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-28 16:20:34 +08:00
b5bb5bb7c0
* add leaderboard stats, periodic update via fastapi-utils * count label tasks for assistant and prompter replies * Daily stats update every 15 mins, simplify leaderboard endpoint * add indices for some created_date columns * make user stats update intervals configurable * make sure intervals are positive
58 lines
2.8 KiB
Python
58 lines
2.8 KiB
Python
from datetime import datetime
|
|
from http import HTTPStatus
|
|
from typing import Optional
|
|
from uuid import UUID, uuid4
|
|
|
|
import sqlalchemy as sa
|
|
import sqlalchemy.dialects.postgresql as pg
|
|
from oasst_backend.models.db_payload import MessagePayload
|
|
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
|
from sqlalchemy import false
|
|
from sqlmodel import Field, Index, SQLModel
|
|
|
|
from .payload_column_type import PayloadContainer, payload_column_type
|
|
|
|
|
|
class Message(SQLModel, table=True):
|
|
__tablename__ = "message"
|
|
__table_args__ = (Index("ix_message_frontend_message_id", "api_client_id", "frontend_message_id", unique=True),)
|
|
|
|
id: Optional[UUID] = Field(
|
|
sa_column=sa.Column(
|
|
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
|
|
),
|
|
)
|
|
parent_id: Optional[UUID] = Field(nullable=True)
|
|
message_tree_id: UUID = Field(nullable=False, index=True)
|
|
task_id: Optional[UUID] = Field(nullable=True, index=True)
|
|
user_id: Optional[UUID] = Field(nullable=True, foreign_key="user.id", index=True)
|
|
role: str = Field(nullable=False, max_length=128, regex="^prompter|assistant$")
|
|
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
|
frontend_message_id: str = Field(max_length=200, nullable=False)
|
|
created_date: Optional[datetime] = Field(
|
|
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp(), index=True)
|
|
)
|
|
payload_type: str = Field(nullable=False, max_length=200)
|
|
payload: Optional[PayloadContainer] = Field(
|
|
sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=True)
|
|
)
|
|
lang: str = Field(nullable=False, max_length=200, default="en-US")
|
|
depth: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
|
|
children_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
|
|
deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
|
|
|
|
review_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
|
|
review_result: bool = Field(sa_column=sa.Column(sa.Boolean, default=False, server_default=false(), nullable=False))
|
|
ranking_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
|
|
|
|
rank: Optional[int] = Field(nullable=True)
|
|
|
|
def ensure_is_message(self) -> None:
|
|
if not self.payload or not isinstance(self.payload.payload, MessagePayload):
|
|
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE, HTTPStatus.INTERNAL_SERVER_ERROR)
|
|
|
|
@property
|
|
def text(self) -> str:
|
|
self.ensure_is_message()
|
|
return self.payload.payload.text
|