mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-29 16:30:24 +08:00
added text labels to the API
This commit is contained in:
@@ -7,6 +7,7 @@ Create Date: ${create_date}
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""empty message
|
||||
|
||||
Revision ID: 067c4002f2d9
|
||||
Revises: 0daec5f8135f
|
||||
Create Date: 2022-12-25 17:05:21.208843
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "067c4002f2d9"
|
||||
down_revision = "0daec5f8135f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"text_labels",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
|
||||
sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("post_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("labels", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column("api_client_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.Column("text", sqlmodel.sql.sqltypes.AutoString(length=65536), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["api_client_id"],
|
||||
["api_client.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["post_id"],
|
||||
["post.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("text_labels")
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,6 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from fastapi import APIRouter
|
||||
from oasst_backend.api.v1 import tasks
|
||||
from oasst_backend.api.v1 import tasks, text_labels
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"])
|
||||
api_router.include_router(text_labels.router, prefix="/text_labels", tags=["text_labels"])
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import pydantic
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security.api_key import APIKey
|
||||
from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_400_BAD_REQUEST
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class LabelTextRequest(pydantic.BaseModel):
|
||||
text_labels: protocol_schema.TextLabels
|
||||
user: protocol_schema.User
|
||||
|
||||
|
||||
@router.post("/") # work with Union once more types are added
|
||||
def label_text(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
request: LabelTextRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Label a piece of text.
|
||||
"""
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
|
||||
try:
|
||||
logger.info(f"Labeling text {request=}.")
|
||||
pr = PromptRepository(db, api_client, user=request.user)
|
||||
pr.store_text_labels(request.text_labels)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to store label.")
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
@@ -4,6 +4,7 @@ from .person import Person
|
||||
from .person_stats import PersonStats
|
||||
from .post import Post
|
||||
from .post_reaction import PostReaction
|
||||
from .text_labels import TextLabels
|
||||
from .work_package import WorkPackage
|
||||
|
||||
__all__ = [
|
||||
@@ -13,4 +14,5 @@ __all__ = [
|
||||
"Post",
|
||||
"PostReaction",
|
||||
"WorkPackage",
|
||||
"TextLabels",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class TextLabels(SQLModel, table=True):
|
||||
__tablename__ = "text_labels"
|
||||
|
||||
id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
|
||||
),
|
||||
)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
)
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
text: str = Field(nullable=False, max_length=2**16)
|
||||
post_id: Optional[UUID] = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("post.id"), nullable=True))
|
||||
labels: dict[str, float] = Field(default={}, sa_column=sa.Column(pg.JSONB), nullable=False)
|
||||
@@ -5,7 +5,7 @@ from uuid import UUID, uuid4
|
||||
|
||||
import oasst_backend.models.db_payload as db_payload
|
||||
from loguru import logger
|
||||
from oasst_backend.models import ApiClient, Person, Post, PostReaction, WorkPackage
|
||||
from oasst_backend.models import ApiClient, Person, Post, PostReaction, TextLabels, WorkPackage
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
@@ -314,3 +314,17 @@ class PromptRepository:
|
||||
self.db.commit()
|
||||
self.db.refresh(reaction)
|
||||
return reaction
|
||||
|
||||
def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> TextLabels:
|
||||
model = TextLabels(
|
||||
api_client_id=self.api_client.id,
|
||||
text=text_labels.text,
|
||||
labels=text_labels.labels,
|
||||
)
|
||||
if text_labels.has_post_id:
|
||||
self.fetch_post_by_frontend_post_id(text_labels.post_id, fail_if_missing=True)
|
||||
model.post_id = text_labels.post_id
|
||||
self.db.add(model)
|
||||
self.db.commit()
|
||||
self.db.refresh(model)
|
||||
return model
|
||||
|
||||
@@ -204,3 +204,51 @@ AnyInteraction = Union[
|
||||
PostRating,
|
||||
PostRanking,
|
||||
]
|
||||
|
||||
|
||||
class TextLabel(str, enum.Enum):
|
||||
"""A label for a piece of text."""
|
||||
|
||||
spam = "spam"
|
||||
violence = "violence"
|
||||
sexual_content = "sexual_content"
|
||||
toxicity = "toxicity"
|
||||
political_content = "political_content"
|
||||
humor = "humor"
|
||||
sarcasm = "sarcasm"
|
||||
hate_speech = "hate_speech"
|
||||
profanity = "profanity"
|
||||
ad_hominem = "ad_hominem"
|
||||
insult = "insult"
|
||||
threat = "threat"
|
||||
aggressive = "aggressive"
|
||||
misleading = "misleading"
|
||||
helpful = "helpful"
|
||||
formal = "formal"
|
||||
cringe = "cringe"
|
||||
creative = "creative"
|
||||
beautiful = "beautiful"
|
||||
informative = "informative"
|
||||
based = "based"
|
||||
slang = "slang"
|
||||
|
||||
|
||||
class TextLabels(BaseModel):
|
||||
"""A set of labels for a piece of text."""
|
||||
|
||||
text: str
|
||||
labels: dict[TextLabel, float]
|
||||
post_id: str | None = None
|
||||
|
||||
@property
|
||||
def has_post_id(self) -> bool:
|
||||
"""Whether this TextLabels has a post_id."""
|
||||
return bool(self.post_id)
|
||||
|
||||
# check that each label value is between 0 and 1
|
||||
@pydantic.validator("labels")
|
||||
def check_label_values(cls, v):
|
||||
for key, value in v.items():
|
||||
if not 0 <= value <= 1:
|
||||
raise ValueError(f"Label values must be between 0 and 1, got {value} for {key}.")
|
||||
return v
|
||||
|
||||
Reference in New Issue
Block a user