mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Store Message Toxicity in database (#553)
* [NEW] MessageToxicity table * [NEW] Alembic message Toxicity * [NEW] Model name enum * [NEW] Refactor Enum HF * [NEW] Settings: DEBUT_SKIP_TOXICITY_CALCULATION * [NEW] Store toxicity values * [FIX] Merge conflict * [FIX] Documentation * [NEW] save_toxicity: function * [FIX] Formatted string * [NEW] DEBUG_SKIP_TOXICITY_CALCULATION=True * [FIX] HfClassificationModel * [FIX] Alembic merge heads * [NEW] Refactor save_toxicity * [NEW] Separating score/label * [NEW] Store score and label * [FIX] Cleaning Alembic * [NEW] Clean HF names * [NEW] Not type hinting * [NEW] Update alembic versions * [NEW] Revert the changes * [NEW] Type hinting label & score * Updated down_revision in migration script Co-authored-by: Andreas Köpf <andreas.koepf@xamla.com>
This commit is contained in:
@@ -83,6 +83,7 @@
|
||||
MAX_WORKERS: "1"
|
||||
RATE_LIMIT: "false"
|
||||
DEBUG_SKIP_EMBEDDING_COMPUTATION: "true"
|
||||
DEBUG_SKIP_TOXICITY_CALCULATION: "true"
|
||||
ports:
|
||||
- 8080:8080
|
||||
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
"""MessageToxicity
|
||||
|
||||
Revision ID: bcc2fe18d214
|
||||
Revises: 20cd871f4ec7
|
||||
Create Date: 2023-01-08 22:00:43.297719
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "bcc2fe18d214"
|
||||
down_revision = "846cc08ac79f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"message_toxicity",
|
||||
sa.Column("message_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("toxicity", sa.Float(), nullable=True),
|
||||
sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("model", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["message_id"],
|
||||
["message.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("message_id", "model"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("message_toxicity")
|
||||
# ### end Alembic commands ###
|
||||
@@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.schemas.hugging_face import ToxicityClassification
|
||||
from oasst_backend.utils.hugging_face import HfUrl, HuggingFaceAPI
|
||||
from oasst_backend.utils.hugging_face import HfClassificationModel, HfUrl, HuggingFaceAPI
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -25,7 +25,8 @@ async def get_text_toxicity(
|
||||
ToxicityClassification: the score of toxicity of the message.
|
||||
"""
|
||||
|
||||
api_url: str = HfUrl.HUGGINGFACE_TOXIC_ROBERTA.value
|
||||
api_url: str = HfUrl.HUGGINGFACE_TOXIC_CLASSIFICATION.value + "/" + HfClassificationModel.TOXIC_ROBERTA.value
|
||||
|
||||
hugging_face_api = HuggingFaceAPI(api_url)
|
||||
response = await hugging_face_api.post(msg)
|
||||
|
||||
|
||||
@@ -79,6 +79,7 @@ class Settings(BaseSettings):
|
||||
)
|
||||
DEBUG_ALLOW_SELF_LABELING: bool = False # allow users to label their own messages
|
||||
DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False
|
||||
DEBUG_SKIP_TOXICITY_CALCULATION: bool = False
|
||||
|
||||
HUGGING_FACE_API_KEY: str = ""
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from .journal import Journal, JournalIntegration
|
||||
from .message import Message
|
||||
from .message_embedding import MessageEmbedding
|
||||
from .message_reaction import MessageReaction
|
||||
from .message_toxicity import MessageToxicity
|
||||
from .message_tree_state import MessageTreeState
|
||||
from .task import Task
|
||||
from .text_labels import TextLabels
|
||||
@@ -17,6 +18,7 @@ __all__ = [
|
||||
"MessageEmbedding",
|
||||
"MessageReaction",
|
||||
"MessageTreeState",
|
||||
"MessageToxicity",
|
||||
"Task",
|
||||
"TextLabels",
|
||||
"Journal",
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, Float, SQLModel
|
||||
|
||||
|
||||
class MessageToxicity(SQLModel, table=True):
|
||||
__tablename__ = "message_toxicity"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("message_id", "model"),)
|
||||
|
||||
message_id: UUID = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("message.id"), nullable=False))
|
||||
model: str = Field(max_length=256, nullable=False)
|
||||
|
||||
# Storing the score and the label of the message
|
||||
score: float = Field(sa_column=sa.Column(Float), nullable=False)
|
||||
label: str = Field(max_length=256, nullable=False)
|
||||
|
||||
# In the case that the Message Embedding is created afterwards
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
@@ -14,6 +14,7 @@ from oasst_backend.models import (
|
||||
Message,
|
||||
MessageEmbedding,
|
||||
MessageReaction,
|
||||
MessageToxicity,
|
||||
MessageTreeState,
|
||||
Task,
|
||||
TextLabels,
|
||||
@@ -293,6 +294,25 @@ class PromptRepository:
|
||||
|
||||
return reaction, task
|
||||
|
||||
def insert_toxicity(self, message_id: UUID, model: str, score: float, label: str) -> MessageToxicity:
|
||||
"""Save the toxicity score of a new message in the database.
|
||||
Args:
|
||||
message_id (UUID): the identifier of the message we want to save its toxicity score
|
||||
model (str): the model used for creating the toxicity score
|
||||
score (float): the toxicity score that we obtained from the model
|
||||
label (str): the final classification in toxicity of the model
|
||||
Raises:
|
||||
OasstError: if misses some of the before params
|
||||
Returns:
|
||||
MessageToxicity: the instance in the database of the score saved for that message
|
||||
"""
|
||||
|
||||
message_toxicity = MessageToxicity(message_id=message_id, model=model, score=score, label=label)
|
||||
self.db.add(message_toxicity)
|
||||
self.db.commit()
|
||||
self.db.refresh(message_toxicity)
|
||||
return message_toxicity
|
||||
|
||||
def insert_message_embedding(self, message_id: UUID, model: str, embedding: List[float]) -> MessageEmbedding:
|
||||
"""Insert the embedding of a new message in the database.
|
||||
|
||||
@@ -308,9 +328,6 @@ class PromptRepository:
|
||||
MessageEmbedding: the instance in the database of the embedding saved for that message
|
||||
"""
|
||||
|
||||
if None in (message_id, model, embedding):
|
||||
raise OasstError("Paramters missing to add embedding", OasstErrorCode.GENERIC_ERROR)
|
||||
|
||||
message_embedding = MessageEmbedding(message_id=message_id, model=model, embedding=embedding)
|
||||
self.db.add(message_embedding)
|
||||
self.db.commit()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import random
|
||||
from enum import Enum
|
||||
from http import HTTPStatus
|
||||
from typing import Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
import numpy as np
|
||||
@@ -11,7 +11,7 @@ from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversatio
|
||||
from oasst_backend.config import TreeManagerConfiguration, settings
|
||||
from oasst_backend.models import Message, MessageReaction, MessageTreeState, TextLabels, message_tree_state
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.utils.hugging_face import HfEmbeddingModel, HfUrl, HuggingFaceAPI
|
||||
from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlalchemy.sql import text
|
||||
@@ -363,6 +363,25 @@ class TreeManager:
|
||||
f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
if not settings.DEBUG_SKIP_TOXICITY_CALCULATION:
|
||||
try:
|
||||
model_name: str = HfClassificationModel.TOXIC_ROBERTA.value
|
||||
hugging_face_api: HuggingFaceAPI = HuggingFaceAPI(
|
||||
f"{HfUrl.HUGGINGFACE_FEATURE_EXTRACTION.value}/{model_name}"
|
||||
)
|
||||
|
||||
toxicity: List[List[Dict[str, Any]]] = await hugging_face_api.post(interaction.text)
|
||||
toxicity = toxicity[0][0]
|
||||
|
||||
pr.insert_toxicity(
|
||||
message_id=message.id, model=model_name, score=toxicity["score"], label=toxicity["label"]
|
||||
)
|
||||
|
||||
except OasstError:
|
||||
logger.error(
|
||||
f"Could not compute toxicity for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
case protocol_schema.MessageRating:
|
||||
logger.info(
|
||||
f"Frontend reports rating of {interaction.message_id=} with {interaction.rating=} by {interaction.user=}."
|
||||
|
||||
@@ -8,10 +8,14 @@ from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
|
||||
|
||||
class HfUrl(str, Enum):
|
||||
HUGGINGFACE_TOXIC_ROBERTA = ("https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta",)
|
||||
HUGGINGFACE_TOXIC_CLASSIFICATION = "https://api-inference.huggingface.co/models"
|
||||
HUGGINGFACE_FEATURE_EXTRACTION = "https://api-inference.huggingface.co/pipeline/feature-extraction"
|
||||
|
||||
|
||||
class HfClassificationModel(str, Enum):
|
||||
TOXIC_ROBERTA = "unitary/multilingual-toxic-xlm-roberta"
|
||||
|
||||
|
||||
class HfEmbeddingModel(str, Enum):
|
||||
MINILM = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
||||
|
||||
|
||||
@@ -101,6 +101,7 @@ services:
|
||||
- DEBUG_USE_SEED_DATA=True
|
||||
- DEBUG_ALLOW_SELF_LABELING=True
|
||||
- MAX_WORKERS=1
|
||||
- DEBUG_SKIP_TOXICITY_CALCULATION=True
|
||||
- DEBUG_SKIP_EMBEDDING_COMPUTATION=True
|
||||
depends_on:
|
||||
db:
|
||||
|
||||
@@ -6,6 +6,7 @@ pushd "$parent_path/../../backend"
|
||||
|
||||
export DEBUG_SKIP_API_KEY_CHECK=True
|
||||
export DEBUG_USE_SEED_DATA=True
|
||||
export DEBUG_SKIP_TOXICITY_CALCULATION=True
|
||||
export DEBUG_ALLOW_SELF_LABELING=True
|
||||
export DEBUG_SKIP_EMBEDDING_COMPUTATION=True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user