diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..6313b56c --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +* text=auto eol=lf diff --git a/README.md b/README.md index 2fe68738..a2b92789 100644 --- a/README.md +++ b/README.md @@ -60,31 +60,3 @@ In case you haven't done this, have already committed, and CI is failing, you ca ### Deployment Upon making a release on GitHub, all docker images are automatically built and pushed to ghcr.io. The docker images are tagged with the release version, and the `latest` tag. Further, the ansible playbook in `ansible/dev.yaml` is run to automatically deploy the built release to the dev machine. - -# (Older version of the readme below) - -## How do I start helping out? - -Check out these pages to learn more about the project. - -Ping Birger on discord if you want help to get started. - -http://**discordapp.com/users/birger#6875** - -## More information in the notion - -https://roan-iguanadon-a58.notion.site/Open-Chat-Gpt-83dd217eeeb84907a155b8a9d716fa46 - -## Code structure - -### Bot - -We have a folder named bot where code related to the bot lives. - -### Backend - -We have a backend folder for backend development of the api that the discord bot sends it information to. - -### Website - -We have a folder for the website, live at https://projects.laion.ai/Open-Chat-GPT/ .The website is built using Next.js diff --git a/backend/alembic/script.py.mako b/backend/alembic/script.py.mako index 55df2863..3124b62c 100644 --- a/backend/alembic/script.py.mako +++ b/backend/alembic/script.py.mako @@ -7,6 +7,7 @@ Create Date: ${create_date} """ from alembic import op import sqlalchemy as sa +import sqlmodel ${imports if imports else ""} # revision identifiers, used by Alembic. diff --git a/backend/alembic/versions/2022_12_25_1705-067c4002f2d9_add_text_labels.py b/backend/alembic/versions/2022_12_25_1705-067c4002f2d9_add_text_labels.py new file mode 100644 index 00000000..94e1c514 --- /dev/null +++ b/backend/alembic/versions/2022_12_25_1705-067c4002f2d9_add_text_labels.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +"""Adds text labels table. + +Revision ID: 067c4002f2d9 +Revises: 0daec5f8135f +Create Date: 2022-12-25 17:05:21.208843 + +""" +import sqlalchemy as sa +import sqlmodel +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "067c4002f2d9" +down_revision = "0daec5f8135f" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "text_labels", + sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False), + sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.Column("post_id", postgresql.UUID(as_uuid=True), nullable=True), + sa.Column("labels", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("api_client_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("text", sqlmodel.sql.sqltypes.AutoString(length=65536), nullable=False), + sa.ForeignKeyConstraint( + ["api_client_id"], + ["api_client.id"], + ), + sa.ForeignKeyConstraint( + ["post_id"], + ["post.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("text_labels") + # ### end Alembic commands ### diff --git a/backend/alembic/versions/2022_12_27_1444-3358eb6834e6_add_journal_table.py b/backend/alembic/versions/2022_12_27_1444-3358eb6834e6_add_journal_table.py new file mode 100644 index 00000000..0dc937a0 --- /dev/null +++ b/backend/alembic/versions/2022_12_27_1444-3358eb6834e6_add_journal_table.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +"""add_journal_table + +Revision ID: 3358eb6834e6 +Revises: 067c4002f2d9 +Create Date: 2022-12-27 14:44:59.483868 + +""" +import sqlalchemy as sa +import sqlmodel +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "3358eb6834e6" +down_revision = "067c4002f2d9" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "journal", + sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column( + "created_date", sa.DateTime(timezone=True), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False + ), + sa.Column( + "event_payload", + postgresql.JSONB(astext_type=sa.Text()), + nullable=False, + ), + sa.Column("person_id", sqlmodel.sql.sqltypes.GUID(), nullable=True), + sa.Column("post_id", sqlmodel.sql.sqltypes.GUID(), nullable=True), + sa.Column("api_client_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("event_type", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False), + sa.ForeignKeyConstraint( + ["api_client_id"], + ["api_client.id"], + ), + sa.ForeignKeyConstraint( + ["person_id"], + ["person.id"], + ), + sa.ForeignKeyConstraint( + ["post_id"], + ["post.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_journal_person_id"), "journal", ["person_id"], unique=False) + op.create_table( + "journal_integration", + sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False), + sa.Column("last_run", sa.DateTime(), nullable=True), + sa.Column("description", sqlmodel.sql.sqltypes.AutoString(length=512), nullable=False), + sa.Column("last_journal_id", sqlmodel.sql.sqltypes.GUID(), nullable=True), + sa.Column("last_error", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("next_run", sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint( + ["last_journal_id"], + ["journal.id"], + ), + sa.PrimaryKeyConstraint("id", "description"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("journal_integration") + op.drop_index(op.f("ix_journal_person_id"), table_name="journal") + op.drop_table("journal") + # ### end Alembic commands ### diff --git a/backend/oasst_backend/api/v1/api.py b/backend/oasst_backend/api/v1/api.py index cd1119d6..b54f3dd0 100644 --- a/backend/oasst_backend/api/v1/api.py +++ b/backend/oasst_backend/api/v1/api.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from fastapi import APIRouter -from oasst_backend.api.v1 import tasks +from oasst_backend.api.v1 import tasks, text_labels api_router = APIRouter() api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"]) +api_router.include_router(text_labels.router, prefix="/text_labels", tags=["text_labels"]) diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index 7ec5aa96..0778fd4c 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -7,7 +7,6 @@ from fastapi import APIRouter, Depends, HTTPException from fastapi.security.api_key import APIKey from loguru import logger from oasst_backend.api import deps -from oasst_backend.models.db_payload import TaskPayload from oasst_backend.prompt_repository import PromptRepository from oasst_shared.schemas import protocol as protocol_schema from sqlmodel import Session @@ -219,10 +218,6 @@ def post_interaction( f"Frontend reports text reply to {interaction.post_id=} with {interaction.text=} by {interaction.user=}." ) - work_package = pr.fetch_workpackage_by_postid(interaction.post_id) - work_payload: TaskPayload = work_package.payload.payload - logger.info(f"found task work package in db: {work_payload}") - # here we store the text reply in the database # ToDo: role user or agent? pr.store_text_reply(interaction, role="unknown") diff --git a/backend/oasst_backend/api/v1/text_labels.py b/backend/oasst_backend/api/v1/text_labels.py new file mode 100644 index 00000000..09933304 --- /dev/null +++ b/backend/oasst_backend/api/v1/text_labels.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +import pydantic +from fastapi import APIRouter, Depends, HTTPException +from fastapi.security.api_key import APIKey +from loguru import logger +from oasst_backend.api import deps +from oasst_backend.prompt_repository import PromptRepository +from oasst_shared.schemas import protocol as protocol_schema +from sqlmodel import Session +from starlette.status import HTTP_400_BAD_REQUEST + +router = APIRouter() + + +class LabelTextRequest(pydantic.BaseModel): + text_labels: protocol_schema.TextLabels + user: protocol_schema.User + + +@router.post("/") +def label_text( + *, + db: Session = Depends(deps.get_db), + api_key: APIKey = Depends(deps.get_api_key), + request: LabelTextRequest, +) -> None: + """ + Label a piece of text. + """ + api_client = deps.api_auth(api_key, db) + + try: + logger.info(f"Labeling text {request=}.") + pr = PromptRepository(db, api_client, user=request.user) + pr.store_text_labels(request.text_labels) + + except Exception: + logger.exception("Failed to store label.") + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + ) diff --git a/backend/oasst_backend/journal_writer.py b/backend/oasst_backend/journal_writer.py new file mode 100644 index 00000000..897e2dda --- /dev/null +++ b/backend/oasst_backend/journal_writer.py @@ -0,0 +1,122 @@ +# -*- coding: utf-8 -*- +import enum +from typing import Literal, Optional +from uuid import UUID + +from oasst_backend.models import ApiClient, Journal, Person, WorkPackage +from oasst_backend.models.payload_column_type import PayloadContainer, payload_type +from oasst_shared.utils import utcnow +from pydantic import BaseModel +from sqlmodel import Session + + +class JournalEventType(str, enum.Enum): + """A label for a piece of text.""" + + user_created = "user_created" + text_reply_to_post = "text_reply_to_post" + post_rating = "post_rating" + post_ranking = "post_ranking" + + +@payload_type +class JournalEvent(BaseModel): + type: str + person_id: Optional[UUID] + post_id: Optional[UUID] + workpackage_id: Optional[UUID] + task_type: Optional[str] + + +@payload_type +class TextReplyEvent(JournalEvent): + type: Literal[JournalEventType.text_reply_to_post] = JournalEventType.text_reply_to_post + length: int + role: str + + +@payload_type +class RatingEvent(JournalEvent): + type: Literal[JournalEventType.post_rating] = JournalEventType.post_rating + rating: int + + +@payload_type +class RankingEvent(JournalEvent): + type: Literal[JournalEventType.post_ranking] = JournalEventType.post_ranking + ranking: list[int] + + +class JournalWriter: + def __init__(self, db: Session, api_client: ApiClient, person: Person): + self.db = db + self.api_client = api_client + self.person = person + self.person_id = self.person.id if self.person else None + + def log_text_reply(self, work_package: WorkPackage, post_id: UUID, role: str, length: int) -> Journal: + return self.log( + task_type=work_package.payload_type, + event_type=JournalEventType.text_reply_to_post, + payload=TextReplyEvent(role=role, length=length), + workpackage_id=work_package.id, + post_id=post_id, + ) + + def log_rating(self, work_package: WorkPackage, post_id: UUID, rating: int) -> Journal: + return self.log( + task_type=work_package.payload_type, + event_type=JournalEventType.post_rating, + payload=RatingEvent(rating=rating), + workpackage_id=work_package.id, + post_id=post_id, + ) + + def log_ranking(self, work_package: WorkPackage, post_id: UUID, ranking: list[int]) -> Journal: + return self.log( + task_type=work_package.payload_type, + event_type=JournalEventType.post_ranking, + payload=RankingEvent(ranking=ranking), + workpackage_id=work_package.id, + post_id=post_id, + ) + + def log( + self, + *, + payload: JournalEvent, + task_type: str, + event_type: str = None, + workpackage_id: Optional[UUID] = None, + post_id: Optional[UUID] = None, + commit: bool = True, + ) -> Journal: + if event_type is None: + if payload is None: + event_type = "null" + else: + event_type = type(payload).__name__ + + if payload.person_id is None: + payload.person_id = self.person_id + if payload.post_id is None: + payload.post_id = post_id + if payload.workpackage_id is None: + payload.workpackage_id = workpackage_id + if payload.task_type is None: + payload.task_type = task_type + + entry = Journal( + person_id=self.person_id, + api_client_id=self.api_client.id, + created_date=utcnow(), + event_type=event_type, + event_payload=PayloadContainer(payload=payload), + post_id=post_id, + ) + + self.db.add(entry) + if commit: + self.db.commit() + + return entry diff --git a/backend/oasst_backend/models/__init__.py b/backend/oasst_backend/models/__init__.py index 414ec385..0acc242c 100644 --- a/backend/oasst_backend/models/__init__.py +++ b/backend/oasst_backend/models/__init__.py @@ -1,9 +1,11 @@ # -*- coding: utf-8 -*- from .api_client import ApiClient +from .journal import Journal, JournalIntegration from .person import Person from .person_stats import PersonStats from .post import Post from .post_reaction import PostReaction +from .text_labels import TextLabels from .work_package import WorkPackage __all__ = [ @@ -13,4 +15,7 @@ __all__ = [ "Post", "PostReaction", "WorkPackage", + "TextLabels", + "Journal", + "JournalIntegration", ] diff --git a/backend/oasst_backend/models/journal.py b/backend/oasst_backend/models/journal.py new file mode 100644 index 00000000..4cec1e99 --- /dev/null +++ b/backend/oasst_backend/models/journal.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +from datetime import datetime +from typing import Optional +from uuid import UUID, uuid1, uuid4 + +import sqlalchemy as sa +import sqlalchemy.dialects.postgresql as pg +from sqlmodel import Field, SQLModel + +from .payload_column_type import PayloadContainer, payload_column_type + + +def generate_time_uuid(node=None, clock_seq=None): + """Create a lexicographically sortable time ordered custom (non-standard) UUID by reordering the timestamp fields of a version 1 UUID.""" + (time_low, time_mid, time_hi_version, clock_seq_hi_variant, clock_seq_low, node) = uuid1(node, clock_seq).fields + # reconstruct 60 bit timestamp, see version 1 uuid: https://www.rfc-editor.org/rfc/rfc4122 + timestamp = (time_hi_version & 0xFFF) << 48 | (time_mid << 32) | time_low + version = time_hi_version >> 12 + assert version == 1 + a = timestamp >> 28 # bits 28-59 + b = (timestamp >> 12) & 0xFFFF # bits 12-27 + c = timestamp & 0xFFF # bits 0-11 (clear version bits) + clock_seq_hi_variant &= 0xF # (clear variant bits) + return UUID(fields=(a, b, c, clock_seq_hi_variant, clock_seq_low, node), version=None) + + +class Journal(SQLModel, table=True): + __tablename__ = "journal" + + id: Optional[UUID] = Field( + sa_column=sa.Column(pg.UUID(as_uuid=True), primary_key=True, default=generate_time_uuid), + ) + created_date: Optional[datetime] = Field( + sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp()) + ) + person_id: UUID = Field(nullable=True, foreign_key="person.id", index=True) + post_id: Optional[UUID] = Field(foreign_key="post.id", nullable=True) + api_client_id: UUID = Field(foreign_key="api_client.id") + + event_type: str = Field(nullable=False, max_length=200) + event_payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False)) + + +class JournalIntegration(SQLModel, table=True): + __tablename__ = "journal_integration" + + id: Optional[UUID] = Field( + sa_column=sa.Column( + pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()") + ), + ) + description: str = Field(max_length=512, primary_key=True) + last_journal_id: UUID = Field(foreign_key="journal.id", nullable=True) + last_run: datetime = Field(sa_column=sa.Column(sa.DateTime(), nullable=True)) + last_error: str = Field(nullable=True) + next_run: datetime = Field(nullable=True) diff --git a/backend/oasst_backend/models/text_labels.py b/backend/oasst_backend/models/text_labels.py new file mode 100644 index 00000000..2699302f --- /dev/null +++ b/backend/oasst_backend/models/text_labels.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +from datetime import datetime +from typing import Optional +from uuid import UUID, uuid4 + +import sqlalchemy as sa +import sqlalchemy.dialects.postgresql as pg +from sqlmodel import Field, SQLModel + + +class TextLabels(SQLModel, table=True): + __tablename__ = "text_labels" + + id: Optional[UUID] = Field( + sa_column=sa.Column( + pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()") + ), + ) + created_date: Optional[datetime] = Field( + sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()), + ) + api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id") + text: str = Field(nullable=False, max_length=2**16) + post_id: Optional[UUID] = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("post.id"), nullable=True)) + labels: dict[str, float] = Field(default={}, sa_column=sa.Column(pg.JSONB), nullable=False) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index b0063cdf..f4f83277 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -5,7 +5,8 @@ from uuid import UUID, uuid4 import oasst_backend.models.db_payload as db_payload from loguru import logger -from oasst_backend.models import ApiClient, Person, Post, PostReaction, WorkPackage +from oasst_backend.journal_writer import JournalWriter +from oasst_backend.models import ApiClient, Person, Post, PostReaction, TextLabels, WorkPackage from oasst_backend.models.payload_column_type import PayloadContainer from oasst_shared.schemas import protocol as protocol_schema from sqlmodel import Session @@ -17,6 +18,7 @@ class PromptRepository: self.api_client = api_client self.person = self.lookup_person(user) self.person_id = self.person.id if self.person else None + self.journal = JournalWriter(db, api_client, self.person) def lookup_person(self, user: protocol_schema.User) -> Person: if not user: @@ -116,6 +118,10 @@ class PromptRepository: self.validate_post_id(reply.post_id) self.validate_post_id(reply.user_post_id) + work_package = self.fetch_workpackage_by_postid(reply.post_id) + work_payload: db_payload.TaskPayload = work_package.payload.payload + logger.info(f"found task work package in db: {work_payload}") + # find post with post-id parent_post: Post = ( self.db.query(Post) @@ -141,6 +147,7 @@ class PromptRepository: role=role, payload=db_payload.PostPayload(text=reply.text), ) + self.journal.log_text_reply(work_package=work_package, post_id=user_post_id, role=role, length=len(reply.text)) return user_post def store_rating(self, rating: protocol_schema.PostRating) -> PostReaction: @@ -159,6 +166,7 @@ class PromptRepository: # store reaction to post reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating) reaction = self.insert_reaction(post.id, reaction_payload) + self.journal.log_rating(work_package, post_id=post.id, rating=rating.rating) logger.info(f"Ranking {rating.rating} stored for work_package {work_package.id}.") return reaction @@ -184,6 +192,7 @@ class PromptRepository: # store reaction to post reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking) reaction = self.insert_reaction(post.id, reaction_payload) + self.journal.log_ranking(work_package, post_id=post.id, ranking=ranking.ranking) logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.") @@ -199,6 +208,7 @@ class PromptRepository: # store reaction to post reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking) reaction = self.insert_reaction(post.id, reaction_payload) + self.journal.log_ranking(work_package, post_id=post.id, ranking=ranking.ranking) logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.") @@ -314,3 +324,17 @@ class PromptRepository: self.db.commit() self.db.refresh(reaction) return reaction + + def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> TextLabels: + model = TextLabels( + api_client_id=self.api_client.id, + text=text_labels.text, + labels=text_labels.labels, + ) + if text_labels.has_post_id: + self.fetch_post_by_frontend_post_id(text_labels.post_id, fail_if_missing=True) + model.post_id = text_labels.post_id + self.db.add(model) + self.db.commit() + self.db.refresh(model) + return model diff --git a/backend/requirements.txt b/backend/requirements.txt index b882d594..dd11aa18 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -5,6 +5,7 @@ numpy==1.22.4 psycopg2-binary==2.9.5 pydantic==1.9.1 python-dotenv==0.21.0 +scipy==1.8.1 SQLAlchemy==1.4.41 sqlmodel==0.0.8 starlette==0.22.0 diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index d5f508b6..17ee23f0 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -204,3 +204,51 @@ AnyInteraction = Union[ PostRating, PostRanking, ] + + +class TextLabel(str, enum.Enum): + """A label for a piece of text.""" + + spam = "spam" + violence = "violence" + sexual_content = "sexual_content" + toxicity = "toxicity" + political_content = "political_content" + humor = "humor" + sarcasm = "sarcasm" + hate_speech = "hate_speech" + profanity = "profanity" + ad_hominem = "ad_hominem" + insult = "insult" + threat = "threat" + aggressive = "aggressive" + misleading = "misleading" + helpful = "helpful" + formal = "formal" + cringe = "cringe" + creative = "creative" + beautiful = "beautiful" + informative = "informative" + based = "based" + slang = "slang" + + +class TextLabels(BaseModel): + """A set of labels for a piece of text.""" + + text: str + labels: dict[TextLabel, float] + post_id: str | None = None + + @property + def has_post_id(self) -> bool: + """Whether this TextLabels has a post_id.""" + return bool(self.post_id) + + # check that each label value is between 0 and 1 + @pydantic.validator("labels") + def check_label_values(cls, v): + for key, value in v.items(): + if not (0 <= value <= 1): + raise ValueError(f"Label values must be between 0 and 1, got {value} for {key}.") + return v diff --git a/oasst-shared/oasst_shared/utils.py b/oasst-shared/oasst_shared/utils.py new file mode 100644 index 00000000..dd1cbf07 --- /dev/null +++ b/oasst-shared/oasst_shared/utils.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- +from datetime import datetime, timezone + + +def utcnow() -> datetime: + """Return the current utc date and time with tzinfo set to UTC.""" + return datetime.now(timezone.utc) diff --git a/scripts/backend-development/docker-compose.yaml b/scripts/backend-development/docker-compose.yaml index 65a65e73..0445cf34 100644 --- a/scripts/backend-development/docker-compose.yaml +++ b/scripts/backend-development/docker-compose.yaml @@ -9,6 +9,7 @@ services: environment: POSTGRES_USER: postgres POSTGRES_PASSWORD: postgres + POSTGRES_DB: postgres healthcheck: test: ["CMD", "pg_isready", "-U", "postgres"] interval: 2s diff --git a/scripts/frontend-development/docker-compose.yaml b/scripts/frontend-development/docker-compose.yaml index ef0f3489..e34c2d8f 100644 --- a/scripts/frontend-development/docker-compose.yaml +++ b/scripts/frontend-development/docker-compose.yaml @@ -16,6 +16,7 @@ services: environment: POSTGRES_USER: postgres POSTGRES_PASSWORD: postgres + POSTGRES_DB: ocgpt_website healthcheck: test: ["CMD", "pg_isready", "-U", "postgres"] interval: 2s diff --git a/scripts/postprocessing/infogain_selector.py b/scripts/postprocessing/infogain_selector.py new file mode 100644 index 00000000..51f60fa7 --- /dev/null +++ b/scripts/postprocessing/infogain_selector.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +import numpy as np +from scipy import log2 +from scipy.integrate import nquad +from scipy.special import gammaln, psi +from scipy.stats import dirichlet + + +def make_range(*x): + """ + constructs leftover values for the simplex given the first k entries + (0,x_k) = 1-(x_1+...+x_(k-1)) + """ + return (0, max(0, 1 - sum(x))) + + +def relative_entropy(p, q): + """ + relative entropy of the two given dirichlet distributions + """ + + def tmp(*x): + """ + First adds the last always forced entry to the input (the last x_last = 1-(x_1+...+x_(N)) ) + Then computes the relative entropy of posterior and prior for that datapoint + """ + x_new = np.append(x, 1 - sum(x)) + return p(x_new) * log2(p(x_new) / q(x_new)) + + return tmp + + +def naive_monte_carlo_integral(fun, dim, samples=10_000_000): + s = np.random.rand(dim - 1, samples) + s = np.sort(np.concatenate((np.zeros((1, samples)), s, np.ones((1, samples)))), 0) + # print(s) + pos = np.diff(s, axis=0) + # print(pos) + res = fun(pos) + return np.mean(res) + + +def analytic_solution(a_post, a_prior): + """ + Analytic solution to the KL-divergence between two dirichlet distributions. + Proof is in the Notion design doc. + """ + post_sum = np.sum(a_post) + prior_sum = np.sum(a_prior) + info = ( + gammaln(post_sum) + - gammaln(prior_sum) + - np.sum(gammaln(a_post)) + + np.sum(gammaln(a_prior)) + - np.sum((a_post - a_prior) * (psi(a_post) - psi(post_sum))) + ) + + return info + + +def infogain(a_post, a_prior): + raise ( + """For the love of good don't use this: + it's insanely poorly conditioned, the worst numerical code I have ever written + and it's slow as molasses. Use the analytic solution instead. + + Maybe remove + """ + ) + args = len(a_prior) + p = dirichlet(a_post).pdf + q = dirichlet(a_prior).pdf + (info, _) = nquad(relative_entropy(p, q), [make_range for _ in range(args - 1)], opts={"epsabs": 1e-8}) + # info = naive_monte_carlo_integral(relative_entropy(p,q), len(a_post)) + return info + + +def uniform_expected_infogain(a_prior): + mean_weight = dirichlet.mean(a_prior) + print("weight", mean_weight) + results = [] + for i, w in enumerate(mean_weight): + a_post = a_prior.copy() + a_post[i] = a_post[i] + 1 + results.append(w * analytic_solution(a_post, a_prior)) + return np.sum(results) + + +if __name__ == "__main__": + a_prior = np.array([1, 1, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + a_post = np.array([1, 1, 20, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + + print("algebraic", analytic_solution(a_post, a_prior)) + # print("raw",infogain(a_post, a_prior)) + print("large infogain", uniform_expected_infogain(a_prior)) + print("post infogain", uniform_expected_infogain(a_post)) + # a_prior = np.array([1,1,1000]) + # print("small infogain",uniform_expected_infogain(a_prior)) diff --git a/scripts/postprocessing/rankings.py b/scripts/postprocessing/rankings.py index 38686f67..7b28399c 100644 --- a/scripts/postprocessing/rankings.py +++ b/scripts/postprocessing/rankings.py @@ -68,7 +68,7 @@ def get_winner(pairs): def get_ranking(pairs): """ Abuses concordance property to get a (not necessarily unqiue) ranking. - The lack of uniqueness is due to the potential existance of multiple + The lack of uniqueness is due to the potential existence of multiple equally ranked winners. We have to pick one, which is where the non-uniqueness comes from """ @@ -99,7 +99,7 @@ def ranked_pairs(ranks: List[List[int]]): tallies = tallies - tallies.T # print(tallies) # note: the resulting tally matrix should be skew-symmetric - # order by strenght of victory (using tideman's original method, don't think it would make a difference for us) + # order by strength of victory (using tideman's original method, don't think it would make a difference for us) sorted_majorities = [] for i in range(len(ranks[0])): for j in range(len(ranks[i])): diff --git a/scripts/postprocessing/scoring.py b/scripts/postprocessing/scoring.py new file mode 100644 index 00000000..3c145b28 --- /dev/null +++ b/scripts/postprocessing/scoring.py @@ -0,0 +1,183 @@ +# -*- coding: utf-8 -*- +from dataclasses import dataclass, replace +from typing import Any + +import numpy as np +import numpy.typing as npt +from scipy.stats import kendalltau + + +@dataclass +class Voter: + """ + Represents a single voter. + This tabulates the number of good votes, total votes, + and points. + We only put well-behaved people on the scoreboard and filter out the badly behaved ones + """ + + uid: Any + num_votes: int + num_good_votes: int + num_prompts: int + num_good_prompts: int + num_rankings: int + num_good_rankings: int + + ##################### + voting_points: int + prompt_points: int + ranking_points: int + + def voter_quality(self): + return self.num_good_votes / self.num_votes + + def rank_quality(self): + return self.num_good_rankings / self.num_rankings + + def prompt_quality(self): + return self.num_good_prompts / self.num_prompts + + def is_well_behaved(self, threshhold_vote, threshhold_prompt, threshhold_rank): + return ( + self.voter_quality() > threshhold_vote + and self.prompt_quality() > threshhold_prompt + and self.rank_quality() > threshhold_rank + ) + + def total_points(self, voting_weight, prompt_weight, ranking_weight): + return ( + voting_weight * self.voting_points + + prompt_weight * self.prompt_points + + ranking_weight * self.ranking_points + ) + + +def score_update_votes(new_vote: int, consensus: npt.ArrayLike, voter_data: Voter) -> Voter: + """ + This function returns the new "quality score" and points for a voter, + after that voter cast a vote on a question. + + This function is only to be run when archiving a question + i.e. the question has had sufficiently many votes, or we cann't get more than "K" bits of information + + The consensus is the array of all votes cast by all voters for that question + We then update the voter data using the new information + + Parameters: + new_vote (int): the index of the vote cast by the voter + consensus (ArrayLike): all votes cast for this question + voter_data (Voter): a "Voter" object that represents the person casting the "new_vote" + + Returns: + updated_voter (Voter): the new "quality score" and points for the voter + """ + # produces the ranking of votes, e.g. for [100,300,200] it returns [0, 2, 1], + # since 100 is the lowest, 300 the highest and 200 the middle value + consensus_ranking = np.argsort(np.argsort(consensus)) + new_points = consensus_ranking[new_vote] + voter_data.voting_points + + # we need to correct for 0 indexing, if you are closer to "right" than "wrong" of the conensus, + # it's a good vote + new_good_votes = int(consensus_ranking[new_vote] > (len(consensus) - 1) / 2) + voter_data.num_good_votes + new_num_votes = voter_data.num_votes + 1 + return replace(voter_data, num_votes=new_num_votes, num_good_votes=new_good_votes, voting_points=new_points) + + +def score_update_prompts(consensus: npt.ArrayLike, voter_data: Voter) -> Voter: + """ + This function returns the gain of points for a given prompt's votes + + This function is only to be run when archiving a question + i.e. the question has had sufficiently many votes, or we cann't get more than "K" bits of information + + Parameters: + consensus (ArrayLike): all votes cast for this question + voter_data (Voter): a "Voter" object that represents the person that wrote the prompt + + Returns: + updated_voter (Voter): the new "quality score" and points for the voter + """ + # produces the ranking of votes, e.g. for [100,300,200] it returns [0, 2, 1], + # since 100 is the lowest, 300 the highest and 200 the middle value + consensus_ranking = np.arange(len(consensus)) - len(consensus) // 2 + 1 + delta_votes = np.sum(consensus_ranking * consensus) + new_points = delta_votes + voter_data.prompt_points + + # we need to correct for 0 indexing, if you are closer to "right" than "wrong" of the conensus, + # it's a good vote + new_good_prompts = int(delta_votes > 0) + voter_data.num_good_prompts + new_num_prompts = voter_data.num_prompts + 1 + return replace( + voter_data, + num_prompts=new_num_prompts, + num_good_prompts=new_good_prompts, + prompt_points=new_points, + ) + + +def score_update_ranking(user_ranking: npt.ArrayLike, consensus_ranking: npt.ArrayLike, voter_data: Voter) -> Voter: + """ + This function returns the gain of points for a given ranking's votes + + This function is only to be run when archiving a question + i.e. the question has had sufficiently many votes, or we cann't get more than "K" bits of information + + we use the bubble-sort distance (or "kendall-tau" distance) to compare the two rankings + we use this over spearman correlation since: + "[Kendall's τ] approaches a normal distribution more rapidly than ρ, as N, the sample size, increases; + and τ is also more tractable mathematically, particularly when ties are present" + Gilpin, A. R. (1993). Table for conversion of Kendall's Tau to Spearman's + Rho within the context measures of magnitude of effect for meta-analysis + + Further in + "research design and statistical analyses, second edition, 2003" + the authors note that at least from an significance test POV they will yield the same p-values + + Parameters: + user_ranking (ArrayLike): ranking produced by the user + consensus (ArrayLike): ranking produced after running the voting algorithm to merge into the consensus ranking + voter_data (Voter): a "Voter" object that represents the person that wrote the prompt + + Returns: + updated_voter (Voter): the new "quality score" and points for the voter + """ + bubble_sort_distance, p_value = kendalltau(user_ranking, consensus_ranking) + # normalize kendall-tau from [-1,1] into [0,1] range + bubble_sort_distance = (1 + bubble_sort_distance) / 2 + new_points = bubble_sort_distance + voter_data.ranking_points + new_good_rankings = int(bubble_sort_distance > 0.5) + voter_data.num_good_rankings + new_num_rankings = voter_data.num_rankings + 1 + return replace( + voter_data, + num_rankings=new_num_rankings, + num_good_rankings=new_good_rankings, + ranking_points=new_points, + ) + + +if __name__ == "__main__": + demo_voter = Voter( + "abc", + num_votes=10, + num_good_votes=2, + num_prompts=10, + num_good_prompts=2, + num_rankings=10, + num_good_rankings=2, + voting_points=6, + prompt_points=0, + ranking_points=0, + ) + new_vote = 3 + consensus = np.array([200, 300, 100, 500]) + print(demo_voter) + print("best vote ", score_update_votes(new_vote, consensus, demo_voter)) + new_vote = 2 + print("worst vote ", score_update_votes(new_vote, consensus, demo_voter)) + new_vote = 1 + print("medium vote ", score_update_votes(new_vote, consensus, demo_voter)) + print("prompt writer", score_update_prompts(consensus, demo_voter)) + print("best rank ", score_update_ranking(np.array([0, 2, 1]), np.array([0, 2, 1]), demo_voter)) + print("medium rank ", score_update_ranking(np.array([2, 0, 1]), np.array([0, 2, 1]), demo_voter)) + print("worst rank ", score_update_ranking(np.array([1, 0, 2]), np.array([0, 2, 1]), demo_voter)) diff --git a/website/.env b/website/.env index a95df390..e4f3a202 100644 --- a/website/.env +++ b/website/.env @@ -5,7 +5,7 @@ DATABASE_URL=postgres://postgres:postgres@localhost:5433/ocgpt_website FASTAPI_URL=http://localhost:8080 FASTAPI_KEY=1234 -# A dev Auth Secret. Can be exposed if we never use this publically. +# A dev Auth Secret. Can be exposed if we never use this publicly. NEXTAUTH_SECRET=O/M2uIbGj+lDD2oyNa8ax4jEOJqCPJzO53UbWShmq98= # The SMTP host and port found by running the jobs in /scripts/frontend-development/docker-compose.yaml diff --git a/website/README.md b/website/README.md index 7d57ceaa..8b03043a 100644 --- a/website/README.md +++ b/website/README.md @@ -63,6 +63,14 @@ If you're doing active development we suggest the following workflow: navigate to `http://localhost:1080`. Check the email listed and click the log in link. You're now logged in and authenticated. +### Using debug user credentials + +Whenever the website runs in development mode, you can use the debug credentials provider to log in without fancy emails or OAuth. + +1. Development mode is automatically active when you start the website with `npm run dev`. +1. Use the `Login` button in the top right to go to the login page. +1. You should see a section for debug credentials. Enter any username you wish, you will be logged in as that user. + ## Code Layout ### React Code diff --git a/website/src/components/Avatar.tsx b/website/src/components/Avatar.tsx index 26bdec0b..d5706946 100644 --- a/website/src/components/Avatar.tsx +++ b/website/src/components/Avatar.tsx @@ -12,7 +12,7 @@ export function Avatar() { return <>; } if (session && session.user) { - const email = session.user.email; + const displayName = session.user.name || session.user.email; const accountOptions = [ { name: "Account Settings", @@ -35,7 +35,7 @@ export function Avatar() { height="40" className="rounded-full" > -

