mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
insert embedding now to new table
This commit is contained in:
+5
-7
@@ -5,25 +5,23 @@ Revises: ba61fe17fb6e
|
||||
Create Date: 2023-01-08 11:06:25.613290
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '023548d474f7'
|
||||
down_revision = 'ba61fe17fb6e'
|
||||
revision = "023548d474f7"
|
||||
down_revision = "ba61fe17fb6e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('message', sa.Column('miniLM_embedding', sa.ARRAY(sa.Float()), nullable=True))
|
||||
op.add_column("message", sa.Column("miniLM_embedding", sa.ARRAY(sa.Float()), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('message', 'miniLM_embedding')
|
||||
op.drop_column("message", "miniLM_embedding")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
+49
@@ -0,0 +1,49 @@
|
||||
"""embedding for message now in its own table
|
||||
|
||||
Revision ID: 35bdc1a08bb8
|
||||
Revises: 023548d474f7
|
||||
Create Date: 2023-01-08 16:03:48.454207
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "35bdc1a08bb8"
|
||||
down_revision = "023548d474f7"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"message_embedding",
|
||||
sa.Column("message_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("embedding", sa.ARRAY(sa.Float()), nullable=True),
|
||||
sa.Column("model", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["message_id"],
|
||||
["message.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("message_id", "model"),
|
||||
)
|
||||
op.drop_column("message", "miniLM_embedding")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"message",
|
||||
sa.Column(
|
||||
"miniLM_embedding",
|
||||
postgresql.ARRAY(postgresql.DOUBLE_PRECISION(precision=53)),
|
||||
autoincrement=False,
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.drop_table("message_embedding")
|
||||
# ### end Alembic commands ###
|
||||
@@ -9,7 +9,7 @@ from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1.utils import prepare_conversation
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.utils.hugging_face import HF_url, HuggingFaceAPI
|
||||
from oasst_backend.utils.hugging_face import HF_embeddingModel, HF_url, HuggingFaceAPI
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
@@ -275,24 +275,27 @@ async def tasks_interaction(
|
||||
f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
embedding = None
|
||||
# here we store the text reply in the database
|
||||
newMessage = pr.store_text_reply(
|
||||
text=interaction.text,
|
||||
frontend_message_id=interaction.message_id,
|
||||
user_frontend_message_id=interaction.user_message_id,
|
||||
)
|
||||
|
||||
if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION:
|
||||
try:
|
||||
hugging_face_api = HuggingFaceAPI(HF_url.HUGGINGFACE_MINILM_EMBEDDING.value)
|
||||
hugging_face_api = HuggingFaceAPI(
|
||||
f"{HF_url.HUGGINGFACE_FEATURE_EXTRACTION.value}{HF_embeddingModel.MINILM.value}"
|
||||
)
|
||||
embedding = await hugging_face_api.post(interaction.text)
|
||||
pr.insert_message_embedding(
|
||||
message_id=newMessage.id, model=HF_embeddingModel.MINILM.value, embedding=embedding
|
||||
)
|
||||
except OasstError:
|
||||
logger.error(
|
||||
f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
# here we store the text reply in the database
|
||||
pr.store_text_reply(
|
||||
text=interaction.text,
|
||||
frontend_message_id=interaction.message_id,
|
||||
user_frontend_message_id=interaction.user_message_id,
|
||||
miniLM_embedding=embedding,
|
||||
)
|
||||
|
||||
return protocol_schema.TaskDone()
|
||||
case protocol_schema.MessageRating:
|
||||
logger.info(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from .api_client import ApiClient
|
||||
from .journal import Journal, JournalIntegration
|
||||
from .message import Message
|
||||
from .message_embedding import MessageEmbedding
|
||||
from .message_reaction import MessageReaction
|
||||
from .message_tree_state import MessageTreeState
|
||||
from .task import Task
|
||||
@@ -13,6 +14,7 @@ __all__ = [
|
||||
"User",
|
||||
"UserStats",
|
||||
"Message",
|
||||
"MessageEmbedding",
|
||||
"MessageReaction",
|
||||
"MessageTreeState",
|
||||
"Task",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime
|
||||
from http import HTTPStatus
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
@@ -8,7 +8,7 @@ import sqlalchemy.dialects.postgresql as pg
|
||||
from oasst_backend.models.db_payload import MessagePayload
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from sqlalchemy import false
|
||||
from sqlmodel import ARRAY, Field, Float, Index, SQLModel
|
||||
from sqlmodel import Field, Index, SQLModel
|
||||
|
||||
from .payload_column_type import PayloadContainer, payload_column_type
|
||||
|
||||
@@ -40,7 +40,6 @@ class Message(SQLModel, table=True):
|
||||
depth: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
|
||||
children_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
|
||||
deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
|
||||
miniLM_embedding: List[float] = Field(sa_column=sa.Column(ARRAY(Float)), nullable=True)
|
||||
|
||||
def ensure_is_message(self) -> None:
|
||||
if not self.payload or not isinstance(self.payload.payload, MessagePayload):
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import ARRAY, Field, Float, SQLModel
|
||||
|
||||
|
||||
class MessageEmbedding(SQLModel, table=True):
|
||||
__tablename__ = "message_embedding"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("message_id", "model"),)
|
||||
|
||||
message_id: UUID = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("message.id"), nullable=False))
|
||||
model: str = Field(max_length=256, nullable=False)
|
||||
embedding: List[float] = Field(sa_column=sa.Column(ARRAY(Float)), nullable=True)
|
||||
|
||||
@@ -8,7 +8,7 @@ from uuid import UUID, uuid4
|
||||
import oasst_backend.models.db_payload as db_payload
|
||||
from loguru import logger
|
||||
from oasst_backend.journal_writer import JournalWriter
|
||||
from oasst_backend.models import ApiClient, Message, MessageReaction, Task, TextLabels, User
|
||||
from oasst_backend.models import ApiClient, Message, MessageEmbedding, MessageReaction, Task, TextLabels, User
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
@@ -165,7 +165,6 @@ class PromptRepository:
|
||||
role=role,
|
||||
payload=db_payload.MessagePayload(text=text),
|
||||
depth=depth,
|
||||
miniLM_embedding=miniLM_embedding,
|
||||
)
|
||||
if not task.collective:
|
||||
task.done = True
|
||||
@@ -369,7 +368,6 @@ class PromptRepository:
|
||||
payload: db_payload.MessagePayload,
|
||||
payload_type: str = None,
|
||||
depth: int = 0,
|
||||
miniLM_embedding: List[float] = None,
|
||||
) -> Message:
|
||||
if payload_type is None:
|
||||
if payload is None:
|
||||
@@ -389,13 +387,22 @@ class PromptRepository:
|
||||
payload_type=payload_type,
|
||||
payload=PayloadContainer(payload=payload),
|
||||
depth=depth,
|
||||
miniLM_embedding=miniLM_embedding,
|
||||
)
|
||||
self.db.add(message)
|
||||
self.db.commit()
|
||||
self.db.refresh(message)
|
||||
return message
|
||||
|
||||
def insert_message_embedding(self, message_id: UUID, model: str, embedding: List[float]) -> MessageEmbedding:
|
||||
if None in (message_id, model, embedding):
|
||||
raise OasstError("Paramters missing to add embedding", OasstErrorCode.GENERIC_ERROR)
|
||||
|
||||
model = MessageEmbedding(message_id=message_id, model=model, embedding=embedding)
|
||||
self.db.add(model)
|
||||
self.db.commit()
|
||||
self.db.refresh(model)
|
||||
return model
|
||||
|
||||
def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
|
||||
if self.user_id is None:
|
||||
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
|
||||
|
||||
@@ -8,7 +8,11 @@ from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
|
||||
class HF_url(str, Enum):
|
||||
HUGGINGFACE_TOXIC_ROBERTA = ("https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta",)
|
||||
HUGGINGFACE_MINILM_EMBEDDING = "https://api-inference.huggingface.co/pipeline/feature-extraction/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
||||
HUGGINGFACE_FEATURE_EXTRACTION = "https://api-inference.huggingface.co/pipeline/feature-extraction/"
|
||||
|
||||
|
||||
class HF_embeddingModel(str, Enum):
|
||||
MINILM = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
||||
|
||||
|
||||
class HuggingFaceAPI:
|
||||
|
||||
Reference in New Issue
Block a user