message embeddings in Messages table

This commit is contained in:
jojopirker
2023-01-08 12:28:38 +01:00
parent 9194e15b80
commit 11d55d572a
10 changed files with 67 additions and 11 deletions
@@ -0,0 +1,29 @@
"""added miniLM_embedding column to message
Revision ID: 023548d474f7
Revises: ba61fe17fb6e
Create Date: 2023-01-08 11:06:25.613290
"""
from alembic import op
import sqlalchemy as sa
import sqlmodel
# revision identifiers, used by Alembic.
revision = '023548d474f7'
down_revision = 'ba61fe17fb6e'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('message', sa.Column('miniLM_embedding', sa.ARRAY(sa.Float()), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('message', 'miniLM_embedding')
# ### end Alembic commands ###
+1 -6
View File
@@ -1,19 +1,14 @@
from enum import Enum
from typing import List
from fastapi import APIRouter, Depends
from oasst_backend.api import deps
from oasst_backend.models import ApiClient
from oasst_backend.schemas.hugging_face import ToxicityClassification
from oasst_backend.utils.hugging_face import HuggingFaceAPI
from oasst_backend.utils.hugging_face import HF_url, HuggingFaceAPI
router = APIRouter()
class HF_url(str, Enum):
HUGGINGFACE_TOXIC_ROBERTA = "https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta"
@router.get("/text_toxicity")
async def get_text_toxicity(
msg: str,
+14 -1
View File
@@ -7,7 +7,9 @@ from fastapi.security.api_key import APIKey
from loguru import logger
from oasst_backend.api import deps
from oasst_backend.api.v1.utils import prepare_conversation
from oasst_backend.config import settings
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.utils.hugging_face import HF_url, HuggingFaceAPI
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
@@ -253,7 +255,7 @@ def tasks_acknowledge_failure(
@router.post("/interaction", response_model=protocol_schema.TaskDone)
def tasks_interaction(
async def tasks_interaction(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
@@ -273,11 +275,22 @@ def tasks_interaction(
f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
)
embedding = None
if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION:
try:
hugging_face_api = HuggingFaceAPI(HF_url.HUGGINGFACE_MINILM_EMBEDDING.value)
embedding = await hugging_face_api.post(interaction.text)
except:
logger.error(
f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
)
# here we store the text reply in the database
pr.store_text_reply(
text=interaction.text,
frontend_message_id=interaction.message_id,
user_frontend_message_id=interaction.user_message_id,
miniLM_embedding=embedding,
)
return protocol_schema.TaskDone()
+1
View File
@@ -25,6 +25,7 @@ class Settings(BaseSettings):
DEBUG_USE_SEED_DATA_PATH: Optional[FilePath] = (
Path(__file__).parent.parent / "test_data/generic/test_generic_data.json"
)
DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False
HUGGING_FACE_API_KEY: str = ""
+3 -2
View File
@@ -1,6 +1,6 @@
from datetime import datetime
from http import HTTPStatus
from typing import Optional
from typing import List, Optional
from uuid import UUID, uuid4
import sqlalchemy as sa
@@ -8,7 +8,7 @@ import sqlalchemy.dialects.postgresql as pg
from oasst_backend.models.db_payload import MessagePayload
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from sqlalchemy import false
from sqlmodel import Field, Index, SQLModel
from sqlmodel import ARRAY, Field, Float, Index, SQLModel
from .payload_column_type import PayloadContainer, payload_column_type
@@ -40,6 +40,7 @@ class Message(SQLModel, table=True):
depth: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
children_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
miniLM_embedding: List[float] = Field(sa_column=sa.Column(ARRAY(Float)), nullable=True)
def ensure_is_message(self) -> None:
if not self.payload or not isinstance(self.payload.payload, MessagePayload):
+7 -2
View File
@@ -2,7 +2,7 @@ import datetime
import random
from collections import defaultdict
from http import HTTPStatus
from typing import Optional
from typing import List, Optional
from uuid import UUID, uuid4
import oasst_backend.models.db_payload as db_payload
@@ -122,7 +122,9 @@ class PromptRepository:
)
return task
def store_text_reply(self, text: str, frontend_message_id: str, user_frontend_message_id: str) -> Message:
def store_text_reply(
self, text: str, frontend_message_id: str, user_frontend_message_id: str, miniLM_embedding: List[float] = None
) -> Message:
self.validate_frontend_message_id(frontend_message_id)
self.validate_frontend_message_id(user_frontend_message_id)
@@ -163,6 +165,7 @@ class PromptRepository:
role=role,
payload=db_payload.MessagePayload(text=text),
depth=depth,
miniLM_embedding=miniLM_embedding,
)
if not task.collective:
task.done = True
@@ -366,6 +369,7 @@ class PromptRepository:
payload: db_payload.MessagePayload,
payload_type: str = None,
depth: int = 0,
miniLM_embedding: List[float] = None,
) -> Message:
if payload_type is None:
if payload is None:
@@ -385,6 +389,7 @@ class PromptRepository:
payload_type=payload_type,
payload=PayloadContainer(payload=payload),
depth=depth,
miniLM_embedding=miniLM_embedding,
)
self.db.add(message)
self.db.commit()
@@ -1,3 +1,4 @@
from enum import Enum
from typing import Any, Dict
import aiohttp
@@ -5,6 +6,11 @@ from oasst_backend.config import settings
from oasst_shared.exceptions import OasstError, OasstErrorCode
class HF_url(str, Enum):
HUGGINGFACE_TOXIC_ROBERTA = ("https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta",)
HUGGINGFACE_MINILM_EMBEDDING = "https://api-inference.huggingface.co/pipeline/feature-extraction/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
class HuggingFaceAPI:
"""Class Object to make post calls to endpoints for inference in models hosted in HuggingFace"""
@@ -41,6 +47,10 @@ class HuggingFaceAPI:
async with session.post(self.api_url, headers=self.headers, json=payload) as response:
# If we get a bad response
if response.status != 200:
from loguru import logger
logger.error(response)
logger.info(self.headers)
raise OasstError(
"Response Error Detoxify HuggingFace", error_code=OasstErrorCode.HUGGINGFACE_API_ERROR
)
+1
View File
@@ -95,6 +95,7 @@ services:
- DEBUG_SKIP_API_KEY_CHECK=True
- DEBUG_USE_SEED_DATA=True
- MAX_WORKERS=1
- DEBUG_SKIP_EMBEDDING_COMPUTATION=True
depends_on:
db:
condition: service_healthy
+1
View File
@@ -6,6 +6,7 @@ pushd "$parent_path/../../backend"
export DEBUG_SKIP_API_KEY_CHECK=True
export DEBUG_USE_SEED_DATA=True
export DEBUG_SKIP_EMBEDDING_COMPUTATION=True
uvicorn main:app --reload --port 8080 --host 0.0.0.0