{email}

+

{displayName}

{/* Will be changed to username once it is implemented */} diff --git a/website/src/components/Button.tsx b/website/src/components/Button.tsx new file mode 100644 index 00000000..5ae1e7b4 --- /dev/null +++ b/website/src/components/Button.tsx @@ -0,0 +1,19 @@ +import clsx from "clsx"; + +export const Button = ( + props: React.DetailedHTMLProps, HTMLButtonElement> +) => { + const { className, children, ...rest } = props; + return ( + + ); +}; diff --git a/website/src/components/Footer.tsx b/website/src/components/Footer.tsx index fa3d8c0d..5e6ac47d 100644 --- a/website/src/components/Footer.tsx +++ b/website/src/components/Footer.tsx @@ -12,13 +12,7 @@ export function Footer() {
- logo + logo
diff --git a/website/src/components/Messages.tsx b/website/src/components/Messages.tsx new file mode 100644 index 00000000..4bc747d5 --- /dev/null +++ b/website/src/components/Messages.tsx @@ -0,0 +1,18 @@ +export interface Message { + text: string; + is_assistant: boolean; +} + +const getColor = (isAssistant: boolean) => (isAssistant ? "bg-slate-800" : "bg-sky-900"); + +export const Messages = ({ messages }: { messages: Message[] }) => { + const items = messages.map(({ text, is_assistant }: Message, i: number) => { + return ( +
+ {text} +
+ ); + }); + // Maybe also show a legend of the colors? + return <>{items}; +}; diff --git a/website/src/components/RankItem.tsx b/website/src/components/RankItem.tsx new file mode 100644 index 00000000..3ba9da70 --- /dev/null +++ b/website/src/components/RankItem.tsx @@ -0,0 +1,12 @@ +const RankItem = ({ username, score }) => { + return ( +
+
1
+
@username
+
20.5
+
gold
+
+ ); +}; + +export default RankItem; diff --git a/website/src/components/TaskSelection/TaskOption.tsx b/website/src/components/TaskSelection/TaskOption.tsx new file mode 100644 index 00000000..764efa68 --- /dev/null +++ b/website/src/components/TaskSelection/TaskOption.tsx @@ -0,0 +1,39 @@ +import { Card, CardBody, Flex, Heading } from "@chakra-ui/react"; +import Image from "next/image"; +import Link from "next/link"; + +export type OptionProps = { + img: string; + alt: string; + title: string; + link: string; +}; + +export const TaskOption = (props: OptionProps) => { + const { alt, img, title, link } = props; + return ( + + + + + {alt} + + {title} + + + + + + ); +}; diff --git a/website/src/components/TaskSelection/TaskOptions.tsx b/website/src/components/TaskSelection/TaskOptions.tsx new file mode 100644 index 00000000..fe24b393 --- /dev/null +++ b/website/src/components/TaskSelection/TaskOptions.tsx @@ -0,0 +1,23 @@ +import { Divider, Flex, Heading } from "@chakra-ui/react"; +import React from "react"; + +export type TaskOptionsProps = { + title: string; + children: JSX.Element | JSX.Element[]; +}; + +export const TaskOptions = (props: TaskOptionsProps) => { + const { title, children } = props; + return ( + + + {title} + + + {children} + + ); +}; diff --git a/website/src/components/TaskSelection/TaskSelection.tsx b/website/src/components/TaskSelection/TaskSelection.tsx new file mode 100644 index 00000000..81c067e8 --- /dev/null +++ b/website/src/components/TaskSelection/TaskSelection.tsx @@ -0,0 +1,29 @@ +import React from "react"; +import { TaskOptions } from "./TaskOptions"; +import { Flex } from "@chakra-ui/react"; +import { TaskOption } from "./TaskOption"; + +export const TaskSelection = () => { + return ( + + + + + + + + + + + ); +}; diff --git a/website/src/components/TaskSelection/index.ts b/website/src/components/TaskSelection/index.ts new file mode 100644 index 00000000..4da7ea7f --- /dev/null +++ b/website/src/components/TaskSelection/index.ts @@ -0,0 +1,3 @@ +export { TaskSelection } from "./TaskSelection"; +export { TaskOptions } from "./TaskOptions"; +export { TaskOption } from "./TaskOption"; diff --git a/website/src/components/TwoColumns.tsx b/website/src/components/TwoColumns.tsx new file mode 100644 index 00000000..5792f7d0 --- /dev/null +++ b/website/src/components/TwoColumns.tsx @@ -0,0 +1,14 @@ +export const TwoColumns = ({ children }: { children: React.ReactNode[] }) => { + if (!Array.isArray(children) || children.length !== 2) { + throw new Error("TwoColumns expects 2 children"); + } + + const [first, second] = children; + + return ( +
+
{first}
+
{second}
+
+ ); +}; diff --git a/website/src/pages/api/auth/[...nextauth].ts b/website/src/pages/api/auth/[...nextauth].ts index f823ed41..62767c98 100644 --- a/website/src/pages/api/auth/[...nextauth].ts +++ b/website/src/pages/api/auth/[...nextauth].ts @@ -2,6 +2,7 @@ import type { AuthOptions } from "next-auth"; import NextAuth from "next-auth"; import DiscordProvider from "next-auth/providers/discord"; import EmailProvider from "next-auth/providers/email"; +import CredentialsProvider from "next-auth/providers/credentials"; import { PrismaAdapter } from "@next-auth/prisma-adapter"; import prisma from "src/lib/prismadb"; @@ -32,6 +33,23 @@ if (process.env.DISCORD_CLIENT_ID) { ); } +if (process.env.NODE_ENV === "development") { + providers.push( + CredentialsProvider({ + name: "Debug Credentials", + credentials: { + username: { label: "Username", type: "text" }, + }, + async authorize(credentials) { + return { + id: credentials.username, + name: credentials.username, + }; + }, + }) + ); +} + export const authOptions: AuthOptions = { // Ensure we can store user data in a database. adapter: PrismaAdapter(prisma), diff --git a/website/src/pages/auth/signup.tsx b/website/src/pages/auth/signup.tsx index 4c890fcd..e34fcd05 100644 --- a/website/src/pages/auth/signup.tsx +++ b/website/src/pages/auth/signup.tsx @@ -1,6 +1,7 @@ import { Button, Input, Stack } from "@chakra-ui/react"; import Head from "next/head"; import Link from "next/link"; +import { FaDiscord, FaEnvelope, FaBug, FaGithub } from "react-icons/fa"; import { getCsrfToken, getProviders, signIn } from "next-auth/react"; import { useRef } from "react"; import { FaDiscord, FaEnvelope, FaGithub } from "react-icons/fa"; @@ -8,19 +9,28 @@ import { FaDiscord, FaEnvelope, FaGithub } from "react-icons/fa"; import { AuthLayout } from "src/components/AuthLayout"; export default function Signin({ csrfToken, providers }) { - const { discord, email, github } = providers; + const { discord, email, github, credentials } = providers; const emailEl = useRef(null); + const debugUsernameEl = useRef(null); const signinWithDiscord = () => { signIn(discord.id, { callbackUrl: "/" }); }; - const signinWithEmail = () => { + + const signinWithEmail = (ev: React.FormEvent) => { + ev.preventDefault(); signIn(email.id, { callbackUrl: "/", email: emailEl.current.value }); }; + const signinWithGithub = () => { signIn(github.id, { callbackUrl: "/" }); }; + function signinWithDebugCredentials(ev: React.FormEvent) { + ev.preventDefault(); + signIn(credentials.id, { callbackUrl: "/", username: debugUsernameEl.current.value }); + } + return ( <> @@ -28,14 +38,27 @@ export default function Signin({ csrfToken, providers }) { - + + {credentials && ( +
+ For Debugging Only + + + + +
+ )} {email && ( - - - - +
+ + + + +
)} {discord && (