mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
add trollboards
This commit is contained in:
@@ -0,0 +1,59 @@
|
||||
"""add troll_stats
|
||||
|
||||
Revision ID: 4d7e0b0ebe84
|
||||
Revises: 9e7ec4a9e3f2
|
||||
Create Date: 2023-02-02 15:44:12.647260
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4d7e0b0ebe84"
|
||||
down_revision = "9e7ec4a9e3f2"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"troll_stats",
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("base_date", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"modified_date", sa.DateTime(timezone=True), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False
|
||||
),
|
||||
sa.Column("time_frame", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("troll_score", sa.Integer(), nullable=False),
|
||||
sa.Column("rank", sa.Integer(), nullable=True),
|
||||
sa.Column("red_flags", sa.Integer(), nullable=False),
|
||||
sa.Column("upvotes", sa.Integer(), nullable=False),
|
||||
sa.Column("downvotes", sa.Integer(), nullable=False),
|
||||
sa.Column("spam_prompts", sa.Integer(), nullable=False),
|
||||
sa.Column("quality", sa.Float(), nullable=True),
|
||||
sa.Column("humor", sa.Float(), nullable=True),
|
||||
sa.Column("toxicity", sa.Float(), nullable=True),
|
||||
sa.Column("violence", sa.Float(), nullable=True),
|
||||
sa.Column("helpfulness", sa.Float(), nullable=True),
|
||||
sa.Column("spam", sa.Integer(), nullable=False),
|
||||
sa.Column("lang_mismach", sa.Integer(), nullable=False),
|
||||
sa.Column("not_appropriate", sa.Integer(), nullable=False),
|
||||
sa.Column("pii", sa.Integer(), nullable=False),
|
||||
sa.Column("hate_speech", sa.Integer(), nullable=False),
|
||||
sa.Column("sexual_content", sa.Integer(), nullable=False),
|
||||
sa.Column("political_content", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("user_id", "time_frame"),
|
||||
)
|
||||
op.create_index("ix_troll_stats__timeframe__user_id", "troll_stats", ["time_frame", "user_id"], unique=True)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index("ix_troll_stats__timeframe__user_id", table_name="troll_stats")
|
||||
op.drop_table("troll_stats")
|
||||
# ### end Alembic commands ###
|
||||
@@ -10,6 +10,7 @@ from oasst_backend.api.v1 import (
|
||||
stats,
|
||||
tasks,
|
||||
text_labels,
|
||||
trollboards,
|
||||
users,
|
||||
)
|
||||
|
||||
@@ -22,6 +23,7 @@ api_router.include_router(users.router, prefix="/users", tags=["users"])
|
||||
api_router.include_router(frontend_users.router, prefix="/frontend_users", tags=["frontend_users"])
|
||||
api_router.include_router(stats.router, prefix="/stats", tags=["stats"])
|
||||
api_router.include_router(leaderboards.router, prefix="/leaderboards", tags=["leaderboards"])
|
||||
api_router.include_router(trollboards.router, prefix="/trollboards", tags=["trollboards"])
|
||||
api_router.include_router(hugging_face.router, prefix="/hf", tags=["hugging_face"])
|
||||
api_router.include_router(admin.router, prefix="/admin", tags=["admin"])
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
|
||||
from oasst_shared.schemas.protocol import TrollboardStats
|
||||
from sqlmodel import Session
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/{time_frame}", response_model=TrollboardStats)
|
||||
def get_trollboard(
|
||||
time_frame: UserStatsTimeFrame,
|
||||
max_count: Optional[int] = Query(100, gt=0, le=10000),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
) -> TrollboardStats:
|
||||
usr = UserStatsRepository(db)
|
||||
return usr.get_trollboard(time_frame, limit=max_count)
|
||||
@@ -8,6 +8,7 @@ from .message_toxicity import MessageToxicity
|
||||
from .message_tree_state import MessageTreeState
|
||||
from .task import Task
|
||||
from .text_labels import TextLabels
|
||||
from .troll_stats import TrollStats
|
||||
from .user import User
|
||||
from .user_stats import UserStats, UserStatsTimeFrame
|
||||
|
||||
@@ -26,4 +27,5 @@ __all__ = [
|
||||
"Journal",
|
||||
"JournalIntegration",
|
||||
"MessageEmoji",
|
||||
"TrollStats",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, Index, SQLModel
|
||||
|
||||
|
||||
class TrollStats(SQLModel, table=True):
|
||||
__tablename__ = "troll_stats"
|
||||
__table_args__ = (Index("ix_troll_stats__timeframe__user_id", "time_frame", "user_id", unique=True),)
|
||||
|
||||
time_frame: Optional[str] = Field(nullable=False, primary_key=True)
|
||||
user_id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id", ondelete="CASCADE"), primary_key=True)
|
||||
)
|
||||
base_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True))
|
||||
|
||||
troll_score: int = 0
|
||||
modified_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
|
||||
rank: int = Field(nullable=True)
|
||||
|
||||
red_flags: int = 0 # num reported messages of user
|
||||
upvotes: int = 0 # num up-voted messages of user
|
||||
downvotes: int = 0 # num down-voted messages of user
|
||||
|
||||
spam_prompts: int = 0
|
||||
|
||||
quality: float = Field(nullable=True)
|
||||
humor: float = Field(nullable=True)
|
||||
toxicity: float = Field(nullable=True)
|
||||
violence: float = Field(nullable=True)
|
||||
helpfulness: float = Field(nullable=True)
|
||||
|
||||
spam: int = 0
|
||||
lang_mismach: int = 0
|
||||
not_appropriate: int = 0
|
||||
pii: int = 0
|
||||
hate_speech: int = 0
|
||||
sexual_content: int = 0
|
||||
political_content: int = 0
|
||||
|
||||
def compute_troll_score(self) -> int:
|
||||
return (
|
||||
self.red_flags * 3
|
||||
- self.upvotes
|
||||
+ self.downvotes
|
||||
+ self.spam_prompts
|
||||
+ self.lang_mismach
|
||||
+ self.not_appropriate
|
||||
+ self.pii
|
||||
+ self.hate_speech
|
||||
+ self.sexual_content
|
||||
+ self.political_content
|
||||
)
|
||||
@@ -5,19 +5,31 @@ from uuid import UUID
|
||||
import sqlalchemy as sa
|
||||
from loguru import logger
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.models import Message, MessageReaction, Task, User, UserStats, UserStatsTimeFrame
|
||||
from oasst_backend.models import (
|
||||
Message,
|
||||
MessageReaction,
|
||||
MessageTreeState,
|
||||
Task,
|
||||
TextLabels,
|
||||
TrollStats,
|
||||
User,
|
||||
UserStats,
|
||||
UserStatsTimeFrame,
|
||||
)
|
||||
from oasst_backend.models.db_payload import (
|
||||
LabelAssistantReplyPayload,
|
||||
LabelPrompterReplyPayload,
|
||||
RankingReactionPayload,
|
||||
)
|
||||
from oasst_shared.schemas.protocol import LeaderboardStats, UserScore
|
||||
from oasst_backend.models.message_tree_state import State as TreeState
|
||||
from oasst_shared.schemas.protocol import EmojiCode, LeaderboardStats, TextLabel, TrollboardStats, TrollScore, UserScore
|
||||
from oasst_shared.utils import log_timing, utcnow
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.sql.functions import coalesce
|
||||
from sqlmodel import Session, delete, func, text
|
||||
|
||||
|
||||
def _create_user_score(r, highlighted_user_id: UUID | None):
|
||||
def _create_user_score(r, highlighted_user_id: UUID | None) -> UserScore:
|
||||
if r["UserStats"]:
|
||||
d = r["UserStats"].dict()
|
||||
else:
|
||||
@@ -37,6 +49,24 @@ def _create_user_score(r, highlighted_user_id: UUID | None):
|
||||
return UserScore(**d)
|
||||
|
||||
|
||||
def _create_troll_score(r, highlighted_user_id: UUID | None) -> TrollScore:
|
||||
if r["TrollStats"]:
|
||||
d = r["TrollStats"].dict()
|
||||
else:
|
||||
d = {"modified_date": utcnow()}
|
||||
for k in [
|
||||
"user_id",
|
||||
"username",
|
||||
"auth_method",
|
||||
"display_name",
|
||||
"last_activity_date",
|
||||
]:
|
||||
d[k] = r[k]
|
||||
if highlighted_user_id:
|
||||
d["highlighted"] = r["user_id"] == highlighted_user_id
|
||||
return TrollScore(**d)
|
||||
|
||||
|
||||
class UserStatsRepository:
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
@@ -133,6 +163,38 @@ class UserStatsRepository:
|
||||
stats_by_timeframe = {tf.value: _create_user_score(r, user_id) for tf in UserStatsTimeFrame}
|
||||
return stats_by_timeframe
|
||||
|
||||
def get_trollboard(
|
||||
self,
|
||||
time_frame: UserStatsTimeFrame,
|
||||
limit: int = 100,
|
||||
highlighted_user_id: Optional[UUID] = None,
|
||||
) -> TrollboardStats:
|
||||
"""
|
||||
Get trollboard stats for the specified time frame
|
||||
"""
|
||||
|
||||
qry = (
|
||||
self.session.query(
|
||||
User.id.label("user_id"),
|
||||
User.username,
|
||||
User.auth_method,
|
||||
User.display_name,
|
||||
User.last_activity_date,
|
||||
TrollStats,
|
||||
)
|
||||
.join(TrollStats, User.id == TrollStats.user_id)
|
||||
.filter(TrollStats.time_frame == time_frame.value)
|
||||
.order_by(TrollStats.rank)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
trollboard = [_create_troll_score(r, highlighted_user_id) for r in self.session.exec(qry)]
|
||||
if len(trollboard) > 0:
|
||||
last_update = max(x.modified_date for x in trollboard)
|
||||
else:
|
||||
last_update = utcnow()
|
||||
return TrollboardStats(time_frame=time_frame.value, trollboard=trollboard, last_updated=last_update)
|
||||
|
||||
def query_total_prompts_per_user(
|
||||
self, reference_time: Optional[datetime] = None, only_reviewed: Optional[bool] = True
|
||||
):
|
||||
@@ -292,10 +354,145 @@ class UserStatsRepository:
|
||||
self.session.add_all(stats_by_user.values())
|
||||
self.session.flush()
|
||||
|
||||
self.update_ranks(time_frame=time_frame)
|
||||
self.update_leader_ranks(time_frame=time_frame)
|
||||
|
||||
def query_message_emoji_counts_per_user(self, reference_time: Optional[datetime] = None):
|
||||
qry = self.session.query(
|
||||
Message.user_id,
|
||||
func.sum(coalesce(Message.emojis[EmojiCode.thumbs_up].cast(sa.Integer), 0)).label("up"),
|
||||
func.sum(coalesce(Message.emojis[EmojiCode.thumbs_down].cast(sa.Integer), 0)).label("down"),
|
||||
func.sum(coalesce(Message.emojis[EmojiCode.red_flag].cast(sa.Integer), 0)).label("flag"),
|
||||
).filter(Message.deleted == sa.false(), Message.emojis.is_not(None))
|
||||
|
||||
if reference_time:
|
||||
qry = qry.filter(Message.created_date >= reference_time)
|
||||
|
||||
qry = qry.group_by(Message.user_id)
|
||||
return qry
|
||||
|
||||
def query_spam_prompts_per_user(self, reference_time: Optional[datetime] = None):
|
||||
qry = (
|
||||
self.session.query(Message.user_id, func.count().label("spam_prompts"))
|
||||
.select_from(MessageTreeState)
|
||||
.join(Message, MessageTreeState.message_tree_id == Message.id)
|
||||
.filter(MessageTreeState.state == TreeState.ABORTED_LOW_GRADE)
|
||||
)
|
||||
|
||||
if reference_time:
|
||||
qry = qry.filter(Message.created_date >= reference_time)
|
||||
|
||||
qry = qry.group_by(Message.user_id)
|
||||
return qry
|
||||
|
||||
def query_labels_per_user(self, reference_time: Optional[datetime] = None):
|
||||
qry = (
|
||||
self.session.query(
|
||||
Message.user_id,
|
||||
func.sum(coalesce(TextLabels.labels[TextLabel.spam].cast(sa.Integer), 0)).label("spam"),
|
||||
func.sum(coalesce(TextLabels.labels[TextLabel.lang_mismatch].cast(sa.Integer), 0)).label(
|
||||
"lang_mismach"
|
||||
),
|
||||
func.sum(coalesce(TextLabels.labels[TextLabel.not_appropriate].cast(sa.Integer), 0)).label(
|
||||
"not_appropriate"
|
||||
),
|
||||
func.sum(coalesce(TextLabels.labels[TextLabel.pii].cast(sa.Integer), 0)).label("pii"),
|
||||
func.sum(coalesce(TextLabels.labels[TextLabel.hate_speech].cast(sa.Integer), 0)).label("hate_speech"),
|
||||
func.sum(coalesce(TextLabels.labels[TextLabel.sexual_content].cast(sa.Integer), 0)).label(
|
||||
"sexual_content"
|
||||
),
|
||||
func.sum(coalesce(TextLabels.labels[TextLabel.political_content].cast(sa.Integer), 0)).label(
|
||||
"political_content"
|
||||
),
|
||||
func.avg(TextLabels.labels[TextLabel.quality].cast(sa.Float)).label("quality"),
|
||||
func.avg(TextLabels.labels[TextLabel.humor].cast(sa.Float)).label("humor"),
|
||||
func.avg(TextLabels.labels[TextLabel.toxicity].cast(sa.Float)).label("toxicity"),
|
||||
func.avg(TextLabels.labels[TextLabel.violence].cast(sa.Float)).label("violence"),
|
||||
func.avg(TextLabels.labels[TextLabel.helpfulness].cast(sa.Float)).label("helpfulness"),
|
||||
)
|
||||
.select_from(TextLabels)
|
||||
.join(Message, TextLabels.message_id == Message.id)
|
||||
.filter(Message.deleted == sa.false(), Message.emojis.is_not(None))
|
||||
)
|
||||
|
||||
if reference_time:
|
||||
qry = qry.filter(Message.created_date >= reference_time)
|
||||
|
||||
qry = qry.group_by(Message.user_id)
|
||||
return qry
|
||||
|
||||
def _update_troll_stats_internal(self, time_frame: UserStatsTimeFrame, base_date: Optional[datetime] = None):
|
||||
# gather user data
|
||||
|
||||
time_frame_key = time_frame.value
|
||||
|
||||
stats_by_user: dict[UUID, TrollStats] = dict()
|
||||
now = utcnow()
|
||||
|
||||
def get_stats(id: UUID) -> TrollStats:
|
||||
us = stats_by_user.get(id)
|
||||
if not us:
|
||||
us = TrollStats(user_id=id, time_frame=time_frame_key, modified_date=now, base_date=base_date)
|
||||
stats_by_user[id] = us
|
||||
return us
|
||||
|
||||
# emoji counts of user's messages
|
||||
qry = self.query_message_emoji_counts_per_user(reference_time=base_date)
|
||||
for r in qry:
|
||||
uid = r["user_id"]
|
||||
s = get_stats(uid)
|
||||
s.upvotes = r["up"]
|
||||
s.downvotes = r["down"]
|
||||
s.red_flags = r["flag"]
|
||||
|
||||
# num spam prompts
|
||||
qry = self.query_spam_prompts_per_user(reference_time=base_date)
|
||||
for r in qry:
|
||||
uid, count = r
|
||||
s = get_stats(uid).spam_prompts = count
|
||||
|
||||
label_field_names = (
|
||||
"quality",
|
||||
"humor",
|
||||
"toxicity",
|
||||
"violence",
|
||||
"helpfulness",
|
||||
"spam",
|
||||
"lang_mismach",
|
||||
"not_appropriate",
|
||||
"pii",
|
||||
"hate_speech",
|
||||
"sexual_content",
|
||||
"political_content",
|
||||
)
|
||||
|
||||
# label counts / mean values
|
||||
qry = self.query_labels_per_user(reference_time=base_date)
|
||||
for r in qry:
|
||||
uid = r["user_id"]
|
||||
s = get_stats(uid)
|
||||
for fn in label_field_names:
|
||||
setattr(s, fn, r[fn])
|
||||
|
||||
# delete all existing stast for time frame
|
||||
d = delete(TrollStats).where(TrollStats.time_frame == time_frame_key)
|
||||
self.session.execute(d)
|
||||
|
||||
if None in stats_by_user:
|
||||
logger.warning("Some messages in DB have NULL values in user_id column.")
|
||||
del stats_by_user[None]
|
||||
|
||||
# compute magic leader score
|
||||
for v in stats_by_user.values():
|
||||
v.troll_score = v.compute_troll_score()
|
||||
|
||||
# insert user objects
|
||||
self.session.add_all(stats_by_user.values())
|
||||
self.session.flush()
|
||||
|
||||
self.update_troll_ranks(time_frame=time_frame)
|
||||
|
||||
@log_timing(log_kwargs=True)
|
||||
def update_ranks(self, time_frame: UserStatsTimeFrame = None):
|
||||
def update_leader_ranks(self, time_frame: UserStatsTimeFrame = None):
|
||||
"""
|
||||
Update user_stats ranks. The persisted rank values allow to
|
||||
quickly the rank of a single user and to query nearby users.
|
||||
@@ -329,10 +526,41 @@ WHERE
|
||||
r = self.session.execute(
|
||||
text(sql_update_rank), {"time_frame": time_frame.value if time_frame is not None else None}
|
||||
)
|
||||
logger.debug(f"pre_compute_ranks updated({time_frame=}) {r.rowcount} rows.")
|
||||
logger.debug(f"pre_compute_ranks leader updated({time_frame=}) {r.rowcount} rows.")
|
||||
|
||||
def update_stats_time_frame(self, time_frame: UserStatsTimeFrame, reference_time: Optional[datetime] = None):
|
||||
self._update_stats_internal(time_frame, reference_time)
|
||||
@log_timing(log_kwargs=True)
|
||||
def update_troll_ranks(self, time_frame: UserStatsTimeFrame = None):
|
||||
sql_update_troll_rank = """
|
||||
-- update rank
|
||||
UPDATE troll_stats ts
|
||||
SET "rank" = r."rank"
|
||||
FROM
|
||||
(SELECT
|
||||
ROW_NUMBER () OVER(
|
||||
PARTITION BY time_frame
|
||||
ORDER BY troll_score DESC, user_id
|
||||
) AS "rank", user_id, time_frame
|
||||
FROM troll_stats ts2
|
||||
WHERE (:time_frame IS NULL OR time_frame = :time_frame)) AS r
|
||||
WHERE
|
||||
ts.user_id = r.user_id
|
||||
AND ts.time_frame = r.time_frame;"""
|
||||
r = self.session.execute(
|
||||
text(sql_update_troll_rank), {"time_frame": time_frame.value if time_frame is not None else None}
|
||||
)
|
||||
logger.debug(f"pre_compute_ranks troll updated({time_frame=}) {r.rowcount} rows.")
|
||||
|
||||
def update_stats_time_frame(
|
||||
self,
|
||||
time_frame: UserStatsTimeFrame,
|
||||
reference_time: Optional[datetime] = None,
|
||||
leader_stats: bool = True,
|
||||
troll_stats: bool = True,
|
||||
):
|
||||
if leader_stats:
|
||||
self._update_stats_internal(time_frame, reference_time)
|
||||
if troll_stats:
|
||||
self._update_troll_stats_internal(time_frame, reference_time)
|
||||
self.session.commit()
|
||||
|
||||
@log_timing(log_kwargs=True, level="INFO")
|
||||
|
||||
@@ -469,6 +469,47 @@ class LeaderboardStats(BaseModel):
|
||||
leaderboard: List[UserScore]
|
||||
|
||||
|
||||
class TrollScore(BaseModel):
|
||||
rank: Optional[int]
|
||||
user_id: UUID
|
||||
highlighted: bool = False
|
||||
username: str
|
||||
auth_method: str
|
||||
display_name: str
|
||||
last_activity_date: Optional[datetime]
|
||||
|
||||
troll_score: int = 0
|
||||
|
||||
base_date: Optional[datetime]
|
||||
modified_date: Optional[datetime]
|
||||
|
||||
red_flags: int = 0 # num reported messages of user
|
||||
upvotes: int = 0 # num up-voted messages of user
|
||||
downvotes: int = 0 # num down-voted messages of user
|
||||
|
||||
spam_prompts: int = 0
|
||||
|
||||
quality: Optional[float] = None
|
||||
humor: Optional[float] = None
|
||||
toxicity: Optional[float] = None
|
||||
violence: Optional[float] = None
|
||||
helpfulness: Optional[float] = None
|
||||
|
||||
spam: int = 0
|
||||
lang_mismach: int = 0
|
||||
not_appropriate: int = 0
|
||||
pii: int = 0
|
||||
hate_speech: int = 0
|
||||
sexual_content: int = 0
|
||||
political_content: int = 0
|
||||
|
||||
|
||||
class TrollboardStats(BaseModel):
|
||||
time_frame: str
|
||||
last_updated: datetime
|
||||
trollboard: List[TrollScore]
|
||||
|
||||
|
||||
class OasstErrorResponse(BaseModel):
|
||||
"""The format of an error response from the OASST API."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user