added text labels to the API

This commit is contained in:
Yannic Kilcher
2022-12-25 17:08:57 +01:00
parent db10c52877
commit a37bf6bf41
8 changed files with 181 additions and 2 deletions
+1
View File
@@ -7,6 +7,7 @@ Create Date: ${create_date}
"""
from alembic import op
import sqlalchemy as sa
import sqlmodel
${imports if imports else ""}
# revision identifiers, used by Alembic.
@@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
"""empty message
Revision ID: 067c4002f2d9
Revises: 0daec5f8135f
Create Date: 2022-12-25 17:05:21.208843
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "067c4002f2d9"
down_revision = "0daec5f8135f"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"text_labels",
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("post_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("labels", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column("api_client_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("text", sqlmodel.sql.sqltypes.AutoString(length=65536), nullable=False),
sa.ForeignKeyConstraint(
["api_client_id"],
["api_client.id"],
),
sa.ForeignKeyConstraint(
["post_id"],
["post.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("text_labels")
# ### end Alembic commands ###
+2 -1
View File
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
from fastapi import APIRouter
from oasst_backend.api.v1 import tasks
from oasst_backend.api.v1 import tasks, text_labels
api_router = APIRouter()
api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"])
api_router.include_router(text_labels.router, prefix="/text_labels", tags=["text_labels"])
@@ -0,0 +1,41 @@
# -*- coding: utf-8 -*-
import pydantic
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security.api_key import APIKey
from loguru import logger
from oasst_backend.api import deps
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
from starlette.status import HTTP_400_BAD_REQUEST
router = APIRouter()
class LabelTextRequest(pydantic.BaseModel):
text_labels: protocol_schema.TextLabels
user: protocol_schema.User
@router.post("/") # work with Union once more types are added
def label_text(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
request: LabelTextRequest,
) -> None:
"""
Label a piece of text.
"""
api_client = deps.api_auth(api_key, db)
try:
logger.info(f"Labeling text {request=}.")
pr = PromptRepository(db, api_client, user=request.user)
pr.store_text_labels(request.text_labels)
except Exception:
logger.exception("Failed to store label.")
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
)
+2
View File
@@ -4,6 +4,7 @@ from .person import Person
from .person_stats import PersonStats
from .post import Post
from .post_reaction import PostReaction
from .text_labels import TextLabels
from .work_package import WorkPackage
__all__ = [
@@ -13,4 +14,5 @@ __all__ = [
"Post",
"PostReaction",
"WorkPackage",
"TextLabels",
]
@@ -0,0 +1,25 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional
from uuid import UUID, uuid4
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, SQLModel
class TextLabels(SQLModel, table=True):
__tablename__ = "text_labels"
id: Optional[UUID] = Field(
sa_column=sa.Column(
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
),
)
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
)
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
text: str = Field(nullable=False, max_length=2**16)
post_id: Optional[UUID] = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("post.id"), nullable=True))
labels: dict[str, float] = Field(default={}, sa_column=sa.Column(pg.JSONB), nullable=False)
+15 -1
View File
@@ -5,7 +5,7 @@ from uuid import UUID, uuid4
import oasst_backend.models.db_payload as db_payload
from loguru import logger
from oasst_backend.models import ApiClient, Person, Post, PostReaction, WorkPackage
from oasst_backend.models import ApiClient, Person, Post, PostReaction, TextLabels, WorkPackage
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
@@ -314,3 +314,17 @@ class PromptRepository:
self.db.commit()
self.db.refresh(reaction)
return reaction
def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> TextLabels:
model = TextLabels(
api_client_id=self.api_client.id,
text=text_labels.text,
labels=text_labels.labels,
)
if text_labels.has_post_id:
self.fetch_post_by_frontend_post_id(text_labels.post_id, fail_if_missing=True)
model.post_id = text_labels.post_id
self.db.add(model)
self.db.commit()
self.db.refresh(model)
return model
@@ -204,3 +204,51 @@ AnyInteraction = Union[
PostRating,
PostRanking,
]
class TextLabel(str, enum.Enum):
"""A label for a piece of text."""
spam = "spam"
violence = "violence"
sexual_content = "sexual_content"
toxicity = "toxicity"
political_content = "political_content"
humor = "humor"
sarcasm = "sarcasm"
hate_speech = "hate_speech"
profanity = "profanity"
ad_hominem = "ad_hominem"
insult = "insult"
threat = "threat"
aggressive = "aggressive"
misleading = "misleading"
helpful = "helpful"
formal = "formal"
cringe = "cringe"
creative = "creative"
beautiful = "beautiful"
informative = "informative"
based = "based"
slang = "slang"
class TextLabels(BaseModel):
"""A set of labels for a piece of text."""
text: str
labels: dict[TextLabel, float]
post_id: str | None = None
@property
def has_post_id(self) -> bool:
"""Whether this TextLabels has a post_id."""
return bool(self.post_id)
# check that each label value is between 0 and 1
@pydantic.validator("labels")
def check_label_values(cls, v):
for key, value in v.items():
if not 0 <= value <= 1:
raise ValueError(f"Label values must be between 0 and 1, got {value} for {key}.")
return v