mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Merge pull request #540 from jojopirker/messageEmbeddings
Store Message embedding
This commit is contained in:
@@ -81,6 +81,7 @@
|
||||
DEBUG_USE_SEED_DATA: "true"
|
||||
MAX_WORKERS: "1"
|
||||
RATE_LIMIT: "false"
|
||||
DEBUG_SKIP_EMBEDDING_COMPUTATION: "true"
|
||||
ports:
|
||||
- 8080:8080
|
||||
|
||||
|
||||
+27
@@ -0,0 +1,27 @@
|
||||
"""added miniLM_embedding column to message
|
||||
|
||||
Revision ID: 023548d474f7
|
||||
Revises: ba61fe17fb6e
|
||||
Create Date: 2023-01-08 11:06:25.613290
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "023548d474f7"
|
||||
down_revision = "ba61fe17fb6e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("message", sa.Column("miniLM_embedding", sa.ARRAY(sa.Float()), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("message", "miniLM_embedding")
|
||||
# ### end Alembic commands ###
|
||||
+49
@@ -0,0 +1,49 @@
|
||||
"""embedding for message now in its own table
|
||||
|
||||
Revision ID: 35bdc1a08bb8
|
||||
Revises: 023548d474f7
|
||||
Create Date: 2023-01-08 16:03:48.454207
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "35bdc1a08bb8"
|
||||
down_revision = "023548d474f7"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"message_embedding",
|
||||
sa.Column("message_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("embedding", sa.ARRAY(sa.Float()), nullable=True),
|
||||
sa.Column("model", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["message_id"],
|
||||
["message.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("message_id", "model"),
|
||||
)
|
||||
op.drop_column("message", "miniLM_embedding")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"message",
|
||||
sa.Column(
|
||||
"miniLM_embedding",
|
||||
postgresql.ARRAY(postgresql.DOUBLE_PRECISION(precision=53)),
|
||||
autoincrement=False,
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.drop_table("message_embedding")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,30 @@
|
||||
"""Created date
|
||||
|
||||
Revision ID: aac6b2f66006
|
||||
Revises: 35bdc1a08bb8
|
||||
Create Date: 2023-01-08 21:28:27.342729
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "aac6b2f66006"
|
||||
down_revision = "35bdc1a08bb8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"message_embedding",
|
||||
sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("message_embedding", "created_date")
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,19 +1,14 @@
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.schemas.hugging_face import ToxicityClassification
|
||||
from oasst_backend.utils.hugging_face import HuggingFaceAPI
|
||||
from oasst_backend.utils.hugging_face import HfUrl, HuggingFaceAPI
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class HF_url(str, Enum):
|
||||
HUGGINGFACE_TOXIC_ROBERTA = "https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta"
|
||||
|
||||
|
||||
@router.get("/text_toxicity")
|
||||
async def get_text_toxicity(
|
||||
msg: str,
|
||||
@@ -30,7 +25,7 @@ async def get_text_toxicity(
|
||||
ToxicityClassification: the score of toxicity of the message.
|
||||
"""
|
||||
|
||||
api_url: str = HF_url.HUGGINGFACE_TOXIC_ROBERTA.value
|
||||
api_url: str = HfUrl.HUGGINGFACE_TOXIC_ROBERTA.value
|
||||
hugging_face_api = HuggingFaceAPI(api_url)
|
||||
response = await hugging_face_api.post(msg)
|
||||
|
||||
|
||||
@@ -7,7 +7,9 @@ from fastapi.security.api_key import APIKey
|
||||
from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1.utils import prepare_conversation
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.prompt_repository import PromptRepository, TaskRepository
|
||||
from oasst_backend.utils.hugging_face import HfEmbeddingModel, HfUrl, HuggingFaceAPI
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
@@ -253,7 +255,7 @@ def tasks_acknowledge_failure(
|
||||
|
||||
|
||||
@router.post("/interaction", response_model=protocol_schema.TaskDone)
|
||||
def tasks_interaction(
|
||||
async def tasks_interaction(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
@@ -274,12 +276,26 @@ def tasks_interaction(
|
||||
)
|
||||
|
||||
# here we store the text reply in the database
|
||||
pr.store_text_reply(
|
||||
newMessage = pr.store_text_reply(
|
||||
text=interaction.text,
|
||||
frontend_message_id=interaction.message_id,
|
||||
user_frontend_message_id=interaction.user_message_id,
|
||||
)
|
||||
|
||||
if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION:
|
||||
try:
|
||||
hugging_face_api = HuggingFaceAPI(
|
||||
f"{HfUrl.HUGGINGFACE_FEATURE_EXTRACTION.value}/{HfEmbeddingModel.MINILM.value}"
|
||||
)
|
||||
embedding = await hugging_face_api.post(interaction.text)
|
||||
pr.insert_message_embedding(
|
||||
message_id=newMessage.id, model=HfEmbeddingModel.MINILM.value, embedding=embedding
|
||||
)
|
||||
except OasstError:
|
||||
logger.error(
|
||||
f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
return protocol_schema.TaskDone()
|
||||
case protocol_schema.MessageRating:
|
||||
logger.info(
|
||||
|
||||
@@ -25,6 +25,7 @@ class Settings(BaseSettings):
|
||||
DEBUG_USE_SEED_DATA_PATH: Optional[FilePath] = (
|
||||
Path(__file__).parent.parent / "test_data/generic/test_generic_data.json"
|
||||
)
|
||||
DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False
|
||||
|
||||
HUGGING_FACE_API_KEY: str = ""
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from .api_client import ApiClient
|
||||
from .journal import Journal, JournalIntegration
|
||||
from .message import Message
|
||||
from .message_embedding import MessageEmbedding
|
||||
from .message_reaction import MessageReaction
|
||||
from .message_tree_state import MessageTreeState
|
||||
from .task import Task
|
||||
@@ -13,6 +14,7 @@ __all__ = [
|
||||
"User",
|
||||
"UserStats",
|
||||
"Message",
|
||||
"MessageEmbedding",
|
||||
"MessageReaction",
|
||||
"MessageTreeState",
|
||||
"Task",
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import ARRAY, Field, Float, SQLModel
|
||||
|
||||
|
||||
class MessageEmbedding(SQLModel, table=True):
|
||||
__tablename__ = "message_embedding"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("message_id", "model"),)
|
||||
|
||||
message_id: UUID = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("message.id"), nullable=False))
|
||||
model: str = Field(max_length=256, nullable=False)
|
||||
embedding: List[float] = Field(sa_column=sa.Column(ARRAY(Float)), nullable=True)
|
||||
|
||||
# In the case that the Message Embedding is created afterwards
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
@@ -2,13 +2,13 @@ import datetime
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from http import HTTPStatus
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import oasst_backend.models.db_payload as db_payload
|
||||
from loguru import logger
|
||||
from oasst_backend.journal_writer import JournalWriter
|
||||
from oasst_backend.models import ApiClient, Message, MessageReaction, TextLabels, User
|
||||
from oasst_backend.models import ApiClient, Message, MessageEmbedding, MessageReaction, TextLabels, User
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_backend.task_repository import TaskRepository, validate_frontend_message_id
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
@@ -91,7 +91,12 @@ class PromptRepository:
|
||||
self.db.refresh(message)
|
||||
return message
|
||||
|
||||
def store_text_reply(self, text: str, frontend_message_id: str, user_frontend_message_id: str) -> Message:
|
||||
def store_text_reply(
|
||||
self,
|
||||
text: str,
|
||||
frontend_message_id: str,
|
||||
user_frontend_message_id: str,
|
||||
) -> Message:
|
||||
validate_frontend_message_id(frontend_message_id)
|
||||
validate_frontend_message_id(user_frontend_message_id)
|
||||
|
||||
@@ -224,6 +229,30 @@ class PromptRepository:
|
||||
OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH,
|
||||
)
|
||||
|
||||
def insert_message_embedding(self, message_id: UUID, model: str, embedding: List[float]) -> MessageEmbedding:
|
||||
"""Insert the embedding of a new message in the database.
|
||||
|
||||
Args:
|
||||
message_id (UUID): the identifier of the message we want to save its embedding
|
||||
model (str): the model used for creating the embedding
|
||||
embedding (List[float]): the values obtained from the message & model
|
||||
|
||||
Raises:
|
||||
OasstError: if misses some of the before params
|
||||
|
||||
Returns:
|
||||
MessageEmbedding: the instance in the database of the embedding saved for that message
|
||||
"""
|
||||
|
||||
if None in (message_id, model, embedding):
|
||||
raise OasstError("Paramters missing to add embedding", OasstErrorCode.GENERIC_ERROR)
|
||||
|
||||
message_embedding = MessageEmbedding(message_id=message_id, model=model, embedding=embedding)
|
||||
self.db.add(message_embedding)
|
||||
self.db.commit()
|
||||
self.db.refresh(message_embedding)
|
||||
return message_embedding
|
||||
|
||||
def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
|
||||
if self.user_id is None:
|
||||
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
|
||||
|
||||
@@ -1,10 +1,21 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from oasst_backend.config import settings
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
|
||||
|
||||
class HfUrl(str, Enum):
|
||||
HUGGINGFACE_TOXIC_ROBERTA = ("https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta",)
|
||||
HUGGINGFACE_FEATURE_EXTRACTION = "https://api-inference.huggingface.co/pipeline/feature-extraction"
|
||||
|
||||
|
||||
class HfEmbeddingModel(str, Enum):
|
||||
MINILM = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
||||
|
||||
|
||||
class HuggingFaceAPI:
|
||||
"""Class Object to make post calls to endpoints for inference in models hosted in HuggingFace"""
|
||||
|
||||
@@ -41,6 +52,9 @@ class HuggingFaceAPI:
|
||||
async with session.post(self.api_url, headers=self.headers, json=payload) as response:
|
||||
# If we get a bad response
|
||||
if response.status != 200:
|
||||
|
||||
logger.error(response)
|
||||
logger.info(self.headers)
|
||||
raise OasstError(
|
||||
"Response Error Detoxify HuggingFace", error_code=OasstErrorCode.HUGGINGFACE_API_ERROR
|
||||
)
|
||||
|
||||
@@ -95,6 +95,7 @@ services:
|
||||
- DEBUG_SKIP_API_KEY_CHECK=True
|
||||
- DEBUG_USE_SEED_DATA=True
|
||||
- MAX_WORKERS=1
|
||||
- DEBUG_SKIP_EMBEDDING_COMPUTATION=True
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
|
||||
@@ -6,6 +6,7 @@ pushd "$parent_path/../../backend"
|
||||
|
||||
export DEBUG_SKIP_API_KEY_CHECK=True
|
||||
export DEBUG_USE_SEED_DATA=True
|
||||
export DEBUG_SKIP_EMBEDDING_COMPUTATION=True
|
||||
|
||||
uvicorn main:app --reload --port 8080 --host 0.0.0.0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user