Merge pull request #540 from jojopirker/messageEmbeddings

Store Message embedding
This commit is contained in:
Yannic Kilcher
2023-01-09 21:22:56 +01:00
committed by GitHub
13 changed files with 199 additions and 12 deletions
+1
View File
@@ -81,6 +81,7 @@
DEBUG_USE_SEED_DATA: "true"
MAX_WORKERS: "1"
RATE_LIMIT: "false"
DEBUG_SKIP_EMBEDDING_COMPUTATION: "true"
ports:
- 8080:8080
@@ -0,0 +1,27 @@
"""added miniLM_embedding column to message
Revision ID: 023548d474f7
Revises: ba61fe17fb6e
Create Date: 2023-01-08 11:06:25.613290
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "023548d474f7"
down_revision = "ba61fe17fb6e"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("message", sa.Column("miniLM_embedding", sa.ARRAY(sa.Float()), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("message", "miniLM_embedding")
# ### end Alembic commands ###
@@ -0,0 +1,49 @@
"""embedding for message now in its own table
Revision ID: 35bdc1a08bb8
Revises: 023548d474f7
Create Date: 2023-01-08 16:03:48.454207
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "35bdc1a08bb8"
down_revision = "023548d474f7"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"message_embedding",
sa.Column("message_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("embedding", sa.ARRAY(sa.Float()), nullable=True),
sa.Column("model", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=False),
sa.ForeignKeyConstraint(
["message_id"],
["message.id"],
),
sa.PrimaryKeyConstraint("message_id", "model"),
)
op.drop_column("message", "miniLM_embedding")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"message",
sa.Column(
"miniLM_embedding",
postgresql.ARRAY(postgresql.DOUBLE_PRECISION(precision=53)),
autoincrement=False,
nullable=True,
),
)
op.drop_table("message_embedding")
# ### end Alembic commands ###
@@ -0,0 +1,30 @@
"""Created date
Revision ID: aac6b2f66006
Revises: 35bdc1a08bb8
Create Date: 2023-01-08 21:28:27.342729
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "aac6b2f66006"
down_revision = "35bdc1a08bb8"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"message_embedding",
sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("message_embedding", "created_date")
# ### end Alembic commands ###
+2 -7
View File
@@ -1,19 +1,14 @@
from enum import Enum
from typing import List
from fastapi import APIRouter, Depends
from oasst_backend.api import deps
from oasst_backend.models import ApiClient
from oasst_backend.schemas.hugging_face import ToxicityClassification
from oasst_backend.utils.hugging_face import HuggingFaceAPI
from oasst_backend.utils.hugging_face import HfUrl, HuggingFaceAPI
router = APIRouter()
class HF_url(str, Enum):
HUGGINGFACE_TOXIC_ROBERTA = "https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta"
@router.get("/text_toxicity")
async def get_text_toxicity(
msg: str,
@@ -30,7 +25,7 @@ async def get_text_toxicity(
ToxicityClassification: the score of toxicity of the message.
"""
api_url: str = HF_url.HUGGINGFACE_TOXIC_ROBERTA.value
api_url: str = HfUrl.HUGGINGFACE_TOXIC_ROBERTA.value
hugging_face_api = HuggingFaceAPI(api_url)
response = await hugging_face_api.post(msg)
+18 -2
View File
@@ -7,7 +7,9 @@ from fastapi.security.api_key import APIKey
from loguru import logger
from oasst_backend.api import deps
from oasst_backend.api.v1.utils import prepare_conversation
from oasst_backend.config import settings
from oasst_backend.prompt_repository import PromptRepository, TaskRepository
from oasst_backend.utils.hugging_face import HfEmbeddingModel, HfUrl, HuggingFaceAPI
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
@@ -253,7 +255,7 @@ def tasks_acknowledge_failure(
@router.post("/interaction", response_model=protocol_schema.TaskDone)
def tasks_interaction(
async def tasks_interaction(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
@@ -274,12 +276,26 @@ def tasks_interaction(
)
# here we store the text reply in the database
pr.store_text_reply(
newMessage = pr.store_text_reply(
text=interaction.text,
frontend_message_id=interaction.message_id,
user_frontend_message_id=interaction.user_message_id,
)
if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION:
try:
hugging_face_api = HuggingFaceAPI(
f"{HfUrl.HUGGINGFACE_FEATURE_EXTRACTION.value}/{HfEmbeddingModel.MINILM.value}"
)
embedding = await hugging_face_api.post(interaction.text)
pr.insert_message_embedding(
message_id=newMessage.id, model=HfEmbeddingModel.MINILM.value, embedding=embedding
)
except OasstError:
logger.error(
f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
)
return protocol_schema.TaskDone()
case protocol_schema.MessageRating:
logger.info(
+1
View File
@@ -25,6 +25,7 @@ class Settings(BaseSettings):
DEBUG_USE_SEED_DATA_PATH: Optional[FilePath] = (
Path(__file__).parent.parent / "test_data/generic/test_generic_data.json"
)
DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False
HUGGING_FACE_API_KEY: str = ""
+2
View File
@@ -1,6 +1,7 @@
from .api_client import ApiClient
from .journal import Journal, JournalIntegration
from .message import Message
from .message_embedding import MessageEmbedding
from .message_reaction import MessageReaction
from .message_tree_state import MessageTreeState
from .task import Task
@@ -13,6 +14,7 @@ __all__ = [
"User",
"UserStats",
"Message",
"MessageEmbedding",
"MessageReaction",
"MessageTreeState",
"Task",
@@ -0,0 +1,21 @@
from datetime import datetime
from typing import List, Optional
from uuid import UUID
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import ARRAY, Field, Float, SQLModel
class MessageEmbedding(SQLModel, table=True):
__tablename__ = "message_embedding"
__table_args__ = (sa.PrimaryKeyConstraint("message_id", "model"),)
message_id: UUID = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("message.id"), nullable=False))
model: str = Field(max_length=256, nullable=False)
embedding: List[float] = Field(sa_column=sa.Column(ARRAY(Float)), nullable=True)
# In the case that the Message Embedding is created afterwards
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
)
+32 -3
View File
@@ -2,13 +2,13 @@ import datetime
import random
from collections import defaultdict
from http import HTTPStatus
from typing import Optional
from typing import List, Optional
from uuid import UUID, uuid4
import oasst_backend.models.db_payload as db_payload
from loguru import logger
from oasst_backend.journal_writer import JournalWriter
from oasst_backend.models import ApiClient, Message, MessageReaction, TextLabels, User
from oasst_backend.models import ApiClient, Message, MessageEmbedding, MessageReaction, TextLabels, User
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_backend.task_repository import TaskRepository, validate_frontend_message_id
from oasst_backend.user_repository import UserRepository
@@ -91,7 +91,12 @@ class PromptRepository:
self.db.refresh(message)
return message
def store_text_reply(self, text: str, frontend_message_id: str, user_frontend_message_id: str) -> Message:
def store_text_reply(
self,
text: str,
frontend_message_id: str,
user_frontend_message_id: str,
) -> Message:
validate_frontend_message_id(frontend_message_id)
validate_frontend_message_id(user_frontend_message_id)
@@ -224,6 +229,30 @@ class PromptRepository:
OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH,
)
def insert_message_embedding(self, message_id: UUID, model: str, embedding: List[float]) -> MessageEmbedding:
"""Insert the embedding of a new message in the database.
Args:
message_id (UUID): the identifier of the message we want to save its embedding
model (str): the model used for creating the embedding
embedding (List[float]): the values obtained from the message & model
Raises:
OasstError: if misses some of the before params
Returns:
MessageEmbedding: the instance in the database of the embedding saved for that message
"""
if None in (message_id, model, embedding):
raise OasstError("Paramters missing to add embedding", OasstErrorCode.GENERIC_ERROR)
message_embedding = MessageEmbedding(message_id=message_id, model=model, embedding=embedding)
self.db.add(message_embedding)
self.db.commit()
self.db.refresh(message_embedding)
return message_embedding
def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
if self.user_id is None:
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
@@ -1,10 +1,21 @@
from enum import Enum
from typing import Any, Dict
import aiohttp
from loguru import logger
from oasst_backend.config import settings
from oasst_shared.exceptions import OasstError, OasstErrorCode
class HfUrl(str, Enum):
HUGGINGFACE_TOXIC_ROBERTA = ("https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta",)
HUGGINGFACE_FEATURE_EXTRACTION = "https://api-inference.huggingface.co/pipeline/feature-extraction"
class HfEmbeddingModel(str, Enum):
MINILM = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
class HuggingFaceAPI:
"""Class Object to make post calls to endpoints for inference in models hosted in HuggingFace"""
@@ -41,6 +52,9 @@ class HuggingFaceAPI:
async with session.post(self.api_url, headers=self.headers, json=payload) as response:
# If we get a bad response
if response.status != 200:
logger.error(response)
logger.info(self.headers)
raise OasstError(
"Response Error Detoxify HuggingFace", error_code=OasstErrorCode.HUGGINGFACE_API_ERROR
)
+1
View File
@@ -95,6 +95,7 @@ services:
- DEBUG_SKIP_API_KEY_CHECK=True
- DEBUG_USE_SEED_DATA=True
- MAX_WORKERS=1
- DEBUG_SKIP_EMBEDDING_COMPUTATION=True
depends_on:
db:
condition: service_healthy
+1
View File
@@ -6,6 +6,7 @@ pushd "$parent_path/../../backend"
export DEBUG_SKIP_API_KEY_CHECK=True
export DEBUG_USE_SEED_DATA=True
export DEBUG_SKIP_EMBEDDING_COMPUTATION=True
uvicorn main:app --reload --port 8080 --host 0.0.0.0