Store Message Toxicity in database (#553)

* [NEW] MessageToxicity table

* [NEW] Alembic message Toxicity

* [NEW] Model name enum

* [NEW] Refactor Enum HF

* [NEW] Settings: DEBUT_SKIP_TOXICITY_CALCULATION

* [NEW] Store toxicity values

* [FIX] Merge conflict

* [FIX] Documentation

* [NEW] save_toxicity: function

* [FIX] Formatted string

* [NEW] DEBUG_SKIP_TOXICITY_CALCULATION=True

* [FIX] HfClassificationModel

* [FIX] Alembic merge heads

* [NEW] Refactor save_toxicity

* [NEW] Separating score/label

* [NEW] Store score and label

* [FIX] Cleaning Alembic

* [NEW] Clean HF names

* [NEW] Not type hinting

* [NEW] Update alembic versions

* [NEW] Revert the changes

* [NEW] Type hinting label & score

* Updated down_revision in migration script

Co-authored-by: Andreas Köpf <andreas.koepf@xamla.com>
This commit is contained in:
Nil Andreu
2023-01-14 13:22:55 +01:00
committed by GitHub
parent dbf8f77072
commit a902c600fa
11 changed files with 119 additions and 8 deletions
+1
View File
@@ -83,6 +83,7 @@
MAX_WORKERS: "1"
RATE_LIMIT: "false"
DEBUG_SKIP_EMBEDDING_COMPUTATION: "true"
DEBUG_SKIP_TOXICITY_CALCULATION: "true"
ports:
- 8080:8080
@@ -0,0 +1,40 @@
"""MessageToxicity
Revision ID: bcc2fe18d214
Revises: 20cd871f4ec7
Create Date: 2023-01-08 22:00:43.297719
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "bcc2fe18d214"
down_revision = "846cc08ac79f"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"message_toxicity",
sa.Column("message_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("toxicity", sa.Float(), nullable=True),
sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("model", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=False),
sa.ForeignKeyConstraint(
["message_id"],
["message.id"],
),
sa.PrimaryKeyConstraint("message_id", "model"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("message_toxicity")
# ### end Alembic commands ###
+3 -2
View File
@@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends
from oasst_backend.api import deps
from oasst_backend.models import ApiClient
from oasst_backend.schemas.hugging_face import ToxicityClassification
from oasst_backend.utils.hugging_face import HfUrl, HuggingFaceAPI
from oasst_backend.utils.hugging_face import HfClassificationModel, HfUrl, HuggingFaceAPI
router = APIRouter()
@@ -25,7 +25,8 @@ async def get_text_toxicity(
ToxicityClassification: the score of toxicity of the message.
"""
api_url: str = HfUrl.HUGGINGFACE_TOXIC_ROBERTA.value
api_url: str = HfUrl.HUGGINGFACE_TOXIC_CLASSIFICATION.value + "/" + HfClassificationModel.TOXIC_ROBERTA.value
hugging_face_api = HuggingFaceAPI(api_url)
response = await hugging_face_api.post(msg)
+1
View File
@@ -79,6 +79,7 @@ class Settings(BaseSettings):
)
DEBUG_ALLOW_SELF_LABELING: bool = False # allow users to label their own messages
DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False
DEBUG_SKIP_TOXICITY_CALCULATION: bool = False
HUGGING_FACE_API_KEY: str = ""
+2
View File
@@ -3,6 +3,7 @@ from .journal import Journal, JournalIntegration
from .message import Message
from .message_embedding import MessageEmbedding
from .message_reaction import MessageReaction
from .message_toxicity import MessageToxicity
from .message_tree_state import MessageTreeState
from .task import Task
from .text_labels import TextLabels
@@ -17,6 +18,7 @@ __all__ = [
"MessageEmbedding",
"MessageReaction",
"MessageTreeState",
"MessageToxicity",
"Task",
"TextLabels",
"Journal",
@@ -0,0 +1,24 @@
from datetime import datetime
from typing import Optional
from uuid import UUID
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, Float, SQLModel
class MessageToxicity(SQLModel, table=True):
__tablename__ = "message_toxicity"
__table_args__ = (sa.PrimaryKeyConstraint("message_id", "model"),)
message_id: UUID = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("message.id"), nullable=False))
model: str = Field(max_length=256, nullable=False)
# Storing the score and the label of the message
score: float = Field(sa_column=sa.Column(Float), nullable=False)
label: str = Field(max_length=256, nullable=False)
# In the case that the Message Embedding is created afterwards
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
)
+20 -3
View File
@@ -14,6 +14,7 @@ from oasst_backend.models import (
Message,
MessageEmbedding,
MessageReaction,
MessageToxicity,
MessageTreeState,
Task,
TextLabels,
@@ -293,6 +294,25 @@ class PromptRepository:
return reaction, task
def insert_toxicity(self, message_id: UUID, model: str, score: float, label: str) -> MessageToxicity:
"""Save the toxicity score of a new message in the database.
Args:
message_id (UUID): the identifier of the message we want to save its toxicity score
model (str): the model used for creating the toxicity score
score (float): the toxicity score that we obtained from the model
label (str): the final classification in toxicity of the model
Raises:
OasstError: if misses some of the before params
Returns:
MessageToxicity: the instance in the database of the score saved for that message
"""
message_toxicity = MessageToxicity(message_id=message_id, model=model, score=score, label=label)
self.db.add(message_toxicity)
self.db.commit()
self.db.refresh(message_toxicity)
return message_toxicity
def insert_message_embedding(self, message_id: UUID, model: str, embedding: List[float]) -> MessageEmbedding:
"""Insert the embedding of a new message in the database.
@@ -308,9 +328,6 @@ class PromptRepository:
MessageEmbedding: the instance in the database of the embedding saved for that message
"""
if None in (message_id, model, embedding):
raise OasstError("Paramters missing to add embedding", OasstErrorCode.GENERIC_ERROR)
message_embedding = MessageEmbedding(message_id=message_id, model=model, embedding=embedding)
self.db.add(message_embedding)
self.db.commit()
+21 -2
View File
@@ -1,7 +1,7 @@
import random
from enum import Enum
from http import HTTPStatus
from typing import Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
from uuid import UUID
import numpy as np
@@ -11,7 +11,7 @@ from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversatio
from oasst_backend.config import TreeManagerConfiguration, settings
from oasst_backend.models import Message, MessageReaction, MessageTreeState, TextLabels, message_tree_state
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.utils.hugging_face import HfEmbeddingModel, HfUrl, HuggingFaceAPI
from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlalchemy.sql import text
@@ -363,6 +363,25 @@ class TreeManager:
f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
)
if not settings.DEBUG_SKIP_TOXICITY_CALCULATION:
try:
model_name: str = HfClassificationModel.TOXIC_ROBERTA.value
hugging_face_api: HuggingFaceAPI = HuggingFaceAPI(
f"{HfUrl.HUGGINGFACE_FEATURE_EXTRACTION.value}/{model_name}"
)
toxicity: List[List[Dict[str, Any]]] = await hugging_face_api.post(interaction.text)
toxicity = toxicity[0][0]
pr.insert_toxicity(
message_id=message.id, model=model_name, score=toxicity["score"], label=toxicity["label"]
)
except OasstError:
logger.error(
f"Could not compute toxicity for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
)
case protocol_schema.MessageRating:
logger.info(
f"Frontend reports rating of {interaction.message_id=} with {interaction.rating=} by {interaction.user=}."
+5 -1
View File
@@ -8,10 +8,14 @@ from oasst_shared.exceptions import OasstError, OasstErrorCode
class HfUrl(str, Enum):
HUGGINGFACE_TOXIC_ROBERTA = ("https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta",)
HUGGINGFACE_TOXIC_CLASSIFICATION = "https://api-inference.huggingface.co/models"
HUGGINGFACE_FEATURE_EXTRACTION = "https://api-inference.huggingface.co/pipeline/feature-extraction"
class HfClassificationModel(str, Enum):
TOXIC_ROBERTA = "unitary/multilingual-toxic-xlm-roberta"
class HfEmbeddingModel(str, Enum):
MINILM = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
+1
View File
@@ -101,6 +101,7 @@ services:
- DEBUG_USE_SEED_DATA=True
- DEBUG_ALLOW_SELF_LABELING=True
- MAX_WORKERS=1
- DEBUG_SKIP_TOXICITY_CALCULATION=True
- DEBUG_SKIP_EMBEDDING_COMPUTATION=True
depends_on:
db:
+1
View File
@@ -6,6 +6,7 @@ pushd "$parent_path/../../backend"
export DEBUG_SKIP_API_KEY_CHECK=True
export DEBUG_USE_SEED_DATA=True
export DEBUG_SKIP_TOXICITY_CALCULATION=True
export DEBUG_ALLOW_SELF_LABELING=True
export DEBUG_SKIP_EMBEDDING_COMPUTATION=True