From a902c600faa20af2847a95aedd3ffde3abf9752e Mon Sep 17 00:00:00 2001 From: Nil Andreu <65730003+Nil-Andreu@users.noreply.github.com> Date: Sat, 14 Jan 2023 13:22:55 +0100 Subject: [PATCH] Store Message Toxicity in database (#553) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [NEW] MessageToxicity table * [NEW] Alembic message Toxicity * [NEW] Model name enum * [NEW] Refactor Enum HF * [NEW] Settings: DEBUT_SKIP_TOXICITY_CALCULATION * [NEW] Store toxicity values * [FIX] Merge conflict * [FIX] Documentation * [NEW] save_toxicity: function * [FIX] Formatted string * [NEW] DEBUG_SKIP_TOXICITY_CALCULATION=True * [FIX] HfClassificationModel * [FIX] Alembic merge heads * [NEW] Refactor save_toxicity * [NEW] Separating score/label * [NEW] Store score and label * [FIX] Cleaning Alembic * [NEW] Clean HF names * [NEW] Not type hinting * [NEW] Update alembic versions * [NEW] Revert the changes * [NEW] Type hinting label & score * Updated down_revision in migration script Co-authored-by: Andreas Köpf --- ansible/dev.yaml | 1 + ...01_08_2200-bcc2fe18d214_messagetoxicity.py | 40 +++++++++++++++++++ backend/oasst_backend/api/v1/hugging_face.py | 5 ++- backend/oasst_backend/config.py | 1 + backend/oasst_backend/models/__init__.py | 2 + .../oasst_backend/models/message_toxicity.py | 24 +++++++++++ backend/oasst_backend/prompt_repository.py | 23 +++++++++-- backend/oasst_backend/tree_manager.py | 23 ++++++++++- backend/oasst_backend/utils/hugging_face.py | 6 ++- docker-compose.yaml | 1 + scripts/backend-development/run-local.sh | 1 + 11 files changed, 119 insertions(+), 8 deletions(-) create mode 100644 backend/alembic/versions/2023_01_08_2200-bcc2fe18d214_messagetoxicity.py create mode 100644 backend/oasst_backend/models/message_toxicity.py diff --git a/ansible/dev.yaml b/ansible/dev.yaml index 3cf061a5..2bf67b01 100644 --- a/ansible/dev.yaml +++ b/ansible/dev.yaml @@ -83,6 +83,7 @@ MAX_WORKERS: "1" RATE_LIMIT: "false" DEBUG_SKIP_EMBEDDING_COMPUTATION: "true" + DEBUG_SKIP_TOXICITY_CALCULATION: "true" ports: - 8080:8080 diff --git a/backend/alembic/versions/2023_01_08_2200-bcc2fe18d214_messagetoxicity.py b/backend/alembic/versions/2023_01_08_2200-bcc2fe18d214_messagetoxicity.py new file mode 100644 index 00000000..1d17b9d2 --- /dev/null +++ b/backend/alembic/versions/2023_01_08_2200-bcc2fe18d214_messagetoxicity.py @@ -0,0 +1,40 @@ +"""MessageToxicity + +Revision ID: bcc2fe18d214 +Revises: 20cd871f4ec7 +Create Date: 2023-01-08 22:00:43.297719 + +""" +import sqlalchemy as sa +import sqlmodel +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "bcc2fe18d214" +down_revision = "846cc08ac79f" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "message_toxicity", + sa.Column("message_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("toxicity", sa.Float(), nullable=True), + sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.Column("model", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=False), + sa.ForeignKeyConstraint( + ["message_id"], + ["message.id"], + ), + sa.PrimaryKeyConstraint("message_id", "model"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("message_toxicity") + # ### end Alembic commands ### diff --git a/backend/oasst_backend/api/v1/hugging_face.py b/backend/oasst_backend/api/v1/hugging_face.py index 62d2ea6b..a6715574 100644 --- a/backend/oasst_backend/api/v1/hugging_face.py +++ b/backend/oasst_backend/api/v1/hugging_face.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends from oasst_backend.api import deps from oasst_backend.models import ApiClient from oasst_backend.schemas.hugging_face import ToxicityClassification -from oasst_backend.utils.hugging_face import HfUrl, HuggingFaceAPI +from oasst_backend.utils.hugging_face import HfClassificationModel, HfUrl, HuggingFaceAPI router = APIRouter() @@ -25,7 +25,8 @@ async def get_text_toxicity( ToxicityClassification: the score of toxicity of the message. """ - api_url: str = HfUrl.HUGGINGFACE_TOXIC_ROBERTA.value + api_url: str = HfUrl.HUGGINGFACE_TOXIC_CLASSIFICATION.value + "/" + HfClassificationModel.TOXIC_ROBERTA.value + hugging_face_api = HuggingFaceAPI(api_url) response = await hugging_face_api.post(msg) diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index c18bd4c2..2967884c 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -79,6 +79,7 @@ class Settings(BaseSettings): ) DEBUG_ALLOW_SELF_LABELING: bool = False # allow users to label their own messages DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False + DEBUG_SKIP_TOXICITY_CALCULATION: bool = False HUGGING_FACE_API_KEY: str = "" diff --git a/backend/oasst_backend/models/__init__.py b/backend/oasst_backend/models/__init__.py index 0873381c..9dc052d7 100644 --- a/backend/oasst_backend/models/__init__.py +++ b/backend/oasst_backend/models/__init__.py @@ -3,6 +3,7 @@ from .journal import Journal, JournalIntegration from .message import Message from .message_embedding import MessageEmbedding from .message_reaction import MessageReaction +from .message_toxicity import MessageToxicity from .message_tree_state import MessageTreeState from .task import Task from .text_labels import TextLabels @@ -17,6 +18,7 @@ __all__ = [ "MessageEmbedding", "MessageReaction", "MessageTreeState", + "MessageToxicity", "Task", "TextLabels", "Journal", diff --git a/backend/oasst_backend/models/message_toxicity.py b/backend/oasst_backend/models/message_toxicity.py new file mode 100644 index 00000000..8a78e2dc --- /dev/null +++ b/backend/oasst_backend/models/message_toxicity.py @@ -0,0 +1,24 @@ +from datetime import datetime +from typing import Optional +from uuid import UUID + +import sqlalchemy as sa +import sqlalchemy.dialects.postgresql as pg +from sqlmodel import Field, Float, SQLModel + + +class MessageToxicity(SQLModel, table=True): + __tablename__ = "message_toxicity" + __table_args__ = (sa.PrimaryKeyConstraint("message_id", "model"),) + + message_id: UUID = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("message.id"), nullable=False)) + model: str = Field(max_length=256, nullable=False) + + # Storing the score and the label of the message + score: float = Field(sa_column=sa.Column(Float), nullable=False) + label: str = Field(max_length=256, nullable=False) + + # In the case that the Message Embedding is created afterwards + created_date: Optional[datetime] = Field( + sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()) + ) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index cb6e70f7..0c40daa0 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -14,6 +14,7 @@ from oasst_backend.models import ( Message, MessageEmbedding, MessageReaction, + MessageToxicity, MessageTreeState, Task, TextLabels, @@ -293,6 +294,25 @@ class PromptRepository: return reaction, task + def insert_toxicity(self, message_id: UUID, model: str, score: float, label: str) -> MessageToxicity: + """Save the toxicity score of a new message in the database. + Args: + message_id (UUID): the identifier of the message we want to save its toxicity score + model (str): the model used for creating the toxicity score + score (float): the toxicity score that we obtained from the model + label (str): the final classification in toxicity of the model + Raises: + OasstError: if misses some of the before params + Returns: + MessageToxicity: the instance in the database of the score saved for that message + """ + + message_toxicity = MessageToxicity(message_id=message_id, model=model, score=score, label=label) + self.db.add(message_toxicity) + self.db.commit() + self.db.refresh(message_toxicity) + return message_toxicity + def insert_message_embedding(self, message_id: UUID, model: str, embedding: List[float]) -> MessageEmbedding: """Insert the embedding of a new message in the database. @@ -308,9 +328,6 @@ class PromptRepository: MessageEmbedding: the instance in the database of the embedding saved for that message """ - if None in (message_id, model, embedding): - raise OasstError("Paramters missing to add embedding", OasstErrorCode.GENERIC_ERROR) - message_embedding = MessageEmbedding(message_id=message_id, model=model, embedding=embedding) self.db.add(message_embedding) self.db.commit() diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index d7abd9f8..fe45ca0e 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -1,7 +1,7 @@ import random from enum import Enum from http import HTTPStatus -from typing import Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from uuid import UUID import numpy as np @@ -11,7 +11,7 @@ from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversatio from oasst_backend.config import TreeManagerConfiguration, settings from oasst_backend.models import Message, MessageReaction, MessageTreeState, TextLabels, message_tree_state from oasst_backend.prompt_repository import PromptRepository -from oasst_backend.utils.hugging_face import HfEmbeddingModel, HfUrl, HuggingFaceAPI +from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from sqlalchemy.sql import text @@ -363,6 +363,25 @@ class TreeManager: f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}." ) + if not settings.DEBUG_SKIP_TOXICITY_CALCULATION: + try: + model_name: str = HfClassificationModel.TOXIC_ROBERTA.value + hugging_face_api: HuggingFaceAPI = HuggingFaceAPI( + f"{HfUrl.HUGGINGFACE_FEATURE_EXTRACTION.value}/{model_name}" + ) + + toxicity: List[List[Dict[str, Any]]] = await hugging_face_api.post(interaction.text) + toxicity = toxicity[0][0] + + pr.insert_toxicity( + message_id=message.id, model=model_name, score=toxicity["score"], label=toxicity["label"] + ) + + except OasstError: + logger.error( + f"Could not compute toxicity for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}." + ) + case protocol_schema.MessageRating: logger.info( f"Frontend reports rating of {interaction.message_id=} with {interaction.rating=} by {interaction.user=}." diff --git a/backend/oasst_backend/utils/hugging_face.py b/backend/oasst_backend/utils/hugging_face.py index 099bc51f..1aef23ce 100644 --- a/backend/oasst_backend/utils/hugging_face.py +++ b/backend/oasst_backend/utils/hugging_face.py @@ -8,10 +8,14 @@ from oasst_shared.exceptions import OasstError, OasstErrorCode class HfUrl(str, Enum): - HUGGINGFACE_TOXIC_ROBERTA = ("https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta",) + HUGGINGFACE_TOXIC_CLASSIFICATION = "https://api-inference.huggingface.co/models" HUGGINGFACE_FEATURE_EXTRACTION = "https://api-inference.huggingface.co/pipeline/feature-extraction" +class HfClassificationModel(str, Enum): + TOXIC_ROBERTA = "unitary/multilingual-toxic-xlm-roberta" + + class HfEmbeddingModel(str, Enum): MINILM = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" diff --git a/docker-compose.yaml b/docker-compose.yaml index 858acb68..cde65166 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -101,6 +101,7 @@ services: - DEBUG_USE_SEED_DATA=True - DEBUG_ALLOW_SELF_LABELING=True - MAX_WORKERS=1 + - DEBUG_SKIP_TOXICITY_CALCULATION=True - DEBUG_SKIP_EMBEDDING_COMPUTATION=True depends_on: db: diff --git a/scripts/backend-development/run-local.sh b/scripts/backend-development/run-local.sh index 2433c67e..22701ace 100755 --- a/scripts/backend-development/run-local.sh +++ b/scripts/backend-development/run-local.sh @@ -6,6 +6,7 @@ pushd "$parent_path/../../backend" export DEBUG_SKIP_API_KEY_CHECK=True export DEBUG_USE_SEED_DATA=True +export DEBUG_SKIP_TOXICITY_CALCULATION=True export DEBUG_ALLOW_SELF_LABELING=True export DEBUG_SKIP_EMBEDDING_COMPUTATION=True