From 11d55d572a06ba42193f846405128d9da21f9957 Mon Sep 17 00:00:00 2001 From: jojopirker Date: Sun, 8 Jan 2023 12:28:38 +0100 Subject: [PATCH 01/12] message embeddings in Messages table --- ...dded_minilm_embedding_column_to_message.py | 29 +++++++++++++++++++ backend/oasst_backend/api/v1/hugging_face.py | 7 +---- backend/oasst_backend/api/v1/tasks.py | 15 +++++++++- backend/oasst_backend/config.py | 1 + backend/oasst_backend/models/message.py | 5 ++-- .../oasst_backend/models/message_embedding.py | 0 backend/oasst_backend/prompt_repository.py | 9 ++++-- backend/oasst_backend/utils/hugging_face.py | 10 +++++++ docker-compose.yaml | 1 + scripts/backend-development/run-local.sh | 1 + 10 files changed, 67 insertions(+), 11 deletions(-) create mode 100644 backend/alembic/versions/2023_01_08_1106-3d96bb92e33a_added_minilm_embedding_column_to_message.py create mode 100644 backend/oasst_backend/models/message_embedding.py diff --git a/backend/alembic/versions/2023_01_08_1106-3d96bb92e33a_added_minilm_embedding_column_to_message.py b/backend/alembic/versions/2023_01_08_1106-3d96bb92e33a_added_minilm_embedding_column_to_message.py new file mode 100644 index 00000000..843f03bc --- /dev/null +++ b/backend/alembic/versions/2023_01_08_1106-3d96bb92e33a_added_minilm_embedding_column_to_message.py @@ -0,0 +1,29 @@ +"""added miniLM_embedding column to message + +Revision ID: 023548d474f7 +Revises: ba61fe17fb6e +Create Date: 2023-01-08 11:06:25.613290 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel + + +# revision identifiers, used by Alembic. +revision = '023548d474f7' +down_revision = 'ba61fe17fb6e' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('message', sa.Column('miniLM_embedding', sa.ARRAY(sa.Float()), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('message', 'miniLM_embedding') + # ### end Alembic commands ### diff --git a/backend/oasst_backend/api/v1/hugging_face.py b/backend/oasst_backend/api/v1/hugging_face.py index 1e7f1ffe..a8d8aeb9 100644 --- a/backend/oasst_backend/api/v1/hugging_face.py +++ b/backend/oasst_backend/api/v1/hugging_face.py @@ -1,19 +1,14 @@ -from enum import Enum from typing import List from fastapi import APIRouter, Depends from oasst_backend.api import deps from oasst_backend.models import ApiClient from oasst_backend.schemas.hugging_face import ToxicityClassification -from oasst_backend.utils.hugging_face import HuggingFaceAPI +from oasst_backend.utils.hugging_face import HF_url, HuggingFaceAPI router = APIRouter() -class HF_url(str, Enum): - HUGGINGFACE_TOXIC_ROBERTA = "https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta" - - @router.get("/text_toxicity") async def get_text_toxicity( msg: str, diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index adfb2907..2248b85a 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -7,7 +7,9 @@ from fastapi.security.api_key import APIKey from loguru import logger from oasst_backend.api import deps from oasst_backend.api.v1.utils import prepare_conversation +from oasst_backend.config import settings from oasst_backend.prompt_repository import PromptRepository +from oasst_backend.utils.hugging_face import HF_url, HuggingFaceAPI from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from sqlmodel import Session @@ -253,7 +255,7 @@ def tasks_acknowledge_failure( @router.post("/interaction", response_model=protocol_schema.TaskDone) -def tasks_interaction( +async def tasks_interaction( *, db: Session = Depends(deps.get_db), api_key: APIKey = Depends(deps.get_api_key), @@ -273,11 +275,22 @@ def tasks_interaction( f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}." ) + embedding = None + if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION: + try: + hugging_face_api = HuggingFaceAPI(HF_url.HUGGINGFACE_MINILM_EMBEDDING.value) + embedding = await hugging_face_api.post(interaction.text) + except: + logger.error( + f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}." + ) + # here we store the text reply in the database pr.store_text_reply( text=interaction.text, frontend_message_id=interaction.message_id, user_frontend_message_id=interaction.user_message_id, + miniLM_embedding=embedding, ) return protocol_schema.TaskDone() diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index 1765af7a..ed394412 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -25,6 +25,7 @@ class Settings(BaseSettings): DEBUG_USE_SEED_DATA_PATH: Optional[FilePath] = ( Path(__file__).parent.parent / "test_data/generic/test_generic_data.json" ) + DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False HUGGING_FACE_API_KEY: str = "" diff --git a/backend/oasst_backend/models/message.py b/backend/oasst_backend/models/message.py index 6d24fd13..c7c2abbb 100644 --- a/backend/oasst_backend/models/message.py +++ b/backend/oasst_backend/models/message.py @@ -1,6 +1,6 @@ from datetime import datetime from http import HTTPStatus -from typing import Optional +from typing import List, Optional from uuid import UUID, uuid4 import sqlalchemy as sa @@ -8,7 +8,7 @@ import sqlalchemy.dialects.postgresql as pg from oasst_backend.models.db_payload import MessagePayload from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from sqlalchemy import false -from sqlmodel import Field, Index, SQLModel +from sqlmodel import ARRAY, Field, Float, Index, SQLModel from .payload_column_type import PayloadContainer, payload_column_type @@ -40,6 +40,7 @@ class Message(SQLModel, table=True): depth: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False)) children_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False)) deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false())) + miniLM_embedding: List[float] = Field(sa_column=sa.Column(ARRAY(Float)), nullable=True) def ensure_is_message(self) -> None: if not self.payload or not isinstance(self.payload.payload, MessagePayload): diff --git a/backend/oasst_backend/models/message_embedding.py b/backend/oasst_backend/models/message_embedding.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 7c7dd7b6..8bb1eb4a 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -2,7 +2,7 @@ import datetime import random from collections import defaultdict from http import HTTPStatus -from typing import Optional +from typing import List, Optional from uuid import UUID, uuid4 import oasst_backend.models.db_payload as db_payload @@ -122,7 +122,9 @@ class PromptRepository: ) return task - def store_text_reply(self, text: str, frontend_message_id: str, user_frontend_message_id: str) -> Message: + def store_text_reply( + self, text: str, frontend_message_id: str, user_frontend_message_id: str, miniLM_embedding: List[float] = None + ) -> Message: self.validate_frontend_message_id(frontend_message_id) self.validate_frontend_message_id(user_frontend_message_id) @@ -163,6 +165,7 @@ class PromptRepository: role=role, payload=db_payload.MessagePayload(text=text), depth=depth, + miniLM_embedding=miniLM_embedding, ) if not task.collective: task.done = True @@ -366,6 +369,7 @@ class PromptRepository: payload: db_payload.MessagePayload, payload_type: str = None, depth: int = 0, + miniLM_embedding: List[float] = None, ) -> Message: if payload_type is None: if payload is None: @@ -385,6 +389,7 @@ class PromptRepository: payload_type=payload_type, payload=PayloadContainer(payload=payload), depth=depth, + miniLM_embedding=miniLM_embedding, ) self.db.add(message) self.db.commit() diff --git a/backend/oasst_backend/utils/hugging_face.py b/backend/oasst_backend/utils/hugging_face.py index 0df913f5..867c537c 100644 --- a/backend/oasst_backend/utils/hugging_face.py +++ b/backend/oasst_backend/utils/hugging_face.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Any, Dict import aiohttp @@ -5,6 +6,11 @@ from oasst_backend.config import settings from oasst_shared.exceptions import OasstError, OasstErrorCode +class HF_url(str, Enum): + HUGGINGFACE_TOXIC_ROBERTA = ("https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta",) + HUGGINGFACE_MINILM_EMBEDDING = "https://api-inference.huggingface.co/pipeline/feature-extraction/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" + + class HuggingFaceAPI: """Class Object to make post calls to endpoints for inference in models hosted in HuggingFace""" @@ -41,6 +47,10 @@ class HuggingFaceAPI: async with session.post(self.api_url, headers=self.headers, json=payload) as response: # If we get a bad response if response.status != 200: + from loguru import logger + + logger.error(response) + logger.info(self.headers) raise OasstError( "Response Error Detoxify HuggingFace", error_code=OasstErrorCode.HUGGINGFACE_API_ERROR ) diff --git a/docker-compose.yaml b/docker-compose.yaml index 6bc42c51..5b0032f4 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -95,6 +95,7 @@ services: - DEBUG_SKIP_API_KEY_CHECK=True - DEBUG_USE_SEED_DATA=True - MAX_WORKERS=1 + - DEBUG_SKIP_EMBEDDING_COMPUTATION=True depends_on: db: condition: service_healthy diff --git a/scripts/backend-development/run-local.sh b/scripts/backend-development/run-local.sh index 3ed0e936..f0f6d16c 100755 --- a/scripts/backend-development/run-local.sh +++ b/scripts/backend-development/run-local.sh @@ -6,6 +6,7 @@ pushd "$parent_path/../../backend" export DEBUG_SKIP_API_KEY_CHECK=True export DEBUG_USE_SEED_DATA=True +export DEBUG_SKIP_EMBEDDING_COMPUTATION=True uvicorn main:app --reload --port 8080 --host 0.0.0.0 From 34e7d1db8a272551a9f29bddb1f34ba7a3df8952 Mon Sep 17 00:00:00 2001 From: Nil-Andreu Date: Sun, 8 Jan 2023 14:36:38 +0100 Subject: [PATCH 02/12] [NEW] Except OasstError --- backend/oasst_backend/api/v1/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index 2248b85a..8656a063 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -280,7 +280,7 @@ async def tasks_interaction( try: hugging_face_api = HuggingFaceAPI(HF_url.HUGGINGFACE_MINILM_EMBEDDING.value) embedding = await hugging_face_api.post(interaction.text) - except: + except OasstError: logger.error( f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}." ) From a677e40cffcc026a431c333af64b42b823c5af7a Mon Sep 17 00:00:00 2001 From: jojopirker Date: Sun, 8 Jan 2023 16:46:53 +0100 Subject: [PATCH 03/12] insert embedding now to new table --- ...dded_minilm_embedding_column_to_message.py | 12 ++--- ...8_embedding_for_message_now_in_its_own_.py | 49 +++++++++++++++++++ backend/oasst_backend/api/v1/tasks.py | 25 +++++----- backend/oasst_backend/models/__init__.py | 2 + backend/oasst_backend/models/message.py | 5 +- .../oasst_backend/models/message_embedding.py | 15 ++++++ backend/oasst_backend/prompt_repository.py | 15 ++++-- backend/oasst_backend/utils/hugging_face.py | 6 ++- 8 files changed, 103 insertions(+), 26 deletions(-) create mode 100644 backend/alembic/versions/2023_01_08_1603-35bdc1a08bb8_embedding_for_message_now_in_its_own_.py diff --git a/backend/alembic/versions/2023_01_08_1106-3d96bb92e33a_added_minilm_embedding_column_to_message.py b/backend/alembic/versions/2023_01_08_1106-3d96bb92e33a_added_minilm_embedding_column_to_message.py index 843f03bc..9b81105f 100644 --- a/backend/alembic/versions/2023_01_08_1106-3d96bb92e33a_added_minilm_embedding_column_to_message.py +++ b/backend/alembic/versions/2023_01_08_1106-3d96bb92e33a_added_minilm_embedding_column_to_message.py @@ -5,25 +5,23 @@ Revises: ba61fe17fb6e Create Date: 2023-01-08 11:06:25.613290 """ -from alembic import op import sqlalchemy as sa -import sqlmodel - +from alembic import op # revision identifiers, used by Alembic. -revision = '023548d474f7' -down_revision = 'ba61fe17fb6e' +revision = "023548d474f7" +down_revision = "ba61fe17fb6e" branch_labels = None depends_on = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.add_column('message', sa.Column('miniLM_embedding', sa.ARRAY(sa.Float()), nullable=True)) + op.add_column("message", sa.Column("miniLM_embedding", sa.ARRAY(sa.Float()), nullable=True)) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_column('message', 'miniLM_embedding') + op.drop_column("message", "miniLM_embedding") # ### end Alembic commands ### diff --git a/backend/alembic/versions/2023_01_08_1603-35bdc1a08bb8_embedding_for_message_now_in_its_own_.py b/backend/alembic/versions/2023_01_08_1603-35bdc1a08bb8_embedding_for_message_now_in_its_own_.py new file mode 100644 index 00000000..b732b792 --- /dev/null +++ b/backend/alembic/versions/2023_01_08_1603-35bdc1a08bb8_embedding_for_message_now_in_its_own_.py @@ -0,0 +1,49 @@ +"""embedding for message now in its own table + +Revision ID: 35bdc1a08bb8 +Revises: 023548d474f7 +Create Date: 2023-01-08 16:03:48.454207 + +""" +import sqlalchemy as sa +import sqlmodel +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "35bdc1a08bb8" +down_revision = "023548d474f7" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "message_embedding", + sa.Column("message_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("embedding", sa.ARRAY(sa.Float()), nullable=True), + sa.Column("model", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=False), + sa.ForeignKeyConstraint( + ["message_id"], + ["message.id"], + ), + sa.PrimaryKeyConstraint("message_id", "model"), + ) + op.drop_column("message", "miniLM_embedding") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "message", + sa.Column( + "miniLM_embedding", + postgresql.ARRAY(postgresql.DOUBLE_PRECISION(precision=53)), + autoincrement=False, + nullable=True, + ), + ) + op.drop_table("message_embedding") + # ### end Alembic commands ### diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index 8656a063..e033c25e 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -9,7 +9,7 @@ from oasst_backend.api import deps from oasst_backend.api.v1.utils import prepare_conversation from oasst_backend.config import settings from oasst_backend.prompt_repository import PromptRepository -from oasst_backend.utils.hugging_face import HF_url, HuggingFaceAPI +from oasst_backend.utils.hugging_face import HF_embeddingModel, HF_url, HuggingFaceAPI from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from sqlmodel import Session @@ -275,24 +275,27 @@ async def tasks_interaction( f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}." ) - embedding = None + # here we store the text reply in the database + newMessage = pr.store_text_reply( + text=interaction.text, + frontend_message_id=interaction.message_id, + user_frontend_message_id=interaction.user_message_id, + ) + if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION: try: - hugging_face_api = HuggingFaceAPI(HF_url.HUGGINGFACE_MINILM_EMBEDDING.value) + hugging_face_api = HuggingFaceAPI( + f"{HF_url.HUGGINGFACE_FEATURE_EXTRACTION.value}{HF_embeddingModel.MINILM.value}" + ) embedding = await hugging_face_api.post(interaction.text) + pr.insert_message_embedding( + message_id=newMessage.id, model=HF_embeddingModel.MINILM.value, embedding=embedding + ) except OasstError: logger.error( f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}." ) - # here we store the text reply in the database - pr.store_text_reply( - text=interaction.text, - frontend_message_id=interaction.message_id, - user_frontend_message_id=interaction.user_message_id, - miniLM_embedding=embedding, - ) - return protocol_schema.TaskDone() case protocol_schema.MessageRating: logger.info( diff --git a/backend/oasst_backend/models/__init__.py b/backend/oasst_backend/models/__init__.py index a856b155..0873381c 100644 --- a/backend/oasst_backend/models/__init__.py +++ b/backend/oasst_backend/models/__init__.py @@ -1,6 +1,7 @@ from .api_client import ApiClient from .journal import Journal, JournalIntegration from .message import Message +from .message_embedding import MessageEmbedding from .message_reaction import MessageReaction from .message_tree_state import MessageTreeState from .task import Task @@ -13,6 +14,7 @@ __all__ = [ "User", "UserStats", "Message", + "MessageEmbedding", "MessageReaction", "MessageTreeState", "Task", diff --git a/backend/oasst_backend/models/message.py b/backend/oasst_backend/models/message.py index c7c2abbb..6d24fd13 100644 --- a/backend/oasst_backend/models/message.py +++ b/backend/oasst_backend/models/message.py @@ -1,6 +1,6 @@ from datetime import datetime from http import HTTPStatus -from typing import List, Optional +from typing import Optional from uuid import UUID, uuid4 import sqlalchemy as sa @@ -8,7 +8,7 @@ import sqlalchemy.dialects.postgresql as pg from oasst_backend.models.db_payload import MessagePayload from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from sqlalchemy import false -from sqlmodel import ARRAY, Field, Float, Index, SQLModel +from sqlmodel import Field, Index, SQLModel from .payload_column_type import PayloadContainer, payload_column_type @@ -40,7 +40,6 @@ class Message(SQLModel, table=True): depth: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False)) children_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False)) deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false())) - miniLM_embedding: List[float] = Field(sa_column=sa.Column(ARRAY(Float)), nullable=True) def ensure_is_message(self) -> None: if not self.payload or not isinstance(self.payload.payload, MessagePayload): diff --git a/backend/oasst_backend/models/message_embedding.py b/backend/oasst_backend/models/message_embedding.py index e69de29b..697a776b 100644 --- a/backend/oasst_backend/models/message_embedding.py +++ b/backend/oasst_backend/models/message_embedding.py @@ -0,0 +1,15 @@ +from typing import List +from uuid import UUID + +import sqlalchemy as sa +import sqlalchemy.dialects.postgresql as pg +from sqlmodel import ARRAY, Field, Float, SQLModel + + +class MessageEmbedding(SQLModel, table=True): + __tablename__ = "message_embedding" + __table_args__ = (sa.PrimaryKeyConstraint("message_id", "model"),) + + message_id: UUID = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("message.id"), nullable=False)) + model: str = Field(max_length=256, nullable=False) + embedding: List[float] = Field(sa_column=sa.Column(ARRAY(Float)), nullable=True) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 8bb1eb4a..efdc1b33 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -8,7 +8,7 @@ from uuid import UUID, uuid4 import oasst_backend.models.db_payload as db_payload from loguru import logger from oasst_backend.journal_writer import JournalWriter -from oasst_backend.models import ApiClient, Message, MessageReaction, Task, TextLabels, User +from oasst_backend.models import ApiClient, Message, MessageEmbedding, MessageReaction, Task, TextLabels, User from oasst_backend.models.payload_column_type import PayloadContainer from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema @@ -165,7 +165,6 @@ class PromptRepository: role=role, payload=db_payload.MessagePayload(text=text), depth=depth, - miniLM_embedding=miniLM_embedding, ) if not task.collective: task.done = True @@ -369,7 +368,6 @@ class PromptRepository: payload: db_payload.MessagePayload, payload_type: str = None, depth: int = 0, - miniLM_embedding: List[float] = None, ) -> Message: if payload_type is None: if payload is None: @@ -389,13 +387,22 @@ class PromptRepository: payload_type=payload_type, payload=PayloadContainer(payload=payload), depth=depth, - miniLM_embedding=miniLM_embedding, ) self.db.add(message) self.db.commit() self.db.refresh(message) return message + def insert_message_embedding(self, message_id: UUID, model: str, embedding: List[float]) -> MessageEmbedding: + if None in (message_id, model, embedding): + raise OasstError("Paramters missing to add embedding", OasstErrorCode.GENERIC_ERROR) + + model = MessageEmbedding(message_id=message_id, model=model, embedding=embedding) + self.db.add(model) + self.db.commit() + self.db.refresh(model) + return model + def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction: if self.user_id is None: raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED) diff --git a/backend/oasst_backend/utils/hugging_face.py b/backend/oasst_backend/utils/hugging_face.py index 867c537c..7a9fa8e3 100644 --- a/backend/oasst_backend/utils/hugging_face.py +++ b/backend/oasst_backend/utils/hugging_face.py @@ -8,7 +8,11 @@ from oasst_shared.exceptions import OasstError, OasstErrorCode class HF_url(str, Enum): HUGGINGFACE_TOXIC_ROBERTA = ("https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta",) - HUGGINGFACE_MINILM_EMBEDDING = "https://api-inference.huggingface.co/pipeline/feature-extraction/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" + HUGGINGFACE_FEATURE_EXTRACTION = "https://api-inference.huggingface.co/pipeline/feature-extraction/" + + +class HF_embeddingModel(str, Enum): + MINILM = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" class HuggingFaceAPI: From 7101e0e7d527ec608362006927185bf0145de088 Mon Sep 17 00:00:00 2001 From: Nil-Andreu Date: Sun, 8 Jan 2023 20:17:14 +0100 Subject: [PATCH 04/12] [NEW] Message embedding created_date --- backend/oasst_backend/models/message_embedding.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/backend/oasst_backend/models/message_embedding.py b/backend/oasst_backend/models/message_embedding.py index 697a776b..74da5004 100644 --- a/backend/oasst_backend/models/message_embedding.py +++ b/backend/oasst_backend/models/message_embedding.py @@ -1,4 +1,5 @@ -from typing import List +from datetime import datetime +from typing import List, Optional from uuid import UUID import sqlalchemy as sa @@ -13,3 +14,8 @@ class MessageEmbedding(SQLModel, table=True): message_id: UUID = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("message.id"), nullable=False)) model: str = Field(max_length=256, nullable=False) embedding: List[float] = Field(sa_column=sa.Column(ARRAY(Float)), nullable=True) + + # In the case that the Message Embedding is created afterwards + created_date: Optional[datetime] = Field( + sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()) + ) From 19eee6be58fd7bd7298169207ba195a331fa7185 Mon Sep 17 00:00:00 2001 From: Nil-Andreu Date: Sun, 8 Jan 2023 20:53:25 +0100 Subject: [PATCH 05/12] [NEW] Removing embedding param in function Store Text Reply --- backend/oasst_backend/prompt_repository.py | 1 - 1 file changed, 1 deletion(-) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index dde5003d..919da19b 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -96,7 +96,6 @@ class PromptRepository: text: str, frontend_message_id: str, user_frontend_message_id: str, - miniLM_embedding: Optional[List[float]] = None, ) -> Message: validate_frontend_message_id(frontend_message_id) validate_frontend_message_id(user_frontend_message_id) From 225a136ad1f65948c14da5cf513d7ed2f2420156 Mon Sep 17 00:00:00 2001 From: Nil-Andreu Date: Sun, 8 Jan 2023 20:58:19 +0100 Subject: [PATCH 06/12] [NEW] Refactor name of message_embedding object --- backend/oasst_backend/prompt_repository.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 919da19b..3feb77bc 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -233,11 +233,11 @@ class PromptRepository: if None in (message_id, model, embedding): raise OasstError("Paramters missing to add embedding", OasstErrorCode.GENERIC_ERROR) - model = MessageEmbedding(message_id=message_id, model=model, embedding=embedding) - self.db.add(model) + message_embedding = MessageEmbedding(message_id=message_id, model=model, embedding=embedding) + self.db.add(message_embedding) self.db.commit() - self.db.refresh(model) - return model + self.db.refresh(message_embedding) + return message_embedding def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction: if self.user_id is None: From 412736f52c67d3330733c94f773f0f96b923854e Mon Sep 17 00:00:00 2001 From: Nil-Andreu Date: Sun, 8 Jan 2023 20:59:46 +0100 Subject: [PATCH 07/12] [NEW] insert_message_embedding: documentation --- backend/oasst_backend/prompt_repository.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 3feb77bc..c31c0061 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -230,6 +230,20 @@ class PromptRepository: ) def insert_message_embedding(self, message_id: UUID, model: str, embedding: List[float]) -> MessageEmbedding: + """Insert the embedding of a new message in the database. + + Args: + message_id (UUID): the identifier of the message we want to save its embedding + model (str): the model used for creating the embedding + embedding (List[float]): the values obtained from the message & model + + Raises: + OasstError: if misses some of the before params + + Returns: + MessageEmbedding: the instance in the database of the embedding saved for that message + """ + if None in (message_id, model, embedding): raise OasstError("Paramters missing to add embedding", OasstErrorCode.GENERIC_ERROR) From e241a8bf28019261523f2b873f261a759a35a89c Mon Sep 17 00:00:00 2001 From: Nil-Andreu Date: Sun, 8 Jan 2023 21:01:15 +0100 Subject: [PATCH 08/12] [NEW] Adding consistency in the URLs --- backend/oasst_backend/api/v1/tasks.py | 2 +- backend/oasst_backend/utils/hugging_face.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index 3bdacf49..186a6ac8 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -285,7 +285,7 @@ async def tasks_interaction( if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION: try: hugging_face_api = HuggingFaceAPI( - f"{HF_url.HUGGINGFACE_FEATURE_EXTRACTION.value}{HF_embeddingModel.MINILM.value}" + f"{HF_url.HUGGINGFACE_FEATURE_EXTRACTION.value}/{HF_embeddingModel.MINILM.value}" ) embedding = await hugging_face_api.post(interaction.text) pr.insert_message_embedding( diff --git a/backend/oasst_backend/utils/hugging_face.py b/backend/oasst_backend/utils/hugging_face.py index 7a9fa8e3..80be074d 100644 --- a/backend/oasst_backend/utils/hugging_face.py +++ b/backend/oasst_backend/utils/hugging_face.py @@ -8,7 +8,7 @@ from oasst_shared.exceptions import OasstError, OasstErrorCode class HF_url(str, Enum): HUGGINGFACE_TOXIC_ROBERTA = ("https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta",) - HUGGINGFACE_FEATURE_EXTRACTION = "https://api-inference.huggingface.co/pipeline/feature-extraction/" + HUGGINGFACE_FEATURE_EXTRACTION = "https://api-inference.huggingface.co/pipeline/feature-extraction" class HF_embeddingModel(str, Enum): From 70620520b49680e9aa5cee89395fa87ffffeef55 Mon Sep 17 00:00:00 2001 From: Nil-Andreu Date: Sun, 8 Jan 2023 21:29:12 +0100 Subject: [PATCH 09/12] [NEW] Created date --- ...23_01_08_2128-aac6b2f66006_created_date.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 backend/alembic/versions/2023_01_08_2128-aac6b2f66006_created_date.py diff --git a/backend/alembic/versions/2023_01_08_2128-aac6b2f66006_created_date.py b/backend/alembic/versions/2023_01_08_2128-aac6b2f66006_created_date.py new file mode 100644 index 00000000..6d40d896 --- /dev/null +++ b/backend/alembic/versions/2023_01_08_2128-aac6b2f66006_created_date.py @@ -0,0 +1,30 @@ +"""Created date + +Revision ID: aac6b2f66006 +Revises: 35bdc1a08bb8 +Create Date: 2023-01-08 21:28:27.342729 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "aac6b2f66006" +down_revision = "35bdc1a08bb8" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "message_embedding", + sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("message_embedding", "created_date") + # ### end Alembic commands ### From b39b86309cdaf8a6d8b207b3278781fd80a47140 Mon Sep 17 00:00:00 2001 From: Nil-Andreu Date: Mon, 9 Jan 2023 09:03:02 +0100 Subject: [PATCH 10/12] [FIX] Import on top --- backend/oasst_backend/utils/hugging_face.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/oasst_backend/utils/hugging_face.py b/backend/oasst_backend/utils/hugging_face.py index 80be074d..ef440055 100644 --- a/backend/oasst_backend/utils/hugging_face.py +++ b/backend/oasst_backend/utils/hugging_face.py @@ -2,6 +2,7 @@ from enum import Enum from typing import Any, Dict import aiohttp +from loguru import logger from oasst_backend.config import settings from oasst_shared.exceptions import OasstError, OasstErrorCode @@ -51,7 +52,6 @@ class HuggingFaceAPI: async with session.post(self.api_url, headers=self.headers, json=payload) as response: # If we get a bad response if response.status != 200: - from loguru import logger logger.error(response) logger.info(self.headers) From ef7bd89df2124739973f703b51531c62324a6a15 Mon Sep 17 00:00:00 2001 From: Nil-Andreu Date: Mon, 9 Jan 2023 09:07:08 +0100 Subject: [PATCH 11/12] [NEW] ansible: DEBUG_SKIP_EMBEDDING_COMPUTATION --- ansible/dev.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/ansible/dev.yaml b/ansible/dev.yaml index 577abd68..81ab12ee 100644 --- a/ansible/dev.yaml +++ b/ansible/dev.yaml @@ -55,6 +55,7 @@ DEBUG_USE_SEED_DATA: "true" MAX_WORKERS: "1" RATE_LIMIT: "false" + DEBUG_SKIP_EMBEDDING_COMPUTATION: "true" ports: - 8080:8080 From 4b4a564a8fb8a6308527dbeb65b6cafd108e0f75 Mon Sep 17 00:00:00 2001 From: Nil-Andreu Date: Mon, 9 Jan 2023 10:10:19 +0100 Subject: [PATCH 12/12] [NEW] Camelcase & 2x space --- backend/oasst_backend/api/v1/hugging_face.py | 4 ++-- backend/oasst_backend/api/v1/tasks.py | 8 ++++---- backend/oasst_backend/utils/hugging_face.py | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/backend/oasst_backend/api/v1/hugging_face.py b/backend/oasst_backend/api/v1/hugging_face.py index a8d8aeb9..62d2ea6b 100644 --- a/backend/oasst_backend/api/v1/hugging_face.py +++ b/backend/oasst_backend/api/v1/hugging_face.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends from oasst_backend.api import deps from oasst_backend.models import ApiClient from oasst_backend.schemas.hugging_face import ToxicityClassification -from oasst_backend.utils.hugging_face import HF_url, HuggingFaceAPI +from oasst_backend.utils.hugging_face import HfUrl, HuggingFaceAPI router = APIRouter() @@ -25,7 +25,7 @@ async def get_text_toxicity( ToxicityClassification: the score of toxicity of the message. """ - api_url: str = HF_url.HUGGINGFACE_TOXIC_ROBERTA.value + api_url: str = HfUrl.HUGGINGFACE_TOXIC_ROBERTA.value hugging_face_api = HuggingFaceAPI(api_url) response = await hugging_face_api.post(msg) diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index 186a6ac8..821ba562 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -9,7 +9,7 @@ from oasst_backend.api import deps from oasst_backend.api.v1.utils import prepare_conversation from oasst_backend.config import settings from oasst_backend.prompt_repository import PromptRepository, TaskRepository -from oasst_backend.utils.hugging_face import HF_embeddingModel, HF_url, HuggingFaceAPI +from oasst_backend.utils.hugging_face import HfEmbeddingModel, HfUrl, HuggingFaceAPI from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from sqlmodel import Session @@ -285,15 +285,15 @@ async def tasks_interaction( if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION: try: hugging_face_api = HuggingFaceAPI( - f"{HF_url.HUGGINGFACE_FEATURE_EXTRACTION.value}/{HF_embeddingModel.MINILM.value}" + f"{HfUrl.HUGGINGFACE_FEATURE_EXTRACTION.value}/{HfEmbeddingModel.MINILM.value}" ) embedding = await hugging_face_api.post(interaction.text) pr.insert_message_embedding( - message_id=newMessage.id, model=HF_embeddingModel.MINILM.value, embedding=embedding + message_id=newMessage.id, model=HfEmbeddingModel.MINILM.value, embedding=embedding ) except OasstError: logger.error( - f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}." + f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}." ) return protocol_schema.TaskDone() diff --git a/backend/oasst_backend/utils/hugging_face.py b/backend/oasst_backend/utils/hugging_face.py index ef440055..87c6288e 100644 --- a/backend/oasst_backend/utils/hugging_face.py +++ b/backend/oasst_backend/utils/hugging_face.py @@ -7,12 +7,12 @@ from oasst_backend.config import settings from oasst_shared.exceptions import OasstError, OasstErrorCode -class HF_url(str, Enum): +class HfUrl(str, Enum): HUGGINGFACE_TOXIC_ROBERTA = ("https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta",) HUGGINGFACE_FEATURE_EXTRACTION = "https://api-inference.huggingface.co/pipeline/feature-extraction" -class HF_embeddingModel(str, Enum): +class HfEmbeddingModel(str, Enum): MINILM = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"