From 11d55d572a06ba42193f846405128d9da21f9957 Mon Sep 17 00:00:00 2001 From: jojopirker Date: Sun, 8 Jan 2023 12:28:38 +0100 Subject: [PATCH] message embeddings in Messages table --- ...dded_minilm_embedding_column_to_message.py | 29 +++++++++++++++++++ backend/oasst_backend/api/v1/hugging_face.py | 7 +---- backend/oasst_backend/api/v1/tasks.py | 15 +++++++++- backend/oasst_backend/config.py | 1 + backend/oasst_backend/models/message.py | 5 ++-- .../oasst_backend/models/message_embedding.py | 0 backend/oasst_backend/prompt_repository.py | 9 ++++-- backend/oasst_backend/utils/hugging_face.py | 10 +++++++ docker-compose.yaml | 1 + scripts/backend-development/run-local.sh | 1 + 10 files changed, 67 insertions(+), 11 deletions(-) create mode 100644 backend/alembic/versions/2023_01_08_1106-3d96bb92e33a_added_minilm_embedding_column_to_message.py create mode 100644 backend/oasst_backend/models/message_embedding.py diff --git a/backend/alembic/versions/2023_01_08_1106-3d96bb92e33a_added_minilm_embedding_column_to_message.py b/backend/alembic/versions/2023_01_08_1106-3d96bb92e33a_added_minilm_embedding_column_to_message.py new file mode 100644 index 00000000..843f03bc --- /dev/null +++ b/backend/alembic/versions/2023_01_08_1106-3d96bb92e33a_added_minilm_embedding_column_to_message.py @@ -0,0 +1,29 @@ +"""added miniLM_embedding column to message + +Revision ID: 023548d474f7 +Revises: ba61fe17fb6e +Create Date: 2023-01-08 11:06:25.613290 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel + + +# revision identifiers, used by Alembic. +revision = '023548d474f7' +down_revision = 'ba61fe17fb6e' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('message', sa.Column('miniLM_embedding', sa.ARRAY(sa.Float()), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('message', 'miniLM_embedding') + # ### end Alembic commands ### diff --git a/backend/oasst_backend/api/v1/hugging_face.py b/backend/oasst_backend/api/v1/hugging_face.py index 1e7f1ffe..a8d8aeb9 100644 --- a/backend/oasst_backend/api/v1/hugging_face.py +++ b/backend/oasst_backend/api/v1/hugging_face.py @@ -1,19 +1,14 @@ -from enum import Enum from typing import List from fastapi import APIRouter, Depends from oasst_backend.api import deps from oasst_backend.models import ApiClient from oasst_backend.schemas.hugging_face import ToxicityClassification -from oasst_backend.utils.hugging_face import HuggingFaceAPI +from oasst_backend.utils.hugging_face import HF_url, HuggingFaceAPI router = APIRouter() -class HF_url(str, Enum): - HUGGINGFACE_TOXIC_ROBERTA = "https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta" - - @router.get("/text_toxicity") async def get_text_toxicity( msg: str, diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index adfb2907..2248b85a 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -7,7 +7,9 @@ from fastapi.security.api_key import APIKey from loguru import logger from oasst_backend.api import deps from oasst_backend.api.v1.utils import prepare_conversation +from oasst_backend.config import settings from oasst_backend.prompt_repository import PromptRepository +from oasst_backend.utils.hugging_face import HF_url, HuggingFaceAPI from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from sqlmodel import Session @@ -253,7 +255,7 @@ def tasks_acknowledge_failure( @router.post("/interaction", response_model=protocol_schema.TaskDone) -def tasks_interaction( +async def tasks_interaction( *, db: Session = Depends(deps.get_db), api_key: APIKey = Depends(deps.get_api_key), @@ -273,11 +275,22 @@ def tasks_interaction( f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}." ) + embedding = None + if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION: + try: + hugging_face_api = HuggingFaceAPI(HF_url.HUGGINGFACE_MINILM_EMBEDDING.value) + embedding = await hugging_face_api.post(interaction.text) + except: + logger.error( + f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}." + ) + # here we store the text reply in the database pr.store_text_reply( text=interaction.text, frontend_message_id=interaction.message_id, user_frontend_message_id=interaction.user_message_id, + miniLM_embedding=embedding, ) return protocol_schema.TaskDone() diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index 1765af7a..ed394412 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -25,6 +25,7 @@ class Settings(BaseSettings): DEBUG_USE_SEED_DATA_PATH: Optional[FilePath] = ( Path(__file__).parent.parent / "test_data/generic/test_generic_data.json" ) + DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False HUGGING_FACE_API_KEY: str = "" diff --git a/backend/oasst_backend/models/message.py b/backend/oasst_backend/models/message.py index 6d24fd13..c7c2abbb 100644 --- a/backend/oasst_backend/models/message.py +++ b/backend/oasst_backend/models/message.py @@ -1,6 +1,6 @@ from datetime import datetime from http import HTTPStatus -from typing import Optional +from typing import List, Optional from uuid import UUID, uuid4 import sqlalchemy as sa @@ -8,7 +8,7 @@ import sqlalchemy.dialects.postgresql as pg from oasst_backend.models.db_payload import MessagePayload from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from sqlalchemy import false -from sqlmodel import Field, Index, SQLModel +from sqlmodel import ARRAY, Field, Float, Index, SQLModel from .payload_column_type import PayloadContainer, payload_column_type @@ -40,6 +40,7 @@ class Message(SQLModel, table=True): depth: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False)) children_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False)) deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false())) + miniLM_embedding: List[float] = Field(sa_column=sa.Column(ARRAY(Float)), nullable=True) def ensure_is_message(self) -> None: if not self.payload or not isinstance(self.payload.payload, MessagePayload): diff --git a/backend/oasst_backend/models/message_embedding.py b/backend/oasst_backend/models/message_embedding.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 7c7dd7b6..8bb1eb4a 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -2,7 +2,7 @@ import datetime import random from collections import defaultdict from http import HTTPStatus -from typing import Optional +from typing import List, Optional from uuid import UUID, uuid4 import oasst_backend.models.db_payload as db_payload @@ -122,7 +122,9 @@ class PromptRepository: ) return task - def store_text_reply(self, text: str, frontend_message_id: str, user_frontend_message_id: str) -> Message: + def store_text_reply( + self, text: str, frontend_message_id: str, user_frontend_message_id: str, miniLM_embedding: List[float] = None + ) -> Message: self.validate_frontend_message_id(frontend_message_id) self.validate_frontend_message_id(user_frontend_message_id) @@ -163,6 +165,7 @@ class PromptRepository: role=role, payload=db_payload.MessagePayload(text=text), depth=depth, + miniLM_embedding=miniLM_embedding, ) if not task.collective: task.done = True @@ -366,6 +369,7 @@ class PromptRepository: payload: db_payload.MessagePayload, payload_type: str = None, depth: int = 0, + miniLM_embedding: List[float] = None, ) -> Message: if payload_type is None: if payload is None: @@ -385,6 +389,7 @@ class PromptRepository: payload_type=payload_type, payload=PayloadContainer(payload=payload), depth=depth, + miniLM_embedding=miniLM_embedding, ) self.db.add(message) self.db.commit() diff --git a/backend/oasst_backend/utils/hugging_face.py b/backend/oasst_backend/utils/hugging_face.py index 0df913f5..867c537c 100644 --- a/backend/oasst_backend/utils/hugging_face.py +++ b/backend/oasst_backend/utils/hugging_face.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Any, Dict import aiohttp @@ -5,6 +6,11 @@ from oasst_backend.config import settings from oasst_shared.exceptions import OasstError, OasstErrorCode +class HF_url(str, Enum): + HUGGINGFACE_TOXIC_ROBERTA = ("https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta",) + HUGGINGFACE_MINILM_EMBEDDING = "https://api-inference.huggingface.co/pipeline/feature-extraction/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" + + class HuggingFaceAPI: """Class Object to make post calls to endpoints for inference in models hosted in HuggingFace""" @@ -41,6 +47,10 @@ class HuggingFaceAPI: async with session.post(self.api_url, headers=self.headers, json=payload) as response: # If we get a bad response if response.status != 200: + from loguru import logger + + logger.error(response) + logger.info(self.headers) raise OasstError( "Response Error Detoxify HuggingFace", error_code=OasstErrorCode.HUGGINGFACE_API_ERROR ) diff --git a/docker-compose.yaml b/docker-compose.yaml index 6bc42c51..5b0032f4 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -95,6 +95,7 @@ services: - DEBUG_SKIP_API_KEY_CHECK=True - DEBUG_USE_SEED_DATA=True - MAX_WORKERS=1 + - DEBUG_SKIP_EMBEDDING_COMPUTATION=True depends_on: db: condition: service_healthy diff --git a/scripts/backend-development/run-local.sh b/scripts/backend-development/run-local.sh index 3ed0e936..f0f6d16c 100755 --- a/scripts/backend-development/run-local.sh +++ b/scripts/backend-development/run-local.sh @@ -6,6 +6,7 @@ pushd "$parent_path/../../backend" export DEBUG_SKIP_API_KEY_CHECK=True export DEBUG_USE_SEED_DATA=True +export DEBUG_SKIP_EMBEDDING_COMPUTATION=True uvicorn main:app --reload --port 8080 --host 0.0.0.0