mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Merge branch 'LAION-AI:main' into Open-Assistant-67
This commit is contained in:
@@ -0,0 +1,75 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""add_journal_table
|
||||
|
||||
Revision ID: 3358eb6834e6
|
||||
Revises: 067c4002f2d9
|
||||
Create Date: 2022-12-27 14:44:59.483868
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3358eb6834e6"
|
||||
down_revision = "067c4002f2d9"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"journal",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column(
|
||||
"created_date", sa.DateTime(timezone=True), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False
|
||||
),
|
||||
sa.Column(
|
||||
"event_payload",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("person_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
|
||||
sa.Column("post_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
|
||||
sa.Column("api_client_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.Column("event_type", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["api_client_id"],
|
||||
["api_client.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["person_id"],
|
||||
["person.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["post_id"],
|
||||
["post.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("ix_journal_person_id"), "journal", ["person_id"], unique=False)
|
||||
op.create_table(
|
||||
"journal_integration",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
|
||||
sa.Column("last_run", sa.DateTime(), nullable=True),
|
||||
sa.Column("description", sqlmodel.sql.sqltypes.AutoString(length=512), nullable=False),
|
||||
sa.Column("last_journal_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
|
||||
sa.Column("last_error", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column("next_run", sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["last_journal_id"],
|
||||
["journal.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", "description"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("journal_integration")
|
||||
op.drop_index(op.f("ix_journal_person_id"), table_name="journal")
|
||||
op.drop_table("journal")
|
||||
# ### end Alembic commands ###
|
||||
@@ -7,7 +7,6 @@ from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security.api_key import APIKey
|
||||
from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.models.db_payload import TaskPayload
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
@@ -219,10 +218,6 @@ def post_interaction(
|
||||
f"Frontend reports text reply to {interaction.post_id=} with {interaction.text=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
work_package = pr.fetch_workpackage_by_postid(interaction.post_id)
|
||||
work_payload: TaskPayload = work_package.payload.payload
|
||||
logger.info(f"found task work package in db: {work_payload}")
|
||||
|
||||
# here we store the text reply in the database
|
||||
# ToDo: role user or agent?
|
||||
pr.store_text_reply(interaction, role="unknown")
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import enum
|
||||
from typing import Literal, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from oasst_backend.models import ApiClient, Journal, Person, WorkPackage
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer, payload_type
|
||||
from oasst_shared.utils import utcnow
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import Session
|
||||
|
||||
|
||||
class JournalEventType(str, enum.Enum):
|
||||
"""A label for a piece of text."""
|
||||
|
||||
user_created = "user_created"
|
||||
text_reply_to_post = "text_reply_to_post"
|
||||
post_rating = "post_rating"
|
||||
post_ranking = "post_ranking"
|
||||
|
||||
|
||||
@payload_type
|
||||
class JournalEvent(BaseModel):
|
||||
type: str
|
||||
person_id: Optional[UUID]
|
||||
post_id: Optional[UUID]
|
||||
workpackage_id: Optional[UUID]
|
||||
task_type: Optional[str]
|
||||
|
||||
|
||||
@payload_type
|
||||
class TextReplyEvent(JournalEvent):
|
||||
type: Literal[JournalEventType.text_reply_to_post] = JournalEventType.text_reply_to_post
|
||||
length: int
|
||||
role: str
|
||||
|
||||
|
||||
@payload_type
|
||||
class RatingEvent(JournalEvent):
|
||||
type: Literal[JournalEventType.post_rating] = JournalEventType.post_rating
|
||||
rating: int
|
||||
|
||||
|
||||
@payload_type
|
||||
class RankingEvent(JournalEvent):
|
||||
type: Literal[JournalEventType.post_ranking] = JournalEventType.post_ranking
|
||||
ranking: list[int]
|
||||
|
||||
|
||||
class JournalWriter:
|
||||
def __init__(self, db: Session, api_client: ApiClient, person: Person):
|
||||
self.db = db
|
||||
self.api_client = api_client
|
||||
self.person = person
|
||||
self.person_id = self.person.id if self.person else None
|
||||
|
||||
def log_text_reply(self, work_package: WorkPackage, post_id: UUID, role: str, length: int) -> Journal:
|
||||
return self.log(
|
||||
task_type=work_package.payload_type,
|
||||
event_type=JournalEventType.text_reply_to_post,
|
||||
payload=TextReplyEvent(role=role, length=length),
|
||||
workpackage_id=work_package.id,
|
||||
post_id=post_id,
|
||||
)
|
||||
|
||||
def log_rating(self, work_package: WorkPackage, post_id: UUID, rating: int) -> Journal:
|
||||
return self.log(
|
||||
task_type=work_package.payload_type,
|
||||
event_type=JournalEventType.post_rating,
|
||||
payload=RatingEvent(rating=rating),
|
||||
workpackage_id=work_package.id,
|
||||
post_id=post_id,
|
||||
)
|
||||
|
||||
def log_ranking(self, work_package: WorkPackage, post_id: UUID, ranking: list[int]) -> Journal:
|
||||
return self.log(
|
||||
task_type=work_package.payload_type,
|
||||
event_type=JournalEventType.post_ranking,
|
||||
payload=RankingEvent(ranking=ranking),
|
||||
workpackage_id=work_package.id,
|
||||
post_id=post_id,
|
||||
)
|
||||
|
||||
def log(
|
||||
self,
|
||||
*,
|
||||
payload: JournalEvent,
|
||||
task_type: str,
|
||||
event_type: str = None,
|
||||
workpackage_id: Optional[UUID] = None,
|
||||
post_id: Optional[UUID] = None,
|
||||
commit: bool = True,
|
||||
) -> Journal:
|
||||
if event_type is None:
|
||||
if payload is None:
|
||||
event_type = "null"
|
||||
else:
|
||||
event_type = type(payload).__name__
|
||||
|
||||
if payload.person_id is None:
|
||||
payload.person_id = self.person_id
|
||||
if payload.post_id is None:
|
||||
payload.post_id = post_id
|
||||
if payload.workpackage_id is None:
|
||||
payload.workpackage_id = workpackage_id
|
||||
if payload.task_type is None:
|
||||
payload.task_type = task_type
|
||||
|
||||
entry = Journal(
|
||||
person_id=self.person_id,
|
||||
api_client_id=self.api_client.id,
|
||||
created_date=utcnow(),
|
||||
event_type=event_type,
|
||||
event_payload=PayloadContainer(payload=payload),
|
||||
post_id=post_id,
|
||||
)
|
||||
|
||||
self.db.add(entry)
|
||||
if commit:
|
||||
self.db.commit()
|
||||
|
||||
return entry
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from .api_client import ApiClient
|
||||
from .journal import Journal, JournalIntegration
|
||||
from .person import Person
|
||||
from .person_stats import PersonStats
|
||||
from .post import Post
|
||||
@@ -15,4 +16,6 @@ __all__ = [
|
||||
"PostReaction",
|
||||
"WorkPackage",
|
||||
"TextLabels",
|
||||
"Journal",
|
||||
"JournalIntegration",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid1, uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
from .payload_column_type import PayloadContainer, payload_column_type
|
||||
|
||||
|
||||
def generate_time_uuid(node=None, clock_seq=None):
|
||||
"""Create a lexicographically sortable time ordered custom (non-standard) UUID by reordering the timestamp fields of a version 1 UUID."""
|
||||
(time_low, time_mid, time_hi_version, clock_seq_hi_variant, clock_seq_low, node) = uuid1(node, clock_seq).fields
|
||||
# reconstruct 60 bit timestamp, see version 1 uuid: https://www.rfc-editor.org/rfc/rfc4122
|
||||
timestamp = (time_hi_version & 0xFFF) << 48 | (time_mid << 32) | time_low
|
||||
version = time_hi_version >> 12
|
||||
assert version == 1
|
||||
a = timestamp >> 28 # bits 28-59
|
||||
b = (timestamp >> 12) & 0xFFFF # bits 12-27
|
||||
c = timestamp & 0xFFF # bits 0-11 (clear version bits)
|
||||
clock_seq_hi_variant &= 0xF # (clear variant bits)
|
||||
return UUID(fields=(a, b, c, clock_seq_hi_variant, clock_seq_low, node), version=None)
|
||||
|
||||
|
||||
class Journal(SQLModel, table=True):
|
||||
__tablename__ = "journal"
|
||||
|
||||
id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), primary_key=True, default=generate_time_uuid),
|
||||
)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
person_id: UUID = Field(nullable=True, foreign_key="person.id", index=True)
|
||||
post_id: Optional[UUID] = Field(foreign_key="post.id", nullable=True)
|
||||
api_client_id: UUID = Field(foreign_key="api_client.id")
|
||||
|
||||
event_type: str = Field(nullable=False, max_length=200)
|
||||
event_payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
|
||||
|
||||
|
||||
class JournalIntegration(SQLModel, table=True):
|
||||
__tablename__ = "journal_integration"
|
||||
|
||||
id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
|
||||
),
|
||||
)
|
||||
description: str = Field(max_length=512, primary_key=True)
|
||||
last_journal_id: UUID = Field(foreign_key="journal.id", nullable=True)
|
||||
last_run: datetime = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
|
||||
last_error: str = Field(nullable=True)
|
||||
next_run: datetime = Field(nullable=True)
|
||||
@@ -5,6 +5,7 @@ from uuid import UUID, uuid4
|
||||
|
||||
import oasst_backend.models.db_payload as db_payload
|
||||
from loguru import logger
|
||||
from oasst_backend.journal_writer import JournalWriter
|
||||
from oasst_backend.models import ApiClient, Person, Post, PostReaction, TextLabels, WorkPackage
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
@@ -17,6 +18,7 @@ class PromptRepository:
|
||||
self.api_client = api_client
|
||||
self.person = self.lookup_person(user)
|
||||
self.person_id = self.person.id if self.person else None
|
||||
self.journal = JournalWriter(db, api_client, self.person)
|
||||
|
||||
def lookup_person(self, user: protocol_schema.User) -> Person:
|
||||
if not user:
|
||||
@@ -116,6 +118,10 @@ class PromptRepository:
|
||||
self.validate_post_id(reply.post_id)
|
||||
self.validate_post_id(reply.user_post_id)
|
||||
|
||||
work_package = self.fetch_workpackage_by_postid(reply.post_id)
|
||||
work_payload: db_payload.TaskPayload = work_package.payload.payload
|
||||
logger.info(f"found task work package in db: {work_payload}")
|
||||
|
||||
# find post with post-id
|
||||
parent_post: Post = (
|
||||
self.db.query(Post)
|
||||
@@ -141,6 +147,7 @@ class PromptRepository:
|
||||
role=role,
|
||||
payload=db_payload.PostPayload(text=reply.text),
|
||||
)
|
||||
self.journal.log_text_reply(work_package=work_package, post_id=user_post_id, role=role, length=len(reply.text))
|
||||
return user_post
|
||||
|
||||
def store_rating(self, rating: protocol_schema.PostRating) -> PostReaction:
|
||||
@@ -159,6 +166,7 @@ class PromptRepository:
|
||||
# store reaction to post
|
||||
reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating)
|
||||
reaction = self.insert_reaction(post.id, reaction_payload)
|
||||
self.journal.log_rating(work_package, post_id=post.id, rating=rating.rating)
|
||||
logger.info(f"Ranking {rating.rating} stored for work_package {work_package.id}.")
|
||||
return reaction
|
||||
|
||||
@@ -184,6 +192,7 @@ class PromptRepository:
|
||||
# store reaction to post
|
||||
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
|
||||
reaction = self.insert_reaction(post.id, reaction_payload)
|
||||
self.journal.log_ranking(work_package, post_id=post.id, ranking=ranking.ranking)
|
||||
|
||||
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
|
||||
|
||||
@@ -199,6 +208,7 @@ class PromptRepository:
|
||||
# store reaction to post
|
||||
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
|
||||
reaction = self.insert_reaction(post.id, reaction_payload)
|
||||
self.journal.log_ranking(work_package, post_id=post.id, ranking=ranking.ranking)
|
||||
|
||||
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
|
||||
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
def utcnow() -> datetime:
|
||||
"""Return the current utc date and time with tzinfo set to UTC."""
|
||||
return datetime.now(timezone.utc)
|
||||
Reference in New Issue
Block a user