insert embedding now to new table

This commit is contained in:
jojopirker
2023-01-08 16:46:53 +01:00
parent 34e7d1db8a
commit a677e40cff
8 changed files with 103 additions and 26 deletions
@@ -5,25 +5,23 @@ Revises: ba61fe17fb6e
Create Date: 2023-01-08 11:06:25.613290
"""
from alembic import op
import sqlalchemy as sa
import sqlmodel
from alembic import op
# revision identifiers, used by Alembic.
revision = '023548d474f7'
down_revision = 'ba61fe17fb6e'
revision = "023548d474f7"
down_revision = "ba61fe17fb6e"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('message', sa.Column('miniLM_embedding', sa.ARRAY(sa.Float()), nullable=True))
op.add_column("message", sa.Column("miniLM_embedding", sa.ARRAY(sa.Float()), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('message', 'miniLM_embedding')
op.drop_column("message", "miniLM_embedding")
# ### end Alembic commands ###
@@ -0,0 +1,49 @@
"""embedding for message now in its own table
Revision ID: 35bdc1a08bb8
Revises: 023548d474f7
Create Date: 2023-01-08 16:03:48.454207
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "35bdc1a08bb8"
down_revision = "023548d474f7"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"message_embedding",
sa.Column("message_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("embedding", sa.ARRAY(sa.Float()), nullable=True),
sa.Column("model", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=False),
sa.ForeignKeyConstraint(
["message_id"],
["message.id"],
),
sa.PrimaryKeyConstraint("message_id", "model"),
)
op.drop_column("message", "miniLM_embedding")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"message",
sa.Column(
"miniLM_embedding",
postgresql.ARRAY(postgresql.DOUBLE_PRECISION(precision=53)),
autoincrement=False,
nullable=True,
),
)
op.drop_table("message_embedding")
# ### end Alembic commands ###
+14 -11
View File
@@ -9,7 +9,7 @@ from oasst_backend.api import deps
from oasst_backend.api.v1.utils import prepare_conversation
from oasst_backend.config import settings
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.utils.hugging_face import HF_url, HuggingFaceAPI
from oasst_backend.utils.hugging_face import HF_embeddingModel, HF_url, HuggingFaceAPI
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
@@ -275,24 +275,27 @@ async def tasks_interaction(
f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
)
embedding = None
# here we store the text reply in the database
newMessage = pr.store_text_reply(
text=interaction.text,
frontend_message_id=interaction.message_id,
user_frontend_message_id=interaction.user_message_id,
)
if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION:
try:
hugging_face_api = HuggingFaceAPI(HF_url.HUGGINGFACE_MINILM_EMBEDDING.value)
hugging_face_api = HuggingFaceAPI(
f"{HF_url.HUGGINGFACE_FEATURE_EXTRACTION.value}{HF_embeddingModel.MINILM.value}"
)
embedding = await hugging_face_api.post(interaction.text)
pr.insert_message_embedding(
message_id=newMessage.id, model=HF_embeddingModel.MINILM.value, embedding=embedding
)
except OasstError:
logger.error(
f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
)
# here we store the text reply in the database
pr.store_text_reply(
text=interaction.text,
frontend_message_id=interaction.message_id,
user_frontend_message_id=interaction.user_message_id,
miniLM_embedding=embedding,
)
return protocol_schema.TaskDone()
case protocol_schema.MessageRating:
logger.info(
+2
View File
@@ -1,6 +1,7 @@
from .api_client import ApiClient
from .journal import Journal, JournalIntegration
from .message import Message
from .message_embedding import MessageEmbedding
from .message_reaction import MessageReaction
from .message_tree_state import MessageTreeState
from .task import Task
@@ -13,6 +14,7 @@ __all__ = [
"User",
"UserStats",
"Message",
"MessageEmbedding",
"MessageReaction",
"MessageTreeState",
"Task",
+2 -3
View File
@@ -1,6 +1,6 @@
from datetime import datetime
from http import HTTPStatus
from typing import List, Optional
from typing import Optional
from uuid import UUID, uuid4
import sqlalchemy as sa
@@ -8,7 +8,7 @@ import sqlalchemy.dialects.postgresql as pg
from oasst_backend.models.db_payload import MessagePayload
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from sqlalchemy import false
from sqlmodel import ARRAY, Field, Float, Index, SQLModel
from sqlmodel import Field, Index, SQLModel
from .payload_column_type import PayloadContainer, payload_column_type
@@ -40,7 +40,6 @@ class Message(SQLModel, table=True):
depth: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
children_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
miniLM_embedding: List[float] = Field(sa_column=sa.Column(ARRAY(Float)), nullable=True)
def ensure_is_message(self) -> None:
if not self.payload or not isinstance(self.payload.payload, MessagePayload):
@@ -0,0 +1,15 @@
from typing import List
from uuid import UUID
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import ARRAY, Field, Float, SQLModel
class MessageEmbedding(SQLModel, table=True):
__tablename__ = "message_embedding"
__table_args__ = (sa.PrimaryKeyConstraint("message_id", "model"),)
message_id: UUID = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("message.id"), nullable=False))
model: str = Field(max_length=256, nullable=False)
embedding: List[float] = Field(sa_column=sa.Column(ARRAY(Float)), nullable=True)
+11 -4
View File
@@ -8,7 +8,7 @@ from uuid import UUID, uuid4
import oasst_backend.models.db_payload as db_payload
from loguru import logger
from oasst_backend.journal_writer import JournalWriter
from oasst_backend.models import ApiClient, Message, MessageReaction, Task, TextLabels, User
from oasst_backend.models import ApiClient, Message, MessageEmbedding, MessageReaction, Task, TextLabels, User
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
@@ -165,7 +165,6 @@ class PromptRepository:
role=role,
payload=db_payload.MessagePayload(text=text),
depth=depth,
miniLM_embedding=miniLM_embedding,
)
if not task.collective:
task.done = True
@@ -369,7 +368,6 @@ class PromptRepository:
payload: db_payload.MessagePayload,
payload_type: str = None,
depth: int = 0,
miniLM_embedding: List[float] = None,
) -> Message:
if payload_type is None:
if payload is None:
@@ -389,13 +387,22 @@ class PromptRepository:
payload_type=payload_type,
payload=PayloadContainer(payload=payload),
depth=depth,
miniLM_embedding=miniLM_embedding,
)
self.db.add(message)
self.db.commit()
self.db.refresh(message)
return message
def insert_message_embedding(self, message_id: UUID, model: str, embedding: List[float]) -> MessageEmbedding:
if None in (message_id, model, embedding):
raise OasstError("Paramters missing to add embedding", OasstErrorCode.GENERIC_ERROR)
model = MessageEmbedding(message_id=message_id, model=model, embedding=embedding)
self.db.add(model)
self.db.commit()
self.db.refresh(model)
return model
def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
if self.user_id is None:
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
+5 -1
View File
@@ -8,7 +8,11 @@ from oasst_shared.exceptions import OasstError, OasstErrorCode
class HF_url(str, Enum):
HUGGINGFACE_TOXIC_ROBERTA = ("https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta",)
HUGGINGFACE_MINILM_EMBEDDING = "https://api-inference.huggingface.co/pipeline/feature-extraction/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
HUGGINGFACE_FEATURE_EXTRACTION = "https://api-inference.huggingface.co/pipeline/feature-extraction/"
class HF_embeddingModel(str, Enum):
MINILM = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
class HuggingFaceAPI: