Merge branch 'LAION-AI:main' into Open-Assistant-67

This commit is contained in:
Seyf Eddine NECIB
2022-12-27 22:51:46 +01:00
committed by GitHub
7 changed files with 273 additions and 5 deletions
@@ -0,0 +1,75 @@
# -*- coding: utf-8 -*-
"""add_journal_table
Revision ID: 3358eb6834e6
Revises: 067c4002f2d9
Create Date: 2022-12-27 14:44:59.483868
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "3358eb6834e6"
down_revision = "067c4002f2d9"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"journal",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column(
"created_date", sa.DateTime(timezone=True), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False
),
sa.Column(
"event_payload",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
),
sa.Column("person_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.Column("post_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.Column("api_client_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("event_type", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
sa.ForeignKeyConstraint(
["api_client_id"],
["api_client.id"],
),
sa.ForeignKeyConstraint(
["person_id"],
["person.id"],
),
sa.ForeignKeyConstraint(
["post_id"],
["post.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_journal_person_id"), "journal", ["person_id"], unique=False)
op.create_table(
"journal_integration",
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
sa.Column("last_run", sa.DateTime(), nullable=True),
sa.Column("description", sqlmodel.sql.sqltypes.AutoString(length=512), nullable=False),
sa.Column("last_journal_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.Column("last_error", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("next_run", sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(
["last_journal_id"],
["journal.id"],
),
sa.PrimaryKeyConstraint("id", "description"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("journal_integration")
op.drop_index(op.f("ix_journal_person_id"), table_name="journal")
op.drop_table("journal")
# ### end Alembic commands ###
-5
View File
@@ -7,7 +7,6 @@ from fastapi import APIRouter, Depends, HTTPException
from fastapi.security.api_key import APIKey
from loguru import logger
from oasst_backend.api import deps
from oasst_backend.models.db_payload import TaskPayload
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
@@ -219,10 +218,6 @@ def post_interaction(
f"Frontend reports text reply to {interaction.post_id=} with {interaction.text=} by {interaction.user=}."
)
work_package = pr.fetch_workpackage_by_postid(interaction.post_id)
work_payload: TaskPayload = work_package.payload.payload
logger.info(f"found task work package in db: {work_payload}")
# here we store the text reply in the database
# ToDo: role user or agent?
pr.store_text_reply(interaction, role="unknown")
+122
View File
@@ -0,0 +1,122 @@
# -*- coding: utf-8 -*-
import enum
from typing import Literal, Optional
from uuid import UUID
from oasst_backend.models import ApiClient, Journal, Person, WorkPackage
from oasst_backend.models.payload_column_type import PayloadContainer, payload_type
from oasst_shared.utils import utcnow
from pydantic import BaseModel
from sqlmodel import Session
class JournalEventType(str, enum.Enum):
"""A label for a piece of text."""
user_created = "user_created"
text_reply_to_post = "text_reply_to_post"
post_rating = "post_rating"
post_ranking = "post_ranking"
@payload_type
class JournalEvent(BaseModel):
type: str
person_id: Optional[UUID]
post_id: Optional[UUID]
workpackage_id: Optional[UUID]
task_type: Optional[str]
@payload_type
class TextReplyEvent(JournalEvent):
type: Literal[JournalEventType.text_reply_to_post] = JournalEventType.text_reply_to_post
length: int
role: str
@payload_type
class RatingEvent(JournalEvent):
type: Literal[JournalEventType.post_rating] = JournalEventType.post_rating
rating: int
@payload_type
class RankingEvent(JournalEvent):
type: Literal[JournalEventType.post_ranking] = JournalEventType.post_ranking
ranking: list[int]
class JournalWriter:
def __init__(self, db: Session, api_client: ApiClient, person: Person):
self.db = db
self.api_client = api_client
self.person = person
self.person_id = self.person.id if self.person else None
def log_text_reply(self, work_package: WorkPackage, post_id: UUID, role: str, length: int) -> Journal:
return self.log(
task_type=work_package.payload_type,
event_type=JournalEventType.text_reply_to_post,
payload=TextReplyEvent(role=role, length=length),
workpackage_id=work_package.id,
post_id=post_id,
)
def log_rating(self, work_package: WorkPackage, post_id: UUID, rating: int) -> Journal:
return self.log(
task_type=work_package.payload_type,
event_type=JournalEventType.post_rating,
payload=RatingEvent(rating=rating),
workpackage_id=work_package.id,
post_id=post_id,
)
def log_ranking(self, work_package: WorkPackage, post_id: UUID, ranking: list[int]) -> Journal:
return self.log(
task_type=work_package.payload_type,
event_type=JournalEventType.post_ranking,
payload=RankingEvent(ranking=ranking),
workpackage_id=work_package.id,
post_id=post_id,
)
def log(
self,
*,
payload: JournalEvent,
task_type: str,
event_type: str = None,
workpackage_id: Optional[UUID] = None,
post_id: Optional[UUID] = None,
commit: bool = True,
) -> Journal:
if event_type is None:
if payload is None:
event_type = "null"
else:
event_type = type(payload).__name__
if payload.person_id is None:
payload.person_id = self.person_id
if payload.post_id is None:
payload.post_id = post_id
if payload.workpackage_id is None:
payload.workpackage_id = workpackage_id
if payload.task_type is None:
payload.task_type = task_type
entry = Journal(
person_id=self.person_id,
api_client_id=self.api_client.id,
created_date=utcnow(),
event_type=event_type,
event_payload=PayloadContainer(payload=payload),
post_id=post_id,
)
self.db.add(entry)
if commit:
self.db.commit()
return entry
+3
View File
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
from .api_client import ApiClient
from .journal import Journal, JournalIntegration
from .person import Person
from .person_stats import PersonStats
from .post import Post
@@ -15,4 +16,6 @@ __all__ = [
"PostReaction",
"WorkPackage",
"TextLabels",
"Journal",
"JournalIntegration",
]
+56
View File
@@ -0,0 +1,56 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional
from uuid import UUID, uuid1, uuid4
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, SQLModel
from .payload_column_type import PayloadContainer, payload_column_type
def generate_time_uuid(node=None, clock_seq=None):
"""Create a lexicographically sortable time ordered custom (non-standard) UUID by reordering the timestamp fields of a version 1 UUID."""
(time_low, time_mid, time_hi_version, clock_seq_hi_variant, clock_seq_low, node) = uuid1(node, clock_seq).fields
# reconstruct 60 bit timestamp, see version 1 uuid: https://www.rfc-editor.org/rfc/rfc4122
timestamp = (time_hi_version & 0xFFF) << 48 | (time_mid << 32) | time_low
version = time_hi_version >> 12
assert version == 1
a = timestamp >> 28 # bits 28-59
b = (timestamp >> 12) & 0xFFFF # bits 12-27
c = timestamp & 0xFFF # bits 0-11 (clear version bits)
clock_seq_hi_variant &= 0xF # (clear variant bits)
return UUID(fields=(a, b, c, clock_seq_hi_variant, clock_seq_low, node), version=None)
class Journal(SQLModel, table=True):
__tablename__ = "journal"
id: Optional[UUID] = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), primary_key=True, default=generate_time_uuid),
)
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
)
person_id: UUID = Field(nullable=True, foreign_key="person.id", index=True)
post_id: Optional[UUID] = Field(foreign_key="post.id", nullable=True)
api_client_id: UUID = Field(foreign_key="api_client.id")
event_type: str = Field(nullable=False, max_length=200)
event_payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
class JournalIntegration(SQLModel, table=True):
__tablename__ = "journal_integration"
id: Optional[UUID] = Field(
sa_column=sa.Column(
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
),
)
description: str = Field(max_length=512, primary_key=True)
last_journal_id: UUID = Field(foreign_key="journal.id", nullable=True)
last_run: datetime = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
last_error: str = Field(nullable=True)
next_run: datetime = Field(nullable=True)
@@ -5,6 +5,7 @@ from uuid import UUID, uuid4
import oasst_backend.models.db_payload as db_payload
from loguru import logger
from oasst_backend.journal_writer import JournalWriter
from oasst_backend.models import ApiClient, Person, Post, PostReaction, TextLabels, WorkPackage
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_shared.schemas import protocol as protocol_schema
@@ -17,6 +18,7 @@ class PromptRepository:
self.api_client = api_client
self.person = self.lookup_person(user)
self.person_id = self.person.id if self.person else None
self.journal = JournalWriter(db, api_client, self.person)
def lookup_person(self, user: protocol_schema.User) -> Person:
if not user:
@@ -116,6 +118,10 @@ class PromptRepository:
self.validate_post_id(reply.post_id)
self.validate_post_id(reply.user_post_id)
work_package = self.fetch_workpackage_by_postid(reply.post_id)
work_payload: db_payload.TaskPayload = work_package.payload.payload
logger.info(f"found task work package in db: {work_payload}")
# find post with post-id
parent_post: Post = (
self.db.query(Post)
@@ -141,6 +147,7 @@ class PromptRepository:
role=role,
payload=db_payload.PostPayload(text=reply.text),
)
self.journal.log_text_reply(work_package=work_package, post_id=user_post_id, role=role, length=len(reply.text))
return user_post
def store_rating(self, rating: protocol_schema.PostRating) -> PostReaction:
@@ -159,6 +166,7 @@ class PromptRepository:
# store reaction to post
reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating)
reaction = self.insert_reaction(post.id, reaction_payload)
self.journal.log_rating(work_package, post_id=post.id, rating=rating.rating)
logger.info(f"Ranking {rating.rating} stored for work_package {work_package.id}.")
return reaction
@@ -184,6 +192,7 @@ class PromptRepository:
# store reaction to post
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
reaction = self.insert_reaction(post.id, reaction_payload)
self.journal.log_ranking(work_package, post_id=post.id, ranking=ranking.ranking)
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
@@ -199,6 +208,7 @@ class PromptRepository:
# store reaction to post
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
reaction = self.insert_reaction(post.id, reaction_payload)
self.journal.log_ranking(work_package, post_id=post.id, ranking=ranking.ranking)
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
+7
View File
@@ -0,0 +1,7 @@
# -*- coding: utf-8 -*-
from datetime import datetime, timezone
def utcnow() -> datetime:
"""Return the current utc date and time with tzinfo set to UTC."""
return datetime.now(timezone.utc)