From f74fe68f87918fad4b6b8c2498526092b60fdfb0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Tue, 27 Dec 2022 16:24:21 +0100 Subject: [PATCH 1/2] add journal table and JournalWriter helper class --- ..._27_1444-3358eb6834e6_add_journal_table.py | 75 +++++++++++ backend/oasst_backend/journal_writer.py | 122 ++++++++++++++++++ backend/oasst_backend/models/__init__.py | 3 + backend/oasst_backend/models/journal.py | 56 ++++++++ backend/oasst_backend/prompt_repository.py | 10 ++ oasst-shared/oasst_shared/utils.py | 7 + 6 files changed, 273 insertions(+) create mode 100644 backend/alembic/versions/2022_12_27_1444-3358eb6834e6_add_journal_table.py create mode 100644 backend/oasst_backend/journal_writer.py create mode 100644 backend/oasst_backend/models/journal.py create mode 100644 oasst-shared/oasst_shared/utils.py diff --git a/backend/alembic/versions/2022_12_27_1444-3358eb6834e6_add_journal_table.py b/backend/alembic/versions/2022_12_27_1444-3358eb6834e6_add_journal_table.py new file mode 100644 index 00000000..0dc937a0 --- /dev/null +++ b/backend/alembic/versions/2022_12_27_1444-3358eb6834e6_add_journal_table.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +"""add_journal_table + +Revision ID: 3358eb6834e6 +Revises: 067c4002f2d9 +Create Date: 2022-12-27 14:44:59.483868 + +""" +import sqlalchemy as sa +import sqlmodel +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "3358eb6834e6" +down_revision = "067c4002f2d9" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "journal", + sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column( + "created_date", sa.DateTime(timezone=True), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False + ), + sa.Column( + "event_payload", + postgresql.JSONB(astext_type=sa.Text()), + nullable=False, + ), + sa.Column("person_id", sqlmodel.sql.sqltypes.GUID(), nullable=True), + sa.Column("post_id", sqlmodel.sql.sqltypes.GUID(), nullable=True), + sa.Column("api_client_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("event_type", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False), + sa.ForeignKeyConstraint( + ["api_client_id"], + ["api_client.id"], + ), + sa.ForeignKeyConstraint( + ["person_id"], + ["person.id"], + ), + sa.ForeignKeyConstraint( + ["post_id"], + ["post.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_journal_person_id"), "journal", ["person_id"], unique=False) + op.create_table( + "journal_integration", + sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False), + sa.Column("last_run", sa.DateTime(), nullable=True), + sa.Column("description", sqlmodel.sql.sqltypes.AutoString(length=512), nullable=False), + sa.Column("last_journal_id", sqlmodel.sql.sqltypes.GUID(), nullable=True), + sa.Column("last_error", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("next_run", sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint( + ["last_journal_id"], + ["journal.id"], + ), + sa.PrimaryKeyConstraint("id", "description"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("journal_integration") + op.drop_index(op.f("ix_journal_person_id"), table_name="journal") + op.drop_table("journal") + # ### end Alembic commands ### diff --git a/backend/oasst_backend/journal_writer.py b/backend/oasst_backend/journal_writer.py new file mode 100644 index 00000000..897e2dda --- /dev/null +++ b/backend/oasst_backend/journal_writer.py @@ -0,0 +1,122 @@ +# -*- coding: utf-8 -*- +import enum +from typing import Literal, Optional +from uuid import UUID + +from oasst_backend.models import ApiClient, Journal, Person, WorkPackage +from oasst_backend.models.payload_column_type import PayloadContainer, payload_type +from oasst_shared.utils import utcnow +from pydantic import BaseModel +from sqlmodel import Session + + +class JournalEventType(str, enum.Enum): + """A label for a piece of text.""" + + user_created = "user_created" + text_reply_to_post = "text_reply_to_post" + post_rating = "post_rating" + post_ranking = "post_ranking" + + +@payload_type +class JournalEvent(BaseModel): + type: str + person_id: Optional[UUID] + post_id: Optional[UUID] + workpackage_id: Optional[UUID] + task_type: Optional[str] + + +@payload_type +class TextReplyEvent(JournalEvent): + type: Literal[JournalEventType.text_reply_to_post] = JournalEventType.text_reply_to_post + length: int + role: str + + +@payload_type +class RatingEvent(JournalEvent): + type: Literal[JournalEventType.post_rating] = JournalEventType.post_rating + rating: int + + +@payload_type +class RankingEvent(JournalEvent): + type: Literal[JournalEventType.post_ranking] = JournalEventType.post_ranking + ranking: list[int] + + +class JournalWriter: + def __init__(self, db: Session, api_client: ApiClient, person: Person): + self.db = db + self.api_client = api_client + self.person = person + self.person_id = self.person.id if self.person else None + + def log_text_reply(self, work_package: WorkPackage, post_id: UUID, role: str, length: int) -> Journal: + return self.log( + task_type=work_package.payload_type, + event_type=JournalEventType.text_reply_to_post, + payload=TextReplyEvent(role=role, length=length), + workpackage_id=work_package.id, + post_id=post_id, + ) + + def log_rating(self, work_package: WorkPackage, post_id: UUID, rating: int) -> Journal: + return self.log( + task_type=work_package.payload_type, + event_type=JournalEventType.post_rating, + payload=RatingEvent(rating=rating), + workpackage_id=work_package.id, + post_id=post_id, + ) + + def log_ranking(self, work_package: WorkPackage, post_id: UUID, ranking: list[int]) -> Journal: + return self.log( + task_type=work_package.payload_type, + event_type=JournalEventType.post_ranking, + payload=RankingEvent(ranking=ranking), + workpackage_id=work_package.id, + post_id=post_id, + ) + + def log( + self, + *, + payload: JournalEvent, + task_type: str, + event_type: str = None, + workpackage_id: Optional[UUID] = None, + post_id: Optional[UUID] = None, + commit: bool = True, + ) -> Journal: + if event_type is None: + if payload is None: + event_type = "null" + else: + event_type = type(payload).__name__ + + if payload.person_id is None: + payload.person_id = self.person_id + if payload.post_id is None: + payload.post_id = post_id + if payload.workpackage_id is None: + payload.workpackage_id = workpackage_id + if payload.task_type is None: + payload.task_type = task_type + + entry = Journal( + person_id=self.person_id, + api_client_id=self.api_client.id, + created_date=utcnow(), + event_type=event_type, + event_payload=PayloadContainer(payload=payload), + post_id=post_id, + ) + + self.db.add(entry) + if commit: + self.db.commit() + + return entry diff --git a/backend/oasst_backend/models/__init__.py b/backend/oasst_backend/models/__init__.py index 2a9b0c1f..0acc242c 100644 --- a/backend/oasst_backend/models/__init__.py +++ b/backend/oasst_backend/models/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- from .api_client import ApiClient +from .journal import Journal, JournalIntegration from .person import Person from .person_stats import PersonStats from .post import Post @@ -15,4 +16,6 @@ __all__ = [ "PostReaction", "WorkPackage", "TextLabels", + "Journal", + "JournalIntegration", ] diff --git a/backend/oasst_backend/models/journal.py b/backend/oasst_backend/models/journal.py new file mode 100644 index 00000000..4cec1e99 --- /dev/null +++ b/backend/oasst_backend/models/journal.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +from datetime import datetime +from typing import Optional +from uuid import UUID, uuid1, uuid4 + +import sqlalchemy as sa +import sqlalchemy.dialects.postgresql as pg +from sqlmodel import Field, SQLModel + +from .payload_column_type import PayloadContainer, payload_column_type + + +def generate_time_uuid(node=None, clock_seq=None): + """Create a lexicographically sortable time ordered custom (non-standard) UUID by reordering the timestamp fields of a version 1 UUID.""" + (time_low, time_mid, time_hi_version, clock_seq_hi_variant, clock_seq_low, node) = uuid1(node, clock_seq).fields + # reconstruct 60 bit timestamp, see version 1 uuid: https://www.rfc-editor.org/rfc/rfc4122 + timestamp = (time_hi_version & 0xFFF) << 48 | (time_mid << 32) | time_low + version = time_hi_version >> 12 + assert version == 1 + a = timestamp >> 28 # bits 28-59 + b = (timestamp >> 12) & 0xFFFF # bits 12-27 + c = timestamp & 0xFFF # bits 0-11 (clear version bits) + clock_seq_hi_variant &= 0xF # (clear variant bits) + return UUID(fields=(a, b, c, clock_seq_hi_variant, clock_seq_low, node), version=None) + + +class Journal(SQLModel, table=True): + __tablename__ = "journal" + + id: Optional[UUID] = Field( + sa_column=sa.Column(pg.UUID(as_uuid=True), primary_key=True, default=generate_time_uuid), + ) + created_date: Optional[datetime] = Field( + sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp()) + ) + person_id: UUID = Field(nullable=True, foreign_key="person.id", index=True) + post_id: Optional[UUID] = Field(foreign_key="post.id", nullable=True) + api_client_id: UUID = Field(foreign_key="api_client.id") + + event_type: str = Field(nullable=False, max_length=200) + event_payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False)) + + +class JournalIntegration(SQLModel, table=True): + __tablename__ = "journal_integration" + + id: Optional[UUID] = Field( + sa_column=sa.Column( + pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()") + ), + ) + description: str = Field(max_length=512, primary_key=True) + last_journal_id: UUID = Field(foreign_key="journal.id", nullable=True) + last_run: datetime = Field(sa_column=sa.Column(sa.DateTime(), nullable=True)) + last_error: str = Field(nullable=True) + next_run: datetime = Field(nullable=True) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 6db42de1..f4f83277 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -5,6 +5,7 @@ from uuid import UUID, uuid4 import oasst_backend.models.db_payload as db_payload from loguru import logger +from oasst_backend.journal_writer import JournalWriter from oasst_backend.models import ApiClient, Person, Post, PostReaction, TextLabels, WorkPackage from oasst_backend.models.payload_column_type import PayloadContainer from oasst_shared.schemas import protocol as protocol_schema @@ -17,6 +18,7 @@ class PromptRepository: self.api_client = api_client self.person = self.lookup_person(user) self.person_id = self.person.id if self.person else None + self.journal = JournalWriter(db, api_client, self.person) def lookup_person(self, user: protocol_schema.User) -> Person: if not user: @@ -116,6 +118,10 @@ class PromptRepository: self.validate_post_id(reply.post_id) self.validate_post_id(reply.user_post_id) + work_package = self.fetch_workpackage_by_postid(reply.post_id) + work_payload: db_payload.TaskPayload = work_package.payload.payload + logger.info(f"found task work package in db: {work_payload}") + # find post with post-id parent_post: Post = ( self.db.query(Post) @@ -141,6 +147,7 @@ class PromptRepository: role=role, payload=db_payload.PostPayload(text=reply.text), ) + self.journal.log_text_reply(work_package=work_package, post_id=user_post_id, role=role, length=len(reply.text)) return user_post def store_rating(self, rating: protocol_schema.PostRating) -> PostReaction: @@ -159,6 +166,7 @@ class PromptRepository: # store reaction to post reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating) reaction = self.insert_reaction(post.id, reaction_payload) + self.journal.log_rating(work_package, post_id=post.id, rating=rating.rating) logger.info(f"Ranking {rating.rating} stored for work_package {work_package.id}.") return reaction @@ -184,6 +192,7 @@ class PromptRepository: # store reaction to post reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking) reaction = self.insert_reaction(post.id, reaction_payload) + self.journal.log_ranking(work_package, post_id=post.id, ranking=ranking.ranking) logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.") @@ -199,6 +208,7 @@ class PromptRepository: # store reaction to post reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking) reaction = self.insert_reaction(post.id, reaction_payload) + self.journal.log_ranking(work_package, post_id=post.id, ranking=ranking.ranking) logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.") diff --git a/oasst-shared/oasst_shared/utils.py b/oasst-shared/oasst_shared/utils.py new file mode 100644 index 00000000..dd1cbf07 --- /dev/null +++ b/oasst-shared/oasst_shared/utils.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- +from datetime import datetime, timezone + + +def utcnow() -> datetime: + """Return the current utc date and time with tzinfo set to UTC.""" + return datetime.now(timezone.utc) From 602ec355cdd58cc10206b4cacbb5ed0aab9f1250 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Tue, 27 Dec 2022 18:05:02 +0100 Subject: [PATCH 2/2] moved workpackage-lookup for text-replies to prompt-repository --- backend/oasst_backend/api/v1/tasks.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index 7ec5aa96..0778fd4c 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -7,7 +7,6 @@ from fastapi import APIRouter, Depends, HTTPException from fastapi.security.api_key import APIKey from loguru import logger from oasst_backend.api import deps -from oasst_backend.models.db_payload import TaskPayload from oasst_backend.prompt_repository import PromptRepository from oasst_shared.schemas import protocol as protocol_schema from sqlmodel import Session @@ -219,10 +218,6 @@ def post_interaction( f"Frontend reports text reply to {interaction.post_id=} with {interaction.text=} by {interaction.user=}." ) - work_package = pr.fetch_workpackage_by_postid(interaction.post_id) - work_payload: TaskPayload = work_package.payload.payload - logger.info(f"found task work package in db: {work_payload}") - # here we store the text reply in the database # ToDo: role user or agent? pr.store_text_reply(interaction, role="unknown")