diff --git a/backend/alembic/versions/2023_02_02_1544-4d7e0b0ebe84_add_troll_stats.py b/backend/alembic/versions/2023_02_02_1544-4d7e0b0ebe84_add_troll_stats.py new file mode 100644 index 00000000..aa9b1ffe --- /dev/null +++ b/backend/alembic/versions/2023_02_02_1544-4d7e0b0ebe84_add_troll_stats.py @@ -0,0 +1,59 @@ +"""add troll_stats + +Revision ID: 4d7e0b0ebe84 +Revises: 9e7ec4a9e3f2 +Create Date: 2023-02-02 15:44:12.647260 + +""" +import sqlalchemy as sa +import sqlmodel +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "4d7e0b0ebe84" +down_revision = "9e7ec4a9e3f2" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "troll_stats", + sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("base_date", sa.DateTime(timezone=True), nullable=True), + sa.Column( + "modified_date", sa.DateTime(timezone=True), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False + ), + sa.Column("time_frame", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("troll_score", sa.Integer(), nullable=False), + sa.Column("rank", sa.Integer(), nullable=True), + sa.Column("red_flags", sa.Integer(), nullable=False), + sa.Column("upvotes", sa.Integer(), nullable=False), + sa.Column("downvotes", sa.Integer(), nullable=False), + sa.Column("spam_prompts", sa.Integer(), nullable=False), + sa.Column("quality", sa.Float(), nullable=True), + sa.Column("humor", sa.Float(), nullable=True), + sa.Column("toxicity", sa.Float(), nullable=True), + sa.Column("violence", sa.Float(), nullable=True), + sa.Column("helpfulness", sa.Float(), nullable=True), + sa.Column("spam", sa.Integer(), nullable=False), + sa.Column("lang_mismach", sa.Integer(), nullable=False), + sa.Column("not_appropriate", sa.Integer(), nullable=False), + sa.Column("pii", sa.Integer(), nullable=False), + sa.Column("hate_speech", sa.Integer(), nullable=False), + sa.Column("sexual_content", sa.Integer(), nullable=False), + sa.Column("political_content", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("user_id", "time_frame"), + ) + op.create_index("ix_troll_stats__timeframe__user_id", "troll_stats", ["time_frame", "user_id"], unique=True) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("ix_troll_stats__timeframe__user_id", table_name="troll_stats") + op.drop_table("troll_stats") + # ### end Alembic commands ### diff --git a/backend/oasst_backend/api/v1/api.py b/backend/oasst_backend/api/v1/api.py index 003f039f..331a7841 100644 --- a/backend/oasst_backend/api/v1/api.py +++ b/backend/oasst_backend/api/v1/api.py @@ -10,6 +10,7 @@ from oasst_backend.api.v1 import ( stats, tasks, text_labels, + trollboards, users, ) @@ -22,6 +23,7 @@ api_router.include_router(users.router, prefix="/users", tags=["users"]) api_router.include_router(frontend_users.router, prefix="/frontend_users", tags=["frontend_users"]) api_router.include_router(stats.router, prefix="/stats", tags=["stats"]) api_router.include_router(leaderboards.router, prefix="/leaderboards", tags=["leaderboards"]) +api_router.include_router(trollboards.router, prefix="/trollboards", tags=["trollboards"]) api_router.include_router(hugging_face.router, prefix="/hf", tags=["hugging_face"]) api_router.include_router(admin.router, prefix="/admin", tags=["admin"]) api_router.include_router(auth.router, prefix="/auth", tags=["auth"]) diff --git a/backend/oasst_backend/api/v1/trollboards.py b/backend/oasst_backend/api/v1/trollboards.py new file mode 100644 index 00000000..4ba5c256 --- /dev/null +++ b/backend/oasst_backend/api/v1/trollboards.py @@ -0,0 +1,21 @@ +from typing import Optional + +from fastapi import APIRouter, Depends, Query +from oasst_backend.api import deps +from oasst_backend.models import ApiClient +from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame +from oasst_shared.schemas.protocol import TrollboardStats +from sqlmodel import Session + +router = APIRouter() + + +@router.get("/{time_frame}", response_model=TrollboardStats) +def get_trollboard( + time_frame: UserStatsTimeFrame, + max_count: Optional[int] = Query(100, gt=0, le=10000), + api_client: ApiClient = Depends(deps.get_trusted_api_client), + db: Session = Depends(deps.get_db), +) -> TrollboardStats: + usr = UserStatsRepository(db) + return usr.get_trollboard(time_frame, limit=max_count) diff --git a/backend/oasst_backend/models/__init__.py b/backend/oasst_backend/models/__init__.py index 420c0ccd..65594dde 100644 --- a/backend/oasst_backend/models/__init__.py +++ b/backend/oasst_backend/models/__init__.py @@ -8,6 +8,7 @@ from .message_toxicity import MessageToxicity from .message_tree_state import MessageTreeState from .task import Task from .text_labels import TextLabels +from .troll_stats import TrollStats from .user import User from .user_stats import UserStats, UserStatsTimeFrame @@ -26,4 +27,5 @@ __all__ = [ "Journal", "JournalIntegration", "MessageEmoji", + "TrollStats", ] diff --git a/backend/oasst_backend/models/troll_stats.py b/backend/oasst_backend/models/troll_stats.py new file mode 100644 index 00000000..2cef7246 --- /dev/null +++ b/backend/oasst_backend/models/troll_stats.py @@ -0,0 +1,59 @@ +from datetime import datetime +from typing import Optional +from uuid import UUID + +import sqlalchemy as sa +import sqlalchemy.dialects.postgresql as pg +from sqlmodel import Field, Index, SQLModel + + +class TrollStats(SQLModel, table=True): + __tablename__ = "troll_stats" + __table_args__ = (Index("ix_troll_stats__timeframe__user_id", "time_frame", "user_id", unique=True),) + + time_frame: Optional[str] = Field(nullable=False, primary_key=True) + user_id: Optional[UUID] = Field( + sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id", ondelete="CASCADE"), primary_key=True) + ) + base_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True)) + + troll_score: int = 0 + modified_date: Optional[datetime] = Field( + sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp()) + ) + + rank: int = Field(nullable=True) + + red_flags: int = 0 # num reported messages of user + upvotes: int = 0 # num up-voted messages of user + downvotes: int = 0 # num down-voted messages of user + + spam_prompts: int = 0 + + quality: float = Field(nullable=True) + humor: float = Field(nullable=True) + toxicity: float = Field(nullable=True) + violence: float = Field(nullable=True) + helpfulness: float = Field(nullable=True) + + spam: int = 0 + lang_mismach: int = 0 + not_appropriate: int = 0 + pii: int = 0 + hate_speech: int = 0 + sexual_content: int = 0 + political_content: int = 0 + + def compute_troll_score(self) -> int: + return ( + self.red_flags * 3 + - self.upvotes + + self.downvotes + + self.spam_prompts + + self.lang_mismach + + self.not_appropriate + + self.pii + + self.hate_speech + + self.sexual_content + + self.political_content + ) diff --git a/backend/oasst_backend/user_stats_repository.py b/backend/oasst_backend/user_stats_repository.py index 3862d098..4c28b293 100644 --- a/backend/oasst_backend/user_stats_repository.py +++ b/backend/oasst_backend/user_stats_repository.py @@ -5,19 +5,31 @@ from uuid import UUID import sqlalchemy as sa from loguru import logger from oasst_backend.config import settings -from oasst_backend.models import Message, MessageReaction, Task, User, UserStats, UserStatsTimeFrame +from oasst_backend.models import ( + Message, + MessageReaction, + MessageTreeState, + Task, + TextLabels, + TrollStats, + User, + UserStats, + UserStatsTimeFrame, +) from oasst_backend.models.db_payload import ( LabelAssistantReplyPayload, LabelPrompterReplyPayload, RankingReactionPayload, ) -from oasst_shared.schemas.protocol import LeaderboardStats, UserScore +from oasst_backend.models.message_tree_state import State as TreeState +from oasst_shared.schemas.protocol import EmojiCode, LeaderboardStats, TextLabel, TrollboardStats, TrollScore, UserScore from oasst_shared.utils import log_timing, utcnow from sqlalchemy.dialects import postgresql +from sqlalchemy.sql.functions import coalesce from sqlmodel import Session, delete, func, text -def _create_user_score(r, highlighted_user_id: UUID | None): +def _create_user_score(r, highlighted_user_id: UUID | None) -> UserScore: if r["UserStats"]: d = r["UserStats"].dict() else: @@ -37,6 +49,24 @@ def _create_user_score(r, highlighted_user_id: UUID | None): return UserScore(**d) +def _create_troll_score(r, highlighted_user_id: UUID | None) -> TrollScore: + if r["TrollStats"]: + d = r["TrollStats"].dict() + else: + d = {"modified_date": utcnow()} + for k in [ + "user_id", + "username", + "auth_method", + "display_name", + "last_activity_date", + ]: + d[k] = r[k] + if highlighted_user_id: + d["highlighted"] = r["user_id"] == highlighted_user_id + return TrollScore(**d) + + class UserStatsRepository: def __init__(self, session: Session): self.session = session @@ -133,6 +163,38 @@ class UserStatsRepository: stats_by_timeframe = {tf.value: _create_user_score(r, user_id) for tf in UserStatsTimeFrame} return stats_by_timeframe + def get_trollboard( + self, + time_frame: UserStatsTimeFrame, + limit: int = 100, + highlighted_user_id: Optional[UUID] = None, + ) -> TrollboardStats: + """ + Get trollboard stats for the specified time frame + """ + + qry = ( + self.session.query( + User.id.label("user_id"), + User.username, + User.auth_method, + User.display_name, + User.last_activity_date, + TrollStats, + ) + .join(TrollStats, User.id == TrollStats.user_id) + .filter(TrollStats.time_frame == time_frame.value) + .order_by(TrollStats.rank) + .limit(limit) + ) + + trollboard = [_create_troll_score(r, highlighted_user_id) for r in self.session.exec(qry)] + if len(trollboard) > 0: + last_update = max(x.modified_date for x in trollboard) + else: + last_update = utcnow() + return TrollboardStats(time_frame=time_frame.value, trollboard=trollboard, last_updated=last_update) + def query_total_prompts_per_user( self, reference_time: Optional[datetime] = None, only_reviewed: Optional[bool] = True ): @@ -292,10 +354,145 @@ class UserStatsRepository: self.session.add_all(stats_by_user.values()) self.session.flush() - self.update_ranks(time_frame=time_frame) + self.update_leader_ranks(time_frame=time_frame) + + def query_message_emoji_counts_per_user(self, reference_time: Optional[datetime] = None): + qry = self.session.query( + Message.user_id, + func.sum(coalesce(Message.emojis[EmojiCode.thumbs_up].cast(sa.Integer), 0)).label("up"), + func.sum(coalesce(Message.emojis[EmojiCode.thumbs_down].cast(sa.Integer), 0)).label("down"), + func.sum(coalesce(Message.emojis[EmojiCode.red_flag].cast(sa.Integer), 0)).label("flag"), + ).filter(Message.deleted == sa.false(), Message.emojis.is_not(None)) + + if reference_time: + qry = qry.filter(Message.created_date >= reference_time) + + qry = qry.group_by(Message.user_id) + return qry + + def query_spam_prompts_per_user(self, reference_time: Optional[datetime] = None): + qry = ( + self.session.query(Message.user_id, func.count().label("spam_prompts")) + .select_from(MessageTreeState) + .join(Message, MessageTreeState.message_tree_id == Message.id) + .filter(MessageTreeState.state == TreeState.ABORTED_LOW_GRADE) + ) + + if reference_time: + qry = qry.filter(Message.created_date >= reference_time) + + qry = qry.group_by(Message.user_id) + return qry + + def query_labels_per_user(self, reference_time: Optional[datetime] = None): + qry = ( + self.session.query( + Message.user_id, + func.sum(coalesce(TextLabels.labels[TextLabel.spam].cast(sa.Integer), 0)).label("spam"), + func.sum(coalesce(TextLabels.labels[TextLabel.lang_mismatch].cast(sa.Integer), 0)).label( + "lang_mismach" + ), + func.sum(coalesce(TextLabels.labels[TextLabel.not_appropriate].cast(sa.Integer), 0)).label( + "not_appropriate" + ), + func.sum(coalesce(TextLabels.labels[TextLabel.pii].cast(sa.Integer), 0)).label("pii"), + func.sum(coalesce(TextLabels.labels[TextLabel.hate_speech].cast(sa.Integer), 0)).label("hate_speech"), + func.sum(coalesce(TextLabels.labels[TextLabel.sexual_content].cast(sa.Integer), 0)).label( + "sexual_content" + ), + func.sum(coalesce(TextLabels.labels[TextLabel.political_content].cast(sa.Integer), 0)).label( + "political_content" + ), + func.avg(TextLabels.labels[TextLabel.quality].cast(sa.Float)).label("quality"), + func.avg(TextLabels.labels[TextLabel.humor].cast(sa.Float)).label("humor"), + func.avg(TextLabels.labels[TextLabel.toxicity].cast(sa.Float)).label("toxicity"), + func.avg(TextLabels.labels[TextLabel.violence].cast(sa.Float)).label("violence"), + func.avg(TextLabels.labels[TextLabel.helpfulness].cast(sa.Float)).label("helpfulness"), + ) + .select_from(TextLabels) + .join(Message, TextLabels.message_id == Message.id) + .filter(Message.deleted == sa.false(), Message.emojis.is_not(None)) + ) + + if reference_time: + qry = qry.filter(Message.created_date >= reference_time) + + qry = qry.group_by(Message.user_id) + return qry + + def _update_troll_stats_internal(self, time_frame: UserStatsTimeFrame, base_date: Optional[datetime] = None): + # gather user data + + time_frame_key = time_frame.value + + stats_by_user: dict[UUID, TrollStats] = dict() + now = utcnow() + + def get_stats(id: UUID) -> TrollStats: + us = stats_by_user.get(id) + if not us: + us = TrollStats(user_id=id, time_frame=time_frame_key, modified_date=now, base_date=base_date) + stats_by_user[id] = us + return us + + # emoji counts of user's messages + qry = self.query_message_emoji_counts_per_user(reference_time=base_date) + for r in qry: + uid = r["user_id"] + s = get_stats(uid) + s.upvotes = r["up"] + s.downvotes = r["down"] + s.red_flags = r["flag"] + + # num spam prompts + qry = self.query_spam_prompts_per_user(reference_time=base_date) + for r in qry: + uid, count = r + s = get_stats(uid).spam_prompts = count + + label_field_names = ( + "quality", + "humor", + "toxicity", + "violence", + "helpfulness", + "spam", + "lang_mismach", + "not_appropriate", + "pii", + "hate_speech", + "sexual_content", + "political_content", + ) + + # label counts / mean values + qry = self.query_labels_per_user(reference_time=base_date) + for r in qry: + uid = r["user_id"] + s = get_stats(uid) + for fn in label_field_names: + setattr(s, fn, r[fn]) + + # delete all existing stast for time frame + d = delete(TrollStats).where(TrollStats.time_frame == time_frame_key) + self.session.execute(d) + + if None in stats_by_user: + logger.warning("Some messages in DB have NULL values in user_id column.") + del stats_by_user[None] + + # compute magic leader score + for v in stats_by_user.values(): + v.troll_score = v.compute_troll_score() + + # insert user objects + self.session.add_all(stats_by_user.values()) + self.session.flush() + + self.update_troll_ranks(time_frame=time_frame) @log_timing(log_kwargs=True) - def update_ranks(self, time_frame: UserStatsTimeFrame = None): + def update_leader_ranks(self, time_frame: UserStatsTimeFrame = None): """ Update user_stats ranks. The persisted rank values allow to quickly the rank of a single user and to query nearby users. @@ -329,10 +526,41 @@ WHERE r = self.session.execute( text(sql_update_rank), {"time_frame": time_frame.value if time_frame is not None else None} ) - logger.debug(f"pre_compute_ranks updated({time_frame=}) {r.rowcount} rows.") + logger.debug(f"pre_compute_ranks leader updated({time_frame=}) {r.rowcount} rows.") - def update_stats_time_frame(self, time_frame: UserStatsTimeFrame, reference_time: Optional[datetime] = None): - self._update_stats_internal(time_frame, reference_time) + @log_timing(log_kwargs=True) + def update_troll_ranks(self, time_frame: UserStatsTimeFrame = None): + sql_update_troll_rank = """ +-- update rank +UPDATE troll_stats ts +SET "rank" = r."rank" +FROM + (SELECT + ROW_NUMBER () OVER( + PARTITION BY time_frame + ORDER BY troll_score DESC, user_id + ) AS "rank", user_id, time_frame + FROM troll_stats ts2 + WHERE (:time_frame IS NULL OR time_frame = :time_frame)) AS r +WHERE + ts.user_id = r.user_id + AND ts.time_frame = r.time_frame;""" + r = self.session.execute( + text(sql_update_troll_rank), {"time_frame": time_frame.value if time_frame is not None else None} + ) + logger.debug(f"pre_compute_ranks troll updated({time_frame=}) {r.rowcount} rows.") + + def update_stats_time_frame( + self, + time_frame: UserStatsTimeFrame, + reference_time: Optional[datetime] = None, + leader_stats: bool = True, + troll_stats: bool = True, + ): + if leader_stats: + self._update_stats_internal(time_frame, reference_time) + if troll_stats: + self._update_troll_stats_internal(time_frame, reference_time) self.session.commit() @log_timing(log_kwargs=True, level="INFO") diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index 3570d3fd..e0dde366 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -469,6 +469,47 @@ class LeaderboardStats(BaseModel): leaderboard: List[UserScore] +class TrollScore(BaseModel): + rank: Optional[int] + user_id: UUID + highlighted: bool = False + username: str + auth_method: str + display_name: str + last_activity_date: Optional[datetime] + + troll_score: int = 0 + + base_date: Optional[datetime] + modified_date: Optional[datetime] + + red_flags: int = 0 # num reported messages of user + upvotes: int = 0 # num up-voted messages of user + downvotes: int = 0 # num down-voted messages of user + + spam_prompts: int = 0 + + quality: Optional[float] = None + humor: Optional[float] = None + toxicity: Optional[float] = None + violence: Optional[float] = None + helpfulness: Optional[float] = None + + spam: int = 0 + lang_mismach: int = 0 + not_appropriate: int = 0 + pii: int = 0 + hate_speech: int = 0 + sexual_content: int = 0 + political_content: int = 0 + + +class TrollboardStats(BaseModel): + time_frame: str + last_updated: datetime + trollboard: List[TrollScore] + + class OasstErrorResponse(BaseModel): """The format of an error response from the OASST API."""