mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
message embeddings in Messages table
This commit is contained in:
+29
@@ -0,0 +1,29 @@
|
||||
"""added miniLM_embedding column to message
|
||||
|
||||
Revision ID: 023548d474f7
|
||||
Revises: ba61fe17fb6e
|
||||
Create Date: 2023-01-08 11:06:25.613290
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '023548d474f7'
|
||||
down_revision = 'ba61fe17fb6e'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('message', sa.Column('miniLM_embedding', sa.ARRAY(sa.Float()), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('message', 'miniLM_embedding')
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,19 +1,14 @@
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.schemas.hugging_face import ToxicityClassification
|
||||
from oasst_backend.utils.hugging_face import HuggingFaceAPI
|
||||
from oasst_backend.utils.hugging_face import HF_url, HuggingFaceAPI
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class HF_url(str, Enum):
|
||||
HUGGINGFACE_TOXIC_ROBERTA = "https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta"
|
||||
|
||||
|
||||
@router.get("/text_toxicity")
|
||||
async def get_text_toxicity(
|
||||
msg: str,
|
||||
|
||||
@@ -7,7 +7,9 @@ from fastapi.security.api_key import APIKey
|
||||
from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1.utils import prepare_conversation
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.utils.hugging_face import HF_url, HuggingFaceAPI
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
@@ -253,7 +255,7 @@ def tasks_acknowledge_failure(
|
||||
|
||||
|
||||
@router.post("/interaction", response_model=protocol_schema.TaskDone)
|
||||
def tasks_interaction(
|
||||
async def tasks_interaction(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
@@ -273,11 +275,22 @@ def tasks_interaction(
|
||||
f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
embedding = None
|
||||
if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION:
|
||||
try:
|
||||
hugging_face_api = HuggingFaceAPI(HF_url.HUGGINGFACE_MINILM_EMBEDDING.value)
|
||||
embedding = await hugging_face_api.post(interaction.text)
|
||||
except:
|
||||
logger.error(
|
||||
f"Could not fetch embbeddings for text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
# here we store the text reply in the database
|
||||
pr.store_text_reply(
|
||||
text=interaction.text,
|
||||
frontend_message_id=interaction.message_id,
|
||||
user_frontend_message_id=interaction.user_message_id,
|
||||
miniLM_embedding=embedding,
|
||||
)
|
||||
|
||||
return protocol_schema.TaskDone()
|
||||
|
||||
@@ -25,6 +25,7 @@ class Settings(BaseSettings):
|
||||
DEBUG_USE_SEED_DATA_PATH: Optional[FilePath] = (
|
||||
Path(__file__).parent.parent / "test_data/generic/test_generic_data.json"
|
||||
)
|
||||
DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False
|
||||
|
||||
HUGGING_FACE_API_KEY: str = ""
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime
|
||||
from http import HTTPStatus
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
@@ -8,7 +8,7 @@ import sqlalchemy.dialects.postgresql as pg
|
||||
from oasst_backend.models.db_payload import MessagePayload
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from sqlalchemy import false
|
||||
from sqlmodel import Field, Index, SQLModel
|
||||
from sqlmodel import ARRAY, Field, Float, Index, SQLModel
|
||||
|
||||
from .payload_column_type import PayloadContainer, payload_column_type
|
||||
|
||||
@@ -40,6 +40,7 @@ class Message(SQLModel, table=True):
|
||||
depth: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
|
||||
children_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
|
||||
deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
|
||||
miniLM_embedding: List[float] = Field(sa_column=sa.Column(ARRAY(Float)), nullable=True)
|
||||
|
||||
def ensure_is_message(self) -> None:
|
||||
if not self.payload or not isinstance(self.payload.payload, MessagePayload):
|
||||
|
||||
@@ -2,7 +2,7 @@ import datetime
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from http import HTTPStatus
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import oasst_backend.models.db_payload as db_payload
|
||||
@@ -122,7 +122,9 @@ class PromptRepository:
|
||||
)
|
||||
return task
|
||||
|
||||
def store_text_reply(self, text: str, frontend_message_id: str, user_frontend_message_id: str) -> Message:
|
||||
def store_text_reply(
|
||||
self, text: str, frontend_message_id: str, user_frontend_message_id: str, miniLM_embedding: List[float] = None
|
||||
) -> Message:
|
||||
self.validate_frontend_message_id(frontend_message_id)
|
||||
self.validate_frontend_message_id(user_frontend_message_id)
|
||||
|
||||
@@ -163,6 +165,7 @@ class PromptRepository:
|
||||
role=role,
|
||||
payload=db_payload.MessagePayload(text=text),
|
||||
depth=depth,
|
||||
miniLM_embedding=miniLM_embedding,
|
||||
)
|
||||
if not task.collective:
|
||||
task.done = True
|
||||
@@ -366,6 +369,7 @@ class PromptRepository:
|
||||
payload: db_payload.MessagePayload,
|
||||
payload_type: str = None,
|
||||
depth: int = 0,
|
||||
miniLM_embedding: List[float] = None,
|
||||
) -> Message:
|
||||
if payload_type is None:
|
||||
if payload is None:
|
||||
@@ -385,6 +389,7 @@ class PromptRepository:
|
||||
payload_type=payload_type,
|
||||
payload=PayloadContainer(payload=payload),
|
||||
depth=depth,
|
||||
miniLM_embedding=miniLM_embedding,
|
||||
)
|
||||
self.db.add(message)
|
||||
self.db.commit()
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict
|
||||
|
||||
import aiohttp
|
||||
@@ -5,6 +6,11 @@ from oasst_backend.config import settings
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
|
||||
|
||||
class HF_url(str, Enum):
|
||||
HUGGINGFACE_TOXIC_ROBERTA = ("https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta",)
|
||||
HUGGINGFACE_MINILM_EMBEDDING = "https://api-inference.huggingface.co/pipeline/feature-extraction/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
||||
|
||||
|
||||
class HuggingFaceAPI:
|
||||
"""Class Object to make post calls to endpoints for inference in models hosted in HuggingFace"""
|
||||
|
||||
@@ -41,6 +47,10 @@ class HuggingFaceAPI:
|
||||
async with session.post(self.api_url, headers=self.headers, json=payload) as response:
|
||||
# If we get a bad response
|
||||
if response.status != 200:
|
||||
from loguru import logger
|
||||
|
||||
logger.error(response)
|
||||
logger.info(self.headers)
|
||||
raise OasstError(
|
||||
"Response Error Detoxify HuggingFace", error_code=OasstErrorCode.HUGGINGFACE_API_ERROR
|
||||
)
|
||||
|
||||
@@ -95,6 +95,7 @@ services:
|
||||
- DEBUG_SKIP_API_KEY_CHECK=True
|
||||
- DEBUG_USE_SEED_DATA=True
|
||||
- MAX_WORKERS=1
|
||||
- DEBUG_SKIP_EMBEDDING_COMPUTATION=True
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
|
||||
@@ -6,6 +6,7 @@ pushd "$parent_path/../../backend"
|
||||
|
||||
export DEBUG_SKIP_API_KEY_CHECK=True
|
||||
export DEBUG_USE_SEED_DATA=True
|
||||
export DEBUG_SKIP_EMBEDDING_COMPUTATION=True
|
||||
|
||||
uvicorn main:app --reload --port 8080 --host 0.0.0.0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user