add trollboards

This commit is contained in:
Andreas Köpf
2023-02-02 16:10:38 +01:00
parent fa53505369
commit 2db3450e9a
7 changed files with 420 additions and 8 deletions
@@ -0,0 +1,59 @@
"""add troll_stats
Revision ID: 4d7e0b0ebe84
Revises: 9e7ec4a9e3f2
Create Date: 2023-02-02 15:44:12.647260
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "4d7e0b0ebe84"
down_revision = "9e7ec4a9e3f2"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"troll_stats",
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("base_date", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"modified_date", sa.DateTime(timezone=True), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False
),
sa.Column("time_frame", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("troll_score", sa.Integer(), nullable=False),
sa.Column("rank", sa.Integer(), nullable=True),
sa.Column("red_flags", sa.Integer(), nullable=False),
sa.Column("upvotes", sa.Integer(), nullable=False),
sa.Column("downvotes", sa.Integer(), nullable=False),
sa.Column("spam_prompts", sa.Integer(), nullable=False),
sa.Column("quality", sa.Float(), nullable=True),
sa.Column("humor", sa.Float(), nullable=True),
sa.Column("toxicity", sa.Float(), nullable=True),
sa.Column("violence", sa.Float(), nullable=True),
sa.Column("helpfulness", sa.Float(), nullable=True),
sa.Column("spam", sa.Integer(), nullable=False),
sa.Column("lang_mismach", sa.Integer(), nullable=False),
sa.Column("not_appropriate", sa.Integer(), nullable=False),
sa.Column("pii", sa.Integer(), nullable=False),
sa.Column("hate_speech", sa.Integer(), nullable=False),
sa.Column("sexual_content", sa.Integer(), nullable=False),
sa.Column("political_content", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("user_id", "time_frame"),
)
op.create_index("ix_troll_stats__timeframe__user_id", "troll_stats", ["time_frame", "user_id"], unique=True)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("ix_troll_stats__timeframe__user_id", table_name="troll_stats")
op.drop_table("troll_stats")
# ### end Alembic commands ###
+2
View File
@@ -10,6 +10,7 @@ from oasst_backend.api.v1 import (
stats,
tasks,
text_labels,
trollboards,
users,
)
@@ -22,6 +23,7 @@ api_router.include_router(users.router, prefix="/users", tags=["users"])
api_router.include_router(frontend_users.router, prefix="/frontend_users", tags=["frontend_users"])
api_router.include_router(stats.router, prefix="/stats", tags=["stats"])
api_router.include_router(leaderboards.router, prefix="/leaderboards", tags=["leaderboards"])
api_router.include_router(trollboards.router, prefix="/trollboards", tags=["trollboards"])
api_router.include_router(hugging_face.router, prefix="/hf", tags=["hugging_face"])
api_router.include_router(admin.router, prefix="/admin", tags=["admin"])
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
@@ -0,0 +1,21 @@
from typing import Optional
from fastapi import APIRouter, Depends, Query
from oasst_backend.api import deps
from oasst_backend.models import ApiClient
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
from oasst_shared.schemas.protocol import TrollboardStats
from sqlmodel import Session
router = APIRouter()
@router.get("/{time_frame}", response_model=TrollboardStats)
def get_trollboard(
time_frame: UserStatsTimeFrame,
max_count: Optional[int] = Query(100, gt=0, le=10000),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
db: Session = Depends(deps.get_db),
) -> TrollboardStats:
usr = UserStatsRepository(db)
return usr.get_trollboard(time_frame, limit=max_count)
+2
View File
@@ -8,6 +8,7 @@ from .message_toxicity import MessageToxicity
from .message_tree_state import MessageTreeState
from .task import Task
from .text_labels import TextLabels
from .troll_stats import TrollStats
from .user import User
from .user_stats import UserStats, UserStatsTimeFrame
@@ -26,4 +27,5 @@ __all__ = [
"Journal",
"JournalIntegration",
"MessageEmoji",
"TrollStats",
]
@@ -0,0 +1,59 @@
from datetime import datetime
from typing import Optional
from uuid import UUID
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, Index, SQLModel
class TrollStats(SQLModel, table=True):
__tablename__ = "troll_stats"
__table_args__ = (Index("ix_troll_stats__timeframe__user_id", "time_frame", "user_id", unique=True),)
time_frame: Optional[str] = Field(nullable=False, primary_key=True)
user_id: Optional[UUID] = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id", ondelete="CASCADE"), primary_key=True)
)
base_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True))
troll_score: int = 0
modified_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
)
rank: int = Field(nullable=True)
red_flags: int = 0 # num reported messages of user
upvotes: int = 0 # num up-voted messages of user
downvotes: int = 0 # num down-voted messages of user
spam_prompts: int = 0
quality: float = Field(nullable=True)
humor: float = Field(nullable=True)
toxicity: float = Field(nullable=True)
violence: float = Field(nullable=True)
helpfulness: float = Field(nullable=True)
spam: int = 0
lang_mismach: int = 0
not_appropriate: int = 0
pii: int = 0
hate_speech: int = 0
sexual_content: int = 0
political_content: int = 0
def compute_troll_score(self) -> int:
return (
self.red_flags * 3
- self.upvotes
+ self.downvotes
+ self.spam_prompts
+ self.lang_mismach
+ self.not_appropriate
+ self.pii
+ self.hate_speech
+ self.sexual_content
+ self.political_content
)
+236 -8
View File
@@ -5,19 +5,31 @@ from uuid import UUID
import sqlalchemy as sa
from loguru import logger
from oasst_backend.config import settings
from oasst_backend.models import Message, MessageReaction, Task, User, UserStats, UserStatsTimeFrame
from oasst_backend.models import (
Message,
MessageReaction,
MessageTreeState,
Task,
TextLabels,
TrollStats,
User,
UserStats,
UserStatsTimeFrame,
)
from oasst_backend.models.db_payload import (
LabelAssistantReplyPayload,
LabelPrompterReplyPayload,
RankingReactionPayload,
)
from oasst_shared.schemas.protocol import LeaderboardStats, UserScore
from oasst_backend.models.message_tree_state import State as TreeState
from oasst_shared.schemas.protocol import EmojiCode, LeaderboardStats, TextLabel, TrollboardStats, TrollScore, UserScore
from oasst_shared.utils import log_timing, utcnow
from sqlalchemy.dialects import postgresql
from sqlalchemy.sql.functions import coalesce
from sqlmodel import Session, delete, func, text
def _create_user_score(r, highlighted_user_id: UUID | None):
def _create_user_score(r, highlighted_user_id: UUID | None) -> UserScore:
if r["UserStats"]:
d = r["UserStats"].dict()
else:
@@ -37,6 +49,24 @@ def _create_user_score(r, highlighted_user_id: UUID | None):
return UserScore(**d)
def _create_troll_score(r, highlighted_user_id: UUID | None) -> TrollScore:
if r["TrollStats"]:
d = r["TrollStats"].dict()
else:
d = {"modified_date": utcnow()}
for k in [
"user_id",
"username",
"auth_method",
"display_name",
"last_activity_date",
]:
d[k] = r[k]
if highlighted_user_id:
d["highlighted"] = r["user_id"] == highlighted_user_id
return TrollScore(**d)
class UserStatsRepository:
def __init__(self, session: Session):
self.session = session
@@ -133,6 +163,38 @@ class UserStatsRepository:
stats_by_timeframe = {tf.value: _create_user_score(r, user_id) for tf in UserStatsTimeFrame}
return stats_by_timeframe
def get_trollboard(
self,
time_frame: UserStatsTimeFrame,
limit: int = 100,
highlighted_user_id: Optional[UUID] = None,
) -> TrollboardStats:
"""
Get trollboard stats for the specified time frame
"""
qry = (
self.session.query(
User.id.label("user_id"),
User.username,
User.auth_method,
User.display_name,
User.last_activity_date,
TrollStats,
)
.join(TrollStats, User.id == TrollStats.user_id)
.filter(TrollStats.time_frame == time_frame.value)
.order_by(TrollStats.rank)
.limit(limit)
)
trollboard = [_create_troll_score(r, highlighted_user_id) for r in self.session.exec(qry)]
if len(trollboard) > 0:
last_update = max(x.modified_date for x in trollboard)
else:
last_update = utcnow()
return TrollboardStats(time_frame=time_frame.value, trollboard=trollboard, last_updated=last_update)
def query_total_prompts_per_user(
self, reference_time: Optional[datetime] = None, only_reviewed: Optional[bool] = True
):
@@ -292,10 +354,145 @@ class UserStatsRepository:
self.session.add_all(stats_by_user.values())
self.session.flush()
self.update_ranks(time_frame=time_frame)
self.update_leader_ranks(time_frame=time_frame)
def query_message_emoji_counts_per_user(self, reference_time: Optional[datetime] = None):
qry = self.session.query(
Message.user_id,
func.sum(coalesce(Message.emojis[EmojiCode.thumbs_up].cast(sa.Integer), 0)).label("up"),
func.sum(coalesce(Message.emojis[EmojiCode.thumbs_down].cast(sa.Integer), 0)).label("down"),
func.sum(coalesce(Message.emojis[EmojiCode.red_flag].cast(sa.Integer), 0)).label("flag"),
).filter(Message.deleted == sa.false(), Message.emojis.is_not(None))
if reference_time:
qry = qry.filter(Message.created_date >= reference_time)
qry = qry.group_by(Message.user_id)
return qry
def query_spam_prompts_per_user(self, reference_time: Optional[datetime] = None):
qry = (
self.session.query(Message.user_id, func.count().label("spam_prompts"))
.select_from(MessageTreeState)
.join(Message, MessageTreeState.message_tree_id == Message.id)
.filter(MessageTreeState.state == TreeState.ABORTED_LOW_GRADE)
)
if reference_time:
qry = qry.filter(Message.created_date >= reference_time)
qry = qry.group_by(Message.user_id)
return qry
def query_labels_per_user(self, reference_time: Optional[datetime] = None):
qry = (
self.session.query(
Message.user_id,
func.sum(coalesce(TextLabels.labels[TextLabel.spam].cast(sa.Integer), 0)).label("spam"),
func.sum(coalesce(TextLabels.labels[TextLabel.lang_mismatch].cast(sa.Integer), 0)).label(
"lang_mismach"
),
func.sum(coalesce(TextLabels.labels[TextLabel.not_appropriate].cast(sa.Integer), 0)).label(
"not_appropriate"
),
func.sum(coalesce(TextLabels.labels[TextLabel.pii].cast(sa.Integer), 0)).label("pii"),
func.sum(coalesce(TextLabels.labels[TextLabel.hate_speech].cast(sa.Integer), 0)).label("hate_speech"),
func.sum(coalesce(TextLabels.labels[TextLabel.sexual_content].cast(sa.Integer), 0)).label(
"sexual_content"
),
func.sum(coalesce(TextLabels.labels[TextLabel.political_content].cast(sa.Integer), 0)).label(
"political_content"
),
func.avg(TextLabels.labels[TextLabel.quality].cast(sa.Float)).label("quality"),
func.avg(TextLabels.labels[TextLabel.humor].cast(sa.Float)).label("humor"),
func.avg(TextLabels.labels[TextLabel.toxicity].cast(sa.Float)).label("toxicity"),
func.avg(TextLabels.labels[TextLabel.violence].cast(sa.Float)).label("violence"),
func.avg(TextLabels.labels[TextLabel.helpfulness].cast(sa.Float)).label("helpfulness"),
)
.select_from(TextLabels)
.join(Message, TextLabels.message_id == Message.id)
.filter(Message.deleted == sa.false(), Message.emojis.is_not(None))
)
if reference_time:
qry = qry.filter(Message.created_date >= reference_time)
qry = qry.group_by(Message.user_id)
return qry
def _update_troll_stats_internal(self, time_frame: UserStatsTimeFrame, base_date: Optional[datetime] = None):
# gather user data
time_frame_key = time_frame.value
stats_by_user: dict[UUID, TrollStats] = dict()
now = utcnow()
def get_stats(id: UUID) -> TrollStats:
us = stats_by_user.get(id)
if not us:
us = TrollStats(user_id=id, time_frame=time_frame_key, modified_date=now, base_date=base_date)
stats_by_user[id] = us
return us
# emoji counts of user's messages
qry = self.query_message_emoji_counts_per_user(reference_time=base_date)
for r in qry:
uid = r["user_id"]
s = get_stats(uid)
s.upvotes = r["up"]
s.downvotes = r["down"]
s.red_flags = r["flag"]
# num spam prompts
qry = self.query_spam_prompts_per_user(reference_time=base_date)
for r in qry:
uid, count = r
s = get_stats(uid).spam_prompts = count
label_field_names = (
"quality",
"humor",
"toxicity",
"violence",
"helpfulness",
"spam",
"lang_mismach",
"not_appropriate",
"pii",
"hate_speech",
"sexual_content",
"political_content",
)
# label counts / mean values
qry = self.query_labels_per_user(reference_time=base_date)
for r in qry:
uid = r["user_id"]
s = get_stats(uid)
for fn in label_field_names:
setattr(s, fn, r[fn])
# delete all existing stast for time frame
d = delete(TrollStats).where(TrollStats.time_frame == time_frame_key)
self.session.execute(d)
if None in stats_by_user:
logger.warning("Some messages in DB have NULL values in user_id column.")
del stats_by_user[None]
# compute magic leader score
for v in stats_by_user.values():
v.troll_score = v.compute_troll_score()
# insert user objects
self.session.add_all(stats_by_user.values())
self.session.flush()
self.update_troll_ranks(time_frame=time_frame)
@log_timing(log_kwargs=True)
def update_ranks(self, time_frame: UserStatsTimeFrame = None):
def update_leader_ranks(self, time_frame: UserStatsTimeFrame = None):
"""
Update user_stats ranks. The persisted rank values allow to
quickly the rank of a single user and to query nearby users.
@@ -329,10 +526,41 @@ WHERE
r = self.session.execute(
text(sql_update_rank), {"time_frame": time_frame.value if time_frame is not None else None}
)
logger.debug(f"pre_compute_ranks updated({time_frame=}) {r.rowcount} rows.")
logger.debug(f"pre_compute_ranks leader updated({time_frame=}) {r.rowcount} rows.")
def update_stats_time_frame(self, time_frame: UserStatsTimeFrame, reference_time: Optional[datetime] = None):
self._update_stats_internal(time_frame, reference_time)
@log_timing(log_kwargs=True)
def update_troll_ranks(self, time_frame: UserStatsTimeFrame = None):
sql_update_troll_rank = """
-- update rank
UPDATE troll_stats ts
SET "rank" = r."rank"
FROM
(SELECT
ROW_NUMBER () OVER(
PARTITION BY time_frame
ORDER BY troll_score DESC, user_id
) AS "rank", user_id, time_frame
FROM troll_stats ts2
WHERE (:time_frame IS NULL OR time_frame = :time_frame)) AS r
WHERE
ts.user_id = r.user_id
AND ts.time_frame = r.time_frame;"""
r = self.session.execute(
text(sql_update_troll_rank), {"time_frame": time_frame.value if time_frame is not None else None}
)
logger.debug(f"pre_compute_ranks troll updated({time_frame=}) {r.rowcount} rows.")
def update_stats_time_frame(
self,
time_frame: UserStatsTimeFrame,
reference_time: Optional[datetime] = None,
leader_stats: bool = True,
troll_stats: bool = True,
):
if leader_stats:
self._update_stats_internal(time_frame, reference_time)
if troll_stats:
self._update_troll_stats_internal(time_frame, reference_time)
self.session.commit()
@log_timing(log_kwargs=True, level="INFO")
@@ -469,6 +469,47 @@ class LeaderboardStats(BaseModel):
leaderboard: List[UserScore]
class TrollScore(BaseModel):
rank: Optional[int]
user_id: UUID
highlighted: bool = False
username: str
auth_method: str
display_name: str
last_activity_date: Optional[datetime]
troll_score: int = 0
base_date: Optional[datetime]
modified_date: Optional[datetime]
red_flags: int = 0 # num reported messages of user
upvotes: int = 0 # num up-voted messages of user
downvotes: int = 0 # num down-voted messages of user
spam_prompts: int = 0
quality: Optional[float] = None
humor: Optional[float] = None
toxicity: Optional[float] = None
violence: Optional[float] = None
helpfulness: Optional[float] = None
spam: int = 0
lang_mismach: int = 0
not_appropriate: int = 0
pii: int = 0
hate_speech: int = 0
sexual_content: int = 0
political_content: int = 0
class TrollboardStats(BaseModel):
time_frame: str
last_updated: datetime
trollboard: List[TrollScore]
class OasstErrorResponse(BaseModel):
"""The format of an error response from the OASST API."""