From a37bf6bf41377cb30234bd8fc0521fc062714367 Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Sun, 25 Dec 2022 17:08:57 +0100 Subject: [PATCH] added text labels to the API --- backend/alembic/script.py.mako | 1 + .../versions/2022_12_25_1705-067c4002f2d9_.py | 47 ++++++++++++++++++ backend/oasst_backend/api/v1/api.py | 3 +- backend/oasst_backend/api/v1/text_labels.py | 41 ++++++++++++++++ backend/oasst_backend/models/__init__.py | 2 + backend/oasst_backend/models/text_labels.py | 25 ++++++++++ backend/oasst_backend/prompt_repository.py | 16 ++++++- oasst-shared/oasst_shared/schemas/protocol.py | 48 +++++++++++++++++++ 8 files changed, 181 insertions(+), 2 deletions(-) create mode 100644 backend/alembic/versions/2022_12_25_1705-067c4002f2d9_.py create mode 100644 backend/oasst_backend/api/v1/text_labels.py create mode 100644 backend/oasst_backend/models/text_labels.py diff --git a/backend/alembic/script.py.mako b/backend/alembic/script.py.mako index 55df2863..3124b62c 100644 --- a/backend/alembic/script.py.mako +++ b/backend/alembic/script.py.mako @@ -7,6 +7,7 @@ Create Date: ${create_date} """ from alembic import op import sqlalchemy as sa +import sqlmodel ${imports if imports else ""} # revision identifiers, used by Alembic. diff --git a/backend/alembic/versions/2022_12_25_1705-067c4002f2d9_.py b/backend/alembic/versions/2022_12_25_1705-067c4002f2d9_.py new file mode 100644 index 00000000..070d8d9c --- /dev/null +++ b/backend/alembic/versions/2022_12_25_1705-067c4002f2d9_.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +"""empty message + +Revision ID: 067c4002f2d9 +Revises: 0daec5f8135f +Create Date: 2022-12-25 17:05:21.208843 + +""" +import sqlalchemy as sa +import sqlmodel +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "067c4002f2d9" +down_revision = "0daec5f8135f" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "text_labels", + sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False), + sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.Column("post_id", postgresql.UUID(as_uuid=True), nullable=True), + sa.Column("labels", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("api_client_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("text", sqlmodel.sql.sqltypes.AutoString(length=65536), nullable=False), + sa.ForeignKeyConstraint( + ["api_client_id"], + ["api_client.id"], + ), + sa.ForeignKeyConstraint( + ["post_id"], + ["post.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("text_labels") + # ### end Alembic commands ### diff --git a/backend/oasst_backend/api/v1/api.py b/backend/oasst_backend/api/v1/api.py index cd1119d6..b54f3dd0 100644 --- a/backend/oasst_backend/api/v1/api.py +++ b/backend/oasst_backend/api/v1/api.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from fastapi import APIRouter -from oasst_backend.api.v1 import tasks +from oasst_backend.api.v1 import tasks, text_labels api_router = APIRouter() api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"]) +api_router.include_router(text_labels.router, prefix="/text_labels", tags=["text_labels"]) diff --git a/backend/oasst_backend/api/v1/text_labels.py b/backend/oasst_backend/api/v1/text_labels.py new file mode 100644 index 00000000..ff8f604d --- /dev/null +++ b/backend/oasst_backend/api/v1/text_labels.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +import pydantic +from fastapi import APIRouter, Depends, HTTPException +from fastapi.security.api_key import APIKey +from loguru import logger +from oasst_backend.api import deps +from oasst_backend.prompt_repository import PromptRepository +from oasst_shared.schemas import protocol as protocol_schema +from sqlmodel import Session +from starlette.status import HTTP_400_BAD_REQUEST + +router = APIRouter() + + +class LabelTextRequest(pydantic.BaseModel): + text_labels: protocol_schema.TextLabels + user: protocol_schema.User + + +@router.post("/") # work with Union once more types are added +def label_text( + *, + db: Session = Depends(deps.get_db), + api_key: APIKey = Depends(deps.get_api_key), + request: LabelTextRequest, +) -> None: + """ + Label a piece of text. + """ + api_client = deps.api_auth(api_key, db) + + try: + logger.info(f"Labeling text {request=}.") + pr = PromptRepository(db, api_client, user=request.user) + pr.store_text_labels(request.text_labels) + + except Exception: + logger.exception("Failed to store label.") + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + ) diff --git a/backend/oasst_backend/models/__init__.py b/backend/oasst_backend/models/__init__.py index 414ec385..2a9b0c1f 100644 --- a/backend/oasst_backend/models/__init__.py +++ b/backend/oasst_backend/models/__init__.py @@ -4,6 +4,7 @@ from .person import Person from .person_stats import PersonStats from .post import Post from .post_reaction import PostReaction +from .text_labels import TextLabels from .work_package import WorkPackage __all__ = [ @@ -13,4 +14,5 @@ __all__ = [ "Post", "PostReaction", "WorkPackage", + "TextLabels", ] diff --git a/backend/oasst_backend/models/text_labels.py b/backend/oasst_backend/models/text_labels.py new file mode 100644 index 00000000..2699302f --- /dev/null +++ b/backend/oasst_backend/models/text_labels.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +from datetime import datetime +from typing import Optional +from uuid import UUID, uuid4 + +import sqlalchemy as sa +import sqlalchemy.dialects.postgresql as pg +from sqlmodel import Field, SQLModel + + +class TextLabels(SQLModel, table=True): + __tablename__ = "text_labels" + + id: Optional[UUID] = Field( + sa_column=sa.Column( + pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()") + ), + ) + created_date: Optional[datetime] = Field( + sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()), + ) + api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id") + text: str = Field(nullable=False, max_length=2**16) + post_id: Optional[UUID] = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("post.id"), nullable=True)) + labels: dict[str, float] = Field(default={}, sa_column=sa.Column(pg.JSONB), nullable=False) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index b0063cdf..6db42de1 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -5,7 +5,7 @@ from uuid import UUID, uuid4 import oasst_backend.models.db_payload as db_payload from loguru import logger -from oasst_backend.models import ApiClient, Person, Post, PostReaction, WorkPackage +from oasst_backend.models import ApiClient, Person, Post, PostReaction, TextLabels, WorkPackage from oasst_backend.models.payload_column_type import PayloadContainer from oasst_shared.schemas import protocol as protocol_schema from sqlmodel import Session @@ -314,3 +314,17 @@ class PromptRepository: self.db.commit() self.db.refresh(reaction) return reaction + + def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> TextLabels: + model = TextLabels( + api_client_id=self.api_client.id, + text=text_labels.text, + labels=text_labels.labels, + ) + if text_labels.has_post_id: + self.fetch_post_by_frontend_post_id(text_labels.post_id, fail_if_missing=True) + model.post_id = text_labels.post_id + self.db.add(model) + self.db.commit() + self.db.refresh(model) + return model diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index d5f508b6..4599b4bd 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -204,3 +204,51 @@ AnyInteraction = Union[ PostRating, PostRanking, ] + + +class TextLabel(str, enum.Enum): + """A label for a piece of text.""" + + spam = "spam" + violence = "violence" + sexual_content = "sexual_content" + toxicity = "toxicity" + political_content = "political_content" + humor = "humor" + sarcasm = "sarcasm" + hate_speech = "hate_speech" + profanity = "profanity" + ad_hominem = "ad_hominem" + insult = "insult" + threat = "threat" + aggressive = "aggressive" + misleading = "misleading" + helpful = "helpful" + formal = "formal" + cringe = "cringe" + creative = "creative" + beautiful = "beautiful" + informative = "informative" + based = "based" + slang = "slang" + + +class TextLabels(BaseModel): + """A set of labels for a piece of text.""" + + text: str + labels: dict[TextLabel, float] + post_id: str | None = None + + @property + def has_post_id(self) -> bool: + """Whether this TextLabels has a post_id.""" + return bool(self.post_id) + + # check that each label value is between 0 and 1 + @pydantic.validator("labels") + def check_label_values(cls, v): + for key, value in v.items(): + if not 0 <= value <= 1: + raise ValueError(f"Label values must be between 0 and 1, got {value} for {key}.") + return v